Browse Source

add cache

auth
clowwindy 11 years ago
parent
commit
fbc4906445
  1. 32
      shadowsocks/asyncdns.py

32
shadowsocks/asyncdns.py

@ -25,6 +25,7 @@ import socket
import struct import struct
import logging import logging
import common import common
import lru_cache
import eventloop import eventloop
@ -242,8 +243,7 @@ class DNSResolver(object):
self._hostname_status = {} self._hostname_status = {}
self._hostname_to_cb = {} self._hostname_to_cb = {}
self._cb_to_hostname = {} self._cb_to_hostname = {}
# TODO add caching self._cache = lru_cache.LRUCache(timeout=300)
# TODO try ipv4 and ipv6 sequencely
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP) socket.SOL_UDP)
self._sock.setblocking(False) self._sock.setblocking(False)
@ -276,11 +276,21 @@ class DNSResolver(object):
loop.add(self._sock, eventloop.POLL_IN) loop.add(self._sock, eventloop.POLL_IN)
loop.add_handler(self.handle_events) loop.add_handler(self.handle_events)
def _call_callback(self, hostname, ip):
callbacks = self._hostname_to_cb.get(hostname, [])
for callback in callbacks:
if self._cb_to_hostname.__contains__(callback):
del self._cb_to_hostname[callback]
callback((hostname, ip), None)
if self._hostname_to_cb.__contains__(hostname):
del self._hostname_to_cb[hostname]
if self._hostname_status.__contains__(hostname):
del self._hostname_status[hostname]
def _handle_data(self, data): def _handle_data(self, data):
response = parse_response(data) response = parse_response(data)
if response and response.hostname: if response and response.hostname:
hostname = response.hostname hostname = response.hostname
callbacks = self._hostname_to_cb.get(hostname, [])
ip = None ip = None
for answer in response.answers: for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
@ -291,15 +301,9 @@ class DNSResolver(object):
== STATUS_IPV4: == STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6 self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA) self._send_req(hostname, QTYPE_AAAA)
return else:
for callback in callbacks: self._cache[hostname] = ip
if self._cb_to_hostname.__contains__(callback): self._call_callback(hostname, ip)
del self._cb_to_hostname[callback]
callback((hostname, ip), None)
if self._hostname_to_cb.__contains__(hostname):
del self._hostname_to_cb[hostname]
if self._hostname_status.__contains__(hostname):
del self._hostname_status[hostname]
def handle_events(self, events): def handle_events(self, events):
for sock, fd, event in events: for sock, fd, event in events:
@ -344,6 +348,10 @@ class DNSResolver(object):
callback(None, Exception('empty hostname')) callback(None, Exception('empty hostname'))
elif is_ip(hostname): elif is_ip(hostname):
callback(hostname, None) callback(hostname, None)
elif self._cache.__contains__(hostname):
logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname]
callback(ip, None)
else: else:
arr = self._hostname_to_cb.get(hostname, None) arr = self._hostname_to_cb.get(hostname, None)
if not arr: if not arr:

Loading…
Cancel
Save