diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 0fedd80..4a456a3 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -1008,6 +1008,9 @@ class TCPRelay(object): if error_no in (errno.EAGAIN, errno.EINPROGRESS, errno.EWOULDBLOCK): return + elif error_no == errno.ECONNRESET: + shell.print_exception(e) + logging.info("recv RST, ignore") else: shell.print_exception(e) if self._config['verbose']: diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 4c66ac8..1b119ea 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -885,6 +885,8 @@ class UDPRelay(object): self._udp_cache_size = config['udp_cache'] self._cache = lru_cache.LRUCache(timeout=config['udp_timeout'], close_callback=self._close_client) + self._cache_dns_client = lru_cache.LRUCache(timeout=10, + close_callback=self._close_client) self._client_fd_to_server_addr = {} self._dns_cache = lru_cache.LRUCache(timeout=300) self._eventloop = None @@ -1066,6 +1068,8 @@ class UDPRelay(object): af, socktype, proto, canonname, sa = addrs[0] key = client_key(r_addr, af) client = self._cache.get(key, None) + if not client: + client = self._cache_dns_client.get(key, None) if not client: if self._forbidden_iplist: if common.to_str(sa[0]) in self._forbidden_iplist: @@ -1075,8 +1079,18 @@ class UDPRelay(object): return client = socket.socket(af, socktype, proto) client.setblocking(False) - self._cache[key] = client - self._client_fd_to_server_addr[client.fileno()] = r_addr + is_dns = False + if len(data) > 12 and data[11:19] == b"\x00\x01\x00\x00\x00\x00\x00\x00": + is_dns = True + else: + pass + #logging.info("unknown data %s" % (binascii.hexlify(data),)) + if sa[1] == 53 and is_dns: #DNS + logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1])) + self._cache_dns_client[key] = client + else: + self._cache[key] = client + self._client_fd_to_server_addr[client.fileno()] = (r_addr, af) self._sockets.add(client.fileno()) self._eventloop.add(client, eventloop.POLL_IN, self) @@ -1088,6 +1102,7 @@ class UDPRelay(object): r_addr[0], r_addr[1])) self._cache.clear(self._udp_cache_size) + self._cache_dns_client.clear(16) if self._is_local: ref_iv = [encrypt.encrypt_new_iv(self._method)] @@ -1215,7 +1230,13 @@ class UDPRelay(object): client_addr = self._client_fd_to_server_addr.get(sock.fileno()) if client_addr: self.server_transfer_dl += len(response) - self.write_to_server_socket(response, client_addr) + self.write_to_server_socket(response, client_addr[0]) + key = client_key(client_addr[0], client_addr[1]) + client = self._cache_dns_client.get(key, None) + if client: + logging.debug("remove dns client %s:%d" % (client_addr[0][0], client_addr[0][1])) + del self._cache_dns_client[key] + self._close_client(client) else: # this packet is from somewhere else we know # simply drop that packet @@ -1344,6 +1365,7 @@ class UDPRelay(object): def handle_periodic(self): if self._closed: self._cache.clear(0) + self._cache_dns_client.clear(0) self._dns_cache.sweep() if self._eventloop: self._eventloop.remove_periodic(self.handle_periodic) @@ -1355,6 +1377,7 @@ class UDPRelay(object): else: before_sweep_size = len(self._sockets) self._cache.sweep() + self._cache_dns_client.sweep() self._dns_cache.sweep() if before_sweep_size != len(self._sockets): logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) @@ -1369,3 +1392,4 @@ class UDPRelay(object): self._eventloop.remove(self._server_socket) self._server_socket.close() self._cache.clear(0) + self._cache_dns_client.clear(0)