|
@ -28,7 +28,7 @@ import traceback |
|
|
import random |
|
|
import random |
|
|
import platform |
|
|
import platform |
|
|
|
|
|
|
|
|
from shadowsocks import encrypt, obfs, eventloop, shell, common |
|
|
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache |
|
|
from shadowsocks.common import pre_parse_header, parse_header |
|
|
from shadowsocks.common import pre_parse_header, parse_header |
|
|
|
|
|
|
|
|
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time |
|
|
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time |
|
@ -961,10 +961,8 @@ class TCPRelay(object): |
|
|
common.connect_log = logging.info |
|
|
common.connect_log = logging.info |
|
|
|
|
|
|
|
|
self._timeout = config['timeout'] |
|
|
self._timeout = config['timeout'] |
|
|
self._timeouts = [] # a list for all the handlers |
|
|
self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, |
|
|
# we trim the timeouts once a while |
|
|
close_callback=self._close_tcp_client) |
|
|
self._timeout_offset = 0 # last checked position for timeout |
|
|
|
|
|
self._handler_to_timeouts = {} # key: handler value: index in timeouts |
|
|
|
|
|
|
|
|
|
|
|
if is_local: |
|
|
if is_local: |
|
|
listen_addr = config['local_address'] |
|
|
listen_addr = config['local_address'] |
|
@ -1005,12 +1003,9 @@ class TCPRelay(object): |
|
|
eventloop.POLL_IN | eventloop.POLL_ERR, self) |
|
|
eventloop.POLL_IN | eventloop.POLL_ERR, self) |
|
|
self._eventloop.add_periodic(self.handle_periodic) |
|
|
self._eventloop.add_periodic(self.handle_periodic) |
|
|
|
|
|
|
|
|
def remove_handler(self, handler): |
|
|
def remove_handler(self, client): |
|
|
index = self._handler_to_timeouts.get(hash(handler), -1) |
|
|
if hash(client) in self._timeout_cache: |
|
|
if index >= 0: |
|
|
del self._timeout_cache[hash(client)] |
|
|
# delete is O(n), so we just set it to None |
|
|
|
|
|
self._timeouts[index] = None |
|
|
|
|
|
del self._handler_to_timeouts[hash(handler)] |
|
|
|
|
|
|
|
|
|
|
|
def add_connection(self, val): |
|
|
def add_connection(self, val): |
|
|
self.server_connections += val |
|
|
self.server_connections += val |
|
@ -1052,57 +1047,22 @@ class TCPRelay(object): |
|
|
logging.info('Total connections down to %d' % newval) |
|
|
logging.info('Total connections down to %d' % newval) |
|
|
self._stat_counter[-1] = self._stat_counter.get(-1, 0) - connections_step |
|
|
self._stat_counter[-1] = self._stat_counter.get(-1, 0) - connections_step |
|
|
|
|
|
|
|
|
def update_activity(self, handler, data_len): |
|
|
def update_activity(self, client, data_len): |
|
|
if data_len and self._stat_callback: |
|
|
if data_len and self._stat_callback: |
|
|
self._stat_callback(self._listen_port, data_len) |
|
|
self._stat_callback(self._listen_port, data_len) |
|
|
|
|
|
|
|
|
# set handler to active |
|
|
self._timeout_cache[hash(client)] = client |
|
|
now = int(time.time()) |
|
|
|
|
|
if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: |
|
|
|
|
|
# thus we can lower timeout modification frequency |
|
|
|
|
|
return |
|
|
|
|
|
handler.last_activity = now |
|
|
|
|
|
index = self._handler_to_timeouts.get(hash(handler), -1) |
|
|
|
|
|
if index >= 0: |
|
|
|
|
|
# delete is O(n), so we just set it to None |
|
|
|
|
|
self._timeouts[index] = None |
|
|
|
|
|
length = len(self._timeouts) |
|
|
|
|
|
self._timeouts.append(handler) |
|
|
|
|
|
self._handler_to_timeouts[hash(handler)] = length |
|
|
|
|
|
|
|
|
|
|
|
def _sweep_timeout(self): |
|
|
def _sweep_timeout(self): |
|
|
# tornado's timeout memory management is more flexible than we need |
|
|
self._timeout_cache.sweep() |
|
|
# we just need a sorted last_activity queue and it's faster than heapq |
|
|
|
|
|
# in fact we can do O(1) insertion/remove so we invent our own |
|
|
def _close_tcp_client(self, client): |
|
|
if self._timeouts: |
|
|
if client.remote_address: |
|
|
logging.log(shell.VERBOSE_LEVEL, 'sweeping timeouts') |
|
|
logging.debug('timed out: %s:%d' % |
|
|
now = time.time() |
|
|
client.remote_address) |
|
|
length = len(self._timeouts) |
|
|
else: |
|
|
pos = self._timeout_offset |
|
|
logging.debug('timed out') |
|
|
while pos < length: |
|
|
client.destroy() |
|
|
handler = self._timeouts[pos] |
|
|
|
|
|
if handler: |
|
|
|
|
|
if now - handler.last_activity < self._timeout: |
|
|
|
|
|
break |
|
|
|
|
|
else: |
|
|
|
|
|
if handler.remote_address: |
|
|
|
|
|
logging.debug('timed out: %s:%d' % |
|
|
|
|
|
handler.remote_address) |
|
|
|
|
|
else: |
|
|
|
|
|
logging.debug('timed out') |
|
|
|
|
|
handler.destroy() |
|
|
|
|
|
self._timeouts[pos] = None # free memory |
|
|
|
|
|
pos += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
pos += 1 |
|
|
|
|
|
if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: |
|
|
|
|
|
# clean up the timeout queue when it gets larger than half |
|
|
|
|
|
# of the queue |
|
|
|
|
|
self._timeouts = self._timeouts[pos:] |
|
|
|
|
|
for key in self._handler_to_timeouts: |
|
|
|
|
|
self._handler_to_timeouts[key] -= pos |
|
|
|
|
|
pos = 0 |
|
|
|
|
|
self._timeout_offset = pos |
|
|
|
|
|
|
|
|
|
|
|
def handle_event(self, sock, fd, event): |
|
|
def handle_event(self, sock, fd, event): |
|
|
# handle events and dispatch to handlers |
|
|
# handle events and dispatch to handlers |
|
|