Python port of ShadowsocksR
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

484 lines
17 KiB

10 years ago
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2014-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
10 years ago
#
# http://www.apache.org/licenses/LICENSE-2.0
10 years ago
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
10 years ago
from __future__ import absolute_import, division, print_function, \
with_statement
10 years ago
import time
10 years ago
import os
10 years ago
import socket
import struct
10 years ago
import re
10 years ago
import logging
from shadowsocks import common, lru_cache, eventloop, utils
10 years ago
CACHE_SWEEP_INTERVAL = 30
10 years ago
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
10 years ago
10 years ago
common.patch_socket()
10 years ago
# rfc1035
10 years ago
# format
# +---------------------+
# | Header |
# +---------------------+
# | Question | the question for the name server
# +---------------------+
# | Answer | RRs answering the question
# +---------------------+
# | Authority | RRs pointing toward an authority
# +---------------------+
# | Additional | RRs holding additional information
# +---------------------+
#
# header
10 years ago
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ID |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | QDCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ANCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | NSCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ARCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
10 years ago
QTYPE_ANY = 255
QTYPE_A = 1
QTYPE_AAAA = 28
10 years ago
QTYPE_CNAME = 5
QTYPE_NS = 2
10 years ago
QCLASS_IN = 1
10 years ago
def build_address(address):
address = address.strip(b'.')
labels = address.split(b'.')
10 years ago
results = []
for label in labels:
l = len(label)
if l > 63:
return None
results.append(common.chr(l))
10 years ago
results.append(label)
results.append(b'\0')
return b''.join(results)
10 years ago
10 years ago
10 years ago
def build_request(address, qtype):
request_id = os.urandom(2)
header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0)
10 years ago
addr = build_address(address)
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
10 years ago
return request_id + header + addr + qtype_qclass
10 years ago
10 years ago
def parse_ip(addrtype, data, length, offset):
if addrtype == QTYPE_A:
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype in [QTYPE_CNAME, QTYPE_NS]:
10 years ago
return parse_name(data, offset)[1]
10 years ago
else:
return data[offset:offset + length]
10 years ago
10 years ago
def parse_name(data, offset):
10 years ago
p = offset
10 years ago
labels = []
l = common.ord(data[p])
10 years ago
while l > 0:
if (l & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[p:p + 2])[0]
pointer &= 0x3FFF
r = parse_name(data, pointer)
labels.append(r[1])
p += 2
# pointer is the end
return p - offset, b'.'.join(labels)
10 years ago
else:
labels.append(data[p + 1:p + 1 + l])
p += 1 + l
l = common.ord(data[p])
return p - offset + 1, b'.'.join(labels)
10 years ago
# rfc1035
# record
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | |
# / /
# / NAME /
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TYPE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | CLASS |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TTL |
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | RDLENGTH |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
# / RDATA /
# / /
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False):
10 years ago
nlen, name = parse_name(data, offset)
10 years ago
if not question:
record_type, record_class, record_ttl, record_rdlength = struct.unpack(
10 years ago
'!HHiH', data[offset + nlen:offset + nlen + 10]
10 years ago
)
10 years ago
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \
10 years ago
(name, ip, record_type, record_class, record_ttl)
else:
record_type, record_class = struct.unpack(
10 years ago
'!HH', data[offset + nlen:offset + nlen + 4]
10 years ago
)
10 years ago
return nlen + 4, (name, None, record_type, record_class, None, None)
10 years ago
def parse_header(data):
if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12])
res_id = header[0]
res_qr = header[1] & 128
res_tc = header[1] & 2
res_ra = header[2] & 128
res_rcode = header[2] & 15
# assert res_tc == 0
# assert res_rcode in [0, 3]
res_qdcount = header[3]
res_ancount = header[4]
res_nscount = header[5]
res_arcount = header[6]
return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount,
res_ancount, res_nscount, res_arcount)
return None
10 years ago
def parse_response(data):
10 years ago
try:
if len(data) >= 12:
header = parse_header(data)
if not header:
return None
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
res_ancount, res_nscount, res_arcount = header
10 years ago
qds = []
ans = []
offset = 12
for i in range(0, res_qdcount):
10 years ago
l, r = parse_record(data, offset, True)
offset += l
if r:
qds.append(r)
for i in range(0, res_ancount):
10 years ago
l, r = parse_record(data, offset)
offset += l
if r:
ans.append(r)
for i in range(0, res_nscount):
10 years ago
l, r = parse_record(data, offset)
offset += l
for i in range(0, res_arcount):
10 years ago
l, r = parse_record(data, offset)
offset += l
10 years ago
response = DNSResponse()
if qds:
response.hostname = qds[0][0]
10 years ago
for an in qds:
response.questions.append((an[1], an[2], an[3]))
10 years ago
for an in ans:
response.answers.append((an[1], an[2], an[3]))
return response
10 years ago
except Exception as e:
utils.print_exception(e)
10 years ago
return None
10 years ago
10 years ago
def is_valid_hostname(hostname):
if len(hostname) > 255:
return False
if hostname[-1] == b'.':
10 years ago
hostname = hostname[:-1]
return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.'))
10 years ago
10 years ago
class DNSResponse(object):
def __init__(self):
self.hostname = None
10 years ago
self.questions = [] # each: (addr, type, class)
10 years ago
self.answers = [] # each: (addr, type, class)
def __str__(self):
return '%s: %s' % (self.hostname, str(self.answers))
STATUS_IPV4 = 0
STATUS_IPV6 = 1
class DNSResolver(object):
def __init__(self):
self._loop = None
10 years ago
self._hosts = {}
10 years ago
self._hostname_status = {}
self._hostname_to_cb = {}
self._cb_to_hostname = {}
10 years ago
self._cache = lru_cache.LRUCache(timeout=300)
10 years ago
self._last_time = time.time()
10 years ago
self._sock = None
self._servers = None
10 years ago
self._parse_resolv()
self._parse_hosts()
# TODO monitor hosts change and reload hosts
# TODO parse /etc/gai.conf and follow its rules
10 years ago
10 years ago
def _parse_resolv(self):
self._servers = []
10 years ago
try:
with open('/etc/resolv.conf', 'rb') as f:
content = f.readlines()
for line in content:
line = line.strip()
if line:
if line.startswith(b'nameserver'):
10 years ago
parts = line.split()
10 years ago
if len(parts) >= 2:
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
10 years ago
except IOError:
pass
if not self._servers:
self._servers = ['8.8.4.4', '8.8.8.8']
10 years ago
10 years ago
def _parse_hosts(self):
etc_path = '/etc/hosts'
10 years ago
if 'WINDIR' in os.environ:
10 years ago
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts'
try:
with open(etc_path, 'rb') as f:
for line in f.readlines():
line = line.strip()
parts = line.split()
if len(parts) >= 2:
ip = parts[0]
if common.is_ip(ip):
for i in range(1, len(parts)):
10 years ago
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError:
self._hosts['localhost'] = '127.0.0.1'
10 years ago
def add_to_loop(self, loop, ref=False):
10 years ago
if self._loop:
raise Exception('already add to loop')
10 years ago
self._loop = loop
10 years ago
# TODO when dns server is IPv6
10 years ago
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
10 years ago
loop.add(self._sock, eventloop.POLL_IN)
10 years ago
loop.add_handler(self.handle_events, ref=ref)
10 years ago
def _call_callback(self, hostname, ip, error=None):
10 years ago
callbacks = self._hostname_to_cb.get(hostname, [])
for callback in callbacks:
10 years ago
if callback in self._cb_to_hostname:
10 years ago
del self._cb_to_hostname[callback]
if ip or error:
callback((hostname, ip), error)
else:
callback((hostname, None),
Exception('unknown hostname %s' % hostname))
10 years ago
if hostname in self._hostname_to_cb:
10 years ago
del self._hostname_to_cb[hostname]
10 years ago
if hostname in self._hostname_status:
10 years ago
del self._hostname_status[hostname]
10 years ago
def _handle_data(self, data):
response = parse_response(data)
if response and response.hostname:
hostname = response.hostname
ip = None
for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
10 years ago
else:
if ip:
self._cache[hostname] = ip
10 years ago
self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
for question in response.questions:
if question[1] == QTYPE_AAAA:
self._call_callback(hostname, None)
break
10 years ago
def handle_events(self, events):
for sock, fd, event in events:
if sock != self._sock:
continue
if event & eventloop.POLL_ERR:
logging.error('dns socket err')
self._loop.remove(self._sock)
self._sock.close()
10 years ago
# TODO when dns server is IPv6
10 years ago
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
self._loop.add(self._sock, eventloop.POLL_IN)
else:
data, addr = sock.recvfrom(1024)
if addr[0] not in self._servers:
10 years ago
logging.warn('received a packet other than our dns')
break
self._handle_data(data)
break
10 years ago
now = time.time()
if now - self._last_time > CACHE_SWEEP_INTERVAL:
self._cache.sweep()
self._last_time = now
10 years ago
def remove_callback(self, callback):
hostname = self._cb_to_hostname.get(callback)
if hostname:
del self._cb_to_hostname[callback]
arr = self._hostname_to_cb.get(hostname, None)
if arr:
arr.remove(callback)
if not arr:
del self._hostname_to_cb[hostname]
10 years ago
if hostname in self._hostname_status:
10 years ago
del self._hostname_status[hostname]
10 years ago
def _send_req(self, hostname, qtype):
10 years ago
req = build_request(hostname, qtype)
for server in self._servers:
10 years ago
logging.debug('resolving %s with type %d using server %s',
hostname, qtype, server)
self._sock.sendto(req, (server, 53))
10 years ago
def resolve(self, hostname, callback):
10 years ago
if type(hostname) != bytes:
hostname = hostname.encode('utf8')
10 years ago
if not hostname:
callback(None, Exception('empty hostname'))
elif common.is_ip(hostname):
10 years ago
callback((hostname, hostname), None)
10 years ago
elif hostname in self._hosts:
10 years ago
logging.debug('hit hosts: %s', hostname)
ip = self._hosts[hostname]
callback((hostname, ip), None)
10 years ago
elif hostname in self._cache:
10 years ago
logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname]
10 years ago
callback((hostname, ip), None)
10 years ago
else:
10 years ago
if not is_valid_hostname(hostname):
callback(None, Exception('invalid hostname: %s' % hostname))
return
10 years ago
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
self._cb_to_hostname[callback] = hostname
else:
arr.append(callback)
# TODO send again only if waited too long
self._send_req(hostname, QTYPE_A)
10 years ago
10 years ago
def close(self):
10 years ago
if self._sock:
self._sock.close()
self._sock = None
10 years ago
def test():
dns_resolver = DNSResolver()
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop, ref=True)
global counter
counter = 0
def make_callback():
global counter
def callback(result, error):
global counter
# TODO: what can we assert?
print(result, error)
counter += 1
if counter == 9:
10 years ago
loop.remove_handler(dns_resolver.handle_events)
dns_resolver.close()
a_callback = callback
return a_callback
assert(make_callback() != make_callback())
dns_resolver.resolve(b'google.com', make_callback())
dns_resolver.resolve('google.com', make_callback())
dns_resolver.resolve('example.com', make_callback())
dns_resolver.resolve('ipv6.google.com', make_callback())
dns_resolver.resolve('www.facebook.com', make_callback())
dns_resolver.resolve('ns2.google.com', make_callback())
dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
loop.run()
if __name__ == '__main__':
test()