Python port of ShadowsocksR
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

492 lines
18 KiB

11 years ago
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2014 clowwindy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import absolute_import, division, print_function, \
with_statement
11 years ago
import time
11 years ago
import os
11 years ago
import socket
import struct
11 years ago
import re
11 years ago
import logging
from shadowsocks import common, lru_cache, eventloop
11 years ago
CACHE_SWEEP_INTERVAL = 30
11 years ago
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
11 years ago
11 years ago
common.patch_socket()
11 years ago
# rfc1035
11 years ago
# format
# +---------------------+
# | Header |
# +---------------------+
# | Question | the question for the name server
# +---------------------+
# | Answer | RRs answering the question
# +---------------------+
# | Authority | RRs pointing toward an authority
# +---------------------+
# | Additional | RRs holding additional information
# +---------------------+
#
# header
11 years ago
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ID |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | QDCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ANCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | NSCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ARCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
11 years ago
QTYPE_ANY = 255
QTYPE_A = 1
QTYPE_AAAA = 28
11 years ago
QTYPE_CNAME = 5
QTYPE_NS = 2
11 years ago
QCLASS_IN = 1
11 years ago
def build_address(address):
address = address.strip(b'.')
labels = address.split(b'.')
11 years ago
results = []
for label in labels:
l = len(label)
if l > 63:
return None
results.append(common.chr(l))
11 years ago
results.append(label)
results.append(b'\0')
return b''.join(results)
11 years ago
11 years ago
10 years ago
def build_request(address, qtype):
request_id = os.urandom(2)
header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0)
11 years ago
addr = build_address(address)
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
10 years ago
return request_id + header + addr + qtype_qclass
11 years ago
11 years ago
def parse_ip(addrtype, data, length, offset):
if addrtype == QTYPE_A:
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype in [QTYPE_CNAME, QTYPE_NS]:
11 years ago
return parse_name(data, offset)[1]
11 years ago
else:
return data[offset:offset + length]
11 years ago
11 years ago
def parse_name(data, offset):
11 years ago
p = offset
11 years ago
labels = []
l = common.ord(data[p])
11 years ago
while l > 0:
if (l & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[p:p + 2])[0]
pointer &= 0x3FFF
r = parse_name(data, pointer)
labels.append(r[1])
p += 2
# pointer is the end
return p - offset, b'.'.join(labels)
11 years ago
else:
labels.append(data[p + 1:p + 1 + l])
p += 1 + l
l = common.ord(data[p])
return p - offset + 1, b'.'.join(labels)
11 years ago
# rfc1035
# record
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | |
# / /
# / NAME /
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TYPE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | CLASS |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TTL |
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | RDLENGTH |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
# / RDATA /
# / /
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False):
11 years ago
nlen, name = parse_name(data, offset)
11 years ago
if not question:
record_type, record_class, record_ttl, record_rdlength = struct.unpack(
11 years ago
'!HHiH', data[offset + nlen:offset + nlen + 10]
11 years ago
)
11 years ago
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \
11 years ago
(name, ip, record_type, record_class, record_ttl)
else:
record_type, record_class = struct.unpack(
11 years ago
'!HH', data[offset + nlen:offset + nlen + 4]
11 years ago
)
11 years ago
return nlen + 4, (name, None, record_type, record_class, None, None)
11 years ago
def parse_header(data):
if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12])
res_id = header[0]
res_qr = header[1] & 128
res_tc = header[1] & 2
res_ra = header[2] & 128
res_rcode = header[2] & 15
# assert res_tc == 0
# assert res_rcode in [0, 3]
res_qdcount = header[3]
res_ancount = header[4]
res_nscount = header[5]
res_arcount = header[6]
return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount,
res_ancount, res_nscount, res_arcount)
return None
11 years ago
def parse_response(data):
11 years ago
try:
if len(data) >= 12:
header = parse_header(data)
if not header:
return None
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
res_ancount, res_nscount, res_arcount = header
11 years ago
qds = []
ans = []
offset = 12
for i in range(0, res_qdcount):
11 years ago
l, r = parse_record(data, offset, True)
offset += l
if r:
qds.append(r)
for i in range(0, res_ancount):
11 years ago
l, r = parse_record(data, offset)
offset += l
if r:
ans.append(r)
for i in range(0, res_nscount):
11 years ago
l, r = parse_record(data, offset)
offset += l
for i in range(0, res_arcount):
11 years ago
l, r = parse_record(data, offset)
offset += l
11 years ago
response = DNSResponse()
if qds:
response.hostname = qds[0][0]
11 years ago
for an in qds:
response.questions.append((an[1], an[2], an[3]))
11 years ago
for an in ans:
response.answers.append((an[1], an[2], an[3]))
return response
11 years ago
except Exception as e:
import traceback
traceback.print_exc()
11 years ago
logging.error(e)
11 years ago
return None
11 years ago
11 years ago
def is_valid_hostname(hostname):
if len(hostname) > 255:
return False
if hostname[-1] == b'.':
11 years ago
hostname = hostname[:-1]
return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.'))
11 years ago
11 years ago
class DNSResponse(object):
def __init__(self):
self.hostname = None
11 years ago
self.questions = [] # each: (addr, type, class)
11 years ago
self.answers = [] # each: (addr, type, class)
def __str__(self):
return '%s: %s' % (self.hostname, str(self.answers))
STATUS_IPV4 = 0
STATUS_IPV6 = 1
class DNSResolver(object):
def __init__(self):
self._loop = None
11 years ago
self._hosts = {}
11 years ago
self._hostname_status = {}
self._hostname_to_cb = {}
self._cb_to_hostname = {}
11 years ago
self._cache = lru_cache.LRUCache(timeout=300)
11 years ago
self._last_time = time.time()
11 years ago
self._sock = None
self._servers = None
11 years ago
self._parse_resolv()
self._parse_hosts()
# TODO monitor hosts change and reload hosts
# TODO parse /etc/gai.conf and follow its rules
11 years ago
11 years ago
def _parse_resolv(self):
self._servers = []
11 years ago
try:
with open('/etc/resolv.conf', 'rb') as f:
content = f.readlines()
for line in content:
line = line.strip()
if line:
if line.startswith(b'nameserver'):
11 years ago
parts = line.split()
11 years ago
if len(parts) >= 2:
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
11 years ago
except IOError:
pass
if not self._servers:
self._servers = ['8.8.4.4', '8.8.8.8']
11 years ago
11 years ago
def _parse_hosts(self):
etc_path = '/etc/hosts'
10 years ago
if 'WINDIR' in os.environ:
11 years ago
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts'
try:
with open(etc_path, 'rb') as f:
for line in f.readlines():
line = line.strip()
parts = line.split()
if len(parts) >= 2:
ip = parts[0]
if common.is_ip(ip):
for i in range(1, len(parts)):
11 years ago
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError:
self._hosts['localhost'] = '127.0.0.1'
10 years ago
def add_to_loop(self, loop, ref=False):
11 years ago
if self._loop:
raise Exception('already add to loop')
11 years ago
self._loop = loop
11 years ago
# TODO when dns server is IPv6
11 years ago
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
11 years ago
loop.add(self._sock, eventloop.POLL_IN)
10 years ago
loop.add_handler(self.handle_events, ref=ref)
11 years ago
def _call_callback(self, hostname, ip, error=None):
11 years ago
callbacks = self._hostname_to_cb.get(hostname, [])
for callback in callbacks:
10 years ago
if callback in self._cb_to_hostname:
11 years ago
del self._cb_to_hostname[callback]
if ip or error:
callback((hostname, ip), error)
else:
callback((hostname, None),
Exception('unknown hostname %s' % hostname))
10 years ago
if hostname in self._hostname_to_cb:
11 years ago
del self._hostname_to_cb[hostname]
10 years ago
if hostname in self._hostname_status:
11 years ago
del self._hostname_status[hostname]
11 years ago
def _handle_data(self, data):
response = parse_response(data)
if response and response.hostname:
hostname = response.hostname
ip = None
for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
11 years ago
else:
if ip:
self._cache[hostname] = ip
11 years ago
self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
for question in response.questions:
if question[1] == QTYPE_AAAA:
self._call_callback(hostname, None)
break
11 years ago
def handle_events(self, events):
for sock, fd, event in events:
if sock != self._sock:
continue
if event & eventloop.POLL_ERR:
logging.error('dns socket err')
self._loop.remove(self._sock)
self._sock.close()
11 years ago
# TODO when dns server is IPv6
11 years ago
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
self._loop.add(self._sock, eventloop.POLL_IN)
else:
data, addr = sock.recvfrom(1024)
if addr[0] not in self._servers:
11 years ago
logging.warn('received a packet other than our dns')
break
self._handle_data(data)
break
11 years ago
now = time.time()
if now - self._last_time > CACHE_SWEEP_INTERVAL:
self._cache.sweep()
self._last_time = now
11 years ago
def remove_callback(self, callback):
hostname = self._cb_to_hostname.get(callback)
if hostname:
del self._cb_to_hostname[callback]
arr = self._hostname_to_cb.get(hostname, None)
if arr:
arr.remove(callback)
if not arr:
del self._hostname_to_cb[hostname]
10 years ago
if hostname in self._hostname_status:
11 years ago
del self._hostname_status[hostname]
11 years ago
def _send_req(self, hostname, qtype):
10 years ago
req = build_request(hostname, qtype)
for server in self._servers:
11 years ago
logging.debug('resolving %s with type %d using server %s',
hostname, qtype, server)
self._sock.sendto(req, (server, 53))
11 years ago
def resolve(self, hostname, callback):
10 years ago
if type(hostname) != bytes:
hostname = hostname.encode('utf8')
11 years ago
if not hostname:
callback(None, Exception('empty hostname'))
elif common.is_ip(hostname):
11 years ago
callback((hostname, hostname), None)
10 years ago
elif hostname in self._hosts:
11 years ago
logging.debug('hit hosts: %s', hostname)
ip = self._hosts[hostname]
callback((hostname, ip), None)
10 years ago
elif hostname in self._cache:
11 years ago
logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname]
11 years ago
callback((hostname, ip), None)
11 years ago
else:
11 years ago
if not is_valid_hostname(hostname):
callback(None, Exception('invalid hostname: %s' % hostname))
return
11 years ago
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
self._cb_to_hostname[callback] = hostname
else:
arr.append(callback)
# TODO send again only if waited too long
self._send_req(hostname, QTYPE_A)
11 years ago
11 years ago
def close(self):
11 years ago
if self._sock:
self._sock.close()
self._sock = None
10 years ago
def test():
dns_resolver = DNSResolver()
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop, ref=True)
global counter
counter = 0
def make_callback():
global counter
def callback(result, error):
global counter
# TODO: what can we assert?
print(result, error)
counter += 1
if counter == 9:
10 years ago
loop.remove_handler(dns_resolver.handle_events)
dns_resolver.close()
a_callback = callback
return a_callback
assert(make_callback() != make_callback())
dns_resolver.resolve(b'google.com', make_callback())
dns_resolver.resolve('google.com', make_callback())
dns_resolver.resolve('example.com', make_callback())
dns_resolver.resolve('ipv6.google.com', make_callback())
dns_resolver.resolve('www.facebook.com', make_callback())
dns_resolver.resolve('ns2.google.com', make_callback())
dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
loop.run()
if __name__ == '__main__':
test()