diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 6f60dc9..6fee6b9 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -233,18 +233,6 @@ def parse_response(data): return None -def is_ip(address): - for family in (socket.AF_INET, socket.AF_INET6): - try: - if type(address) != str: - address = address.decode('utf8') - socket.inet_pton(family, address) - return family - except (TypeError, ValueError, OSError, IOError): - pass - return False - - def is_valid_hostname(hostname): if len(hostname) > 255: return False @@ -296,7 +284,7 @@ class DNSResolver(object): parts = line.split() if len(parts) >= 2: server = parts[1] - if is_ip(server) == socket.AF_INET: + if common.is_ip(server) == socket.AF_INET: if type(server) != str: server = server.decode('utf8') self._servers.append(server) @@ -316,7 +304,7 @@ class DNSResolver(object): parts = line.split() if len(parts) >= 2: ip = parts[0] - if is_ip(ip): + if common.is_ip(ip): for i in range(1, len(parts)): hostname = parts[i] if hostname: @@ -423,7 +411,7 @@ class DNSResolver(object): hostname = hostname.encode('utf8') if not hostname: callback(None, Exception('empty hostname')) - elif is_ip(hostname): + elif common.is_ip(hostname): callback((hostname, hostname), None) elif hostname in self._hosts: logging.debug('hit hosts: %s', hostname) diff --git a/shadowsocks/common.py b/shadowsocks/common.py index e4f698c..0c4e278 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -101,6 +101,18 @@ def inet_pton(family, addr): raise RuntimeError("What family?") +def is_ip(address): + for family in (socket.AF_INET, socket.AF_INET6): + try: + if type(address) != str: + address = address.decode('utf8') + inet_pton(family, address) + return family + except (TypeError, ValueError, OSError, IOError): + pass + return False + + def patch_socket(): if not hasattr(socket, 'inet_pton'): socket.inet_pton = inet_pton @@ -172,6 +184,61 @@ def parse_header(data): return addrtype, to_bytes(dest_addr), dest_port, header_length +class IPNetwork(object): + ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128, False: 0} + + def __init__(self, addrs): + self._network_list_v4 = [] + self._network_list_v6 = [] + if type(addrs) == str: + addrs = addrs.split(',') + list(map(self.add_network, addrs)) + + def add_network(self, addr): + if addr is "": + return + block = addr.split('/') + addr_family = is_ip(block[0]) + addr_len = IPNetwork.ADDRLENGTH[addr_family] + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(block[0])) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, block[0])) + ip = (hi << 64) | lo + else: + raise SyntaxError("Not a valid CIDR notation: %s" % addr) + if len(block) is 1: + prefix_size = 0 + while (ip & 1) == 0 and ip is not 0: + ip >>= 1 + prefix_size += 1 + logging.warn("You did't specify CIDR routing prefix size for %s, " + "implicit treated as %s/%d" % (addr, addr, addr_len)) + elif block[1].isdigit() and int(block[1]) <= addr_len: + prefix_size = addr_len - int(block[1]) + ip >>= prefix_size + else: + raise SyntaxError("Not a valid CIDR notation: %s" % addr) + if addr_family is socket.AF_INET: + self._network_list_v4.append((ip, prefix_size)) + else: + self._network_list_v6.append((ip, prefix_size)) + + def __contains__(self, addr): + addr_family = is_ip(addr) + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(addr)) + return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1], + self._network_list_v4)) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr)) + ip = (hi << 64) | lo + return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1], + self._network_list_v6)) + else: + return False + + def test_inet_conv(): ipv4 = b'8.8.4.4' b = inet_pton(socket.AF_INET, ipv4) @@ -198,7 +265,23 @@ def test_pack_header(): assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com' +def test_ip_network(): + ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.0.2.0') + assert '127.0.0.1' in ip_network + assert '127.0.1.1' not in ip_network + assert ':ff:ffff' in ip_network + assert '::ffff:1' not in ip_network + assert '::1' in ip_network + assert '::2' not in ip_network + assert '192.168.1.1' in ip_network + assert '192.168.1.2' not in ip_network + assert '192.0.2.1' in ip_network + assert '192.0.3.1' in ip_network # 192.0.2.0 is treated as 192.0.2.0/23 + assert 'www.google.com' not in ip_network + + if __name__ == '__main__': test_inet_conv() test_parse_header() test_pack_header() + test_ip_network() diff --git a/shadowsocks/utils.py b/shadowsocks/utils.py index a51c965..6ea3daa 100644 --- a/shadowsocks/utils.py +++ b/shadowsocks/utils.py @@ -29,7 +29,7 @@ import json import sys import getopt import logging -from shadowsocks.common import to_bytes, to_str +from shadowsocks.common import to_bytes, to_str, IPNetwork VERBOSE_LEVEL = 5 @@ -193,6 +193,8 @@ def get_config(is_local): sys.exit(2) else: config['server'] = config.get('server', '0.0.0.0') + config['forbidden_ip'] = \ + IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) config['server_port'] = config.get('server_port', 8388) if is_local and not config.get('password', None): diff --git a/tests/test_large_file.sh b/tests/test_large_file.sh index e8acd79..33bcb59 100755 --- a/tests/test_large_file.sh +++ b/tests/test_large_file.sh @@ -8,7 +8,7 @@ mkdir -p tmp $PYTHON shadowsocks/local.py -c tests/aes.json & LOCAL=$! -$PYTHON shadowsocks/server.py -c tests/aes.json & +$PYTHON shadowsocks/server.py -c tests/aes.json --forbidden-ip "" & SERVER=$! sleep 3