diff --git a/shadowsocks/common.py b/shadowsocks/common.py index fdb2243..40e46cf 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -236,7 +236,7 @@ def parse_header(data): 'encryption method' % addrtype) if dest_addr is None: return None - return connecttype, to_bytes(dest_addr), dest_port, header_length + return connecttype, addrtype, to_bytes(dest_addr), dest_port, header_length class IPNetwork(object): diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 693dca2..f816052 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -123,6 +123,29 @@ class SpeedTester(object): return self.sum_len >= self.max_speed return False +class UDPAsyncDNSHandler(object): + def __init__(self, params): + self.params = params + self.remote_addr = None + self.call_back = None + + def resolve(self, dns_resolver, remote_addr, call_back): + self.call_back = call_back + self.remote_addr = remote_addr + dns_resolver.resolve(remote_addr[0], self._handle_dns_resolved) + + def _handle_dns_resolved(self, result, error): + if error: + logging.error("%s when resolve DNS" % (error,)) #drop + return + if result: + ip = result[1] + if ip: + if self.call_back: + self.call_back(self.params, self.remote_addr, ip) + return + logging.warning("can't resolve %s" % (self.remote_addr,)) + class TCPRelayHandler(object): def __init__(self, server, fd_to_handlers, loop, local_sock, config, dns_resolver, is_local): @@ -344,26 +367,12 @@ class TCPRelayHandler(object): header_result = parse_header(data) if header_result is None: continue - connecttype, dest_addr, dest_port, header_length = header_result - addrs = socket.getaddrinfo(dest_addr, dest_port, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) - if addrs: - af, socktype, proto, canonname, server_addr = addrs[0] - data = data[header_length:] - if af == socket.AF_INET6: - self._remote_sock_v6.sendto(data, (server_addr[0], dest_port)) - if self._udpv6_send_pack_id == 0: - addr, port = self._remote_sock_v6.getsockname()[:2] - common.connect_log('UDPv6 sendto %s:%d from %s:%d by user %d' % - (server_addr[0], dest_port, addr, port, self._user_id)) - self._udpv6_send_pack_id += 1 - else: - sock.sendto(data, (server_addr[0], dest_port)) - if self._udp_send_pack_id == 0: - addr, port = sock.getsockname()[:2] - common.connect_log('UDP sendto %s:%d from %s:%d by user %d' % - (server_addr[0], dest_port, addr, port, self._user_id)) - self._udp_send_pack_id += 1 + connecttype, addrtype, dest_addr, dest_port, header_length = header_result + if (addrtype & 7) == 3: + handler = UDPAsyncDNSHandler(data[header_length:]) + handler.resolve(self._dns_resolver, (dest_addr, dest_port), self._handle_server_dns_resolved) + else: + return self._handle_server_dns_resolved(data[header_length:], (dest_addr, dest_port), dest_addr) except Exception as e: #trace = traceback.format_exc() @@ -426,6 +435,31 @@ class TCPRelayHandler(object): logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) return True + def _handle_server_dns_resolved(self, data, remote_addr, server_addr): + try: + addrs = socket.getaddrinfo(server_addr, remote_addr[1], 0, socket.SOCK_DGRAM, socket.SOL_UDP) + if not addrs: # drop + return + af, socktype, proto, canonname, sa = addrs[0] + if af == socket.AF_INET6: + self._remote_sock_v6.sendto(data, (server_addr, remote_addr[1])) + if self._udpv6_send_pack_id == 0: + addr, port = self._remote_sock_v6.getsockname()[:2] + common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' % + (common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id)) + self._udpv6_send_pack_id += 1 + else: + self._remote_sock.sendto(data, (server_addr, remote_addr[1])) + if self._udp_send_pack_id == 0: + addr, port = self._remote_sock.getsockname()[:2] + common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' % + (common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id)) + self._udp_send_pack_id += 1 + return True + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + def _get_redirect_host(self, client_address, ogn_data): host_list = self._redir_list or ["*#0.0.0.0:0"] @@ -601,7 +635,7 @@ class TCPRelayHandler(object): header_result = parse_header(data) if header_result is not None: try: - common.to_str(header_result[1]) + common.to_str(header_result[2]) except Exception as e: header_result = None if header_result is None: @@ -613,7 +647,7 @@ class TCPRelayHandler(object): server_info.buffer_size = self._recv_buffer_size server_info = self._protocol.get_server_info() server_info.buffer_size = self._recv_buffer_size - connecttype, remote_addr, remote_port, header_length = header_result + connecttype, addrtype, remote_addr, remote_port, header_length = header_result if connecttype != 0: pass #common.connect_log('UDP over TCP by user %d' % @@ -771,7 +805,7 @@ class TCPRelayHandler(object): raise e addr, port = self._remote_sock.getsockname()[:2] common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' % - (self._remote_address[0], remote_addr, remote_port, addr, port, self._user_id)) + (common.to_str(self._remote_address[0]), common.to_str(remote_addr), remote_port, addr, port, self._user_id)) self._loop.add(remote_sock, eventloop.POLL_ERR | eventloop.POLL_OUT, diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index c153979..8271d1e 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -123,11 +123,33 @@ RSP_STATE_ERROR = b"\x03" RSP_STATE_DISCONNECT = b"\x04" RSP_STATE_REDIRECT = b"\x05" +class UDPAsyncDNSHandler(object): + def __init__(self, params): + self.params = params + self.remote_addr = None + self.call_back = None + + def resolve(self, dns_resolver, remote_addr, call_back): + self.call_back = call_back + self.remote_addr = remote_addr + dns_resolver.resolve(remote_addr[0], self._handle_dns_resolved) + + def _handle_dns_resolved(self, result, error): + if error: + logging.error("%s when resolve DNS" % (error,)) #drop + return + if result: + ip = result[1] + if ip: + if self.call_back: + self.call_back(*self.params, self.remote_addr, None, ip, True) + return + logging.warning("can't resolve %s" % (self.remote_addr,)) + def client_key(source_addr, server_af): # notice this is server af, not dest af return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) - class UDPRelay(object): def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): self._config = config @@ -154,7 +176,7 @@ class UDPRelay(object): self._cache_dns_client = lru_cache.LRUCache(timeout=10, close_callback=self._close_client_pair) self._client_fd_to_server_addr = {} - self._dns_cache = lru_cache.LRUCache(timeout=1800) + #self._dns_cache = lru_cache.LRUCache(timeout=1800) self._eventloop = None self._closed = False self.server_transfer_ul = 0 @@ -375,97 +397,98 @@ class UDPRelay(object): if header_result is None: self._handel_protocol_error(r_addr, ogn_data) return - connecttype, dest_addr, dest_port, header_length = header_result + connecttype, addrtype, dest_addr, dest_port, header_length = header_result if self._is_local: - connecttype = 3 + addrtype = 3 server_addr, server_port = self._get_a_server() else: server_addr, server_port = dest_addr, dest_port - if (connecttype & 7) == 3: - addrs = self._dns_cache.get(server_addr, None) + if (addrtype & 7) == 3: + handler = UDPAsyncDNSHandler((data, r_addr, uid, header_length)) + handler.resolve(self._dns_resolver, (server_addr, server_port), self._handle_server_dns_resolved) + else: + self._handle_server_dns_resolved(data, r_addr, uid, header_length, (server_addr, server_port), None, server_addr, False) + + def _handle_server_dns_resolved(self, data, r_addr, uid, header_length, remote_addr, addrs, server_addr, dns_resolved): + try: + server_port = remote_addr[1] if addrs is None: - # TODO async getaddrinfo addrs = socket.getaddrinfo(server_addr, server_port, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) - if not addrs: - # drop - return - else: - self._dns_cache[server_addr] = addrs - else: - addrs = socket.getaddrinfo(server_addr, server_port, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) - if not addrs: - # drop + socket.SOCK_DGRAM, socket.SOL_UDP) + if not addrs: # drop return + af, socktype, proto, canonname, sa = addrs[0] + server_addr = sa[0] + key = client_key(r_addr, af) + client_pair = self._cache.get(key, None) + if client_pair is None: + client_pair = self._cache_dns_client.get(key, None) + if client_pair is None: + if self._forbidden_iplist: + if common.to_str(sa[0]) in self._forbidden_iplist: + logging.debug('IP %s is in forbidden list, drop' % common.to_str(sa[0])) + # drop + return + if self._forbidden_portset: + if sa[1] in self._forbidden_portset: + logging.debug('Port %d is in forbidden list, reject' % sa[1]) + # drop + return + client = socket.socket(af, socktype, proto) + client_uid = uid + client.setblocking(False) + self._socket_bind_addr(client, af) + is_dns = False + if len(data) > header_length + 13 and data[header_length + 4 : header_length + 12] == b"\x00\x01\x00\x00\x00\x00\x00\x00": + is_dns = True + else: + pass + if sa[1] == 53 and is_dns: #DNS + logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1])) + self._cache_dns_client[key] = (client, uid) + else: + self._cache[key] = (client, uid) + self._client_fd_to_server_addr[client.fileno()] = (r_addr, af) - af, socktype, proto, canonname, sa = addrs[0] - key = client_key(r_addr, af) - client_pair = self._cache.get(key, None) - if client_pair is None: - client_pair = self._cache_dns_client.get(key, None) - if client_pair is None: - if self._forbidden_iplist: - if common.to_str(sa[0]) in self._forbidden_iplist: - logging.debug('IP %s is in forbidden list, drop' % common.to_str(sa[0])) - # drop - return - if self._forbidden_portset: - if sa[1] in self._forbidden_portset: - logging.debug('Port %d is in forbidden list, reject' % sa[1]) - # drop - return - client = socket.socket(af, socktype, proto) - client_uid = uid - client.setblocking(False) - self._socket_bind_addr(client, af) - is_dns = False - if len(data) > 20 and data[11:19] == b"\x00\x01\x00\x00\x00\x00\x00\x00": - is_dns = True - else: - pass - if sa[1] == 53 and is_dns: #DNS - logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1])) - self._cache_dns_client[key] = (client, uid) - else: - self._cache[key] = (client, uid) - self._client_fd_to_server_addr[client.fileno()] = (r_addr, af) - - self._sockets.add(client.fileno()) - self._eventloop.add(client, eventloop.POLL_IN, self) + self._sockets.add(client.fileno()) + self._eventloop.add(client, eventloop.POLL_IN, self) - logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) + logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) - if uid is None: - user_id = self._listen_port + if uid is None: + user_id = self._listen_port + else: + user_id = struct.unpack('