From a38db829f2aeaf22ef4f23e62aa88b610d88ea9a Mon Sep 17 00:00:00 2001 From: breakwa11 Date: Wed, 5 Aug 2015 16:26:21 +0800 Subject: [PATCH] new UDP over TCP protocol, merge master --- shadowsocks/asyncdns.py | 60 ++- shadowsocks/common.py | 15 +- shadowsocks/eventloop.py | 120 ++--- shadowsocks/tcprelay.py | 209 ++++---- shadowsocks/udprelay.py | 1010 +++++++++++++++++++++++++++++++++++++- 5 files changed, 1199 insertions(+), 215 deletions(-) diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 7e4a4ed..c5fc99d 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -18,7 +18,6 @@ from __future__ import absolute_import, division, print_function, \ with_statement -import time import os import socket import struct @@ -256,7 +255,6 @@ class DNSResolver(object): self._hostname_to_cb = {} self._cb_to_hostname = {} self._cache = lru_cache.LRUCache(timeout=300) - self._last_time = time.time() self._sock = None self._servers = None self._parse_resolv() @@ -304,7 +302,7 @@ class DNSResolver(object): except IOError: self._hosts['localhost'] = '127.0.0.1' - def add_to_loop(self, loop, ref=False): + def add_to_loop(self, loop): if self._loop: raise Exception('already add to loop') self._loop = loop @@ -312,8 +310,8 @@ class DNSResolver(object): self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP) self._sock.setblocking(False) - loop.add(self._sock, eventloop.POLL_IN) - loop.add_handler(self.handle_events, ref=ref) + loop.add(self._sock, eventloop.POLL_IN, self) + loop.add_periodic(self.handle_periodic) def _call_callback(self, hostname, ip, error=None): callbacks = self._hostname_to_cb.get(hostname, []) @@ -354,30 +352,27 @@ class DNSResolver(object): self._call_callback(hostname, None) break - def handle_events(self, events): - for sock, fd, event in events: - if sock != self._sock: - continue - if event & eventloop.POLL_ERR: - logging.error('dns socket err') - self._loop.remove(self._sock) - self._sock.close() - # TODO when dns server is IPv6 - self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, - socket.SOL_UDP) - self._sock.setblocking(False) - self._loop.add(self._sock, eventloop.POLL_IN) - else: - data, addr = sock.recvfrom(1024) - if addr[0] not in self._servers: - logging.warn('received a packet other than our dns') - break - self._handle_data(data) - break - now = time.time() - if now - self._last_time > CACHE_SWEEP_INTERVAL: - self._cache.sweep() - self._last_time = now + def handle_event(self, sock, fd, event): + if sock != self._sock: + return + if event & eventloop.POLL_ERR: + logging.error('dns socket err') + self._loop.remove(self._sock) + self._sock.close() + # TODO when dns server is IPv6 + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + self._loop.add(self._sock, eventloop.POLL_IN, self) + else: + data, addr = sock.recvfrom(1024) + if addr[0] not in self._servers: + logging.warn('received a packet other than our dns') + return + self._handle_data(data) + + def handle_periodic(self): + self._cache.sweep() def remove_callback(self, callback): hostname = self._cb_to_hostname.get(callback) @@ -430,6 +425,9 @@ class DNSResolver(object): def close(self): if self._sock: + if self._loop: + self._loop.remove_periodic(self.handle_periodic) + self._loop.remove(self._sock) self._sock.close() self._sock = None @@ -437,7 +435,7 @@ class DNSResolver(object): def test(): dns_resolver = DNSResolver() loop = eventloop.EventLoop() - dns_resolver.add_to_loop(loop, ref=True) + dns_resolver.add_to_loop(loop) global counter counter = 0 @@ -451,8 +449,8 @@ def test(): print(result, error) counter += 1 if counter == 9: - loop.remove_handler(dns_resolver.handle_events) dns_resolver.close() + loop.stop() a_callback = callback return a_callback diff --git a/shadowsocks/common.py b/shadowsocks/common.py index 605fbfa..11b0622 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -151,6 +151,15 @@ def pre_parse_header(data): data = data[rand_data_size + 2:] elif datatype == 0x81: data = data[1:] + elif datatype == 0x82 : + if len(data) <= 3: + return None + rand_data_size = struct.unpack('>H', data[1:3])[0] + if rand_data_size + 3 >= len(data): + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + data = data[rand_data_size + 3:] return data def parse_header(data): @@ -158,8 +167,8 @@ def parse_header(data): dest_addr = None dest_port = None header_length = 0 - connecttype = (addrtype & 8) and 1 or 0 - addrtype &= ~8 + connecttype = (addrtype & 0x10) and 1 or 0 + addrtype &= ~0x10 if addrtype == ADDRTYPE_IPV4: if len(data) >= 7: dest_addr = socket.inet_ntoa(data[1:5]) @@ -173,7 +182,7 @@ def parse_header(data): if len(data) >= 2 + addrlen: dest_addr = data[2:2 + addrlen] dest_port = struct.unpack('>H', data[2 + addrlen:4 + - addrlen])[0] + addrlen])[0] header_length = 4 + addrlen else: logging.warn('header is too short') diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 42f9205..b27afe3 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -22,6 +22,7 @@ from __future__ import absolute_import, division, print_function, \ with_statement import os +import time import socket import select import errno @@ -51,23 +52,8 @@ EVENT_NAMES = { POLL_NVAL: 'POLL_NVAL', } - -class EpollLoop(object): - - def __init__(self): - self._epoll = select.epoll() - - def poll(self, timeout): - return self._epoll.poll(timeout) - - def add_fd(self, fd, mode): - self._epoll.register(fd, mode) - - def remove_fd(self, fd): - self._epoll.unregister(fd) - - def modify_fd(self, fd, mode): - self._epoll.modify(fd, mode) +# we check timeouts every TIMEOUT_PRECISION seconds +TIMEOUT_PRECISION = 10 class KqueueLoop(object): @@ -100,17 +86,17 @@ class KqueueLoop(object): results[fd] |= POLL_OUT return results.items() - def add_fd(self, fd, mode): + def register(self, fd, mode): self._fds[fd] = mode self._control(fd, mode, select.KQ_EV_ADD) - def remove_fd(self, fd): + def unregister(self, fd): self._control(fd, self._fds[fd], select.KQ_EV_DELETE) del self._fds[fd] - def modify_fd(self, fd, mode): - self.remove_fd(fd) - self.add_fd(fd, mode) + def modify(self, fd, mode): + self.unregister(fd) + self.register(fd, mode) class SelectLoop(object): @@ -129,7 +115,7 @@ class SelectLoop(object): results[fd] |= p[1] return results.items() - def add_fd(self, fd, mode): + def register(self, fd, mode): if mode & POLL_IN: self._r_list.add(fd) if mode & POLL_OUT: @@ -137,7 +123,7 @@ class SelectLoop(object): if mode & POLL_ERR: self._x_list.add(fd) - def remove_fd(self, fd): + def unregister(self, fd): if fd in self._r_list: self._r_list.remove(fd) if fd in self._w_list: @@ -145,16 +131,15 @@ class SelectLoop(object): if fd in self._x_list: self._x_list.remove(fd) - def modify_fd(self, fd, mode): - self.remove_fd(fd) - self.add_fd(fd, mode) + def modify(self, fd, mode): + self.unregister(fd) + self.register(fd, mode) class EventLoop(object): def __init__(self): - self._iterating = False if hasattr(select, 'epoll'): - self._impl = EpollLoop() + self._impl = select.epoll() model = 'epoll' elif hasattr(select, 'kqueue'): self._impl = KqueueLoop() @@ -165,72 +150,71 @@ class EventLoop(object): else: raise Exception('can not find any available functions in select ' 'package') - self._fd_to_f = {} - self._handlers = [] - self._ref_handlers = [] - self._handlers_to_remove = [] + self._fdmap = {} # (f, handler) + self._last_time = time.time() + self._periodic_callbacks = [] + self._stopping = False logging.debug('using event model: %s', model) def poll(self, timeout=None): events = self._impl.poll(timeout) - return [(self._fd_to_f[fd], fd, event) for fd, event in events] + return [(self._fdmap[fd][0], fd, event) for fd, event in events] - def add(self, f, mode): + def add(self, f, mode, handler): fd = f.fileno() - self._fd_to_f[fd] = f - self._impl.add_fd(fd, mode) + self._fdmap[fd] = (f, handler) + self._impl.register(fd, mode) def remove(self, f): fd = f.fileno() - del self._fd_to_f[fd] - self._impl.remove_fd(fd) + del self._fdmap[fd] + self._impl.unregister(fd) + + def add_periodic(self, callback): + self._periodic_callbacks.append(callback) + + def remove_periodic(self, callback): + self._periodic_callbacks.remove(callback) def modify(self, f, mode): fd = f.fileno() - self._impl.modify_fd(fd, mode) - - def add_handler(self, handler, ref=True): - self._handlers.append(handler) - if ref: - # when all ref handlers are removed, loop stops - self._ref_handlers.append(handler) - - def remove_handler(self, handler): - if handler in self._ref_handlers: - self._ref_handlers.remove(handler) - if self._iterating: - self._handlers_to_remove.append(handler) - else: - self._handlers.remove(handler) + self._impl.modify(fd, mode) + + def stop(self): + self._stopping = True def run(self): events = [] - while self._ref_handlers: + while not self._stopping: + asap = False try: - events = self.poll(1) + events = self.poll(TIMEOUT_PRECISION) except (OSError, IOError) as e: if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): # EPIPE: Happens when the client closes the connection # EINTR: Happens when received a signal # handles them as soon as possible + asap = True logging.debug('poll:%s', e) else: logging.error('poll:%s', e) import traceback traceback.print_exc() continue - self._iterating = True - for handler in self._handlers: - # TODO when there are a lot of handlers - try: - handler(events) - except (OSError, IOError) as e: - shell.print_exception(e) - if self._handlers_to_remove: - for handler in self._handlers_to_remove: - self._handlers.remove(handler) - self._handlers_to_remove = [] - self._iterating = False + + for sock, fd, event in events: + handler = self._fdmap.get(fd, None) + if handler is not None: + handler = handler[1] + try: + handler.handle_event(sock, fd, event) + except (OSError, IOError) as e: + shell.print_exception(e) + now = time.time() + if asap or now - self._last_time >= TIMEOUT_PRECISION: + for callback in self._periodic_callbacks: + callback() + self._last_time = now # from tornado diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 6bb6bbc..8188a00 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -115,6 +115,7 @@ class TCPRelayHandler(object): self._fastopen_connected = False self._data_to_write_to_local = [] self._data_to_write_to_remote = [] + self._udp_data_send_buffer = '' self._upstream_status = WAIT_STATUS_READING self._downstream_status = WAIT_STATUS_INIT self._client_address = local_sock.getpeername()[:2] @@ -128,7 +129,8 @@ class TCPRelayHandler(object): fd_to_handlers[local_sock.fileno()] = self local_sock.setblocking(False) local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) + loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, + self._server) self.last_activity = 0 self._update_activity() @@ -185,6 +187,8 @@ class TCPRelayHandler(object): if self._upstream_status & WAIT_STATUS_WRITING: event |= eventloop.POLL_OUT self._loop.modify(self._remote_sock, event) + if self._remote_sock_v6: + self._loop.modify(self._remote_sock_v6, event) def _write_to_sock(self, data, sock): # write data to sock @@ -193,51 +197,70 @@ class TCPRelayHandler(object): if not data or not sock: return False #logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp)) - if self._remote_udp and self._remote_sock == sock: + uncomplete = False + if self._remote_udp and sock == self._remote_sock: try: - frag = common.ord(data[2]) - if frag != 0: - logging.warn('drop a message since frag is %d' % (frag,)) - return False - else: - data = data[3:] - header_result = parse_header(data) - if header_result is None: - return False - connecttype, dest_addr, dest_port, header_length = header_result - addrs = socket.getaddrinfo(dest_addr, dest_port, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) - if addrs: - af, socktype, proto, canonname, server_addr = addrs[0] - data = data[header_length:] - if af == socket.AF_INET6: - self._remote_sock_v6.sendto(data, (server_addr[0], dest_port)) + self._udp_data_send_buffer += data + #logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data))) + while len(self._udp_data_send_buffer) > 6: + length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0] + + if length > len(self._udp_data_send_buffer): + break + + data = self._udp_data_send_buffer[:length] + self._udp_data_send_buffer = self._udp_data_send_buffer[length:] + + frag = common.ord(data[2]) + if frag != 0: + logging.warn('drop a message since frag is %d' % (frag,)) + continue else: - sock.sendto(data, (server_addr[0], dest_port)) + data = data[3:] + header_result = parse_header(data) + if header_result is None: + continue + connecttype, dest_addr, dest_port, header_length = header_result + addrs = socket.getaddrinfo(dest_addr, dest_port, 0, + socket.SOCK_DGRAM, socket.SOL_UDP) + #logging.info('UDP over TCP sendto %s:%d %d bytes from %s:%d' % (dest_addr, dest_port, len(data), self._client_address[0], self._client_address[1])) + if addrs: + af, socktype, proto, canonname, server_addr = addrs[0] + data = data[header_length:] + if af == socket.AF_INET6: + self._remote_sock_v6.sendto(data, (server_addr[0], dest_port)) + else: + sock.sendto(data, (server_addr[0], dest_port)) except Exception as e: #trace = traceback.format_exc() #logging.error(trace) - logging.error(e) + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + shell.print_exception(e) + self.destroy() + return False return True - - uncomplete = False - try: - l = len(data) - s = sock.send(data) - if s < l: - data = data[s:] - uncomplete = True - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - uncomplete = True - else: - #traceback.print_exc() - shell.print_exception(e) - self.destroy() - return False + else: + try: + l = len(data) + s = sock.send(data) + if s < l: + data = data[s:] + uncomplete = True + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + #traceback.print_exc() + shell.print_exception(e) + self.destroy() + return False if uncomplete: if sock == self._local_sock: self._data_to_write_to_local.append(data) @@ -270,7 +293,7 @@ class TCPRelayHandler(object): remote_sock = \ self._create_remote_socket(self._chosen_server[0], self._chosen_server[1]) - self._loop.add(remote_sock, eventloop.POLL_ERR) + self._loop.add(remote_sock, eventloop.POLL_ERR, self._server) data = b''.join(self._data_to_write_to_remote) l = len(data) s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) @@ -382,6 +405,11 @@ class TCPRelayHandler(object): remote_sock_v6 = socket.socket(af, socktype, proto) self._remote_sock_v6 = remote_sock_v6 self._fd_to_handlers[remote_sock_v6.fileno()] = self + remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) + remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) + remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) + remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) + remote_sock.setblocking(False) if self._remote_udp: @@ -421,10 +449,12 @@ class TCPRelayHandler(object): remote_port) if self._remote_udp: self._loop.add(remote_sock, - eventloop.POLL_IN) + eventloop.POLL_IN, + self._server) if self._remote_sock_v6: self._loop.add(self._remote_sock_v6, - eventloop.POLL_IN) + eventloop.POLL_IN, + self._server) else: try: remote_sock.connect((remote_addr, remote_port)) @@ -433,10 +463,16 @@ class TCPRelayHandler(object): errno.EINPROGRESS: pass self._loop.add(remote_sock, - eventloop.POLL_ERR | eventloop.POLL_OUT) + eventloop.POLL_ERR | eventloop.POLL_OUT, + self._server) self._stage = STAGE_CONNECTING self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + if self._remote_udp: + while self._data_to_write_to_remote: + data = self._data_to_write_to_remote[0] + del self._data_to_write_to_remote[0] + self._write_to_sock(data, self._remote_sock) return except Exception as e: shell.print_exception(e) @@ -495,11 +531,12 @@ class TCPRelayHandler(object): port = struct.pack('>H', addr[1]) try: ip = socket.inet_aton(addr[0]) - data = '\x00\x00\x00\x01' + ip + port + data + data = '\x00\x01' + ip + port + data except Exception as e: ip = socket.inet_pton(socket.AF_INET6, addr[0]) - data = '\x00\x00\x00\x04' + ip + port + data - logging.info('UDP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1])) + data = '\x00\x04' + ip + port + data + data = struct.pack('>H', len(data) + 2) + data + #logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1])) else: data = self._remote_sock.recv(BUF_SIZE) except (OSError, IOError) as e: @@ -637,7 +674,6 @@ class TCPRelay(object): self._closed = False self._eventloop = None self._fd_to_handlers = {} - self._last_time = time.time() self.server_transfer_ul = 0L self.server_transfer_dl = 0L @@ -680,10 +716,9 @@ class TCPRelay(object): if self._closed: raise Exception('already closed') self._eventloop = loop - loop.add_handler(self._handle_events) - self._eventloop.add(self._server_socket, - eventloop.POLL_IN | eventloop.POLL_ERR) + eventloop.POLL_IN | eventloop.POLL_ERR, self) + self._eventloop.add_periodic(self.handle_periodic) def remove_handler(self, handler): index = self._handler_to_timeouts.get(hash(handler), -1) @@ -695,7 +730,7 @@ class TCPRelay(object): def update_activity(self, handler): # set handler to active now = int(time.time()) - if now - handler.last_activity < TIMEOUT_PRECISION: + if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: # thus we can lower timeout modification frequency return handler.last_activity = now @@ -741,53 +776,55 @@ class TCPRelay(object): pos = 0 self._timeout_offset = pos - def _handle_events(self, events): + def handle_event(self, sock, fd, event): # handle events and dispatch to handlers - for sock, fd, event in events: + if sock: + logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, + eventloop.EVENT_NAMES.get(event, event)) + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + # TODO + raise Exception('server_socket error') + try: + logging.debug('accept') + conn = self._server_socket.accept() + TCPRelayHandler(self, self._fd_to_handlers, + self._eventloop, conn[0], self._config, + self._dns_resolver, self._is_local) + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + else: if sock: - logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, - eventloop.EVENT_NAMES.get(event, event)) - if sock == self._server_socket: - if event & eventloop.POLL_ERR: - # TODO - raise Exception('server_socket error') - try: - logging.debug('accept') - conn = self._server_socket.accept() - TCPRelayHandler(self, self._fd_to_handlers, - self._eventloop, conn[0], self._config, - self._dns_resolver, self._is_local) - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - continue - else: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() + handler = self._fd_to_handlers.get(fd, None) + if handler: + handler.handle_event(sock, event) else: - if sock: - handler = self._fd_to_handlers.get(fd, None) - if handler: - handler.handle_event(sock, event) - else: - logging.warn('poll removed fd') + logging.warn('poll removed fd') - now = time.time() - if now - self._last_time > TIMEOUT_PRECISION: - self._sweep_timeout() - self._last_time = now + def handle_periodic(self): if self._closed: if self._server_socket: self._eventloop.remove(self._server_socket) self._server_socket.close() self._server_socket = None - logging.info('closed listen port %d', self._listen_port) + logging.info('closed TCP port %d', self._listen_port) if not self._fd_to_handlers: - self._eventloop.remove_handler(self._handle_events) + logging.info('stopping') + self._eventloop.stop() + self._sweep_timeout() def close(self, next_tick=False): + logging.debug('TCP close') self._closed = True if not next_tick: + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.remove(self._server_socket) self._server_socket.close() diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 1e142bf..018a6a6 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -68,19 +68,760 @@ import logging import struct import errno import random +import binascii +import traceback from shadowsocks import encrypt, eventloop, lru_cache, common, shell from shadowsocks.common import pre_parse_header, parse_header, pack_addr +# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time +TIMEOUTS_CLEAN_SIZE = 512 + +# we check timeouts every TIMEOUT_PRECISION seconds +TIMEOUT_PRECISION = 4 + +# for each handler, we have 2 stream directions: +# upstream: from client to server direction +# read local and write to remote +# downstream: from server to client direction +# read remote and write to local + +STREAM_UP = 0 +STREAM_DOWN = 1 + +# for each stream, it's waiting for reading, or writing, or both +WAIT_STATUS_INIT = 0 +WAIT_STATUS_READING = 1 +WAIT_STATUS_WRITING = 2 +WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING BUF_SIZE = 65536 +DOUBLE_SEND_BEG_IDS = 16 +POST_MTU_MIN = 1000 +POST_MTU_MAX = 1400 + +STAGE_INIT = 0 +STAGE_RSP_ID = 1 +STAGE_DNS = 2 +STAGE_CONNECTING = 3 +STAGE_STREAM = 4 +STAGE_DESTROYED = -1 + +CMD_CONNECT = 0 +CMD_RSP_CONNECT = 1 +CMD_CONNECT_REMOTE = 2 +CMD_RSP_CONNECT_REMOTE = 3 +CMD_POST = 4 +CMD_SYN_STATUS = 5 +CMD_POST_64 = 6 +CMD_SYN_STATUS_64 = 7 +CMD_DISCONNECT = 8 + +CMD_VER_STR = "\x08" + +class UDPLocalAddress(object): + def __init__(self, addr): + self.addr = addr + self.last_activity = time.time() + + def is_timeout(self): + return time.time() - self.last_activity > 30 + +class PacketInfo(object): + def __init__(self, data): + self.data = data + self.time = time.time() + +class SendingQueue(object): + def __init__(self): + self.queue = {} + self.begin_id = 0 + self.end_id = 1 + self.interval = 0.5 + + def append(self, data): + self.queue[self.end_id] = PacketInfo(data) + self.end_id += 1 + return self.end_id - 1 + + def empty(self): + return self.begin_id + 1 == self.end_id + + def size(self): + return self.end_id - self.begin_id - 1 + + def get_begin_id(self): + return self.begin_id + + def get_end_id(self): + return self.end_id + + def get_data_list(self, pack_id_base, pack_id_list): + ret_list = [] + curtime = time.time() + for pack_id in pack_id_list: + offset = pack_id_base + pack_id + if offset <= self.begin_id or self.end_id <= offset: + continue + ret_data = self.queue[offset] + if curtime - ret_data.time > self.interval: + ret_data.time = curtime + ret_list.append( (offset, ret_data.data) ) + return ret_list + + def set_finish(self, begin_id, done_list): + while self.begin_id < begin_id: + self.begin_id += 1 + del self.queue[self.begin_id] + #while len(self.queue) > 0 and self.queue[0][0] <= begin_id: + # del self.queue[0] + # self.begin_id += 1 + +class RecvQueue(object): + def __init__(self): + self.queue = {} + self.miss_queue = set() + self.begin_id = 0 + self.end_id = 1 + + def empty(self): + return self.begin_id + 1 == self.end_id + + def insert(self, pack_id, data): + if (pack_id not in self.queue) and pack_id > self.begin_id: + self.queue[pack_id] = PacketInfo(data) + if self.end_id == pack_id: + self.end_id = pack_id + 1 + elif self.end_id < pack_id: + for eid in xrange(self.end_id, pack_id): + self.miss_queue.add(eid) + self.end_id = pack_id + 1 + else: + self.miss_queue.remove(pack_id) + + def set_end(self, end_id): + if end_id > self.end_id: + for eid in xrange(self.end_id, end_id): + self.miss_queue.add(eid) + self.end_id = end_id + + def get_begin_id(self): + return self.begin_id + + def has_data(self): + return (self.begin_id + 1) in self.queue + + def get_data(self): + if (self.begin_id + 1) in self.queue: + self.begin_id += 1 + pack_id = self.begin_id + ret_data = self.queue[pack_id] + del self.queue[pack_id] + return (pack_id, ret_data.data) + + def get_missing_id(self, begin_id): + missing = [] + if begin_id == 0: + begin_id = self.begin_id + for i in self.miss_queue: + if i - begin_id > 32768: + break + missing.append(i - begin_id) + return (begin_id, missing) + +class TCPRelayHandler(object): + def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop, + local_sock, local_id, client_param, config, + dns_resolver, is_local): + self._server = server + self._reqid_to_handlers = reqid_to_handlers + self._fd_to_handlers = fd_to_handlers + self._loop = loop + self._local_sock = local_sock + self._remote_sock = None + self._remote_udp = False + self._config = config + self._dns_resolver = dns_resolver + self._local_id = local_id + + self._is_local = is_local + self._stage = STAGE_INIT + self._password = config['password'] + self._method = config['method'] + self._fastopen_connected = False + self._data_to_write_to_local = [] + self._data_to_write_to_remote = [] + self._upstream_status = WAIT_STATUS_READING + self._downstream_status = WAIT_STATUS_INIT + self._request_id = 0 + self._client_address = {} + self._remote_address = None + self._sendingqueue = SendingQueue() + self._recvqueue = RecvQueue() + if 'forbidden_ip' in config: + self._forbidden_iplist = config['forbidden_ip'] + else: + self._forbidden_iplist = None + #fd_to_handlers[local_sock.fileno()] = self + #local_sock.setblocking(False) + #loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) + self.last_activity = 0 + self._update_activity() + self._random_mtu_size = [random.randint(POST_MTU_MIN, POST_MTU_MAX) for i in xrange(1024)] + self._random_mtu_index = 0 + + self._rand_data = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10" * 4 + + def __hash__(self): + # default __hash__ is id / 16 + # we want to eliminate collisions + return id(self) + + @property + def remote_address(self): + return self._remote_address + + def add_local_address(self, addr): + self._client_address[addr] = UDPLocalAddress(addr) + + def _update_activity(self): + # tell the TCP Relay we have activities recently + # else it will think we are inactive and timed out + self._server.update_activity(self) + + def _update_stream(self, stream, status): + # update a stream to a new waiting status + + # check if status is changed + # only update if dirty + dirty = False + if stream == STREAM_DOWN: + if self._downstream_status != status: + self._downstream_status = status + dirty = True + elif stream == STREAM_UP: + if self._upstream_status != status: + self._upstream_status = status + dirty = True + if dirty: + ''' + if self._local_sock: + event = eventloop.POLL_ERR + if self._downstream_status & WAIT_STATUS_WRITING: + event |= eventloop.POLL_OUT + if self._upstream_status & WAIT_STATUS_READING: + event |= eventloop.POLL_IN + self._loop.modify(self._local_sock, event) + ''' + if self._remote_sock: + event = eventloop.POLL_ERR + if self._downstream_status & WAIT_STATUS_READING: + event |= eventloop.POLL_IN + if self._upstream_status & WAIT_STATUS_WRITING: + event |= eventloop.POLL_OUT + self._loop.modify(self._remote_sock, event) + + def _write_to_sock(self, data, sock, addr = None): + # write data to sock + # if only some of the data are written, put remaining in the buffer + # and update the stream to wait for writing + if not data or not sock: + return False + + uncomplete = False + retry = 0 + if sock == self._local_sock: + data = encrypt.encrypt_all(self._password, self._method, 1, data) + if addr is None: + return False + try: + self._server.write_to_server_socket(data, addr) + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + uncomplete = True + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + pass + else: + #traceback.print_exc() + shell.print_exception(e) + self.destroy() + return False + else: + try: + l = len(data) + s = sock.send(data) + if s < l: + data = data[s:] + uncomplete = True + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + #logging.error(traceback.extract_stack()) + #traceback.print_exc() + shell.print_exception(e) + self.destroy() + return False + if uncomplete: + if sock == self._local_sock: + #if data is not None and retry < 10: + # self._data_to_write_to_local.append([(data, addr), retry]) + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + elif sock == self._remote_sock: + self._data_to_write_to_remote.append(data) + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) + else: + logging.error('write_all_to_sock:unknown socket') + else: + if sock == self._local_sock: + if self._sendingqueue.size() > 8192: + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + else: + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + elif sock == self._remote_sock: + if self._sendingqueue.size() > 8192: + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + else: + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + else: + logging.error('write_all_to_sock:unknown socket') + return True + + def _create_remote_socket(self, ip, port): + addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("getaddrinfo failed for %s:%d" % (ip, port)) + af, socktype, proto, canonname, sa = addrs[0] + if self._forbidden_iplist: + if common.to_str(sa[0]) in self._forbidden_iplist: + raise Exception('IP %s is in forbidden list, reject' % + common.to_str(sa[0])) + remote_sock = socket.socket(af, socktype, proto) + self._remote_sock = remote_sock + + self._fd_to_handlers[remote_sock.fileno()] = self + + remote_sock.setblocking(False) + remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return remote_sock + + def _handle_dns_resolved(self, result, error): + if error: + self._log_error(error) + self.destroy() + return + if result: + ip = result[1] + if ip: + + try: + self._stage = STAGE_CONNECTING + remote_addr = ip + remote_port = self._remote_address[1] + logging.info("connect to %s : %d" % (remote_addr, remote_port)) + + remote_sock = self._create_remote_socket(remote_addr, + remote_port) + try: + remote_sock.connect((remote_addr, remote_port)) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == \ + errno.EINPROGRESS: + pass + + self._loop.add(remote_sock, + eventloop.POLL_ERR | eventloop.POLL_OUT, + self._server) + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + self._stage = STAGE_STREAM + + for it_addr in self._client_address: + addr = it_addr + break + + for i in xrange(2): + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") + self._write_to_sock(rsp_data, self._local_sock, addr) + + return + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + self.destroy() + + def _on_local_read(self): + # handle all local read events and dispatch them to methods for + # each stage + self._update_activity() + if not self._local_sock: + return + data = None + try: + data = self._local_sock.recv(BUF_SIZE) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): + return + if not data: + self.destroy() + return + if not data: + return + self._server.server_transfer_ul += len(data) + #TODO ============================================================ + if self._stage == STAGE_STREAM: + self._write_to_sock(data, self._remote_sock) + return + def _on_remote_read(self): + # handle all remote read events + self._update_activity() + data = None + try: + data = self._remote_sock.recv(BUF_SIZE) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK + return + if not data: + self.destroy() + return + self._server.server_transfer_dl += len(data) + try: + recv_data = data + beg_pos = 0 + max_len = len(recv_data) + while beg_pos < max_len: + if beg_pos + POST_MTU_MAX >= max_len: + split_pos = max_len + else: + split_pos = beg_pos + self._random_mtu_size[self._random_mtu_index] + self._random_mtu_index = (self._random_mtu_index + 1) & 0x3ff + #split_pos = beg_pos + random.randint(POST_MTU_MIN, POST_MTU_MAX) + data = recv_data[beg_pos:split_pos] + beg_pos = split_pos + + pack_id = self._sendingqueue.append(data) + post_data = self._pack_post_data(CMD_POST, pack_id, data) + for it_addr in self._client_address: + addr = it_addr + break + self._write_to_sock(post_data, self._local_sock, addr) + #if pack_id <= DOUBLE_SEND_BEG_IDS: + # post_data = self._pack_post_data(CMD_POST, pack_id, data) + # self._write_to_sock(post_data, self._local_sock, addr) + + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + # TODO use logging when debug completed + self.destroy() + + def _on_local_write(self): + # handle local writable event + if self._data_to_write_to_local: + data = b''.join(self._data_to_write_to_local) + self._data_to_write_to_local = [] + self._write_to_sock(data, self._local_sock) + else: + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + + def _on_remote_write(self): + # handle remote writable event + self._stage = STAGE_STREAM + if self._data_to_write_to_remote: + data = b''.join(self._data_to_write_to_remote) + self._data_to_write_to_remote = [] + self._write_to_sock(data, self._remote_sock) + else: + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + + def _on_local_error(self): + logging.debug('got local error') + if self._local_sock: + logging.error(eventloop.get_sock_error(self._local_sock)) + self.destroy() + + def _on_remote_error(self): + logging.debug('got remote error') + if self._remote_sock: + logging.error(eventloop.get_sock_error(self._remote_sock)) + self.destroy() + + def _pack_rsp_data(self, cmd, data): + reqid_str = struct.pack(">H", self._request_id) + return ''.join([CMD_VER_STR, chr(cmd), reqid_str, data, self._rand_data[:random.randint(0, len(self._rand_data))], reqid_str]) + + def _pack_rnd_data(self, data): + length = random.randint(0, len(self._rand_data)) + if length == 0: + return data + elif length == 1: + return "\x81" + data + elif length < 256: + return "\x80" + chr(length) + self._rand_data[:length - 2] + data + else: + return "\x82" + struct.pack(">H", length) + self._rand_data[:length - 3] + data + + def _pack_post_data(self, cmd, pack_id, data): + reqid_str = struct.pack(">H", self._request_id) + recv_id = self._recvqueue.get_begin_id() + rsp_data = ''.join([CMD_VER_STR, chr(cmd), reqid_str, struct.pack(">I", recv_id), struct.pack(">I", pack_id), data, reqid_str]) + return rsp_data + + def _pack_post_data_64(self, cmd, send_id, pack_id, data): + reqid_str = struct.pack(">H", self._request_id) + recv_id = self._recvqueue.get_begin_id() + rsp_data = ''.join([CMD_VER_STR, chr(cmd), reqid_str, struct.pack(">Q", recv_id), struct.pack(">Q", pack_id), data, reqid_str]) + return rsp_data + + def sweep_timeout(self): + logging.info("sweep_timeout") + if self._stage == STAGE_STREAM: + pack_id, missing = self._recvqueue.get_missing_id(0) + logging.info("sweep_timeout %s %s" % (pack_id, missing)) + data = '' + for pid in missing: + data += struct.pack(">H", pid) + rsp_data = self._pack_post_data(CMD_SYN_STATUS, pack_id, data) + self._write_to_sock(rsp_data, self._local_sock, addr) + + def handle_stream_sync_status(self, addr, cmd, request_id, pack_id, max_send_id, data): + missing_list = [] + while len(data) >= 2: + pid = struct.unpack(">H", data[0:2])[0] + data = data[2:] + missing_list.append(pid) + done_list = [] + self._recvqueue.set_end(max_send_id) + self._sendingqueue.set_finish(pack_id, done_list) + + if self._stage == STAGE_DESTROYED and self._sendingqueue.empty(): + self.destroy_local() + return + + # post CMD_SYN_STATUS + send_id = self._sendingqueue.get_end_id() + post_pack_id, missing = self._recvqueue.get_missing_id(0) + pack_ids_data = '' + for pid in missing: + pack_ids_data += struct.pack(">H", pid) + + rsp_data = self._pack_rnd_data(self._pack_post_data(CMD_SYN_STATUS, send_id, pack_ids_data)) + self._write_to_sock(rsp_data, self._local_sock, addr) + + send_list = self._sendingqueue.get_data_list(pack_id, missing_list) + for post_pack_id, post_data in send_list: + rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) + self._write_to_sock(rsp_data, self._local_sock, addr) + #if post_pack_id <= DOUBLE_SEND_BEG_IDS: + # rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) + # self._write_to_sock(rsp_data, self._local_sock, addr) + + def handle_client(self, addr, cmd, request_id, data): + self.add_local_address(addr) + if cmd == CMD_DISCONNECT: + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + self.destroy() + self.destroy_local() + return + if self._stage == STAGE_INIT: + if cmd == CMD_CONNECT: + self._request_id = request_id + self._stage = STAGE_RSP_ID + return + if self._request_id != request_id: + return + + if self._stage == STAGE_RSP_ID: + if cmd == CMD_CONNECT: + for i in xrange(2): + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, "\x01") + self._write_to_sock(rsp_data, self._local_sock, addr) + elif cmd == CMD_CONNECT_REMOTE: + local_id = data[0:4] + if self._local_id == local_id: + data = data[4:] + header_result = parse_header(data) + if header_result is None: + return + connecttype, remote_addr, remote_port, header_length = header_result + self._remote_address = (common.to_str(remote_addr), remote_port) + self._stage = STAGE_DNS + self._dns_resolver.resolve(remote_addr, + self._handle_dns_resolved) + logging.info('TCP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1])) + else: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + elif self._stage == STAGE_CONNECTING: + if cmd == CMD_CONNECT_REMOTE: + local_id = data[0:4] + if self._local_id == local_id: + for i in xrange(2): + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") + self._write_to_sock(rsp_data, self._local_sock, addr) + else: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + elif self._stage == STAGE_STREAM: + if len(data) < 4: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + return + local_id = data[0:4] + if self._local_id != local_id: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + return + else: + data = data[4:] + if cmd == CMD_CONNECT_REMOTE: + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") + self._write_to_sock(rsp_data, self._local_sock, addr) + elif cmd == CMD_POST: + recv_id = struct.unpack(">I", data[0:4])[0] + pack_id = struct.unpack(">I", data[4:8])[0] + self._recvqueue.insert(pack_id, data[8:]) + self._sendingqueue.set_finish(recv_id, []) + elif cmd == CMD_POST_64: + recv_id = struct.unpack(">Q", data[0:8])[0] + pack_id = struct.unpack(">Q", data[8:16])[0] + self._recvqueue.insert(pack_id, data[16:]) + self._sendingqueue.set_finish(recv_id, []) + elif cmd == CMD_DISCONNECT: + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + self.destroy() + self.destroy_local() + return + elif cmd == CMD_SYN_STATUS: + pack_id = struct.unpack(">I", data[0:4])[0] + max_send_id = struct.unpack(">I", data[4:8])[0] + data = data[8:] + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + elif cmd == CMD_SYN_STATUS_64: + pack_id = struct.unpack(">Q", data[0:8])[0] + max_send_id = struct.unpack(">Q", data[8:16])[0] + data = data[16:] + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + while self._recvqueue.has_data(): + pack_id, post_data = self._recvqueue.get_data() + self._write_to_sock(post_data, self._remote_sock) + elif self._stage == STAGE_DESTROYED: + local_id = data[0:4] + if self._local_id != local_id: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + self._write_to_sock(rsp_data, self._local_sock, addr) + return + else: + data = data[4:] + if cmd == CMD_SYN_STATUS: + pack_id = struct.unpack(">I", data[0:4])[0] + max_send_id = struct.unpack(">I", data[4:8])[0] + data = data[8:] + logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id)) + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + elif cmd == CMD_SYN_STATUS_64: + pack_id = struct.unpack(">Q", data[0:8])[0] + max_send_id = struct.unpack(">Q", data[8:16])[0] + data = data[16:] + logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id)) + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + + def handle_event(self, sock, event): + # handle all events in this handler and dispatch them to methods + if self._stage == STAGE_DESTROYED: + logging.debug('ignore handle_event: destroyed') + return + # order is important + if sock == self._remote_sock: + if event & eventloop.POLL_ERR: + self._on_remote_error() + if self._stage == STAGE_DESTROYED: + return + if event & (eventloop.POLL_IN | eventloop.POLL_HUP): + self._on_remote_read() + if self._stage == STAGE_DESTROYED: + return + if event & eventloop.POLL_OUT: + self._on_remote_write() + elif sock == self._local_sock: + if event & eventloop.POLL_ERR: + self._on_local_error() + if self._stage == STAGE_DESTROYED: + return + if event & (eventloop.POLL_IN | eventloop.POLL_HUP): + self._on_local_read() + if self._stage == STAGE_DESTROYED: + return + if event & eventloop.POLL_OUT: + self._on_local_write() + else: + logging.warn('unknown socket') + + def _log_error(self, e): + logging.error('%s when handling connection from %s' % + (e, self._client_address.keys())) + + def destroy(self): + # destroy the handler and release any resources + # promises: + # 1. destroy won't make another destroy() call inside + # 2. destroy releases resources so it prevents future call to destroy + # 3. destroy won't raise any exceptions + # if any of the promises are broken, it indicates a bug has been + # introduced! mostly likely memory leaks, etc + #logging.info('tcp destroy called') + if self._stage == STAGE_DESTROYED: + # this couldn't happen + logging.debug('already destroyed') + return + self._stage = STAGE_DESTROYED + if self._remote_address: + logging.debug('destroy: %s:%d' % + self._remote_address) + else: + logging.debug('destroy') + if self._remote_sock: + logging.debug('destroying remote') + self._loop.remove(self._remote_sock) + del self._fd_to_handlers[self._remote_sock.fileno()] + self._remote_sock.close() + self._remote_sock = None + if self._sendingqueue.empty(): + self.destroy_local() + self._dns_resolver.remove_callback(self._handle_dns_resolved) + + def destroy_local(self): + if self._local_sock: + logging.debug('disconnect local') + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + addr = None + for it_addr in self._client_address: + addr = it_addr + break + self._write_to_sock(rsp_data, self._local_sock, addr) + self._local_sock = None + del self._reqid_to_handlers[self._request_id] + self._server.remove_handler(self) def client_key(source_addr, server_af): # notice this is server af, not dest af return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) - class UDPRelay(object): def __init__(self, config, dns_resolver, is_local): self._config = config @@ -106,8 +847,19 @@ class UDPRelay(object): self._dns_cache = lru_cache.LRUCache(timeout=300) self._eventloop = None self._closed = False - self._last_time = time.time() + self.server_transfer_ul = 0L + self.server_transfer_dl = 0L + self._sockets = set() + self._fd_to_handlers = {} + self._reqid_to_hd = {} + self._data_to_write_to_server_socket = [] + + self._timeouts = [] # a list for all the handlers + # we trim the timeouts once a while + self._timeout_offset = 0 # last checked position for timeout + self._handler_to_timeouts = {} # key: handler value: index in timeouts + if 'forbidden_ip' in config: self._forbidden_iplist = config['forbidden_ip'] else: @@ -122,6 +874,8 @@ class UDPRelay(object): server_socket = socket.socket(af, socktype, proto) server_socket.bind((self._listen_addr, self._listen_port)) server_socket.setblocking(False) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) self._server_socket = server_socket def _get_a_server(self): @@ -143,6 +897,41 @@ class UDPRelay(object): # just an address pass + def _pre_parse_udp_header(self, data): + if data is None: + return + datatype = ord(data[0]) + if datatype == 0x8: + if len(data) >= 8: + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + cmd = ord(data[1]) + request_id = struct.unpack('>H', data[2:4])[0] + data = data[4:-4] + return (cmd, request_id, data) + elif len(data) >= 6 and ord(data[1]) == 0x0: + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + cmd = ord(data[1]) + data = data[2:-4] + return (cmd, 0, data) + else: + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + return data + + def _pack_rsp_data(self, cmd, request_id, data): + _rand_data = "123456789abcdefghijklmnopqrstuvwxyz" * 2 + reqid_str = struct.pack(">H", request_id) + return ''.join([CMD_VER_STR, chr(cmd), reqid_str, data, _rand_data[:random.randint(0, len(_rand_data))], reqid_str]) + def _handle_server(self): server = self._server_socket data, r_addr = server.recvfrom(BUF_SIZE) @@ -162,11 +951,75 @@ class UDPRelay(object): logging.debug('UDP handle_server: data is empty after decrypt') return + #logging.info("UDP data %s" % (binascii.hexlify(data),)) if not self._is_local: data = pre_parse_header(data) + + data = self._pre_parse_udp_header(data) if data is None: return + if type(data) is tuple: + #(cmd, request_id, data) + #logging.info("UDP data %d %d %s" % (data[0], data[1], binascii.hexlify(data[2]))) + try: + if data[0] == 0: + if len(data[2]) >= 4: + for i in xrange(64): + req_id = random.randint(1, 65535) + if req_id not in self._reqid_to_hd: + break + if req_id in self._reqid_to_hd: + for i in xrange(64): + req_id = random.randint(1, 65535) + if type(self._reqid_to_hd[req_id]) is tuple: + break + # return req id + self._reqid_to_hd[req_id] = (data[2][0:4], None) + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, "\x01") + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + elif data[0] == CMD_CONNECT_REMOTE: + if len(data[2]) > 4 and data[1] in self._reqid_to_hd: + # create + if type(self._reqid_to_hd[data[1]]) is tuple: + if data[2][0:4] == self._reqid_to_hd[data[1]][0]: + handle = TCPRelayHandler(self, self._reqid_to_hd, self._fd_to_handlers, + self._eventloop, self._server_socket, + self._reqid_to_hd[data[1]][0], self._reqid_to_hd[data[1]][1], + self._config, self._dns_resolver, self._is_local) + self._reqid_to_hd[data[1]] = handle + handle.handle_client(r_addr, CMD_CONNECT, data[1], data[2]) + handle.handle_client(r_addr, *data) + self.update_activity(handle) + else: + # disconnect + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + else: + self.update_activity(self._reqid_to_hd[data[1]]) + self._reqid_to_hd[data[1]].handle_client(r_addr, *data) + else: + # disconnect + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT: + if data[1] in self._reqid_to_hd: + self.update_activity(self._reqid_to_hd[data[1]]) + self._reqid_to_hd[data[1]].handle_client(r_addr, *data) + else: + # disconnect + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + return + except Exception as e: + trace = traceback.format_exc() + logging.error(trace) + return + header_result = parse_header(data) if header_result is None: return @@ -205,7 +1058,7 @@ class UDPRelay(object): self._client_fd_to_server_addr[client.fileno()] = r_addr self._sockets.add(client.fileno()) - self._eventloop.add(client, eventloop.POLL_IN) + self._eventloop.add(client, eventloop.POLL_IN, self) if self._is_local: data = encrypt.encrypt_all(self._password, self._method, 1, data) @@ -216,6 +1069,7 @@ class UDPRelay(object): if not data: return try: + #logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data))) client.sendto(data, (server_addr, server_port)) except IOError as e: err = eventloop.errno_from_exception(e) @@ -247,50 +1101,152 @@ class UDPRelay(object): header_result = parse_header(data) if header_result is None: return - # connecttype, dest_addr, dest_port, header_length = header_result + connecttype, dest_addr, dest_port, header_length = header_result + #logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port)) + response = b'\x00\x00\x00' + data client_addr = self._client_fd_to_server_addr.get(sock.fileno()) if client_addr: - self._server_socket.sendto(response, client_addr) + self.write_to_server_socket(response, client_addr) else: # this packet is from somewhere else we know # simply drop that packet pass + def write_to_server_socket(self, data, addr): + #self._server_socket.sendto(data, addr) + #''' + uncomplete = False + retry = 0 + try: + #""" + #if self._data_to_write_to_server_socket: + # self._data_to_write_to_server_socket.append([(data, addr), 0]) + #else: + self._server_socket.sendto(data, addr) + data = None + while self._data_to_write_to_server_socket: + data_buf = self._data_to_write_to_server_socket[0] + retry = data_buf[1] + 1 + del self._data_to_write_to_server_socket[0] + data, addr = data_buf[0] + self._server_socket.sendto(data, addr) + #""" + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + uncomplete = True + if error_no in (errno.EWOULDBLOCK,): + pass + else: + shell.print_exception(e) + return False + #if uncomplete and data is not None and retry < 3: + # self._data_to_write_to_server_socket.append([(data, addr), retry]) + #''' + def add_to_loop(self, loop): if self._eventloop: raise Exception('already add to loop') if self._closed: raise Exception('already closed') self._eventloop = loop - loop.add_handler(self._handle_events) server_socket = self._server_socket self._eventloop.add(server_socket, - eventloop.POLL_IN | eventloop.POLL_ERR) - - def _handle_events(self, events): - for sock, fd, event in events: - if sock == self._server_socket: - if event & eventloop.POLL_ERR: - logging.error('UDP server_socket err') - self._handle_server() - elif sock and (fd in self._sockets): - if event & eventloop.POLL_ERR: - logging.error('UDP client_socket err') - self._handle_client(sock) - now = time.time() - if now - self._last_time > 3: - self._cache.sweep() - self._client_fd_to_server_addr.sweep() - self._last_time = now + eventloop.POLL_IN | eventloop.POLL_ERR, self) + loop.add_periodic(self.handle_periodic) + + def remove_handler(self, handler): + index = self._handler_to_timeouts.get(hash(handler), -1) + if index >= 0: + # delete is O(n), so we just set it to None + self._timeouts[index] = None + del self._handler_to_timeouts[hash(handler)] + + def update_activity(self, handler): + # set handler to active + now = int(time.time()) + if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: + # thus we can lower timeout modification frequency + return + handler.last_activity = now + index = self._handler_to_timeouts.get(hash(handler), -1) + if index >= 0: + # delete is O(n), so we just set it to None + self._timeouts[index] = None + length = len(self._timeouts) + self._timeouts.append(handler) + self._handler_to_timeouts[hash(handler)] = length + + def _sweep_timeout(self): + # tornado's timeout memory management is more flexible than we need + # we just need a sorted last_activity queue and it's faster than heapq + # in fact we can do O(1) insertion/remove so we invent our own + if self._timeouts: + logging.log(shell.VERBOSE_LEVEL, 'sweeping timeouts') + now = time.time() + length = len(self._timeouts) + pos = self._timeout_offset + while pos < length: + handler = self._timeouts[pos] + if handler: + if now - handler.last_activity < self._timeout: + break + else: + if handler.remote_address: + logging.warn('timed out: %s:%d' % + handler.remote_address) + else: + logging.warn('timed out') + handler.destroy() + handler.destroy_local() + self._timeouts[pos] = None # free memory + pos += 1 + else: + pos += 1 + if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: + # clean up the timeout queue when it gets larger than half + # of the queue + self._timeouts = self._timeouts[pos:] + for key in self._handler_to_timeouts: + self._handler_to_timeouts[key] -= pos + pos = 0 + self._timeout_offset = pos + + def handle_event(self, sock, fd, event): + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + logging.error('UDP server_socket err') + self._handle_server() + elif sock and (fd in self._sockets): + if event & eventloop.POLL_ERR: + logging.error('UDP client_socket err') + self._handle_client(sock) + else: + if sock: + handler = self._fd_to_handlers.get(fd, None) + if handler: + handler.handle_event(sock, event) + else: + logging.warn('poll removed fd') + + def handle_periodic(self): if self._closed: - self._server_socket.close() - for sock in self._sockets: - sock.close() - self._eventloop.remove_handler(self._handle_events) + if self._server_socket: + self._server_socket.close() + self._server_socket = None + for sock in self._sockets: + sock.close() + logging.info('closed UDP port %d', self._listen_port) + self._cache.sweep() + self._client_fd_to_server_addr.sweep() + self._sweep_timeout() def close(self, next_tick=False): + logging.debug('UDP close') self._closed = True if not next_tick: + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.remove(self._server_socket) self._server_socket.close()