@ -1,4 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2015-2015 breakwa11
#
@ -38,17 +39,31 @@ from shadowsocks import common, lru_cache, encrypt
from shadowsocks . obfsplugin import plain
from shadowsocks . common import to_bytes , to_str , ord , chr
def create_auth_chain_a ( method ) :
return auth_chain_a ( method )
def create_auth_chain_b ( method ) :
return auth_chain_b ( method )
def create_auth_chain_c ( method ) :
return auth_chain_c ( method )
def create_auth_chain_d ( method ) :
return auth_chain_d ( method )
obfs_map = {
' auth_chain_a ' : ( create_auth_chain_a , ) ,
' auth_chain_b ' : ( create_auth_chain_b , ) ,
' auth_chain_c ' : ( create_auth_chain_c , ) ,
' auth_chain_d ' : ( create_auth_chain_d , ) ,
}
class xorshift128plus ( object ) :
max_int = ( 1 << 64 ) - 1
mov_mask = ( 1 << ( 64 - 23 ) ) - 1
@ -80,12 +95,14 @@ class xorshift128plus(object):
for i in range ( 4 ) :
self . next ( )
def match_begin ( str1 , str2 ) :
if len ( str1 ) > = len ( str2 ) :
if str1 [ : len ( str2 ) ] == str2 :
return True
return False
class auth_base ( plain . plain ) :
def __init__ ( self , method ) :
super ( auth_base , self ) . __init__ ( method )
@ -121,6 +138,7 @@ class auth_base(plain.plain):
return ( b ' E ' * 2048 , False )
return ( buf , False )
class client_queue ( object ) :
def __init__ ( self , begin_id ) :
self . front = begin_id - 64
@ -175,6 +193,7 @@ class client_queue(object):
self . addref ( )
return True
class obfs_auth_chain_data ( object ) :
def __init__ ( self , name ) :
self . name = name
@ -229,6 +248,7 @@ class obfs_auth_chain_data(object):
if client_id in local_client_id :
local_client_id [ client_id ] . delref ( )
class auth_chain_a ( auth_base ) :
def __init__ ( self , method ) :
super ( auth_chain_a , self ) . __init__ ( method )
@ -362,14 +382,16 @@ class auth_chain_a(auth_base):
if self . user_key is None :
self . user_key = self . server_info . key
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + self . salt , ' aes-128-cbc ' , b ' \x00 ' * 16 )
encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + self . salt , ' aes-128-cbc ' , b ' \x00 ' * 16 )
uid = struct . unpack ( ' <I ' , uid ) [ 0 ] ^ struct . unpack ( ' <I ' , self . last_client_hash [ 8 : 12 ] ) [ 0 ]
uid = struct . pack ( ' <I ' , uid )
data = uid + encryptor . encrypt ( data ) [ 16 : ]
self . last_server_hash = hmac . new ( self . user_key , data , self . hashfunc ) . digest ( )
data = check_head + data + self . last_server_hash [ : 4 ]
self . encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
self . encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
return data + self . pack_client_data ( buf )
def auth_data ( self ) :
@ -420,7 +442,8 @@ class auth_chain_a(auth_base):
server_hash = hmac . new ( mac_key , self . recv_buf [ : length + 2 ] , self . hashfunc ) . digest ( )
if server_hash [ : 2 ] != self . recv_buf [ length + 2 : length + 4 ] :
logging . info ( ' %s : checksum error, data %s ' % ( self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] ) ) )
logging . info ( ' %s : checksum error, data %s '
% ( self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] ) ) )
self . raw_trans = True
self . recv_buf = b ' '
raise Exception ( ' client_post_decrypt data uncorrect checksum ' )
@ -488,7 +511,10 @@ class auth_chain_a(auth_base):
md5data = hmac . new ( self . user_key , self . recv_buf [ 12 : 12 + 20 ] , self . hashfunc ) . digest ( )
if md5data [ : 4 ] != self . recv_buf [ 32 : 36 ] :
logging . error ( ' %s data uncorrect auth HMAC-MD5 from %s : %d , data %s ' % ( self . no_compatible_method , self . server_info . client , self . server_info . client_port , binascii . hexlify ( self . recv_buf ) ) )
logging . error ( ' %s data uncorrect auth HMAC-MD5 from %s : %d , data %s ' % (
self . no_compatible_method , self . server_info . client , self . server_info . client_port ,
binascii . hexlify ( self . recv_buf )
) )
if len ( self . recv_buf ) < 36 :
return ( b ' ' , False )
return self . not_match_return ( self . recv_buf )
@ -503,7 +529,9 @@ class auth_chain_a(auth_base):
connection_id = struct . unpack ( ' <I ' , head [ 8 : 12 ] ) [ 0 ]
time_dif = common . int32 ( utc_time - ( int ( time . time ( ) ) & 0xffffffff ) )
if time_dif < - self . max_time_dif or time_dif > self . max_time_dif :
logging . info ( ' %s : wrong timestamp, time_dif %d , data %s ' % ( self . no_compatible_method , time_dif , binascii . hexlify ( head ) ) )
logging . info ( ' %s : wrong timestamp, time_dif %d , data %s ' % (
self . no_compatible_method , time_dif , binascii . hexlify ( head )
) )
return self . not_match_return ( self . recv_buf )
elif self . server_info . data . insert ( self . user_id , client_id , connection_id ) :
self . has_recv_header = True
@ -513,7 +541,8 @@ class auth_chain_a(auth_base):
logging . info ( ' %s : auth fail, data %s ' % ( self . no_compatible_method , binascii . hexlify ( out_buf ) ) )
return self . not_match_return ( self . recv_buf )
self . encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
self . encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( self . last_client_hash ) ) , ' rc4 ' )
self . recv_buf = self . recv_buf [ 36 : ]
self . has_recv_header = True
sendback = True
@ -537,7 +566,9 @@ class auth_chain_a(auth_base):
client_hash = hmac . new ( mac_key , self . recv_buf [ : length + 2 ] , self . hashfunc ) . digest ( )
if client_hash [ : 2 ] != self . recv_buf [ length + 2 : length + 4 ] :
logging . info ( ' %s : checksum error, data %s ' % ( self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] ) ) )
logging . info ( ' %s : checksum error, data %s ' % (
self . no_compatible_method , binascii . hexlify ( self . recv_buf [ : length ] )
) )
self . raw_trans = True
self . recv_buf = b ' '
if self . recv_id == 0 :
@ -577,7 +608,8 @@ class auth_chain_a(auth_base):
uid = struct . unpack ( ' <I ' , self . user_id ) [ 0 ] ^ struct . unpack ( ' <I ' , md5data [ : 4 ] ) [ 0 ]
uid = struct . pack ( ' <I ' , uid )
rand_len = self . udp_rnd_data_len ( md5data , self . random_client )
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
out_buf = encryptor . encrypt ( buf )
buf = out_buf + os . urandom ( rand_len ) + authdata + uid
return buf + hmac . new ( self . user_key , buf , self . hashfunc ) . digest ( ) [ : 1 ]
@ -590,7 +622,8 @@ class auth_chain_a(auth_base):
mac_key = self . server_info . key
md5data = hmac . new ( mac_key , buf [ - 8 : - 1 ] , self . hashfunc ) . digest ( )
rand_len = self . udp_rnd_data_len ( md5data , self . random_server )
encryptor = encrypt . Encryptor ( to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
encryptor = encrypt . Encryptor (
to_bytes ( base64 . b64encode ( self . user_key ) ) + to_bytes ( base64 . b64encode ( md5data ) ) , ' rc4 ' )
return encryptor . decrypt ( buf [ : - 8 - rand_len ] )
def server_udp_pre_encrypt ( self , buf , uid ) :
@ -634,11 +667,16 @@ class auth_chain_a(auth_base):
def dispose ( self ) :
self . server_info . data . remove ( self . user_id , self . client_id )
class auth_chain_b ( auth_chain_a ) :
def __init__ ( self , method ) :
super ( auth_chain_b , self ) . __init__ ( method )
self . salt = b " auth_chain_b "
self . no_compatible_method = ' auth_chain_b '
# NOTE
# 补全后长度数组
# 随机在其中选择一个补全到的长度
# 为每个连接初始化一个固定内容的数组
self . data_size_list = [ ]
self . data_size_list2 = [ ]
@ -648,10 +686,12 @@ class auth_chain_b(auth_chain_a):
self . data_size_list2 = [ ]
random = xorshift128plus ( )
random . init_from_bin ( key )
# 补全数组长为4~12-1
list_len = random . next ( ) % 8 + 4
for i in range ( 0 , list_len ) :
self . data_size_list . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list . sort ( )
# 补全数组长为8~24-1
list_len = random . next ( ) % 16 + 8
for i in range ( 0 , list_len ) :
self . data_size_list2 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
@ -672,15 +712,21 @@ class auth_chain_b(auth_chain_a):
random . init_from_bin_len ( last_hash , buf_size )
pos = bisect . bisect_left ( self . data_size_list , buf_size + self . server_info . overhead )
final_pos = pos + random . next ( ) % ( len ( self . data_size_list ) )
# 假设random均匀分布,则越长的原始数据长度越容易if false
if final_pos < len ( self . data_size_list ) :
return self . data_size_list [ final_pos ] - buf_size - self . server_info . overhead
# 上面if false后选择2号补全数组,此处有更精细的长度分段
pos = bisect . bisect_left ( self . data_size_list2 , buf_size + self . server_info . overhead )
final_pos = pos + random . next ( ) % ( len ( self . data_size_list2 ) )
if final_pos < len ( self . data_size_list2 ) :
return self . data_size_list2 [ final_pos ] - buf_size - self . server_info . overhead
# final_pos 总是分布在pos~(data_size_list2.len-1)之间
if final_pos < pos + len ( self . data_size_list2 ) - 1 :
return 0
# 有1/len(self.data_size_list2)的概率不满足上一个if ?
# 理论上不会运行到此处,因此可以插入运行断言 ?
# assert False
if buf_size > 1300 :
return random . next ( ) % 31
@ -690,3 +736,105 @@ class auth_chain_b(auth_chain_a):
return random . next ( ) % 521
return random . next ( ) % 1021
class auth_chain_c ( auth_chain_b ) :
def __init__ ( self , method ) :
super ( auth_chain_c , self ) . __init__ ( method )
self . salt = b " auth_chain_c "
self . no_compatible_method = ' auth_chain_c '
self . data_size_list0 = [ ]
def init_data_size ( self , key ) :
if self . data_size_list0 :
self . data_size_list0 = [ ]
random = xorshift128plus ( )
random . init_from_bin ( key )
# 补全数组长为12~24-1
list_len = random . next ( ) % ( 8 + 16 ) + ( 4 + 8 )
for i in range ( 0 , list_len ) :
self . data_size_list0 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list0 . sort ( )
def set_server_info ( self , server_info ) :
self . server_info = server_info
try :
max_client = int ( server_info . protocol_param . split ( ' # ' ) [ 0 ] )
except :
max_client = 64
self . server_info . data . set_max_client ( max_client )
self . init_data_size ( self . server_info . key )
def rnd_data_len ( self , buf_size , last_hash , random ) :
other_data_size = buf_size + self . server_info . overhead
# 一定要在random使用前初始化,以保证服务器与客户端同步,保证包大小验证结果正确
random . init_from_bin_len ( last_hash , buf_size )
# final_pos 总是分布在pos~(data_size_list0.len-1)之间
# 除非data_size_list0中的任何值均过小使其全部都无法容纳buf
if other_data_size > = self . data_size_list0 [ - 1 ] :
if other_data_size > = 1440 :
return 0
if other_data_size > 1300 :
return random . next ( ) % 31
if other_data_size > 900 :
return random . next ( ) % 127
if other_data_size > 400 :
return random . next ( ) % 521
return random . next ( ) % 1021
pos = bisect . bisect_left ( self . data_size_list0 , other_data_size )
# random select a size in the leftover data_size_list0
final_pos = pos + random . next ( ) % ( len ( self . data_size_list0 ) - pos )
return self . data_size_list0 [ final_pos ] - other_data_size
class auth_chain_d ( auth_chain_b ) :
def __init__ ( self , method ) :
super ( auth_chain_d , self ) . __init__ ( method )
self . salt = b " auth_chain_d "
self . no_compatible_method = ' auth_chain_d '
self . data_size_list0 = [ ]
def check_and_patch_data_size ( self , random ) :
# append new item
# when the biggest item(first time) or the last append item(other time) are not big enough.
# but set a limit size (64) to avoid stack overflow.
if self . data_size_list0 [ - 1 ] < 1300 and len ( self . data_size_list0 ) < 64 :
self . data_size_list0 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . check_and_patch_data_size ( random )
def init_data_size ( self , key ) :
if self . data_size_list0 :
self . data_size_list0 = [ ]
random = xorshift128plus ( )
random . init_from_bin ( key )
# 补全数组长为12~24-1
list_len = random . next ( ) % ( 8 + 16 ) + ( 4 + 8 )
for i in range ( 0 , list_len ) :
self . data_size_list0 . append ( ( int ) ( random . next ( ) % 2340 % 2040 % 1440 ) )
self . data_size_list0 . sort ( )
old_len = len ( self . data_size_list0 )
self . check_and_patch_data_size ( random )
# if check_and_patch_data_size are work, re-sort again.
if old_len != len ( self . data_size_list0 ) :
self . data_size_list0 . sort ( )
def set_server_info ( self , server_info ) :
self . server_info = server_info
try :
max_client = int ( server_info . protocol_param . split ( ' # ' ) [ 0 ] )
except :
max_client = 64
self . server_info . data . set_max_client ( max_client )
self . init_data_size ( self . server_info . key )
def rnd_data_len ( self , buf_size , last_hash , random ) :
other_data_size = buf_size + self . server_info . overhead
# if other_data_size > the bigest item in data_size_list0, not padding any data
if other_data_size > = self . data_size_list0 [ - 1 ] :
return 0
random . init_from_bin_len ( last_hash , buf_size )
pos = bisect . bisect_left ( self . data_size_list0 , other_data_size )
# random select a size in the leftover data_size_list0
final_pos = pos + random . next ( ) % ( len ( self . data_size_list0 ) - pos )
return self . data_size_list0 [ final_pos ] - other_data_size