From 1e6383685ca4153cd34bf4f8efaeb60d2e022c67 Mon Sep 17 00:00:00 2001 From: Akkariiin Date: Mon, 13 Aug 2018 00:24:04 +0800 Subject: [PATCH] add typing hint and reformat code --- shadowsocks/obfsplugin/auth_akarin.py | 177 ++++++++++++++------------ 1 file changed, 93 insertions(+), 84 deletions(-) diff --git a/shadowsocks/obfsplugin/auth_akarin.py b/shadowsocks/obfsplugin/auth_akarin.py index 58f84a3..903d5f1 100644 --- a/shadowsocks/obfsplugin/auth_akarin.py +++ b/shadowsocks/obfsplugin/auth_akarin.py @@ -29,6 +29,8 @@ import math import struct import hmac import bisect +import typing +from ..obfs import (server_info as ServerInfo) import shadowsocks from shadowsocks import common, lru_cache, encrypt @@ -38,6 +40,7 @@ from shadowsocks.crypto import openssl rand_bytes = openssl.rand_bytes + def create_auth_akarin_rand(method): return auth_akarin_rand(method) @@ -46,7 +49,7 @@ def create_auth_akarin_spec_a(method): return auth_akarin_spec_a(method) -obfs_map = { +obfs_map: typing.Dict[str, tuple] = { 'auth_akarin_rand': (create_auth_akarin_rand,), 'auth_akarin_spec_a': (create_auth_akarin_spec_a,), } @@ -60,7 +63,7 @@ class xorshift128plus(object): self.v0 = 0 self.v1 = 0 - def next(self): + def next(self) -> int: x = self.v0 y = self.v1 self.v0 = y @@ -69,19 +72,20 @@ class xorshift128plus(object): self.v1 = x return (x + y) & xorshift128plus.max_int - def init_from_bin(self, bin): + def init_from_bin(self, bin: bytes): if len(bin) < 16: bin += b'\0' * 16 self.v0 = struct.unpack('= len(str2): if str1[:len(str2)] == str2: return True @@ -89,34 +93,35 @@ def match_begin(str1, str2): class auth_base(plain.plain): - def __init__(self, method): + def __init__(self, method: str): super(auth_base, self).__init__(method) - self.method = method + self.method: str = method self.no_compatible_method = '' - self.overhead = 4 + self.overhead: int = 4 + self.raw_trans: bool = False def init_data(self): return '' - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction: bool) -> int: # direction: true for c->s false for s->c return self.overhead - def set_server_info(self, server_info): - self.server_info = server_info + def set_server_info(self, server_info: ServerInfo): + self.server_info: ServerInfo = server_info - def client_encode(self, buf): + def client_encode(self, buf: bytes) -> bytes: return buf - def client_decode(self, buf): + def client_decode(self, buf: bytes) -> typing.Tuple[bytes, bool]: return (buf, False) - def server_encode(self, buf): + def server_encode(self, buf: bytes) -> bytes: return buf - def server_decode(self, buf): + def server_decode(self, buf: bytes) -> typing.Tuple[bytes, bool, bool]: return (buf, True, False) - def not_match_return(self, buf): + def not_match_return(self, buf: bytes) -> typing.Tuple[bytes, bool]: self.raw_trans = True self.overhead = 0 if self.method == self.no_compatible_method: @@ -125,13 +130,13 @@ class auth_base(plain.plain): class client_queue(object): - def __init__(self, begin_id): - self.front = begin_id - 64 - self.back = begin_id + 1 - self.alloc = {} - self.enable = True - self.last_update = time.time() - self.ref = 0 + def __init__(self, begin_id: int): + self.front: int = begin_id - 64 + self.back: int = begin_id + 1 + self.alloc: typing.Dict[int, bool] = {} + self.enable: bool = True + self.last_update: float = time.time() + self.ref: int = 0 def update(self): self.last_update = time.time() @@ -146,13 +151,13 @@ class client_queue(object): def is_active(self): return (self.ref > 0) and (time.time() - self.last_update < 60 * 10) - def re_enable(self, connection_id): + def re_enable(self, connection_id: int): self.enable = True self.front = connection_id - 64 self.back = connection_id + 1 self.alloc = {} - def insert(self, connection_id): + def insert(self, connection_id: int) -> bool: if not self.enable: logging.warn('obfs auth: not enable') return False @@ -180,29 +185,31 @@ class client_queue(object): class obfs_auth_akarin_data(object): - def __init__(self, name): - self.name = name - self.user_id = {} - self.local_client_id = b'' - self.connection_id = 0 + def __init__(self, name: str): + self.name: str = name + self.user_id: typing.Dict[int, lru_cache.LRUCache[int, client_queue]] = {} + self.local_client_id: bytes = b'' + self.connection_id: int = 0 + self.max_client: int = 0 + self.max_buffer: int = 0 self.set_max_client(64) # max active client count - def update(self, user_id, client_id, connection_id): + def update(self, user_id: int, client_id: int, connection_id: int): if user_id not in self.user_id: self.user_id[user_id] = lru_cache.LRUCache() - local_client_id = self.user_id[user_id] + local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id] if client_id in local_client_id: local_client_id[client_id].update() - def set_max_client(self, max_client): - self.max_client = max_client - self.max_buffer = max(self.max_client * 2, 1024) + def set_max_client(self, max_client: int): + self.max_client: int = max_client + self.max_buffer: int = max(self.max_client * 2, 1024) - def insert(self, user_id, client_id, connection_id): + def insert(self, user_id: int, client_id: int, connection_id: int): if user_id not in self.user_id: self.user_id[user_id] = lru_cache.LRUCache() - local_client_id = self.user_id[user_id] + local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id] if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: if local_client_id.first() is None or len(local_client_id) < self.max_client: @@ -229,49 +236,49 @@ class obfs_auth_akarin_data(object): def remove(self, user_id, client_id): if user_id in self.user_id: - local_client_id = self.user_id[user_id] + local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id] if client_id in local_client_id: local_client_id[client_id].delref() class auth_akarin_rand(auth_base): - def __init__(self, method): + def __init__(self, method: str): super(auth_akarin_rand, self).__init__(method) - self.hashfunc = hashlib.md5 - self.recv_buf = b'' - self.unit_len = 2800 - self.raw_trans = False - self.has_sent_header = False - self.has_recv_header = False - self.client_id = 0 - self.connection_id = 0 - self.max_time_dif = 60 * 60 * 24 # time dif (second) setting - self.salt = b"auth_akarin_rand" - self.no_compatible_method = 'auth_akarin_rand' - self.pack_id = 1 - self.recv_id = 1 - self.user_id = None - self.user_id_num = 0 - self.user_key = None - self.overhead = 4 - self.client_over_head = self.overhead - self.last_client_hash = b'' - self.last_server_hash = b'' - self.random_client = xorshift128plus() - self.random_server = xorshift128plus() - self.encryptor = None - self.new_send_tcp_mss = 2000 - self.send_tcp_mss = 2000 - self.recv_tcp_mss = 2000 - self.send_back_cmd = [] - - def init_data(self): + self.hashfunc: function = hashlib.md5 + self.recv_buf: bytes = b'' + self.unit_len: int = 2800 + self.raw_trans: bool = False + self.has_sent_header: bool = False + self.has_recv_header: bool = False + self.client_id: int = 0 + self.connection_id: int = 0 + self.max_time_dif: int = 60 * 60 * 24 # time dif (second) setting + self.salt: bytes = b"auth_akarin_rand" + self.no_compatible_method: str = 'auth_akarin_rand' + self.pack_id: int = 1 + self.recv_id: int = 1 + self.user_id: bytes = None + self.user_id_num: int = 0 + self.user_key: bytes = None + self.overhead: int = 4 + self.client_over_head: int = self.overhead + self.last_client_hash: bytes = b'' + self.last_server_hash: bytes = b'' + self.random_client: xorshift128plus = xorshift128plus() + self.random_server: xorshift128plus = xorshift128plus() + self.encryptor: encrypt.Encryptor = None + self.new_send_tcp_mss: int = 2000 + self.send_tcp_mss: int = 2000 + self.recv_tcp_mss: int = 2000 + self.send_back_cmd: typing.List[bytes] = [] + + def init_data(self) -> obfs_auth_akarin_data: return obfs_auth_akarin_data(self.method) - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction: bool) -> int: # direction: true for c->s false for s->c return self.overhead - def set_server_info(self, server_info): + def set_server_info(self, server_info: ServerInfo): self.server_info = server_info try: max_client = int(server_info.protocol_param.split('#')[0]) @@ -290,7 +297,7 @@ class auth_akarin_rand(auth_base): v = self.trapezoid_random_float(d) return int(v * max_val) - def send_rnd_data_len(self, buf_size, last_hash, random): + def send_rnd_data_len(self, buf_size: int, last_hash, random: xorshift128plus) -> int: if buf_size + self.server_info.overhead > self.send_tcp_mss: random.init_from_bin_len(last_hash, buf_size) return random.next() % 521 @@ -305,7 +312,7 @@ class auth_akarin_rand(auth_base): return random.next() % 521 return random.next() % (self.send_tcp_mss - buf_size - self.server_info.overhead) - def recv_rnd_data_len(self, buf_size, last_hash, random): + def recv_rnd_data_len(self, buf_size, last_hash, random: xorshift128plus) -> int: if buf_size + self.server_info.overhead > self.recv_tcp_mss: random.init_from_bin_len(last_hash, buf_size) return random.next() % 521 @@ -320,11 +327,11 @@ class auth_akarin_rand(auth_base): return random.next() % 521 return random.next() % (self.recv_tcp_mss - buf_size - self.server_info.overhead) - def udp_rnd_data_len(self, last_hash, random): + def udp_rnd_data_len(self, last_hash, random: xorshift128plus) -> int: random.init_from_bin(last_hash) return random.next() % 127 - def rnd_data(self, buf_size, buf, last_hash, random): + def rnd_data(self, buf_size: int, buf: bytes, last_hash, random: xorshift128plus) -> bytes: rand_len = self.send_rnd_data_len(buf_size, last_hash, random) rnd_data_buf = rand_bytes(rand_len) @@ -337,7 +344,7 @@ class auth_akarin_rand(auth_base): else: return buf - def pack_client_data(self, buf): + def pack_client_data(self, buf: bytes) -> bytes: buf = self.encryptor.encrypt(buf) if self.send_back_cmd: cmd_len = 2 @@ -401,7 +408,8 @@ class auth_akarin_rand(auth_base): self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest() data = check_head + data + self.last_server_hash[:4] self.encryptor = encrypt.Encryptor( - to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_client_hash[:8]) + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', + self.last_client_hash[:8]) self.encryptor.encrypt(b'') self.encryptor.decrypt(self.last_server_hash[:8]) return data + self.pack_client_data(buf) @@ -563,7 +571,8 @@ class auth_akarin_rand(auth_base): self.on_recv_auth_data(utc_time) self.encryptor = encrypt.Encryptor( - to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_server_hash[:8]) + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), + 'chacha20', self.last_server_hash[:8]) self.encryptor.encrypt(b'') self.encryptor.decrypt(self.last_client_hash[:8]) self.recv_buf = self.recv_buf[36:] @@ -580,7 +589,8 @@ class auth_akarin_rand(auth_base): cmd_len += 2 self.recv_tcp_mss = self.send_tcp_mss recv_buf = recv_buf[2:] - data_len = struct.unpack(' self.recv_tcp_mss: random.init_from_bin_len(last_hash, buf_size) @@ -797,5 +808,3 @@ class auth_akarin_spec_a(auth_akarin_rand): if buf_size > 400: return random.next() % 521 return random.next() % 1021 - -