Browse Source

Async DNS query under UDP

dev
破娃酱 8 years ago
parent
commit
abe4ac3b90
  1. 2
      shadowsocks/common.py
  2. 80
      shadowsocks/tcprelay.py
  3. 173
      shadowsocks/udprelay.py

2
shadowsocks/common.py

@ -236,7 +236,7 @@ def parse_header(data):
'encryption method' % addrtype) 'encryption method' % addrtype)
if dest_addr is None: if dest_addr is None:
return None return None
return connecttype, to_bytes(dest_addr), dest_port, header_length return connecttype, addrtype, to_bytes(dest_addr), dest_port, header_length
class IPNetwork(object): class IPNetwork(object):

80
shadowsocks/tcprelay.py

@ -123,6 +123,29 @@ class SpeedTester(object):
return self.sum_len >= self.max_speed return self.sum_len >= self.max_speed
return False return False
class UDPAsyncDNSHandler(object):
def __init__(self, params):
self.params = params
self.remote_addr = None
self.call_back = None
def resolve(self, dns_resolver, remote_addr, call_back):
self.call_back = call_back
self.remote_addr = remote_addr
dns_resolver.resolve(remote_addr[0], self._handle_dns_resolved)
def _handle_dns_resolved(self, result, error):
if error:
logging.error("%s when resolve DNS" % (error,)) #drop
return
if result:
ip = result[1]
if ip:
if self.call_back:
self.call_back(self.params, self.remote_addr, ip)
return
logging.warning("can't resolve %s" % (self.remote_addr,))
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config, def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local): dns_resolver, is_local):
@ -344,26 +367,12 @@ class TCPRelayHandler(object):
header_result = parse_header(data) header_result = parse_header(data)
if header_result is None: if header_result is None:
continue continue
connecttype, dest_addr, dest_port, header_length = header_result connecttype, addrtype, dest_addr, dest_port, header_length = header_result
addrs = socket.getaddrinfo(dest_addr, dest_port, 0, if (addrtype & 7) == 3:
socket.SOCK_DGRAM, socket.SOL_UDP) handler = UDPAsyncDNSHandler(data[header_length:])
if addrs: handler.resolve(self._dns_resolver, (dest_addr, dest_port), self._handle_server_dns_resolved)
af, socktype, proto, canonname, server_addr = addrs[0] else:
data = data[header_length:] return self._handle_server_dns_resolved(data[header_length:], (dest_addr, dest_port), dest_addr)
if af == socket.AF_INET6:
self._remote_sock_v6.sendto(data, (server_addr[0], dest_port))
if self._udpv6_send_pack_id == 0:
addr, port = self._remote_sock_v6.getsockname()[:2]
common.connect_log('UDPv6 sendto %s:%d from %s:%d by user %d' %
(server_addr[0], dest_port, addr, port, self._user_id))
self._udpv6_send_pack_id += 1
else:
sock.sendto(data, (server_addr[0], dest_port))
if self._udp_send_pack_id == 0:
addr, port = sock.getsockname()[:2]
common.connect_log('UDP sendto %s:%d from %s:%d by user %d' %
(server_addr[0], dest_port, addr, port, self._user_id))
self._udp_send_pack_id += 1
except Exception as e: except Exception as e:
#trace = traceback.format_exc() #trace = traceback.format_exc()
@ -426,6 +435,31 @@ class TCPRelayHandler(object):
logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
return True return True
def _handle_server_dns_resolved(self, data, remote_addr, server_addr):
try:
addrs = socket.getaddrinfo(server_addr, remote_addr[1], 0, socket.SOCK_DGRAM, socket.SOL_UDP)
if not addrs: # drop
return
af, socktype, proto, canonname, sa = addrs[0]
if af == socket.AF_INET6:
self._remote_sock_v6.sendto(data, (server_addr, remote_addr[1]))
if self._udpv6_send_pack_id == 0:
addr, port = self._remote_sock_v6.getsockname()[:2]
common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id))
self._udpv6_send_pack_id += 1
else:
self._remote_sock.sendto(data, (server_addr, remote_addr[1]))
if self._udp_send_pack_id == 0:
addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' %
(common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id))
self._udp_send_pack_id += 1
return True
except Exception as e:
shell.print_exception(e)
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))
def _get_redirect_host(self, client_address, ogn_data): def _get_redirect_host(self, client_address, ogn_data):
host_list = self._redir_list or ["*#0.0.0.0:0"] host_list = self._redir_list or ["*#0.0.0.0:0"]
@ -601,7 +635,7 @@ class TCPRelayHandler(object):
header_result = parse_header(data) header_result = parse_header(data)
if header_result is not None: if header_result is not None:
try: try:
common.to_str(header_result[1]) common.to_str(header_result[2])
except Exception as e: except Exception as e:
header_result = None header_result = None
if header_result is None: if header_result is None:
@ -613,7 +647,7 @@ class TCPRelayHandler(object):
server_info.buffer_size = self._recv_buffer_size server_info.buffer_size = self._recv_buffer_size
server_info = self._protocol.get_server_info() server_info = self._protocol.get_server_info()
server_info.buffer_size = self._recv_buffer_size server_info.buffer_size = self._recv_buffer_size
connecttype, remote_addr, remote_port, header_length = header_result connecttype, addrtype, remote_addr, remote_port, header_length = header_result
if connecttype != 0: if connecttype != 0:
pass pass
#common.connect_log('UDP over TCP by user %d' % #common.connect_log('UDP over TCP by user %d' %
@ -771,7 +805,7 @@ class TCPRelayHandler(object):
raise e raise e
addr, port = self._remote_sock.getsockname()[:2] addr, port = self._remote_sock.getsockname()[:2]
common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' % common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' %
(self._remote_address[0], remote_addr, remote_port, addr, port, self._user_id)) (common.to_str(self._remote_address[0]), common.to_str(remote_addr), remote_port, addr, port, self._user_id))
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT, eventloop.POLL_ERR | eventloop.POLL_OUT,

