Browse Source

add typing hint and reformat code

akkariiin/Experimental
Akkariiin 6 years ago
parent
commit
5c3137a327
  1. 53
      shadowsocks/tcprelay.py

53
shadowsocks/tcprelay.py

@ -28,6 +28,7 @@ import traceback
import random
import platform
import threading
import typing
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache, version
from shadowsocks.common import pre_parse_header, parse_header
@ -95,6 +96,7 @@ TCP_MSS = NETWORK_MTU - 40
BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object):
def __init__(self, max_speed=0):
self.max_speed = max_speed * 1024
@ -123,6 +125,7 @@ class SpeedTester(object):
return self.sum_len >= self.max_speed
return False
class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local):
@ -352,11 +355,14 @@ class TCPRelayHandler(object):
af = common.is_ip(dest_addr)
if af == False:
handler = common.UDPAsyncDNSHandler(data[header_length:])
handler.resolve(self._dns_resolver, (dest_addr, dest_port), self._handle_server_dns_resolved)
handler.resolve(self._dns_resolver, (dest_addr, dest_port),
self._handle_server_dns_resolved)
else:
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:])
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr,
data[header_length:])
else:
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:])
return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr,
data[header_length:])
except Exception as e:
# trace = traceback.format_exc()
@ -409,14 +415,16 @@ class TCPRelayHandler(object):
self._data_to_write_to_remote.append(data)
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
else:
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
logging.error(
'write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
else:
if sock == self._local_sock:
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
elif sock == self._remote_sock:
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
else:
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
logging.error(
'write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
return True
def _handle_server_dns_resolved(self, error, remote_addr, server_addr, data):
@ -432,14 +440,16 @@ class TCPRelayHandler(object):
if self._udpv6_send_pack_id == 0:
addr, port = self._remote_sock_v6.getsockname()[:2]
common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id))
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr,
port, self._user_id))
self._udpv6_send_pack_id += 1
else:
self._remote_sock.sendto(data, (server_addr, remote_addr[1]))
if self._udp_send_pack_id == 0:
addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id))
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr,
port, self._user_id))
self._udp_send_pack_id += 1
return True
except Exception as e:
@ -517,7 +527,8 @@ class TCPRelayHandler(object):
return ("0.0.0.0", 0)
def _handel_protocol_error(self, client_address, ogn_data):
logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id))
logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (
binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id))
self._encrypt_correct = False
# create redirect or disconnect by hash code
host, port = self._get_redirect_host(client_address, ogn_data)
@ -709,14 +720,16 @@ class TCPRelayHandler(object):
if common.to_str(sa[0]) in self._forbidden_iplist:
if self._remote_address:
raise Exception('IP %s is in forbidden list, when connect to %s:%d via port %d by UID %d' %
(common.to_str(sa[0]), self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id))
(common.to_str(sa[0]), self._remote_address[0], self._remote_address[1],
self._server._listen_port, self._user_id))
raise Exception('IP %s is in forbidden list, reject' %
common.to_str(sa[0]))
if self._forbidden_portset:
if sa[1] in self._forbidden_portset:
if self._remote_address:
raise Exception('Port %d is in forbidden list, when connect to %s:%d via port %d by UID %d' %
(sa[1], self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id))
(sa[1], self._remote_address[0], self._remote_address[1],
self._server._listen_port, self._user_id))
raise Exception('Port %d is in forbidden list, reject' % sa[1])
remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock
@ -790,7 +803,8 @@ class TCPRelayHandler(object):
raise e
addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' %
(common.to_str(self._remote_address[0]), common.to_str(remote_addr), remote_port, addr, port, self._user_id))
(common.to_str(self._remote_address[0]), common.to_str(remote_addr),
remote_port, addr, port, self._user_id))
self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT,
@ -858,7 +872,8 @@ class TCPRelayHandler(object):
try:
obfs_decode = self._obfs.server_decode(data)
if self._stage == STAGE_INIT:
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local)
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(
self._is_local)
server_info = self._protocol.get_server_info()
server_info.overhead = self._overhead
except Exception as e:
@ -896,7 +911,8 @@ class TCPRelayHandler(object):
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))
logging.error(
"exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy()
return
except Exception as e:
@ -1037,9 +1053,11 @@ class TCPRelayHandler(object):
if err.errno not in [errno.ECONNRESET]:
logging.error(err)
if self._remote_address:
logging.error("remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1]))
logging.error(
"remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1]))
else:
logging.error("remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1]))
logging.error(
"remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1]))
self.destroy()
def handle_event(self, sock, fd, event):
@ -1175,6 +1193,7 @@ class TCPRelayHandler(object):
# gc.collect()
# logging.debug("gc %s" % (gc.garbage,))
class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None):
self._config = config
@ -1185,7 +1204,7 @@ class TCPRelay(object):
self._fd_to_handlers = {}
self.server_transfer_ul = 0
self.server_transfer_dl = 0
self.server_users = {}
self.server_users: typing.Dict[bytes, bytes] = {}
self.server_users_cfg = {}
self.server_user_transfer_ul = {}
self.server_user_transfer_dl = {}
@ -1292,7 +1311,7 @@ class TCPRelay(object):
uid = struct.pack('<I', id)
self.add_user(uid, users[id])
def add_user(self, uid, cfg): # user: binstr[4], passwd: str
def add_user(self, uid: bytes, cfg): # user: binstr[4], passwd: str
passwd = cfg['password']
self.server_users[uid] = common.to_bytes(passwd)
self.server_users_cfg[uid] = cfg

Loading…
Cancel
Save