Browse Source

format code

akkariiin/master
Akkariiin 7 years ago
parent
commit
2f012ab04f
  1. 42
      shadowsocks/obfsplugin/auth_chain.py

42
shadowsocks/obfsplugin/auth_chain.py

@ -38,17 +38,21 @@ from shadowsocks import common, lru_cache, encrypt
from shadowsocks.obfsplugin import plain from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord, chr from shadowsocks.common import to_bytes, to_str, ord, chr
def create_auth_chain_a(method): def create_auth_chain_a(method):
return auth_chain_a(method) return auth_chain_a(method)
def create_auth_chain_b(method): def create_auth_chain_b(method):
return auth_chain_b(method) return auth_chain_b(method)
obfs_map = { obfs_map = {
'auth_chain_a': (create_auth_chain_a,), 'auth_chain_a': (create_auth_chain_a,),
'auth_chain_b': (create_auth_chain_b,), 'auth_chain_b': (create_auth_chain_b,),
} }
class xorshift128plus(object): class xorshift128plus(object):
max_int = (1 << 64) - 1 max_int = (1 << 64) - 1
mov_mask = (1 << (64 - 23)) - 1 mov_mask = (1 << (64 - 23)) - 1
@ -80,12 +84,14 @@ class xorshift128plus(object):
for i in range(4): for i in range(4):
self.next() self.next()
def match_begin(str1, str2): def match_begin(str1, str2):
if len(str1) >= len(str2): if len(str1) >= len(str2):
if str1[:len(str2)] == str2: if str1[:len(str2)] == str2:
return True return True
return False return False
class auth_base(plain.plain): class auth_base(plain.plain):
def __init__(self, method): def __init__(self, method):
super(auth_base, self).__init__(method) super(auth_base, self).__init__(method)
@ -121,6 +127,7 @@ class auth_base(plain.plain):
return (b'E' * 2048, False) return (b'E' * 2048, False)
return (buf, False) return (buf, False)
class client_queue(object): class client_queue(object):
def __init__(self, begin_id): def __init__(self, begin_id):
self.front = begin_id - 64 self.front = begin_id - 64
@ -175,6 +182,7 @@ class client_queue(object):
self.addref() self.addref()
return True return True
class obfs_auth_chain_data(object): class obfs_auth_chain_data(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@ -229,6 +237,7 @@ class obfs_auth_chain_data(object):
if client_id in local_client_id: if client_id in local_client_id:
local_client_id[client_id].delref() local_client_id[client_id].delref()
class auth_chain_a(auth_base): class auth_chain_a(auth_base):
def __init__(self, method): def __init__(self, method):
super(auth_chain_a, self).__init__(method) super(auth_chain_a, self).__init__(method)
@ -362,14 +371,16 @@ class auth_chain_a(auth_base):
if self.user_key is None: if self.user_key is None:
self.user_key = self.server_info.key self.user_key = self.server_info.key
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16) encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0] uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0]
uid = struct.pack('<I', uid) uid = struct.pack('<I', uid)
data = uid + encryptor.encrypt(data)[16:] data = uid + encryptor.encrypt(data)[16:]
self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest() self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest()
data = check_head + data + self.last_server_hash[:4] data = check_head + data + self.last_server_hash[:4]
self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
return data + self.pack_client_data(buf) return data + self.pack_client_data(buf)
def auth_data(self): def auth_data(self):
@ -420,7 +431,8 @@ class auth_chain_a(auth_base):
server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if server_hash[:2] != self.recv_buf[length + 2: length + 4]: if server_hash[:2] != self.recv_buf[length + 2: length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) logging.info('%s: checksum error, data %s'
% (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
raise Exception('client_post_decrypt data uncorrect checksum') raise Exception('client_post_decrypt data uncorrect checksum')
@ -488,7 +500,10 @@ class auth_chain_a(auth_base):
md5data = hmac.new(self.user_key, self.recv_buf[12: 12 + 20], self.hashfunc).digest() md5data = hmac.new(self.user_key, self.recv_buf[12: 12 + 20], self.hashfunc).digest()
if md5data[:4] != self.recv_buf[32:36]: if md5data[:4] != self.recv_buf[32:36]:
logging.error('%s data uncorrect auth HMAC-MD5 from %s:%d, data %s' % (self.no_compatible_method, self.server_info.client, self.server_info.client_port, binascii.hexlify(self.recv_buf))) logging.error('%s data uncorrect auth HMAC-MD5 from %s:%d, data %s' % (
self.no_compatible_method, self.server_info.client, self.server_info.client_port,
binascii.hexlify(self.recv_buf)
))
if len(self.recv_buf) < 36: if len(self.recv_buf) < 36:
return (b'', False) return (b'', False)
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
@ -503,7 +518,9 @@ class auth_chain_a(auth_base):
connection_id = struct.unpack('<I', head[8:12])[0] connection_id = struct.unpack('<I', head[8:12])[0]
time_dif = common.int32(utc_time - (int(time.time()) & 0xffffffff)) time_dif = common.int32(utc_time - (int(time.time()) & 0xffffffff))
if time_dif < -self.max_time_dif or time_dif > self.max_time_dif: if time_dif < -self.max_time_dif or time_dif > self.max_time_dif:
logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) logging.info('%s: wrong timestamp, time_dif %d, data %s' % (
self.no_compatible_method, time_dif, binascii.hexlify(head)
))
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
elif self.server_info.data.insert(self.user_id, client_id, connection_id): elif self.server_info.data.insert(self.user_id, client_id, connection_id):
self.has_recv_header = True self.has_recv_header = True
@ -513,7 +530,8 @@ class auth_chain_a(auth_base):
logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf)))
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.recv_buf = self.recv_buf[36:] self.recv_buf = self.recv_buf[36:]
self.has_recv_header = True self.has_recv_header = True
sendback = True sendback = True
@ -537,7 +555,9 @@ class auth_chain_a(auth_base):
client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if client_hash[:2] != self.recv_buf[length + 2: length + 4]: if client_hash[:2] != self.recv_buf[length + 2: length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) logging.info('%s: checksum error, data %s' % (
self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])
))
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
if self.recv_id == 0: if self.recv_id == 0:
@ -577,7 +597,8 @@ class auth_chain_a(auth_base):
uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0] uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0]
uid = struct.pack('<I', uid) uid = struct.pack('<I', uid)
rand_len = self.udp_rnd_data_len(md5data, self.random_client) rand_len = self.udp_rnd_data_len(md5data, self.random_client)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4') encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf) out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata + uid buf = out_buf + os.urandom(rand_len) + authdata + uid
return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1] return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1]
@ -590,7 +611,8 @@ class auth_chain_a(auth_base):
mac_key = self.server_info.key mac_key = self.server_info.key
md5data = hmac.new(mac_key, buf[-8:-1], self.hashfunc).digest() md5data = hmac.new(mac_key, buf[-8:-1], self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server) rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4') encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
return encryptor.decrypt(buf[:-8 - rand_len]) return encryptor.decrypt(buf[:-8 - rand_len])
def server_udp_pre_encrypt(self, buf, uid): def server_udp_pre_encrypt(self, buf, uid):
@ -634,6 +656,7 @@ class auth_chain_a(auth_base):
def dispose(self): def dispose(self):
self.server_info.data.remove(self.user_id, self.client_id) self.server_info.data.remove(self.user_id, self.client_id)
class auth_chain_b(auth_chain_a): class auth_chain_b(auth_chain_a):
def __init__(self, method): def __init__(self, method):
super(auth_chain_b, self).__init__(method) super(auth_chain_b, self).__init__(method)
@ -701,4 +724,3 @@ class auth_chain_b(auth_chain_a):
if buf_size > 400: if buf_size > 400:
return random.next() % 521 return random.next() % 521
return random.next() % 1021 return random.next() % 1021

Loading…
Cancel
Save