From 774134ffadb51de2c3bca86b619df2a54be2987f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A0=B4=E5=A8=83=E9=85=B1?= Date: Tue, 25 Apr 2017 10:27:51 +0800 Subject: [PATCH] remove auth_sha1, auth_sha1_v2, verify_simple, auth_simple, verify_sha1 remove TCP over UDP add UDP part of auth_chain_a --- db_transfer.py | 4 +- shadowsocks/obfs.py | 7 +- shadowsocks/obfsplugin/auth.py | 450 +------------- shadowsocks/obfsplugin/auth_chain.py | 84 ++- shadowsocks/obfsplugin/plain.py | 2 +- shadowsocks/obfsplugin/verify.py | 206 ------- shadowsocks/tcprelay.py | 2 +- shadowsocks/udprelay.py | 837 +-------------------------- 8 files changed, 87 insertions(+), 1505 deletions(-) diff --git a/db_transfer.py b/db_transfer.py index 80d1d87..a5fb862 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -6,7 +6,7 @@ import time import sys from server_pool import ServerPool import traceback -from shadowsocks import common, shell, lru_cache +from shadowsocks import common, shell, lru_cache, obfs from configloader import load_config, get_config import importloader @@ -123,7 +123,7 @@ class TransferBase(object): if allow: allow_users[port] = passwd - if 'protocol' in cfg and 'protocol_param' in cfg and common.to_str(cfg['protocol']) in ['auth_aes128_md5', 'auth_aes128_sha1']: + if 'protocol' in cfg and 'protocol_param' in cfg and common.to_str(cfg['protocol']) in obfs.mu_protocol(): if '#' in common.to_str(cfg['protocol_param']): mu_servers[port] = passwd del allow_users[port] diff --git a/shadowsocks/obfs.py b/shadowsocks/obfs.py index 46867d8..3dfdb14 100644 --- a/shadowsocks/obfs.py +++ b/shadowsocks/obfs.py @@ -34,6 +34,9 @@ method_supported.update(verify.obfs_map) method_supported.update(auth.obfs_map) method_supported.update(auth_chain.obfs_map) +def mu_protocol(): + return ["auth_aes128_md5", "auth_aes128_sha1", "auth_chain_a"] + class server_info(object): def __init__(self, data): self.data = data @@ -99,8 +102,8 @@ class obfs(object): def client_udp_post_decrypt(self, buf): return self.obfs.client_udp_post_decrypt(buf) - def server_udp_pre_encrypt(self, buf): - return self.obfs.server_udp_pre_encrypt(buf) + def server_udp_pre_encrypt(self, buf, uid): + return self.obfs.server_udp_pre_encrypt(buf, uid) def server_udp_post_decrypt(self, buf): return self.obfs.server_udp_post_decrypt(buf) diff --git a/shadowsocks/obfsplugin/auth.py b/shadowsocks/obfsplugin/auth.py index 13c4411..51a1b00 100755 --- a/shadowsocks/obfsplugin/auth.py +++ b/shadowsocks/obfsplugin/auth.py @@ -37,12 +37,6 @@ from shadowsocks import common, lru_cache, encrypt from shadowsocks.obfsplugin import plain from shadowsocks.common import to_bytes, to_str, ord, chr -def create_auth_sha1(method): - return auth_sha1(method) - -def create_auth_sha1_v2(method): - return auth_sha1_v2(method) - def create_auth_sha1_v4(method): return auth_sha1_v4(method) @@ -53,10 +47,6 @@ def create_auth_aes128_sha1(method): return auth_aes128_sha1(method, hashlib.sha1) obfs_map = { - 'auth_sha1': (create_auth_sha1,), - 'auth_sha1_compatible': (create_auth_sha1,), - 'auth_sha1_v2': (create_auth_sha1_v2,), - 'auth_sha1_v2_compatible': (create_auth_sha1_v2,), 'auth_sha1_v4': (create_auth_sha1_v4,), 'auth_sha1_v4_compatible': (create_auth_sha1_v4,), 'auth_aes128_md5': (create_auth_aes128_md5,), @@ -69,10 +59,6 @@ def match_begin(str1, str2): return True return False -class obfs_verify_data(object): - def __init__(self): - pass - class auth_base(plain.plain): def __init__(self, method): super(auth_base, self).__init__(method) @@ -153,239 +139,6 @@ class client_queue(object): self.front += 1 return True -class obfs_auth_data(object): - def __init__(self): - self.client_id = {} - self.startup_time = int(time.time() - 30) & 0xFFFFFFFF - self.local_client_id = b'' - self.connection_id = 0 - self.set_max_client(64) # max active client count - - def update(self, client_id, connection_id): - if client_id in self.client_id: - self.client_id[client_id].update() - - def set_max_client(self, max_client): - self.max_client = max_client - self.max_buffer = max(self.max_client * 2, 256) - - def insert(self, client_id, connection_id): - if client_id not in self.client_id or not self.client_id[client_id].enable: - active = 0 - for c_id in self.client_id: - if self.client_id[c_id].is_active(): - active += 1 - if active >= self.max_client: - logging.warn('obfs auth: max active clients exceeded') - return False - - if len(self.client_id) < self.max_client: - if client_id not in self.client_id: - self.client_id[client_id] = client_queue(connection_id) - else: - self.client_id[client_id].re_enable(connection_id) - return self.client_id[client_id].insert(connection_id) - keys = self.client_id.keys() - random.shuffle(keys) - for c_id in keys: - if not self.client_id[c_id].is_active() and self.client_id[c_id].enable: - if len(self.client_id) >= self.max_buffer: - del self.client_id[c_id] - else: - self.client_id[c_id].enable = False - if client_id not in self.client_id: - self.client_id[client_id] = client_queue(connection_id) - else: - self.client_id[client_id].re_enable(connection_id) - return self.client_id[client_id].insert(connection_id) - logging.warn('obfs auth: no inactive client [assert]') - return False - else: - return self.client_id[client_id].insert(connection_id) - -class auth_sha1(auth_base): - def __init__(self, method): - super(auth_sha1, self).__init__(method) - self.recv_buf = b'' - self.unit_len = 8000 - self.decrypt_packet_num = 0 - self.raw_trans = False - self.has_sent_header = False - self.has_recv_header = False - self.client_id = 0 - self.connection_id = 0 - self.max_time_dif = 60 * 60 # time dif (second) setting - self.no_compatible_method = 'auth_sha1' - - def init_data(self): - return obfs_auth_data() - - def set_server_info(self, server_info): - self.server_info = server_info - try: - max_client = int(server_info.protocol_param) - except: - max_client = 64 - self.server_info.data.set_max_client(max_client) - - def pack_data(self, buf): - rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16) - data = common.chr(len(rnd_data) + 1) + rnd_data + buf - data = struct.pack('>H', len(data) + 6) + data - adler32 = zlib.adler32(data) & 0xFFFFFFFF - data += struct.pack('H', len(data) + 16) + data - crc = binascii.crc32(self.server_info.key) & 0xFFFFFFFF - data = struct.pack(' 0xFF000000: - self.server_info.data.local_client_id = b'' - if not self.server_info.data.local_client_id: - self.server_info.data.local_client_id = os.urandom(4) - logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) - self.server_info.data.connection_id = struct.unpack(' self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def client_post_decrypt(self, buf): - if self.raw_trans: - return buf - self.recv_buf += buf - out_buf = b'' - while len(self.recv_buf) > 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] - if length >= 8192 or length < 7: - self.raw_trans = True - self.recv_buf = b'' - raise Exception('client_post_decrypt data error') - if length > len(self.recv_buf): - break - - if struct.pack(' self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def server_post_decrypt(self, buf): - if self.raw_trans: - return (buf, False) - self.recv_buf += buf - out_buf = b'' - if not self.has_recv_header: - if len(self.recv_buf) < 6: - return (b'', False) - crc = struct.pack('H', self.recv_buf[4:6])[0] - if length > 2048: - return self.not_match_return(self.recv_buf) - if length > len(self.recv_buf): - return (b'', False) - sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, self.recv_buf[:length - 10], hashlib.sha1).digest()[:10] - if sha1data != self.recv_buf[length - 10:length]: - logging.error('auth_sha1 data uncorrect auth HMAC-SHA1') - return self.not_match_return(self.recv_buf) - pos = common.ord(self.recv_buf[6]) + 6 - out_buf = self.recv_buf[pos:length - 10] - if len(out_buf) < 12: - logging.info('auth_sha1: too short, data %s' % (binascii.hexlify(self.recv_buf),)) - return self.not_match_return(self.recv_buf) - utc_time = struct.unpack(' self.max_time_dif \ - or common.int32(utc_time - self.server_info.data.startup_time) < -self.max_time_dif / 2: - logging.info('auth_sha1: wrong timestamp, time_dif %d, data %s' % (time_dif, binascii.hexlify(out_buf),)) - return self.not_match_return(self.recv_buf) - elif self.server_info.data.insert(client_id, connection_id): - self.has_recv_header = True - out_buf = out_buf[12:] - self.client_id = client_id - self.connection_id = connection_id - else: - logging.info('auth_sha1: auth fail, data %s' % (binascii.hexlify(out_buf),)) - return self.not_match_return(self.recv_buf) - self.recv_buf = self.recv_buf[length:] - self.has_recv_header = True - - while len(self.recv_buf) > 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] - if length >= 8192 or length < 7: - self.raw_trans = True - self.recv_buf = b'' - if self.decrypt_packet_num == 0: - logging.info('auth_sha1: over size') - return (b'E'*2048, False) - else: - raise Exception('server_post_decrype data error') - if length > len(self.recv_buf): - break - - if struct.pack(' 1300: - return b'\x01' - - if buf_size > 400: - rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 128) - return common.chr(len(rnd_data) + 1) + rnd_data - - rnd_data = os.urandom(struct.unpack('>H', os.urandom(2))[0] % 1024) - return common.chr(255) + struct.pack('>H', len(rnd_data) + 3) + rnd_data - - def pack_data(self, buf): - data = self.rnd_data(len(buf)) + buf - data = struct.pack('>H', len(data) + 6) + data - adler32 = zlib.adler32(data) & 0xFFFFFFFF - data += struct.pack('H', len(data) + 16) + data - crc = binascii.crc32(self.salt + self.server_info.key) & 0xFFFFFFFF - data = struct.pack(' 0xFF000000: - self.server_info.data.local_client_id = b'' - if not self.server_info.data.local_client_id: - self.server_info.data.local_client_id = os.urandom(8) - logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) - self.server_info.data.connection_id = struct.unpack(' self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def client_post_decrypt(self, buf): - if self.raw_trans: - return buf - self.recv_buf += buf - out_buf = b'' - while len(self.recv_buf) > 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] - if length >= 8192 or length < 7: - self.raw_trans = True - self.recv_buf = b'' - raise Exception('client_post_decrypt data error') - if length > len(self.recv_buf): - break - - if struct.pack('H', self.recv_buf[3:5])[0] + 2 - out_buf += self.recv_buf[pos:length - 4] - self.recv_buf = self.recv_buf[length:] - - if out_buf: - self.decrypt_packet_num += 1 - return out_buf - - def server_pre_encrypt(self, buf): - if self.raw_trans: - return buf - ret = b'' - while len(buf) > self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def server_post_decrypt(self, buf): - if self.raw_trans: - return (buf, False) - self.recv_buf += buf - out_buf = b'' - sendback = False - - if not self.has_recv_header: - if len(self.recv_buf) < 6: - return (b'', False) - crc = struct.pack('H', self.recv_buf[4:6])[0] - if length > 2048: - return self.not_match_return(self.recv_buf) - if length > len(self.recv_buf): - return (b'', False) - sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, self.recv_buf[:length - 10], hashlib.sha1).digest()[:10] - if sha1data != self.recv_buf[length - 10:length]: - logging.error('auth_sha1_v2 data uncorrect auth HMAC-SHA1') - return self.not_match_return(self.recv_buf) - pos = common.ord(self.recv_buf[6]) - if pos < 255: - pos += 6 - else: - pos = struct.unpack('>H', self.recv_buf[7:9])[0] + 6 - out_buf = self.recv_buf[pos:length - 10] - if len(out_buf) < 12: - logging.info('auth_sha1_v2: too short, data %s' % (binascii.hexlify(self.recv_buf),)) - return self.not_match_return(self.recv_buf) - client_id = struct.unpack(' 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] - if length >= 8192 or length < 7: - self.raw_trans = True - self.recv_buf = b'' - if self.decrypt_packet_num == 0: - logging.info('auth_sha1_v2: over size') - return (b'E'*2048, False) - else: - raise Exception('server_post_decrype data error') - if length > len(self.recv_buf): - break - - if struct.pack('H', self.recv_buf[3:5])[0] + 2 - out_buf += self.recv_buf[pos:length - 4] - self.recv_buf = self.recv_buf[length:] - if pos == length - 4: - sendback = True - - if out_buf: - self.server_info.data.update(self.client_id, self.connection_id) - self.decrypt_packet_num += 1 - return (out_buf, sendback) - class auth_sha1_v4(auth_base): def __init__(self, method): super(auth_sha1_v4, self).__init__(method) @@ -1215,7 +767,7 @@ class auth_aes128_sha1(auth_base): return b'' return buf[:-4] - def server_udp_pre_encrypt(self, buf): + def server_udp_pre_encrypt(self, buf, uid): user_key = self.server_info.key return buf + hmac.new(user_key, buf, self.hashfunc).digest()[:4] diff --git a/shadowsocks/obfsplugin/auth_chain.py b/shadowsocks/obfsplugin/auth_chain.py index 2bb21da..ff0a812 100644 --- a/shadowsocks/obfsplugin/auth_chain.py +++ b/shadowsocks/obfsplugin/auth_chain.py @@ -289,6 +289,10 @@ class auth_chain_a(auth_base): return random.next() % 521 return random.next() % 1021 + def udp_rnd_data_len(self, last_hash, random): + random.init_from_bin(last_hash) + return random.next() % 127 + def rnd_start_pos(self, rand_len, random): if rand_len > 0: return random.next() % 8589934609 % rand_len @@ -457,20 +461,12 @@ class auth_chain_a(auth_base): return (b'', False) self.last_client_hash = md5data - md5data = hmac.new(mac_key, self.recv_buf[12 : 12 + 20], self.hashfunc).digest() - if md5data[:4] != self.recv_buf[32:36]: - logging.error('%s data uncorrect auth HMAC-SHA1 from %s:%d, data %s' % (self.no_compatible_method, self.server_info.client, self.server_info.client_port, binascii.hexlify(self.recv_buf))) - if len(self.recv_buf) < 36: - return (b'', False) - return self.not_match_return(self.recv_buf) - - self.last_server_hash = md5data uid = struct.unpack('H', len(data) + 6) + data - crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff - data += struct.pack(' self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def client_post_decrypt(self, buf): - if self.raw_trans: - return buf - self.recv_buf += buf - out_buf = b'' - while len(self.recv_buf) > 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] - if length >= 8192 or length < 7: - self.raw_trans = True - self.recv_buf = b'' - raise Exception('client_post_decrypt data error') - if length > len(self.recv_buf): - break - - if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff: - self.raw_trans = True - self.recv_buf = b'' - raise Exception('client_post_decrypt data uncorrect CRC32') - - pos = common.ord(self.recv_buf[2]) + 2 - out_buf += self.recv_buf[pos:length - 4] - self.recv_buf = self.recv_buf[length:] - - if out_buf: - self.decrypt_packet_num += 1 - return out_buf - - def server_pre_encrypt(self, buf): - ret = b'' - while len(buf) > self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def server_post_decrypt(self, buf): - if self.raw_trans: - return (buf, False) - self.recv_buf += buf - out_buf = b'' - while len(self.recv_buf) > 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] - if length >= 8192 or length < 7: - self.raw_trans = True - self.recv_buf = b'' - if self.decrypt_packet_num == 0: - return (b'E'*2048, False) - else: - raise Exception('server_post_decrype data error') - if length > len(self.recv_buf): - break - - if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff: - self.raw_trans = True - self.recv_buf = b'' - if self.decrypt_packet_num == 0: - return (b'E'*2048, False) - else: - raise Exception('server_post_decrype data uncorrect CRC32') - - pos = common.ord(self.recv_buf[2]) + 2 - out_buf += self.recv_buf[pos:length - 4] - self.recv_buf = self.recv_buf[length:] - - if out_buf: - self.decrypt_packet_num += 1 - return (out_buf, False) - class verify_deflate(verify_base): def __init__(self, method): super(verify_deflate, self).__init__(method) @@ -258,103 +152,3 @@ class verify_deflate(verify_base): self.decrypt_packet_num += 1 return (out_buf, False) -class verify_sha1(verify_base): - def __init__(self, method): - super(verify_sha1, self).__init__(method) - self.recv_buf = b'' - self.unit_len = 8100 - self.raw_trans = False - self.pack_id = 0 - self.recv_id = 0 - self.has_sent_header = False - self.has_recv_header = False - - def pack_data(self, buf): - if len(buf) == 0: - return b'' - sha1data = hmac.new(self.server_info.iv + struct.pack('>I', self.pack_id), buf, hashlib.sha1).digest() - data = struct.pack('>H', len(buf)) + sha1data[:10] + buf - self.pack_id += 1 - return data - - def pack_auth_data(self, buf): - data = chr(ord(buf[0]) | 0x10) + buf[1:] - data += hmac.new(self.server_info.iv + self.server_info.key, data, hashlib.sha1).digest()[:10] - return data - - def client_pre_encrypt(self, buf): - ret = b'' - if not self.has_sent_header: - datalen = self.get_head_size(buf, 30) - ret += self.pack_auth_data(buf[:datalen]) - buf = buf[datalen:] - self.has_sent_header = True - while len(buf) > self.unit_len: - ret += self.pack_data(buf[:self.unit_len]) - buf = buf[self.unit_len:] - ret += self.pack_data(buf) - return ret - - def client_post_decrypt(self, buf): - return buf - - def server_pre_encrypt(self, buf): - return buf - - def not_match_return(self, buf): - self.raw_trans = True - if self.method == 'verify_sha1': - return (b'E'*2048, False) - return (buf, False) - - def server_post_decrypt(self, buf): - if self.raw_trans: - return (buf, False) - self.recv_buf += buf - out_buf = b'' - if not self.has_recv_header: - if len(self.recv_buf) < 2: - return (b'', False) - if (ord(self.recv_buf[0]) & 0x10) != 0x10: - return self.not_match_return(self.recv_buf) - head_size = self.get_head_size(self.recv_buf, 65536) - if len(self.recv_buf) < head_size + 10: - return self.not_match_return(self.recv_buf) - sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, self.recv_buf[:head_size], hashlib.sha1).digest()[:10] - if sha1data != self.recv_buf[head_size:head_size + 10]: - logging.error('server_post_decrype data uncorrect auth HMAC-SHA1') - return self.not_match_return(self.recv_buf) - out_buf = to_bytes(chr(ord(self.recv_buf[0]) & 0xEF)) + self.recv_buf[1:head_size] - self.recv_buf = self.recv_buf[head_size + 10:] - self.has_recv_header = True - while len(self.recv_buf) > 2: - length = struct.unpack('>H', self.recv_buf[:2])[0] + 12 - if length > len(self.recv_buf): - break - - data = self.recv_buf[12:length] - sha1data = hmac.new(self.server_info.recv_iv + struct.pack('>I', self.recv_id), data, hashlib.sha1).digest()[:10] - if sha1data != self.recv_buf[2:12]: - raise Exception('server_post_decrype data uncorrect chunk HMAC-SHA1') - - self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF - out_buf += data - self.recv_buf = self.recv_buf[length:] - - return (out_buf, False) - - def client_udp_pre_encrypt(self, buf): - ret = self.pack_auth_data(buf) - return chr(ord(buf[0]) | 0x10) + buf[1:] - - def server_udp_post_decrypt(self, buf): - if buf and ((ord(buf[0]) & 0x10) == 0x10): - if len(buf) <= 11: - return (b'', None) - sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, buf[:-10], hashlib.sha1).digest()[:10] - if sha1data != buf[-10:]: - return (b'', None) - return (to_bytes(chr(ord(buf[0]) & 0xEF)) + buf[1:-10], None) - else: - return (buf, None) - diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 3616761..467c0f8 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -1116,7 +1116,7 @@ class TCPRelay(object): listen_port = config['server_port'] self._listen_port = listen_port - if common.to_bytes(config['protocol']) in [b"auth_aes128_md5", b"auth_aes128_sha1"]: + if common.to_str(config['protocol']) in obfs.mu_protocol(): self._update_users(None, None) addrs = socket.getaddrinfo(listen_addr, listen_port, 0, diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 1b4aae9..98b8e83 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -123,751 +123,6 @@ RSP_STATE_ERROR = b"\x03" RSP_STATE_DISCONNECT = b"\x04" RSP_STATE_REDIRECT = b"\x05" -class UDPLocalAddress(object): - def __init__(self, addr): - self.addr = addr - self.last_activity = time.time() - - def is_timeout(self): - return time.time() - self.last_activity > 30 - -class PacketInfo(object): - def __init__(self, data): - self.data = data - self.time = time.time() - -class SendingQueue(object): - def __init__(self): - self.queue = {} - self.begin_id = 0 - self.end_id = 1 - self.interval = 0.5 - - def append(self, data): - self.queue[self.end_id] = PacketInfo(data) - self.end_id += 1 - return self.end_id - 1 - - def empty(self): - return self.begin_id + 1 == self.end_id - - def size(self): - return self.end_id - self.begin_id - 1 - - def get_begin_id(self): - return self.begin_id - - def get_end_id(self): - return self.end_id - - def get_data_list(self, pack_id_base, pack_id_list): - ret_list = [] - curtime = time.time() - for pack_id in pack_id_list: - offset = pack_id_base + pack_id - if offset <= self.begin_id or self.end_id <= offset: - continue - ret_data = self.queue[offset] - if curtime - ret_data.time > self.interval: - ret_data.time = curtime - ret_list.append( (offset, ret_data.data) ) - return ret_list - - def set_finish(self, begin_id, done_list): - while self.begin_id < begin_id: - self.begin_id += 1 - del self.queue[self.begin_id] - -class RecvQueue(object): - def __init__(self): - self.queue = {} - self.miss_queue = set() - self.begin_id = 0 - self.end_id = 1 - - def empty(self): - return self.begin_id + 1 == self.end_id - - def insert(self, pack_id, data): - if (pack_id not in self.queue) and pack_id > self.begin_id: - self.queue[pack_id] = PacketInfo(data) - if self.end_id == pack_id: - self.end_id = pack_id + 1 - elif self.end_id < pack_id: - eid = self.end_id - while eid < pack_id: - self.miss_queue.add(eid) - eid += 1 - self.end_id = pack_id + 1 - else: - self.miss_queue.remove(pack_id) - - def set_end(self, end_id): - if end_id > self.end_id: - eid = self.end_id - while eid < end_id: - self.miss_queue.add(eid) - eid += 1 - self.end_id = end_id - - def get_begin_id(self): - return self.begin_id - - def has_data(self): - return (self.begin_id + 1) in self.queue - - def get_data(self): - if (self.begin_id + 1) in self.queue: - self.begin_id += 1 - pack_id = self.begin_id - ret_data = self.queue[pack_id] - del self.queue[pack_id] - return (pack_id, ret_data.data) - - def get_missing_id(self, begin_id): - missing = [] - if begin_id == 0: - begin_id = self.begin_id - for i in self.miss_queue: - if i - begin_id > 32768: - break - missing.append(i - begin_id) - return (begin_id, missing) - -class AddressMap(object): - def __init__(self): - self._queue = [] - self._addr_map = {} - - def add(self, addr): - if addr in self._addr_map: - self._addr_map[addr] = UDPLocalAddress(addr) - else: - self._addr_map[addr] = UDPLocalAddress(addr) - self._queue.append(addr) - - def keys(self): - return self._queue - - def get(self): - if self._queue: - while True: - if len(self._queue) == 1: - return self._queue[0] - index = random.randint(0, len(self._queue) - 1) - addr = self._queue[index] - if self._addr_map[addr].is_timeout(): - self._queue[index] = self._queue[len(self._queue) - 1] - del self._queue[len(self._queue) - 1] - del self._addr_map[addr] - else: - break - return addr - else: - return None - -class TCPRelayHandler(object): - def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop, - local_sock, local_id, client_param, config, - dns_resolver, is_local): - self._server = server - self._reqid_to_handlers = reqid_to_handlers - self._fd_to_handlers = fd_to_handlers - self._loop = loop - self._local_sock = local_sock - self._remote_sock = None - self._remote_udp = False - self._config = config - self._dns_resolver = dns_resolver - self._local_id = local_id - - self._is_local = is_local - self._stage = STAGE_INIT - self._password = config['password'] - self._method = config['method'] - self._fastopen_connected = False - self._data_to_write_to_local = [] - self._data_to_write_to_remote = [] - self._upstream_status = WAIT_STATUS_READING - self._downstream_status = WAIT_STATUS_INIT - self._request_id = 0 - self._client_address = AddressMap() - self._remote_address = None - self._sendingqueue = SendingQueue() - self._recvqueue = RecvQueue() - if 'forbidden_ip' in config: - self._forbidden_iplist = config['forbidden_ip'] - else: - self._forbidden_iplist = None - if 'forbidden_port' in config: - self._forbidden_portset = config['forbidden_port'] - else: - self._forbidden_portset = None - #fd_to_handlers[local_sock.fileno()] = self - #local_sock.setblocking(False) - #loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) - self.last_activity = 0 - self._update_activity() - self._random_mtu_size = [random.randint(POST_MTU_MIN, POST_MTU_MAX) for i in range(1024)] - self._random_mtu_index = 0 - - self._rand_data = b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10" * 4 - - def __hash__(self): - # default __hash__ is id / 16 - # we want to eliminate collisions - return id(self) - - @property - def remote_address(self): - return self._remote_address - - def add_local_address(self, addr): - self._client_address.add(addr) - - def get_local_address(self): - return self._client_address.get() - - def _update_activity(self): - # tell the TCP Relay we have activities recently - # else it will think we are inactive and timed out - self._server.update_activity(self) - - def _update_stream(self, stream, status): - # update a stream to a new waiting status - - # check if status is changed - # only update if dirty - dirty = False - if stream == STREAM_DOWN: - if self._downstream_status != status: - self._downstream_status = status - dirty = True - elif stream == STREAM_UP: - if self._upstream_status != status: - self._upstream_status = status - dirty = True - if dirty: - ''' - if self._local_sock: - event = eventloop.POLL_ERR - if self._downstream_status & WAIT_STATUS_WRITING: - event |= eventloop.POLL_OUT - if self._upstream_status & WAIT_STATUS_READING: - event |= eventloop.POLL_IN - self._loop.modify(self._local_sock, event) - ''' - if self._remote_sock: - event = eventloop.POLL_ERR - if self._downstream_status & WAIT_STATUS_READING: - event |= eventloop.POLL_IN - if self._upstream_status & WAIT_STATUS_WRITING: - event |= eventloop.POLL_OUT - self._loop.modify(self._remote_sock, event) - - def _write_to_sock(self, data, sock, addr = None): - # write data to sock - # if only some of the data are written, put remaining in the buffer - # and update the stream to wait for writing - if not data or not sock: - return False - - uncomplete = False - retry = 0 - if sock == self._local_sock: - data = encrypt.encrypt_all(self._password, self._method, 1, data) - if addr is None: - return False - try: - self._server.write_to_server_socket(data, addr) - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - uncomplete = True - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - pass - else: - #traceback.print_exc() - shell.print_exception(e) - self.destroy() - return False - else: - try: - l = len(data) - s = sock.send(data) - if s < l: - data = data[s:] - uncomplete = True - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - uncomplete = True - else: - #logging.error(traceback.extract_stack()) - #traceback.print_exc() - shell.print_exception(e) - self.destroy() - return False - if uncomplete: - if sock == self._local_sock: - self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) - elif sock == self._remote_sock: - self._data_to_write_to_remote.append(data) - self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) - else: - logging.error('write_all_to_sock:unknown socket') - else: - if sock == self._local_sock: - if self._sendingqueue.size() > SENDING_WINDOW_SIZE: - self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) - else: - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - elif sock == self._remote_sock: - self._update_stream(STREAM_UP, WAIT_STATUS_READING) - else: - logging.error('write_all_to_sock:unknown socket') - return True - - def _create_remote_socket(self, ip, port): - addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) - if len(addrs) == 0: - raise Exception("getaddrinfo failed for %s:%d" % (ip, port)) - af, socktype, proto, canonname, sa = addrs[0] - if self._forbidden_iplist: - if common.to_str(sa[0]) in self._forbidden_iplist: - raise Exception('IP %s is in forbidden list, reject' % - common.to_str(sa[0])) - remote_sock = socket.socket(af, socktype, proto) - self._remote_sock = remote_sock - - self._fd_to_handlers[remote_sock.fileno()] = self - - remote_sock.setblocking(False) - remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - return remote_sock - - def _handle_dns_resolved(self, result, error): - if error: - self._log_error(error) - self.destroy() - return - if result: - ip = result[1] - if ip: - - try: - self._stage = STAGE_CONNECTING - remote_addr = ip - remote_port = self._remote_address[1] - logging.info("connect to %s : %d" % (remote_addr, remote_port)) - - remote_sock = self._create_remote_socket(remote_addr, - remote_port) - try: - remote_sock.connect((remote_addr, remote_port)) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) in (errno.EINPROGRESS, - errno.EWOULDBLOCK): - pass # always goto here - else: - raise e - - self._loop.add(remote_sock, - eventloop.POLL_ERR | eventloop.POLL_OUT, - self._server) - self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - self._stage = STAGE_STREAM - - addr = self.get_local_address() - - for i in range(2): - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) - self._write_to_sock(rsp_data, self._local_sock, addr) - - return - except Exception as e: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() - self.destroy() - - def _on_local_read(self): - # handle all local read events and dispatch them to methods for - # each stage - self._update_activity() - if not self._local_sock: - return - data = None - try: - data = self._local_sock.recv(BUF_SIZE) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) in \ - (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): - return - if not data: - self.destroy() - return - if not data: - return - self._server.server_transfer_ul += len(data) - #TODO ============================================================ - if self._stage == STAGE_STREAM: - self._write_to_sock(data, self._remote_sock) - return - - def _on_remote_read(self): - # handle all remote read events - self._update_activity() - data = None - try: - data = self._remote_sock.recv(BUF_SIZE) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) in \ - (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK - return - if not data: - self.destroy() - return - try: - self._server.server_transfer_dl += len(data) - recv_data = data - beg_pos = 0 - max_len = len(recv_data) - while beg_pos < max_len: - if beg_pos + POST_MTU_MAX >= max_len: - split_pos = max_len - else: - split_pos = beg_pos + self._random_mtu_size[self._random_mtu_index] - self._random_mtu_index = (self._random_mtu_index + 1) & 0x3ff - #split_pos = beg_pos + random.randint(POST_MTU_MIN, POST_MTU_MAX) - data = recv_data[beg_pos:split_pos] - beg_pos = split_pos - - pack_id = self._sendingqueue.append(data) - post_data = self._pack_post_data(CMD_POST, pack_id, data) - addr = self.get_local_address() - self._write_to_sock(post_data, self._local_sock, addr) - if pack_id <= DOUBLE_SEND_BEG_IDS: - post_data = self._pack_post_data(CMD_POST, pack_id, data) - self._write_to_sock(post_data, self._local_sock, addr) - - except Exception as e: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() - # TODO use logging when debug completed - self.destroy() - - def _on_local_write(self): - # handle local writable event - if self._data_to_write_to_local: - data = b''.join(self._data_to_write_to_local) - self._data_to_write_to_local = [] - self._write_to_sock(data, self._local_sock) - else: - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - - def _on_remote_write(self): - # handle remote writable event - self._stage = STAGE_STREAM - if self._data_to_write_to_remote: - data = b''.join(self._data_to_write_to_remote) - self._data_to_write_to_remote = [] - self._write_to_sock(data, self._remote_sock) - else: - self._update_stream(STREAM_UP, WAIT_STATUS_READING) - - def _on_local_error(self): - logging.debug('got local error') - if self._local_sock: - logging.error(eventloop.get_sock_error(self._local_sock)) - self.destroy() - - def _on_remote_error(self): - logging.debug('got remote error') - if self._remote_sock: - logging.error(eventloop.get_sock_error(self._remote_sock)) - self.destroy() - - def _pack_rsp_data(self, cmd, data): - reqid_str = struct.pack(">H", self._request_id) - return b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, data, self._rand_data[:random.randint(0, len(self._rand_data))], reqid_str]) - - def _pack_rnd_data(self, data): - length = random.randint(0, len(self._rand_data)) - if length == 0: - return data - elif length == 1: - return b"\x81" + data - elif length < 256: - return b"\x80" + common.chr(length) + self._rand_data[:length - 2] + data - else: - return b"\x82" + struct.pack(">H", length) + self._rand_data[:length - 3] + data - - def _pack_post_data(self, cmd, pack_id, data): - reqid_str = struct.pack(">H", self._request_id) - recv_id = self._recvqueue.get_begin_id() - rsp_data = b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, struct.pack(">I", recv_id), struct.pack(">I", pack_id), data, reqid_str]) - return rsp_data - - def _pack_post_data_64(self, cmd, send_id, pack_id, data): - reqid_str = struct.pack(">H", self._request_id) - recv_id = self._recvqueue.get_begin_id() - rsp_data = b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, struct.pack(">Q", recv_id), struct.pack(">Q", pack_id), data, reqid_str]) - return rsp_data - - def sweep_timeout(self): - logging.info("sweep_timeout") - if self._stage == STAGE_STREAM: - pack_id, missing = self._recvqueue.get_missing_id(0) - logging.info("sweep_timeout %s %s" % (pack_id, missing)) - data = b'' - for pid in missing: - data += struct.pack(">H", pid) - rsp_data = self._pack_post_data(CMD_SYN_STATUS, pack_id, data) - addr = self.get_local_address() - self._write_to_sock(rsp_data, self._local_sock, addr) - - def handle_stream_sync_status(self, addr, cmd, request_id, pack_id, max_send_id, data): - missing_list = [] - while len(data) >= 2: - pid = struct.unpack(">H", data[0:2])[0] - data = data[2:] - missing_list.append(pid) - done_list = [] - self._recvqueue.set_end(max_send_id) - self._sendingqueue.set_finish(pack_id, done_list) - - if self._stage == STAGE_DESTROYED and self._sendingqueue.empty(): - self.destroy_local() - return - - # post CMD_SYN_STATUS - send_id = self._sendingqueue.get_end_id() - post_pack_id, missing = self._recvqueue.get_missing_id(0) - pack_ids_data = b'' - for pid in missing: - pack_ids_data += struct.pack(">H", pid) - - rsp_data = self._pack_rnd_data(self._pack_post_data(CMD_SYN_STATUS, send_id, pack_ids_data)) - self._write_to_sock(rsp_data, self._local_sock, addr) - - send_list = self._sendingqueue.get_data_list(pack_id, missing_list) - for post_pack_id, post_data in send_list: - rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) - self._write_to_sock(rsp_data, self._local_sock, addr) - if post_pack_id <= DOUBLE_SEND_BEG_IDS: - rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) - self._write_to_sock(rsp_data, self._local_sock, addr) - - def handle_client(self, addr, cmd, request_id, data): - self.add_local_address(addr) - if cmd == CMD_DISCONNECT: - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - self.destroy() - self.destroy_local() - return - if self._stage == STAGE_INIT: - if cmd == CMD_CONNECT: - self._request_id = request_id - self._stage = STAGE_RSP_ID - return - if self._request_id != request_id: - return - - if self._stage == STAGE_RSP_ID: - if cmd == CMD_CONNECT: - for i in range(2): - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, RSP_STATE_CONNECTED) - self._write_to_sock(rsp_data, self._local_sock, addr) - elif cmd == CMD_CONNECT_REMOTE: - local_id = data[0:4] - if self._local_id == local_id: - data = data[4:] - header_result = parse_header(data) - if header_result is None: - return - connecttype, remote_addr, remote_port, header_length = header_result - self._remote_address = (common.to_str(remote_addr), remote_port) - self._stage = STAGE_DNS - self._dns_resolver.resolve(remote_addr, - self._handle_dns_resolved) - common.connect_log('TCPonUDP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1])) - else: - # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - elif self._stage == STAGE_CONNECTING: - if cmd == CMD_CONNECT_REMOTE: - local_id = data[0:4] - if self._local_id == local_id: - for i in range(2): - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) - self._write_to_sock(rsp_data, self._local_sock, addr) - else: - # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - elif self._stage == STAGE_STREAM: - if len(data) < 4: - # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - return - local_id = data[0:4] - if self._local_id != local_id: - # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - return - else: - data = data[4:] - if cmd == CMD_CONNECT_REMOTE: - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) - self._write_to_sock(rsp_data, self._local_sock, addr) - elif cmd == CMD_POST: - recv_id = struct.unpack(">I", data[0:4])[0] - pack_id = struct.unpack(">I", data[4:8])[0] - self._recvqueue.insert(pack_id, data[8:]) - self._sendingqueue.set_finish(recv_id, []) - elif cmd == CMD_POST_64: - recv_id = struct.unpack(">Q", data[0:8])[0] - pack_id = struct.unpack(">Q", data[8:16])[0] - self._recvqueue.insert(pack_id, data[16:]) - self._sendingqueue.set_finish(recv_id, []) - elif cmd == CMD_DISCONNECT: - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - self.destroy() - self.destroy_local() - return - elif cmd == CMD_SYN_STATUS: - pack_id = struct.unpack(">I", data[0:4])[0] - max_send_id = struct.unpack(">I", data[4:8])[0] - data = data[8:] - self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) - elif cmd == CMD_SYN_STATUS_64: - pack_id = struct.unpack(">Q", data[0:8])[0] - max_send_id = struct.unpack(">Q", data[8:16])[0] - data = data[16:] - self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) - while self._recvqueue.has_data(): - pack_id, post_data = self._recvqueue.get_data() - self._write_to_sock(post_data, self._remote_sock) - elif self._stage == STAGE_DESTROYED: - local_id = data[0:4] - if self._local_id != local_id: - # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - self._write_to_sock(rsp_data, self._local_sock, addr) - return - else: - data = data[4:] - if cmd == CMD_SYN_STATUS: - pack_id = struct.unpack(">I", data[0:4])[0] - max_send_id = struct.unpack(">I", data[4:8])[0] - data = data[8:] - self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) - elif cmd == CMD_SYN_STATUS_64: - pack_id = struct.unpack(">Q", data[0:8])[0] - max_send_id = struct.unpack(">Q", data[8:16])[0] - data = data[16:] - self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) - - def handle_event(self, sock, event): - # handle all events in this handler and dispatch them to methods - handle = False - if self._stage == STAGE_DESTROYED: - logging.debug('ignore handle_event: destroyed') - return True - # order is important - if sock == self._remote_sock: - if event & eventloop.POLL_ERR: - handle = True - self._on_remote_error() - if self._stage == STAGE_DESTROYED: - return True - if event & (eventloop.POLL_IN | eventloop.POLL_HUP): - handle = True - self._on_remote_read() - if self._stage == STAGE_DESTROYED: - return True - if event & eventloop.POLL_OUT: - handle = True - self._on_remote_write() - elif sock == self._local_sock: - if event & eventloop.POLL_ERR: - handle = True - self._on_local_error() - if self._stage == STAGE_DESTROYED: - return True - if event & (eventloop.POLL_IN | eventloop.POLL_HUP): - handle = True - self._on_local_read() - if self._stage == STAGE_DESTROYED: - return True - if event & eventloop.POLL_OUT: - handle = True - self._on_local_write() - else: - logging.warn('unknown socket') - - return handle - - def _log_error(self, e): - logging.error('%s when handling connection from %s' % - (e, self._client_address.keys())) - - def destroy(self): - # destroy the handler and release any resources - # promises: - # 1. destroy won't make another destroy() call inside - # 2. destroy releases resources so it prevents future call to destroy - # 3. destroy won't raise any exceptions - # if any of the promises are broken, it indicates a bug has been - # introduced! mostly likely memory leaks, etc - #logging.info('tcp destroy called') - if self._stage == STAGE_DESTROYED: - # this couldn't happen - logging.debug('already destroyed') - return - self._stage = STAGE_DESTROYED - if self._remote_address: - logging.debug('destroy: %s:%d' % - self._remote_address) - else: - logging.debug('destroy') - if self._remote_sock: - logging.debug('destroying remote') - self._loop.remove(self._remote_sock) - try: - del self._fd_to_handlers[self._remote_sock.fileno()] - except Exception as e: - pass - self._remote_sock.close() - self._remote_sock = None - if self._sendingqueue.empty(): - self.destroy_local() - self._dns_resolver.remove_callback(self._handle_dns_resolved) - - def destroy_local(self): - if self._local_sock: - logging.debug('disconnect local') - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) - addr = None - addr = self.get_local_address() - self._write_to_sock(rsp_data, self._local_sock, addr) - self._local_sock = None - try: - del self._reqid_to_handlers[self._request_id] - except Exception as e: - pass - - self._server.remove_handler(self) - def client_key(source_addr, server_af): # notice this is server af, not dest af return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) @@ -908,7 +163,7 @@ class UDPRelay(object): self.server_user_transfer_ul = {} self.server_user_transfer_dl = {} - if common.to_bytes(config['protocol']) in [b"auth_aes128_md5", b"auth_aes128_sha1"]: + if common.to_bytes(config['protocol']) in obfs.mu_protocol(): self._update_users(None, None) self.protocol_data = obfs.obfs(config['protocol']).init_data() @@ -1261,71 +516,6 @@ class UDPRelay(object): else: shell.print_exception(e) - def _handle_tcp_over_udp(self, data, r_addr): - #(cmd, request_id, data) - #logging.info("UDP data %d %d %s" % (data[0], data[1], binascii.hexlify(data[2]))) - try: - self.server_transfer_ul += len(data[2]) - if data[0] == 0: - if len(data[2]) >= 4: - for i in range(64): - req_id = random.randint(1, 65535) - if req_id not in self._reqid_to_hd: - break - if req_id in self._reqid_to_hd: - for i in range(64): - req_id = random.randint(1, 65535) - if type(self._reqid_to_hd[req_id]) is tuple: - break - # return req id - self._reqid_to_hd[req_id] = (data[2][0:4], None) - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, RSP_STATE_CONNECTED) - data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) - self.write_to_server_socket(data_to_send, r_addr) - elif data[0] == CMD_CONNECT_REMOTE: - if len(data[2]) > 4 and data[1] in self._reqid_to_hd: - # create - if type(self._reqid_to_hd[data[1]]) is tuple: - if data[2][0:4] == self._reqid_to_hd[data[1]][0]: - handle = TCPRelayHandler(self, self._reqid_to_hd, self._fd_to_handlers, - self._eventloop, self._server_socket, - self._reqid_to_hd[data[1]][0], self._reqid_to_hd[data[1]][1], - self._config, self._dns_resolver, self._is_local) - self._reqid_to_hd[data[1]] = handle - handle.handle_client(r_addr, CMD_CONNECT, data[1], data[2]) - handle.handle_client(r_addr, *data) - self.update_activity(handle) - else: - # disconnect - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) - data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) - self.write_to_server_socket(data_to_send, r_addr) - else: - self.update_activity(self._reqid_to_hd[data[1]]) - self._reqid_to_hd[data[1]].handle_client(r_addr, *data) - else: - # disconnect - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) - data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) - self.write_to_server_socket(data_to_send, r_addr) - elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT: - if data[1] in self._reqid_to_hd: - if type(self._reqid_to_hd[data[1]]) is tuple: - pass - else: - self.update_activity(self._reqid_to_hd[data[1]]) - self._reqid_to_hd[data[1]].handle_client(r_addr, *data) - else: - # disconnect - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) - data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) - self.write_to_server_socket(data_to_send, r_addr) - return - except Exception as e: - trace = traceback.format_exc() - logging.error(trace) - return - def _handle_client(self, sock): data, r_addr = sock.recvfrom(BUF_SIZE) if not data: @@ -1333,6 +523,18 @@ class UDPRelay(object): return if self._stat_callback: self._stat_callback(self._listen_port, len(data)) + + client_addr = self._client_fd_to_server_addr.get(sock.fileno()) + client_uid = None + if client_addr: + key = client_key(client_addr[0], client_addr[1]) + client_pair = self._cache.get(key, None) + client_dns_pair = self._cache_dns_client.get(key, None) + if client_pair: + client, client_uid = client_pair + elif client_dns_pair: + client, client_uid = client_dns_pair + if not self._is_local: addrlen = len(r_addr[0]) if addrlen > 255: @@ -1341,7 +543,7 @@ class UDPRelay(object): data = pack_addr(r_addr[0]) + struct.pack('>H', r_addr[1]) + data ref_iv = [encrypt.encrypt_new_iv(self._method)] self._protocol.obfs.server_info.iv = ref_iv[0] - data = self._protocol.server_udp_pre_encrypt(data) + data = self._protocol.server_udp_pre_encrypt(data, client_uid) response = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 1, data, ref_iv) if not response: @@ -1361,16 +563,9 @@ class UDPRelay(object): #logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port)) response = b'\x00\x00\x00' + data - client_addr = self._client_fd_to_server_addr.get(sock.fileno()) + if client_addr: - key = client_key(client_addr[0], client_addr[1]) - client_pair = self._cache.get(key, None) - client_dns_pair = self._cache_dns_client.get(key, None) - if client_pair: - client, client_uid = client_pair - self.add_transfer_d(client_uid, len(response)) - elif client_dns_pair: - client, client_uid = client_dns_pair + if client_uid: self.add_transfer_d(client_uid, len(response)) else: self.server_transfer_dl += len(response)