diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 5f9f91b..20d6ff0 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -173,8 +173,8 @@ def parse_response(data): res_tc = header[1] & 2 # res_ra = header[2] & 128 res_rcode = header[2] & 15 - assert res_tc == 0 - assert res_rcode in [0, 3] + # assert res_tc == 0 + # assert res_rcode in [0, 3] res_qdcount = header[3] res_ancount = header[4] res_nscount = header[5] @@ -308,7 +308,11 @@ class DNSResolver(object): for callback in callbacks: if self._cb_to_hostname.__contains__(callback): del self._cb_to_hostname[callback] - callback((hostname, ip), None) + if ip: + callback((hostname, ip), None) + else: + callback((hostname, None), + Exception('unknown hostname %s' % hostname)) if self._hostname_to_cb.__contains__(hostname): del self._hostname_to_cb[hostname] if self._hostname_status.__contains__(hostname): @@ -329,7 +333,8 @@ class DNSResolver(object): self._hostname_status[hostname] = STATUS_IPV6 self._send_req(hostname, QTYPE_AAAA) else: - self._cache[hostname] = ip + if ip: + self._cache[hostname] = ip self._call_callback(hostname, ip) def handle_events(self, events): diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 381a809..08fb2f5 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -56,6 +56,7 @@ CMD_UDP_ASSOCIATE = 3 STAGE_INIT = 0 STAGE_HELLO = 1 STAGE_UDP_ASSOC = 2 +STAGE_DNS = 3 STAGE_REPLY = 4 STAGE_STREAM = 5 STAGE_DESTROYED = -1 @@ -75,13 +76,14 @@ BUF_SIZE = 8 * 1024 class TCPRelayHandler(object): def __init__(self, server, fd_to_handlers, loop, local_sock, config, - is_local): + dns_resolver, is_local): self._server = server self._fd_to_handlers = fd_to_handlers self._loop = loop self._local_sock = local_sock self._remote_sock = None self._config = config + self._dns_resolver = dns_resolver self._is_local = is_local self._stage = STAGE_INIT self._encryptor = encrypt.Encryptor(config['password'], @@ -239,53 +241,83 @@ class TCPRelayHandler(object): addrtype, remote_addr, remote_port, header_length = header_result logging.info('connecting %s:%d' % (remote_addr, remote_port)) self._remote_address = (remote_addr, remote_port) + # pause reading + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) + self._stage = STAGE_DNS if self._is_local: # forward address to remote self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10', self._local_sock) data_to_send = self._encryptor.encrypt(data) self._data_to_write_to_remote.append(data_to_send) - remote_addr = self._config['server'] - remote_port = self._config['server_port'] + # notice here may go into _handle_dns_resolved directly + self._dns_resolver.resolve(self._config['server'], + self._handle_dns_resolved) else: if len(data) > header_length: self._data_to_write_to_remote.append(data[header_length:]) - - # TODO async DNS - addrs = socket.getaddrinfo(remote_addr, remote_port, 0, - socket.SOCK_STREAM, socket.SOL_TCP) - if len(addrs) == 0: - raise Exception("can't get addrinfo for %s:%d" % - (remote_addr, remote_port)) - af, socktype, proto, canonname, sa = addrs[0] - remote_sock = socket.socket(af, socktype, proto) - self._remote_sock = remote_sock - self._fd_to_handlers[remote_sock.fileno()] = self - remote_sock.setblocking(False) - remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - - if self._is_local and self._config['fast_open']: - # wait for more data to arrive and send them in one SYN - self._stage = STAGE_REPLY - self._loop.add(remote_sock, eventloop.POLL_ERR) - # TODO when there is already data in this packet - else: - try: - remote_sock.connect(sa) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == errno.EINPROGRESS: - pass - self._loop.add(remote_sock, - eventloop.POLL_ERR | eventloop.POLL_OUT) - self._stage = STAGE_REPLY - self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + # notice here may go into _handle_dns_resolved directly + self._dns_resolver.resolve(remote_addr, + self._handle_dns_resolved) except Exception as e: logging.error(e) traceback.print_exc() # TODO use logging when debug completed self.destroy() + def _handle_dns_resolved(self, result, error): + if error: + logging.error(error) + self.destroy() + return + if result: + ip = result[1] + if ip: + try: + self._stage = STAGE_REPLY + remote_addr = self._remote_address[0] + remote_port = self._remote_address[1] + if self._is_local: + remote_addr = self._config['server'] + remote_port = self._config['server_port'] + addrs = socket.getaddrinfo(ip, remote_port, 0, + socket.SOCK_STREAM, + socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("getaddrinfo failed for %s:%d" % + (remote_addr, remote_port)) + af, socktype, proto, canonname, sa = addrs[0] + remote_sock = socket.socket(af, socktype, proto) + self._remote_sock = remote_sock + self._fd_to_handlers[remote_sock.fileno()] = self + remote_sock.setblocking(False) + remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, + 1) + + if self._is_local and self._config['fast_open']: + # wait for more data to arrive and send them in one SYN + self._stage = STAGE_REPLY + self._loop.add(remote_sock, eventloop.POLL_ERR) + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + # TODO when there is already data in this packet + else: + try: + remote_sock.connect(sa) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == \ + errno.EINPROGRESS: + pass + self._loop.add(remote_sock, + eventloop.POLL_ERR | eventloop.POLL_OUT) + self._stage = STAGE_REPLY + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + return + except (OSError, IOError) as e: + logging.error(e) + traceback.print_exc() + self.destroy() + def _on_local_read(self): self._update_activity() if not self._local_sock: @@ -422,6 +454,7 @@ class TCPRelayHandler(object): del self._fd_to_handlers[self._local_sock.fileno()] self._local_sock.close() self._local_sock = None + self._dns_resolver.remove_callback(self._handle_dns_resolved) self._server.remove_handler(self) @@ -545,7 +578,8 @@ class TCPRelay(object): # logging.debug('accept') conn = self._server_socket.accept() TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, - conn[0], self._config, self._is_local) + conn[0], self._config, self._dns_resolver, + self._is_local) except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) if error_no in (errno.EAGAIN, errno.EINPROGRESS):