Browse Source

add async dns to tcp relay

auth
clowwindy 11 years ago
parent
commit
bcdc1e9671
  1. 9
      shadowsocks/asyncdns.py
  2. 58
      shadowsocks/tcprelay.py

9
shadowsocks/asyncdns.py

@ -173,8 +173,8 @@ def parse_response(data):
res_tc = header[1] & 2 res_tc = header[1] & 2
# res_ra = header[2] & 128 # res_ra = header[2] & 128
res_rcode = header[2] & 15 res_rcode = header[2] & 15
assert res_tc == 0 # assert res_tc == 0
assert res_rcode in [0, 3] # assert res_rcode in [0, 3]
res_qdcount = header[3] res_qdcount = header[3]
res_ancount = header[4] res_ancount = header[4]
res_nscount = header[5] res_nscount = header[5]
@ -308,7 +308,11 @@ class DNSResolver(object):
for callback in callbacks: for callback in callbacks:
if self._cb_to_hostname.__contains__(callback): if self._cb_to_hostname.__contains__(callback):
del self._cb_to_hostname[callback] del self._cb_to_hostname[callback]
if ip:
callback((hostname, ip), None) callback((hostname, ip), None)
else:
callback((hostname, None),
Exception('unknown hostname %s' % hostname))
if self._hostname_to_cb.__contains__(hostname): if self._hostname_to_cb.__contains__(hostname):
del self._hostname_to_cb[hostname] del self._hostname_to_cb[hostname]
if self._hostname_status.__contains__(hostname): if self._hostname_status.__contains__(hostname):
@ -329,6 +333,7 @@ class DNSResolver(object):
self._hostname_status[hostname] = STATUS_IPV6 self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA) self._send_req(hostname, QTYPE_AAAA)
else: else:
if ip:
self._cache[hostname] = ip self._cache[hostname] = ip
self._call_callback(hostname, ip) self._call_callback(hostname, ip)

58
shadowsocks/tcprelay.py

@ -56,6 +56,7 @@ CMD_UDP_ASSOCIATE = 3
STAGE_INIT = 0 STAGE_INIT = 0
STAGE_HELLO = 1 STAGE_HELLO = 1
STAGE_UDP_ASSOC = 2 STAGE_UDP_ASSOC = 2
STAGE_DNS = 3
STAGE_REPLY = 4 STAGE_REPLY = 4
STAGE_STREAM = 5 STAGE_STREAM = 5
STAGE_DESTROYED = -1 STAGE_DESTROYED = -1
@ -75,13 +76,14 @@ BUF_SIZE = 8 * 1024
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config, def __init__(self, server, fd_to_handlers, loop, local_sock, config,
is_local): dns_resolver, is_local):
self._server = server self._server = server
self._fd_to_handlers = fd_to_handlers self._fd_to_handlers = fd_to_handlers
self._loop = loop self._loop = loop
self._local_sock = local_sock self._local_sock = local_sock
self._remote_sock = None self._remote_sock = None
self._config = config self._config = config
self._dns_resolver = dns_resolver
self._is_local = is_local self._is_local = is_local
self._stage = STAGE_INIT self._stage = STAGE_INIT
self._encryptor = encrypt.Encryptor(config['password'], self._encryptor = encrypt.Encryptor(config['password'],
@ -239,51 +241,81 @@ class TCPRelayHandler(object):
addrtype, remote_addr, remote_port, header_length = header_result addrtype, remote_addr, remote_port, header_length = header_result
logging.info('connecting %s:%d' % (remote_addr, remote_port)) logging.info('connecting %s:%d' % (remote_addr, remote_port))
self._remote_address = (remote_addr, remote_port) self._remote_address = (remote_addr, remote_port)
# pause reading
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
self._stage = STAGE_DNS
if self._is_local: if self._is_local:
# forward address to remote # forward address to remote
self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10', self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10',
self._local_sock) self._local_sock)
data_to_send = self._encryptor.encrypt(data) data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send) self._data_to_write_to_remote.append(data_to_send)
remote_addr = self._config['server'] # notice here may go into _handle_dns_resolved directly
remote_port = self._config['server_port'] self._dns_resolver.resolve(self._config['server'],
self._handle_dns_resolved)
else: else:
if len(data) > header_length: if len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:]) self._data_to_write_to_remote.append(data[header_length:])
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(remote_addr,
self._handle_dns_resolved)
except Exception as e:
logging.error(e)
traceback.print_exc()
# TODO use logging when debug completed
self.destroy()
# TODO async DNS def _handle_dns_resolved(self, result, error):
addrs = socket.getaddrinfo(remote_addr, remote_port, 0, if error:
socket.SOCK_STREAM, socket.SOL_TCP) logging.error(error)
self.destroy()
return
if result:
ip = result[1]
if ip:
try:
self._stage = STAGE_REPLY
remote_addr = self._remote_address[0]
remote_port = self._remote_address[1]
if self._is_local:
remote_addr = self._config['server']
remote_port = self._config['server_port']
addrs = socket.getaddrinfo(ip, remote_port, 0,
socket.SOCK_STREAM,
socket.SOL_TCP)
if len(addrs) == 0: if len(addrs) == 0:
raise Exception("can't get addrinfo for %s:%d" % raise Exception("getaddrinfo failed for %s:%d" %
(remote_addr, remote_port)) (remote_addr, remote_port))
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
remote_sock = socket.socket(af, socktype, proto) remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock self._remote_sock = remote_sock
self._fd_to_handlers[remote_sock.fileno()] = self self._fd_to_handlers[remote_sock.fileno()] = self
remote_sock.setblocking(False) remote_sock.setblocking(False)
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY,
1)
if self._is_local and self._config['fast_open']: if self._is_local and self._config['fast_open']:
# wait for more data to arrive and send them in one SYN # wait for more data to arrive and send them in one SYN
self._stage = STAGE_REPLY self._stage = STAGE_REPLY
self._loop.add(remote_sock, eventloop.POLL_ERR) self._loop.add(remote_sock, eventloop.POLL_ERR)
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
# TODO when there is already data in this packet # TODO when there is already data in this packet
else: else:
try: try:
remote_sock.connect(sa) remote_sock.connect(sa)
except (OSError, IOError) as e: except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EINPROGRESS: if eventloop.errno_from_exception(e) == \
errno.EINPROGRESS:
pass pass
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT) eventloop.POLL_ERR | eventloop.POLL_OUT)
self._stage = STAGE_REPLY self._stage = STAGE_REPLY
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
except Exception as e: return
except (OSError, IOError) as e:
logging.error(e) logging.error(e)
traceback.print_exc() traceback.print_exc()
# TODO use logging when debug completed
self.destroy() self.destroy()
def _on_local_read(self): def _on_local_read(self):
@ -422,6 +454,7 @@ class TCPRelayHandler(object):
del self._fd_to_handlers[self._local_sock.fileno()] del self._fd_to_handlers[self._local_sock.fileno()]
self._local_sock.close() self._local_sock.close()
self._local_sock = None self._local_sock = None
self._dns_resolver.remove_callback(self._handle_dns_resolved)
self._server.remove_handler(self) self._server.remove_handler(self)
@ -545,7 +578,8 @@ class TCPRelay(object):
# logging.debug('accept') # logging.debug('accept')
conn = self._server_socket.accept() conn = self._server_socket.accept()
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
conn[0], self._config, self._is_local) conn[0], self._config, self._dns_resolver,
self._is_local)
except (OSError, IOError) as e: except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e) error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS): if error_no in (errno.EAGAIN, errno.EINPROGRESS):

Loading…
Cancel
Save