Browse Source

add timeout support

auth
clowwindy 10 years ago
parent
commit
c5bcb9a050
  1. 114
      shadowsocks/tcprelay.py
  2. 6
      shadowsocks/utils.py

114
shadowsocks/tcprelay.py

@ -30,6 +30,10 @@ import encrypt
import eventloop
from common import parse_header
TIMEOUTS_CLEAN_SIZE = 512
TIMEOUT_PRECISION = 4
CMD_CONNECT = 1
CMD_BIND = 2
CMD_UDP_ASSOCIATE = 3
@ -66,7 +70,9 @@ BUF_SIZE = 8 * 1024
class TCPRelayHandler(object):
def __init__(self, fd_to_handlers, loop, local_sock, config, is_local):
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
is_local):
self._server = server
self._fd_to_handlers = fd_to_handlers
self._loop = loop
self._local_sock = local_sock
@ -80,10 +86,25 @@ class TCPRelayHandler(object):
self._data_to_write_to_remote = []
self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT
self._remote_address = None
fd_to_handlers[local_sock.fileno()] = self
local_sock.setblocking(False)
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR)
self.last_activity = 0
self.update_activity()
def __hash__(self):
# default __hash__ is id / 16
# we want to eliminate collisions
return id(self)
@property
def remote_address(self):
return self._remote_address
def update_activity(self):
self._server.update_activity(self)
def update_stream(self, stream, status):
dirty = False
@ -146,7 +167,7 @@ class TCPRelayHandler(object):
logging.error('write_all_to_sock:unknown socket')
def on_local_read(self):
# TODO update timeout
self.update_activity()
if not self._local_sock:
return
is_local = self._is_local
@ -211,6 +232,7 @@ class TCPRelayHandler(object):
addrtype, remote_addr, remote_port, header_length =\
header_result
logging.debug('connecting %s:%d' % (remote_addr, remote_port))
self._remote_address = (remote_addr, remote_port)
if is_local:
# forward address to remote
self.write_to_sock('\x05\x00\x00\x01' +
@ -257,7 +279,7 @@ class TCPRelayHandler(object):
self.destroy()
def on_remote_read(self):
# TODO update timeout
self.update_activity()
data = None
try:
data = self._remote_sock.recv(BUF_SIZE)
@ -325,7 +347,11 @@ class TCPRelayHandler(object):
logging.warn('unknown socket')
def destroy(self):
logging.debug('destroy')
if self._remote_address:
logging.debug('destroy: %s:%d' %
self._remote_address)
else:
logging.debug('destroy')
if self._remote_sock:
self._loop.remove(self._remote_sock)
del self._fd_to_handlers[self._remote_sock.fileno()]
@ -336,6 +362,7 @@ class TCPRelayHandler(object):
del self._fd_to_handlers[self._local_sock.fileno()]
self._local_sock.close()
self._local_sock = None
self._server.remove_handler(self)
class TCPRelay(object):
@ -347,6 +374,12 @@ class TCPRelay(object):
self._fd_to_handlers = {}
self._last_time = time.time()
self._timeout = config['timeout']
self._timeouts = [] # a list for all the handlers
self._timeout_offset = 0 # last checked position for timeout
# we trim the timeouts once a while
self._handler_to_timeouts = {} # key: handler value: index in timeouts
if is_local:
listen_addr = config['local_address']
listen_port = config['local_port']
@ -376,19 +409,74 @@ class TCPRelay(object):
self._eventloop.add(self._server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR)
def remove_handler(self, handler):
index = self._handler_to_timeouts.get(hash(handler), -1)
if index >= 0:
# delete is O(n), so we just set it to None
self._timeouts[index] = None
del self._handler_to_timeouts[hash(handler)]
def update_activity(self, handler):
""" set handler to active """
now = int(time.time())
if now - handler.last_activity < TIMEOUT_PRECISION:
# thus we can lower timeout modification frequency
return
handler.last_activity = now
index = self._handler_to_timeouts.get(hash(handler), -1)
if index >= 0:
# delete is O(n), so we just set it to None
self._timeouts[index] = None
length = len(self._timeouts)
self._timeouts.append(handler)
self._handler_to_timeouts[hash(handler)] = length
def _sweep_timeout(self):
# tornado's timeout memory management is more flexible that we need
# we just need a sorted last_activity queue and it's faster that heapq
# in fact we can do O(1) insertion/remove so we invent our own
if self._timeouts:
now = time.time()
length = len(self._timeouts)
pos = self._timeout_offset
while pos < length:
handler = self._timeouts[pos]
if handler:
if now - handler.last_activity < self._timeout:
break
else:
if handler.remote_address:
logging.warn('timed out: %s:%d' %
handler.remote_address)
else:
logging.warn('timed out')
handler.destroy()
self._timeouts[pos] = None # free memory
pos += 1
else:
pos += 1
if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1:
# clean up the timeout queue when it gets larger than half
# of the queue
self._timeouts = self._timeouts[pos:]
for key in self._handler_to_timeouts:
self._handler_to_timeouts[key] -= pos
pos = 0
self._timeout_offset = pos
def _handle_events(self, events):
for sock, fd, event in events:
if sock:
logging.debug('fd %d %s', fd,
eventloop.EVENT_NAMES.get(event, event))
# if sock:
# logging.debug('fd %d %s', fd,
# eventloop.EVENT_NAMES.get(event, event))
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
logging.debug('accept')
# logging.debug('accept')
conn = self._server_socket.accept()
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
conn[0], self._config, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
@ -404,10 +492,10 @@ class TCPRelay(object):
else:
logging.warn('poll removed fd')
now = time.time()
if now - self._last_time > 5:
# TODO sweep timeouts
self._last_time = now
now = time.time()
if now - self._last_time > TIMEOUT_PRECISION:
self._sweep_timeout()
self._last_time = now
def close(self):
self._closed = True

6
shadowsocks/utils.py

@ -64,10 +64,10 @@ def check_config(config):
if (config.get('method', '') or '').lower() == 'rc4':
logging.warn('warning: RC4 is not safe; please use a safer cipher, '
'like AES-256-CFB')
if (int(config.get('timeout', 300)) or 300) < 100:
if config.get('timeout', 300) < 100:
logging.warn('warning: your timeout %d seems too short' %
int(config.get('timeout')))
if (int(config.get('timeout', 300)) or 300) > 600:
if config.get('timeout', 300) > 600:
logging.warn('warning: your timeout %d seems too long' %
int(config.get('timeout')))
@ -114,6 +114,8 @@ def get_config(is_local):
config['local_address'] = value
elif key == '-v':
config['verbose'] = True
elif key == '-t':
config['timeout'] = int(value)
elif key == '--fast-open':
config['fast_open'] = True
elif key == '--workers':

Loading…
Cancel
Save