Browse Source

Finish 3.2.0

add auth_chain_c/d
3.2.0
Akkariiin 7 years ago
parent
commit
0aa49b1ff8
  1. 65
      shadowsocks/asyncdns.py
  2. 204
      shadowsocks/obfsplugin/auth_chain.py
  3. 1
      shadowsocks/obfsplugin/http_simple.py
  4. 14
      shadowsocks/server.py
  5. 18
      shadowsocks/shell.py

65
shadowsocks/asyncdns.py

@ -27,12 +27,12 @@ import logging
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
import inspect import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0, os.path.join(file_path, '../')) sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import common, lru_cache, eventloop, shell from shadowsocks import common, lru_cache, eventloop, shell
CACHE_SWEEP_INTERVAL = 30 CACHE_SWEEP_INTERVAL = 30
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(?<!-)$", re.IGNORECASE) VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(?<!-)$", re.IGNORECASE)
@ -77,6 +77,7 @@ QTYPE_CNAME = 5
QTYPE_NS = 2 QTYPE_NS = 2
QCLASS_IN = 1 QCLASS_IN = 1
def detect_ipv6_supprot(): def detect_ipv6_supprot():
if 'has_ipv6' in dir(socket): if 'has_ipv6' in dir(socket):
try: try:
@ -89,8 +90,10 @@ def detect_ipv6_supprot():
print('IPv6 not support') print('IPv6 not support')
return False return False
IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot() IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot()
def build_address(address): def build_address(address):
address = address.strip(b'.') address = address.strip(b'.')
labels = address.split(b'.') labels = address.split(b'.')
@ -175,7 +178,7 @@ def parse_record(data, offset, question=False):
) )
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10) ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \ return nlen + 10 + record_rdlength, \
(name, ip, record_type, record_class, record_ttl) (name, ip, record_type, record_class, record_ttl)
else: else:
record_type, record_class = struct.unpack( record_type, record_class = struct.unpack(
'!HH', data[offset + nlen:offset + nlen + 4] '!HH', data[offset + nlen:offset + nlen + 4]
@ -209,7 +212,7 @@ def parse_response(data):
if not header: if not header:
return None return None
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \ res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
res_ancount, res_nscount, res_arcount = header res_ancount, res_nscount, res_arcount = header
qds = [] qds = []
ans = [] ans = []
@ -266,14 +269,22 @@ STATUS_IPV6 = 1
class DNSResolver(object): class DNSResolver(object):
def __init__(self, black_hostname_list=None):
def __init__(self):
self._loop = None self._loop = None
self._hosts = {} self._hosts = {}
self._hostname_status = {} self._hostname_status = {}
self._hostname_to_cb = {} self._hostname_to_cb = {}
self._cb_to_hostname = {} self._cb_to_hostname = {}
self._cache = lru_cache.LRUCache(timeout=300) self._cache = lru_cache.LRUCache(timeout=300)
# read black_hostname_list from config
if type(black_hostname_list) != list:
self._black_hostname_list = []
else:
self._black_hostname_list = list(map(
(lambda t: t if type(t) == bytes else t.encode('utf8')),
black_hostname_list
))
logging.info('black_hostname_list init as : ' + str(self._black_hostname_list))
self._sock = None self._sock = None
self._servers = None self._servers = None
self._parse_resolv() self._parse_resolv()
@ -377,7 +388,7 @@ class DNSResolver(object):
ip = None ip = None
for answer in response.answers: for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN: answer[2] == QCLASS_IN:
ip = answer[0] ip = answer[0]
break break
if IPV6_CONNECTION_SUPPORT: if IPV6_CONNECTION_SUPPORT:
@ -462,19 +473,22 @@ class DNSResolver(object):
ip = self._hosts[hostname] ip = self._hosts[hostname]
callback((hostname, ip), None) callback((hostname, ip), None)
elif hostname in self._cache: elif hostname in self._cache:
logging.debug('hit cache: %s', hostname) logging.debug('hit cache: %s ==>> %s', hostname, self._cache[hostname])
ip = self._cache[hostname] ip = self._cache[hostname]
callback((hostname, ip), None) callback((hostname, ip), None)
elif any(hostname.endswith(t) for t in self._black_hostname_list):
callback(None, Exception('hostname <%s> is block by the black hostname list' % hostname))
return
else: else:
if not is_valid_hostname(hostname): if not is_valid_hostname(hostname):
callback(None, Exception('invalid hostname: %s' % hostname)) callback(None, Exception('invalid hostname: %s' % hostname))
return return
if False: if False:
addrs = socket.getaddrinfo(hostname, 0, 0, addrs = socket.getaddrinfo(hostname, 0, 0,
socket.SOCK_DGRAM, socket.SOL_UDP) socket.SOCK_DGRAM, socket.SOL_UDP)
if addrs: if addrs:
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
logging.debug('DNS resolve %s %s' % (hostname, sa[0]) ) logging.debug('DNS resolve %s %s' % (hostname, sa[0]))
self._cache[hostname] = sa[0] self._cache[hostname] = sa[0]
callback((hostname, sa[0]), None) callback((hostname, sa[0]), None)
return return
@ -506,7 +520,11 @@ class DNSResolver(object):
def test(): def test():
dns_resolver = DNSResolver() black_hostname_list = [
'baidu.com',
'yahoo.com',
]
dns_resolver = DNSResolver(black_hostname_list=black_hostname_list)
loop = eventloop.EventLoop() loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop) dns_resolver.add_to_loop(loop)
@ -521,16 +539,20 @@ def test():
# TODO: what can we assert? # TODO: what can we assert?
print(result, error) print(result, error)
counter += 1 counter += 1
if counter == 9: if counter == 12:
dns_resolver.close() dns_resolver.close()
loop.stop() loop.stop()
a_callback = callback a_callback = callback
return a_callback return a_callback
assert(make_callback() != make_callback()) assert (make_callback() != make_callback())
dns_resolver.resolve(b'google.com', make_callback()) dns_resolver.resolve(b'google.com', make_callback())
dns_resolver.resolve('google.com', make_callback()) dns_resolver.resolve('google.com', make_callback())
dns_resolver.resolve('baidu.com', make_callback())
dns_resolver.resolve('map.baidu.com', make_callback())
dns_resolver.resolve('yahoo.com', make_callback())
dns_resolver.resolve('example.com', make_callback()) dns_resolver.resolve('example.com', make_callback())
dns_resolver.resolve('ipv6.google.com', make_callback()) dns_resolver.resolve('ipv6.google.com', make_callback())
dns_resolver.resolve('www.facebook.com', make_callback()) dns_resolver.resolve('www.facebook.com', make_callback())
@ -546,10 +568,25 @@ def test():
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback()) 'long.hostname', make_callback())
loop.run() loop.run()
# test black_hostname_list
dns_resolver = DNSResolver(black_hostname_list=[])
assert type(dns_resolver._black_hostname_list) == list
assert len(dns_resolver._black_hostname_list) == 0
dns_resolver.close()
dns_resolver = DNSResolver(black_hostname_list=123)
assert type(dns_resolver._black_hostname_list) == list
assert len(dns_resolver._black_hostname_list) == 0
dns_resolver.close()
dns_resolver = DNSResolver(black_hostname_list=None)
assert type(dns_resolver._black_hostname_list) == list
assert len(dns_resolver._black_hostname_list) == 0
dns_resolver.close()
dns_resolver = DNSResolver()
assert type(dns_resolver._black_hostname_list) == list
assert dns_resolver._black_hostname_list.__len__() == 0
dns_resolver.close()
if __name__ == '__main__': if __name__ == '__main__':
test() test()

