diff --git a/shadowsocks/common.py b/shadowsocks/common.py index ca5fb21..4f7aac6 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -70,19 +70,24 @@ if not hasattr(socket, 'inet_ntop'): socket.inet_ntop = inet_ntop +ADDRTYPE_IPV4 = 1 +ADDRTYPE_IPV6 = 4 +ADDRTYPE_HOST = 3 + + def parse_header(data): addrtype = ord(data[0]) dest_addr = None dest_port = None header_length = 0 - if addrtype == 1: + if addrtype == ADDRTYPE_IPV4: if len(data) >= 7: dest_addr = socket.inet_ntoa(data[1:5]) dest_port = struct.unpack('>H', data[5:7])[0] header_length = 7 else: logging.warn('header is too short') - elif addrtype == 3: + elif addrtype == ADDRTYPE_HOST: if len(data) > 2: addrlen = ord(data[1]) if len(data) >= 2 + addrlen: @@ -94,7 +99,7 @@ def parse_header(data): logging.warn('header is too short') else: logging.warn('header is too short') - elif addrtype == 4: + elif addrtype == ADDRTYPE_IPV6: if len(data) >= 19: dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17]) dest_port = struct.unpack('>H', data[17:19])[0] diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 1c48718..b432252 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -26,6 +26,7 @@ import os +import socket import select from collections import defaultdict @@ -192,7 +193,5 @@ def errno_from_exception(e): # from tornado def get_sock_error(sock): - errno = sock.getsockopt(socket.SOL_SOCKET, - socket.SO_ERROR) + errno = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) return socket.error(errno, os.strerror(errno)) - diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 039eed9..88abf30 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -30,6 +30,7 @@ import threading import eventloop from common import parse_header +CMD_CONNECT = 1 # local: # stage 0 init @@ -42,6 +43,18 @@ from common import parse_header # stage 4 addr received, reply sent # stage 5 remote connected +STAGE_INIT = 0 +STAGE_HELLO = 1 +STAGE_REPLY = 4 +STAGE_STREAM = 5 + +# stream direction +STREAM_UP = 0 +STREAM_DOWN = 1 + +# stream status +STATUS_WAIT_READING = 0 +STATUS_WAIT_WRITING = 1 BUF_SIZE = 8 * 1024 @@ -54,34 +67,69 @@ class TCPRelayHandler(object): self._remote_sock = None self._config = config self._is_local = is_local - self._stage = 0 + self._stage = STAGE_INIT self._encryptor = encrypt.Encryptor(config['password'], config['method']) self._data_to_write_to_local = [] self._data_to_write_to_remote = [] + self._upstream_status = STATUS_WAIT_READING + self._downstream_status = STATUS_WAIT_READING fd_to_handlers[local_sock.fileno()] = self local_sock.setblocking(False) loop.add(local_sock, eventloop.POLL_IN) - def resume_reading(self, sock): - pass - - def pause_reading(self, sock): - pass - - def resume_writing(self, sock): - pass - - def pause_writing(self, sock): - pass + def update_stream(self, stream, status): + dirty = False + if stream == STREAM_DOWN: + if self._downstream_status != status: + self._downstream_status = status + dirty = True + elif stream == STREAM_UP: + if self._upstream_status != status: + self._upstream_status = status + dirty = True + if dirty: + if self._local_sock: + event = eventloop.POLL_ERR + if self._downstream_status == STATUS_WAIT_WRITING: + event |= eventloop.POLL_OUT + if self._upstream_status == STATUS_WAIT_READING: + event |= eventloop.POLL_IN + self._loop.modify(self._local_sock, event) + if self._remote_sock: + event = eventloop.POLL_ERR + if self._downstream_status == STATUS_WAIT_READING: + event |= eventloop.POLL_IN + if self._upstream_status == STATUS_WAIT_WRITING: + event |= eventloop.POLL_OUT + self._loop.modify(self._remote_sock, event) def write_all_to_sock(self, data, sock): - # write to sock - # put remaining bytes into buffer - # return true if all written - # return false if some bytes left in buffer - # raise if encounter error - return True + if not data or not sock: + return + uncomplete = False + try: + l = len(data) + s = sock.send(data) + if s < l: + data = data[s:] + uncomplete = True + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS): + uncomplete = True + else: + logging.error(e) + self.destroy() + if uncomplete: + if sock == self._local_sock: + self._data_to_write_to_local.append(data) + self.update_stream(STREAM_DOWN, STATUS_WAIT_WRITING) + elif sock == self._remote_sock: + self._data_to_write_to_remote.append(data) + self.update_stream(STREAM_UP, STATUS_WAIT_WRITING) + else: + logging.error('write_all_to_sock:unknown socket') def on_local_read(self): if not self._local_sock: @@ -90,26 +138,25 @@ class TCPRelayHandler(object): data = self._local_sock.recv(BUF_SIZE) if not is_local: data = self._encryptor.decrypt(data) - if self._stage == 5: + if self._stage == STAGE_STREAM: if self._is_local: data = self._encryptor.encrypt(data) - if not self.write_all_to_sock(data, self._remote_sock): - self.pause_reading(self._local_sock) + self.write_all_to_sock(data, self._remote_sock) return - if is_local and self._stage == 0: + if is_local and self._stage == STAGE_INIT: # TODO check auth method self.write_all_to_sock('\x05\00', self._local_sock) - self._stage = 1 + self._stage = STAGE_HELLO return - if self._stage == 4: + if self._stage == STAGE_REPLY: self._data_to_write_to_remote.append(data) - if (is_local and self._stage == 0) or \ - (not is_local and self._stage == 1): + if (is_local and self._stage == STAGE_HELLO) or \ + (not is_local and self._stage == STAGE_INIT): try: if is_local: cmd = ord(data[1]) # TODO check cmd == 1 - assert cmd == 1 + assert cmd == CMD_CONNECT # just trim VER CMD RSV data = data[3:] header_result = parse_header(data) @@ -145,7 +192,7 @@ class TCPRelayHandler(object): self._data_to_write_to_remote.append(data[header_length:]) self._stage = 4 - self.pause_reading(self._local_sock) + self.update_stream(STREAM_UP, STATUS_WAIT_WRITING) return except Exception: import traceback @@ -153,7 +200,7 @@ class TCPRelayHandler(object): # TODO use logging when debug completed self.destroy() - if self._stage == 4: + elif self._stage == STAGE_REPLY: self._data_to_write_to_remote.append(data) def on_remote_read(self): @@ -161,9 +208,7 @@ class TCPRelayHandler(object): if self._is_local: data = self._encryptor.decrypt(data) try: - if not self.write_all_to_sock(data, self._local_sock): - self.pause_reading(self._remote_sock) - self.resume_writing(self._local_sock) + self.write_all_to_sock(data, self._local_sock) except Exception: import traceback traceback.print_exc() @@ -172,26 +217,26 @@ class TCPRelayHandler(object): def on_local_write(self): if self._data_to_write_to_local: - written = self.write_all_to_sock( - ''.join(self._data_to_write_to_local), self._local_sock) - if written: - self.pause_writing(self._local_sock) + data = ''.join(self._data_to_write_to_local) + self._data_to_write_to_local = [] + self.write_all_to_sock(data, self._local_sock) else: - self.pause_writing(self._local_sock) + self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) def on_remote_write(self): if self._data_to_write_to_remote: - written = self.write_all_to_sock( - ''.join(self._data_to_write_to_remote), self._remote_sock) - if written: - self.pause_writing(self._remote_sock) + data = ''.join(self._data_to_write_to_remote) + self._data_to_write_to_remote = [] + self.write_all_to_sock(data, self._remote_sock) else: - self.pause_writing(self._remote_sock) + self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) def on_local_error(self): + logging.error(eventloop.get_sock_error(self._local_sock)) self.destroy() def on_remote_error(self): + logging.error(eventloop.get_sock_error(self._remote_sock)) self.destroy() def handle_event(self, sock, event): @@ -276,7 +321,7 @@ class TCPRelay(object): conn, self._config, self._is_local) except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) - if error_no in [errno.EAGAIN, errno.EINPROGRESS]: + if error_no in (errno.EAGAIN, errno.EINPROGRESS): continue else: handler = self._fd_to_handlers.get(sock.fileno(), None) @@ -291,6 +336,7 @@ class TCPRelay(object): last_time = now def start(self): + # TODO combine loops on multiple ports into one single loop if self._closed: raise Exception('closed') t = threading.Thread(target=self._run)