diff --git a/shadowsocks/obfsplugin/auth.py b/shadowsocks/obfsplugin/auth.py index 313a170..c805796 100644 --- a/shadowsocks/obfsplugin/auth.py +++ b/shadowsocks/obfsplugin/auth.py @@ -1106,6 +1106,53 @@ class auth_aes128(auth_base): return (b'', None) return (data, None) +class obfs_auth_mu_data(object): + def __init__(self): + self.user_id = {} + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if client_id in local_client_id: + local_client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: + if local_client_id.first() is None or len(local_client_id) < self.max_client: + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + if not local_client_id[local_client_id.first()].is_active(): + del local_client_id[local_client_id.first()] + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + logging.warn('auth_aes128: no inactive client') + return False + else: + return local_client_id[client_id].insert(connection_id) + class auth_aes128_sha1(auth_base): def __init__(self, method, hashfunc): super(auth_aes128_sha1, self).__init__(method) @@ -1123,10 +1170,11 @@ class auth_aes128_sha1(auth_base): self.extra_wait_size = struct.unpack('>H', os.urandom(2))[0] % 1024 self.pack_id = 1 self.recv_id = 1 + self.user_id = None self.user_key = None def init_data(self): - return obfs_auth_v2_data() + return obfs_auth_mu_data() def set_server_info(self, server_info): self.server_info = server_info @@ -1282,6 +1330,7 @@ class auth_aes128_sha1(auth_base): uid = self.recv_buf[7:11] if uid in self.server_info.users: + self.user_id = uid self.user_key = self.hashfunc(self.server_info.users[uid]).digest() self.server_info.update_user_func(uid) else: @@ -1306,7 +1355,7 @@ class auth_aes128_sha1(auth_base): if time_dif < -self.max_time_dif or time_dif > self.max_time_dif: logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) return self.not_match_return(self.recv_buf) - elif self.server_info.data.insert(client_id, connection_id): + elif self.server_info.data.insert(self.user_id, client_id, connection_id): self.has_recv_header = True out_buf = self.recv_buf[31 + rnd_len:length - 4] self.client_id = client_id @@ -1362,7 +1411,7 @@ class auth_aes128_sha1(auth_base): sendback = True if out_buf: - self.server_info.data.update(self.client_id, self.connection_id) + self.server_info.data.update(self.user_id, self.client_id, self.connection_id) return (out_buf, sendback) def client_udp_pre_encrypt(self, buf):