diff --git a/shadowsocks/obfsplugin/auth_chain.py b/shadowsocks/obfsplugin/auth_chain.py index cf2b341..3cff2ec 100644 --- a/shadowsocks/obfsplugin/auth_chain.py +++ b/shadowsocks/obfsplugin/auth_chain.py @@ -38,17 +38,21 @@ from shadowsocks import common, lru_cache, encrypt from shadowsocks.obfsplugin import plain from shadowsocks.common import to_bytes, to_str, ord, chr + def create_auth_chain_a(method): return auth_chain_a(method) + def create_auth_chain_b(method): return auth_chain_b(method) + obfs_map = { - 'auth_chain_a': (create_auth_chain_a,), - 'auth_chain_b': (create_auth_chain_b,), + 'auth_chain_a': (create_auth_chain_a,), + 'auth_chain_b': (create_auth_chain_b,), } + class xorshift128plus(object): max_int = (1 << 64) - 1 mov_mask = (1 << (64 - 23)) - 1 @@ -80,12 +84,14 @@ class xorshift128plus(object): for i in range(4): self.next() + def match_begin(str1, str2): if len(str1) >= len(str2): if str1[:len(str2)] == str2: return True return False + class auth_base(plain.plain): def __init__(self, method): super(auth_base, self).__init__(method) @@ -96,7 +102,7 @@ class auth_base(plain.plain): def init_data(self): return '' - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def set_server_info(self, server_info): @@ -118,9 +124,10 @@ class auth_base(plain.plain): self.raw_trans = True self.overhead = 0 if self.method == self.no_compatible_method: - return (b'E'*2048, False) + return (b'E' * 2048, False) return (buf, False) + class client_queue(object): def __init__(self, begin_id): self.front = begin_id - 64 @@ -175,13 +182,14 @@ class client_queue(object): self.addref() return True + class obfs_auth_chain_data(object): def __init__(self, name): self.name = name self.user_id = {} self.local_client_id = b'' self.connection_id = 0 - self.set_max_client(64) # max active client count + self.set_max_client(64) # max active client count def update(self, user_id, client_id, connection_id): if user_id not in self.user_id: @@ -203,7 +211,7 @@ class obfs_auth_chain_data(object): if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: if local_client_id.first() is None or len(local_client_id) < self.max_client: if client_id not in local_client_id: - #TODO: check + # TODO: check local_client_id[client_id] = client_queue(connection_id) else: local_client_id[client_id].re_enable(connection_id) @@ -212,7 +220,7 @@ class obfs_auth_chain_data(object): if not local_client_id[local_client_id.first()].is_active(): del local_client_id[local_client_id.first()] if client_id not in local_client_id: - #TODO: check + # TODO: check local_client_id[client_id] = client_queue(connection_id) else: local_client_id[client_id].re_enable(connection_id) @@ -229,6 +237,7 @@ class obfs_auth_chain_data(object): if client_id in local_client_id: local_client_id[client_id].delref() + class auth_chain_a(auth_base): def __init__(self, method): super(auth_chain_a, self).__init__(method) @@ -240,7 +249,7 @@ class auth_chain_a(auth_base): self.has_recv_header = False self.client_id = 0 self.connection_id = 0 - self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting self.salt = b"auth_chain_a" self.no_compatible_method = 'auth_chain_a' self.pack_id = 1 @@ -259,7 +268,7 @@ class auth_chain_a(auth_base): def init_data(self): return obfs_auth_chain_data(self.method) - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def set_server_info(self, server_info): @@ -362,14 +371,16 @@ class auth_chain_a(auth_base): if self.user_key is None: self.user_key = self.server_info.key - encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16) + encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16) uid = struct.unpack(' 0 and rand_len > 0: pos = 2 + self.rnd_start_pos(rand_len, self.random_server) - out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) self.last_server_hash = server_hash if self.recv_id == 1: self.server_info.tcp_mss = struct.unpack(' self.max_time_dif: - logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) + logging.info('%s: wrong timestamp, time_dif %d, data %s' % ( + self.no_compatible_method, time_dif, binascii.hexlify(head) + )) return self.not_match_return(self.recv_buf) elif self.server_info.data.insert(self.user_id, client_id, connection_id): self.has_recv_header = True @@ -513,7 +530,8 @@ class auth_chain_a(auth_base): logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) return self.not_match_return(self.recv_buf) - self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') + self.encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.recv_buf = self.recv_buf[36:] self.has_recv_header = True sendback = True @@ -528,7 +546,7 @@ class auth_chain_a(auth_base): self.recv_buf = b'' if self.recv_id == 0: logging.info(self.no_compatible_method + ': over size') - return (b'E'*2048, False) + return (b'E' * 2048, False) else: raise Exception('server_post_decrype data error') @@ -536,12 +554,14 @@ class auth_chain_a(auth_base): break client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() - if client_hash[:2] != self.recv_buf[length + 2 : length + 4]: - logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + if client_hash[:2] != self.recv_buf[length + 2: length + 4]: + logging.info('%s: checksum error, data %s' % ( + self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]) + )) self.raw_trans = True self.recv_buf = b'' if self.recv_id == 0: - return (b'E'*2048, False) + return (b'E' * 2048, False) else: raise Exception('server_post_decrype data uncorrect checksum') @@ -549,7 +569,7 @@ class auth_chain_a(auth_base): pos = 2 if data_len > 0 and rand_len > 0: pos = 2 + self.rnd_start_pos(rand_len, self.random_client) - out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) self.last_client_hash = client_hash self.recv_buf = self.recv_buf[length + 4:] if data_len == 0: @@ -577,7 +597,8 @@ class auth_chain_a(auth_base): uid = struct.unpack(' 400: return random.next() % 521 return random.next() % 1021 -