173
shadowsocks/udprelay.py

@ -123,11 +123,33 @@ RSP_STATE_ERROR = b"\x03"
RSP_STATE_DISCONNECT = b"\x04" RSP_STATE_DISCONNECT = b"\x04"
RSP_STATE_REDIRECT = b"\x05" RSP_STATE_REDIRECT = b"\x05"
class UDPAsyncDNSHandler(object):
def __init__(self, params):
self.params = params
self.remote_addr = None
self.call_back = None
def resolve(self, dns_resolver, remote_addr, call_back):
self.call_back = call_back
self.remote_addr = remote_addr
dns_resolver.resolve(remote_addr[0], self._handle_dns_resolved)
def _handle_dns_resolved(self, result, error):
if error:
logging.error("%s when resolve DNS" % (error,)) #drop
return
if result:
ip = result[1]
if ip:
if self.call_back:
self.call_back(*self.params, self.remote_addr, None, ip, True)
return
logging.warning("can't resolve %s" % (self.remote_addr,))
def client_key(source_addr, server_af): def client_key(source_addr, server_af):
# notice this is server af, not dest af # notice this is server af, not dest af
return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af)
class UDPRelay(object): class UDPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None):
self._config = config self._config = config
@ -154,7 +176,7 @@ class UDPRelay(object):
self._cache_dns_client = lru_cache.LRUCache(timeout=10, self._cache_dns_client = lru_cache.LRUCache(timeout=10,
close_callback=self._close_client_pair) close_callback=self._close_client_pair)
self._client_fd_to_server_addr = {} self._client_fd_to_server_addr = {}
self._dns_cache = lru_cache.LRUCache(timeout=1800) #self._dns_cache = lru_cache.LRUCache(timeout=1800)
self._eventloop = None self._eventloop = None
self._closed = False self._closed = False
self.server_transfer_ul = 0 self.server_transfer_ul = 0
@ -375,97 +397,98 @@ class UDPRelay(object):
if header_result is None: if header_result is None:
self._handel_protocol_error(r_addr, ogn_data) self._handel_protocol_error(r_addr, ogn_data)
return return
connecttype, dest_addr, dest_port, header_length = header_result connecttype, addrtype, dest_addr, dest_port, header_length = header_result
if self._is_local: if self._is_local:
connecttype = 3 addrtype = 3
server_addr, server_port = self._get_a_server() server_addr, server_port = self._get_a_server()
else: else:
server_addr, server_port = dest_addr, dest_port server_addr, server_port = dest_addr, dest_port
if (connecttype & 7) == 3: if (addrtype & 7) == 3:
addrs = self._dns_cache.get(server_addr, None) handler = UDPAsyncDNSHandler((data, r_addr, uid, header_length))
handler.resolve(self._dns_resolver, (server_addr, server_port), self._handle_server_dns_resolved)
else:
self._handle_server_dns_resolved(data, r_addr, uid, header_length, (server_addr, server_port), None, server_addr, False)
def _handle_server_dns_resolved(self, data, r_addr, uid, header_length, remote_addr, addrs, server_addr, dns_resolved):
try:
server_port = remote_addr[1]
if addrs is None: if addrs is None:
# TODO async getaddrinfo
addrs = socket.getaddrinfo(server_addr, server_port, 0, addrs = socket.getaddrinfo(server_addr, server_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP) socket.SOCK_DGRAM, socket.SOL_UDP)
if not addrs: if not addrs: # drop
# drop
return
else:
self._dns_cache[server_addr] = addrs
else:
addrs = socket.getaddrinfo(server_addr, server_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP)
if not addrs:
# drop
return return
af, socktype, proto, canonname, sa = addrs[0]
server_addr = sa[0]
key = client_key(r_addr, af)
client_pair = self._cache.get(key, None)
if client_pair is None:
client_pair = self._cache_dns_client.get(key, None)
if client_pair is None:
if self._forbidden_iplist:
if common.to_str(sa[0]) in self._forbidden_iplist:
logging.debug('IP %s is in forbidden list, drop' % common.to_str(sa[0]))
# drop
return
if self._forbidden_portset:
if sa[1] in self._forbidden_portset:
logging.debug('Port %d is in forbidden list, reject' % sa[1])
# drop
return
client = socket.socket(af, socktype, proto)
client_uid = uid
client.setblocking(False)
self._socket_bind_addr(client, af)
is_dns = False
if len(data) > header_length + 13 and data[header_length + 4 : header_length + 12] == b"\x00\x01\x00\x00\x00\x00\x00\x00":
is_dns = True
else:
pass
if sa[1] == 53 and is_dns: #DNS
logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1]))
self._cache_dns_client[key] = (client, uid)
else:
self._cache[key] = (client, uid)
self._client_fd_to_server_addr[client.fileno()] = (r_addr, af)
af, socktype, proto, canonname, sa = addrs[0] self._sockets.add(client.fileno())
key = client_key(r_addr, af) self._eventloop.add(client, eventloop.POLL_IN, self)
client_pair = self._cache.get(key, None)
if client_pair is None:
client_pair = self._cache_dns_client.get(key, None)
if client_pair is None:
if self._forbidden_iplist:
if common.to_str(sa[0]) in self._forbidden_iplist:
logging.debug('IP %s is in forbidden list, drop' % common.to_str(sa[0]))
# drop
return
if self._forbidden_portset:
if sa[1] in self._forbidden_portset:
logging.debug('Port %d is in forbidden list, reject' % sa[1])
# drop
return
client = socket.socket(af, socktype, proto)
client_uid = uid
client.setblocking(False)
self._socket_bind_addr(client, af)
is_dns = False
if len(data) > 20 and data[11:19] == b"\x00\x01\x00\x00\x00\x00\x00\x00":
is_dns = True
else:
pass
if sa[1] == 53 and is_dns: #DNS
logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1]))
self._cache_dns_client[key] = (client, uid)
else:
self._cache[key] = (client, uid)
self._client_fd_to_server_addr[client.fileno()] = (r_addr, af)
self._sockets.add(client.fileno())
self._eventloop.add(client, eventloop.POLL_IN, self)
logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets)))
if uid is None: if uid is None:
user_id = self._listen_port user_id = self._listen_port
else:
user_id = struct.unpack('<I', client_uid)[0]
else: else:
user_id = struct.unpack('<I', client_uid)[0] client, client_uid = client_pair
else: self._cache.clear(self._udp_cache_size)
client, client_uid = client_pair self._cache_dns_client.clear(16)
self._cache.clear(self._udp_cache_size)
self._cache_dns_client.clear(16) if self._is_local:
ref_iv = [encrypt.encrypt_new_iv(self._method)]
if self._is_local: self._protocol.obfs.server_info.iv = ref_iv[0]
ref_iv = [encrypt.encrypt_new_iv(self._method)] data = self._protocol.client_udp_pre_encrypt(data)
self._protocol.obfs.server_info.iv = ref_iv[0] #logging.debug("%s" % (binascii.hexlify(data),))
data = self._protocol.client_udp_pre_encrypt(data) data = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 1, data, ref_iv)
#logging.debug("%s" % (binascii.hexlify(data),)) if not data:
data = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 1, data, ref_iv) return
else:
data = data[header_length:]
if not data: if not data:
return return
else: except Exception as e:
data = data[header_length:] shell.print_exception(e)
if not data: logging.error("exception from user %d" % (user_id,))
return
try: try:
client.sendto(data, (server_addr, server_port)) client.sendto(data, (server_addr, server_port))
self.add_transfer_u(client_uid, len(data)) self.add_transfer_u(client_uid, len(data))
if client_pair is None: # new request if client_pair is None: # new request
addr, port = client.getsockname()[:2] addr, port = client.getsockname()[:2]
common.connect_log('UDP data to %s:%d from %s:%d by UID %d' % common.connect_log('UDP data to %s(%s):%d from %s:%d by user %d' %
(common.to_str(server_addr), server_port, addr, port, user_id)) (common.to_str(remote_addr[0]), common.to_str(server_addr), server_port, addr, port, user_id))
except IOError as e: except IOError as e:
err = eventloop.errno_from_exception(e) err = eventloop.errno_from_exception(e)
logging.warning('IOError sendto %s:%d by user %d' % (server_addr, server_port, user_id)) logging.warning('IOError sendto %s:%d by user %d' % (server_addr, server_port, user_id))
@ -623,7 +646,7 @@ class UDPRelay(object):
if self._closed: if self._closed:
self._cache.clear(0) self._cache.clear(0)
self._cache_dns_client.clear(0) self._cache_dns_client.clear(0)
self._dns_cache.sweep() #self._dns_cache.sweep()
if self._eventloop: if self._eventloop:
self._eventloop.remove_periodic(self.handle_periodic) self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket) self._eventloop.remove(self._server_socket)
@ -635,7 +658,7 @@ class UDPRelay(object):
before_sweep_size = len(self._sockets) before_sweep_size = len(self._sockets)
self._cache.sweep() self._cache.sweep()
self._cache_dns_client.sweep() self._cache_dns_client.sweep()
self._dns_cache.sweep() #self._dns_cache.sweep()
if before_sweep_size != len(self._sockets): if before_sweep_size != len(self._sockets):
logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets)))
self._sweep_timeout() self._sweep_timeout()

Loading…
Cancel
Save