Browse Source

double UDP sockets for IPv4 and IPv6

master
breakwa11 10 years ago
parent
commit
e5b4a804e9
  1. 55
      shadowsocks/tcprelay.py

55
shadowsocks/tcprelay.py

@ -94,8 +94,6 @@ BUF_SIZE = 32 * 1024
class TCPRelayHandler(object): class TCPRelayHandler(object):
support_ipv6 = None
def __init__(self, server, fd_to_handlers, loop, local_sock, config, def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local): dns_resolver, is_local):
self._server = server self._server = server
@ -103,6 +101,7 @@ class TCPRelayHandler(object):
self._loop = loop self._loop = loop
self._local_sock = local_sock self._local_sock = local_sock
self._remote_sock = None self._remote_sock = None
self._remote_sock_v6 = None
self._remote_udp = False self._remote_udp = False
self._config = config self._config = config
self._dns_resolver = dns_resolver self._dns_resolver = dns_resolver
@ -194,7 +193,7 @@ class TCPRelayHandler(object):
if not data or not sock: if not data or not sock:
return False return False
#logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp)) #logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp))
if self._remote_sock == sock and self._remote_udp: if self._remote_udp and self._remote_sock == sock:
try: try:
frag = common.ord(data[2]) frag = common.ord(data[2])
if frag != 0: if frag != 0:
@ -211,11 +210,15 @@ class TCPRelayHandler(object):
if addrs: if addrs:
af, socktype, proto, canonname, server_addr = addrs[0] af, socktype, proto, canonname, server_addr = addrs[0]
data = data[header_length:] data = data[header_length:]
sock.sendto(data, (server_addr[0], dest_port)) if af == socket.AF_INET6:
self._remote_sock_v6.sendto(data, (server_addr[0], dest_port))
else:
sock.sendto(data, (server_addr[0], dest_port))
except Exception as e: except Exception as e:
trace = traceback.format_exc() #trace = traceback.format_exc()
logging.error(trace) #logging.error(trace)
logging.error(e)
return True return True
uncomplete = False uncomplete = False
@ -362,18 +365,10 @@ class TCPRelayHandler(object):
return True return True
return False return False
def _is_support_ipv6(self):
if TCPRelayHandler.support_ipv6 is None:
local = socket.gethostbyaddr(socket.gethostname())
TCPRelayHandler.support_ipv6 = self._has_ipv6_addr(local)
return TCPRelayHandler.support_ipv6
def _create_remote_socket(self, ip, port): def _create_remote_socket(self, ip, port):
if self._remote_udp: if self._remote_udp:
if self._is_support_ipv6(): addrs_v6 = socket.getaddrinfo("::", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP)
addrs = socket.getaddrinfo("::", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) addrs = socket.getaddrinfo("0.0.0.0", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP)
else:
addrs = socket.getaddrinfo("0.0.0.0", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP)
else: else:
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
if len(addrs) == 0: if len(addrs) == 0:
@ -385,7 +380,15 @@ class TCPRelayHandler(object):
common.to_str(sa[0])) common.to_str(sa[0]))
remote_sock = socket.socket(af, socktype, proto) remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock self._remote_sock = remote_sock
self._fd_to_handlers[remote_sock.fileno()] = self self._fd_to_handlers[remote_sock.fileno()] = self
if self._remote_udp:
af, socktype, proto, canonname, sa = addrs_v6[0]
remote_sock_v6 = socket.socket(af, socktype, proto)
self._remote_sock_v6 = remote_sock_v6
self._fd_to_handlers[remote_sock_v6.fileno()] = self
remote_sock.setblocking(False) remote_sock.setblocking(False)
if self._remote_udp: if self._remote_udp:
pass pass
@ -425,6 +428,9 @@ class TCPRelayHandler(object):
if self._remote_udp: if self._remote_udp:
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_IN) eventloop.POLL_IN)
if self._remote_sock_v6:
self._loop.add(self._remote_sock_v6,
eventloop.POLL_IN)
else: else:
try: try:
remote_sock.connect((remote_addr, remote_port)) remote_sock.connect((remote_addr, remote_port))
@ -482,13 +488,16 @@ class TCPRelayHandler(object):
(not is_local and self._stage == STAGE_INIT): (not is_local and self._stage == STAGE_INIT):
self._handle_stage_addr(data) self._handle_stage_addr(data)
def _on_remote_read(self): def _on_remote_read(self, is_remote_sock):
# handle all remote read events # handle all remote read events
self._update_activity() self._update_activity()
data = None data = None
try: try:
if self._remote_udp: if self._remote_udp:
data, addr = self._remote_sock.recvfrom(BUF_SIZE) if is_remote_sock:
data, addr = self._remote_sock.recvfrom(BUF_SIZE)
else:
data, addr = self._remote_sock_v6.recvfrom(BUF_SIZE)
port = struct.pack('>H', addr[1]) port = struct.pack('>H', addr[1])
try: try:
ip = socket.inet_aton(addr[0]) ip = socket.inet_aton(addr[0])
@ -557,13 +566,13 @@ class TCPRelayHandler(object):
logging.debug('ignore handle_event: destroyed') logging.debug('ignore handle_event: destroyed')
return return
# order is important # order is important
if sock == self._remote_sock: if sock == self._remote_sock or sock == self._remote_sock_v6:
if event & eventloop.POLL_ERR: if event & eventloop.POLL_ERR:
self._on_remote_error() self._on_remote_error()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
self._on_remote_read() self._on_remote_read(sock == self._remote_sock)
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return
if event & eventloop.POLL_OUT: if event & eventloop.POLL_OUT:
@ -610,6 +619,12 @@ class TCPRelayHandler(object):
del self._fd_to_handlers[self._remote_sock.fileno()] del self._fd_to_handlers[self._remote_sock.fileno()]
self._remote_sock.close() self._remote_sock.close()
self._remote_sock = None self._remote_sock = None
if self._remote_sock_v6:
logging.debug('destroying remote')
self._loop.remove(self._remote_sock_v6)
del self._fd_to_handlers[self._remote_sock_v6.fileno()]
self._remote_sock_v6.close()
self._remote_sock_v6 = None
if self._local_sock: if self._local_sock:
logging.debug('destroying local') logging.debug('destroying local')
self._loop.remove(self._local_sock) self._loop.remove(self._local_sock)

Loading…
Cancel
Save