Python port of ShadowsocksR
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

484 lines
17 KiB

10 years ago
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2014 clowwindy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
10 years ago
import time
10 years ago
import os
10 years ago
import socket
import struct
10 years ago
import re
10 years ago
import logging
10 years ago
import common
10 years ago
import lru_cache
10 years ago
import eventloop
10 years ago
CACHE_SWEEP_INTERVAL = 30
10 years ago
10 years ago
VALID_HOSTNAME = re.compile("(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
10 years ago
common.patch_socket()
10 years ago
# rfc1035
10 years ago
# format
# +---------------------+
# | Header |
# +---------------------+
# | Question | the question for the name server
# +---------------------+
# | Answer | RRs answering the question
# +---------------------+
# | Authority | RRs pointing toward an authority
# +---------------------+
# | Additional | RRs holding additional information
# +---------------------+
#
# header
10 years ago
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ID |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | QDCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ANCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | NSCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ARCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
10 years ago
QTYPE_ANY = 255
QTYPE_A = 1
QTYPE_AAAA = 28
10 years ago
QTYPE_CNAME = 5
QTYPE_NS = 2
10 years ago
QCLASS_IN = 1
10 years ago
def build_address(address):
10 years ago
address = address.strip('.')
labels = address.split('.')
results = []
for label in labels:
l = len(label)
if l > 63:
return None
results.append(chr(l))
results.append(label)
results.append('\0')
return ''.join(results)
10 years ago
10 years ago
def build_request(address, qtype, request_id):
header = struct.pack('!HBBHHHH', request_id, 1, 0, 1, 0, 0, 0)
10 years ago
addr = build_address(address)
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
10 years ago
return header + addr + qtype_qclass
10 years ago
10 years ago
def parse_ip(addrtype, data, length, offset):
if addrtype == QTYPE_A:
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype in [QTYPE_CNAME, QTYPE_NS]:
10 years ago
return parse_name(data, offset)[1]
10 years ago
else:
return data[offset:offset + length]
10 years ago
10 years ago
def parse_name(data, offset):
10 years ago
p = offset
10 years ago
labels = []
l = ord(data[p])
while l > 0:
if (l & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[p:p + 2])[0]
pointer &= 0x3FFF
r = parse_name(data, pointer)
labels.append(r[1])
p += 2
# pointer is the end
return p - offset, '.'.join(labels)
else:
labels.append(data[p + 1:p + 1 + l])
p += 1 + l
10 years ago
l = ord(data[p])
10 years ago
return p - offset + 1, '.'.join(labels)
10 years ago
# rfc1035
# record
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | |
# / /
# / NAME /
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TYPE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | CLASS |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TTL |
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | RDLENGTH |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
# / RDATA /
# / /
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False):
10 years ago
nlen, name = parse_name(data, offset)
10 years ago
if not question:
record_type, record_class, record_ttl, record_rdlength = struct.unpack(
10 years ago
'!HHiH', data[offset + nlen:offset + nlen + 10]
10 years ago
)
10 years ago
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \
10 years ago
(name, ip, record_type, record_class, record_ttl)
else:
record_type, record_class = struct.unpack(
10 years ago
'!HH', data[offset + nlen:offset + nlen + 4]
10 years ago
)
10 years ago
return nlen + 4, (name, None, record_type, record_class, None, None)
10 years ago
def parse_header(data):
if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12])
res_id = header[0]
res_qr = header[1] & 128
res_tc = header[1] & 2
res_ra = header[2] & 128
res_rcode = header[2] & 15
# assert res_tc == 0
# assert res_rcode in [0, 3]
res_qdcount = header[3]
res_ancount = header[4]
res_nscount = header[5]
res_arcount = header[6]
return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount,
res_ancount, res_nscount, res_arcount)
return None
10 years ago
def parse_response(data):
10 years ago
try:
if len(data) >= 12:
header = parse_header(data)
if not header:
return None
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
res_ancount, res_nscount, res_arcount = header
10 years ago
qds = []
ans = []
offset = 12
for i in xrange(0, res_qdcount):
l, r = parse_record(data, offset, True)
offset += l
if r:
qds.append(r)
for i in xrange(0, res_ancount):
l, r = parse_record(data, offset)
offset += l
if r:
ans.append(r)
for i in xrange(0, res_nscount):
l, r = parse_record(data, offset)
offset += l
for i in xrange(0, res_arcount):
l, r = parse_record(data, offset)
offset += l
10 years ago
response = DNSResponse()
if qds:
response.hostname = qds[0][0]
10 years ago
for an in qds:
response.questions.append((an[1], an[2], an[3]))
10 years ago
for an in ans:
response.answers.append((an[1], an[2], an[3]))
return response
10 years ago
except Exception as e:
import traceback
traceback.print_exc()
10 years ago
logging.error(e)
10 years ago
return None
10 years ago
10 years ago
def is_ip(address):
for family in (socket.AF_INET, socket.AF_INET6):
try:
socket.inet_pton(family, address)
return family
10 years ago
except (TypeError, ValueError, OSError, IOError):
10 years ago
pass
return False
10 years ago
def is_valid_hostname(hostname):
if len(hostname) > 255:
return False
if hostname[-1] == ".":
hostname = hostname[:-1]
return all(VALID_HOSTNAME.match(x) for x in hostname.split("."))
10 years ago
class DNSResponse(object):
def __init__(self):
self.hostname = None
10 years ago
self.questions = [] # each: (addr, type, class)
10 years ago
self.answers = [] # each: (addr, type, class)
def __str__(self):
return '%s: %s' % (self.hostname, str(self.answers))
STATUS_IPV4 = 0
STATUS_IPV6 = 1
class DNSResolver(object):
def __init__(self):
self._loop = None
10 years ago
self._request_id = 1
10 years ago
self._hosts = {}
10 years ago
self._hostname_status = {}
self._hostname_to_cb = {}
self._cb_to_hostname = {}
10 years ago
self._cache = lru_cache.LRUCache(timeout=300)
10 years ago
self._last_time = time.time()
10 years ago
self._sock = None
self._servers = None
10 years ago
self._parse_resolv()
self._parse_hosts()
# TODO monitor hosts change and reload hosts
# TODO parse /etc/gai.conf and follow its rules
10 years ago
10 years ago
def _parse_resolv(self):
self._servers = []
10 years ago
try:
with open('/etc/resolv.conf', 'rb') as f:
content = f.readlines()
for line in content:
line = line.strip()
if line:
if line.startswith('nameserver'):
10 years ago
parts = line.split()
10 years ago
if len(parts) >= 2:
server = parts[1]
if is_ip(server) == socket.AF_INET:
self._servers.append(server)
10 years ago
except IOError:
pass
if not self._servers:
self._servers = ['8.8.4.4', '8.8.8.8']
10 years ago
10 years ago
def _parse_hosts(self):
etc_path = '/etc/hosts'
if os.environ.__contains__('WINDIR'):
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts'
try:
with open(etc_path, 'rb') as f:
for line in f.readlines():
line = line.strip()
parts = line.split()
if len(parts) >= 2:
ip = parts[0]
if is_ip(ip):
for i in xrange(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError:
self._hosts['localhost'] = '127.0.0.1'
10 years ago
def add_to_loop(self, loop):
10 years ago
if self._loop:
raise Exception('already add to loop')
10 years ago
self._loop = loop
10 years ago
# TODO when dns server is IPv6
10 years ago
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
10 years ago
loop.add(self._sock, eventloop.POLL_IN)
loop.add_handler(self.handle_events, ref=False)
10 years ago
def _call_callback(self, hostname, ip, error=None):
10 years ago
callbacks = self._hostname_to_cb.get(hostname, [])
for callback in callbacks:
if self._cb_to_hostname.__contains__(callback):
del self._cb_to_hostname[callback]
if ip or error:
callback((hostname, ip), error)
else:
callback((hostname, None),
Exception('unknown hostname %s' % hostname))
10 years ago
if self._hostname_to_cb.__contains__(hostname):
del self._hostname_to_cb[hostname]
if self._hostname_status.__contains__(hostname):
del self._hostname_status[hostname]
10 years ago
def _handle_data(self, data):
response = parse_response(data)
if response and response.hostname:
hostname = response.hostname
ip = None
for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
10 years ago
else:
if ip:
self._cache[hostname] = ip
10 years ago
self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
self._call_callback(hostname, None)
10 years ago
def handle_events(self, events):
for sock, fd, event in events:
if sock != self._sock:
continue
if event & eventloop.POLL_ERR:
logging.error('dns socket err')
self._loop.remove(self._sock)
self._sock.close()
10 years ago
# TODO when dns server is IPv6
10 years ago
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
self._loop.add(self._sock, eventloop.POLL_IN)
else:
data, addr = sock.recvfrom(1024)
if addr[0] not in self._servers:
10 years ago
logging.warn('received a packet other than our dns')
break
self._handle_data(data)
break
10 years ago
now = time.time()
if now - self._last_time > CACHE_SWEEP_INTERVAL:
self._cache.sweep()
self._last_time = now
10 years ago
def remove_callback(self, callback):
hostname = self._cb_to_hostname.get(callback)
if hostname:
del self._cb_to_hostname[callback]
arr = self._hostname_to_cb.get(hostname, None)
if arr:
arr.remove(callback)
if not arr:
del self._hostname_to_cb[hostname]
10 years ago
if self._hostname_status.__contains__(hostname):
del self._hostname_status[hostname]
10 years ago
def _send_req(self, hostname, qtype):
10 years ago
self._request_id += 1
if self._request_id > 32768:
self._request_id = 1
req = build_request(hostname, qtype, self._request_id)
for server in self._servers:
10 years ago
logging.debug('resolving %s with type %d using server %s',
hostname, qtype, server)
self._sock.sendto(req, (server, 53))
10 years ago
def resolve(self, hostname, callback):
if not hostname:
callback(None, Exception('empty hostname'))
elif is_ip(hostname):
10 years ago
callback((hostname, hostname), None)
elif self._hosts.__contains__(hostname):
logging.debug('hit hosts: %s', hostname)
ip = self._hosts[hostname]
callback((hostname, ip), None)
10 years ago
elif self._cache.__contains__(hostname):
logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname]
10 years ago
callback((hostname, ip), None)
10 years ago
else:
10 years ago
if not is_valid_hostname(hostname):
callback(None, Exception('invalid hostname: %s' % hostname))
return
10 years ago
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
self._cb_to_hostname[callback] = hostname
else:
arr.append(callback)
# TODO send again only if waited too long
self._send_req(hostname, QTYPE_A)
10 years ago
10 years ago
def close(self):
10 years ago
if self._sock:
self._sock.close()
self._sock = None
10 years ago
10 years ago
def test():
10 years ago
logging.getLogger('').handlers = []
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', filemode='a+')
def _callback(address, error):
print error, address
loop = eventloop.EventLoop()
resolver = DNSResolver()
resolver.add_to_loop(loop)
10 years ago
for hostname in ['www.google.com',
'8.8.8.8',
10 years ago
'localhost',
'activate.adobe.com',
10 years ago
'www.twitter.com',
'ipv6.google.com',
'ipv6.l.google.com',
'www.gmail.com',
'r4---sn-3qqp-ioql.googlevideo.com',
'www.baidu.com',
'www.a.shifen.com',
'm.baidu.jp',
'www.youku.com',
'www.twitter.com',
'ipv6.google.com']:
resolver.resolve(hostname, _callback)
10 years ago
loop.run()
10 years ago
if __name__ == '__main__':
test()