Browse Source

add typing hint and reformat code

akkariiin/Experimental
Akkariiin 6 years ago
parent
commit
5c3137a327
  1. 87
      shadowsocks/tcprelay.py

87
shadowsocks/tcprelay.py

@ -28,6 +28,7 @@ import traceback
import random
import platform
import threading
import typing
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache, version
from shadowsocks.common import pre_parse_header, parse_header
@ -95,8 +96,9 @@ TCP_MSS = NETWORK_MTU - 40
BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object):
def __init__(self, max_speed = 0):
def __init__(self, max_speed=0):
self.max_speed = max_speed * 1024
self.last_time = time.time()
self.sum_len = 0
@ -123,6 +125,7 @@ class SpeedTester(object):
return self.sum_len >= self.max_speed
return False
class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local):
@ -160,8 +163,8 @@ class TCPRelayHandler(object):
server_info = obfs.server_info(server.obfs_data)
server_info.host = config['server']
server_info.port = server._listen_port
#server_info.users = server.server_users
#server_info.update_user_func = self._update_user
# server_info.users = server.server_users
# server_info.update_user_func = self._update_user
server_info.client = self._client_address[0]
server_info.client_port = self._client_address[1]
server_info.protocol_param = ''
@ -328,7 +331,7 @@ class TCPRelayHandler(object):
if self._remote_udp and sock == self._remote_sock:
try:
self._udp_data_send_buffer += data
#logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data)))
# logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data)))
while len(self._udp_data_send_buffer) > 6:
length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0]
@ -352,15 +355,18 @@ class TCPRelayHandler(object):
af = common.is_ip(dest_addr)
if af == False:
handler = common.UDPAsyncDNSHandler(data[header_length:])
handler.resolve(self._dns_resolver, (dest_addr, dest_port), self._handle_server_dns_resolved)
handler.resolve(self._dns_resolver, (dest_addr, dest_port),
self._handle_server_dns_resolved)
else:
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:])
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr,
data[header_length:])
else:
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:])
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr,
data[header_length:])
except Exception as e:
#trace = traceback.format_exc()
#logging.error(trace)
# trace = traceback.format_exc()
# logging.error(trace)
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
@ -391,7 +397,7 @@ class TCPRelayHandler(object):
errno.EWOULDBLOCK):
uncomplete = True
else:
#traceback.print_exc()
# traceback.print_exc()
shell.print_exception(e)
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy()
@ -409,14 +415,16 @@ class TCPRelayHandler(object):
self._data_to_write_to_remote.append(data)
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
else:
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
logging.error(
'write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
else:
if sock == self._local_sock:
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
elif sock == self._remote_sock:
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
else:
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
logging.error(
'write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
return True
def _handle_server_dns_resolved(self, error, remote_addr, server_addr, data):
@ -432,14 +440,16 @@ class TCPRelayHandler(object):
if self._udpv6_send_pack_id == 0:
addr, port = self._remote_sock_v6.getsockname()[:2]
common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id))
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr,
port, self._user_id))
self._udpv6_send_pack_id += 1
else:
self._remote_sock.sendto(data, (server_addr, remote_addr[1]))
if self._udp_send_pack_id == 0:
addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id))
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr,
port, self._user_id))
self._udp_send_pack_id += 1
return True
except Exception as e:
@ -517,9 +527,10 @@ class TCPRelayHandler(object):
return ("0.0.0.0", 0)
def _handel_protocol_error(self, client_address, ogn_data):
logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id))
logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (
binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id))
self._encrypt_correct = False
#create redirect or disconnect by hash code
# create redirect or disconnect by hash code
host, port = self._get_redirect_host(client_address, ogn_data)
if port == 0:
raise Exception('can not parse header')
@ -635,7 +646,7 @@ class TCPRelayHandler(object):
connecttype, addrtype, remote_addr, remote_port, header_length = header_result
if connecttype != 0:
pass
#common.connect_log('UDP over TCP by user %d' %
# common.connect_log('UDP over TCP by user %d' %
# (self._user_id, ))
else:
common.connect_log('TCP request %s:%d by user %d' %
@ -709,14 +720,16 @@ class TCPRelayHandler(object):
if common.to_str(sa[0]) in self._forbidden_iplist:
if self._remote_address:
raise Exception('IP %s is in forbidden list, when connect to %s:%d via port %d by UID %d' %
(common.to_str(sa[0]), self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id))
(common.to_str(sa[0]), self._remote_address[0], self._remote_address[1],
self._server._listen_port, self._user_id))
raise Exception('IP %s is in forbidden list, reject' %
common.to_str(sa[0]))
if self._forbidden_portset:
if sa[1] in self._forbidden_portset:
if self._remote_address:
raise Exception('Port %d is in forbidden list, when connect to %s:%d via port %d by UID %d' %
(sa[1], self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id))
(sa[1], self._remote_address[0], self._remote_address[1],
self._server._listen_port, self._user_id))
raise Exception('Port %d is in forbidden list, reject' % sa[1])
remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock
@ -790,7 +803,8 @@ class TCPRelayHandler(object):
raise e
addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' %
(common.to_str(self._remote_address[0]), common.to_str(remote_addr), remote_port, addr, port, self._user_id))
(common.to_str(self._remote_address[0]), common.to_str(remote_addr),
remote_port, addr, port, self._user_id))
self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT,
@ -858,7 +872,8 @@ class TCPRelayHandler(object):
try:
obfs_decode = self._obfs.server_decode(data)
if self._stage == STAGE_INIT:
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local)
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(
self._is_local)
server_info = self._protocol.get_server_info()
server_info.overhead = self._overhead
except Exception as e:
@ -896,7 +911,8 @@ class TCPRelayHandler(object):
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))
logging.error(
"exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy()
return
except Exception as e:
@ -943,7 +959,7 @@ class TCPRelayHandler(object):
data = b'\x00\x04' + ip + port + data
size = len(data) + 2
data = struct.pack('>H', size) + data
#logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
# logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
else:
if self._is_local:
recv_buffer_size = BUF_SIZE
@ -953,7 +969,7 @@ class TCPRelayHandler(object):
self._recv_pack_id += 1
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in \
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): # errno.WSAEWOULDBLOCK
return
if not data:
self.destroy()
@ -1037,9 +1053,11 @@ class TCPRelayHandler(object):
if err.errno not in [errno.ECONNRESET]:
logging.error(err)
if self._remote_address:
logging.error("remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1]))
logging.error(
"remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1]))
else:
logging.error("remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1]))
logging.error(
"remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy()
def handle_event(self, sock, fd, event):
@ -1171,9 +1189,10 @@ class TCPRelayHandler(object):
self._server.add_connection(-1)
self._server.stat_add(self._client_address[0], -1)
#import gc
#gc.collect()
#logging.debug("gc %s" % (gc.garbage,))
# import gc
# gc.collect()
# logging.debug("gc %s" % (gc.garbage,))
class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None):
@ -1185,7 +1204,7 @@ class TCPRelay(object):
self._fd_to_handlers = {}
self.server_transfer_ul = 0
self.server_transfer_dl = 0
self.server_users = {}
self.server_users: typing.Dict[bytes, bytes] = {}
self.server_users_cfg = {}
self.server_user_transfer_ul = {}
self.server_user_transfer_dl = {}
@ -1277,7 +1296,7 @@ class TCPRelay(object):
self.del_user(uid)
else:
passwd = items[1]
self.add_user(uid, {'password':passwd})
self.add_user(uid, {'password': passwd})
def _update_user(self, id, passwd):
uid = struct.pack('<I', id)
@ -1292,7 +1311,7 @@ class TCPRelay(object):
uid = struct.pack('<I', id)
self.add_user(uid, users[id])
def add_user(self, uid, cfg): # user: binstr[4], passwd: str
def add_user(self, uid: bytes, cfg): # user: binstr[4], passwd: str
passwd = cfg['password']
self.server_users[uid] = common.to_bytes(passwd)
self.server_users_cfg[uid] = cfg
@ -1332,7 +1351,7 @@ class TCPRelay(object):
def speed_tester_u(self, uid):
if uid not in self._speed_tester_u:
if self.mu: #TODO
if self.mu: # TODO
self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
else:
self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
@ -1340,7 +1359,7 @@ class TCPRelay(object):
def speed_tester_d(self, uid):
if uid not in self._speed_tester_d:
if self.mu: #TODO
if self.mu: # TODO
self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
else:
self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))

Loading…
Cancel
Save