diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 4354b1d..46459c9 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -307,7 +307,7 @@ class DNSResolver(object): def _parse_hosts(self): etc_path = '/etc/hosts' - if os.environ.__contains__('WINDIR'): + if 'WINDIR' in os.environ: etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts' try: with open(etc_path, 'rb') as f: @@ -324,7 +324,7 @@ class DNSResolver(object): except IOError: self._hosts['localhost'] = '127.0.0.1' - def add_to_loop(self, loop): + def add_to_loop(self, loop, ref=False): if self._loop: raise Exception('already add to loop') self._loop = loop @@ -333,21 +333,21 @@ class DNSResolver(object): socket.SOL_UDP) self._sock.setblocking(False) loop.add(self._sock, eventloop.POLL_IN) - loop.add_handler(self.handle_events, ref=False) + loop.add_handler(self.handle_events, ref=ref) def _call_callback(self, hostname, ip, error=None): callbacks = self._hostname_to_cb.get(hostname, []) for callback in callbacks: - if self._cb_to_hostname.__contains__(callback): + if callback in self._cb_to_hostname: del self._cb_to_hostname[callback] if ip or error: callback((hostname, ip), error) else: callback((hostname, None), Exception('unknown hostname %s' % hostname)) - if self._hostname_to_cb.__contains__(hostname): + if hostname in self._hostname_to_cb: del self._hostname_to_cb[hostname] - if self._hostname_status.__contains__(hostname): + if hostname in self._hostname_status: del self._hostname_status[hostname] def _handle_data(self, data): @@ -408,7 +408,7 @@ class DNSResolver(object): arr.remove(callback) if not arr: del self._hostname_to_cb[hostname] - if self._hostname_status.__contains__(hostname): + if hostname in self._hostname_status: del self._hostname_status[hostname] def _send_req(self, hostname, qtype): @@ -422,15 +422,17 @@ class DNSResolver(object): self._sock.sendto(req, (server, 53)) def resolve(self, hostname, callback): + if type(hostname) != bytes: + hostname = hostname.encode('utf8') if not hostname: callback(None, Exception('empty hostname')) elif is_ip(hostname): callback((hostname, hostname), None) - elif self._hosts.__contains__(hostname): + elif hostname in self._hosts: logging.debug('hit hosts: %s', hostname) ip = self._hosts[hostname] callback((hostname, ip), None) - elif self._cache.__contains__(hostname): + elif hostname in self._cache: logging.debug('hit cache: %s', hostname) ip = self._cache[hostname] callback((hostname, ip), None) @@ -453,3 +455,53 @@ class DNSResolver(object): if self._sock: self._sock.close() self._sock = None + + +def test(): + dns_resolver = DNSResolver() + loop = eventloop.EventLoop() + dns_resolver.add_to_loop(loop, ref=True) + + global counter + counter = 0 + + def make_callback(): + global counter + + def callback(result, error): + global counter + # TODO: what can we assert? + print(result, error) + counter += 1 + if counter == 10: + loop.remove_handler(dns_resolver.handle_events) + dns_resolver.close() + a_callback = callback + return a_callback + + assert(make_callback() != make_callback()) + + dns_resolver.resolve(b'google.com', make_callback()) + dns_resolver.resolve('google.com', make_callback()) + dns_resolver.resolve('example.com', make_callback()) + dns_resolver.resolve('ipv6.google.com', make_callback()) + dns_resolver.resolve('www.facebook.com', make_callback()) + dns_resolver.resolve('ns2.google.com', make_callback()) + dns_resolver.resolve('not.existed.google.com', make_callback()) + dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback()) + dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'long.hostname', make_callback()) + dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'long.hostname', make_callback()) + + loop.run() + + +if __name__ == '__main__': + test() \ No newline at end of file