Browse Source

DNSResolver add black_hostname_list

TODO read black_hostname_list from config
akkariiin/master
Akkariiin 7 years ago
parent
commit
15b4d97b6c
  1. 29
      shadowsocks/asyncdns.py

29
shadowsocks/asyncdns.py

@ -27,12 +27,12 @@ import logging
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
import inspect import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0, os.path.join(file_path, '../')) sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import common, lru_cache, eventloop, shell from shadowsocks import common, lru_cache, eventloop, shell
CACHE_SWEEP_INTERVAL = 30 CACHE_SWEEP_INTERVAL = 30
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(?<!-)$", re.IGNORECASE) VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(?<!-)$", re.IGNORECASE)
@ -77,6 +77,7 @@ QTYPE_CNAME = 5
QTYPE_NS = 2 QTYPE_NS = 2
QCLASS_IN = 1 QCLASS_IN = 1
def detect_ipv6_supprot(): def detect_ipv6_supprot():
if 'has_ipv6' in dir(socket): if 'has_ipv6' in dir(socket):
try: try:
@ -89,8 +90,10 @@ def detect_ipv6_supprot():
print('IPv6 not support') print('IPv6 not support')
return False return False
IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot() IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot()
def build_address(address): def build_address(address):
address = address.strip(b'.') address = address.strip(b'.')
labels = address.split(b'.') labels = address.split(b'.')
@ -175,7 +178,7 @@ def parse_record(data, offset, question=False):
) )
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10) ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \ return nlen + 10 + record_rdlength, \
(name, ip, record_type, record_class, record_ttl) (name, ip, record_type, record_class, record_ttl)
else: else:
record_type, record_class = struct.unpack( record_type, record_class = struct.unpack(
'!HH', data[offset + nlen:offset + nlen + 4] '!HH', data[offset + nlen:offset + nlen + 4]
@ -209,7 +212,7 @@ def parse_response(data):
if not header: if not header:
return None return None
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \ res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
res_ancount, res_nscount, res_arcount = header res_ancount, res_nscount, res_arcount = header
qds = [] qds = []
ans = [] ans = []
@ -266,14 +269,18 @@ STATUS_IPV6 = 1
class DNSResolver(object): class DNSResolver(object):
def __init__(self, black_hostname_list=None):
def __init__(self):
self._loop = None self._loop = None
self._hosts = {} self._hosts = {}
self._hostname_status = {} self._hostname_status = {}
self._hostname_to_cb = {} self._hostname_to_cb = {}
self._cb_to_hostname = {} self._cb_to_hostname = {}
self._cache = lru_cache.LRUCache(timeout=300) self._cache = lru_cache.LRUCache(timeout=300)
# TODO read black_hostname_list from config
if type(black_hostname_list) != list:
self._black_hostname_list = []
else:
self._black_hostname_list = black_hostname_list
self._sock = None self._sock = None
self._servers = None self._servers = None
self._parse_resolv() self._parse_resolv()
@ -377,7 +384,7 @@ class DNSResolver(object):
ip = None ip = None
for answer in response.answers: for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN: answer[2] == QCLASS_IN:
ip = answer[0] ip = answer[0]
break break
if IPV6_CONNECTION_SUPPORT: if IPV6_CONNECTION_SUPPORT:
@ -465,16 +472,18 @@ class DNSResolver(object):
logging.debug('hit cache: %s', hostname) logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname] ip = self._cache[hostname]
callback((hostname, ip), None) callback((hostname, ip), None)
elif hostname in self._black_hostname_list:
callback(None, Exception('hostname <%s> in the black hostname list' % hostname))
else: else:
if not is_valid_hostname(hostname): if not is_valid_hostname(hostname):
callback(None, Exception('invalid hostname: %s' % hostname)) callback(None, Exception('invalid hostname: %s' % hostname))
return return
if False: if False:
addrs = socket.getaddrinfo(hostname, 0, 0, addrs = socket.getaddrinfo(hostname, 0, 0,
socket.SOCK_DGRAM, socket.SOL_UDP) socket.SOCK_DGRAM, socket.SOL_UDP)
if addrs: if addrs:
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
logging.debug('DNS resolve %s %s' % (hostname, sa[0]) ) logging.debug('DNS resolve %s %s' % (hostname, sa[0]))
self._cache[hostname] = sa[0] self._cache[hostname] = sa[0]
callback((hostname, sa[0]), None) callback((hostname, sa[0]), None)
return return
@ -524,10 +533,11 @@ def test():
if counter == 9: if counter == 9:
dns_resolver.close() dns_resolver.close()
loop.stop() loop.stop()
a_callback = callback a_callback = callback
return a_callback return a_callback
assert(make_callback() != make_callback()) assert (make_callback() != make_callback())
dns_resolver.resolve(b'google.com', make_callback()) dns_resolver.resolve(b'google.com', make_callback())
dns_resolver.resolve('google.com', make_callback()) dns_resolver.resolve('google.com', make_callback())
@ -552,4 +562,3 @@ def test():
if __name__ == '__main__': if __name__ == '__main__':
test() test()

Loading…
Cancel
Save