From 2312f3150c4df2697cc7c15a58219400535c26b2 Mon Sep 17 00:00:00 2001 From: Aspirin Geyer Date: Thu, 9 Nov 2017 23:59:29 +0800 Subject: [PATCH] Clean & reformat code. --- shadowsocks/obfsplugin/obfs_tls.py | 99 +++++++++++++++++++++--------- shadowsocks/tcprelay.py | 12 ++-- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/shadowsocks/obfsplugin/obfs_tls.py b/shadowsocks/obfsplugin/obfs_tls.py index 7f1f233..2bac89f 100644 --- a/shadowsocks/obfsplugin/obfs_tls.py +++ b/shadowsocks/obfsplugin/obfs_tls.py @@ -35,9 +35,11 @@ from shadowsocks.obfsplugin import plain from shadowsocks.common import to_bytes, to_str, ord from shadowsocks import lru_cache + def create_tls_ticket_auth_obfs(method): return tls_ticket_auth(method) + obfs_map = { 'tls1.2_ticket_auth': (create_tls_ticket_auth_obfs,), 'tls1.2_ticket_auth_compatible': (create_tls_ticket_auth_obfs,), @@ -45,34 +47,39 @@ obfs_map = { 'tls1.2_ticket_fastauth_compatible': (create_tls_ticket_auth_obfs,), } + def match_begin(str1, str2): if len(str1) >= len(str2): if str1[:len(str2)] == str2: return True return False + class obfs_auth_data(object): + def __init__(self): self.client_data = lru_cache.LRUCache(60 * 5) self.client_id = os.urandom(32) self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF self.ticket_buf = {} + class tls_ticket_auth(plain.plain): + def __init__(self, method): self.method = method self.handshake_status = 0 self.send_buffer = b'' self.recv_buffer = b'' self.client_id = b'' - self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting self.tls_version = b'\x03\x03' self.overhead = 5 def init_data(self): return obfs_auth_data() - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def sni(self, url): @@ -101,9 +108,15 @@ class tls_ticket_auth(plain.plain): return ret if len(buf) > 0: self.send_buffer += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + if self.handshake_status == 0: self.handshake_status = 1 - data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100") + data = self.tls_version \ + + self.pack_auth_data(self.server_info.data.client_id) \ + + b"\x20" \ + + self.server_info.data.client_id \ + + binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100") + ext = binascii.unhexlify(b"ff01000100") host = self.server_info.obfs_param or self.server_info.host if host and host[-1] in string.digits: @@ -113,7 +126,9 @@ class tls_ticket_auth(plain.plain): ext += self.sni(host) ext += b"\x00\x17\x00\x00" if host not in self.server_info.data.ticket_buf: - self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 17 + 8) * 16) + self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H', + os.urandom(2))[0] % 17 + 8) * 16) + ext += b"\x00\x23" + struct.pack('>H', len(self.server_info.data.ticket_buf[host])) + self.server_info.data.ticket_buf[host] ext += binascii.unhexlify(b"000d001600140601060305010503040104030301030302010203") ext += binascii.unhexlify(b"000500050100000000") @@ -126,8 +141,8 @@ class tls_ticket_auth(plain.plain): data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data return data elif self.handshake_status == 1 and len(buf) == 0: - data = b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec - data += b"\x16" + self.tls_version + b"\x00\x20" + os.urandom(22) #Finished + data = b"\x14" + self.tls_version + b"\x00\x01\x01" # ChangeCipherSpec + data += b"\x16" + self.tls_version + b"\x00\x20" + os.urandom(22) # Finished data += hmac.new(self.server_info.key + self.server_info.data.client_id, data, hashlib.sha1).digest()[:10] ret = data + self.send_buffer self.send_buffer = b'' @@ -137,7 +152,7 @@ class tls_ticket_auth(plain.plain): def client_decode(self, buf): if self.handshake_status == -1: - return (buf, False) + return buf, False if self.handshake_status == 8: ret = b'' @@ -152,7 +167,7 @@ class tls_ticket_auth(plain.plain): buf = self.recv_buffer[5:size+5] ret += buf self.recv_buffer = self.recv_buffer[size+5:] - return (ret, False) + return ret, False if len(buf) < 11 + 32 + 1 + 32: raise Exception('client_decode data error') @@ -161,7 +176,7 @@ class tls_ticket_auth(plain.plain): raise Exception('client_decode data error') if hmac.new(self.server_info.key + self.server_info.data.client_id, buf[:-10], hashlib.sha1).digest()[:10] != buf[-10:]: raise Exception('client_decode data error') - return (b'', True) + return b'', True def server_encode(self, buf): if self.handshake_status == -1: @@ -176,19 +191,25 @@ class tls_ticket_auth(plain.plain): ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf return ret self.handshake_status |= 8 - data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"c02f000005ff01000100") - data = b"\x02\x00" + struct.pack('>H', len(data)) + data #server hello + data = self.tls_version + self.pack_auth_data(self.client_id) \ + + b"\x20" + self.client_id \ + + binascii.unhexlify(b"c02f000005ff01000100") + + data = b"\x02\x00" + struct.pack('>H', len(data)) + data # server hello data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + if random.randint(0, 8) < 1: ticket = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 164) * 2 + 64) ticket = struct.pack('>H', len(ticket) + 4) + b"\x04\x00" + struct.pack('>H', len(ticket)) + ticket - data += b"\x16" + self.tls_version + ticket #New session ticket - data += b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec + data += b"\x16" + self.tls_version + ticket # New session ticket + + data += b"\x14" + self.tls_version + b"\x00\x01\x01" # ChangeCipherSpec finish_len = random.choice([32, 40]) - data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) #Finished + data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) # Finished data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10] if buf: data += self.server_encode(buf) + return data def decode_error_return(self, buf): @@ -197,26 +218,31 @@ class tls_ticket_auth(plain.plain): self.server_info.overhead -= self.overhead self.overhead = 0 if self.method in ['tls1.2_ticket_auth', 'tls1.2_ticket_fastauth']: - return (b'E'*2048, False, False) - return (buf, True, False) + return b'E'*2048, False, False + + return buf, True, False def server_decode(self, buf): if self.handshake_status == -1: - return (buf, True, False) + return buf, True, False if (self.handshake_status & 4) == 4: ret = b'' self.recv_buffer += buf while len(self.recv_buffer) > 5: - if ord(self.recv_buffer[0]) != 0x17 or ord(self.recv_buffer[1]) != 0x3 or ord(self.recv_buffer[2]) != 0x3: + if ord(self.recv_buffer[0]) != 0x17 \ + or ord(self.recv_buffer[1]) != 0x3 \ + or ord(self.recv_buffer[2]) != 0x3: logging.info("data = %s" % (binascii.hexlify(self.recv_buffer))) raise Exception('server_decode appdata error') + size = struct.unpack('>H', self.recv_buffer[3:5])[0] if len(self.recv_buffer) < size + 5: break ret += self.recv_buffer[5:size+5] self.recv_buffer = self.recv_buffer[size+5:] - return (ret, True, False) + + return ret, True, False if (self.handshake_status & 1) == 1: self.recv_buffer += buf @@ -224,49 +250,61 @@ class tls_ticket_auth(plain.plain): verify = buf if len(buf) < 11: raise Exception('server_decode data error') - if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): #ChangeCipherSpec + + if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): # ChangeCipherSpec raise Exception('server_decode data error') + buf = buf[6:] - if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): #Finished + if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): # Finished raise Exception('server_decode data error') - verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10 + + verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10 if len(verify) < verify_len + 10: - return (b'', False, False) - if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]: + return b'', False, False + + if hmac.new(self.server_info.key + self.client_id, + verify[:verify_len], + hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]: raise Exception('server_decode data error') + self.recv_buffer = verify[verify_len + 10:] status = self.handshake_status self.handshake_status |= 4 ret = self.server_decode(b'') - return ret; + return ret #raise Exception("handshake data = %s" % (binascii.hexlify(buf))) self.recv_buffer += buf buf = self.recv_buffer ogn_buf = buf if len(buf) < 3: - return (b'', False, False) + return b'', False, False + if not match_begin(buf, b'\x16\x03\x01'): return self.decode_error_return(ogn_buf) + buf = buf[3:] header_len = struct.unpack('>H', buf[:2])[0] if header_len > len(buf) - 2: - return (b'', False, False) + return b'', False, False self.recv_buffer = self.recv_buffer[header_len + 5:] self.handshake_status = 1 buf = buf[2:header_len + 2] - if not match_begin(buf, b'\x01\x00'): #client hello + if not match_begin(buf, b'\x01\x00'): # client hello logging.info("tls_auth not client hello message") return self.decode_error_return(ogn_buf) + buf = buf[2:] if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: logging.info("tls_auth wrong message size") return self.decode_error_return(ogn_buf) + buf = buf[2:] if not match_begin(buf, self.tls_version): logging.info("tls_auth wrong tls version") return self.decode_error_return(ogn_buf) + buf = buf[2:] verifyid = buf[:32] buf = buf[32:] @@ -299,7 +337,8 @@ class tls_ticket_auth(plain.plain): self.server_info.data.client_data[verifyid[:22]] = sessionid if len(self.recv_buffer) >= 11: ret = self.server_decode(b'') - return (ret[0], True, True) + return ret[0], True, True + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) - return (b'', False, True) + return b'', False, True diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index ce84d84..03eead4 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -95,8 +95,9 @@ TCP_MSS = NETWORK_MTU - 40 BUF_SIZE = 32 * 1024 UDP_MAX_BUF_SIZE = 65536 + class SpeedTester(object): - def __init__(self, max_speed = 0): + def __init__(self, max_speed=0): self.max_speed = max_speed * 1024 self.last_time = time.time() self.sum_len = 0 @@ -123,6 +124,7 @@ class SpeedTester(object): return self.sum_len >= self.max_speed return False + class TCPRelayHandler(object): def __init__(self, server, fd_to_handlers, loop, local_sock, config, dns_resolver, is_local): @@ -954,7 +956,7 @@ class TCPRelayHandler(object): self._recv_pack_id += 1 except (OSError, IOError) as e: if eventloop.errno_from_exception(e) in \ - (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): # errno.WSAEWOULDBLOCK return if not data: self.destroy() @@ -1175,6 +1177,7 @@ class TCPRelayHandler(object): #gc.collect() #logging.debug("gc %s" % (gc.garbage,)) + class TCPRelay(object): def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): self._config = config @@ -1200,8 +1203,7 @@ class TCPRelay(object): common.connect_log = logging.info self._timeout = config['timeout'] - self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, - close_callback=self._close_tcp_client) + self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, close_callback=self._close_tcp_client) if is_local: listen_addr = config['local_address'] @@ -1277,7 +1279,7 @@ class TCPRelay(object): self.del_user(uid) else: passwd = items[1] - self.add_user(uid, {'password':passwd}) + self.add_user(uid, {'password': passwd}) def _update_user(self, id, passwd): uid = struct.pack('