@ -1,4 +1,5 @@
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#
# Copyright 2015-2015 breakwa11
# Copyright 2015-2015 breakwa11
#
#
@ -38,17 +39,31 @@ from shadowsocks import common, lru_cache, encrypt
from shadowsocks . obfsplugin import plain
from shadowsocks . obfsplugin import plain
from shadowsocks . common import to_bytes , to_str , ord , chr
from shadowsocks . common import to_bytes , to_str , ord , chr
def create_auth_chain_a ( method ) :
def create_auth_chain_a ( method ) :
return auth_chain_a ( method )
return auth_chain_a ( method )
def create_auth_chain_b ( method ) :
def create_auth_chain_b ( method ) :
return auth_chain_b ( method )
return auth_chain_b ( method )
def create_auth_chain_c ( method ) :
return auth_chain_c ( method )
def create_auth_chain_d ( method ) :
return auth_chain_d ( method )
obfs_map = {
obfs_map = {
' auth_chain_a ' : ( create_auth_chain_a , ) ,
' auth_chain_a ' : ( create_auth_chain_a , ) ,
' auth_chain_b ' : ( create_auth_chain_b , ) ,
' auth_chain_b ' : ( create_auth_chain_b , ) ,
' auth_chain_c ' : ( create_auth_chain_c , ) ,
' auth_chain_d ' : ( create_auth_chain_d , ) ,
}
}
class xorshift128plus ( object ) :
class xorshift128plus ( object ) :
max_int = ( 1 << 64 ) - 1
max_int = ( 1 << 64 ) - 1
mov_mask = ( 1 << ( 64 - 23 ) ) - 1
mov_mask = ( 1 << ( 64 - 23 ) ) - 1
@ -80,12 +95,14 @@ class xorshift128plus(object):
for i in range ( 4 ) :
for i in range ( 4 ) :
self . next ( )
self . next ( )
def match_begin ( str1 , str2 ) :
def match_begin ( str1 , str2 ) :
if len ( str1 ) > = len ( str2 ) :
if len ( str1 ) > = len ( str2 ) :
if str1 [ : len ( str2 ) ] == str2 :
if str1 [ : len ( str2 ) ] == str2 :
return True
return True
return False
return False
class auth_base ( plain . plain ) :
class auth_base ( plain . plain ) :
def __init__ ( self , method ) :
def __init__ ( self , method ) :
super ( auth_base , self ) . __init__ ( method )
super ( auth_base , self ) . __init__ ( method )
@ -96,7 +113,7 @@ class auth_base(plain.plain):
def init_data ( self ) :
def init_data ( self ) :
return ' '
return ' '
def get_overhead ( self , direction ) : # direction: true for c->s false for s->c
def get_overhead ( self , direction ) : # direction: true for c->s false for s->c
return self . overhead
return self . overhead
def set_server_info ( self , server_info ) :
def set_server_info ( self , server_info ) :
@ -118,9 +135,10 @@ class auth_base(plain.plain):
self . raw_trans = True
self . raw_trans = True
self . overhead = 0
self . overhead = 0
if self . method == self . no_compatible_method :
if self . method == self . no_compatible_method :
return ( b ' E ' * 2048 , False )
return ( b ' E ' * 2048 , False )
return ( buf , False )
return ( buf , False )
class client_queue ( object ) :
class client_queue ( object ) :
def __init__ ( self , begin_id ) :
def __init__ ( self , begin_id ) :
self . front = begin_id - 64
self . front = begin_id - 64
@ -175,13 +193,14 @@ class client_queue(object):
self . addref ( )
self . addref ( )
return True
return True
class obfs_auth_chain_data ( object ) :
class obfs_auth_chain_data ( object ) :
def __init__ ( self , name ) :
def __init__ ( self , name ) :
self . name = name
self . name = name
self . user_id = { }
self . user_id = { }
self . local_client_id = b ' '
self . local_client_id = b ' '
self . connection_id = 0
self . connection_id = 0
self . set_max_client ( 64 ) # max active client count
self . set_max_client ( 64 ) # max active client count
def update ( self , user_id , client_id , connection_id ) :
def update ( self , user_id , client_id , connection_id ) :
if user_id not in self . user_id :
if user_id not in self . user_id :
@ -203,7 +222,7 @@ class obfs_auth_chain_data(object):
if local_client_id . get ( client_id , None ) is None or not local_client_id [ client_id ] . enable :
if local_client_id . get ( client_id , None ) is None or not local_client_id [ client_id ] . enable :
if local_client_id . first ( ) is None or len ( local_client_id ) < self . max_client :
if local_client_id . first ( ) is None or len ( local_client_id ) < self . max_client :
if client_id not in local_client_id :
if client_id not in local_client_id :
#TODO: check
# TODO: check
local_client_id [ client_id ] = client_queue ( connection_id )
local_client_id [ client_id ] = client_queue ( connection_id )
else :
else :
local_client_id [ client_id ] . re_enable ( connection_id )
local_client_id [ client_id ] . re_enable ( connection_id )
@ -212,7 +231,7 @@ class obfs_auth_chain_data(object):
if not local_client_id [ local_client_id . first ( ) ] . is_active ( ) :
if not local_client_id [ local_client_id . first ( ) ] . is_active ( ) :
del local_client_id [ local_client_id . first ( ) ]
del local_client_id [ local_client_id . first ( ) ]
if client_id not in local_client_id :
if client_id not in local_client_id :
#TODO: check
# TODO: check
local_client_id [ client_id ] = client_queue ( connection_id )
local_client_id [ client_id ] = client_queue ( connection_id )
else :
else :
local_client_id [ client_id ] . re_enable ( connection_id )
local_client_id [ client_id ] . re_enable ( connection_id )
@ -229,6 +248,7 @@ class obfs_auth_chain_data(object):
if client_id in local_client_id :
if client_id in local_client_id :
local_client_id [ client_id ] . delref ( )
local_client_id [ client_id ] . delref ( )
class auth_chain_a ( auth_base ) :
class auth_chain_a ( auth_base ) :
def __init__ ( self , method ) :
def __init__ ( self , method ) :
super ( auth_chain_a , self ) . __init__ ( method )
super ( auth_chain_a , self ) . __init__ ( method )
@ -240,7 +260,7 @@ class auth_chain_a(auth_base):
self . has_recv_header = False
self . has_recv_header = False
self . client_id = 0
self . client_id = 0
self . connection_id = 0
self . connection_id = 0
self . max_time_dif = 60 * 60 * 24 # time dif (second) setting
self . max_time_dif = 60 * 60 * 24 # time dif (second) setting
self . salt = b " auth_chain_a "
self . salt = b " auth_chain_a "
self . no_compatible_method = ' auth_chain_a '
self . no_compatible_method = ' auth_chain_a '
self . pack_id = 1
self . pack_id = 1
@ -259,7 +279,7 @@ class auth_chain_a(auth_base):
def init_data ( self ) :
def init_data ( self ) :
return obfs_auth_chain_data ( self . method )
return obfs_auth_chain_data ( self . method )
def get_overhead ( self , direction ) : # direction: true for c->s false for s->c
def get_overhead ( self , direction ) : # direction: true for c->s false for s->c
return self . overhead
return self . overhead
def set_server_info ( self , server_info ) :
def set_server_info ( self , server_info ) :
@ -362,14 +382,16 @@ class auth_chain_a(auth_base):
if self . user_key is None :
if self . user_key is None :
self . user_key = self . server_info . key
self . user_key = self . server_info . key
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + self . salt , ' aes-128-cbc ' , b ' \x00 ' * 16 )
encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + self . salt , ' aes-128-cbc ' , b ' \x00 ' * 16 )
uid = struct . unpack ( ' <I ' , uid ) [ 0 ] ^ struct . unpack ( ' <I ' , self . last_client_hash [ 8 : 12 ] ) [ 0 ]
uid = struct . unpack ( ' <I ' , uid ) [ 0 ] ^ struct . unpack ( ' <I ' , self . last_client_hash [ 8 : 12 ] ) [ 0 ]
uid = struct . pack ( ' <I ' , uid )
uid = struct . pack ( ' <I ' , uid )
data = uid + encryptor . encrypt ( data ) [ 16 : ]
data = uid + encryptor . encrypt ( data ) [ 16 : ]
self . last_server_hash = hmac . new ( self . user_key , data , self . hashfunc ) . digest ( )
self . last_server_hash = hmac . new ( self . user_key , data , self . hashfunc ) . digest ( )
data = check_head + data + self . last_server_hash [ : 4 ]
data = check_head + data + self . last_server_hash [ : 4 ]
self . encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
self . encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
return data + self . pack_client_data ( buf )
return data + self . pack_client_data ( buf )
def auth_data ( self ) :
def auth_data ( self ) :
@ -382,8 +404,8 @@ class auth_chain_a(auth_base):
self . server_info . data . connection_id = struct . unpack ( ' <I ' , os . urandom ( 4 ) ) [ 0 ] & 0xFFFFFF
self . server_info . data . connection_id = struct . unpack ( ' <I ' , os . urandom ( 4 ) ) [ 0 ] & 0xFFFFFF
self . server_info . data . connection_id + = 1
self . server_info . data . connection_id + = 1
return b ' ' . join ( [ struct . pack ( ' <I ' , utc_time ) ,
return b ' ' . join ( [ struct . pack ( ' <I ' , utc_time ) ,
self . server_info . data . local_client_id ,
self . server_info . data . local_client_id ,
struct . pack ( ' <I ' , self . server_info . data . connection_id ) ] )
struct . pack ( ' <I ' , self . server_info . data . connection_id ) ] )
def client_pre_encrypt ( self , buf ) :
def client_pre_encrypt ( self , buf ) :
ret = b ' '
ret = b ' '
@ -419,8 +441,9 @@ class auth_chain_a(auth_base):
break
break
server_hash = hmac . new ( mac_key , self . recv_buf [ : length + 2 ] , self . hashfunc ) . digest ( )
server_hash = hmac . new ( mac_key , self . recv_buf [ : length + 2 ] , self . hashfunc ) . digest ( )
if server_hash [ : 2 ] != self . recv_buf [ length + 2 : length + 4 ] :
if server_hash [ : 2 ] != self . recv_buf [ length + 2 : length + 4 ] :
logging . info ( ' %s : checksum error, data %s ' % ( self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] ) ) )
logging . info ( ' %s : checksum error, data %s '
% ( self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] ) ) )
self . raw_trans = True
self . raw_trans = True
self . recv_buf = b ' '
self . recv_buf = b ' '
raise Exception ( ' client_post_decrypt data uncorrect checksum ' )
raise Exception ( ' client_post_decrypt data uncorrect checksum ' )
@ -428,7 +451,7 @@ class auth_chain_a(auth_base):
pos = 2
pos = 2
if data_len > 0 and rand_len > 0 :
if data_len > 0 and rand_len > 0 :
pos = 2 + self . rnd_start_pos ( rand_len , self . random_server )
pos = 2 + self . rnd_start_pos ( rand_len , self . random_server )
out_buf + = self . encryptor . decrypt ( self . recv_buf [ pos : data_len + pos ] )
out_buf + = self . encryptor . decrypt ( self . recv_buf [ pos : data_len + pos ] )
self . last_server_hash = server_hash
self . last_server_hash = server_hash
if self . recv_id == 1 :
if self . recv_id == 1 :
self . server_info . tcp_mss = struct . unpack ( ' <H ' , out_buf [ : 2 ] ) [ 0 ]
self . server_info . tcp_mss = struct . unpack ( ' <H ' , out_buf [ : 2 ] ) [ 0 ]
@ -486,16 +509,19 @@ class auth_chain_a(auth_base):
else :
else :
self . user_key = self . server_info . recv_iv
self . user_key = self . server_info . recv_iv
md5data = hmac . new ( self . user_key , self . recv_buf [ 12 : 12 + 20 ] , self . hashfunc ) . digest ( )
md5data = hmac . new ( self . user_key , self . recv_buf [ 12 : 12 + 20 ] , self . hashfunc ) . digest ( )
if md5data [ : 4 ] != self . recv_buf [ 32 : 36 ] :
if md5data [ : 4 ] != self . recv_buf [ 32 : 36 ] :
logging . error ( ' %s data uncorrect auth HMAC-MD5 from %s : %d , data %s ' % ( self . no_compatible_method , self . server_info . client , self . server_info . client_port , binascii . hexlify ( self . recv_buf ) ) )
logging . error ( ' %s data uncorrect auth HMAC-MD5 from %s : %d , data %s ' % (
self . no_compatible_method , self . server_info . client , self . server_info . client_port ,
binascii . hexlify ( self . recv_buf )
) )
if len ( self . recv_buf ) < 36 :
if len ( self . recv_buf ) < 36 :
return ( b ' ' , False )
return ( b ' ' , False )
return self . not_match_return ( self . recv_buf )
return self . not_match_return ( self . recv_buf )
self . last_server_hash = md5data
self . last_server_hash = md5data
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + self . salt , ' aes-128-cbc ' )
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + self . salt , ' aes-128-cbc ' )
head = encryptor . decrypt ( b ' \x00 ' * 16 + self . recv_buf [ 16 : 32 ] + b ' \x00 ' ) # need an extra byte or recv empty
head = encryptor . decrypt ( b ' \x00 ' * 16 + self . recv_buf [ 16 : 32 ] + b ' \x00 ' ) # need an extra byte or recv empty
self . client_over_head = struct . unpack ( ' <H ' , head [ 12 : 14 ] ) [ 0 ]
self . client_over_head = struct . unpack ( ' <H ' , head [ 12 : 14 ] ) [ 0 ]
utc_time = struct . unpack ( ' <I ' , head [ : 4 ] ) [ 0 ]
utc_time = struct . unpack ( ' <I ' , head [ : 4 ] ) [ 0 ]
@ -503,7 +529,9 @@ class auth_chain_a(auth_base):
connection_id = struct . unpack ( ' <I ' , head [ 8 : 12 ] ) [ 0 ]
connection_id = struct . unpack ( ' <I ' , head [ 8 : 12 ] ) [ 0 ]
time_dif = common . int32 ( utc_time - ( int ( time . time ( ) ) & 0xffffffff ) )
time_dif = common . int32 ( utc_time - ( int ( time . time ( ) ) & 0xffffffff ) )
if time_dif < - self . max_time_dif or time_dif > self . max_time_dif :
if time_dif < - self . max_time_dif or time_dif > self . max_time_dif :
logging . info ( ' %s : wrong timestamp, time_dif %d , data %s ' % ( self . no_compatible_method , time_dif , binascii . hexlify ( head ) ) )
logging . info ( ' %s : wrong timestamp, time_dif %d , data %s ' % (
self . no_compatible_method , time_dif , binascii . hexlify ( head )
) )
return self . not_match_return ( self . recv_buf )
return self . not_match_return ( self . recv_buf )
elif self . server_info . data . insert ( self . user_id , client_id , connection_id ) :
elif self . server_info . data . insert ( self . user_id , client_id , connection_id ) :
self . has_recv_header = True
self . has_recv_header = True
@ -513,7 +541,8 @@ class auth_chain_a(auth_base):
logging . info ( ' %s : auth fail, data %s ' % ( self . no_compatible_method , binascii . hexlify ( out_buf ) ) )
logging . info ( ' %s : auth fail, data %s ' % ( self . no_compatible_method , binascii . hexlify ( out_buf ) ) )
return self . not_match_return ( self . recv_buf )
return self . not_match_return ( self . recv_buf )
self . encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
self . encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
self . recv_buf = self . recv_buf [ 36 : ]
self . recv_buf = self . recv_buf [ 36 : ]
self . has_recv_header = True
self . has_recv_header = True
sendback = True
sendback = True
@ -528,7 +557,7 @@ class auth_chain_a(auth_base):
self . recv_buf = b ' '
self . recv_buf = b ' '
if self . recv_id == 0 :
if self . recv_id == 0 :
logging . info ( self . no_compatible_method + ' : over size ' )
logging . info ( self . no_compatible_method + ' : over size ' )
return ( b ' E ' * 2048 , False )
return ( b ' E ' * 2048 , False )
else :
else :
raise Exception ( ' server_post_decrype data error ' )
raise Exception ( ' server_post_decrype data error ' )
@ -536,12 +565,14 @@ class auth_chain_a(auth_base):
break
break
client_hash = hmac . new ( mac_key , self . recv_buf [ : length + 2 ] , self . hashfunc ) . digest ( )
client_hash = hmac . new ( mac_key , self . recv_buf [ : length + 2 ] , self . hashfunc ) . digest ( )
if client_hash [ : 2 ] != self . recv_buf [ length + 2 : length + 4 ] :
if client_hash [ : 2 ] != self . recv_buf [ length + 2 : length + 4 ] :
logging . info ( ' %s : checksum error, data %s ' % ( self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] ) ) )
logging . info ( ' %s : checksum error, data %s ' % (
self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] )
) )
self . raw_trans = True
self . raw_trans = True
self . recv_buf = b ' '
self . recv_buf = b ' '
if self . recv_id == 0 :
if self . recv_id == 0 :
return ( b ' E ' * 2048 , False )
return ( b ' E ' * 2048 , False )
else :
else :
raise Exception ( ' server_post_decrype data uncorrect checksum ' )
raise Exception ( ' server_post_decrype data uncorrect checksum ' )
@ -549,7 +580,7 @@ class auth_chain_a(auth_base):
pos = 2
pos = 2
if data_len > 0 and rand_len > 0 :
if data_len > 0 and rand_len > 0 :
pos = 2 + self . rnd_start_pos ( rand_len , self . random_client )
pos = 2 + self . rnd_start_pos ( rand_len , self . random_client )
out_buf + = self . encryptor . decrypt ( self . recv_buf [ pos : data_len + pos ] )
out_buf + = self . encryptor . decrypt ( self . recv_buf [ pos : data_len + pos ] )
self . last_client_hash = client_hash
self . last_client_hash = client_hash
self . recv_buf = self . recv_buf [ length + 4 : ]
self . recv_buf = self . recv_buf [ length + 4 : ]
if data_len == 0 :
if data_len == 0 :
@ -577,7 +608,8 @@ class auth_chain_a(auth_base):
uid = struct . unpack ( ' <I ' , self . user_id ) [ 0 ] ^ struct . unpack ( ' <I ' , md5data [ : 4 ] ) [ 0 ]
uid = struct . unpack ( ' <I ' , self . user_id ) [ 0 ] ^ struct . unpack ( ' <I ' , md5data [ : 4 ] ) [ 0 ]
uid = struct . pack ( ' <I ' , uid )
uid = struct . pack ( ' <I ' , uid )
rand_len = self . udp_rnd_data_len ( md5data , self . random_client )
rand_len = self . udp_rnd_data_len ( md5data , self . random_client )
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
out_buf = encryptor . encrypt ( buf )
out_buf = encryptor . encrypt ( buf )
buf = out_buf + os . urandom ( rand_len ) + authdata + uid
buf = out_buf + os . urandom ( rand_len ) + authdata + uid
return buf + hmac . new ( self . user_key , buf , self . hashfunc ) . digest ( ) [ : 1 ]
return buf + hmac . new ( self . user_key , buf , self . hashfunc ) . digest ( ) [ : 1 ]
@ -590,7 +622,8 @@ class auth_chain_a(auth_base):
mac_key = self . server_info . key
mac_key = self . server_info . key
md5data = hmac . new ( mac_key , buf [ - 8 : - 1 ] , self . hashfunc ) . digest ( )
md5data = hmac . new ( mac_key , buf [ - 8 : - 1 ] , self . hashfunc ) . digest ( )
rand_len = self . udp_rnd_data_len ( md5data , self . random_server )
rand_len = self . udp_rnd_data_len ( md5data , self . random_server )
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
return encryptor . decrypt ( buf [ : - 8 - rand_len ] )
return encryptor . decrypt ( buf [ : - 8 - rand_len ] )
def server_udp_pre_encrypt ( self , buf , uid ) :
def server_udp_pre_encrypt ( self , buf , uid ) :
@ -634,11 +667,16 @@ class auth_chain_a(auth_base):
def dispose ( self ) :
def dispose ( self ) :
self . server_info . data . remove ( self . user_id , self . client_id )
self . server_info . data . remove ( self . user_id , self . client_id )
class auth_chain_b ( auth_chain_a ) :
class auth_chain_b ( auth_chain_a ) :
def __init__ ( self , method ) :
def __init__ ( self , method ) :
super ( auth_chain_b , self ) . __init__ ( method )
super ( auth_chain_b , self ) . __init__ ( method )
self . salt = b " auth_chain_b "
self . salt = b " auth_chain_b "
self . no_compatible_method = ' auth_chain_b '
self . no_compatible_method = ' auth_chain_b '
# NOTE
# 补全后长度数组
# 随机在其中选择一个补全到的长度
# 为每个连接初始化一个固定内容的数组
self . data_size_list = [ ]
self . data_size_list = [ ]
self . data_size_list2 = [ ]
self . data_size_list2 = [ ]
@ -648,10 +686,12 @@ class auth_chain_b(auth_chain_a):
self . data_size_list2 = [ ]
self . data_size_list2 = [ ]
random = xorshift128plus ( )
random = xorshift128plus ( )
random . init_from_bin ( key )
random . init_from_bin ( key )
# 补全数组长为4~12-1
list_len = random . next ( ) % 8 + 4
list_len = random . next ( ) % 8 + 4
for i in range ( 0 , list_len ) :
for i in range ( 0 , list_len ) :
self . data_size_list . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list . sort ( )
self . data_size_list . sort ( )
# 补全数组长为8~24-1
list_len = random . next ( ) % 16 + 8
list_len = random . next ( ) % 16 + 8
for i in range ( 0 , list_len ) :
for i in range ( 0 , list_len ) :
self . data_size_list2 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list2 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
@ -672,15 +712,21 @@ class auth_chain_b(auth_chain_a):
random . init_from_bin_len ( last_hash , buf_size )
random . init_from_bin_len ( last_hash , buf_size )
pos = bisect . bisect_left ( self . data_size_list , buf_size + self . server_info . overhead )
pos = bisect . bisect_left ( self . data_size_list , buf_size + self . server_info . overhead )
final_pos = pos + random . next ( ) % ( len ( self . data_size_list ) )
final_pos = pos + random . next ( ) % ( len ( self . data_size_list ) )
# 假设random均匀分布,则越长的原始数据长度越容易if false
if final_pos < len ( self . data_size_list ) :
if final_pos < len ( self . data_size_list ) :
return self . data_size_list [ final_pos ] - buf_size - self . server_info . overhead
return self . data_size_list [ final_pos ] - buf_size - self . server_info . overhead
# 上面if false后选择2号补全数组,此处有更精细的长度分段
pos = bisect . bisect_left ( self . data_size_list2 , buf_size + self . server_info . overhead )
pos = bisect . bisect_left ( self . data_size_list2 , buf_size + self . server_info . overhead )
final_pos = pos + random . next ( ) % ( len ( self . data_size_list2 ) )
final_pos = pos + random . next ( ) % ( len ( self . data_size_list2 ) )
if final_pos < len ( self . data_size_list2 ) :
if final_pos < len ( self . data_size_list2 ) :
return self . data_size_list2 [ final_pos ] - buf_size - self . server_info . overhead
return self . data_size_list2 [ final_pos ] - buf_size - self . server_info . overhead
# final_pos 总是分布在pos~(data_size_list2.len-1)之间
if final_pos < pos + len ( self . data_size_list2 ) - 1 :
if final_pos < pos + len ( self . data_size_list2 ) - 1 :
return 0
return 0
# 有1/len(self.data_size_list2)的概率不满足上一个if ?
# 理论上不会运行到此处,因此可以插入运行断言 ?
# assert False
if buf_size > 1300 :
if buf_size > 1300 :
return random . next ( ) % 31
return random . next ( ) % 31
@ -690,3 +736,105 @@ class auth_chain_b(auth_chain_a):
return random . next ( ) % 521
return random . next ( ) % 521
return random . next ( ) % 1021
return random . next ( ) % 1021
class auth_chain_c ( auth_chain_b ) :
def __init__ ( self , method ) :
super ( auth_chain_c , self ) . __init__ ( method )
self . salt = b " auth_chain_c "
self . no_compatible_method = ' auth_chain_c '
self . data_size_list0 = [ ]
def init_data_size ( self , key ) :
if self . data_size_list0 :
self . data_size_list0 = [ ]
random = xorshift128plus ( )
random . init_from_bin ( key )
# 补全数组长为12~24-1
list_len = random . next ( ) % ( 8 + 16 ) + ( 4 + 8 )
for i in range ( 0 , list_len ) :
self . data_size_list0 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list0 . sort ( )
def set_server_info ( self , server_info ) :
self . server_info = server_info
try :
max_client = int ( server_info . protocol_param . split ( ' # ' ) [ 0 ] )
except :
max_client = 64
self . server_info . data . set_max_client ( max_client )
self . init_data_size ( self . server_info . key )
def rnd_data_len ( self , buf_size , last_hash , random ) :
other_data_size = buf_size + self . server_info . overhead
# 一定要在random使用前初始化,以保证服务器与客户端同步,保证包大小验证结果正确
random . init_from_bin_len ( last_hash , buf_size )
# final_pos 总是分布在pos~(data_size_list0.len-1)之间
# 除非data_size_list0中的任何值均过小使其全部都无法容纳buf
if other_data_size > = self . data_size_list0 [ - 1 ] :
if other_data_size > = 1440 :
return 0
if other_data_size > 1300 :
return random . next ( ) % 31
if other_data_size > 900 :
return random . next ( ) % 127
if other_data_size > 400 :
return random . next ( ) % 521
return random . next ( ) % 1021
pos = bisect . bisect_left ( self . data_size_list0 , other_data_size )
# random select a size in the leftover data_size_list0
final_pos = pos + random . next ( ) % ( len ( self . data_size_list0 ) - pos )
return self . data_size_list0 [ final_pos ] - other_data_size
class auth_chain_d ( auth_chain_b ) :
def __init__ ( self , method ) :
super ( auth_chain_d , self ) . __init__ ( method )
self . salt = b " auth_chain_d "
self . no_compatible_method = ' auth_chain_d '
self . data_size_list0 = [ ]
def check_and_patch_data_size ( self , random ) :
# append new item
# when the biggest item(first time) or the last append item(other time) are not big enough.
# but set a limit size (64) to avoid stack overflow.
if self . data_size_list0 [ - 1 ] < 1300 and len ( self . data_size_list0 ) < 64 :
self . data_size_list0 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . check_and_patch_data_size ( random )
def init_data_size ( self , key ) :
if self . data_size_list0 :
self . data_size_list0 = [ ]
random = xorshift128plus ( )
random . init_from_bin ( key )
# 补全数组长为12~24-1
list_len = random . next ( ) % ( 8 + 16 ) + ( 4 + 8 )
for i in range ( 0 , list_len ) :
self . data_size_list0 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list0 . sort ( )
old_len = len ( self . data_size_list0 )
self . check_and_patch_data_size ( random )
# if check_and_patch_data_size are work, re-sort again.
if old_len != len ( self . data_size_list0 ) :
self . data_size_list0 . sort ( )
def set_server_info ( self , server_info ) :
self . server_info = server_info
try :
max_client = int ( server_info . protocol_param . split ( ' # ' ) [ 0 ] )
except :
max_client = 64
self . server_info . data . set_max_client ( max_client )
self . init_data_size ( self . server_info . key )
def rnd_data_len ( self , buf_size , last_hash , random ) :
other_data_size = buf_size + self . server_info . overhead
# if other_data_size > the bigest item in data_size_list0, not padding any data
if other_data_size > = self . data_size_list0 [ - 1 ] :
return 0
random . init_from_bin_len ( last_hash , buf_size )
pos = bisect . bisect_left ( self . data_size_list0 , other_data_size )
# random select a size in the leftover data_size_list0
final_pos = pos + random . next ( ) % ( len ( self . data_size_list0 ) - pos )
return self . data_size_list0 [ final_pos ] - other_data_size