|
@ -30,6 +30,7 @@ import threading |
|
|
import eventloop |
|
|
import eventloop |
|
|
from common import parse_header |
|
|
from common import parse_header |
|
|
|
|
|
|
|
|
|
|
|
CMD_CONNECT = 1 |
|
|
|
|
|
|
|
|
# local: |
|
|
# local: |
|
|
# stage 0 init |
|
|
# stage 0 init |
|
@ -42,6 +43,18 @@ from common import parse_header |
|
|
# stage 4 addr received, reply sent |
|
|
# stage 4 addr received, reply sent |
|
|
# stage 5 remote connected |
|
|
# stage 5 remote connected |
|
|
|
|
|
|
|
|
|
|
|
STAGE_INIT = 0 |
|
|
|
|
|
STAGE_HELLO = 1 |
|
|
|
|
|
STAGE_REPLY = 4 |
|
|
|
|
|
STAGE_STREAM = 5 |
|
|
|
|
|
|
|
|
|
|
|
# stream direction |
|
|
|
|
|
STREAM_UP = 0 |
|
|
|
|
|
STREAM_DOWN = 1 |
|
|
|
|
|
|
|
|
|
|
|
# stream status |
|
|
|
|
|
STATUS_WAIT_READING = 0 |
|
|
|
|
|
STATUS_WAIT_WRITING = 1 |
|
|
|
|
|
|
|
|
BUF_SIZE = 8 * 1024 |
|
|
BUF_SIZE = 8 * 1024 |
|
|
|
|
|
|
|
@ -54,34 +67,69 @@ class TCPRelayHandler(object): |
|
|
self._remote_sock = None |
|
|
self._remote_sock = None |
|
|
self._config = config |
|
|
self._config = config |
|
|
self._is_local = is_local |
|
|
self._is_local = is_local |
|
|
self._stage = 0 |
|
|
self._stage = STAGE_INIT |
|
|
self._encryptor = encrypt.Encryptor(config['password'], |
|
|
self._encryptor = encrypt.Encryptor(config['password'], |
|
|
config['method']) |
|
|
config['method']) |
|
|
self._data_to_write_to_local = [] |
|
|
self._data_to_write_to_local = [] |
|
|
self._data_to_write_to_remote = [] |
|
|
self._data_to_write_to_remote = [] |
|
|
|
|
|
self._upstream_status = STATUS_WAIT_READING |
|
|
|
|
|
self._downstream_status = STATUS_WAIT_READING |
|
|
fd_to_handlers[local_sock.fileno()] = self |
|
|
fd_to_handlers[local_sock.fileno()] = self |
|
|
local_sock.setblocking(False) |
|
|
local_sock.setblocking(False) |
|
|
loop.add(local_sock, eventloop.POLL_IN) |
|
|
loop.add(local_sock, eventloop.POLL_IN) |
|
|
|
|
|
|
|
|
def resume_reading(self, sock): |
|
|
def update_stream(self, stream, status): |
|
|
pass |
|
|
dirty = False |
|
|
|
|
|
if stream == STREAM_DOWN: |
|
|
def pause_reading(self, sock): |
|
|
if self._downstream_status != status: |
|
|
pass |
|
|
self._downstream_status = status |
|
|
|
|
|
dirty = True |
|
|
def resume_writing(self, sock): |
|
|
elif stream == STREAM_UP: |
|
|
pass |
|
|
if self._upstream_status != status: |
|
|
|
|
|
self._upstream_status = status |
|
|
def pause_writing(self, sock): |
|
|
dirty = True |
|
|
pass |
|
|
if dirty: |
|
|
|
|
|
if self._local_sock: |
|
|
|
|
|
event = eventloop.POLL_ERR |
|
|
|
|
|
if self._downstream_status == STATUS_WAIT_WRITING: |
|
|
|
|
|
event |= eventloop.POLL_OUT |
|
|
|
|
|
if self._upstream_status == STATUS_WAIT_READING: |
|
|
|
|
|
event |= eventloop.POLL_IN |
|
|
|
|
|
self._loop.modify(self._local_sock, event) |
|
|
|
|
|
if self._remote_sock: |
|
|
|
|
|
event = eventloop.POLL_ERR |
|
|
|
|
|
if self._downstream_status == STATUS_WAIT_READING: |
|
|
|
|
|
event |= eventloop.POLL_IN |
|
|
|
|
|
if self._upstream_status == STATUS_WAIT_WRITING: |
|
|
|
|
|
event |= eventloop.POLL_OUT |
|
|
|
|
|
self._loop.modify(self._remote_sock, event) |
|
|
|
|
|
|
|
|
def write_all_to_sock(self, data, sock): |
|
|
def write_all_to_sock(self, data, sock): |
|
|
# write to sock |
|
|
if not data or not sock: |
|
|
# put remaining bytes into buffer |
|
|
return |
|
|
# return true if all written |
|
|
uncomplete = False |
|
|
# return false if some bytes left in buffer |
|
|
try: |
|
|
# raise if encounter error |
|
|
l = len(data) |
|
|
return True |
|
|
s = sock.send(data) |
|
|
|
|
|
if s < l: |
|
|
|
|
|
data = data[s:] |
|
|
|
|
|
uncomplete = True |
|
|
|
|
|
except (OSError, IOError) as e: |
|
|
|
|
|
error_no = eventloop.errno_from_exception(e) |
|
|
|
|
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS): |
|
|
|
|
|
uncomplete = True |
|
|
|
|
|
else: |
|
|
|
|
|
logging.error(e) |
|
|
|
|
|
self.destroy() |
|
|
|
|
|
if uncomplete: |
|
|
|
|
|
if sock == self._local_sock: |
|
|
|
|
|
self._data_to_write_to_local.append(data) |
|
|
|
|
|
self.update_stream(STREAM_DOWN, STATUS_WAIT_WRITING) |
|
|
|
|
|
elif sock == self._remote_sock: |
|
|
|
|
|
self._data_to_write_to_remote.append(data) |
|
|
|
|
|
self.update_stream(STREAM_UP, STATUS_WAIT_WRITING) |
|
|
|
|
|
else: |
|
|
|
|
|
logging.error('write_all_to_sock:unknown socket') |
|
|
|
|
|
|
|
|
def on_local_read(self): |
|
|
def on_local_read(self): |
|
|
if not self._local_sock: |
|
|
if not self._local_sock: |
|
@ -90,26 +138,25 @@ class TCPRelayHandler(object): |
|
|
data = self._local_sock.recv(BUF_SIZE) |
|
|
data = self._local_sock.recv(BUF_SIZE) |
|
|
if not is_local: |
|
|
if not is_local: |
|
|
data = self._encryptor.decrypt(data) |
|
|
data = self._encryptor.decrypt(data) |
|
|
if self._stage == 5: |
|
|
if self._stage == STAGE_STREAM: |
|
|
if self._is_local: |
|
|
if self._is_local: |
|
|
data = self._encryptor.encrypt(data) |
|
|
data = self._encryptor.encrypt(data) |
|
|
if not self.write_all_to_sock(data, self._remote_sock): |
|
|
self.write_all_to_sock(data, self._remote_sock) |
|
|
self.pause_reading(self._local_sock) |
|
|
|
|
|
return |
|
|
return |
|
|
if is_local and self._stage == 0: |
|
|
if is_local and self._stage == STAGE_INIT: |
|
|
# TODO check auth method |
|
|
# TODO check auth method |
|
|
self.write_all_to_sock('\x05\00', self._local_sock) |
|
|
self.write_all_to_sock('\x05\00', self._local_sock) |
|
|
self._stage = 1 |
|
|
self._stage = STAGE_HELLO |
|
|
return |
|
|
return |
|
|
if self._stage == 4: |
|
|
if self._stage == STAGE_REPLY: |
|
|
self._data_to_write_to_remote.append(data) |
|
|
self._data_to_write_to_remote.append(data) |
|
|
if (is_local and self._stage == 0) or \ |
|
|
if (is_local and self._stage == STAGE_HELLO) or \ |
|
|
(not is_local and self._stage == 1): |
|
|
(not is_local and self._stage == STAGE_INIT): |
|
|
try: |
|
|
try: |
|
|
if is_local: |
|
|
if is_local: |
|
|
cmd = ord(data[1]) |
|
|
cmd = ord(data[1]) |
|
|
# TODO check cmd == 1 |
|
|
# TODO check cmd == 1 |
|
|
assert cmd == 1 |
|
|
assert cmd == CMD_CONNECT |
|
|
# just trim VER CMD RSV |
|
|
# just trim VER CMD RSV |
|
|
data = data[3:] |
|
|
data = data[3:] |
|
|
header_result = parse_header(data) |
|
|
header_result = parse_header(data) |
|
@ -145,7 +192,7 @@ class TCPRelayHandler(object): |
|
|
self._data_to_write_to_remote.append(data[header_length:]) |
|
|
self._data_to_write_to_remote.append(data[header_length:]) |
|
|
|
|
|
|
|
|
self._stage = 4 |
|
|
self._stage = 4 |
|
|
self.pause_reading(self._local_sock) |
|
|
self.update_stream(STREAM_UP, STATUS_WAIT_WRITING) |
|
|
return |
|
|
return |
|
|
except Exception: |
|
|
except Exception: |
|
|
import traceback |
|
|
import traceback |
|
@ -153,7 +200,7 @@ class TCPRelayHandler(object): |
|
|
# TODO use logging when debug completed |
|
|
# TODO use logging when debug completed |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
|
|
|
if self._stage == 4: |
|
|
elif self._stage == STAGE_REPLY: |
|
|
self._data_to_write_to_remote.append(data) |
|
|
self._data_to_write_to_remote.append(data) |
|
|
|
|
|
|
|
|
def on_remote_read(self): |
|
|
def on_remote_read(self): |
|
@ -161,9 +208,7 @@ class TCPRelayHandler(object): |
|
|
if self._is_local: |
|
|
if self._is_local: |
|
|
data = self._encryptor.decrypt(data) |
|
|
data = self._encryptor.decrypt(data) |
|
|
try: |
|
|
try: |
|
|
if not self.write_all_to_sock(data, self._local_sock): |
|
|
self.write_all_to_sock(data, self._local_sock) |
|
|
self.pause_reading(self._remote_sock) |
|
|
|
|
|
self.resume_writing(self._local_sock) |
|
|
|
|
|
except Exception: |
|
|
except Exception: |
|
|
import traceback |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
traceback.print_exc() |
|
@ -172,26 +217,26 @@ class TCPRelayHandler(object): |
|
|
|
|
|
|
|
|
def on_local_write(self): |
|
|
def on_local_write(self): |
|
|
if self._data_to_write_to_local: |
|
|
if self._data_to_write_to_local: |
|
|
written = self.write_all_to_sock( |
|
|
data = ''.join(self._data_to_write_to_local) |
|
|
''.join(self._data_to_write_to_local), self._local_sock) |
|
|
self._data_to_write_to_local = [] |
|
|
if written: |
|
|
self.write_all_to_sock(data, self._local_sock) |
|
|
self.pause_writing(self._local_sock) |
|
|
|
|
|
else: |
|
|
else: |
|
|
self.pause_writing(self._local_sock) |
|
|
self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) |
|
|
|
|
|
|
|
|
def on_remote_write(self): |
|
|
def on_remote_write(self): |
|
|
if self._data_to_write_to_remote: |
|
|
if self._data_to_write_to_remote: |
|
|
written = self.write_all_to_sock( |
|
|
data = ''.join(self._data_to_write_to_remote) |
|
|
''.join(self._data_to_write_to_remote), self._remote_sock) |
|
|
self._data_to_write_to_remote = [] |
|
|
if written: |
|
|
self.write_all_to_sock(data, self._remote_sock) |
|
|
self.pause_writing(self._remote_sock) |
|
|
|
|
|
else: |
|
|
else: |
|
|
self.pause_writing(self._remote_sock) |
|
|
self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) |
|
|
|
|
|
|
|
|
def on_local_error(self): |
|
|
def on_local_error(self): |
|
|
|
|
|
logging.error(eventloop.get_sock_error(self._local_sock)) |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
|
|
|
def on_remote_error(self): |
|
|
def on_remote_error(self): |
|
|
|
|
|
logging.error(eventloop.get_sock_error(self._remote_sock)) |
|
|
self.destroy() |
|
|
self.destroy() |
|
|
|
|
|
|
|
|
def handle_event(self, sock, event): |
|
|
def handle_event(self, sock, event): |
|
@ -276,7 +321,7 @@ class TCPRelay(object): |
|
|
conn, self._config, self._is_local) |
|
|
conn, self._config, self._is_local) |
|
|
except (OSError, IOError) as e: |
|
|
except (OSError, IOError) as e: |
|
|
error_no = eventloop.errno_from_exception(e) |
|
|
error_no = eventloop.errno_from_exception(e) |
|
|
if error_no in [errno.EAGAIN, errno.EINPROGRESS]: |
|
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS): |
|
|
continue |
|
|
continue |
|
|
else: |
|
|
else: |
|
|
handler = self._fd_to_handlers.get(sock.fileno(), None) |
|
|
handler = self._fd_to_handlers.get(sock.fileno(), None) |
|
@ -291,6 +336,7 @@ class TCPRelay(object): |
|
|
last_time = now |
|
|
last_time = now |
|
|
|
|
|
|
|
|
def start(self): |
|
|
def start(self): |
|
|
|
|
|
# TODO combine loops on multiple ports into one single loop |
|
|
if self._closed: |
|
|
if self._closed: |
|
|
raise Exception('closed') |
|
|
raise Exception('closed') |
|
|
t = threading.Thread(target=self._run) |
|
|
t = threading.Thread(target=self._run) |
|
|