diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 3096e00..c387373 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -147,7 +147,6 @@ class TCPRelayHandler(object): # TCP Relay works as either sslocal or ssserver # if is_local, this is sslocal self._is_local = is_local - self._stage = STAGE_INIT self._encrypt_correct = True self._obfs = obfs.obfs(config['obfs']) self._protocol = obfs.obfs(config['protocol']) @@ -211,11 +210,6 @@ class TCPRelayHandler(object): if is_local: self._chosen_server = self._get_a_server() - fd_to_handlers[local_sock.fileno()] = self - local_sock.setblocking(False) - local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, self._server) - self.last_activity = 0 self._update_activity() self._server.add_connection(1) @@ -226,6 +220,12 @@ class TCPRelayHandler(object): self._udp_send_pack_id = 0 self._udpv6_send_pack_id = 0 + local_sock.setblocking(False) + local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + fd_to_handlers[local_sock.fileno()] = self + loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, self._server) + self._stage = STAGE_INIT + def __hash__(self): # default __hash__ is id / 16 # we want to eliminate collisions @@ -311,7 +311,6 @@ class TCPRelayHandler(object): # and update the stream to wait for writing if not sock: return False - #logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp)) uncomplete = False if self._remote_udp and sock == self._remote_sock: try: @@ -832,7 +831,15 @@ class TCPRelayHandler(object): return if obfs_decode[2]: data = self._obfs.server_encode(b'') - self._write_to_sock(data, self._local_sock) + try: + self._write_to_sock(data, self._local_sock) + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return if obfs_decode[1]: if not self._protocol.obfs.server_info.recv_iv: iv_len = len(self._protocol.obfs.server_info.iv) @@ -859,6 +866,7 @@ class TCPRelayHandler(object): shell.print_exception(e) logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) self.destroy() + return else: return if not data: @@ -870,12 +878,10 @@ class TCPRelayHandler(object): data = self._encryptor.encrypt(data) data = self._obfs.client_encode(data) self._write_to_sock(data, self._remote_sock) - return elif is_local and self._stage == STAGE_INIT: # TODO check auth method self._write_to_sock(b'\x05\00', self._local_sock) self._stage = STAGE_ADDR - return elif self._stage == STAGE_CONNECTING: self._handle_stage_connecting(data) elif (is_local and self._stage == STAGE_ADDR) or \ @@ -1016,36 +1022,34 @@ class TCPRelayHandler(object): if event & eventloop.POLL_ERR: handle = True self._on_remote_error() - if self._stage == STAGE_DESTROYED: - return True - if event & (eventloop.POLL_IN | eventloop.POLL_HUP): + elif event & (eventloop.POLL_IN | eventloop.POLL_HUP): if not self.speed_tester_d.isExceed(): if not self._server.speed_tester_d(self._user_id).isExceed(): handle = True self._on_remote_read(sock == self._remote_sock) - if self._stage == STAGE_DESTROYED: - return True - if event & eventloop.POLL_OUT: + elif event & eventloop.POLL_OUT: handle = True self._on_remote_write() elif sock == self._local_sock: if event & eventloop.POLL_ERR: handle = True self._on_local_error() - if self._stage == STAGE_DESTROYED: - return True - if event & (eventloop.POLL_IN | eventloop.POLL_HUP): + elif event & (eventloop.POLL_IN | eventloop.POLL_HUP): if not self.speed_tester_u.isExceed(): if not self._server.speed_tester_u(self._user_id).isExceed(): handle = True self._on_local_read() - if self._stage == STAGE_DESTROYED: - return True - if event & eventloop.POLL_OUT: + elif event & eventloop.POLL_OUT: handle = True self._on_local_write() else: logging.warn('unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) + try: + self._loop.remove(sock) + except Exception as e: + shell.print_exception(e) + del self._fd_to_handlers[sock.fileno()] + sock.close() return handle @@ -1079,16 +1083,16 @@ class TCPRelayHandler(object): try: self._loop.remove(self._remote_sock) except Exception as e: - pass + shell.print_exception(e) del self._fd_to_handlers[self._remote_sock.fileno()] self._remote_sock.close() self._remote_sock = None if self._remote_sock_v6: - logging.debug('destroying remote') + logging.debug('destroying remote_v6') try: self._loop.remove(self._remote_sock_v6) except Exception as e: - pass + shell.print_exception(e) del self._fd_to_handlers[self._remote_sock_v6.fileno()] self._remote_sock_v6.close() self._remote_sock_v6 = None @@ -1334,6 +1338,7 @@ class TCPRelay(object): if event & eventloop.POLL_ERR: # TODO raise Exception('server_socket error') + handler = None try: logging.debug('accept') conn = self._server_socket.accept() @@ -1351,11 +1356,15 @@ class TCPRelay(object): shell.print_exception(e) if self._config['verbose']: traceback.print_exc() + if handler: + handler.destroy() else: if sock: handler = self._fd_to_handlers.get(fd, None) if handler: handler.handle_event(sock, event) + else: + logging.warn('unknown fd') else: logging.warn('poll removed fd')