204
shadowsocks/obfsplugin/auth_chain.py

@ -1,4 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*-
# #
# Copyright 2015-2015 breakwa11 # Copyright 2015-2015 breakwa11
# #
@ -38,17 +39,31 @@ from shadowsocks import common, lru_cache, encrypt
from shadowsocks.obfsplugin import plain from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord, chr from shadowsocks.common import to_bytes, to_str, ord, chr
def create_auth_chain_a(method): def create_auth_chain_a(method):
return auth_chain_a(method) return auth_chain_a(method)
def create_auth_chain_b(method): def create_auth_chain_b(method):
return auth_chain_b(method) return auth_chain_b(method)
def create_auth_chain_c(method):
return auth_chain_c(method)
def create_auth_chain_d(method):
return auth_chain_d(method)
obfs_map = { obfs_map = {
'auth_chain_a': (create_auth_chain_a,), 'auth_chain_a': (create_auth_chain_a,),
'auth_chain_b': (create_auth_chain_b,), 'auth_chain_b': (create_auth_chain_b,),
'auth_chain_c': (create_auth_chain_c,),
'auth_chain_d': (create_auth_chain_d,),
} }
class xorshift128plus(object): class xorshift128plus(object):
max_int = (1 << 64) - 1 max_int = (1 << 64) - 1
mov_mask = (1 << (64 - 23)) - 1 mov_mask = (1 << (64 - 23)) - 1
@ -80,12 +95,14 @@ class xorshift128plus(object):
for i in range(4): for i in range(4):
self.next() self.next()
def match_begin(str1, str2): def match_begin(str1, str2):
if len(str1) >= len(str2): if len(str1) >= len(str2):
if str1[:len(str2)] == str2: if str1[:len(str2)] == str2:
return True return True
return False return False
class auth_base(plain.plain): class auth_base(plain.plain):
def __init__(self, method): def __init__(self, method):
super(auth_base, self).__init__(method) super(auth_base, self).__init__(method)
@ -96,7 +113,7 @@ class auth_base(plain.plain):
def init_data(self): def init_data(self):
return '' return ''
def get_overhead(self, direction): # direction: true for c->s false for s->c def get_overhead(self, direction): # direction: true for c->s false for s->c
return self.overhead return self.overhead
def set_server_info(self, server_info): def set_server_info(self, server_info):
@ -118,9 +135,10 @@ class auth_base(plain.plain):
self.raw_trans = True self.raw_trans = True
self.overhead = 0 self.overhead = 0
if self.method == self.no_compatible_method: if self.method == self.no_compatible_method:
return (b'E'*2048, False) return (b'E' * 2048, False)
return (buf, False) return (buf, False)
class client_queue(object): class client_queue(object):
def __init__(self, begin_id): def __init__(self, begin_id):
self.front = begin_id - 64 self.front = begin_id - 64
@ -175,13 +193,14 @@ class client_queue(object):
self.addref() self.addref()
return True return True
class obfs_auth_chain_data(object): class obfs_auth_chain_data(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.user_id = {} self.user_id = {}
self.local_client_id = b'' self.local_client_id = b''
self.connection_id = 0 self.connection_id = 0
self.set_max_client(64) # max active client count self.set_max_client(64) # max active client count
def update(self, user_id, client_id, connection_id): def update(self, user_id, client_id, connection_id):
if user_id not in self.user_id: if user_id not in self.user_id:
@ -203,7 +222,7 @@ class obfs_auth_chain_data(object):
if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable:
if local_client_id.first() is None or len(local_client_id) < self.max_client: if local_client_id.first() is None or len(local_client_id) < self.max_client:
if client_id not in local_client_id: if client_id not in local_client_id:
#TODO: check # TODO: check
local_client_id[client_id] = client_queue(connection_id) local_client_id[client_id] = client_queue(connection_id)
else: else:
local_client_id[client_id].re_enable(connection_id) local_client_id[client_id].re_enable(connection_id)
@ -212,7 +231,7 @@ class obfs_auth_chain_data(object):
if not local_client_id[local_client_id.first()].is_active(): if not local_client_id[local_client_id.first()].is_active():
del local_client_id[local_client_id.first()] del local_client_id[local_client_id.first()]
if client_id not in local_client_id: if client_id not in local_client_id:
#TODO: check # TODO: check
local_client_id[client_id] = client_queue(connection_id) local_client_id[client_id] = client_queue(connection_id)
else: else:
local_client_id[client_id].re_enable(connection_id) local_client_id[client_id].re_enable(connection_id)
@ -229,6 +248,7 @@ class obfs_auth_chain_data(object):
if client_id in local_client_id: if client_id in local_client_id:
local_client_id[client_id].delref() local_client_id[client_id].delref()
class auth_chain_a(auth_base): class auth_chain_a(auth_base):
def __init__(self, method): def __init__(self, method):
super(auth_chain_a, self).__init__(method) super(auth_chain_a, self).__init__(method)
@ -240,7 +260,7 @@ class auth_chain_a(auth_base):
self.has_recv_header = False self.has_recv_header = False
self.client_id = 0 self.client_id = 0
self.connection_id = 0 self.connection_id = 0
self.max_time_dif = 60 * 60 * 24 # time dif (second) setting self.max_time_dif = 60 * 60 * 24 # time dif (second) setting
self.salt = b"auth_chain_a" self.salt = b"auth_chain_a"
self.no_compatible_method = 'auth_chain_a' self.no_compatible_method = 'auth_chain_a'
self.pack_id = 1 self.pack_id = 1
@ -259,7 +279,7 @@ class auth_chain_a(auth_base):
def init_data(self): def init_data(self):
return obfs_auth_chain_data(self.method) return obfs_auth_chain_data(self.method)
def get_overhead(self, direction): # direction: true for c->s false for s->c def get_overhead(self, direction): # direction: true for c->s false for s->c
return self.overhead return self.overhead
def set_server_info(self, server_info): def set_server_info(self, server_info):
@ -362,14 +382,16 @@ class auth_chain_a(auth_base):
if self.user_key is None: if self.user_key is None:
self.user_key = self.server_info.key self.user_key = self.server_info.key
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16) encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0] uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0]
uid = struct.pack('<I', uid) uid = struct.pack('<I', uid)
data = uid + encryptor.encrypt(data)[16:] data = uid + encryptor.encrypt(data)[16:]
self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest() self.last_server_hash = hmac.new(self.user_key, data, self.hashfunc).digest()
data = check_head + data + self.last_server_hash[:4] data = check_head + data + self.last_server_hash[:4]
self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
return data + self.pack_client_data(buf) return data + self.pack_client_data(buf)
def auth_data(self): def auth_data(self):
@ -382,8 +404,8 @@ class auth_chain_a(auth_base):
self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF
self.server_info.data.connection_id += 1 self.server_info.data.connection_id += 1
return b''.join([struct.pack('<I', utc_time), return b''.join([struct.pack('<I', utc_time),
self.server_info.data.local_client_id, self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)]) struct.pack('<I', self.server_info.data.connection_id)])
def client_pre_encrypt(self, buf): def client_pre_encrypt(self, buf):
ret = b'' ret = b''
@ -419,8 +441,9 @@ class auth_chain_a(auth_base):
break break
server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if server_hash[:2] != self.recv_buf[length + 2 : length + 4]: if server_hash[:2] != self.recv_buf[length + 2: length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) logging.info('%s: checksum error, data %s'
% (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
raise Exception('client_post_decrypt data uncorrect checksum') raise Exception('client_post_decrypt data uncorrect checksum')
@ -428,7 +451,7 @@ class auth_chain_a(auth_base):
pos = 2 pos = 2
if data_len > 0 and rand_len > 0: if data_len > 0 and rand_len > 0:
pos = 2 + self.rnd_start_pos(rand_len, self.random_server) pos = 2 + self.rnd_start_pos(rand_len, self.random_server)
out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos])
self.last_server_hash = server_hash self.last_server_hash = server_hash
if self.recv_id == 1: if self.recv_id == 1:
self.server_info.tcp_mss = struct.unpack('<H', out_buf[:2])[0] self.server_info.tcp_mss = struct.unpack('<H', out_buf[:2])[0]
@ -486,16 +509,19 @@ class auth_chain_a(auth_base):
else: else:
self.user_key = self.server_info.recv_iv self.user_key = self.server_info.recv_iv
md5data = hmac.new(self.user_key, self.recv_buf[12 : 12 + 20], self.hashfunc).digest() md5data = hmac.new(self.user_key, self.recv_buf[12: 12 + 20], self.hashfunc).digest()
if md5data[:4] != self.recv_buf[32:36]: if md5data[:4] != self.recv_buf[32:36]:
logging.error('%s data uncorrect auth HMAC-MD5 from %s:%d, data %s' % (self.no_compatible_method, self.server_info.client, self.server_info.client_port, binascii.hexlify(self.recv_buf))) logging.error('%s data uncorrect auth HMAC-MD5 from %s:%d, data %s' % (
self.no_compatible_method, self.server_info.client, self.server_info.client_port,
binascii.hexlify(self.recv_buf)
))
if len(self.recv_buf) < 36: if len(self.recv_buf) < 36:
return (b'', False) return (b'', False)
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
self.last_server_hash = md5data self.last_server_hash = md5data
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc') encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc')
head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[16:32] + b'\x00') # need an extra byte or recv empty head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[16:32] + b'\x00') # need an extra byte or recv empty
self.client_over_head = struct.unpack('<H', head[12:14])[0] self.client_over_head = struct.unpack('<H', head[12:14])[0]
utc_time = struct.unpack('<I', head[:4])[0] utc_time = struct.unpack('<I', head[:4])[0]
@ -503,7 +529,9 @@ class auth_chain_a(auth_base):
connection_id = struct.unpack('<I', head[8:12])[0] connection_id = struct.unpack('<I', head[8:12])[0]
time_dif = common.int32(utc_time - (int(time.time()) & 0xffffffff)) time_dif = common.int32(utc_time - (int(time.time()) & 0xffffffff))
if time_dif < -self.max_time_dif or time_dif > self.max_time_dif: if time_dif < -self.max_time_dif or time_dif > self.max_time_dif:
logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) logging.info('%s: wrong timestamp, time_dif %d, data %s' % (
self.no_compatible_method, time_dif, binascii.hexlify(head)
))
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
elif self.server_info.data.insert(self.user_id, client_id, connection_id): elif self.server_info.data.insert(self.user_id, client_id, connection_id):
self.has_recv_header = True self.has_recv_header = True
@ -513,7 +541,8 @@ class auth_chain_a(auth_base):
logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf)))
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.recv_buf = self.recv_buf[36:] self.recv_buf = self.recv_buf[36:]
self.has_recv_header = True self.has_recv_header = True
sendback = True sendback = True
@ -528,7 +557,7 @@ class auth_chain_a(auth_base):
self.recv_buf = b'' self.recv_buf = b''
if self.recv_id == 0: if self.recv_id == 0:
logging.info(self.no_compatible_method + ': over size') logging.info(self.no_compatible_method + ': over size')
return (b'E'*2048, False) return (b'E' * 2048, False)
else: else:
raise Exception('server_post_decrype data error') raise Exception('server_post_decrype data error')
@ -536,12 +565,14 @@ class auth_chain_a(auth_base):
break break
client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if client_hash[:2] != self.recv_buf[length + 2 : length + 4]: if client_hash[:2] != self.recv_buf[length + 2: length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) logging.info('%s: checksum error, data %s' % (
self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])
))
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
if self.recv_id == 0: if self.recv_id == 0:
return (b'E'*2048, False) return (b'E' * 2048, False)
else: else:
raise Exception('server_post_decrype data uncorrect checksum') raise Exception('server_post_decrype data uncorrect checksum')
@ -549,7 +580,7 @@ class auth_chain_a(auth_base):
pos = 2 pos = 2
if data_len > 0 and rand_len > 0: if data_len > 0 and rand_len > 0:
pos = 2 + self.rnd_start_pos(rand_len, self.random_client) pos = 2 + self.rnd_start_pos(rand_len, self.random_client)
out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos])
self.last_client_hash = client_hash self.last_client_hash = client_hash
self.recv_buf = self.recv_buf[length + 4:] self.recv_buf = self.recv_buf[length + 4:]
if data_len == 0: if data_len == 0:
@ -577,7 +608,8 @@ class auth_chain_a(auth_base):
uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0] uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0]
uid = struct.pack('<I', uid) uid = struct.pack('<I', uid)
rand_len = self.udp_rnd_data_len(md5data, self.random_client) rand_len = self.udp_rnd_data_len(md5data, self.random_client)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4') encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf) out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata + uid buf = out_buf + os.urandom(rand_len) + authdata + uid
return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1] return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1]
@ -590,7 +622,8 @@ class auth_chain_a(auth_base):
mac_key = self.server_info.key mac_key = self.server_info.key
md5data = hmac.new(mac_key, buf[-8:-1], self.hashfunc).digest() md5data = hmac.new(mac_key, buf[-8:-1], self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server) rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4') encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
return encryptor.decrypt(buf[:-8 - rand_len]) return encryptor.decrypt(buf[:-8 - rand_len])
def server_udp_pre_encrypt(self, buf, uid): def server_udp_pre_encrypt(self, buf, uid):
@ -634,11 +667,16 @@ class auth_chain_a(auth_base):
def dispose(self): def dispose(self):
self.server_info.data.remove(self.user_id, self.client_id) self.server_info.data.remove(self.user_id, self.client_id)
class auth_chain_b(auth_chain_a): class auth_chain_b(auth_chain_a):
def __init__(self, method): def __init__(self, method):
super(auth_chain_b, self).__init__(method) super(auth_chain_b, self).__init__(method)
self.salt = b"auth_chain_b" self.salt = b"auth_chain_b"
self.no_compatible_method = 'auth_chain_b' self.no_compatible_method = 'auth_chain_b'
# NOTE
# 补全后长度数组
# 随机在其中选择一个补全到的长度
# 为每个连接初始化一个固定内容的数组
self.data_size_list = [] self.data_size_list = []
self.data_size_list2 = [] self.data_size_list2 = []
@ -648,10 +686,12 @@ class auth_chain_b(auth_chain_a):
self.data_size_list2 = [] self.data_size_list2 = []
random = xorshift128plus() random = xorshift128plus()
random.init_from_bin(key) random.init_from_bin(key)
# 补全数组长为4~12-1
list_len = random.next() % 8 + 4 list_len = random.next() % 8 + 4
for i in range(0, list_len): for i in range(0, list_len):
self.data_size_list.append((int)(random.next() % 2340 % 2040 % 1440)) self.data_size_list.append((int)(random.next() % 2340 % 2040 % 1440))
self.data_size_list.sort() self.data_size_list.sort()
# 补全数组长为8~24-1
list_len = random.next() % 16 + 8 list_len = random.next() % 16 + 8
for i in range(0, list_len): for i in range(0, list_len):
self.data_size_list2.append((int)(random.next() % 2340 % 2040 % 1440)) self.data_size_list2.append((int)(random.next() % 2340 % 2040 % 1440))
@ -672,15 +712,21 @@ class auth_chain_b(auth_chain_a):
random.init_from_bin_len(last_hash, buf_size) random.init_from_bin_len(last_hash, buf_size)
pos = bisect.bisect_left(self.data_size_list, buf_size + self.server_info.overhead) pos = bisect.bisect_left(self.data_size_list, buf_size + self.server_info.overhead)
final_pos = pos + random.next() % (len(self.data_size_list)) final_pos = pos + random.next() % (len(self.data_size_list))
# 假设random均匀分布,则越长的原始数据长度越容易if false
if final_pos < len(self.data_size_list): if final_pos < len(self.data_size_list):
return self.data_size_list[final_pos] - buf_size - self.server_info.overhead return self.data_size_list[final_pos] - buf_size - self.server_info.overhead
# 上面if false后选择2号补全数组,此处有更精细的长度分段
pos = bisect.bisect_left(self.data_size_list2, buf_size + self.server_info.overhead) pos = bisect.bisect_left(self.data_size_list2, buf_size + self.server_info.overhead)
final_pos = pos + random.next() % (len(self.data_size_list2)) final_pos = pos + random.next() % (len(self.data_size_list2))
if final_pos < len(self.data_size_list2): if final_pos < len(self.data_size_list2):
return self.data_size_list2[final_pos] - buf_size - self.server_info.overhead return self.data_size_list2[final_pos] - buf_size - self.server_info.overhead
# final_pos 总是分布在pos~(data_size_list2.len-1)之间
if final_pos < pos + len(self.data_size_list2) - 1: if final_pos < pos + len(self.data_size_list2) - 1:
return 0 return 0
# 有1/len(self.data_size_list2)的概率不满足上一个if ?
# 理论上不会运行到此处,因此可以插入运行断言 ?
# assert False
if buf_size > 1300: if buf_size > 1300:
return random.next() % 31 return random.next() % 31
@ -690,3 +736,105 @@ class auth_chain_b(auth_chain_a):
return random.next() % 521 return random.next() % 521
return random.next() % 1021 return random.next() % 1021
class auth_chain_c(auth_chain_b):
def __init__(self, method):
super(auth_chain_c, self).__init__(method)
self.salt = b"auth_chain_c"
self.no_compatible_method = 'auth_chain_c'
self.data_size_list0 = []
def init_data_size(self, key):
if self.data_size_list0:
self.data_size_list0 = []
random = xorshift128plus()
random.init_from_bin(key)
# 补全数组长为12~24-1
list_len = random.next() % (8 + 16) + (4 + 8)
for i in range(0, list_len):
self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440))
self.data_size_list0.sort()
def set_server_info(self, server_info):
self.server_info = server_info
try:
max_client = int(server_info.protocol_param.split('#')[0])
except:
max_client = 64
self.server_info.data.set_max_client(max_client)
self.init_data_size(self.server_info.key)
def rnd_data_len(self, buf_size, last_hash, random):
other_data_size = buf_size + self.server_info.overhead
# 一定要在random使用前初始化,以保证服务器与客户端同步,保证包大小验证结果正确
random.init_from_bin_len(last_hash, buf_size)
# final_pos 总是分布在pos~(data_size_list0.len-1)之间
# 除非data_size_list0中的任何值均过小使其全部都无法容纳buf
if other_data_size >= self.data_size_list0[-1]:
if other_data_size >= 1440:
return 0
if other_data_size > 1300:
return random.next() % 31
if other_data_size > 900:
return random.next() % 127
if other_data_size > 400:
return random.next() % 521
return random.next() % 1021
pos = bisect.bisect_left(self.data_size_list0, other_data_size)
# random select a size in the leftover data_size_list0
final_pos = pos + random.next() % (len(self.data_size_list0) - pos)
return self.data_size_list0[final_pos] - other_data_size
class auth_chain_d(auth_chain_b):
def __init__(self, method):
super(auth_chain_d, self).__init__(method)
self.salt = b"auth_chain_d"
self.no_compatible_method = 'auth_chain_d'
self.data_size_list0 = []
def check_and_patch_data_size(self, random):
# append new item
# when the biggest item(first time) or the last append item(other time) are not big enough.
# but set a limit size (64) to avoid stack overflow.
if self.data_size_list0[-1] < 1300 and len(self.data_size_list0) < 64:
self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440))
self.check_and_patch_data_size(random)
def init_data_size(self, key):
if self.data_size_list0:
self.data_size_list0 = []
random = xorshift128plus()
random.init_from_bin(key)
# 补全数组长为12~24-1
list_len = random.next() % (8 + 16) + (4 + 8)
for i in range(0, list_len):
self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440))
self.data_size_list0.sort()
old_len = len(self.data_size_list0)
self.check_and_patch_data_size(random)
# if check_and_patch_data_size are work, re-sort again.
if old_len != len(self.data_size_list0):
self.data_size_list0.sort()
def set_server_info(self, server_info):
self.server_info = server_info
try:
max_client = int(server_info.protocol_param.split('#')[0])
except:
max_client = 64
self.server_info.data.set_max_client(max_client)
self.init_data_size(self.server_info.key)
def rnd_data_len(self, buf_size, last_hash, random):
other_data_size = buf_size + self.server_info.overhead
# if other_data_size > the bigest item in data_size_list0, not padding any data
if other_data_size >= self.data_size_list0[-1]:
return 0
random.init_from_bin_len(last_hash, buf_size)
pos = bisect.bisect_left(self.data_size_list0, other_data_size)
# random select a size in the leftover data_size_list0
final_pos = pos + random.next() % (len(self.data_size_list0) - pos)
return self.data_size_list0[final_pos] - other_data_size

