diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index dcfa8ed..298b9dd 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -94,7 +94,7 @@ class TCPRelayHandler(object): local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) self.last_activity = 0 - self.update_activity() + self._update_activity() def __hash__(self): # default __hash__ is id / 16 @@ -105,10 +105,10 @@ class TCPRelayHandler(object): def remote_address(self): return self._remote_address - def update_activity(self): + def _update_activity(self): self._server.update_activity(self) - def update_stream(self, stream, status): + def _update_stream(self, stream, status): dirty = False if stream == STREAM_DOWN: if self._downstream_status != status: @@ -134,7 +134,7 @@ class TCPRelayHandler(object): event |= eventloop.POLL_OUT self._loop.modify(self._remote_sock, event) - def write_to_sock(self, data, sock): + def _write_to_sock(self, data, sock): if not data or not sock: return uncomplete = False @@ -154,22 +154,133 @@ class TCPRelayHandler(object): if uncomplete: if sock == self._local_sock: self._data_to_write_to_local.append(data) - self.update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) elif sock == self._remote_sock: self._data_to_write_to_remote.append(data) - self.update_stream(STREAM_UP, WAIT_STATUS_WRITING) + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) else: logging.error('write_all_to_sock:unknown socket') else: if sock == self._local_sock: - self.update_stream(STREAM_DOWN, WAIT_STATUS_READING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) elif sock == self._remote_sock: - self.update_stream(STREAM_UP, WAIT_STATUS_READING) + self._update_stream(STREAM_UP, WAIT_STATUS_READING) else: logging.error('write_all_to_sock:unknown socket') - def on_local_read(self): - self.update_activity() + def _handle_stage_reply(self, data): + if self._is_local: + data = self._encryptor.encrypt(data) + self._data_to_write_to_remote.append(data) + if self._is_local and self._upstream_status == WAIT_STATUS_INIT and \ + self._config['fast_open']: + try: + data = ''.join(self._data_to_write_to_local) + l = len(data) + s = self._remote_sock.sendto(data, MSG_FASTOPEN, + self.remote_address) + if s < l: + data = data[s:] + self._data_to_write_to_local = [data] + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + else: + self._data_to_write_to_local = [] + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + self._stage = STAGE_STREAM + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == errno.EINPROGRESS: + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + elif eventloop.errno_from_exception(e) == errno.ENOTCONN: + logging.error('fast open not supported on this OS') + self._config['fast_open'] = False + self.destroy() + else: + logging.error(e) + self.destroy() + + def _handle_stage_hello(self, data): + try: + if self._is_local: + cmd = ord(data[1]) + if cmd == CMD_UDP_ASSOCIATE: + logging.debug('UDP associate') + if self._local_sock.family == socket.AF_INET6: + header = '\x05\x00\x00\x04' + else: + header = '\x05\x00\x00\x01' + addr, port = self._local_sock.getsockname() + addr_to_send = socket.inet_pton(self._local_sock.family, + addr) + port_to_send = struct.pack('>H', port) + self._write_to_sock(header + addr_to_send + port_to_send, + self._local_sock) + self._stage = STAGE_UDP_ASSOC + # just wait for the client to disconnect + return + elif cmd == CMD_CONNECT: + # just trim VER CMD RSV + data = data[3:] + else: + logging.error('unknown command %d', cmd) + self.destroy() + return + header_result = parse_header(data) + if header_result is None: + raise Exception('can not parse header') + addrtype, remote_addr, remote_port, header_length = header_result + logging.info('connecting %s:%d' % (remote_addr, remote_port)) + self._remote_address = (remote_addr, remote_port) + if self._is_local: + # forward address to remote + self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10', + self._local_sock) + data_to_send = self._encryptor.encrypt(data) + self._data_to_write_to_remote.append(data_to_send) + remote_addr = self._config['server'] + remote_port = self._config['server_port'] + else: + if len(data) > header_length: + self._data_to_write_to_remote.append(data[header_length:]) + + # TODO async DNS + addrs = socket.getaddrinfo(remote_addr, remote_port, 0, + socket.SOCK_STREAM, socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("can't get addrinfo for %s:%d" % + (remote_addr, remote_port)) + af, socktype, proto, canonname, sa = addrs[0] + remote_sock = socket.socket(af, socktype, proto) + self._remote_sock = remote_sock + self._fd_to_handlers[remote_sock.fileno()] = self + remote_sock.setblocking(False) + remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + + if self._is_local and self._config['fast_open']: + # wait for more data to arrive and send them in one SYN + self._stage = STAGE_REPLY + # TODO when there is already data in this packet + else: + try: + remote_sock.connect(sa) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == errno.EINPROGRESS: + pass + self._loop.add(remote_sock, + eventloop.POLL_ERR | eventloop.POLL_OUT) + self._stage = STAGE_REPLY + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + except Exception: + import traceback + traceback.print_exc() + # TODO use logging when debug completed + self.destroy() + + def _on_local_read(self): + self._update_activity() if not self._local_sock: return is_local = self._is_local @@ -190,131 +301,21 @@ class TCPRelayHandler(object): if self._stage == STAGE_STREAM: if self._is_local: data = self._encryptor.encrypt(data) - self.write_to_sock(data, self._remote_sock) + self._write_to_sock(data, self._remote_sock) return elif is_local and self._stage == STAGE_INIT: # TODO check auth method - self.write_to_sock('\x05\00', self._local_sock) + self._write_to_sock('\x05\00', self._local_sock) self._stage = STAGE_HELLO return elif self._stage == STAGE_REPLY: - if is_local: - data = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data) - if is_local and self._upstream_status == WAIT_STATUS_INIT and \ - self._config['fast_open']: - try: - data = ''.join(self._data_to_write_to_local) - l = len(data) - s = self._remote_sock.sendto(data, MSG_FASTOPEN, - self.remote_address) - if s < l: - data = data[s:] - self._data_to_write_to_local = [data] - self.update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self.update_stream(STREAM_DOWN, WAIT_STATUS_READING) - else: - self._data_to_write_to_local = [] - self.update_stream(STREAM_UP, WAIT_STATUS_READING) - self.update_stream(STREAM_DOWN, WAIT_STATUS_READING) - self._stage = STAGE_STREAM - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == errno.EINPROGRESS: - self.update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self.update_stream(STREAM_DOWN, WAIT_STATUS_READING) - elif eventloop.errno_from_exception(e) == errno.ENOTCONN: - logging.error('fast open not supported on this OS') - self._config['fast_open'] = False - self.destroy() - else: - logging.error(e) - self.destroy() + self._handle_stage_reply(data) elif (is_local and self._stage == STAGE_HELLO) or \ (not is_local and self._stage == STAGE_INIT): - try: - if is_local: - cmd = ord(data[1]) - if cmd == CMD_UDP_ASSOCIATE: - logging.debug('UDP associate') - if self._local_sock.family == socket.AF_INET6: - header = '\x05\x00\x00\x04' - else: - header = '\x05\x00\x00\x01' - addr, port = self._local_sock.getsockname() - addr_to_send = socket.inet_pton(self._local_sock.family, - addr) - port_to_send = struct.pack('>H', port) - self.write_to_sock(header + addr_to_send + port_to_send, - self._local_sock) - self._stage = STAGE_UDP_ASSOC - # just wait for the client to disconnect - return - elif cmd == CMD_CONNECT: - # just trim VER CMD RSV - data = data[3:] - else: - logging.error('unknown command %d', cmd) - self.destroy() - return - header_result = parse_header(data) - if header_result is None: - raise Exception('can not parse header') - addrtype, remote_addr, remote_port, header_length =\ - header_result - logging.info('connecting %s:%d' % (remote_addr, remote_port)) - self._remote_address = (remote_addr, remote_port) - if is_local: - # forward address to remote - self.write_to_sock('\x05\x00\x00\x01' + - '\x00\x00\x00\x00\x10\x10', - self._local_sock) - data_to_send = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data_to_send) - remote_addr = self._config['server'] - remote_port = self._config['server_port'] - else: - if len(data) > header_length: - self._data_to_write_to_remote.append( - data[header_length:]) - - # TODO async DNS - addrs = socket.getaddrinfo(remote_addr, remote_port, 0, - socket.SOCK_STREAM, socket.SOL_TCP) - if len(addrs) == 0: - raise Exception("can't get addrinfo for %s:%d" % - (remote_addr, remote_port)) - af, socktype, proto, canonname, sa = addrs[0] - remote_sock = socket.socket(af, socktype, proto) - self._remote_sock = remote_sock - self._fd_to_handlers[remote_sock.fileno()] = self - remote_sock.setblocking(False) - remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - - if self._is_local and self._config['fast_open']: - # wait for more data to arrive and send them in one SYN - self._stage = STAGE_REPLY - # TODO when there is already data in this packet - else: - try: - remote_sock.connect(sa) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == \ - errno.EINPROGRESS: - pass - self._loop.add(remote_sock, - eventloop.POLL_ERR | eventloop.POLL_OUT) - self._stage = STAGE_REPLY - self.update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self.update_stream(STREAM_DOWN, WAIT_STATUS_READING) - return - except Exception: - import traceback - traceback.print_exc() - # TODO use logging when debug completed - self.destroy() + self._handle_stage_hello(data) - def on_remote_read(self): - self.update_activity() + def _on_remote_read(self): + self._update_activity() data = None try: data = self._remote_sock.recv(BUF_SIZE) @@ -330,36 +331,36 @@ class TCPRelayHandler(object): else: data = self._encryptor.encrypt(data) try: - self.write_to_sock(data, self._local_sock) + self._write_to_sock(data, self._local_sock) except Exception: import traceback traceback.print_exc() # TODO use logging when debug completed self.destroy() - def on_local_write(self): + def _on_local_write(self): if self._data_to_write_to_local: data = ''.join(self._data_to_write_to_local) self._data_to_write_to_local = [] - self.write_to_sock(data, self._local_sock) + self._write_to_sock(data, self._local_sock) else: - self.update_stream(STREAM_DOWN, WAIT_STATUS_READING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - def on_remote_write(self): + def _on_remote_write(self): self._stage = STAGE_STREAM if self._data_to_write_to_remote: data = ''.join(self._data_to_write_to_remote) self._data_to_write_to_remote = [] - self.write_to_sock(data, self._remote_sock) + self._write_to_sock(data, self._remote_sock) else: - self.update_stream(STREAM_UP, WAIT_STATUS_READING) + self._update_stream(STREAM_UP, WAIT_STATUS_READING) - def on_local_error(self): + def _on_local_error(self): if self._local_sock: logging.error(eventloop.get_sock_error(self._local_sock)) self.destroy() - def on_remote_error(self): + def _on_remote_error(self): if self._remote_sock: logging.error(eventloop.get_sock_error(self._remote_sock)) self.destroy() @@ -368,18 +369,18 @@ class TCPRelayHandler(object): # order is important if sock == self._remote_sock: if event & eventloop.POLL_IN: - self.on_remote_read() + self._on_remote_read() if event & eventloop.POLL_OUT: - self.on_remote_write() + self._on_remote_write() if event & eventloop.POLL_ERR: - self.on_remote_error() + self._on_remote_error() elif sock == self._local_sock: if event & eventloop.POLL_IN: - self.on_local_read() + self._on_local_read() if event & eventloop.POLL_OUT: - self.on_local_write() + self._on_local_write() if event & eventloop.POLL_ERR: - self.on_local_error() + self._on_local_error() else: logging.warn('unknown socket')