diff --git a/shadowsocks/common.py b/shadowsocks/common.py index d582923..0f94e07 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -185,7 +185,7 @@ def parse_header(data): class IPNetwork(object): - ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128} + ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128, False: 0} def __init__(self, addrs): self._network_list_v4 = [] @@ -197,6 +197,7 @@ class IPNetwork(object): def add_network(self, addr): block = addr.split('/') addr_family = is_ip(block[0]) + addr_len = IPNetwork.ADDRLENGTH[addr_family] if addr_family is socket.AF_INET: ip, = struct.unpack("!I", socket.inet_aton(block[0])) elif addr_family is socket.AF_INET6: @@ -210,10 +211,9 @@ class IPNetwork(object): ip >>= 1 prefix_size += 1 logging.warn("You did't specify CIDR routing prefix size for %s, " - "implicit treated as %s/%d" % (addr, addr, - IPNetwork.ADDRLENGTH[addr_family] - prefix_size)) - elif block[1].isdigit() and int(block[1]) <= IPNetwork.ADDRLENGTH[addr_family]: - prefix_size = IPNetwork.ADDRLENGTH[addr_family] - int(block[1]) + "implicit treated as %s/%d" % (addr, addr, addr_len)) + elif block[1].isdigit() and int(block[1]) <= addr_len: + prefix_size = addr_len - int(block[1]) ip >>= prefix_size else: raise SyntaxError("Not a valid CIDR notation: %s" % addr) @@ -226,11 +226,13 @@ class IPNetwork(object): addr_family = is_ip(addr) if addr_family is socket.AF_INET: ip, = struct.unpack("!I", socket.inet_aton(addr)) - return any(map(lambda (naddr, ps): naddr == ip >> ps, self._network_list_v4)) + return any(map(lambda (n, ps): n == ip >> ps, + self._network_list_v4)) elif addr_family is socket.AF_INET6: hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr)) ip = (hi << 64) | lo - return any(map(lambda (naddr, ps): naddr == ip >> ps, self._network_list_v6)) + return any(map(lambda (n, ps): n == ip >> ps, + self._network_list_v6)) else: return False @@ -262,7 +264,7 @@ def test_pack_header(): def test_ip_network(): - ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.168.2.0') + ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.0.2.0') assert '127.0.0.1' in ip_network assert '127.0.1.1' not in ip_network assert ':ff:ffff' in ip_network @@ -271,8 +273,8 @@ def test_ip_network(): assert '::2' not in ip_network assert '192.168.1.1' in ip_network assert '192.168.1.2' not in ip_network - assert '192.168.2.1' in ip_network - assert '192.168.3.1' in ip_network # 192.168.2.0 is treated as 192.168.2.0/23 + assert '192.0.2.1' in ip_network + assert '192.0.3.1' in ip_network # 192.0.2.0 is treated as 192.0.2.0/23 assert 'www.google.com' not in ip_network