Browse Source

Clean & reformat code.

2to3
Aspirin Geyer 7 years ago
parent
commit
2312f3150c
No known key found for this signature in database GPG Key ID: C0E90B22AB36C6DF
  1. 75
      shadowsocks/obfsplugin/obfs_tls.py
  2. 6
      shadowsocks/tcprelay.py

75
shadowsocks/obfsplugin/obfs_tls.py

@ -35,9 +35,11 @@ from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord from shadowsocks.common import to_bytes, to_str, ord
from shadowsocks import lru_cache from shadowsocks import lru_cache
def create_tls_ticket_auth_obfs(method): def create_tls_ticket_auth_obfs(method):
return tls_ticket_auth(method) return tls_ticket_auth(method)
obfs_map = { obfs_map = {
'tls1.2_ticket_auth': (create_tls_ticket_auth_obfs,), 'tls1.2_ticket_auth': (create_tls_ticket_auth_obfs,),
'tls1.2_ticket_auth_compatible': (create_tls_ticket_auth_obfs,), 'tls1.2_ticket_auth_compatible': (create_tls_ticket_auth_obfs,),
@ -45,20 +47,25 @@ obfs_map = {
'tls1.2_ticket_fastauth_compatible': (create_tls_ticket_auth_obfs,), 'tls1.2_ticket_fastauth_compatible': (create_tls_ticket_auth_obfs,),
} }
def match_begin(str1, str2): def match_begin(str1, str2):
if len(str1) >= len(str2): if len(str1) >= len(str2):
if str1[:len(str2)] == str2: if str1[:len(str2)] == str2:
return True return True
return False return False
class obfs_auth_data(object): class obfs_auth_data(object):
def __init__(self): def __init__(self):
self.client_data = lru_cache.LRUCache(60 * 5) self.client_data = lru_cache.LRUCache(60 * 5)
self.client_id = os.urandom(32) self.client_id = os.urandom(32)
self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF
self.ticket_buf = {} self.ticket_buf = {}
class tls_ticket_auth(plain.plain): class tls_ticket_auth(plain.plain):
def __init__(self, method): def __init__(self, method):
self.method = method self.method = method
self.handshake_status = 0 self.handshake_status = 0
@ -101,9 +108,15 @@ class tls_ticket_auth(plain.plain):
return ret return ret
if len(buf) > 0: if len(buf) > 0:
self.send_buffer += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf self.send_buffer += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf
if self.handshake_status == 0: if self.handshake_status == 0:
self.handshake_status = 1 self.handshake_status = 1
data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100") data = self.tls_version \
+ self.pack_auth_data(self.server_info.data.client_id) \
+ b"\x20" \
+ self.server_info.data.client_id \
+ binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100")
ext = binascii.unhexlify(b"ff01000100") ext = binascii.unhexlify(b"ff01000100")
host = self.server_info.obfs_param or self.server_info.host host = self.server_info.obfs_param or self.server_info.host
if host and host[-1] in string.digits: if host and host[-1] in string.digits:
@ -113,7 +126,9 @@ class tls_ticket_auth(plain.plain):
ext += self.sni(host) ext += self.sni(host)
ext += b"\x00\x17\x00\x00" ext += b"\x00\x17\x00\x00"
if host not in self.server_info.data.ticket_buf: if host not in self.server_info.data.ticket_buf:
self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 17 + 8) * 16) self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H',
os.urandom(2))[0] % 17 + 8) * 16)
ext += b"\x00\x23" + struct.pack('>H', len(self.server_info.data.ticket_buf[host])) + self.server_info.data.ticket_buf[host] ext += b"\x00\x23" + struct.pack('>H', len(self.server_info.data.ticket_buf[host])) + self.server_info.data.ticket_buf[host]
ext += binascii.unhexlify(b"000d001600140601060305010503040104030301030302010203") ext += binascii.unhexlify(b"000d001600140601060305010503040104030301030302010203")
ext += binascii.unhexlify(b"000500050100000000") ext += binascii.unhexlify(b"000500050100000000")
@ -137,7 +152,7 @@ class tls_ticket_auth(plain.plain):
def client_decode(self, buf): def client_decode(self, buf):
if self.handshake_status == -1: if self.handshake_status == -1:
return (buf, False) return buf, False
if self.handshake_status == 8: if self.handshake_status == 8:
ret = b'' ret = b''
@ -152,7 +167,7 @@ class tls_ticket_auth(plain.plain):
buf = self.recv_buffer[5:size+5] buf = self.recv_buffer[5:size+5]
ret += buf ret += buf
self.recv_buffer = self.recv_buffer[size+5:] self.recv_buffer = self.recv_buffer[size+5:]
return (ret, False) return ret, False
if len(buf) < 11 + 32 + 1 + 32: if len(buf) < 11 + 32 + 1 + 32:
raise Exception('client_decode data error') raise Exception('client_decode data error')
@ -161,7 +176,7 @@ class tls_ticket_auth(plain.plain):
raise Exception('client_decode data error') raise Exception('client_decode data error')
if hmac.new(self.server_info.key + self.server_info.data.client_id, buf[:-10], hashlib.sha1).digest()[:10] != buf[-10:]: if hmac.new(self.server_info.key + self.server_info.data.client_id, buf[:-10], hashlib.sha1).digest()[:10] != buf[-10:]:
raise Exception('client_decode data error') raise Exception('client_decode data error')
return (b'', True) return b'', True
def server_encode(self, buf): def server_encode(self, buf):
if self.handshake_status == -1: if self.handshake_status == -1:
@ -176,19 +191,25 @@ class tls_ticket_auth(plain.plain):
ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf
return ret return ret
self.handshake_status |= 8 self.handshake_status |= 8
data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"c02f000005ff01000100") data = self.tls_version + self.pack_auth_data(self.client_id) \
+ b"\x20" + self.client_id \
+ binascii.unhexlify(b"c02f000005ff01000100")
data = b"\x02\x00" + struct.pack('>H', len(data)) + data # server hello data = b"\x02\x00" + struct.pack('>H', len(data)) + data # server hello
data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data
if random.randint(0, 8) < 1: if random.randint(0, 8) < 1:
ticket = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 164) * 2 + 64) ticket = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 164) * 2 + 64)
ticket = struct.pack('>H', len(ticket) + 4) + b"\x04\x00" + struct.pack('>H', len(ticket)) + ticket ticket = struct.pack('>H', len(ticket) + 4) + b"\x04\x00" + struct.pack('>H', len(ticket)) + ticket
data += b"\x16" + self.tls_version + ticket # New session ticket data += b"\x16" + self.tls_version + ticket # New session ticket
data += b"\x14" + self.tls_version + b"\x00\x01\x01" # ChangeCipherSpec data += b"\x14" + self.tls_version + b"\x00\x01\x01" # ChangeCipherSpec
finish_len = random.choice([32, 40]) finish_len = random.choice([32, 40])
data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) # Finished data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) # Finished
data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10] data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10]
if buf: if buf:
data += self.server_encode(buf) data += self.server_encode(buf)
return data return data
def decode_error_return(self, buf): def decode_error_return(self, buf):
@ -197,26 +218,31 @@ class tls_ticket_auth(plain.plain):
self.server_info.overhead -= self.overhead self.server_info.overhead -= self.overhead
self.overhead = 0 self.overhead = 0
if self.method in ['tls1.2_ticket_auth', 'tls1.2_ticket_fastauth']: if self.method in ['tls1.2_ticket_auth', 'tls1.2_ticket_fastauth']:
return (b'E'*2048, False, False) return b'E'*2048, False, False
return (buf, True, False)
return buf, True, False
def server_decode(self, buf): def server_decode(self, buf):
if self.handshake_status == -1: if self.handshake_status == -1:
return (buf, True, False) return buf, True, False
if (self.handshake_status & 4) == 4: if (self.handshake_status & 4) == 4:
ret = b'' ret = b''
self.recv_buffer += buf self.recv_buffer += buf
while len(self.recv_buffer) > 5: while len(self.recv_buffer) > 5:
if ord(self.recv_buffer[0]) != 0x17 or ord(self.recv_buffer[1]) != 0x3 or ord(self.recv_buffer[2]) != 0x3: if ord(self.recv_buffer[0]) != 0x17 \
or ord(self.recv_buffer[1]) != 0x3 \
or ord(self.recv_buffer[2]) != 0x3:
logging.info("data = %s" % (binascii.hexlify(self.recv_buffer))) logging.info("data = %s" % (binascii.hexlify(self.recv_buffer)))
raise Exception('server_decode appdata error') raise Exception('server_decode appdata error')
size = struct.unpack('>H', self.recv_buffer[3:5])[0] size = struct.unpack('>H', self.recv_buffer[3:5])[0]
if len(self.recv_buffer) < size + 5: if len(self.recv_buffer) < size + 5:
break break
ret += self.recv_buffer[5:size+5] ret += self.recv_buffer[5:size+5]
self.recv_buffer = self.recv_buffer[size+5:] self.recv_buffer = self.recv_buffer[size+5:]
return (ret, True, False)
return ret, True, False
if (self.handshake_status & 1) == 1: if (self.handshake_status & 1) == 1:
self.recv_buffer += buf self.recv_buffer += buf
@ -224,34 +250,43 @@ class tls_ticket_auth(plain.plain):
verify = buf verify = buf
if len(buf) < 11: if len(buf) < 11:
raise Exception('server_decode data error') raise Exception('server_decode data error')
if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): # ChangeCipherSpec if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): # ChangeCipherSpec
raise Exception('server_decode data error') raise Exception('server_decode data error')
buf = buf[6:] buf = buf[6:]
if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): # Finished if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): # Finished
raise Exception('server_decode data error') raise Exception('server_decode data error')
verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10 verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10
if len(verify) < verify_len + 10: if len(verify) < verify_len + 10:
return (b'', False, False) return b'', False, False
if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]:
if hmac.new(self.server_info.key + self.client_id,
verify[:verify_len],
hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]:
raise Exception('server_decode data error') raise Exception('server_decode data error')
self.recv_buffer = verify[verify_len + 10:] self.recv_buffer = verify[verify_len + 10:]
status = self.handshake_status status = self.handshake_status
self.handshake_status |= 4 self.handshake_status |= 4
ret = self.server_decode(b'') ret = self.server_decode(b'')
return ret; return ret
#raise Exception("handshake data = %s" % (binascii.hexlify(buf))) #raise Exception("handshake data = %s" % (binascii.hexlify(buf)))
self.recv_buffer += buf self.recv_buffer += buf
buf = self.recv_buffer buf = self.recv_buffer
ogn_buf = buf ogn_buf = buf
if len(buf) < 3: if len(buf) < 3:
return (b'', False, False) return b'', False, False
if not match_begin(buf, b'\x16\x03\x01'): if not match_begin(buf, b'\x16\x03\x01'):
return self.decode_error_return(ogn_buf) return self.decode_error_return(ogn_buf)
buf = buf[3:] buf = buf[3:]
header_len = struct.unpack('>H', buf[:2])[0] header_len = struct.unpack('>H', buf[:2])[0]
if header_len > len(buf) - 2: if header_len > len(buf) - 2:
return (b'', False, False) return b'', False, False
self.recv_buffer = self.recv_buffer[header_len + 5:] self.recv_buffer = self.recv_buffer[header_len + 5:]
self.handshake_status = 1 self.handshake_status = 1
@ -259,14 +294,17 @@ class tls_ticket_auth(plain.plain):
if not match_begin(buf, b'\x01\x00'): # client hello if not match_begin(buf, b'\x01\x00'): # client hello
logging.info("tls_auth not client hello message") logging.info("tls_auth not client hello message")
return self.decode_error_return(ogn_buf) return self.decode_error_return(ogn_buf)
buf = buf[2:] buf = buf[2:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
logging.info("tls_auth wrong message size") logging.info("tls_auth wrong message size")
return self.decode_error_return(ogn_buf) return self.decode_error_return(ogn_buf)
buf = buf[2:] buf = buf[2:]
if not match_begin(buf, self.tls_version): if not match_begin(buf, self.tls_version):
logging.info("tls_auth wrong tls version") logging.info("tls_auth wrong tls version")
return self.decode_error_return(ogn_buf) return self.decode_error_return(ogn_buf)
buf = buf[2:] buf = buf[2:]
verifyid = buf[:32] verifyid = buf[:32]
buf = buf[32:] buf = buf[32:]
@ -299,7 +337,8 @@ class tls_ticket_auth(plain.plain):
self.server_info.data.client_data[verifyid[:22]] = sessionid self.server_info.data.client_data[verifyid[:22]] = sessionid
if len(self.recv_buffer) >= 11: if len(self.recv_buffer) >= 11:
ret = self.server_decode(b'') ret = self.server_decode(b'')
return (ret[0], True, True) return ret[0], True, True
# (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back)
return (b'', False, True) return b'', False, True

6
shadowsocks/tcprelay.py

@ -95,6 +95,7 @@ TCP_MSS = NETWORK_MTU - 40
BUF_SIZE = 32 * 1024 BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536 UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object): class SpeedTester(object):
def __init__(self, max_speed=0): def __init__(self, max_speed=0):
self.max_speed = max_speed * 1024 self.max_speed = max_speed * 1024
@ -123,6 +124,7 @@ class SpeedTester(object):
return self.sum_len >= self.max_speed return self.sum_len >= self.max_speed
return False return False
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config, def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local): dns_resolver, is_local):
@ -1175,6 +1177,7 @@ class TCPRelayHandler(object):
#gc.collect() #gc.collect()
#logging.debug("gc %s" % (gc.garbage,)) #logging.debug("gc %s" % (gc.garbage,))
class TCPRelay(object): class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None):
self._config = config self._config = config
@ -1200,8 +1203,7 @@ class TCPRelay(object):
common.connect_log = logging.info common.connect_log = logging.info
self._timeout = config['timeout'] self._timeout = config['timeout']
self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, close_callback=self._close_tcp_client)
close_callback=self._close_tcp_client)
if is_local: if is_local:
listen_addr = config['local_address'] listen_addr = config['local_address']

Loading…
Cancel
Save