1
shadowsocks/obfsplugin/http_simple.py

@ -63,6 +63,7 @@ class http_simple(plain.plain):
self.host = None self.host = None
self.port = 0 self.port = 0
self.recv_buffer = b'' self.recv_buffer = b''
# TODO user config user_agent
self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0", self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0",
b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0", b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0",
b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36",

14
shadowsocks/server.py

@ -25,6 +25,7 @@ import signal
if __name__ == '__main__': if __name__ == '__main__':
import inspect import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0, os.path.join(file_path, '../')) sys.path.insert(0, os.path.join(file_path, '../'))
@ -43,7 +44,8 @@ def main():
try: try:
import resource import resource
logging.info('current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) logging.info(
'current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE))
except ImportError: except ImportError:
pass pass
@ -68,7 +70,7 @@ def main():
tcp_servers = [] tcp_servers = []
udp_servers = [] udp_servers = []
dns_resolver = asyncdns.DNSResolver() dns_resolver = asyncdns.DNSResolver(config['black_hostname_list'])
if int(config['workers']) > 1: if int(config['workers']) > 1:
stat_counter_dict = None stat_counter_dict = None
else: else:
@ -103,10 +105,11 @@ def main():
a_config = config.copy() a_config = config.copy()
ipv6_ok = False ipv6_ok = False
logging.info("server start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" % logging.info("server start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" %
(protocol, password, method, obfs, obfs_param)) (protocol, password, method, obfs, obfs_param))
if 'server_ipv6' in a_config: if 'server_ipv6' in a_config:
try: try:
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][
-1] == "]":
a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] a_config['server_ipv6'] = a_config['server_ipv6'][1:-1]
a_config['server_port'] = int(port) a_config['server_port'] = int(port)
a_config['password'] = password a_config['password'] = password
@ -151,11 +154,13 @@ def main():
logging.warn('received SIGQUIT, doing graceful shutting down..') logging.warn('received SIGQUIT, doing graceful shutting down..')
list(map(lambda s: s.close(next_tick=True), list(map(lambda s: s.close(next_tick=True),
tcp_servers + udp_servers)) tcp_servers + udp_servers))
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM),
child_handler) child_handler)
def int_handler(signum, _): def int_handler(signum, _):
sys.exit(1) sys.exit(1)
signal.signal(signal.SIGINT, int_handler) signal.signal(signal.SIGINT, int_handler)
try: try:
@ -191,6 +196,7 @@ def main():
except OSError: # child may already exited except OSError: # child may already exited
pass pass
sys.exit() sys.exit()
signal.signal(signal.SIGTERM, handler) signal.signal(signal.SIGTERM, handler)
signal.signal(signal.SIGQUIT, handler) signal.signal(signal.SIGQUIT, handler)
signal.signal(signal.SIGINT, handler) signal.signal(signal.SIGINT, handler)

