diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 797704e..868ea61 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -27,12 +27,12 @@ import logging if __name__ == '__main__': import sys import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) sys.path.insert(0, os.path.join(file_path, '../')) from shadowsocks import common, lru_cache, eventloop, shell - CACHE_SWEEP_INTERVAL = 30 VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(?> %s', hostname, self._cache[hostname]) ip = self._cache[hostname] callback((hostname, ip), None) + elif any(hostname.endswith(t) for t in self._black_hostname_list): + callback(None, Exception('hostname <%s> is block by the black hostname list' % hostname)) + return else: if not is_valid_hostname(hostname): callback(None, Exception('invalid hostname: %s' % hostname)) return if False: addrs = socket.getaddrinfo(hostname, 0, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) + socket.SOCK_DGRAM, socket.SOL_UDP) if addrs: af, socktype, proto, canonname, sa = addrs[0] - logging.debug('DNS resolve %s %s' % (hostname, sa[0]) ) + logging.debug('DNS resolve %s %s' % (hostname, sa[0])) self._cache[hostname] = sa[0] callback((hostname, sa[0]), None) return @@ -506,7 +520,11 @@ class DNSResolver(object): def test(): - dns_resolver = DNSResolver() + black_hostname_list = [ + 'baidu.com', + 'yahoo.com', + ] + dns_resolver = DNSResolver(black_hostname_list=black_hostname_list) loop = eventloop.EventLoop() dns_resolver.add_to_loop(loop) @@ -521,16 +539,20 @@ def test(): # TODO: what can we assert? print(result, error) counter += 1 - if counter == 9: + if counter == 12: dns_resolver.close() loop.stop() + a_callback = callback return a_callback - assert(make_callback() != make_callback()) + assert (make_callback() != make_callback()) dns_resolver.resolve(b'google.com', make_callback()) dns_resolver.resolve('google.com', make_callback()) + dns_resolver.resolve('baidu.com', make_callback()) + dns_resolver.resolve('map.baidu.com', make_callback()) + dns_resolver.resolve('yahoo.com', make_callback()) dns_resolver.resolve('example.com', make_callback()) dns_resolver.resolve('ipv6.google.com', make_callback()) dns_resolver.resolve('www.facebook.com', make_callback()) @@ -546,10 +568,25 @@ def test(): 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'long.hostname', make_callback()) - loop.run() + # test black_hostname_list + dns_resolver = DNSResolver(black_hostname_list=[]) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=123) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=None) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver() + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() if __name__ == '__main__': test() - diff --git a/shadowsocks/obfsplugin/auth_chain.py b/shadowsocks/obfsplugin/auth_chain.py index 26097bf..b11d545 100644 --- a/shadowsocks/obfsplugin/auth_chain.py +++ b/shadowsocks/obfsplugin/auth_chain.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# -*- coding: utf-8 -*- # # Copyright 2015-2015 breakwa11 # @@ -38,17 +39,31 @@ from shadowsocks import common, lru_cache, encrypt from shadowsocks.obfsplugin import plain from shadowsocks.common import to_bytes, to_str, ord, chr + def create_auth_chain_a(method): return auth_chain_a(method) + def create_auth_chain_b(method): return auth_chain_b(method) + +def create_auth_chain_c(method): + return auth_chain_c(method) + + +def create_auth_chain_d(method): + return auth_chain_d(method) + + obfs_map = { - 'auth_chain_a': (create_auth_chain_a,), - 'auth_chain_b': (create_auth_chain_b,), + 'auth_chain_a': (create_auth_chain_a,), + 'auth_chain_b': (create_auth_chain_b,), + 'auth_chain_c': (create_auth_chain_c,), + 'auth_chain_d': (create_auth_chain_d,), } + class xorshift128plus(object): max_int = (1 << 64) - 1 mov_mask = (1 << (64 - 23)) - 1 @@ -80,12 +95,14 @@ class xorshift128plus(object): for i in range(4): self.next() + def match_begin(str1, str2): if len(str1) >= len(str2): if str1[:len(str2)] == str2: return True return False + class auth_base(plain.plain): def __init__(self, method): super(auth_base, self).__init__(method) @@ -96,7 +113,7 @@ class auth_base(plain.plain): def init_data(self): return '' - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def set_server_info(self, server_info): @@ -118,9 +135,10 @@ class auth_base(plain.plain): self.raw_trans = True self.overhead = 0 if self.method == self.no_compatible_method: - return (b'E'*2048, False) + return (b'E' * 2048, False) return (buf, False) + class client_queue(object): def __init__(self, begin_id): self.front = begin_id - 64 @@ -175,13 +193,14 @@ class client_queue(object): self.addref() return True + class obfs_auth_chain_data(object): def __init__(self, name): self.name = name self.user_id = {} self.local_client_id = b'' self.connection_id = 0 - self.set_max_client(64) # max active client count + self.set_max_client(64) # max active client count def update(self, user_id, client_id, connection_id): if user_id not in self.user_id: @@ -203,7 +222,7 @@ class obfs_auth_chain_data(object): if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: if local_client_id.first() is None or len(local_client_id) < self.max_client: if client_id not in local_client_id: - #TODO: check + # TODO: check local_client_id[client_id] = client_queue(connection_id) else: local_client_id[client_id].re_enable(connection_id) @@ -212,7 +231,7 @@ class obfs_auth_chain_data(object): if not local_client_id[local_client_id.first()].is_active(): del local_client_id[local_client_id.first()] if client_id not in local_client_id: - #TODO: check + # TODO: check local_client_id[client_id] = client_queue(connection_id) else: local_client_id[client_id].re_enable(connection_id) @@ -229,6 +248,7 @@ class obfs_auth_chain_data(object): if client_id in local_client_id: local_client_id[client_id].delref() + class auth_chain_a(auth_base): def __init__(self, method): super(auth_chain_a, self).__init__(method) @@ -240,7 +260,7 @@ class auth_chain_a(auth_base): self.has_recv_header = False self.client_id = 0 self.connection_id = 0 - self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting self.salt = b"auth_chain_a" self.no_compatible_method = 'auth_chain_a' self.pack_id = 1 @@ -259,7 +279,7 @@ class auth_chain_a(auth_base): def init_data(self): return obfs_auth_chain_data(self.method) - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def set_server_info(self, server_info): @@ -362,14 +382,16 @@ class auth_chain_a(auth_base): if self.user_key is None: self.user_key = self.server_info.key - encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16) + encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16) uid = struct.unpack(' 0 and rand_len > 0: pos = 2 + self.rnd_start_pos(rand_len, self.random_server) - out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) self.last_server_hash = server_hash if self.recv_id == 1: self.server_info.tcp_mss = struct.unpack(' self.max_time_dif: - logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) + logging.info('%s: wrong timestamp, time_dif %d, data %s' % ( + self.no_compatible_method, time_dif, binascii.hexlify(head) + )) return self.not_match_return(self.recv_buf) elif self.server_info.data.insert(self.user_id, client_id, connection_id): self.has_recv_header = True @@ -513,7 +541,8 @@ class auth_chain_a(auth_base): logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) return self.not_match_return(self.recv_buf) - self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') + self.encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.recv_buf = self.recv_buf[36:] self.has_recv_header = True sendback = True @@ -528,7 +557,7 @@ class auth_chain_a(auth_base): self.recv_buf = b'' if self.recv_id == 0: logging.info(self.no_compatible_method + ': over size') - return (b'E'*2048, False) + return (b'E' * 2048, False) else: raise Exception('server_post_decrype data error') @@ -536,12 +565,14 @@ class auth_chain_a(auth_base): break client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() - if client_hash[:2] != self.recv_buf[length + 2 : length + 4]: - logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + if client_hash[:2] != self.recv_buf[length + 2: length + 4]: + logging.info('%s: checksum error, data %s' % ( + self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]) + )) self.raw_trans = True self.recv_buf = b'' if self.recv_id == 0: - return (b'E'*2048, False) + return (b'E' * 2048, False) else: raise Exception('server_post_decrype data uncorrect checksum') @@ -549,7 +580,7 @@ class auth_chain_a(auth_base): pos = 2 if data_len > 0 and rand_len > 0: pos = 2 + self.rnd_start_pos(rand_len, self.random_client) - out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) self.last_client_hash = client_hash self.recv_buf = self.recv_buf[length + 4:] if data_len == 0: @@ -577,7 +608,8 @@ class auth_chain_a(auth_base): uid = struct.unpack(' 1300: return random.next() % 31 @@ -690,3 +736,105 @@ class auth_chain_b(auth_chain_a): return random.next() % 521 return random.next() % 1021 + +class auth_chain_c(auth_chain_b): + def __init__(self, method): + super(auth_chain_c, self).__init__(method) + self.salt = b"auth_chain_c" + self.no_compatible_method = 'auth_chain_c' + self.data_size_list0 = [] + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + random.init_from_bin(key) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + self.init_data_size(self.server_info.key) + + def rnd_data_len(self, buf_size, last_hash, random): + other_data_size = buf_size + self.server_info.overhead + # 一定要在random使用前初始化,以保证服务器与客户端同步,保证包大小验证结果正确 + random.init_from_bin_len(last_hash, buf_size) + # final_pos 总是分布在pos~(data_size_list0.len-1)之间 + # 除非data_size_list0中的任何值均过小使其全部都无法容纳buf + if other_data_size >= self.data_size_list0[-1]: + if other_data_size >= 1440: + return 0 + if other_data_size > 1300: + return random.next() % 31 + if other_data_size > 900: + return random.next() % 127 + if other_data_size > 400: + return random.next() % 521 + return random.next() % 1021 + + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + # random select a size in the leftover data_size_list0 + final_pos = pos + random.next() % (len(self.data_size_list0) - pos) + return self.data_size_list0[final_pos] - other_data_size + + +class auth_chain_d(auth_chain_b): + def __init__(self, method): + super(auth_chain_d, self).__init__(method) + self.salt = b"auth_chain_d" + self.no_compatible_method = 'auth_chain_d' + self.data_size_list0 = [] + + def check_and_patch_data_size(self, random): + # append new item + # when the biggest item(first time) or the last append item(other time) are not big enough. + # but set a limit size (64) to avoid stack overflow. + if self.data_size_list0[-1] < 1300 and len(self.data_size_list0) < 64: + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.check_and_patch_data_size(random) + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + random.init_from_bin(key) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + old_len = len(self.data_size_list0) + self.check_and_patch_data_size(random) + # if check_and_patch_data_size are work, re-sort again. + if old_len != len(self.data_size_list0): + self.data_size_list0.sort() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + self.init_data_size(self.server_info.key) + + def rnd_data_len(self, buf_size, last_hash, random): + other_data_size = buf_size + self.server_info.overhead + # if other_data_size > the bigest item in data_size_list0, not padding any data + if other_data_size >= self.data_size_list0[-1]: + return 0 + + random.init_from_bin_len(last_hash, buf_size) + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + # random select a size in the leftover data_size_list0 + final_pos = pos + random.next() % (len(self.data_size_list0) - pos) + return self.data_size_list0[final_pos] - other_data_size diff --git a/shadowsocks/obfsplugin/http_simple.py b/shadowsocks/obfsplugin/http_simple.py index 6f1a05e..ff3c5fd 100644 --- a/shadowsocks/obfsplugin/http_simple.py +++ b/shadowsocks/obfsplugin/http_simple.py @@ -63,6 +63,7 @@ class http_simple(plain.plain): self.host = None self.port = 0 self.recv_buffer = b'' + # TODO user config user_agent self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0", b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0", b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", diff --git a/shadowsocks/server.py b/shadowsocks/server.py index c18ad1c..0815389 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -25,6 +25,7 @@ import signal if __name__ == '__main__': import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) sys.path.insert(0, os.path.join(file_path, '../')) @@ -43,7 +44,8 @@ def main(): try: import resource - logging.info('current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) + logging.info( + 'current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) except ImportError: pass @@ -68,7 +70,7 @@ def main(): tcp_servers = [] udp_servers = [] - dns_resolver = asyncdns.DNSResolver() + dns_resolver = asyncdns.DNSResolver(config['black_hostname_list']) if int(config['workers']) > 1: stat_counter_dict = None else: @@ -103,10 +105,11 @@ def main(): a_config = config.copy() ipv6_ok = False logging.info("server start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" % - (protocol, password, method, obfs, obfs_param)) + (protocol, password, method, obfs, obfs_param)) if 'server_ipv6' in a_config: try: - if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": + if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][ + -1] == "]": a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] a_config['server_port'] = int(port) a_config['password'] = password @@ -151,11 +154,13 @@ def main(): logging.warn('received SIGQUIT, doing graceful shutting down..') list(map(lambda s: s.close(next_tick=True), tcp_servers + udp_servers)) + signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), child_handler) def int_handler(signum, _): sys.exit(1) + signal.signal(signal.SIGINT, int_handler) try: @@ -191,6 +196,7 @@ def main(): except OSError: # child may already exited pass sys.exit() + signal.signal(signal.SIGTERM, handler) signal.signal(signal.SIGQUIT, handler) signal.signal(signal.SIGINT, handler) diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py index 6246d98..a1547d0 100755 --- a/shadowsocks/shell.py +++ b/shadowsocks/shell.py @@ -26,7 +26,6 @@ import logging from shadowsocks.common import to_bytes, to_str, IPNetwork, PortRange from shadowsocks import encrypt - VERBOSE_LEVEL = 5 verbose = 0 @@ -52,6 +51,7 @@ def print_exception(e): import traceback traceback.print_exc() + def __version(): version_str = '' try: @@ -65,9 +65,11 @@ def __version(): pass return version_str + def print_shadowsocks(): print('ShadowsocksR %s' % __version()) + def log_shadowsocks_version(): logging.info('ShadowsocksR %s' % __version()) @@ -84,6 +86,7 @@ def find_config(): return sub_find(user_config_path) or sub_find(config_path) + def check_config(config, is_local): if config.get('daemon', None) == 'stop': # no need to specify configuration for daemon stop @@ -110,13 +113,13 @@ def check_config(config, is_local): logging.warning('warning: local set to listen on 0.0.0.0, it\'s not safe') if config.get('server', '') in ['127.0.0.1', 'localhost']: logging.warning('warning: server set to listen on %s:%s, are you sure?' % - (to_str(config['server']), config['server_port'])) + (to_str(config['server']), config['server_port'])) if config.get('timeout', 300) < 100: logging.warning('warning: your timeout %d seems too short' % - int(config.get('timeout'))) + int(config.get('timeout'))) if config.get('timeout', 300) > 600: logging.warning('warning: your timeout %d seems too long' % - int(config.get('timeout'))) + int(config.get('timeout'))) if config.get('password') in [b'mypassword']: logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your ' 'config.json!') @@ -160,7 +163,6 @@ def get_config(is_local): if config_path is None: config_path = find_config() - if config_path: logging.debug('loading config from %s' % config_path) with open(config_path, 'rb') as f: @@ -170,7 +172,6 @@ def get_config(is_local): logging.error('found an error in config.json: %s', str(e)) sys.exit(1) - v_count = 0 for key, value in optlist: if key == '-p': @@ -260,6 +261,9 @@ def get_config(is_local): config['server'] = to_str(config['server']) else: config['server'] = to_str(config.get('server', '0.0.0.0')) + config['black_hostname_list'] = to_str(config.get('black_hostname_list', '')).split(',') + if len(config['black_hostname_list']) == 1 and config['black_hostname_list'][0] == '': + config['black_hostname_list'] = [] try: config['forbidden_ip'] = \ IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) @@ -398,6 +402,7 @@ def _decode_dict(data): rv[key] = value return rv + class JSFormat: def __init__(self): self.state = 0 @@ -435,6 +440,7 @@ class JSFormat: return "\n" return "" + def remove_comment(json): fmt = JSFormat() return "".join([fmt.push(c) for c in json])