From c5bcb9a050a395c19b82b4340866a627dcd23c18 Mon Sep 17 00:00:00 2001 From: clowwindy Date: Mon, 2 Jun 2014 17:01:35 +0800 Subject: [PATCH] add timeout support --- shadowsocks/tcprelay.py | 114 +++++++++++++++++++++++++++++++++++----- shadowsocks/utils.py | 6 ++- 2 files changed, 105 insertions(+), 15 deletions(-) diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 51e186a..e1b8cf9 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -30,6 +30,10 @@ import encrypt import eventloop from common import parse_header + +TIMEOUTS_CLEAN_SIZE = 512 +TIMEOUT_PRECISION = 4 + CMD_CONNECT = 1 CMD_BIND = 2 CMD_UDP_ASSOCIATE = 3 @@ -66,7 +70,9 @@ BUF_SIZE = 8 * 1024 class TCPRelayHandler(object): - def __init__(self, fd_to_handlers, loop, local_sock, config, is_local): + def __init__(self, server, fd_to_handlers, loop, local_sock, config, + is_local): + self._server = server self._fd_to_handlers = fd_to_handlers self._loop = loop self._local_sock = local_sock @@ -80,10 +86,25 @@ class TCPRelayHandler(object): self._data_to_write_to_remote = [] self._upstream_status = WAIT_STATUS_READING self._downstream_status = WAIT_STATUS_INIT + self._remote_address = None fd_to_handlers[local_sock.fileno()] = self local_sock.setblocking(False) local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) + self.last_activity = 0 + self.update_activity() + + def __hash__(self): + # default __hash__ is id / 16 + # we want to eliminate collisions + return id(self) + + @property + def remote_address(self): + return self._remote_address + + def update_activity(self): + self._server.update_activity(self) def update_stream(self, stream, status): dirty = False @@ -146,7 +167,7 @@ class TCPRelayHandler(object): logging.error('write_all_to_sock:unknown socket') def on_local_read(self): - # TODO update timeout + self.update_activity() if not self._local_sock: return is_local = self._is_local @@ -211,6 +232,7 @@ class TCPRelayHandler(object): addrtype, remote_addr, remote_port, header_length =\ header_result logging.debug('connecting %s:%d' % (remote_addr, remote_port)) + self._remote_address = (remote_addr, remote_port) if is_local: # forward address to remote self.write_to_sock('\x05\x00\x00\x01' + @@ -257,7 +279,7 @@ class TCPRelayHandler(object): self.destroy() def on_remote_read(self): - # TODO update timeout + self.update_activity() data = None try: data = self._remote_sock.recv(BUF_SIZE) @@ -325,7 +347,11 @@ class TCPRelayHandler(object): logging.warn('unknown socket') def destroy(self): - logging.debug('destroy') + if self._remote_address: + logging.debug('destroy: %s:%d' % + self._remote_address) + else: + logging.debug('destroy') if self._remote_sock: self._loop.remove(self._remote_sock) del self._fd_to_handlers[self._remote_sock.fileno()] @@ -336,6 +362,7 @@ class TCPRelayHandler(object): del self._fd_to_handlers[self._local_sock.fileno()] self._local_sock.close() self._local_sock = None + self._server.remove_handler(self) class TCPRelay(object): @@ -347,6 +374,12 @@ class TCPRelay(object): self._fd_to_handlers = {} self._last_time = time.time() + self._timeout = config['timeout'] + self._timeouts = [] # a list for all the handlers + self._timeout_offset = 0 # last checked position for timeout + # we trim the timeouts once a while + self._handler_to_timeouts = {} # key: handler value: index in timeouts + if is_local: listen_addr = config['local_address'] listen_port = config['local_port'] @@ -376,19 +409,74 @@ class TCPRelay(object): self._eventloop.add(self._server_socket, eventloop.POLL_IN | eventloop.POLL_ERR) + def remove_handler(self, handler): + index = self._handler_to_timeouts.get(hash(handler), -1) + if index >= 0: + # delete is O(n), so we just set it to None + self._timeouts[index] = None + del self._handler_to_timeouts[hash(handler)] + + def update_activity(self, handler): + """ set handler to active """ + now = int(time.time()) + if now - handler.last_activity < TIMEOUT_PRECISION: + # thus we can lower timeout modification frequency + return + handler.last_activity = now + index = self._handler_to_timeouts.get(hash(handler), -1) + if index >= 0: + # delete is O(n), so we just set it to None + self._timeouts[index] = None + length = len(self._timeouts) + self._timeouts.append(handler) + self._handler_to_timeouts[hash(handler)] = length + + def _sweep_timeout(self): + # tornado's timeout memory management is more flexible that we need + # we just need a sorted last_activity queue and it's faster that heapq + # in fact we can do O(1) insertion/remove so we invent our own + if self._timeouts: + now = time.time() + length = len(self._timeouts) + pos = self._timeout_offset + while pos < length: + handler = self._timeouts[pos] + if handler: + if now - handler.last_activity < self._timeout: + break + else: + if handler.remote_address: + logging.warn('timed out: %s:%d' % + handler.remote_address) + else: + logging.warn('timed out') + handler.destroy() + self._timeouts[pos] = None # free memory + pos += 1 + else: + pos += 1 + if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: + # clean up the timeout queue when it gets larger than half + # of the queue + self._timeouts = self._timeouts[pos:] + for key in self._handler_to_timeouts: + self._handler_to_timeouts[key] -= pos + pos = 0 + self._timeout_offset = pos + def _handle_events(self, events): for sock, fd, event in events: - if sock: - logging.debug('fd %d %s', fd, - eventloop.EVENT_NAMES.get(event, event)) + # if sock: + # logging.debug('fd %d %s', fd, + # eventloop.EVENT_NAMES.get(event, event)) if sock == self._server_socket: if event & eventloop.POLL_ERR: # TODO raise Exception('server_socket error') try: - logging.debug('accept') + # logging.debug('accept') conn = self._server_socket.accept() - TCPRelayHandler(self._fd_to_handlers, self._eventloop, + TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, conn[0], self._config, self._is_local) except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) @@ -404,10 +492,10 @@ class TCPRelay(object): else: logging.warn('poll removed fd') - now = time.time() - if now - self._last_time > 5: - # TODO sweep timeouts - self._last_time = now + now = time.time() + if now - self._last_time > TIMEOUT_PRECISION: + self._sweep_timeout() + self._last_time = now def close(self): self._closed = True diff --git a/shadowsocks/utils.py b/shadowsocks/utils.py index c087ee3..d1941f9 100644 --- a/shadowsocks/utils.py +++ b/shadowsocks/utils.py @@ -64,10 +64,10 @@ def check_config(config): if (config.get('method', '') or '').lower() == 'rc4': logging.warn('warning: RC4 is not safe; please use a safer cipher, ' 'like AES-256-CFB') - if (int(config.get('timeout', 300)) or 300) < 100: + if config.get('timeout', 300) < 100: logging.warn('warning: your timeout %d seems too short' % int(config.get('timeout'))) - if (int(config.get('timeout', 300)) or 300) > 600: + if config.get('timeout', 300) > 600: logging.warn('warning: your timeout %d seems too long' % int(config.get('timeout'))) @@ -114,6 +114,8 @@ def get_config(is_local): config['local_address'] = value elif key == '-v': config['verbose'] = True + elif key == '-t': + config['timeout'] = int(value) elif key == '--fast-open': config['fast_open'] = True elif key == '--workers':