|
|
@ -56,6 +56,7 @@ CMD_UDP_ASSOCIATE = 3 |
|
|
|
STAGE_INIT = 0 |
|
|
|
STAGE_HELLO = 1 |
|
|
|
STAGE_UDP_ASSOC = 2 |
|
|
|
STAGE_DNS = 3 |
|
|
|
STAGE_REPLY = 4 |
|
|
|
STAGE_STREAM = 5 |
|
|
|
STAGE_DESTROYED = -1 |
|
|
@ -75,13 +76,14 @@ BUF_SIZE = 8 * 1024 |
|
|
|
|
|
|
|
class TCPRelayHandler(object): |
|
|
|
def __init__(self, server, fd_to_handlers, loop, local_sock, config, |
|
|
|
is_local): |
|
|
|
dns_resolver, is_local): |
|
|
|
self._server = server |
|
|
|
self._fd_to_handlers = fd_to_handlers |
|
|
|
self._loop = loop |
|
|
|
self._local_sock = local_sock |
|
|
|
self._remote_sock = None |
|
|
|
self._config = config |
|
|
|
self._dns_resolver = dns_resolver |
|
|
|
self._is_local = is_local |
|
|
|
self._stage = STAGE_INIT |
|
|
|
self._encryptor = encrypt.Encryptor(config['password'], |
|
|
@ -239,53 +241,83 @@ class TCPRelayHandler(object): |
|
|
|
addrtype, remote_addr, remote_port, header_length = header_result |
|
|
|
logging.info('connecting %s:%d' % (remote_addr, remote_port)) |
|
|
|
self._remote_address = (remote_addr, remote_port) |
|
|
|
# pause reading |
|
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) |
|
|
|
self._stage = STAGE_DNS |
|
|
|
if self._is_local: |
|
|
|
# forward address to remote |
|
|
|
self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10', |
|
|
|
self._local_sock) |
|
|
|
data_to_send = self._encryptor.encrypt(data) |
|
|
|
self._data_to_write_to_remote.append(data_to_send) |
|
|
|
remote_addr = self._config['server'] |
|
|
|
remote_port = self._config['server_port'] |
|
|
|
# notice here may go into _handle_dns_resolved directly |
|
|
|
self._dns_resolver.resolve(self._config['server'], |
|
|
|
self._handle_dns_resolved) |
|
|
|
else: |
|
|
|
if len(data) > header_length: |
|
|
|
self._data_to_write_to_remote.append(data[header_length:]) |
|
|
|
|
|
|
|
# TODO async DNS |
|
|
|
addrs = socket.getaddrinfo(remote_addr, remote_port, 0, |
|
|
|
socket.SOCK_STREAM, socket.SOL_TCP) |
|
|
|
if len(addrs) == 0: |
|
|
|
raise Exception("can't get addrinfo for %s:%d" % |
|
|
|
(remote_addr, remote_port)) |
|
|
|
af, socktype, proto, canonname, sa = addrs[0] |
|
|
|
remote_sock = socket.socket(af, socktype, proto) |
|
|
|
self._remote_sock = remote_sock |
|
|
|
self._fd_to_handlers[remote_sock.fileno()] = self |
|
|
|
remote_sock.setblocking(False) |
|
|
|
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) |
|
|
|
|
|
|
|
if self._is_local and self._config['fast_open']: |
|
|
|
# wait for more data to arrive and send them in one SYN |
|
|
|
self._stage = STAGE_REPLY |
|
|
|
self._loop.add(remote_sock, eventloop.POLL_ERR) |
|
|
|
# TODO when there is already data in this packet |
|
|
|
else: |
|
|
|
try: |
|
|
|
remote_sock.connect(sa) |
|
|
|
except (OSError, IOError) as e: |
|
|
|
if eventloop.errno_from_exception(e) == errno.EINPROGRESS: |
|
|
|
pass |
|
|
|
self._loop.add(remote_sock, |
|
|
|
eventloop.POLL_ERR | eventloop.POLL_OUT) |
|
|
|
self._stage = STAGE_REPLY |
|
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) |
|
|
|
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) |
|
|
|
# notice here may go into _handle_dns_resolved directly |
|
|
|
self._dns_resolver.resolve(remote_addr, |
|
|
|
self._handle_dns_resolved) |
|
|
|
except Exception as e: |
|
|
|
logging.error(e) |
|
|
|
traceback.print_exc() |
|
|
|
# TODO use logging when debug completed |
|
|
|
self.destroy() |
|
|
|
|
|
|
|
def _handle_dns_resolved(self, result, error): |
|
|
|
if error: |
|
|
|
logging.error(error) |
|
|
|
self.destroy() |
|
|
|
return |
|
|
|
if result: |
|
|
|
ip = result[1] |
|
|
|
if ip: |
|
|
|
try: |
|
|
|
self._stage = STAGE_REPLY |
|
|
|
remote_addr = self._remote_address[0] |
|
|
|
remote_port = self._remote_address[1] |
|
|
|
if self._is_local: |
|
|
|
remote_addr = self._config['server'] |
|
|
|
remote_port = self._config['server_port'] |
|
|
|
addrs = socket.getaddrinfo(ip, remote_port, 0, |
|
|
|
socket.SOCK_STREAM, |
|
|
|
socket.SOL_TCP) |
|
|
|
if len(addrs) == 0: |
|
|
|
raise Exception("getaddrinfo failed for %s:%d" % |
|
|
|
(remote_addr, remote_port)) |
|
|
|
af, socktype, proto, canonname, sa = addrs[0] |
|
|
|
remote_sock = socket.socket(af, socktype, proto) |
|
|
|
self._remote_sock = remote_sock |
|
|
|
self._fd_to_handlers[remote_sock.fileno()] = self |
|
|
|
remote_sock.setblocking(False) |
|
|
|
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, |
|
|
|
1) |
|
|
|
|
|
|
|
if self._is_local and self._config['fast_open']: |
|
|
|
# wait for more data to arrive and send them in one SYN |
|
|
|
self._stage = STAGE_REPLY |
|
|
|
self._loop.add(remote_sock, eventloop.POLL_ERR) |
|
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READING) |
|
|
|
# TODO when there is already data in this packet |
|
|
|
else: |
|
|
|
try: |
|
|
|
remote_sock.connect(sa) |
|
|
|
except (OSError, IOError) as e: |
|
|
|
if eventloop.errno_from_exception(e) == \ |
|
|
|
errno.EINPROGRESS: |
|
|
|
pass |
|
|
|
self._loop.add(remote_sock, |
|
|
|
eventloop.POLL_ERR | eventloop.POLL_OUT) |
|
|
|
self._stage = STAGE_REPLY |
|
|
|
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) |
|
|
|
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) |
|
|
|
return |
|
|
|
except (OSError, IOError) as e: |
|
|
|
logging.error(e) |
|
|
|
traceback.print_exc() |
|
|
|
self.destroy() |
|
|
|
|
|
|
|
def _on_local_read(self): |
|
|
|
self._update_activity() |
|
|
|
if not self._local_sock: |
|
|
@ -422,6 +454,7 @@ class TCPRelayHandler(object): |
|
|
|
del self._fd_to_handlers[self._local_sock.fileno()] |
|
|
|
self._local_sock.close() |
|
|
|
self._local_sock = None |
|
|
|
self._dns_resolver.remove_callback(self._handle_dns_resolved) |
|
|
|
self._server.remove_handler(self) |
|
|
|
|
|
|
|
|
|
|
@ -545,7 +578,8 @@ class TCPRelay(object): |
|
|
|
# logging.debug('accept') |
|
|
|
conn = self._server_socket.accept() |
|
|
|
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, |
|
|
|
conn[0], self._config, self._is_local) |
|
|
|
conn[0], self._config, self._dns_resolver, |
|
|
|
self._is_local) |
|
|
|
except (OSError, IOError) as e: |
|
|
|
error_no = eventloop.errno_from_exception(e) |
|
|
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS): |
|
|
|