Browse Source

add typing hint and reformat code

akkariiin/Experimental
Akkariiin 6 years ago
parent
commit
5c3137a327
  1. 111
      shadowsocks/tcprelay.py

111
shadowsocks/tcprelay.py

@ -28,6 +28,7 @@ import traceback
import random import random
import platform import platform
import threading import threading
import typing
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache, version from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache, version
from shadowsocks.common import pre_parse_header, parse_header from shadowsocks.common import pre_parse_header, parse_header
@ -95,8 +96,9 @@ TCP_MSS = NETWORK_MTU - 40
BUF_SIZE = 32 * 1024 BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536 UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object): class SpeedTester(object):
def __init__(self, max_speed = 0): def __init__(self, max_speed=0):
self.max_speed = max_speed * 1024 self.max_speed = max_speed * 1024
self.last_time = time.time() self.last_time = time.time()
self.sum_len = 0 self.sum_len = 0
@ -123,6 +125,7 @@ class SpeedTester(object):
return self.sum_len >= self.max_speed return self.sum_len >= self.max_speed
return False return False
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config, def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local): dns_resolver, is_local):
@ -160,8 +163,8 @@ class TCPRelayHandler(object):
server_info = obfs.server_info(server.obfs_data) server_info = obfs.server_info(server.obfs_data)
server_info.host = config['server'] server_info.host = config['server']
server_info.port = server._listen_port server_info.port = server._listen_port
#server_info.users = server.server_users # server_info.users = server.server_users
#server_info.update_user_func = self._update_user # server_info.update_user_func = self._update_user
server_info.client = self._client_address[0] server_info.client = self._client_address[0]
server_info.client_port = self._client_address[1] server_info.client_port = self._client_address[1]
server_info.protocol_param = '' server_info.protocol_param = ''
@ -328,7 +331,7 @@ class TCPRelayHandler(object):
if self._remote_udp and sock == self._remote_sock: if self._remote_udp and sock == self._remote_sock:
try: try:
self._udp_data_send_buffer += data self._udp_data_send_buffer += data
#logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data))) # logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data)))
while len(self._udp_data_send_buffer) > 6: while len(self._udp_data_send_buffer) > 6:
length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0] length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0]
@ -352,15 +355,18 @@ class TCPRelayHandler(object):
af = common.is_ip(dest_addr) af = common.is_ip(dest_addr)
if af == False: if af == False:
handler = common.UDPAsyncDNSHandler(data[header_length:]) handler = common.UDPAsyncDNSHandler(data[header_length:])
handler.resolve(self._dns_resolver, (dest_addr, dest_port), self._handle_server_dns_resolved) handler.resolve(self._dns_resolver, (dest_addr, dest_port),
self._handle_server_dns_resolved)
else: else:
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:]) return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr,
data[header_length:])
else: else:
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:]) return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr,
data[header_length:])
except Exception as e: except Exception as e:
#trace = traceback.format_exc() # trace = traceback.format_exc()
#logging.error(trace) # logging.error(trace)
error_no = eventloop.errno_from_exception(e) error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS, if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK): errno.EWOULDBLOCK):
@ -391,7 +397,7 @@ class TCPRelayHandler(object):
errno.EWOULDBLOCK): errno.EWOULDBLOCK):
uncomplete = True uncomplete = True
else: else:
#traceback.print_exc() # traceback.print_exc()
shell.print_exception(e) shell.print_exception(e)
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy() self.destroy()
@ -409,14 +415,16 @@ class TCPRelayHandler(object):
self._data_to_write_to_remote.append(data) self._data_to_write_to_remote.append(data)
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
else: else:
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) logging.error(
'write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
else: else:
if sock == self._local_sock: if sock == self._local_sock:
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
elif sock == self._remote_sock: elif sock == self._remote_sock:
self._update_stream(STREAM_UP, WAIT_STATUS_READING) self._update_stream(STREAM_UP, WAIT_STATUS_READING)
else: else:
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) logging.error(
'write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
return True return True
def _handle_server_dns_resolved(self, error, remote_addr, server_addr, data): def _handle_server_dns_resolved(self, error, remote_addr, server_addr, data):
@ -424,7 +432,7 @@ class TCPRelayHandler(object):
return return
try: try:
addrs = socket.getaddrinfo(server_addr, remote_addr[1], 0, socket.SOCK_DGRAM, socket.SOL_UDP) addrs = socket.getaddrinfo(server_addr, remote_addr[1], 0, socket.SOCK_DGRAM, socket.SOL_UDP)
if not addrs: # drop if not addrs: # drop
return return
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
if af == socket.AF_INET6: if af == socket.AF_INET6:
@ -432,14 +440,16 @@ class TCPRelayHandler(object):
if self._udpv6_send_pack_id == 0: if self._udpv6_send_pack_id == 0:
addr, port = self._remote_sock_v6.getsockname()[:2] addr, port = self._remote_sock_v6.getsockname()[:2]
common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' % common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id)) (common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr,
port, self._user_id))
self._udpv6_send_pack_id += 1 self._udpv6_send_pack_id += 1
else: else:
self._remote_sock.sendto(data, (server_addr, remote_addr[1])) self._remote_sock.sendto(data, (server_addr, remote_addr[1]))
if self._udp_send_pack_id == 0: if self._udp_send_pack_id == 0:
addr, port = self._remote_sock.getsockname()[:2] addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' % common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id)) (common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr,
port, self._user_id))
self._udp_send_pack_id += 1 self._udp_send_pack_id += 1
return True return True
except Exception as e: except Exception as e:
@ -517,9 +527,10 @@ class TCPRelayHandler(object):
return ("0.0.0.0", 0) return ("0.0.0.0", 0)
def _handel_protocol_error(self, client_address, ogn_data): def _handel_protocol_error(self, client_address, ogn_data):
logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id)) logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (
binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id))
self._encrypt_correct = False self._encrypt_correct = False
#create redirect or disconnect by hash code # create redirect or disconnect by hash code
host, port = self._get_redirect_host(client_address, ogn_data) host, port = self._get_redirect_host(client_address, ogn_data)
if port == 0: if port == 0:
raise Exception('can not parse header') raise Exception('can not parse header')
@ -635,11 +646,11 @@ class TCPRelayHandler(object):
connecttype, addrtype, remote_addr, remote_port, header_length = header_result connecttype, addrtype, remote_addr, remote_port, header_length = header_result
if connecttype != 0: if connecttype != 0:
pass pass
#common.connect_log('UDP over TCP by user %d' % # common.connect_log('UDP over TCP by user %d' %
# (self._user_id, )) # (self._user_id, ))
else: else:
common.connect_log('TCP request %s:%d by user %d' % common.connect_log('TCP request %s:%d by user %d' %
(common.to_str(remote_addr), remote_port, self._user_id)) (common.to_str(remote_addr), remote_port, self._user_id))
self._remote_address = (common.to_str(remote_addr), remote_port) self._remote_address = (common.to_str(remote_addr), remote_port)
self._remote_udp = (connecttype != 0) self._remote_udp = (connecttype != 0)
# pause reading # pause reading
@ -709,14 +720,16 @@ class TCPRelayHandler(object):
if common.to_str(sa[0]) in self._forbidden_iplist: if common.to_str(sa[0]) in self._forbidden_iplist:
if self._remote_address: if self._remote_address:
raise Exception('IP %s is in forbidden list, when connect to %s:%d via port %d by UID %d' % raise Exception('IP %s is in forbidden list, when connect to %s:%d via port %d by UID %d' %
(common.to_str(sa[0]), self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id)) (common.to_str(sa[0]), self._remote_address[0], self._remote_address[1],
self._server._listen_port, self._user_id))
raise Exception('IP %s is in forbidden list, reject' % raise Exception('IP %s is in forbidden list, reject' %
common.to_str(sa[0])) common.to_str(sa[0]))
if self._forbidden_portset: if self._forbidden_portset:
if sa[1] in self._forbidden_portset: if sa[1] in self._forbidden_portset:
if self._remote_address: if self._remote_address:
raise Exception('Port %d is in forbidden list, when connect to %s:%d via port %d by UID %d' % raise Exception('Port %d is in forbidden list, when connect to %s:%d via port %d by UID %d' %
(sa[1], self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id)) (sa[1], self._remote_address[0], self._remote_address[1],
self._server._listen_port, self._user_id))
raise Exception('Port %d is in forbidden list, reject' % sa[1]) raise Exception('Port %d is in forbidden list, reject' % sa[1])
remote_sock = socket.socket(af, socktype, proto) remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock self._remote_sock = remote_sock
@ -777,24 +790,25 @@ class TCPRelayHandler(object):
self._server) self._server)
if self._remote_sock_v6: if self._remote_sock_v6:
self._loop.add(self._remote_sock_v6, self._loop.add(self._remote_sock_v6,
eventloop.POLL_IN, eventloop.POLL_IN,
self._server) self._server)
else: else:
try: try:
remote_sock.connect((remote_addr, remote_port)) remote_sock.connect((remote_addr, remote_port))
except (OSError, IOError) as e: except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in (errno.EINPROGRESS, if eventloop.errno_from_exception(e) in (errno.EINPROGRESS,
errno.EWOULDBLOCK): errno.EWOULDBLOCK):
pass # always goto here pass # always goto here
else: else:
raise e raise e
addr, port = self._remote_sock.getsockname()[:2] addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' % common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' %
(common.to_str(self._remote_address[0]), common.to_str(remote_addr), remote_port, addr, port, self._user_id)) (common.to_str(self._remote_address[0]), common.to_str(remote_addr),
remote_port, addr, port, self._user_id))
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT, eventloop.POLL_ERR | eventloop.POLL_OUT,
self._server) self._server)
self._stage = STAGE_CONNECTING self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
@ -858,7 +872,8 @@ class TCPRelayHandler(object):
try: try:
obfs_decode = self._obfs.server_decode(data) obfs_decode = self._obfs.server_decode(data)
if self._stage == STAGE_INIT: if self._stage == STAGE_INIT:
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local) self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(
self._is_local)
server_info = self._protocol.get_server_info() server_info = self._protocol.get_server_info()
server_info.overhead = self._overhead server_info.overhead = self._overhead
except Exception as e: except Exception as e:
@ -896,7 +911,8 @@ class TCPRelayHandler(object):
shell.print_exception(e) shell.print_exception(e)
if self._config['verbose']: if self._config['verbose']:
traceback.print_exc() traceback.print_exc()
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) logging.error(
"exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy() self.destroy()
return return
except Exception as e: except Exception as e:
@ -943,7 +959,7 @@ class TCPRelayHandler(object):
data = b'\x00\x04' + ip + port + data data = b'\x00\x04' + ip + port + data
size = len(data) + 2 size = len(data) + 2
data = struct.pack('>H', size) + data data = struct.pack('>H', size) + data
#logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1])) # logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
else: else:
if self._is_local: if self._is_local:
recv_buffer_size = BUF_SIZE recv_buffer_size = BUF_SIZE
@ -953,7 +969,7 @@ class TCPRelayHandler(object):
self._recv_pack_id += 1 self._recv_pack_id += 1
except (OSError, IOError) as e: except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in \ if eventloop.errno_from_exception(e) in \
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): # errno.WSAEWOULDBLOCK
return return
if not data: if not data:
self.destroy() self.destroy()
@ -1037,9 +1053,11 @@ class TCPRelayHandler(object):
if err.errno not in [errno.ECONNRESET]: if err.errno not in [errno.ECONNRESET]:
logging.error(err) logging.error(err)
if self._remote_address: if self._remote_address:
logging.error("remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1])) logging.error(
"remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1]))
else: else:
logging.error("remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1])) logging.error(
"remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy() self.destroy()
def handle_event(self, sock, fd, event): def handle_event(self, sock, fd, event):
@ -1171,9 +1189,10 @@ class TCPRelayHandler(object):
self._server.add_connection(-1) self._server.add_connection(-1)
self._server.stat_add(self._client_address[0], -1) self._server.stat_add(self._client_address[0], -1)
#import gc # import gc
#gc.collect() # gc.collect()
#logging.debug("gc %s" % (gc.garbage,)) # logging.debug("gc %s" % (gc.garbage,))
class TCPRelay(object): class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None):
@ -1185,7 +1204,7 @@ class TCPRelay(object):
self._fd_to_handlers = {} self._fd_to_handlers = {}
self.server_transfer_ul = 0 self.server_transfer_ul = 0
self.server_transfer_dl = 0 self.server_transfer_dl = 0
self.server_users = {} self.server_users: typing.Dict[bytes, bytes] = {}
self.server_users_cfg = {} self.server_users_cfg = {}
self.server_user_transfer_ul = {} self.server_user_transfer_ul = {}
self.server_user_transfer_dl = {} self.server_user_transfer_dl = {}
@ -1201,7 +1220,7 @@ class TCPRelay(object):
self._timeout = config['timeout'] self._timeout = config['timeout']
self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout,
close_callback=self._close_tcp_client) close_callback=self._close_tcp_client)
if is_local: if is_local:
listen_addr = config['local_address'] listen_addr = config['local_address']
@ -1277,7 +1296,7 @@ class TCPRelay(object):
self.del_user(uid) self.del_user(uid)
else: else:
passwd = items[1] passwd = items[1]
self.add_user(uid, {'password':passwd}) self.add_user(uid, {'password': passwd})
def _update_user(self, id, passwd): def _update_user(self, id, passwd):
uid = struct.pack('<I', id) uid = struct.pack('<I', id)
@ -1292,7 +1311,7 @@ class TCPRelay(object):
uid = struct.pack('<I', id) uid = struct.pack('<I', id)
self.add_user(uid, users[id]) self.add_user(uid, users[id])
def add_user(self, uid, cfg): # user: binstr[4], passwd: str def add_user(self, uid: bytes, cfg): # user: binstr[4], passwd: str
passwd = cfg['password'] passwd = cfg['password']
self.server_users[uid] = common.to_bytes(passwd) self.server_users[uid] = common.to_bytes(passwd)
self.server_users_cfg[uid] = cfg self.server_users_cfg[uid] = cfg
@ -1332,7 +1351,7 @@ class TCPRelay(object):
def speed_tester_u(self, uid): def speed_tester_u(self, uid):
if uid not in self._speed_tester_u: if uid not in self._speed_tester_u:
if self.mu: #TODO if self.mu: # TODO
self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0)) self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
else: else:
self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0)) self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
@ -1340,7 +1359,7 @@ class TCPRelay(object):
def speed_tester_d(self, uid): def speed_tester_d(self, uid):
if uid not in self._speed_tester_d: if uid not in self._speed_tester_d:
if self.mu: #TODO if self.mu: # TODO
self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0)) self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
else: else:
self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0)) self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
@ -1400,7 +1419,7 @@ class TCPRelay(object):
def _close_tcp_client(self, client): def _close_tcp_client(self, client):
if client.remote_address: if client.remote_address:
logging.debug('timed out: %s:%d' % logging.debug('timed out: %s:%d' %
client.remote_address) client.remote_address)
else: else:
logging.debug('timed out') logging.debug('timed out')
client.destroy() client.destroy()
@ -1421,8 +1440,8 @@ class TCPRelay(object):
logging.debug('accept') logging.debug('accept')
conn = self._server_socket.accept() conn = self._server_socket.accept()
handler = TCPRelayHandler(self, self._fd_to_handlers, handler = TCPRelayHandler(self, self._fd_to_handlers,
self._eventloop, conn[0], self._config, self._eventloop, conn[0], self._config,
self._dns_resolver, self._is_local) self._dns_resolver, self._is_local)
if handler.stage() == STAGE_DESTROYED: if handler.stage() == STAGE_DESTROYED:
conn[0].close() conn[0].close()
except (OSError, IOError) as e: except (OSError, IOError) as e:

Loading…
Cancel
Save