diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 45676ad..02311d2 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -22,6 +22,7 @@ # SOFTWARE. import time +import os import socket import struct import logging @@ -173,7 +174,7 @@ def parse_response(data): # res_ra = header[2] & 128 res_rcode = header[2] & 15 assert res_tc == 0 - assert res_rcode == 0 + assert res_rcode in [0, 3] res_qdcount = header[3] res_ancount = header[4] res_nscount = header[5] @@ -238,15 +239,19 @@ class DNSResolver(object): def __init__(self): self._loop = None self._request_id = 1 + self._hosts = {} self._hostname_status = {} self._hostname_to_cb = {} self._cb_to_hostname = {} self._cache = lru_cache.LRUCache(timeout=300) self._last_time = time.time() self._sock = None - self._parse_config() + self._parse_resolv() + self._parse_hosts() + # TODO monitor hosts change and reload hosts + # TODO parse /etc/gai.conf and follow its rules - def _parse_config(self): + def _parse_resolv(self): try: with open('/etc/resolv.conf', 'rb') as f: servers = [] @@ -255,7 +260,7 @@ class DNSResolver(object): line = line.strip() if line: if line.startswith('nameserver'): - parts = line.split(' ') + parts = line.split() if len(parts) >= 2: server = parts[1] if is_ip(server): @@ -268,6 +273,25 @@ class DNSResolver(object): pass self._dns_server = ('8.8.8.8', 53) + def _parse_hosts(self): + etc_path = '/etc/hosts' + if os.environ.__contains__('WINDIR'): + etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts' + try: + with open(etc_path, 'rb') as f: + for line in f.readlines(): + line = line.strip() + parts = line.split() + if len(parts) >= 2: + ip = parts[0] + if is_ip(ip): + for i in xrange(1, len(parts)): + hostname = parts[i] + if hostname: + self._hosts[hostname] = ip + except IOError: + self._hosts['localhost'] = '127.0.0.1' + def add_to_loop(self, loop): if self._loop: raise Exception('already add to loop') @@ -356,11 +380,15 @@ class DNSResolver(object): if not hostname: callback(None, Exception('empty hostname')) elif is_ip(hostname): - callback(hostname, None) + callback((hostname, hostname), None) + elif self._hosts.__contains__(hostname): + logging.debug('hit hosts: %s', hostname) + ip = self._hosts[hostname] + callback((hostname, ip), None) elif self._cache.__contains__(hostname): logging.debug('hit cache: %s', hostname) ip = self._cache[hostname] - callback(ip, None) + callback((hostname, ip), None) else: arr = self._hostname_to_cb.get(hostname, None) if not arr: @@ -390,6 +418,8 @@ def test(): for hostname in ['www.google.com', '8.8.8.8', + 'localhost', + 'activate.adobe.com', 'www.twitter.com', 'ipv6.google.com', 'ipv6.l.google.com',