|
|
@ -30,6 +30,10 @@ import encrypt |
|
|
|
import eventloop |
|
|
|
from common import parse_header |
|
|
|
|
|
|
|
|
|
|
|
TIMEOUTS_CLEAN_SIZE = 512 |
|
|
|
TIMEOUT_PRECISION = 4 |
|
|
|
|
|
|
|
CMD_CONNECT = 1 |
|
|
|
CMD_BIND = 2 |
|
|
|
CMD_UDP_ASSOCIATE = 3 |
|
|
@ -66,7 +70,9 @@ BUF_SIZE = 8 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
class TCPRelayHandler(object): |
|
|
|
def __init__(self, fd_to_handlers, loop, local_sock, config, is_local): |
|
|
|
def __init__(self, server, fd_to_handlers, loop, local_sock, config, |
|
|
|
is_local): |
|
|
|
self._server = server |
|
|
|
self._fd_to_handlers = fd_to_handlers |
|
|
|
self._loop = loop |
|
|
|
self._local_sock = local_sock |
|
|
@ -80,10 +86,25 @@ class TCPRelayHandler(object): |
|
|
|
self._data_to_write_to_remote = [] |
|
|
|
self._upstream_status = WAIT_STATUS_READING |
|
|
|
self._downstream_status = WAIT_STATUS_INIT |
|
|
|
self._remote_address = None |
|
|
|
fd_to_handlers[local_sock.fileno()] = self |
|
|
|
local_sock.setblocking(False) |
|
|
|
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) |
|
|
|
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) |
|
|
|
self.last_activity = 0 |
|
|
|
self.update_activity() |
|
|
|
|
|
|
|
def __hash__(self): |
|
|
|
# default __hash__ is id / 16 |
|
|
|
# we want to eliminate collisions |
|
|
|
return id(self) |
|
|
|
|
|
|
|
@property |
|
|
|
def remote_address(self): |
|
|
|
return self._remote_address |
|
|
|
|
|
|
|
def update_activity(self): |
|
|
|
self._server.update_activity(self) |
|
|
|
|
|
|
|
def update_stream(self, stream, status): |
|
|
|
dirty = False |
|
|
@ -146,7 +167,7 @@ class TCPRelayHandler(object): |
|
|
|
logging.error('write_all_to_sock:unknown socket') |
|
|
|
|
|
|
|
def on_local_read(self): |
|
|
|
# TODO update timeout |
|
|
|
self.update_activity() |
|
|
|
if not self._local_sock: |
|
|
|
return |
|
|
|
is_local = self._is_local |
|
|
@ -211,6 +232,7 @@ class TCPRelayHandler(object): |
|
|
|
addrtype, remote_addr, remote_port, header_length =\ |
|
|
|
header_result |
|
|
|
logging.debug('connecting %s:%d' % (remote_addr, remote_port)) |
|
|
|
self._remote_address = (remote_addr, remote_port) |
|
|
|
if is_local: |
|
|
|
# forward address to remote |
|
|
|
self.write_to_sock('\x05\x00\x00\x01' + |
|
|
@ -257,7 +279,7 @@ class TCPRelayHandler(object): |
|
|
|
self.destroy() |
|
|
|
|
|
|
|
def on_remote_read(self): |
|
|
|
# TODO update timeout |
|
|
|
self.update_activity() |
|
|
|
data = None |
|
|
|
try: |
|
|
|
data = self._remote_sock.recv(BUF_SIZE) |
|
|
@ -325,7 +347,11 @@ class TCPRelayHandler(object): |
|
|
|
logging.warn('unknown socket') |
|
|
|
|
|
|
|
def destroy(self): |
|
|
|
logging.debug('destroy') |
|
|
|
if self._remote_address: |
|
|
|
logging.debug('destroy: %s:%d' % |
|
|
|
self._remote_address) |
|
|
|
else: |
|
|
|
logging.debug('destroy') |
|
|
|
if self._remote_sock: |
|
|
|
self._loop.remove(self._remote_sock) |
|
|
|
del self._fd_to_handlers[self._remote_sock.fileno()] |
|
|
@ -336,6 +362,7 @@ class TCPRelayHandler(object): |
|
|
|
del self._fd_to_handlers[self._local_sock.fileno()] |
|
|
|
self._local_sock.close() |
|
|
|
self._local_sock = None |
|
|
|
self._server.remove_handler(self) |
|
|
|
|
|
|
|
|
|
|
|
class TCPRelay(object): |
|
|
@ -347,6 +374,12 @@ class TCPRelay(object): |
|
|
|
self._fd_to_handlers = {} |
|
|
|
self._last_time = time.time() |
|
|
|
|
|
|
|
self._timeout = config['timeout'] |
|
|
|
self._timeouts = [] # a list for all the handlers |
|
|
|
self._timeout_offset = 0 # last checked position for timeout |
|
|
|
# we trim the timeouts once a while |
|
|
|
self._handler_to_timeouts = {} # key: handler value: index in timeouts |
|
|
|
|
|
|
|
if is_local: |
|
|
|
listen_addr = config['local_address'] |
|
|
|
listen_port = config['local_port'] |
|
|
@ -376,19 +409,74 @@ class TCPRelay(object): |
|
|
|
self._eventloop.add(self._server_socket, |
|
|
|
eventloop.POLL_IN | eventloop.POLL_ERR) |
|
|
|
|
|
|
|
def remove_handler(self, handler): |
|
|
|
index = self._handler_to_timeouts.get(hash(handler), -1) |
|
|
|
if index >= 0: |
|
|
|
# delete is O(n), so we just set it to None |
|
|
|
self._timeouts[index] = None |
|
|
|
del self._handler_to_timeouts[hash(handler)] |
|
|
|
|
|
|
|
def update_activity(self, handler): |
|
|
|
""" set handler to active """ |
|
|
|
now = int(time.time()) |
|
|
|
if now - handler.last_activity < TIMEOUT_PRECISION: |
|
|
|
# thus we can lower timeout modification frequency |
|
|
|
return |
|
|
|
handler.last_activity = now |
|
|
|
index = self._handler_to_timeouts.get(hash(handler), -1) |
|
|
|
if index >= 0: |
|
|
|
# delete is O(n), so we just set it to None |
|
|
|
self._timeouts[index] = None |
|
|
|
length = len(self._timeouts) |
|
|
|
self._timeouts.append(handler) |
|
|
|
self._handler_to_timeouts[hash(handler)] = length |
|
|
|
|
|
|
|
def _sweep_timeout(self): |
|
|
|
# tornado's timeout memory management is more flexible that we need |
|
|
|
# we just need a sorted last_activity queue and it's faster that heapq |
|
|
|
# in fact we can do O(1) insertion/remove so we invent our own |
|
|
|
if self._timeouts: |
|
|
|
now = time.time() |
|
|
|
length = len(self._timeouts) |
|
|
|
pos = self._timeout_offset |
|
|
|
while pos < length: |
|
|
|
handler = self._timeouts[pos] |
|
|
|
if handler: |
|
|
|
if now - handler.last_activity < self._timeout: |
|
|
|
break |
|
|
|
else: |
|
|
|
if handler.remote_address: |
|
|
|
logging.warn('timed out: %s:%d' % |
|
|
|
handler.remote_address) |
|
|
|
else: |
|
|
|
logging.warn('timed out') |
|
|
|
handler.destroy() |
|
|
|
self._timeouts[pos] = None # free memory |
|
|
|
pos += 1 |
|
|
|
else: |
|
|
|
pos += 1 |
|
|
|
if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: |
|
|
|
# clean up the timeout queue when it gets larger than half |
|
|
|
# of the queue |
|
|
|
self._timeouts = self._timeouts[pos:] |
|
|
|
for key in self._handler_to_timeouts: |
|
|
|
self._handler_to_timeouts[key] -= pos |
|
|
|
pos = 0 |
|
|
|
self._timeout_offset = pos |
|
|
|
|
|
|
|
def _handle_events(self, events): |
|
|
|
for sock, fd, event in events: |
|
|
|
if sock: |
|
|
|
logging.debug('fd %d %s', fd, |
|
|
|
eventloop.EVENT_NAMES.get(event, event)) |
|
|
|
# if sock: |
|
|
|
# logging.debug('fd %d %s', fd, |
|
|
|
# eventloop.EVENT_NAMES.get(event, event)) |
|
|
|
if sock == self._server_socket: |
|
|
|
if event & eventloop.POLL_ERR: |
|
|
|
# TODO |
|
|
|
raise Exception('server_socket error') |
|
|
|
try: |
|
|
|
logging.debug('accept') |
|
|
|
# logging.debug('accept') |
|
|
|
conn = self._server_socket.accept() |
|
|
|
TCPRelayHandler(self._fd_to_handlers, self._eventloop, |
|
|
|
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, |
|
|
|
conn[0], self._config, self._is_local) |
|
|
|
except (OSError, IOError) as e: |
|
|
|
error_no = eventloop.errno_from_exception(e) |
|
|
@ -404,10 +492,10 @@ class TCPRelay(object): |
|
|
|
else: |
|
|
|
logging.warn('poll removed fd') |
|
|
|
|
|
|
|
now = time.time() |
|
|
|
if now - self._last_time > 5: |
|
|
|
# TODO sweep timeouts |
|
|
|
self._last_time = now |
|
|
|
now = time.time() |
|
|
|
if now - self._last_time > TIMEOUT_PRECISION: |
|
|
|
self._sweep_timeout() |
|
|
|
self._last_time = now |
|
|
|
|
|
|
|
def close(self): |
|
|
|
self._closed = True |
|
|
|