Browse Source

add typing hint and reformat code

akkariiin/Experimental
Akkariiin 6 years ago
parent
commit
1e6383685c
  1. 175
      shadowsocks/obfsplugin/auth_akarin.py

175
shadowsocks/obfsplugin/auth_akarin.py

@ -29,6 +29,8 @@ import math
import struct
import hmac
import bisect
import typing
from ..obfs import (server_info as ServerInfo)
import shadowsocks
from shadowsocks import common, lru_cache, encrypt
@ -38,6 +40,7 @@ from shadowsocks.crypto import openssl
rand_bytes = openssl.rand_bytes
def create_auth_akarin_rand(method):
return auth_akarin_rand(method)
@ -46,7 +49,7 @@ def create_auth_akarin_spec_a(method):
return auth_akarin_spec_a(method)
obfs_map = {
obfs_map: typing.Dict[str, tuple] = {
'auth_akarin_rand': (create_auth_akarin_rand,),
'auth_akarin_spec_a': (create_auth_akarin_spec_a,),
}
@ -60,7 +63,7 @@ class xorshift128plus(object):
self.v0 = 0
self.v1 = 0
def next(self):
def next(self) -> int:
x = self.v0
y = self.v1
self.v0 = y
@ -69,19 +72,20 @@ class xorshift128plus(object):
self.v1 = x
return (x + y) & xorshift128plus.max_int
def init_from_bin(self, bin):
def init_from_bin(self, bin: bytes):
if len(bin) < 16:
bin += b'\0' * 16
self.v0 = struct.unpack('<Q', bin[:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0]
def init_from_bin_len(self, bin, length):
def init_from_bin_len(self, bin: bytes, length: int):
if len(bin) < 16:
bin += b'\0' * 16
self.v0 = struct.unpack('<Q', struct.pack('<H', length) + bin[2:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0]
def match_begin(str1, str2):
def match_begin(str1: str, str2: str):
if len(str1) >= len(str2):
if str1[:len(str2)] == str2:
return True
@ -89,34 +93,35 @@ def match_begin(str1, str2):
class auth_base(plain.plain):
def __init__(self, method):
def __init__(self, method: str):
super(auth_base, self).__init__(method)
self.method = method
self.method: str = method
self.no_compatible_method = ''
self.overhead = 4
self.overhead: int = 4
self.raw_trans: bool = False
def init_data(self):
return ''
def get_overhead(self, direction): # direction: true for c->s false for s->c
def get_overhead(self, direction: bool) -> int: # direction: true for c->s false for s->c
return self.overhead
def set_server_info(self, server_info):
self.server_info = server_info
def set_server_info(self, server_info: ServerInfo):
self.server_info: ServerInfo = server_info
def client_encode(self, buf):
def client_encode(self, buf: bytes) -> bytes:
return buf
def client_decode(self, buf):
def client_decode(self, buf: bytes) -> typing.Tuple[bytes, bool]:
return (buf, False)
def server_encode(self, buf):
def server_encode(self, buf: bytes) -> bytes:
return buf
def server_decode(self, buf):
def server_decode(self, buf: bytes) -> typing.Tuple[bytes, bool, bool]:
return (buf, True, False)
def not_match_return(self, buf):
def not_match_return(self, buf: bytes) -> typing.Tuple[bytes, bool]:
self.raw_trans = True
self.overhead = 0
if self.method == self.no_compatible_method:
@ -125,13 +130,13 @@ class auth_base(plain.plain):
class client_queue(object):
def __init__(self, begin_id):
self.front = begin_id - 64
self.back = begin_id + 1
self.alloc = {}
self.enable = True
self.last_update = time.time()
self.ref = 0
def __init__(self, begin_id: int):
self.front: int = begin_id - 64
self.back: int = begin_id + 1
self.alloc: typing.Dict[int, bool] = {}
self.enable: bool = True
self.last_update: float = time.time()
self.ref: int = 0
def update(self):
self.last_update = time.time()
@ -146,13 +151,13 @@ class client_queue(object):
def is_active(self):
return (self.ref > 0) and (time.time() - self.last_update < 60 * 10)
def re_enable(self, connection_id):
def re_enable(self, connection_id: int):
self.enable = True
self.front = connection_id - 64
self.back = connection_id + 1
self.alloc = {}
def insert(self, connection_id):
def insert(self, connection_id: int) -> bool:
if not self.enable:
logging.warn('obfs auth: not enable')
return False
@ -180,29 +185,31 @@ class client_queue(object):
class obfs_auth_akarin_data(object):
def __init__(self, name):
self.name = name
self.user_id = {}
self.local_client_id = b''
self.connection_id = 0
def __init__(self, name: str):
self.name: str = name
self.user_id: typing.Dict[int, lru_cache.LRUCache[int, client_queue]] = {}
self.local_client_id: bytes = b''
self.connection_id: int = 0
self.max_client: int = 0
self.max_buffer: int = 0
self.set_max_client(64) # max active client count
def update(self, user_id, client_id, connection_id):
def update(self, user_id: int, client_id: int, connection_id: int):
if user_id not in self.user_id:
self.user_id[user_id] = lru_cache.LRUCache()
local_client_id = self.user_id[user_id]
local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id]
if client_id in local_client_id:
local_client_id[client_id].update()
def set_max_client(self, max_client):
self.max_client = max_client
self.max_buffer = max(self.max_client * 2, 1024)
def set_max_client(self, max_client: int):
self.max_client: int = max_client
self.max_buffer: int = max(self.max_client * 2, 1024)
def insert(self, user_id, client_id, connection_id):
def insert(self, user_id: int, client_id: int, connection_id: int):
if user_id not in self.user_id:
self.user_id[user_id] = lru_cache.LRUCache()
local_client_id = self.user_id[user_id]
local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id]
if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable:
if local_client_id.first() is None or len(local_client_id) < self.max_client:
@ -229,49 +236,49 @@ class obfs_auth_akarin_data(object):
def remove(self, user_id, client_id):
if user_id in self.user_id:
local_client_id = self.user_id[user_id]
local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id]
if client_id in local_client_id:
local_client_id[client_id].delref()
class auth_akarin_rand(auth_base):
def __init__(self, method):
def __init__(self, method: str):
super(auth_akarin_rand, self).__init__(method)
self.hashfunc = hashlib.md5
self.recv_buf = b''
self.unit_len = 2800
self.raw_trans = False
self.has_sent_header = False
self.has_recv_header = False
self.client_id = 0
self.connection_id = 0
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting
self.salt = b"auth_akarin_rand"
self.no_compatible_method = 'auth_akarin_rand'
self.pack_id = 1
self.recv_id = 1
self.user_id = None
self.user_id_num = 0
self.user_key = None
self.overhead = 4
self.client_over_head = self.overhead
self.last_client_hash = b''
self.last_server_hash = b''
self.random_client = xorshift128plus()
self.random_server = xorshift128plus()
self.encryptor = None
self.new_send_tcp_mss = 2000
self.send_tcp_mss = 2000
self.recv_tcp_mss = 2000
self.send_back_cmd = []
def init_data(self):
self.hashfunc: function = hashlib.md5
self.recv_buf: bytes = b''
self.unit_len: int = 2800
self.raw_trans: bool = False
self.has_sent_header: bool = False
self.has_recv_header: bool = False
self.client_id: int = 0
self.connection_id: int = 0
self.max_time_dif: int = 60 * 60 * 24 # time dif (second) setting
self.salt: bytes = b"auth_akarin_rand"
self.no_compatible_method: str = 'auth_akarin_rand'
self.pack_id: int = 1
self.recv_id: int = 1
self.user_id: bytes = None
self.user_id_num: int = 0
self.user_key: bytes = None
self.overhead: int = 4
self.client_over_head: int = self.overhead
self.last_client_hash: bytes = b''
self.last_server_hash: bytes = b''
self.random_client: xorshift128plus = xorshift128plus()
self.random_server: xorshift128plus = xorshift128plus()
self.encryptor: encrypt.Encryptor = None
self.new_send_tcp_mss: int = 2000
self.send_tcp_mss: int = 2000
self.recv_tcp_mss: int = 2000
self.send_back_cmd: typing.List[bytes] = []
def init_data(self) -> obfs_auth_akarin_data:
return obfs_auth_akarin_data(self.method)
def get_overhead(self, direction): # direction: true for c->s false for s->c
def get_overhead(self, direction: bool) -> int: # direction: true for c->s false for s->c
return self.overhead
def set_server_info(self, server_info):
def set_server_info(self, server_info: ServerInfo):
self.server_info = server_info
try:
max_client = int(server_info.protocol_param.split('#')[0])
@ -290,7 +297,7 @@ class auth_akarin_rand(auth_base):
v = self.trapezoid_random_float(d)
return int(v * max_val)
def send_rnd_data_len(self, buf_size, last_hash, random):
def send_rnd_data_len(self, buf_size: int, last_hash, random: xorshift128plus) -> int:
if buf_size + self.server_info.overhead > self.send_tcp_mss:
random.init_from_bin_len(last_hash, buf_size)
return random.next() % 521
@ -305,7 +312,7 @@ class auth_akarin_rand(auth_base):
return random.next() % 521
return random.next() % (self.send_tcp_mss - buf_size - self.server_info.overhead)
def recv_rnd_data_len(self, buf_size, last_hash, random):
def recv_rnd_data_len(self, buf_size, last_hash, random: xorshift128plus) -> int:
if buf_size + self.server_info.overhead > self.recv_tcp_mss:
random.init_from_bin_len(last_hash, buf_size)
return random.next() % 521
@ -320,11 +327,11 @@ class auth_akarin_rand(auth_base):
return random.next() % 521
return random.next() % (self.recv_tcp_mss - buf_size - self.server_info.overhead)
def udp_rnd_data_len(self, last_hash, random):
def udp_rnd_data_len(self, last_hash, random: xorshift128plus) -> int:
random.init_from_bin(last_hash)
return random.next() % 127
def rnd_data(self, buf_size, buf, last_hash, random):
def rnd_data(self, buf_size: int, buf: bytes, last_hash, random: xorshift128plus) -> bytes:
rand_len = self.send_rnd_data_len(buf_size, last_hash, random)
rnd_data_buf = rand_bytes(rand_len)
@ -337,7 +344,7 @@ class auth_akarin_rand(auth_base):
else:
return buf
def pack_client_data(self, buf):
def pack_client_data(self, buf: bytes) -> bytes:
buf = self.encryptor.encrypt(buf)
if self.send_back_cmd:
cmd_len = 2
@ -401,7 +408,8 @@ class auth_akarin_rand(auth_base):
self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest()
data = check_head + data + self.last_server_hash[:4]
self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_client_hash[:8])
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20',
self.last_client_hash[:8])
self.encryptor.encrypt(b'')
self.encryptor.decrypt(self.last_server_hash[:8])
return data + self.pack_client_data(buf)
@ -563,7 +571,8 @@ class auth_akarin_rand(auth_base):
self.on_recv_auth_data(utc_time)
self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_server_hash[:8])
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)),
'chacha20', self.last_server_hash[:8])
self.encryptor.encrypt(b'')
self.encryptor.decrypt(self.last_client_hash[:8])
self.recv_buf = self.recv_buf[36:]
@ -580,7 +589,8 @@ class auth_akarin_rand(auth_base):
cmd_len += 2
self.recv_tcp_mss = self.send_tcp_mss
recv_buf = recv_buf[2:]
data_len = struct.unpack('<H', recv_buf[:2])[0] ^ struct.unpack('<H', self.last_client_hash[12:14])[0]
data_len = struct.unpack('<H', recv_buf[:2])[0] ^ struct.unpack('<H', self.last_client_hash[12:14])[
0]
else:
self.raw_trans = True
self.recv_buf = b''
@ -680,7 +690,8 @@ class auth_akarin_rand(auth_base):
mac_key = self.server_info.key
md5data = hmac.new(mac_key, authdata, self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'chacha20', mac_key[:8])
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)),
'chacha20', mac_key[:8])
encryptor.encrypt(b'')
out_buf = encryptor.encrypt(buf)
buf = out_buf + rand_bytes(rand_len) + authdata
@ -702,7 +713,8 @@ class auth_akarin_rand(auth_base):
if hmac.new(user_key, buf[:-1], self.hashfunc).digest()[:1] != buf[-1:]:
return (b'', None)
rand_len = self.udp_rnd_data_len(md5data, self.random_client)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'chacha20')
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)),
'chacha20')
encryptor.decrypt(mac_key[:8])
out_buf = encryptor.decrypt(buf[:-8 - rand_len])
return (out_buf, uid)
@ -770,7 +782,6 @@ class auth_akarin_spec_a(auth_akarin_rand):
return random.next() % 521
return random.next() % 1021
def recv_rnd_data_len(self, buf_size, last_hash, random):
if buf_size + self.server_info.overhead > self.recv_tcp_mss:
random.init_from_bin_len(last_hash, buf_size)
@ -797,5 +808,3 @@ class auth_akarin_spec_a(auth_akarin_rand):
if buf_size > 400:
return random.next() % 521
return random.next() % 1021

Loading…
Cancel
Save