|
@ -26,6 +26,7 @@ import socket |
|
|
import errno |
|
|
import errno |
|
|
import struct |
|
|
import struct |
|
|
import logging |
|
|
import logging |
|
|
|
|
|
import traceback |
|
|
import encrypt |
|
|
import encrypt |
|
|
import eventloop |
|
|
import eventloop |
|
|
from common import parse_header |
|
|
from common import parse_header |
|
@ -57,6 +58,7 @@ STAGE_HELLO = 1 |
|
|
STAGE_UDP_ASSOC = 2 |
|
|
STAGE_UDP_ASSOC = 2 |
|
|
STAGE_REPLY = 4 |
|
|
STAGE_REPLY = 4 |
|
|
STAGE_STREAM = 5 |
|
|
STAGE_STREAM = 5 |
|
|
|
|
|
STAGE_DESTROYED = -1 |
|
|
|
|
|
|
|
|
# stream direction |
|
|
# stream direction |
|
|
STREAM_UP = 0 |
|
|
STREAM_UP = 0 |
|
@ -137,7 +139,7 @@ class TCPRelayHandler(object): |
|
|
|
|
|
|
|
|
def _write_to_sock(self, data, sock): |
|
|
def _write_to_sock(self, data, sock): |
|
|
if not data or not sock: |
|
|
if not data or not sock: |
|
|
return |
|
|
return False |
|
|
uncomplete = False |
|
|
uncomplete = False |
|
|
try: |
|
|
try: |
|
|
l = len(data) |
|
|
l = len(data) |
|
@ -151,7 +153,9 @@ class TCPRelayHandler(object): |
|
|
uncomplete = True |
|
|
uncomplete = True |
|
|
else: |
|
|
else: |
|
|
logging.error(e) |
|
|
logging.error(e) |
|
|
|
|
|
traceback.print_exc() |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
return False |
|
|
if uncomplete: |
|
|
if uncomplete: |
|
|
if sock == self._local_sock: |
|
|
if sock == self._local_sock: |
|
|
self._data_to_write_to_local.append(data) |
|
|
self._data_to_write_to_local.append(data) |
|
@ -168,6 +172,7 @@ class TCPRelayHandler(object): |
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READING) |
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READING) |
|
|
else: |
|
|
else: |
|
|
logging.error('write_all_to_sock:unknown socket') |
|
|
logging.error('write_all_to_sock:unknown socket') |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def _handle_stage_reply(self, data): |
|
|
def _handle_stage_reply(self, data): |
|
|
if self._is_local: |
|
|
if self._is_local: |
|
@ -199,6 +204,7 @@ class TCPRelayHandler(object): |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
else: |
|
|
else: |
|
|
logging.error(e) |
|
|
logging.error(e) |
|
|
|
|
|
traceback.print_exc() |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
|
|
|
def _handle_stage_hello(self, data): |
|
|
def _handle_stage_hello(self, data): |
|
@ -274,8 +280,8 @@ class TCPRelayHandler(object): |
|
|
self._stage = STAGE_REPLY |
|
|
self._stage = STAGE_REPLY |
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) |
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) |
|
|
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) |
|
|
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) |
|
|
except Exception: |
|
|
except Exception as e: |
|
|
import traceback |
|
|
logging.error(e) |
|
|
traceback.print_exc() |
|
|
traceback.print_exc() |
|
|
# TODO use logging when debug completed |
|
|
# TODO use logging when debug completed |
|
|
self.destroy() |
|
|
self.destroy() |
|
@ -333,8 +339,8 @@ class TCPRelayHandler(object): |
|
|
data = self._encryptor.encrypt(data) |
|
|
data = self._encryptor.encrypt(data) |
|
|
try: |
|
|
try: |
|
|
self._write_to_sock(data, self._local_sock) |
|
|
self._write_to_sock(data, self._local_sock) |
|
|
except Exception: |
|
|
except Exception as e: |
|
|
import traceback |
|
|
logging.error(e) |
|
|
traceback.print_exc() |
|
|
traceback.print_exc() |
|
|
# TODO use logging when debug completed |
|
|
# TODO use logging when debug completed |
|
|
self.destroy() |
|
|
self.destroy() |
|
@ -358,34 +364,49 @@ class TCPRelayHandler(object): |
|
|
|
|
|
|
|
|
def _on_local_error(self): |
|
|
def _on_local_error(self): |
|
|
if self._local_sock: |
|
|
if self._local_sock: |
|
|
|
|
|
logging.debug('got local error') |
|
|
logging.error(eventloop.get_sock_error(self._local_sock)) |
|
|
logging.error(eventloop.get_sock_error(self._local_sock)) |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
|
|
|
def _on_remote_error(self): |
|
|
def _on_remote_error(self): |
|
|
if self._remote_sock: |
|
|
if self._remote_sock: |
|
|
|
|
|
logging.debug('got remote error') |
|
|
logging.error(eventloop.get_sock_error(self._remote_sock)) |
|
|
logging.error(eventloop.get_sock_error(self._remote_sock)) |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
|
|
|
def handle_event(self, sock, event): |
|
|
def handle_event(self, sock, event): |
|
|
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
|
|
return |
|
|
# order is important |
|
|
# order is important |
|
|
if sock == self._remote_sock: |
|
|
if sock == self._remote_sock: |
|
|
|
|
|
if event & eventloop.POLL_ERR: |
|
|
|
|
|
self._on_remote_error() |
|
|
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
|
|
return |
|
|
if event & eventloop.POLL_IN: |
|
|
if event & eventloop.POLL_IN: |
|
|
self._on_remote_read() |
|
|
self._on_remote_read() |
|
|
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
|
|
return |
|
|
if event & eventloop.POLL_OUT: |
|
|
if event & eventloop.POLL_OUT: |
|
|
self._on_remote_write() |
|
|
self._on_remote_write() |
|
|
if event & eventloop.POLL_ERR: |
|
|
|
|
|
self._on_remote_error() |
|
|
|
|
|
elif sock == self._local_sock: |
|
|
elif sock == self._local_sock: |
|
|
|
|
|
if event & eventloop.POLL_ERR: |
|
|
|
|
|
self._on_local_error() |
|
|
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
|
|
return |
|
|
if event & eventloop.POLL_IN: |
|
|
if event & eventloop.POLL_IN: |
|
|
self._on_local_read() |
|
|
self._on_local_read() |
|
|
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
|
|
return |
|
|
if event & eventloop.POLL_OUT: |
|
|
if event & eventloop.POLL_OUT: |
|
|
self._on_local_write() |
|
|
self._on_local_write() |
|
|
if event & eventloop.POLL_ERR: |
|
|
|
|
|
self._on_local_error() |
|
|
|
|
|
else: |
|
|
else: |
|
|
logging.warn('unknown socket') |
|
|
logging.warn('unknown socket') |
|
|
|
|
|
|
|
|
def destroy(self): |
|
|
def destroy(self): |
|
|
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
|
|
return |
|
|
|
|
|
self._stage = STAGE_DESTROYED |
|
|
if self._remote_address: |
|
|
if self._remote_address: |
|
|
logging.debug('destroy: %s:%d' % |
|
|
logging.debug('destroy: %s:%d' % |
|
|
self._remote_address) |
|
|
self._remote_address) |
|
@ -510,9 +531,9 @@ class TCPRelay(object): |
|
|
|
|
|
|
|
|
def _handle_events(self, events): |
|
|
def _handle_events(self, events): |
|
|
for sock, fd, event in events: |
|
|
for sock, fd, event in events: |
|
|
# if sock: |
|
|
if sock: |
|
|
# logging.debug('fd %d %s', fd, |
|
|
logging.debug('fd %d %s', fd, |
|
|
# eventloop.EVENT_NAMES.get(event, event)) |
|
|
eventloop.EVENT_NAMES.get(event, event)) |
|
|
if sock == self._server_socket: |
|
|
if sock == self._server_socket: |
|
|
if event & eventloop.POLL_ERR: |
|
|
if event & eventloop.POLL_ERR: |
|
|
# TODO |
|
|
# TODO |
|
@ -528,6 +549,7 @@ class TCPRelay(object): |
|
|
continue |
|
|
continue |
|
|
else: |
|
|
else: |
|
|
logging.error(e) |
|
|
logging.error(e) |
|
|
|
|
|
traceback.print_exc() |
|
|
else: |
|
|
else: |
|
|
if sock: |
|
|
if sock: |
|
|
handler = self._fd_to_handlers.get(fd, None) |
|
|
handler = self._fd_to_handlers.get(fd, None) |
|
|