Browse Source

add typing hint and reformat code

akkariiin/Experimental
Akkariiin 6 years ago
parent
commit
1e6383685c
  1. 177
      shadowsocks/obfsplugin/auth_akarin.py

177
shadowsocks/obfsplugin/auth_akarin.py

@ -29,6 +29,8 @@ import math
import struct import struct
import hmac import hmac
import bisect import bisect
import typing
from ..obfs import (server_info as ServerInfo)
import shadowsocks import shadowsocks
from shadowsocks import common, lru_cache, encrypt from shadowsocks import common, lru_cache, encrypt
@ -38,6 +40,7 @@ from shadowsocks.crypto import openssl
rand_bytes = openssl.rand_bytes rand_bytes = openssl.rand_bytes
def create_auth_akarin_rand(method): def create_auth_akarin_rand(method):
return auth_akarin_rand(method) return auth_akarin_rand(method)
@ -46,7 +49,7 @@ def create_auth_akarin_spec_a(method):
return auth_akarin_spec_a(method) return auth_akarin_spec_a(method)
obfs_map = { obfs_map: typing.Dict[str, tuple] = {
'auth_akarin_rand': (create_auth_akarin_rand,), 'auth_akarin_rand': (create_auth_akarin_rand,),
'auth_akarin_spec_a': (create_auth_akarin_spec_a,), 'auth_akarin_spec_a': (create_auth_akarin_spec_a,),
} }
@ -60,7 +63,7 @@ class xorshift128plus(object):
self.v0 = 0 self.v0 = 0
self.v1 = 0 self.v1 = 0
def next(self): def next(self) -> int:
x = self.v0 x = self.v0
y = self.v1 y = self.v1
self.v0 = y self.v0 = y
@ -69,19 +72,20 @@ class xorshift128plus(object):
self.v1 = x self.v1 = x
return (x + y) & xorshift128plus.max_int return (x + y) & xorshift128plus.max_int
def init_from_bin(self, bin): def init_from_bin(self, bin: bytes):
if len(bin) < 16: if len(bin) < 16:
bin += b'\0' * 16 bin += b'\0' * 16
self.v0 = struct.unpack('<Q', bin[:8])[0] self.v0 = struct.unpack('<Q', bin[:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0] self.v1 = struct.unpack('<Q', bin[8:16])[0]
def init_from_bin_len(self, bin, length): def init_from_bin_len(self, bin: bytes, length: int):
if len(bin) < 16: if len(bin) < 16:
bin += b'\0' * 16 bin += b'\0' * 16
self.v0 = struct.unpack('<Q', struct.pack('<H', length) + bin[2:8])[0] self.v0 = struct.unpack('<Q', struct.pack('<H', length) + bin[2:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0] self.v1 = struct.unpack('<Q', bin[8:16])[0]
def match_begin(str1, str2):
def match_begin(str1: str, str2: str):
if len(str1) >= len(str2): if len(str1) >= len(str2):
if str1[:len(str2)] == str2: if str1[:len(str2)] == str2:
return True return True
@ -89,34 +93,35 @@ def match_begin(str1, str2):
class auth_base(plain.plain): class auth_base(plain.plain):
def __init__(self, method): def __init__(self, method: str):
super(auth_base, self).__init__(method) super(auth_base, self).__init__(method)
self.method = method self.method: str = method
self.no_compatible_method = '' self.no_compatible_method = ''
self.overhead = 4 self.overhead: int = 4
self.raw_trans: bool = False
def init_data(self): def init_data(self):
return '' return ''
def get_overhead(self, direction): # direction: true for c->s false for s->c def get_overhead(self, direction: bool) -> int: # direction: true for c->s false for s->c
return self.overhead return self.overhead
def set_server_info(self, server_info): def set_server_info(self, server_info: ServerInfo):
self.server_info = server_info self.server_info: ServerInfo = server_info
def client_encode(self, buf): def client_encode(self, buf: bytes) -> bytes:
return buf return buf
def client_decode(self, buf): def client_decode(self, buf: bytes) -> typing.Tuple[bytes, bool]:
return (buf, False) return (buf, False)
def server_encode(self, buf): def server_encode(self, buf: bytes) -> bytes:
return buf return buf
def server_decode(self, buf): def server_decode(self, buf: bytes) -> typing.Tuple[bytes, bool, bool]:
return (buf, True, False) return (buf, True, False)
def not_match_return(self, buf): def not_match_return(self, buf: bytes) -> typing.Tuple[bytes, bool]:
self.raw_trans = True self.raw_trans = True
self.overhead = 0 self.overhead = 0
if self.method == self.no_compatible_method: if self.method == self.no_compatible_method:
@ -125,13 +130,13 @@ class auth_base(plain.plain):
class client_queue(object): class client_queue(object):
def __init__(self, begin_id): def __init__(self, begin_id: int):
self.front = begin_id - 64 self.front: int = begin_id - 64
self.back = begin_id + 1 self.back: int = begin_id + 1
self.alloc = {} self.alloc: typing.Dict[int, bool] = {}
self.enable = True self.enable: bool = True
self.last_update = time.time() self.last_update: float = time.time()
self.ref = 0 self.ref: int = 0
def update(self): def update(self):
self.last_update = time.time() self.last_update = time.time()
@ -146,13 +151,13 @@ class client_queue(object):
def is_active(self): def is_active(self):
return (self.ref > 0) and (time.time() - self.last_update < 60 * 10) return (self.ref > 0) and (time.time() - self.last_update < 60 * 10)
def re_enable(self, connection_id): def re_enable(self, connection_id: int):
self.enable = True self.enable = True
self.front = connection_id - 64 self.front = connection_id - 64
self.back = connection_id + 1 self.back = connection_id + 1
self.alloc = {} self.alloc = {}
def insert(self, connection_id): def insert(self, connection_id: int) -> bool:
if not self.enable: if not self.enable:
logging.warn('obfs auth: not enable') logging.warn('obfs auth: not enable')
return False return False
@ -180,29 +185,31 @@ class client_queue(object):
class obfs_auth_akarin_data(object): class obfs_auth_akarin_data(object):
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name: str = name
self.user_id = {} self.user_id: typing.Dict[int, lru_cache.LRUCache[int, client_queue]] = {}
self.local_client_id = b'' self.local_client_id: bytes = b''
self.connection_id = 0 self.connection_id: int = 0
self.max_client: int = 0
self.max_buffer: int = 0
self.set_max_client(64) # max active client count self.set_max_client(64) # max active client count
def update(self, user_id, client_id, connection_id): def update(self, user_id: int, client_id: int, connection_id: int):
if user_id not in self.user_id: if user_id not in self.user_id:
self.user_id[user_id] = lru_cache.LRUCache() self.user_id[user_id] = lru_cache.LRUCache()
local_client_id = self.user_id[user_id] local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id]
if client_id in local_client_id: if client_id in local_client_id:
local_client_id[client_id].update() local_client_id[client_id].update()
def set_max_client(self, max_client): def set_max_client(self, max_client: int):
self.max_client = max_client self.max_client: int = max_client
self.max_buffer = max(self.max_client * 2, 1024) self.max_buffer: int = max(self.max_client * 2, 1024)
def insert(self, user_id, client_id, connection_id): def insert(self, user_id: int, client_id: int, connection_id: int):
if user_id not in self.user_id: if user_id not in self.user_id:
self.user_id[user_id] = lru_cache.LRUCache() self.user_id[user_id] = lru_cache.LRUCache()
local_client_id = self.user_id[user_id] local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id]
if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable:
if local_client_id.first() is None or len(local_client_id) < self.max_client: if local_client_id.first() is None or len(local_client_id) < self.max_client:
@ -229,49 +236,49 @@ class obfs_auth_akarin_data(object):
def remove(self, user_id, client_id): def remove(self, user_id, client_id):
if user_id in self.user_id: if user_id in self.user_id:
local_client_id = self.user_id[user_id] local_client_id: lru_cache.LRUCache[int, client_queue] = self.user_id[user_id]
if client_id in local_client_id: if client_id in local_client_id:
local_client_id[client_id].delref() local_client_id[client_id].delref()
class auth_akarin_rand(auth_base): class auth_akarin_rand(auth_base):
def __init__(self, method): def __init__(self, method: str):
super(auth_akarin_rand, self).__init__(method) super(auth_akarin_rand, self).__init__(method)
self.hashfunc = hashlib.md5 self.hashfunc: function = hashlib.md5
self.recv_buf = b'' self.recv_buf: bytes = b''
self.unit_len = 2800 self.unit_len: int = 2800
self.raw_trans = False self.raw_trans: bool = False
self.has_sent_header = False self.has_sent_header: bool = False
self.has_recv_header = False self.has_recv_header: bool = False
self.client_id = 0 self.client_id: int = 0
self.connection_id = 0 self.connection_id: int = 0
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting self.max_time_dif: int = 60 * 60 * 24 # time dif (second) setting
self.salt = b"auth_akarin_rand" self.salt: bytes = b"auth_akarin_rand"
self.no_compatible_method = 'auth_akarin_rand' self.no_compatible_method: str = 'auth_akarin_rand'
self.pack_id = 1 self.pack_id: int = 1
self.recv_id = 1 self.recv_id: int = 1
self.user_id = None self.user_id: bytes = None
self.user_id_num = 0 self.user_id_num: int = 0
self.user_key = None self.user_key: bytes = None
self.overhead = 4 self.overhead: int = 4
self.client_over_head = self.overhead self.client_over_head: int = self.overhead
self.last_client_hash = b'' self.last_client_hash: bytes = b''
self.last_server_hash = b'' self.last_server_hash: bytes = b''
self.random_client = xorshift128plus() self.random_client: xorshift128plus = xorshift128plus()
self.random_server = xorshift128plus() self.random_server: xorshift128plus = xorshift128plus()
self.encryptor = None self.encryptor: encrypt.Encryptor = None
self.new_send_tcp_mss = 2000 self.new_send_tcp_mss: int = 2000
self.send_tcp_mss = 2000 self.send_tcp_mss: int = 2000
self.recv_tcp_mss = 2000 self.recv_tcp_mss: int = 2000
self.send_back_cmd = [] self.send_back_cmd: typing.List[bytes] = []
def init_data(self): def init_data(self) -> obfs_auth_akarin_data:
return obfs_auth_akarin_data(self.method) return obfs_auth_akarin_data(self.method)
def get_overhead(self, direction): # direction: true for c->s false for s->c def get_overhead(self, direction: bool) -> int: # direction: true for c->s false for s->c
return self.overhead return self.overhead
def set_server_info(self, server_info): def set_server_info(self, server_info: ServerInfo):
self.server_info = server_info self.server_info = server_info
try: try:
max_client = int(server_info.protocol_param.split('#')[0]) max_client = int(server_info.protocol_param.split('#')[0])
@ -290,7 +297,7 @@ class auth_akarin_rand(auth_base):
v = self.trapezoid_random_float(d) v = self.trapezoid_random_float(d)
return int(v * max_val) return int(v * max_val)
def send_rnd_data_len(self, buf_size, last_hash, random): def send_rnd_data_len(self, buf_size: int, last_hash, random: xorshift128plus) -> int:
if buf_size + self.server_info.overhead > self.send_tcp_mss: if buf_size + self.server_info.overhead > self.send_tcp_mss:
random.init_from_bin_len(last_hash, buf_size) random.init_from_bin_len(last_hash, buf_size)
return random.next() % 521 return random.next() % 521
@ -305,7 +312,7 @@ class auth_akarin_rand(auth_base):
return random.next() % 521 return random.next() % 521
return random.next() % (self.send_tcp_mss - buf_size - self.server_info.overhead) return random.next() % (self.send_tcp_mss - buf_size - self.server_info.overhead)
def recv_rnd_data_len(self, buf_size, last_hash, random): def recv_rnd_data_len(self, buf_size, last_hash, random: xorshift128plus) -> int:
if buf_size + self.server_info.overhead > self.recv_tcp_mss: if buf_size + self.server_info.overhead > self.recv_tcp_mss:
random.init_from_bin_len(last_hash, buf_size) random.init_from_bin_len(last_hash, buf_size)
return random.next() % 521 return random.next() % 521
@ -320,11 +327,11 @@ class auth_akarin_rand(auth_base):
return random.next() % 521 return random.next() % 521
return random.next() % (self.recv_tcp_mss - buf_size - self.server_info.overhead) return random.next() % (self.recv_tcp_mss - buf_size - self.server_info.overhead)
def udp_rnd_data_len(self, last_hash, random): def udp_rnd_data_len(self, last_hash, random: xorshift128plus) -> int:
random.init_from_bin(last_hash) random.init_from_bin(last_hash)
return random.next() % 127 return random.next() % 127
def rnd_data(self, buf_size, buf, last_hash, random): def rnd_data(self, buf_size: int, buf: bytes, last_hash, random: xorshift128plus) -> bytes:
rand_len = self.send_rnd_data_len(buf_size, last_hash, random) rand_len = self.send_rnd_data_len(buf_size, last_hash, random)
rnd_data_buf = rand_bytes(rand_len) rnd_data_buf = rand_bytes(rand_len)
@ -337,7 +344,7 @@ class auth_akarin_rand(auth_base):
else: else:
return buf return buf
def pack_client_data(self, buf): def pack_client_data(self, buf: bytes) -> bytes:
buf = self.encryptor.encrypt(buf) buf = self.encryptor.encrypt(buf)
if self.send_back_cmd: if self.send_back_cmd:
cmd_len = 2 cmd_len = 2
@ -401,7 +408,8 @@ class auth_akarin_rand(auth_base):
self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest() self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest()
data = check_head + data + self.last_server_hash[:4] data = check_head + data + self.last_server_hash[:4]
self.encryptor = encrypt.Encryptor( self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_client_hash[:8]) to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20',
self.last_client_hash[:8])
self.encryptor.encrypt(b'') self.encryptor.encrypt(b'')
self.encryptor.decrypt(self.last_server_hash[:8]) self.encryptor.decrypt(self.last_server_hash[:8])
return data + self.pack_client_data(buf) return data + self.pack_client_data(buf)
@ -563,7 +571,8 @@ class auth_akarin_rand(auth_base):
self.on_recv_auth_data(utc_time) self.on_recv_auth_data(utc_time)
self.encryptor = encrypt.Encryptor( self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_server_hash[:8]) to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)),
'chacha20', self.last_server_hash[:8])
self.encryptor.encrypt(b'') self.encryptor.encrypt(b'')
self.encryptor.decrypt(self.last_client_hash[:8]) self.encryptor.decrypt(self.last_client_hash[:8])
self.recv_buf = self.recv_buf[36:] self.recv_buf = self.recv_buf[36:]
@ -580,7 +589,8 @@ class auth_akarin_rand(auth_base):
cmd_len += 2 cmd_len += 2
self.recv_tcp_mss = self.send_tcp_mss self.recv_tcp_mss = self.send_tcp_mss
recv_buf = recv_buf[2:] recv_buf = recv_buf[2:]
data_len = struct.unpack('<H', recv_buf[:2])[0] ^ struct.unpack('<H', self.last_client_hash[12:14])[0] data_len = struct.unpack('<H', recv_buf[:2])[0] ^ struct.unpack('<H', self.last_client_hash[12:14])[
0]
else: else:
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
@ -648,7 +658,7 @@ class auth_akarin_rand(auth_base):
uid = struct.pack('<I', uid) uid = struct.pack('<I', uid)
rand_len = self.udp_rnd_data_len(md5data, self.random_client) rand_len = self.udp_rnd_data_len(md5data, self.random_client)
encryptor = encrypt.Encryptor( encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'chacha20', mac_key[:8]) to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'chacha20', mac_key[:8])
encryptor.encrypt(b'') encryptor.encrypt(b'')
out_buf = encryptor.encrypt(buf) out_buf = encryptor.encrypt(buf)
buf = out_buf + rand_bytes(rand_len) + authdata + uid buf = out_buf + rand_bytes(rand_len) + authdata + uid
@ -680,7 +690,8 @@ class auth_akarin_rand(auth_base):
mac_key = self.server_info.key mac_key = self.server_info.key
md5data = hmac.new(mac_key, authdata, self.hashfunc).digest() md5data = hmac.new(mac_key, authdata, self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server) rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'chacha20', mac_key[:8]) encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)),
'chacha20', mac_key[:8])
encryptor.encrypt(b'') encryptor.encrypt(b'')
out_buf = encryptor.encrypt(buf) out_buf = encryptor.encrypt(buf)
buf = out_buf + rand_bytes(rand_len) + authdata buf = out_buf + rand_bytes(rand_len) + authdata
@ -702,7 +713,8 @@ class auth_akarin_rand(auth_base):
if hmac.new(user_key, buf[:-1], self.hashfunc).digest()[:1] != buf[-1:]: if hmac.new(user_key, buf[:-1], self.hashfunc).digest()[:1] != buf[-1:]:
return (b'', None) return (b'', None)
rand_len = self.udp_rnd_data_len(md5data, self.random_client) rand_len = self.udp_rnd_data_len(md5data, self.random_client)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'chacha20') encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)),
'chacha20')
encryptor.decrypt(mac_key[:8]) encryptor.decrypt(mac_key[:8])
out_buf = encryptor.decrypt(buf[:-8 - rand_len]) out_buf = encryptor.decrypt(buf[:-8 - rand_len])
return (out_buf, uid) return (out_buf, uid)
@ -770,7 +782,6 @@ class auth_akarin_spec_a(auth_akarin_rand):
return random.next() % 521 return random.next() % 521
return random.next() % 1021 return random.next() % 1021
def recv_rnd_data_len(self, buf_size, last_hash, random): def recv_rnd_data_len(self, buf_size, last_hash, random):
if buf_size + self.server_info.overhead > self.recv_tcp_mss: if buf_size + self.server_info.overhead > self.recv_tcp_mss:
random.init_from_bin_len(last_hash, buf_size) random.init_from_bin_len(last_hash, buf_size)
@ -797,5 +808,3 @@ class auth_akarin_spec_a(auth_akarin_rand):
if buf_size > 400: if buf_size > 400:
return random.next() % 521 return random.next() % 521
return random.next() % 1021 return random.next() % 1021

Loading…
Cancel
Save