18
shadowsocks/shell.py

@ -26,7 +26,6 @@ import logging
from shadowsocks.common import to_bytes, to_str, IPNetwork, PortRange from shadowsocks.common import to_bytes, to_str, IPNetwork, PortRange
from shadowsocks import encrypt from shadowsocks import encrypt
VERBOSE_LEVEL = 5 VERBOSE_LEVEL = 5
verbose = 0 verbose = 0
@ -52,6 +51,7 @@ def print_exception(e):
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def __version(): def __version():
version_str = '' version_str = ''
try: try:
@ -65,9 +65,11 @@ def __version():
pass pass
return version_str return version_str
def print_shadowsocks(): def print_shadowsocks():
print('ShadowsocksR %s' % __version()) print('ShadowsocksR %s' % __version())
def log_shadowsocks_version(): def log_shadowsocks_version():
logging.info('ShadowsocksR %s' % __version()) logging.info('ShadowsocksR %s' % __version())
@ -84,6 +86,7 @@ def find_config():
return sub_find(user_config_path) or sub_find(config_path) return sub_find(user_config_path) or sub_find(config_path)
def check_config(config, is_local): def check_config(config, is_local):
if config.get('daemon', None) == 'stop': if config.get('daemon', None) == 'stop':
# no need to specify configuration for daemon stop # no need to specify configuration for daemon stop
@ -110,13 +113,13 @@ def check_config(config, is_local):
logging.warning('warning: local set to listen on 0.0.0.0, it\'s not safe') logging.warning('warning: local set to listen on 0.0.0.0, it\'s not safe')
if config.get('server', '') in ['127.0.0.1', 'localhost']: if config.get('server', '') in ['127.0.0.1', 'localhost']:
logging.warning('warning: server set to listen on %s:%s, are you sure?' % logging.warning('warning: server set to listen on %s:%s, are you sure?' %
(to_str(config['server']), config['server_port'])) (to_str(config['server']), config['server_port']))
if config.get('timeout', 300) < 100: if config.get('timeout', 300) < 100:
logging.warning('warning: your timeout %d seems too short' % logging.warning('warning: your timeout %d seems too short' %
int(config.get('timeout'))) int(config.get('timeout')))
if config.get('timeout', 300) > 600: if config.get('timeout', 300) > 600:
logging.warning('warning: your timeout %d seems too long' % logging.warning('warning: your timeout %d seems too long' %
int(config.get('timeout'))) int(config.get('timeout')))
if config.get('password') in [b'mypassword']: if config.get('password') in [b'mypassword']:
logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your ' logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your '
'config.json!') 'config.json!')
@ -160,7 +163,6 @@ def get_config(is_local):
if config_path is None: if config_path is None:
config_path = find_config() config_path = find_config()
if config_path: if config_path:
logging.debug('loading config from %s' % config_path) logging.debug('loading config from %s' % config_path)
with open(config_path, 'rb') as f: with open(config_path, 'rb') as f:
@ -170,7 +172,6 @@ def get_config(is_local):
logging.error('found an error in config.json: %s', str(e)) logging.error('found an error in config.json: %s', str(e))
sys.exit(1) sys.exit(1)
v_count = 0 v_count = 0
for key, value in optlist: for key, value in optlist:
if key == '-p': if key == '-p':
@ -260,6 +261,9 @@ def get_config(is_local):
config['server'] = to_str(config['server']) config['server'] = to_str(config['server'])
else: else:
config['server'] = to_str(config.get('server', '0.0.0.0')) config['server'] = to_str(config.get('server', '0.0.0.0'))
config['black_hostname_list'] = to_str(config.get('black_hostname_list', '')).split(',')
if len(config['black_hostname_list']) == 1 and config['black_hostname_list'][0] == '':
config['black_hostname_list'] = []
try: try:
config['forbidden_ip'] = \ config['forbidden_ip'] = \
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
@ -398,6 +402,7 @@ def _decode_dict(data):
rv[key] = value rv[key] = value
return rv return rv
class JSFormat: class JSFormat:
def __init__(self): def __init__(self):
self.state = 0 self.state = 0
@ -435,6 +440,7 @@ class JSFormat:
return "\n" return "\n"
return "" return ""
def remove_comment(json): def remove_comment(json):
fmt = JSFormat() fmt = JSFormat()
return "".join([fmt.push(c) for c in json]) return "".join([fmt.push(c) for c in json])

Loading…
Cancel
Save