Browse Source

format code

akkariiin/master
Akkariiin 8 years ago
parent
commit
2f012ab04f
  1. 80
      shadowsocks/obfsplugin/auth_chain.py

80
shadowsocks/obfsplugin/auth_chain.py

@ -38,17 +38,21 @@ from shadowsocks import common, lru_cache, encrypt
from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord, chr
def create_auth_chain_a(method):
return auth_chain_a(method)
def create_auth_chain_b(method):
return auth_chain_b(method)
obfs_map = {
'auth_chain_a': (create_auth_chain_a,),
'auth_chain_b': (create_auth_chain_b,),
'auth_chain_a': (create_auth_chain_a,),
'auth_chain_b': (create_auth_chain_b,),
}
class xorshift128plus(object):
max_int = (1 << 64) - 1
mov_mask = (1 << (64 - 23)) - 1
@ -80,12 +84,14 @@ class xorshift128plus(object):
for i in range(4):
self.next()
def match_begin(str1, str2):
if len(str1) >= len(str2):
if str1[:len(str2)] == str2:
return True
return False
class auth_base(plain.plain):
def __init__(self, method):
super(auth_base, self).__init__(method)
@ -96,7 +102,7 @@ class auth_base(plain.plain):
def init_data(self):
return ''
def get_overhead(self, direction): # direction: true for c->s false for s->c
def get_overhead(self, direction): # direction: true for c->s false for s->c
return self.overhead
def set_server_info(self, server_info):
@ -118,9 +124,10 @@ class auth_base(plain.plain):
self.raw_trans = True
self.overhead = 0
if self.method == self.no_compatible_method:
return (b'E'*2048, False)
return (b'E' * 2048, False)
return (buf, False)
class client_queue(object):
def __init__(self, begin_id):
self.front = begin_id - 64
@ -175,13 +182,14 @@ class client_queue(object):
self.addref()
return True
class obfs_auth_chain_data(object):
def __init__(self, name):
self.name = name
self.user_id = {}
self.local_client_id = b''
self.connection_id = 0
self.set_max_client(64) # max active client count
self.set_max_client(64) # max active client count
def update(self, user_id, client_id, connection_id):
if user_id not in self.user_id:
@ -203,7 +211,7 @@ class obfs_auth_chain_data(object):
if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable:
if local_client_id.first() is None or len(local_client_id) < self.max_client:
if client_id not in local_client_id:
#TODO: check
# TODO: check
local_client_id[client_id] = client_queue(connection_id)
else:
local_client_id[client_id].re_enable(connection_id)
@ -212,7 +220,7 @@ class obfs_auth_chain_data(object):
if not local_client_id[local_client_id.first()].is_active():
del local_client_id[local_client_id.first()]
if client_id not in local_client_id:
#TODO: check
# TODO: check
local_client_id[client_id] = client_queue(connection_id)
else:
local_client_id[client_id].re_enable(connection_id)
@ -229,6 +237,7 @@ class obfs_auth_chain_data(object):
if client_id in local_client_id:
local_client_id[client_id].delref()
class auth_chain_a(auth_base):
def __init__(self, method):
super(auth_chain_a, self).__init__(method)
@ -240,7 +249,7 @@ class auth_chain_a(auth_base):
self.has_recv_header = False
self.client_id = 0
self.connection_id = 0
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting
self.salt = b"auth_chain_a"
self.no_compatible_method = 'auth_chain_a'
self.pack_id = 1
@ -259,7 +268,7 @@ class auth_chain_a(auth_base):
def init_data(self):
return obfs_auth_chain_data(self.method)
def get_overhead(self, direction): # direction: true for c->s false for s->c
def get_overhead(self, direction): # direction: true for c->s false for s->c
return self.overhead
def set_server_info(self, server_info):
@ -362,14 +371,16 @@ class auth_chain_a(auth_base):
if self.user_key is None:
self.user_key = self.server_info.key
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0]
uid = struct.pack('<I', uid)
data = uid + encryptor.encrypt(data)[16:]
self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest()
data = check_head + data + self.last_server_hash[:4]
self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
return data + self.pack_client_data(buf)
def auth_data(self):
@ -382,8 +393,8 @@ class auth_chain_a(auth_base):
self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF
self.server_info.data.connection_id += 1
return b''.join([struct.pack('<I', utc_time),
self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)])
self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)])
def client_pre_encrypt(self, buf):
ret = b''
@ -419,8 +430,9 @@ class auth_chain_a(auth_base):
break
server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if server_hash[:2] != self.recv_buf[length + 2 : length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
if server_hash[:2] != self.recv_buf[length + 2: length + 4]:
logging.info('%s: checksum error, data %s'
% (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
self.raw_trans = True
self.recv_buf = b''
raise Exception('client_post_decrypt data uncorrect checksum')
@ -428,7 +440,7 @@ class auth_chain_a(auth_base):
pos = 2
if data_len > 0 and rand_len > 0:
pos = 2 + self.rnd_start_pos(rand_len, self.random_server)
out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos])
out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos])
self.last_server_hash = server_hash
if self.recv_id == 1:
self.server_info.tcp_mss = struct.unpack('<H', out_buf[:2])[0]
@ -486,16 +498,19 @@ class auth_chain_a(auth_base):
else:
self.user_key = self.server_info.recv_iv
md5data = hmac.new(self.user_key, self.recv_buf[12 : 12 + 20], self.hashfunc).digest()
md5data = hmac.new(self.user_key, self.recv_buf[12: 12 + 20], self.hashfunc).digest()
if md5data[:4] != self.recv_buf[32:36]:
logging.error('%s data uncorrect auth HMAC-MD5 from %s:%d, data %s' % (self.no_compatible_method, self.server_info.client, self.server_info.client_port, binascii.hexlify(self.recv_buf)))
logging.error('%s data uncorrect auth HMAC-MD5 from %s:%d, data %s' % (
self.no_compatible_method, self.server_info.client, self.server_info.client_port,
binascii.hexlify(self.recv_buf)
))
if len(self.recv_buf) < 36:
return (b'', False)
return self.not_match_return(self.recv_buf)
self.last_server_hash = md5data
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc')
head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[16:32] + b'\x00') # need an extra byte or recv empty
head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[16:32] + b'\x00') # need an extra byte or recv empty
self.client_over_head = struct.unpack('<H', head[12:14])[0]
utc_time = struct.unpack('<I', head[:4])[0]
@ -503,7 +518,9 @@ class auth_chain_a(auth_base):
connection_id = struct.unpack('<I', head[8:12])[0]
time_dif = common.int32(utc_time - (int(time.time()) & 0xffffffff))
if time_dif < -self.max_time_dif or time_dif > self.max_time_dif:
logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head)))
logging.info('%s: wrong timestamp, time_dif %d, data %s' % (
self.no_compatible_method, time_dif, binascii.hexlify(head)
))
return self.not_match_return(self.recv_buf)
elif self.server_info.data.insert(self.user_id, client_id, connection_id):
self.has_recv_header = True
@ -513,7 +530,8 @@ class auth_chain_a(auth_base):
logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf)))
return self.not_match_return(self.recv_buf)
self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.recv_buf = self.recv_buf[36:]
self.has_recv_header = True
sendback = True
@ -528,7 +546,7 @@ class auth_chain_a(auth_base):
self.recv_buf = b''
if self.recv_id == 0:
logging.info(self.no_compatible_method + ': over size')
return (b'E'*2048, False)
return (b'E' * 2048, False)
else:
raise Exception('server_post_decrype data error')
@ -536,12 +554,14 @@ class auth_chain_a(auth_base):
break
client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if client_hash[:2] != self.recv_buf[length + 2 : length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
if client_hash[:2] != self.recv_buf[length + 2: length + 4]:
logging.info('%s: checksum error, data %s' % (
self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])
))
self.raw_trans = True
self.recv_buf = b''
if self.recv_id == 0:
return (b'E'*2048, False)
return (b'E' * 2048, False)
else:
raise Exception('server_post_decrype data uncorrect checksum')
@ -549,7 +569,7 @@ class auth_chain_a(auth_base):
pos = 2
if data_len > 0 and rand_len > 0:
pos = 2 + self.rnd_start_pos(rand_len, self.random_client)
out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos])
out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos])
self.last_client_hash = client_hash
self.recv_buf = self.recv_buf[length + 4:]
if data_len == 0:
@ -577,7 +597,8 @@ class auth_chain_a(auth_base):
uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0]
uid = struct.pack('<I', uid)
rand_len = self.udp_rnd_data_len(md5data, self.random_client)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata + uid
return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1]
@ -590,7 +611,8 @@ class auth_chain_a(auth_base):
mac_key = self.server_info.key
md5data = hmac.new(mac_key, buf[-8:-1], self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
return encryptor.decrypt(buf[:-8 - rand_len])
def server_udp_pre_encrypt(self, buf, uid):
@ -634,6 +656,7 @@ class auth_chain_a(auth_base):
def dispose(self):
self.server_info.data.remove(self.user_id, self.client_id)
class auth_chain_b(auth_chain_a):
def __init__(self, method):
super(auth_chain_b, self).__init__(method)
@ -701,4 +724,3 @@ class auth_chain_b(auth_chain_a):
if buf_size > 400:
return random.next() % 521
return random.next() % 1021

Loading…
Cancel
Save