|
|
@ -307,7 +307,7 @@ class DNSResolver(object): |
|
|
|
|
|
|
|
def _parse_hosts(self): |
|
|
|
etc_path = '/etc/hosts' |
|
|
|
if os.environ.__contains__('WINDIR'): |
|
|
|
if 'WINDIR' in os.environ: |
|
|
|
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts' |
|
|
|
try: |
|
|
|
with open(etc_path, 'rb') as f: |
|
|
@ -324,7 +324,7 @@ class DNSResolver(object): |
|
|
|
except IOError: |
|
|
|
self._hosts['localhost'] = '127.0.0.1' |
|
|
|
|
|
|
|
def add_to_loop(self, loop): |
|
|
|
def add_to_loop(self, loop, ref=False): |
|
|
|
if self._loop: |
|
|
|
raise Exception('already add to loop') |
|
|
|
self._loop = loop |
|
|
@ -333,21 +333,21 @@ class DNSResolver(object): |
|
|
|
socket.SOL_UDP) |
|
|
|
self._sock.setblocking(False) |
|
|
|
loop.add(self._sock, eventloop.POLL_IN) |
|
|
|
loop.add_handler(self.handle_events, ref=False) |
|
|
|
loop.add_handler(self.handle_events, ref=ref) |
|
|
|
|
|
|
|
def _call_callback(self, hostname, ip, error=None): |
|
|
|
callbacks = self._hostname_to_cb.get(hostname, []) |
|
|
|
for callback in callbacks: |
|
|
|
if self._cb_to_hostname.__contains__(callback): |
|
|
|
if callback in self._cb_to_hostname: |
|
|
|
del self._cb_to_hostname[callback] |
|
|
|
if ip or error: |
|
|
|
callback((hostname, ip), error) |
|
|
|
else: |
|
|
|
callback((hostname, None), |
|
|
|
Exception('unknown hostname %s' % hostname)) |
|
|
|
if self._hostname_to_cb.__contains__(hostname): |
|
|
|
if hostname in self._hostname_to_cb: |
|
|
|
del self._hostname_to_cb[hostname] |
|
|
|
if self._hostname_status.__contains__(hostname): |
|
|
|
if hostname in self._hostname_status: |
|
|
|
del self._hostname_status[hostname] |
|
|
|
|
|
|
|
def _handle_data(self, data): |
|
|
@ -408,7 +408,7 @@ class DNSResolver(object): |
|
|
|
arr.remove(callback) |
|
|
|
if not arr: |
|
|
|
del self._hostname_to_cb[hostname] |
|
|
|
if self._hostname_status.__contains__(hostname): |
|
|
|
if hostname in self._hostname_status: |
|
|
|
del self._hostname_status[hostname] |
|
|
|
|
|
|
|
def _send_req(self, hostname, qtype): |
|
|
@ -422,15 +422,17 @@ class DNSResolver(object): |
|
|
|
self._sock.sendto(req, (server, 53)) |
|
|
|
|
|
|
|
def resolve(self, hostname, callback): |
|
|
|
if type(hostname) != bytes: |
|
|
|
hostname = hostname.encode('utf8') |
|
|
|
if not hostname: |
|
|
|
callback(None, Exception('empty hostname')) |
|
|
|
elif is_ip(hostname): |
|
|
|
callback((hostname, hostname), None) |
|
|
|
elif self._hosts.__contains__(hostname): |
|
|
|
elif hostname in self._hosts: |
|
|
|
logging.debug('hit hosts: %s', hostname) |
|
|
|
ip = self._hosts[hostname] |
|
|
|
callback((hostname, ip), None) |
|
|
|
elif self._cache.__contains__(hostname): |
|
|
|
elif hostname in self._cache: |
|
|
|
logging.debug('hit cache: %s', hostname) |
|
|
|
ip = self._cache[hostname] |
|
|
|
callback((hostname, ip), None) |
|
|
@ -453,3 +455,53 @@ class DNSResolver(object): |
|
|
|
if self._sock: |
|
|
|
self._sock.close() |
|
|
|
self._sock = None |
|
|
|
|
|
|
|
|
|
|
|
def test(): |
|
|
|
dns_resolver = DNSResolver() |
|
|
|
loop = eventloop.EventLoop() |
|
|
|
dns_resolver.add_to_loop(loop, ref=True) |
|
|
|
|
|
|
|
global counter |
|
|
|
counter = 0 |
|
|
|
|
|
|
|
def make_callback(): |
|
|
|
global counter |
|
|
|
|
|
|
|
def callback(result, error): |
|
|
|
global counter |
|
|
|
# TODO: what can we assert? |
|
|
|
print(result, error) |
|
|
|
counter += 1 |
|
|
|
if counter == 10: |
|
|
|
loop.remove_handler(dns_resolver.handle_events) |
|
|
|
dns_resolver.close() |
|
|
|
a_callback = callback |
|
|
|
return a_callback |
|
|
|
|
|
|
|
assert(make_callback() != make_callback()) |
|
|
|
|
|
|
|
dns_resolver.resolve(b'google.com', make_callback()) |
|
|
|
dns_resolver.resolve('google.com', make_callback()) |
|
|
|
dns_resolver.resolve('example.com', make_callback()) |
|
|
|
dns_resolver.resolve('ipv6.google.com', make_callback()) |
|
|
|
dns_resolver.resolve('www.facebook.com', make_callback()) |
|
|
|
dns_resolver.resolve('ns2.google.com', make_callback()) |
|
|
|
dns_resolver.resolve('not.existed.google.com', make_callback()) |
|
|
|
dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback()) |
|
|
|
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'long.hostname', make_callback()) |
|
|
|
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'ooooooooooooooooooooooooooooooooooooooooooooooooooo' |
|
|
|
'long.hostname', make_callback()) |
|
|
|
|
|
|
|
loop.run() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test() |