Browse Source

add asyncdns test

auth
clowwindy 10 years ago
parent
commit
030cdbcec0
  1. 70
      shadowsocks/asyncdns.py

70
shadowsocks/asyncdns.py

@ -307,7 +307,7 @@ class DNSResolver(object):
def _parse_hosts(self): def _parse_hosts(self):
etc_path = '/etc/hosts' etc_path = '/etc/hosts'
if os.environ.__contains__('WINDIR'): if 'WINDIR' in os.environ:
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts' etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts'
try: try:
with open(etc_path, 'rb') as f: with open(etc_path, 'rb') as f:
@ -324,7 +324,7 @@ class DNSResolver(object):
except IOError: except IOError:
self._hosts['localhost'] = '127.0.0.1' self._hosts['localhost'] = '127.0.0.1'
def add_to_loop(self, loop): def add_to_loop(self, loop, ref=False):
if self._loop: if self._loop:
raise Exception('already add to loop') raise Exception('already add to loop')
self._loop = loop self._loop = loop
@ -333,21 +333,21 @@ class DNSResolver(object):
socket.SOL_UDP) socket.SOL_UDP)
self._sock.setblocking(False) self._sock.setblocking(False)
loop.add(self._sock, eventloop.POLL_IN) loop.add(self._sock, eventloop.POLL_IN)
loop.add_handler(self.handle_events, ref=False) loop.add_handler(self.handle_events, ref=ref)
def _call_callback(self, hostname, ip, error=None): def _call_callback(self, hostname, ip, error=None):
callbacks = self._hostname_to_cb.get(hostname, []) callbacks = self._hostname_to_cb.get(hostname, [])
for callback in callbacks: for callback in callbacks:
if self._cb_to_hostname.__contains__(callback): if callback in self._cb_to_hostname:
del self._cb_to_hostname[callback] del self._cb_to_hostname[callback]
if ip or error: if ip or error:
callback((hostname, ip), error) callback((hostname, ip), error)
else: else:
callback((hostname, None), callback((hostname, None),
Exception('unknown hostname %s' % hostname)) Exception('unknown hostname %s' % hostname))
if self._hostname_to_cb.__contains__(hostname): if hostname in self._hostname_to_cb:
del self._hostname_to_cb[hostname] del self._hostname_to_cb[hostname]
if self._hostname_status.__contains__(hostname): if hostname in self._hostname_status:
del self._hostname_status[hostname] del self._hostname_status[hostname]
def _handle_data(self, data): def _handle_data(self, data):
@ -408,7 +408,7 @@ class DNSResolver(object):
arr.remove(callback) arr.remove(callback)
if not arr: if not arr:
del self._hostname_to_cb[hostname] del self._hostname_to_cb[hostname]
if self._hostname_status.__contains__(hostname): if hostname in self._hostname_status:
del self._hostname_status[hostname] del self._hostname_status[hostname]
def _send_req(self, hostname, qtype): def _send_req(self, hostname, qtype):
@ -422,15 +422,17 @@ class DNSResolver(object):
self._sock.sendto(req, (server, 53)) self._sock.sendto(req, (server, 53))
def resolve(self, hostname, callback): def resolve(self, hostname, callback):
if type(hostname) != bytes:
hostname = hostname.encode('utf8')
if not hostname: if not hostname:
callback(None, Exception('empty hostname')) callback(None, Exception('empty hostname'))
elif is_ip(hostname): elif is_ip(hostname):
callback((hostname, hostname), None) callback((hostname, hostname), None)
elif self._hosts.__contains__(hostname): elif hostname in self._hosts:
logging.debug('hit hosts: %s', hostname) logging.debug('hit hosts: %s', hostname)
ip = self._hosts[hostname] ip = self._hosts[hostname]
callback((hostname, ip), None) callback((hostname, ip), None)
elif self._cache.__contains__(hostname): elif hostname in self._cache:
logging.debug('hit cache: %s', hostname) logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname] ip = self._cache[hostname]
callback((hostname, ip), None) callback((hostname, ip), None)
@ -453,3 +455,53 @@ class DNSResolver(object):
if self._sock: if self._sock:
self._sock.close() self._sock.close()
self._sock = None self._sock = None
def test():
dns_resolver = DNSResolver()
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop, ref=True)
global counter
counter = 0
def make_callback():
global counter
def callback(result, error):
global counter
# TODO: what can we assert?
print(result, error)
counter += 1
if counter == 10:
loop.remove_handler(dns_resolver.handle_events)
dns_resolver.close()
a_callback = callback
return a_callback
assert(make_callback() != make_callback())
dns_resolver.resolve(b'google.com', make_callback())
dns_resolver.resolve('google.com', make_callback())
dns_resolver.resolve('example.com', make_callback())
dns_resolver.resolve('ipv6.google.com', make_callback())
dns_resolver.resolve('www.facebook.com', make_callback())
dns_resolver.resolve('ns2.google.com', make_callback())
dns_resolver.resolve('not.existed.google.com', make_callback())
dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
loop.run()
if __name__ == '__main__':
test()
Loading…
Cancel
Save