diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index d86c988..729aa04 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -94,8 +94,6 @@ BUF_SIZE = 32 * 1024 class TCPRelayHandler(object): - support_ipv6 = None - def __init__(self, server, fd_to_handlers, loop, local_sock, config, dns_resolver, is_local): self._server = server @@ -103,6 +101,7 @@ class TCPRelayHandler(object): self._loop = loop self._local_sock = local_sock self._remote_sock = None + self._remote_sock_v6 = None self._remote_udp = False self._config = config self._dns_resolver = dns_resolver @@ -194,7 +193,7 @@ class TCPRelayHandler(object): if not data or not sock: return False #logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp)) - if self._remote_sock == sock and self._remote_udp: + if self._remote_udp and self._remote_sock == sock: try: frag = common.ord(data[2]) if frag != 0: @@ -211,11 +210,15 @@ class TCPRelayHandler(object): if addrs: af, socktype, proto, canonname, server_addr = addrs[0] data = data[header_length:] - sock.sendto(data, (server_addr[0], dest_port)) + if af == socket.AF_INET6: + self._remote_sock_v6.sendto(data, (server_addr[0], dest_port)) + else: + sock.sendto(data, (server_addr[0], dest_port)) except Exception as e: - trace = traceback.format_exc() - logging.error(trace) + #trace = traceback.format_exc() + #logging.error(trace) + logging.error(e) return True uncomplete = False @@ -362,18 +365,10 @@ class TCPRelayHandler(object): return True return False - def _is_support_ipv6(self): - if TCPRelayHandler.support_ipv6 is None: - local = socket.gethostbyaddr(socket.gethostname()) - TCPRelayHandler.support_ipv6 = self._has_ipv6_addr(local) - return TCPRelayHandler.support_ipv6 - def _create_remote_socket(self, ip, port): if self._remote_udp: - if self._is_support_ipv6(): - addrs = socket.getaddrinfo("::", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) - else: - addrs = socket.getaddrinfo("0.0.0.0", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) + addrs_v6 = socket.getaddrinfo("::", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) + addrs = socket.getaddrinfo("0.0.0.0", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) else: addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) if len(addrs) == 0: @@ -385,7 +380,15 @@ class TCPRelayHandler(object): common.to_str(sa[0])) remote_sock = socket.socket(af, socktype, proto) self._remote_sock = remote_sock + self._fd_to_handlers[remote_sock.fileno()] = self + + if self._remote_udp: + af, socktype, proto, canonname, sa = addrs_v6[0] + remote_sock_v6 = socket.socket(af, socktype, proto) + self._remote_sock_v6 = remote_sock_v6 + self._fd_to_handlers[remote_sock_v6.fileno()] = self + remote_sock.setblocking(False) if self._remote_udp: pass @@ -425,6 +428,9 @@ class TCPRelayHandler(object): if self._remote_udp: self._loop.add(remote_sock, eventloop.POLL_IN) + if self._remote_sock_v6: + self._loop.add(self._remote_sock_v6, + eventloop.POLL_IN) else: try: remote_sock.connect((remote_addr, remote_port)) @@ -482,13 +488,16 @@ class TCPRelayHandler(object): (not is_local and self._stage == STAGE_INIT): self._handle_stage_addr(data) - def _on_remote_read(self): + def _on_remote_read(self, is_remote_sock): # handle all remote read events self._update_activity() data = None try: if self._remote_udp: - data, addr = self._remote_sock.recvfrom(BUF_SIZE) + if is_remote_sock: + data, addr = self._remote_sock.recvfrom(BUF_SIZE) + else: + data, addr = self._remote_sock_v6.recvfrom(BUF_SIZE) port = struct.pack('>H', addr[1]) try: ip = socket.inet_aton(addr[0]) @@ -557,13 +566,13 @@ class TCPRelayHandler(object): logging.debug('ignore handle_event: destroyed') return # order is important - if sock == self._remote_sock: + if sock == self._remote_sock or sock == self._remote_sock_v6: if event & eventloop.POLL_ERR: self._on_remote_error() if self._stage == STAGE_DESTROYED: return if event & (eventloop.POLL_IN | eventloop.POLL_HUP): - self._on_remote_read() + self._on_remote_read(sock == self._remote_sock) if self._stage == STAGE_DESTROYED: return if event & eventloop.POLL_OUT: @@ -610,6 +619,12 @@ class TCPRelayHandler(object): del self._fd_to_handlers[self._remote_sock.fileno()] self._remote_sock.close() self._remote_sock = None + if self._remote_sock_v6: + logging.debug('destroying remote') + self._loop.remove(self._remote_sock_v6) + del self._fd_to_handlers[self._remote_sock_v6.fileno()] + self._remote_sock_v6.close() + self._remote_sock_v6 = None if self._local_sock: logging.debug('destroying local') self._loop.remove(self._local_sock)