|  |  | @ -30,6 +30,10 @@ import encrypt | 
			
		
	
		
			
				
					|  |  |  | import eventloop | 
			
		
	
		
			
				
					|  |  |  | from common import parse_header | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | TIMEOUTS_CLEAN_SIZE = 512 | 
			
		
	
		
			
				
					|  |  |  | TIMEOUT_PRECISION = 4 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | CMD_CONNECT = 1 | 
			
		
	
		
			
				
					|  |  |  | CMD_BIND = 2 | 
			
		
	
		
			
				
					|  |  |  | CMD_UDP_ASSOCIATE = 3 | 
			
		
	
	
		
			
				
					|  |  | @ -66,7 +70,9 @@ BUF_SIZE = 8 * 1024 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, fd_to_handlers, loop, local_sock, config, is_local): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, server, fd_to_handlers, loop, local_sock, config, | 
			
		
	
		
			
				
					|  |  |  |                  is_local): | 
			
		
	
		
			
				
					|  |  |  |         self._server = server | 
			
		
	
		
			
				
					|  |  |  |         self._fd_to_handlers = fd_to_handlers | 
			
		
	
		
			
				
					|  |  |  |         self._loop = loop | 
			
		
	
		
			
				
					|  |  |  |         self._local_sock = local_sock | 
			
		
	
	
		
			
				
					|  |  | @ -80,10 +86,25 @@ class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |         self._data_to_write_to_remote = [] | 
			
		
	
		
			
				
					|  |  |  |         self._upstream_status = WAIT_STATUS_READING | 
			
		
	
		
			
				
					|  |  |  |         self._downstream_status = WAIT_STATUS_INIT | 
			
		
	
		
			
				
					|  |  |  |         self._remote_address = None | 
			
		
	
		
			
				
					|  |  |  |         fd_to_handlers[local_sock.fileno()] = self | 
			
		
	
		
			
				
					|  |  |  |         local_sock.setblocking(False) | 
			
		
	
		
			
				
					|  |  |  |         local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | 
			
		
	
		
			
				
					|  |  |  |         loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) | 
			
		
	
		
			
				
					|  |  |  |         self.last_activity = 0 | 
			
		
	
		
			
				
					|  |  |  |         self.update_activity() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __hash__(self): | 
			
		
	
		
			
				
					|  |  |  |         # default __hash__ is id / 16 | 
			
		
	
		
			
				
					|  |  |  |         # we want to eliminate collisions | 
			
		
	
		
			
				
					|  |  |  |         return id(self) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     @property | 
			
		
	
		
			
				
					|  |  |  |     def remote_address(self): | 
			
		
	
		
			
				
					|  |  |  |         return self._remote_address | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def update_activity(self): | 
			
		
	
		
			
				
					|  |  |  |         self._server.update_activity(self) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def update_stream(self, stream, status): | 
			
		
	
		
			
				
					|  |  |  |         dirty = False | 
			
		
	
	
		
			
				
					|  |  | @ -146,7 +167,7 @@ class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |                 logging.error('write_all_to_sock:unknown socket') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def on_local_read(self): | 
			
		
	
		
			
				
					|  |  |  |         # TODO update timeout | 
			
		
	
		
			
				
					|  |  |  |         self.update_activity() | 
			
		
	
		
			
				
					|  |  |  |         if not self._local_sock: | 
			
		
	
		
			
				
					|  |  |  |             return | 
			
		
	
		
			
				
					|  |  |  |         is_local = self._is_local | 
			
		
	
	
		
			
				
					|  |  | @ -211,6 +232,7 @@ class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |                 addrtype, remote_addr, remote_port, header_length =\ | 
			
		
	
		
			
				
					|  |  |  |                     header_result | 
			
		
	
		
			
				
					|  |  |  |                 logging.debug('connecting %s:%d' % (remote_addr, remote_port)) | 
			
		
	
		
			
				
					|  |  |  |                 self._remote_address = (remote_addr, remote_port) | 
			
		
	
		
			
				
					|  |  |  |                 if is_local: | 
			
		
	
		
			
				
					|  |  |  |                     # forward address to remote | 
			
		
	
		
			
				
					|  |  |  |                     self.write_to_sock('\x05\x00\x00\x01' + | 
			
		
	
	
		
			
				
					|  |  | @ -257,7 +279,7 @@ class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |                 self.destroy() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def on_remote_read(self): | 
			
		
	
		
			
				
					|  |  |  |         # TODO update timeout | 
			
		
	
		
			
				
					|  |  |  |         self.update_activity() | 
			
		
	
		
			
				
					|  |  |  |         data = None | 
			
		
	
		
			
				
					|  |  |  |         try: | 
			
		
	
		
			
				
					|  |  |  |             data = self._remote_sock.recv(BUF_SIZE) | 
			
		
	
	
		
			
				
					|  |  | @ -325,6 +347,10 @@ class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |             logging.warn('unknown socket') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def destroy(self): | 
			
		
	
		
			
				
					|  |  |  |         if self._remote_address: | 
			
		
	
		
			
				
					|  |  |  |             logging.debug('destroy: %s:%d' % | 
			
		
	
		
			
				
					|  |  |  |                           self._remote_address) | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             logging.debug('destroy') | 
			
		
	
		
			
				
					|  |  |  |         if self._remote_sock: | 
			
		
	
		
			
				
					|  |  |  |             self._loop.remove(self._remote_sock) | 
			
		
	
	
		
			
				
					|  |  | @ -336,6 +362,7 @@ class TCPRelayHandler(object): | 
			
		
	
		
			
				
					|  |  |  |             del self._fd_to_handlers[self._local_sock.fileno()] | 
			
		
	
		
			
				
					|  |  |  |             self._local_sock.close() | 
			
		
	
		
			
				
					|  |  |  |             self._local_sock = None | 
			
		
	
		
			
				
					|  |  |  |         self._server.remove_handler(self) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | class TCPRelay(object): | 
			
		
	
	
		
			
				
					|  |  | @ -347,6 +374,12 @@ class TCPRelay(object): | 
			
		
	
		
			
				
					|  |  |  |         self._fd_to_handlers = {} | 
			
		
	
		
			
				
					|  |  |  |         self._last_time = time.time() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         self._timeout = config['timeout'] | 
			
		
	
		
			
				
					|  |  |  |         self._timeouts = []  # a list for all the handlers | 
			
		
	
		
			
				
					|  |  |  |         self._timeout_offset = 0  # last checked position for timeout | 
			
		
	
		
			
				
					|  |  |  |                                   # we trim the timeouts once a while | 
			
		
	
		
			
				
					|  |  |  |         self._handler_to_timeouts = {}  # key: handler value: index in timeouts | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         if is_local: | 
			
		
	
		
			
				
					|  |  |  |             listen_addr = config['local_address'] | 
			
		
	
		
			
				
					|  |  |  |             listen_port = config['local_port'] | 
			
		
	
	
		
			
				
					|  |  | @ -376,19 +409,74 @@ class TCPRelay(object): | 
			
		
	
		
			
				
					|  |  |  |         self._eventloop.add(self._server_socket, | 
			
		
	
		
			
				
					|  |  |  |                             eventloop.POLL_IN | eventloop.POLL_ERR) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def remove_handler(self, handler): | 
			
		
	
		
			
				
					|  |  |  |         index = self._handler_to_timeouts.get(hash(handler), -1) | 
			
		
	
		
			
				
					|  |  |  |         if index >= 0: | 
			
		
	
		
			
				
					|  |  |  |             # delete is O(n), so we just set it to None | 
			
		
	
		
			
				
					|  |  |  |             self._timeouts[index] = None | 
			
		
	
		
			
				
					|  |  |  |             del self._handler_to_timeouts[hash(handler)] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def update_activity(self, handler): | 
			
		
	
		
			
				
					|  |  |  |         """ set handler to active """ | 
			
		
	
		
			
				
					|  |  |  |         now = int(time.time()) | 
			
		
	
		
			
				
					|  |  |  |         if now - handler.last_activity < TIMEOUT_PRECISION: | 
			
		
	
		
			
				
					|  |  |  |             # thus we can lower timeout modification frequency | 
			
		
	
		
			
				
					|  |  |  |             return | 
			
		
	
		
			
				
					|  |  |  |         handler.last_activity = now | 
			
		
	
		
			
				
					|  |  |  |         index = self._handler_to_timeouts.get(hash(handler), -1) | 
			
		
	
		
			
				
					|  |  |  |         if index >= 0: | 
			
		
	
		
			
				
					|  |  |  |             # delete is O(n), so we just set it to None | 
			
		
	
		
			
				
					|  |  |  |             self._timeouts[index] = None | 
			
		
	
		
			
				
					|  |  |  |         length = len(self._timeouts) | 
			
		
	
		
			
				
					|  |  |  |         self._timeouts.append(handler) | 
			
		
	
		
			
				
					|  |  |  |         self._handler_to_timeouts[hash(handler)] = length | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _sweep_timeout(self): | 
			
		
	
		
			
				
					|  |  |  |         # tornado's timeout memory management is more flexible that we need | 
			
		
	
		
			
				
					|  |  |  |         # we just need a sorted last_activity queue and it's faster that heapq | 
			
		
	
		
			
				
					|  |  |  |         # in fact we can do O(1) insertion/remove so we invent our own | 
			
		
	
		
			
				
					|  |  |  |         if self._timeouts: | 
			
		
	
		
			
				
					|  |  |  |             now = time.time() | 
			
		
	
		
			
				
					|  |  |  |             length = len(self._timeouts) | 
			
		
	
		
			
				
					|  |  |  |             pos = self._timeout_offset | 
			
		
	
		
			
				
					|  |  |  |             while pos < length: | 
			
		
	
		
			
				
					|  |  |  |                 handler = self._timeouts[pos] | 
			
		
	
		
			
				
					|  |  |  |                 if handler: | 
			
		
	
		
			
				
					|  |  |  |                     if now - handler.last_activity < self._timeout: | 
			
		
	
		
			
				
					|  |  |  |                         break | 
			
		
	
		
			
				
					|  |  |  |                     else: | 
			
		
	
		
			
				
					|  |  |  |                         if handler.remote_address: | 
			
		
	
		
			
				
					|  |  |  |                             logging.warn('timed out: %s:%d' % | 
			
		
	
		
			
				
					|  |  |  |                                          handler.remote_address) | 
			
		
	
		
			
				
					|  |  |  |                         else: | 
			
		
	
		
			
				
					|  |  |  |                             logging.warn('timed out') | 
			
		
	
		
			
				
					|  |  |  |                         handler.destroy() | 
			
		
	
		
			
				
					|  |  |  |                         self._timeouts[pos] = None  # free memory | 
			
		
	
		
			
				
					|  |  |  |                         pos += 1 | 
			
		
	
		
			
				
					|  |  |  |                 else: | 
			
		
	
		
			
				
					|  |  |  |                     pos += 1 | 
			
		
	
		
			
				
					|  |  |  |             if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: | 
			
		
	
		
			
				
					|  |  |  |                 # clean up the timeout queue when it gets larger than half | 
			
		
	
		
			
				
					|  |  |  |                 # of the queue | 
			
		
	
		
			
				
					|  |  |  |                 self._timeouts = self._timeouts[pos:] | 
			
		
	
		
			
				
					|  |  |  |                 for key in self._handler_to_timeouts: | 
			
		
	
		
			
				
					|  |  |  |                     self._handler_to_timeouts[key] -= pos | 
			
		
	
		
			
				
					|  |  |  |                 pos = 0 | 
			
		
	
		
			
				
					|  |  |  |             self._timeout_offset = pos | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _handle_events(self, events): | 
			
		
	
		
			
				
					|  |  |  |         for sock, fd, event in events: | 
			
		
	
		
			
				
					|  |  |  |             if sock: | 
			
		
	
		
			
				
					|  |  |  |                 logging.debug('fd %d %s', fd, | 
			
		
	
		
			
				
					|  |  |  |                               eventloop.EVENT_NAMES.get(event, event)) | 
			
		
	
		
			
				
					|  |  |  |             # if sock: | 
			
		
	
		
			
				
					|  |  |  |             #     logging.debug('fd %d %s', fd, | 
			
		
	
		
			
				
					|  |  |  |             #                   eventloop.EVENT_NAMES.get(event, event)) | 
			
		
	
		
			
				
					|  |  |  |             if sock == self._server_socket: | 
			
		
	
		
			
				
					|  |  |  |                 if event & eventloop.POLL_ERR: | 
			
		
	
		
			
				
					|  |  |  |                     # TODO | 
			
		
	
		
			
				
					|  |  |  |                     raise Exception('server_socket error') | 
			
		
	
		
			
				
					|  |  |  |                 try: | 
			
		
	
		
			
				
					|  |  |  |                     logging.debug('accept') | 
			
		
	
		
			
				
					|  |  |  |                     # logging.debug('accept') | 
			
		
	
		
			
				
					|  |  |  |                     conn = self._server_socket.accept() | 
			
		
	
		
			
				
					|  |  |  |                     TCPRelayHandler(self._fd_to_handlers, self._eventloop, | 
			
		
	
		
			
				
					|  |  |  |                     TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, | 
			
		
	
		
			
				
					|  |  |  |                                     conn[0], self._config, self._is_local) | 
			
		
	
		
			
				
					|  |  |  |                 except (OSError, IOError) as e: | 
			
		
	
		
			
				
					|  |  |  |                     error_no = eventloop.errno_from_exception(e) | 
			
		
	
	
		
			
				
					|  |  | @ -405,8 +493,8 @@ class TCPRelay(object): | 
			
		
	
		
			
				
					|  |  |  |                     logging.warn('poll removed fd') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         now = time.time() | 
			
		
	
		
			
				
					|  |  |  |             if now - self._last_time > 5: | 
			
		
	
		
			
				
					|  |  |  |                 # TODO sweep timeouts | 
			
		
	
		
			
				
					|  |  |  |         if now - self._last_time > TIMEOUT_PRECISION: | 
			
		
	
		
			
				
					|  |  |  |             self._sweep_timeout() | 
			
		
	
		
			
				
					|  |  |  |             self._last_time = now | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def close(self): | 
			
		
	
	
		
			
				
					|  |  | 
 |