diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 5b5d6b2..426cbf5 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -28,6 +28,8 @@ import common import eventloop +common.patch_socket() + _request_count = 1 # rfc1035 @@ -99,40 +101,30 @@ def parse_ip(addrtype, data, length, offset): elif addrtype == QTYPE_AAAA: return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length]) elif addrtype == QTYPE_CNAME: - return parse_name(data, offset, length)[1] + return parse_name(data, offset)[1] else: return data -def parse_name(data, offset, length=512): +def parse_name(data, offset): p = offset - if (ord(data[offset]) & (128 + 64)) == (128 + 64): - # pointer - pointer = struct.unpack('!H', data[offset:offset + 2])[0] - pointer = pointer & 0x3FFF - if pointer == offset: - return (0, None) - return (2, parse_name(data, pointer)[1]) - else: - labels = [] + labels = [] + l = ord(data[p]) + while l > 0: + if (l & (128 + 64)) == (128 + 64): + # pointer + pointer = struct.unpack('!H', data[p:p + 2])[0] + pointer &= 0x3FFF + r = parse_name(data, pointer) + labels.append(r[1]) + p += 2 + # pointer is the end + return p - offset, '.'.join(labels) + else: + labels.append(data[p + 1:p + 1 + l]) + p += 1 + l l = ord(data[p]) - while l > 0 and p < offset + length: - if (l & (128 + 64)) == (128 + 64): - # pointer - pointer = struct.unpack('!H', data[p:p + 2])[0] - pointer = pointer & 0x3FFF - # if pointer == offset: - # return (0, None) - r = parse_name(data, pointer) - labels.append(r[1]) - p += 2 - # pointer is the end - return (p - offset + 1, '.'.join(labels)) - else: - labels.append(data[p + 1:p + 1 + l]) - p += 1 + l - l = ord(data[p]) - return (p - offset + 1, '.'.join(labels)) + return p - offset + 1, '.'.join(labels) # rfc1035 @@ -158,33 +150,30 @@ def parse_name(data, offset, length=512): # / / # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ def parse_record(data, offset, question=False): - len, name = parse_name(data, offset) - # TODO - assert len + nlen, name = parse_name(data, offset) if not question: record_type, record_class, record_ttl, record_rdlength = struct.unpack( - '!HHiH', data[offset + len:offset + len + 10] + '!HHiH', data[offset + nlen:offset + nlen + 10] ) - ip = parse_ip(record_type, data, record_rdlength, offset + len + 10) - return len + 10 + record_rdlength, \ + ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10) + return nlen + 10 + record_rdlength, \ (name, ip, record_type, record_class, record_ttl) else: record_type, record_class = struct.unpack( - '!HH', data[offset + len:offset + len + 4] + '!HH', data[offset + nlen:offset + nlen + 4] ) - return len + 4, (name, None, record_type, record_class, None, None) + return nlen + 4, (name, None, record_type, record_class, None, None) def parse_response(data): try: if len(data) >= 12: header = struct.unpack('!HBBHHHH', data[:12]) - res_id = header[0] - res_qr = header[1] & 128 + # res_id = header[0] + # res_qr = header[1] & 128 res_tc = header[1] & 2 - res_ra = header[2] & 128 + # res_ra = header[2] & 128 res_rcode = header[2] & 15 - # TODO check tc and rcode assert res_tc == 0 assert res_rcode == 0 res_qdcount = header[3] @@ -193,8 +182,6 @@ def parse_response(data): res_arcount = header[6] qds = [] ans = [] - nss = [] - ars = [] offset = 12 for i in xrange(0, res_qdcount): l, r = parse_record(data, offset, True) @@ -209,13 +196,9 @@ def parse_response(data): for i in xrange(0, res_nscount): l, r = parse_record(data, offset) offset += l - if r: - nss.append(r) for i in xrange(0, res_arcount): l, r = parse_record(data, offset) offset += l - if r: - ars.append(r) response = DNSResponse() if qds: response.hostname = qds[0][0] @@ -225,6 +208,7 @@ def parse_response(data): except Exception as e: import traceback traceback.print_exc() + logging.error(e) return None @@ -380,9 +364,9 @@ def test(): resolver = DNSResolver() resolver.add_to_loop(loop) + resolver.resolve('www.google.com', _callback) resolver.resolve('8.8.8.8', _callback) resolver.resolve('www.twitter.com', _callback) - resolver.resolve('www.google.com', _callback) resolver.resolve('ipv6.google.com', _callback) resolver.resolve('ipv6.l.google.com', _callback) resolver.resolve('www.gmail.com', _callback) diff --git a/shadowsocks/common.py b/shadowsocks/common.py index 4f7aac6..6104478 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -63,11 +63,15 @@ def inet_pton(family, addr): raise RuntimeError("What family?") -if not hasattr(socket, 'inet_pton'): - socket.inet_pton = inet_pton +def patch_socket(): + if not hasattr(socket, 'inet_pton'): + socket.inet_pton = inet_pton -if not hasattr(socket, 'inet_ntop'): - socket.inet_ntop = inet_ntop + if not hasattr(socket, 'inet_ntop'): + socket.inet_ntop = inet_ntop + + +patch_socket() ADDRTYPE_IPV4 = 1