Browse Source

Clean & reformat code.

2to3
Aspirin Geyer 7 years ago
parent
commit
2312f3150c
No known key found for this signature in database GPG Key ID: C0E90B22AB36C6DF
  1. 99
      shadowsocks/obfsplugin/obfs_tls.py
  2. 12
      shadowsocks/tcprelay.py

99
shadowsocks/obfsplugin/obfs_tls.py

@ -35,9 +35,11 @@ from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord
from shadowsocks import lru_cache
def create_tls_ticket_auth_obfs(method):
return tls_ticket_auth(method)
obfs_map = {
'tls1.2_ticket_auth': (create_tls_ticket_auth_obfs,),
'tls1.2_ticket_auth_compatible': (create_tls_ticket_auth_obfs,),
@ -45,34 +47,39 @@ obfs_map = {
'tls1.2_ticket_fastauth_compatible': (create_tls_ticket_auth_obfs,),
}
def match_begin(str1, str2):
if len(str1) >= len(str2):
if str1[:len(str2)] == str2:
return True
return False
class obfs_auth_data(object):
def __init__(self):
self.client_data = lru_cache.LRUCache(60 * 5)
self.client_id = os.urandom(32)
self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF
self.ticket_buf = {}
class tls_ticket_auth(plain.plain):
def __init__(self, method):
self.method = method
self.handshake_status = 0
self.send_buffer = b''
self.recv_buffer = b''
self.client_id = b''
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting
self.tls_version = b'\x03\x03'
self.overhead = 5
def init_data(self):
return obfs_auth_data()
def get_overhead(self, direction): # direction: true for c->s false for s->c
def get_overhead(self, direction): # direction: true for c->s false for s->c
return self.overhead
def sni(self, url):
@ -101,9 +108,15 @@ class tls_ticket_auth(plain.plain):
return ret
if len(buf) > 0:
self.send_buffer += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf
if self.handshake_status == 0:
self.handshake_status = 1
data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100")
data = self.tls_version \
+ self.pack_auth_data(self.server_info.data.client_id) \
+ b"\x20" \
+ self.server_info.data.client_id \
+ binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100")
ext = binascii.unhexlify(b"ff01000100")
host = self.server_info.obfs_param or self.server_info.host
if host and host[-1] in string.digits:
@ -113,7 +126,9 @@ class tls_ticket_auth(plain.plain):
ext += self.sni(host)
ext += b"\x00\x17\x00\x00"
if host not in self.server_info.data.ticket_buf:
self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 17 + 8) * 16)
self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H',
os.urandom(2))[0] % 17 + 8) * 16)
ext += b"\x00\x23" + struct.pack('>H', len(self.server_info.data.ticket_buf[host])) + self.server_info.data.ticket_buf[host]
ext += binascii.unhexlify(b"000d001600140601060305010503040104030301030302010203")
ext += binascii.unhexlify(b"000500050100000000")
@ -126,8 +141,8 @@ class tls_ticket_auth(plain.plain):
data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data
return data
elif self.handshake_status == 1 and len(buf) == 0:
data = b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec
data += b"\x16" + self.tls_version + b"\x00\x20" + os.urandom(22) #Finished
data = b"\x14" + self.tls_version + b"\x00\x01\x01" # ChangeCipherSpec
data += b"\x16" + self.tls_version + b"\x00\x20" + os.urandom(22) # Finished
data += hmac.new(self.server_info.key + self.server_info.data.client_id, data, hashlib.sha1).digest()[:10]
ret = data + self.send_buffer
self.send_buffer = b''
@ -137,7 +152,7 @@ class tls_ticket_auth(plain.plain):
def client_decode(self, buf):
if self.handshake_status == -1:
return (buf, False)
return buf, False
if self.handshake_status == 8:
ret = b''
@ -152,7 +167,7 @@ class tls_ticket_auth(plain.plain):
buf = self.recv_buffer[5:size+5]
ret += buf
self.recv_buffer = self.recv_buffer[size+5:]
return (ret, False)
return ret, False
if len(buf) < 11 + 32 + 1 + 32:
raise Exception('client_decode data error')
@ -161,7 +176,7 @@ class tls_ticket_auth(plain.plain):
raise Exception('client_decode data error')
if hmac.new(self.server_info.key + self.server_info.data.client_id, buf[:-10], hashlib.sha1).digest()[:10] != buf[-10:]:
raise Exception('client_decode data error')
return (b'', True)
return b'', True
def server_encode(self, buf):
if self.handshake_status == -1:
@ -176,19 +191,25 @@ class tls_ticket_auth(plain.plain):
ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf
return ret
self.handshake_status |= 8
data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"c02f000005ff01000100")
data = b"\x02\x00" + struct.pack('>H', len(data)) + data #server hello
data = self.tls_version + self.pack_auth_data(self.client_id) \
+ b"\x20" + self.client_id \
+ binascii.unhexlify(b"c02f000005ff01000100")
data = b"\x02\x00" + struct.pack('>H', len(data)) + data # server hello
data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data
if random.randint(0, 8) < 1:
ticket = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 164) * 2 + 64)
ticket = struct.pack('>H', len(ticket) + 4) + b"\x04\x00" + struct.pack('>H', len(ticket)) + ticket
data += b"\x16" + self.tls_version + ticket #New session ticket
data += b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec
data += b"\x16" + self.tls_version + ticket # New session ticket
data += b"\x14" + self.tls_version + b"\x00\x01\x01" # ChangeCipherSpec
finish_len = random.choice([32, 40])
data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) #Finished
data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) # Finished
data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10]
if buf:
data += self.server_encode(buf)
return data
def decode_error_return(self, buf):
@ -197,26 +218,31 @@ class tls_ticket_auth(plain.plain):
self.server_info.overhead -= self.overhead
self.overhead = 0
if self.method in ['tls1.2_ticket_auth', 'tls1.2_ticket_fastauth']:
return (b'E'*2048, False, False)
return (buf, True, False)
return b'E'*2048, False, False
return buf, True, False
def server_decode(self, buf):
if self.handshake_status == -1:
return (buf, True, False)
return buf, True, False
if (self.handshake_status & 4) == 4:
ret = b''
self.recv_buffer += buf
while len(self.recv_buffer) > 5:
if ord(self.recv_buffer[0]) != 0x17 or ord(self.recv_buffer[1]) != 0x3 or ord(self.recv_buffer[2]) != 0x3:
if ord(self.recv_buffer[0]) != 0x17 \
or ord(self.recv_buffer[1]) != 0x3 \
or ord(self.recv_buffer[2]) != 0x3:
logging.info("data = %s" % (binascii.hexlify(self.recv_buffer)))
raise Exception('server_decode appdata error')
size = struct.unpack('>H', self.recv_buffer[3:5])[0]
if len(self.recv_buffer) < size + 5:
break
ret += self.recv_buffer[5:size+5]
self.recv_buffer = self.recv_buffer[size+5:]
return (ret, True, False)
return ret, True, False
if (self.handshake_status & 1) == 1:
self.recv_buffer += buf
@ -224,49 +250,61 @@ class tls_ticket_auth(plain.plain):
verify = buf
if len(buf) < 11:
raise Exception('server_decode data error')
if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): #ChangeCipherSpec
if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): # ChangeCipherSpec
raise Exception('server_decode data error')
buf = buf[6:]
if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): #Finished
if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): # Finished
raise Exception('server_decode data error')
verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10
verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10
if len(verify) < verify_len + 10:
return (b'', False, False)
if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]:
return b'', False, False
if hmac.new(self.server_info.key + self.client_id,
verify[:verify_len],
hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]:
raise Exception('server_decode data error')
self.recv_buffer = verify[verify_len + 10:]
status = self.handshake_status
self.handshake_status |= 4
ret = self.server_decode(b'')
return ret;
return ret
#raise Exception("handshake data = %s" % (binascii.hexlify(buf)))
self.recv_buffer += buf
buf = self.recv_buffer
ogn_buf = buf
if len(buf) < 3:
return (b'', False, False)
return b'', False, False
if not match_begin(buf, b'\x16\x03\x01'):
return self.decode_error_return(ogn_buf)
buf = buf[3:]
header_len = struct.unpack('>H', buf[:2])[0]
if header_len > len(buf) - 2:
return (b'', False, False)
return b'', False, False
self.recv_buffer = self.recv_buffer[header_len + 5:]
self.handshake_status = 1
buf = buf[2:header_len + 2]
if not match_begin(buf, b'\x01\x00'): #client hello
if not match_begin(buf, b'\x01\x00'): # client hello
logging.info("tls_auth not client hello message")
return self.decode_error_return(ogn_buf)
buf = buf[2:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
logging.info("tls_auth wrong message size")
return self.decode_error_return(ogn_buf)
buf = buf[2:]
if not match_begin(buf, self.tls_version):
logging.info("tls_auth wrong tls version")
return self.decode_error_return(ogn_buf)
buf = buf[2:]
verifyid = buf[:32]
buf = buf[32:]
@ -299,7 +337,8 @@ class tls_ticket_auth(plain.plain):
self.server_info.data.client_data[verifyid[:22]] = sessionid
if len(self.recv_buffer) >= 11:
ret = self.server_decode(b'')
return (ret[0], True, True)
return ret[0], True, True
# (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back)
return (b'', False, True)
return b'', False, True

12
shadowsocks/tcprelay.py

@ -95,8 +95,9 @@ TCP_MSS = NETWORK_MTU - 40
BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object):
def __init__(self, max_speed = 0):
def __init__(self, max_speed=0):
self.max_speed = max_speed * 1024
self.last_time = time.time()
self.sum_len = 0
@ -123,6 +124,7 @@ class SpeedTester(object):
return self.sum_len >= self.max_speed
return False
class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local):
@ -954,7 +956,7 @@ class TCPRelayHandler(object):
self._recv_pack_id += 1
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in \
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): # errno.WSAEWOULDBLOCK
return
if not data:
self.destroy()
@ -1175,6 +1177,7 @@ class TCPRelayHandler(object):
#gc.collect()
#logging.debug("gc %s" % (gc.garbage,))
class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None):
self._config = config
@ -1200,8 +1203,7 @@ class TCPRelay(object):
common.connect_log = logging.info
self._timeout = config['timeout']
self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout,
close_callback=self._close_tcp_client)
self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, close_callback=self._close_tcp_client)
if is_local:
listen_addr = config['local_address']
@ -1277,7 +1279,7 @@ class TCPRelay(object):
self.del_user(uid)
else:
passwd = items[1]
self.add_user(uid, {'password':passwd})
self.add_user(uid, {'password': passwd})
def _update_user(self, id, passwd):
uid = struct.pack('<I', id)

Loading…
Cancel
Save