diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index cfab5f2..225fb28 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -200,10 +200,12 @@ class TCPRelayHandler(object): self._config['fast_open']: try: self._fastopen_connected = True + remote_sock = self._create_remote_socket(self._chosen_server[0], + self._chosen_server[1]) + self._loop.add(remote_sock, eventloop.POLL_ERR) data = ''.join(self._data_to_write_to_local) l = len(data) - s = self._remote_sock.sendto(data, MSG_FASTOPEN, - self._chosen_server) + s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) if s < l: data = data[s:] self._data_to_write_to_local = [data] @@ -282,6 +284,19 @@ class TCPRelayHandler(object): # TODO use logging when debug completed self.destroy() + def _create_remote_socket(self, ip, port): + addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, + socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("getaddrinfo failed for %s:%d" % (ip, port)) + af, socktype, proto, canonname, sa = addrs[0] + remote_sock = socket.socket(af, socktype, proto) + self._remote_sock = remote_sock + self._fd_to_handlers[remote_sock.fileno()] = self + remote_sock.setblocking(False) + remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return remote_sock + def _handle_dns_resolved(self, result, error): if error: logging.error(error) @@ -292,33 +307,22 @@ class TCPRelayHandler(object): if ip: try: self._stage = STAGE_REPLY - remote_addr = self._remote_address[0] - remote_port = self._remote_address[1] + remote_addr = ip if self._is_local: - remote_addr, remote_port = self._chosen_server - addrs = socket.getaddrinfo(ip, remote_port, 0, - socket.SOCK_STREAM, - socket.SOL_TCP) - if len(addrs) == 0: - raise Exception("getaddrinfo failed for %s:%d" % - (remote_addr, remote_port)) - af, socktype, proto, canonname, sa = addrs[0] - remote_sock = socket.socket(af, socktype, proto) - self._remote_sock = remote_sock - self._fd_to_handlers[remote_sock.fileno()] = self - remote_sock.setblocking(False) - remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, - 1) + remote_port = self._chosen_server[1] + else: + remote_port = self._remote_address[1] if self._is_local and self._config['fast_open']: # wait for more data to arrive and send them in one SYN self._stage = STAGE_REPLY - self._loop.add(remote_sock, eventloop.POLL_ERR) self._update_stream(STREAM_UP, WAIT_STATUS_READING) # TODO when there is already data in this packet else: + remote_sock = self._create_remote_socket(remote_addr, + remote_port) try: - remote_sock.connect(sa) + remote_sock.connect((remote_addr, remote_port)) except (OSError, IOError) as e: if eventloop.errno_from_exception(e) == \ errno.EINPROGRESS: @@ -432,23 +436,23 @@ class TCPRelayHandler(object): if sock == self._remote_sock: if event & eventloop.POLL_ERR: self._on_remote_error() - if self._stage == STAGE_DESTROYED: - return + if self._stage == STAGE_DESTROYED: + return if event & (eventloop.POLL_IN | eventloop.POLL_HUP): self._on_remote_read() - if self._stage == STAGE_DESTROYED: - return + if self._stage == STAGE_DESTROYED: + return if event & eventloop.POLL_OUT: self._on_remote_write() elif sock == self._local_sock: if event & eventloop.POLL_ERR: self._on_local_error() - if self._stage == STAGE_DESTROYED: - return + if self._stage == STAGE_DESTROYED: + return if event & (eventloop.POLL_IN | eventloop.POLL_HUP): self._on_local_read() - if self._stage == STAGE_DESTROYED: - return + if self._stage == STAGE_DESTROYED: + return if event & eventloop.POLL_OUT: self._on_local_write() else: