diff --git a/shadowsocks/common.py b/shadowsocks/common.py index 11b0622..cc88d5d 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -21,7 +21,7 @@ from __future__ import absolute_import, division, print_function, \ import socket import struct import logging - +import binascii def compat_ord(s): if type(s) == int: @@ -140,7 +140,7 @@ def pack_addr(address): def pre_parse_header(data): datatype = ord(data[0]) - if datatype == 0x80 : + if datatype == 0x80: if len(data) <= 2: return None rand_data_size = ord(data[1]) @@ -151,7 +151,7 @@ def pre_parse_header(data): data = data[rand_data_size + 2:] elif datatype == 0x81: data = data[1:] - elif datatype == 0x82 : + elif datatype == 0x82: if len(data) <= 3: return None rand_data_size = struct.unpack('>H', data[1:3])[0] @@ -160,6 +160,21 @@ def pre_parse_header(data): 'encryption method') return None data = data[rand_data_size + 3:] + elif datatype == 0x88: + if len(data) <= 7 + 7: + return None + data_size = struct.unpack('>H', data[1:3])[0] + ogn_data = data + data = data[:data_size] + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + start_pos = 3 + ord(data[3]) + data = data[start_pos:-4] + if data_size < len(ogn_data): + data += ogn_data[data_size:] return data def parse_header(data): diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index b27afe3..ce9c11b 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -98,6 +98,9 @@ class KqueueLoop(object): self.unregister(fd) self.register(fd, mode) + def close(self): + self.kqueue.close() + class SelectLoop(object): @@ -135,6 +138,9 @@ class SelectLoop(object): self.unregister(fd) self.register(fd, mode) + def close(self): + pass + class EventLoop(object): def __init__(self): @@ -216,6 +222,9 @@ class EventLoop(object): callback() self._last_time = now + def __del__(self): + self._impl.close() + # from tornado def errno_from_exception(e): diff --git a/shadowsocks/manager.py b/shadowsocks/manager.py new file mode 100644 index 0000000..e8009b4 --- /dev/null +++ b/shadowsocks/manager.py @@ -0,0 +1,286 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import errno +import traceback +import socket +import logging +import json +import collections + +from shadowsocks import common, eventloop, tcprelay, udprelay, asyncdns, shell + + +BUF_SIZE = 1506 +STAT_SEND_LIMIT = 100 + + +class Manager(object): + + def __init__(self, config): + self._config = config + self._relays = {} # (tcprelay, udprelay) + self._loop = eventloop.EventLoop() + self._dns_resolver = asyncdns.DNSResolver() + self._dns_resolver.add_to_loop(self._loop) + + self._statistics = collections.defaultdict(int) + self._control_client_addr = None + try: + manager_address = config['manager_address'] + if ':' in manager_address: + addr = manager_address.rsplit(':', 1) + addr = addr[0], int(addr[1]) + addrs = socket.getaddrinfo(addr[0], addr[1]) + if addrs: + family = addrs[0][0] + else: + logging.error('invalid address: %s', manager_address) + exit(1) + else: + addr = manager_address + family = socket.AF_UNIX + self._control_socket = socket.socket(family, + socket.SOCK_DGRAM) + self._control_socket.bind(addr) + self._control_socket.setblocking(False) + except (OSError, IOError) as e: + logging.error(e) + logging.error('can not bind to manager address') + exit(1) + self._loop.add(self._control_socket, + eventloop.POLL_IN, self) + self._loop.add_periodic(self.handle_periodic) + + port_password = config['port_password'] + del config['port_password'] + for port, password in port_password.items(): + a_config = config.copy() + a_config['server_port'] = int(port) + a_config['password'] = password + self.add_port(a_config) + + def add_port(self, config): + port = int(config['server_port']) + servers = self._relays.get(port, None) + if servers: + logging.error("server already exists at %s:%d" % (config['server'], + port)) + return + logging.info("adding server at %s:%d" % (config['server'], port)) + t = tcprelay.TCPRelay(config, self._dns_resolver, False, + self.stat_callback) + u = udprelay.UDPRelay(config, self._dns_resolver, False, + self.stat_callback) + t.add_to_loop(self._loop) + u.add_to_loop(self._loop) + self._relays[port] = (t, u) + + def remove_port(self, config): + port = int(config['server_port']) + servers = self._relays.get(port, None) + if servers: + logging.info("removing server at %s:%d" % (config['server'], port)) + t, u = servers + t.close(next_tick=False) + u.close(next_tick=False) + del self._relays[port] + else: + logging.error("server not exist at %s:%d" % (config['server'], + port)) + + def handle_event(self, sock, fd, event): + if sock == self._control_socket and event == eventloop.POLL_IN: + data, self._control_client_addr = sock.recvfrom(BUF_SIZE) + parsed = self._parse_command(data) + if parsed: + command, config = parsed + a_config = self._config.copy() + if config: + # let the command override the configuration file + a_config.update(config) + if 'server_port' not in a_config: + logging.error('can not find server_port in config') + else: + if command == 'add': + self.add_port(a_config) + self._send_control_data(b'ok') + elif command == 'remove': + self.remove_port(a_config) + self._send_control_data(b'ok') + elif command == 'ping': + self._send_control_data(b'pong') + else: + logging.error('unknown command %s', command) + + def _parse_command(self, data): + # commands: + # add: {"server_port": 8000, "password": "foobar"} + # remove: {"server_port": 8000"} + data = common.to_str(data) + parts = data.split(':', 1) + if len(parts) < 2: + return data, None + command, config_json = parts + try: + config = shell.parse_json_in_str(config_json) + return command, config + except Exception as e: + logging.error(e) + return None + + def stat_callback(self, port, data_len): + self._statistics[port] += data_len + + def handle_periodic(self): + r = {} + i = 0 + + def send_data(data_dict): + if data_dict: + # use compact JSON format (without space) + data = common.to_bytes(json.dumps(data_dict, + separators=(',', ':'))) + self._send_control_data(b'stat: ' + data) + + for k, v in self._statistics.items(): + r[k] = v + i += 1 + # split the data into segments that fit in UDP packets + if i >= STAT_SEND_LIMIT: + send_data(r) + r.clear() + send_data(r) + self._statistics.clear() + + def _send_control_data(self, data): + if self._control_client_addr: + try: + self._control_socket.sendto(data, self._control_client_addr) + except (socket.error, OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + + def run(self): + self._loop.run() + + +def run(config): + Manager(config).run() + + +def test(): + import time + import threading + import struct + from shadowsocks import encrypt + + logging.basicConfig(level=5, + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + enc = [] + eventloop.TIMEOUT_PRECISION = 1 + + def run_server(): + config = { + 'server': '127.0.0.1', + 'local_port': 1081, + 'port_password': { + '8381': 'foobar1', + '8382': 'foobar2' + }, + 'method': 'aes-256-cfb', + 'manager_address': '127.0.0.1:6001', + 'timeout': 60, + 'fast_open': False, + 'verbose': 2 + } + manager = Manager(config) + enc.append(manager) + manager.run() + + t = threading.Thread(target=run_server) + t.start() + time.sleep(1) + manager = enc[0] + cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + cli.connect(('127.0.0.1', 6001)) + + # test add and remove + time.sleep(1) + cli.send(b'add: {"server_port":7001, "password":"asdfadsfasdf"}') + time.sleep(1) + assert 7001 in manager._relays + data, addr = cli.recvfrom(1506) + assert b'ok' in data + + cli.send(b'remove: {"server_port":8381}') + time.sleep(1) + assert 8381 not in manager._relays + data, addr = cli.recvfrom(1506) + assert b'ok' in data + logging.info('add and remove test passed') + + # test statistics for TCP + header = common.pack_addr(b'google.com') + struct.pack('>H', 80) + data = encrypt.encrypt_all(b'asdfadsfasdf', 'aes-256-cfb', 1, + header + b'GET /\r\n\r\n') + tcp_cli = socket.socket() + tcp_cli.connect(('127.0.0.1', 7001)) + tcp_cli.send(data) + tcp_cli.recv(4096) + tcp_cli.close() + + data, addr = cli.recvfrom(1506) + data = common.to_str(data) + assert data.startswith('stat: ') + data = data.split('stat:')[1] + stats = shell.parse_json_in_str(data) + assert '7001' in stats + logging.info('TCP statistics test passed') + + # test statistics for UDP + header = common.pack_addr(b'127.0.0.1') + struct.pack('>H', 80) + data = encrypt.encrypt_all(b'foobar2', 'aes-256-cfb', 1, + header + b'test') + udp_cli = socket.socket(type=socket.SOCK_DGRAM) + udp_cli.sendto(data, ('127.0.0.1', 8382)) + tcp_cli.close() + + data, addr = cli.recvfrom(1506) + data = common.to_str(data) + assert data.startswith('stat: ') + data = data.split('stat:')[1] + stats = json.loads(data) + assert '8382' in stats + logging.info('UDP statistics test passed') + + manager._loop.stop() + t.join() + + +if __name__ == '__main__': + test() diff --git a/shadowsocks/server.py b/shadowsocks/server.py index d919092..68a6716 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -24,7 +24,8 @@ import logging import signal sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) -from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns +from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \ + asyncdns, manager def main(): @@ -48,10 +49,17 @@ def main(): else: config['port_password'][str(server_port)] = config['password'] + if config.get('manager_address', 0): + logging.info('entering manager mode') + manager.run(config) + return + tcp_servers = [] udp_servers = [] dns_resolver = asyncdns.DNSResolver() - for port, password in config['port_password'].items(): + port_password = config['port_password'] + del config['port_password'] + for port, password in port_password.items(): a_config = config.copy() ipv6_ok = False logging.info("server start with password [%s] method [%s]" % (password, a_config['method'])) diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py index f8ae81f..c91fc22 100644 --- a/shadowsocks/shell.py +++ b/shadowsocks/shell.py @@ -136,7 +136,7 @@ def get_config(is_local): else: shortopts = 'hd:s:p:k:m:c:t:vq' longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=', - 'forbidden-ip=', 'user=', 'version'] + 'forbidden-ip=', 'user=', 'manager-address=', 'version'] try: config_path = find_config() optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts) @@ -148,8 +148,7 @@ def get_config(is_local): logging.info('loading config from %s' % config_path) with open(config_path, 'rb') as f: try: - config = json.loads(f.read().decode('utf8'), - object_hook=_decode_dict) + config = parse_json_in_str(f.read().decode('utf8')) except ValueError as e: logging.error('found an error in config.json: %s', e.message) @@ -181,6 +180,8 @@ def get_config(is_local): config['fast_open'] = True elif key == '--workers': config['workers'] = int(value) + elif key == '--manager-address': + config['manager_address'] = value elif key == '--user': config['user'] = to_str(value) elif key == '--forbidden-ip': @@ -317,6 +318,7 @@ Proxy options: --fast-open use TCP_FASTOPEN, requires Linux 3.7+ --workers WORKERS number of workers, available on Unix/Linux --forbidden-ip IPLIST comma seperated IP list forbidden to connect + --manager-address ADDR optional server manager UDP address, see wiki General options: -h, --help show this help message and exit @@ -356,3 +358,8 @@ def _decode_dict(data): value = _decode_dict(value) rv[key] = value return rv + + +def parse_json_in_str(data): + # parse json and convert everything from unicode to str + return json.loads(data, object_hook=_decode_dict) diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 8188a00..8c53113 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -23,6 +23,7 @@ import socket import errno import struct import logging +import binascii import traceback import random @@ -32,9 +33,6 @@ from shadowsocks.common import pre_parse_header, parse_header # we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time TIMEOUTS_CLEAN_SIZE = 512 -# we check timeouts every TIMEOUT_PRECISION seconds -TIMEOUT_PRECISION = 4 - MSG_FASTOPEN = 0x20000000 # SOCKS command definition @@ -153,10 +151,10 @@ class TCPRelayHandler(object): logging.debug('chosen server: %s:%d', server, server_port) return server, server_port - def _update_activity(self): + def _update_activity(self, data_len=0): # tell the TCP Relay we have activities recently # else it will think we are inactive and timed out - self._server.update_activity(self) + self._server.update_activity(self, data_len) def _update_stream(self, stream, status): # update a stream to a new waiting status @@ -343,6 +341,8 @@ class TCPRelayHandler(object): logging.error('unknown command %d', cmd) self.destroy() return + if False and ord(data[0]) != 0x88: # force new header + raise Exception('can not parse header') data = pre_parse_header(data) if data is None: raise Exception('can not parse header') @@ -379,7 +379,6 @@ class TCPRelayHandler(object): self._log_error(e) if self._config['verbose']: traceback.print_exc() - # TODO use logging when debug completed self.destroy() def _create_remote_socket(self, ip, port): @@ -397,7 +396,6 @@ class TCPRelayHandler(object): common.to_str(sa[0])) remote_sock = socket.socket(af, socktype, proto) self._remote_sock = remote_sock - self._fd_to_handlers[remote_sock.fileno()] = self if self._remote_udp: @@ -410,7 +408,6 @@ class TCPRelayHandler(object): remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) - remote_sock.setblocking(False) if self._remote_udp: pass @@ -483,7 +480,6 @@ class TCPRelayHandler(object): def _on_local_read(self): # handle all local read events and dispatch them to methods for # each stage - self._update_activity() if not self._local_sock: return is_local = self._is_local @@ -497,6 +493,7 @@ class TCPRelayHandler(object): if not data: self.destroy() return + self._update_activity(len(data)) if not is_local: data = self._encryptor.decrypt(data) if not data: @@ -520,7 +517,6 @@ class TCPRelayHandler(object): def _on_remote_read(self, is_remote_sock): # handle all remote read events - self._update_activity() data = None try: if self._remote_udp: @@ -547,6 +543,7 @@ class TCPRelayHandler(object): self.destroy() return self._server.server_transfer_dl += len(data) + self._update_activity(len(data)) if self._is_local: data = self._encryptor.decrypt(data) else: @@ -667,7 +664,7 @@ class TCPRelayHandler(object): class TCPRelay(object): - def __init__(self, config, dns_resolver, is_local): + def __init__(self, config, dns_resolver, is_local, stat_callback=None): self._config = config self._is_local = is_local self._dns_resolver = dns_resolver @@ -709,6 +706,7 @@ class TCPRelay(object): self._config['fast_open'] = False server_socket.listen(1024) self._server_socket = server_socket + self._stat_callback = stat_callback def add_to_loop(self, loop): if self._eventloop: @@ -727,7 +725,10 @@ class TCPRelay(object): self._timeouts[index] = None del self._handler_to_timeouts[hash(handler)] - def update_activity(self, handler): + def update_activity(self, handler, data_len): + if data_len and self._stat_callback: + self._stat_callback(self._listen_port, data_len) + # set handler to active now = int(time.time()) if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: @@ -828,3 +829,5 @@ class TCPRelay(object): self._eventloop.remove_periodic(self.handle_periodic) self._eventloop.remove(self._server_socket) self._server_socket.close() + for handler in list(self._fd_to_handlers.values()): + handler.destroy() diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 018a6a6..3ea9b6d 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -77,9 +77,6 @@ from shadowsocks.common import pre_parse_header, parse_header, pack_addr # we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time TIMEOUTS_CLEAN_SIZE = 512 -# we check timeouts every TIMEOUT_PRECISION seconds -TIMEOUT_PRECISION = 4 - # for each handler, we have 2 stream directions: # upstream: from client to server direction # read local and write to remote @@ -97,8 +94,9 @@ WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING BUF_SIZE = 65536 DOUBLE_SEND_BEG_IDS = 16 -POST_MTU_MIN = 1000 +POST_MTU_MIN = 500 POST_MTU_MAX = 1400 +SENDING_WINDOW_SIZE = 8192 STAGE_INIT = 0 STAGE_RSP_ID = 1 @@ -119,6 +117,14 @@ CMD_DISCONNECT = 8 CMD_VER_STR = "\x08" +RSP_STATE_EMPTY = "" +RSP_STATE_REJECT = "\x00" +RSP_STATE_CONNECTED = "\x01" +RSP_STATE_CONNECTEDREMOTE = "\x02" +RSP_STATE_ERROR = "\x03" +RSP_STATE_DISCONNECT = "\x04" +RSP_STATE_REDIRECT = "\x05" + class UDPLocalAddress(object): def __init__(self, addr): self.addr = addr @@ -173,9 +179,6 @@ class SendingQueue(object): while self.begin_id < begin_id: self.begin_id += 1 del self.queue[self.begin_id] - #while len(self.queue) > 0 and self.queue[0][0] <= begin_id: - # del self.queue[0] - # self.begin_id += 1 class RecvQueue(object): def __init__(self): @@ -229,6 +232,38 @@ class RecvQueue(object): missing.append(i - begin_id) return (begin_id, missing) +class AddressMap(object): + def __init__(self): + self._queue = [] + self._addr_map = {} + + def add(self, addr): + if addr in self._addr_map: + self._addr_map[addr] = UDPLocalAddress(addr) + else: + self._addr_map[addr] = UDPLocalAddress(addr) + self._queue.append(addr) + + def keys(self): + return self._queue + + def get(self): + if self._queue: + while True: + if len(self._queue) == 1: + return self._queue[0] + index = random.randint(0, len(self._queue) - 1) + addr = self._queue[index] + if self._addr_map[addr].is_timeout(): + self._queue[index] = self._queue[len(self._queue) - 1] + del self._queue[len(self._queue) - 1] + del self._addr_map[addr] + else: + break + return addr + else: + return None + class TCPRelayHandler(object): def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop, local_sock, local_id, client_param, config, @@ -254,7 +289,7 @@ class TCPRelayHandler(object): self._upstream_status = WAIT_STATUS_READING self._downstream_status = WAIT_STATUS_INIT self._request_id = 0 - self._client_address = {} + self._client_address = AddressMap() self._remote_address = None self._sendingqueue = SendingQueue() self._recvqueue = RecvQueue() @@ -282,7 +317,10 @@ class TCPRelayHandler(object): return self._remote_address def add_local_address(self, addr): - self._client_address[addr] = UDPLocalAddress(addr) + self._client_address.add(addr) + + def get_local_address(self): + return self._client_address.get() def _update_activity(self): # tell the TCP Relay we have activities recently @@ -367,8 +405,6 @@ class TCPRelayHandler(object): return False if uncomplete: if sock == self._local_sock: - #if data is not None and retry < 10: - # self._data_to_write_to_local.append([(data, addr), retry]) self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) elif sock == self._remote_sock: self._data_to_write_to_remote.append(data) @@ -377,15 +413,12 @@ class TCPRelayHandler(object): logging.error('write_all_to_sock:unknown socket') else: if sock == self._local_sock: - if self._sendingqueue.size() > 8192: + if self._sendingqueue.size() > SENDING_WINDOW_SIZE: self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) else: self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) elif sock == self._remote_sock: - if self._sendingqueue.size() > 8192: - self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) - else: - self._update_stream(STREAM_UP, WAIT_STATUS_READING) + self._update_stream(STREAM_UP, WAIT_STATUS_READING) else: logging.error('write_all_to_sock:unknown socket') return True @@ -439,12 +472,10 @@ class TCPRelayHandler(object): self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._stage = STAGE_STREAM - for it_addr in self._client_address: - addr = it_addr - break + addr = self.get_local_address() for i in xrange(2): - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) self._write_to_sock(rsp_data, self._local_sock, addr) return @@ -508,13 +539,11 @@ class TCPRelayHandler(object): pack_id = self._sendingqueue.append(data) post_data = self._pack_post_data(CMD_POST, pack_id, data) - for it_addr in self._client_address: - addr = it_addr - break + addr = self.get_local_address() self._write_to_sock(post_data, self._local_sock, addr) - #if pack_id <= DOUBLE_SEND_BEG_IDS: - # post_data = self._pack_post_data(CMD_POST, pack_id, data) - # self._write_to_sock(post_data, self._local_sock, addr) + if pack_id <= DOUBLE_SEND_BEG_IDS: + post_data = self._pack_post_data(CMD_POST, pack_id, data) + self._write_to_sock(post_data, self._local_sock, addr) except Exception as e: shell.print_exception(e) @@ -620,14 +649,14 @@ class TCPRelayHandler(object): for post_pack_id, post_data in send_list: rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) self._write_to_sock(rsp_data, self._local_sock, addr) - #if post_pack_id <= DOUBLE_SEND_BEG_IDS: - # rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) - # self._write_to_sock(rsp_data, self._local_sock, addr) + if post_pack_id <= DOUBLE_SEND_BEG_IDS: + rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) + self._write_to_sock(rsp_data, self._local_sock, addr) def handle_client(self, addr, cmd, request_id, data): self.add_local_address(addr) if cmd == CMD_DISCONNECT: - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) self.destroy() self.destroy_local() @@ -643,7 +672,7 @@ class TCPRelayHandler(object): if self._stage == STAGE_RSP_ID: if cmd == CMD_CONNECT: for i in xrange(2): - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, "\x01") + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, RSP_STATE_CONNECTED) self._write_to_sock(rsp_data, self._local_sock, addr) elif cmd == CMD_CONNECT_REMOTE: local_id = data[0:4] @@ -660,35 +689,35 @@ class TCPRelayHandler(object): logging.info('TCP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1])) else: # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) elif self._stage == STAGE_CONNECTING: if cmd == CMD_CONNECT_REMOTE: local_id = data[0:4] if self._local_id == local_id: for i in xrange(2): - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) self._write_to_sock(rsp_data, self._local_sock, addr) else: # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) elif self._stage == STAGE_STREAM: if len(data) < 4: # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) return local_id = data[0:4] if self._local_id != local_id: # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) return else: data = data[4:] if cmd == CMD_CONNECT_REMOTE: - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) self._write_to_sock(rsp_data, self._local_sock, addr) elif cmd == CMD_POST: recv_id = struct.unpack(">I", data[0:4])[0] @@ -701,7 +730,7 @@ class TCPRelayHandler(object): self._recvqueue.insert(pack_id, data[16:]) self._sendingqueue.set_finish(recv_id, []) elif cmd == CMD_DISCONNECT: - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) self.destroy() self.destroy_local() @@ -723,7 +752,7 @@ class TCPRelayHandler(object): local_id = data[0:4] if self._local_id != local_id: # ileagal request - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) self._write_to_sock(rsp_data, self._local_sock, addr) return else: @@ -732,13 +761,11 @@ class TCPRelayHandler(object): pack_id = struct.unpack(">I", data[0:4])[0] max_send_id = struct.unpack(">I", data[4:8])[0] data = data[8:] - logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id)) self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) elif cmd == CMD_SYN_STATUS_64: pack_id = struct.unpack(">Q", data[0:8])[0] max_send_id = struct.unpack(">Q", data[8:16])[0] data = data[16:] - logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id)) self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) def handle_event(self, sock, event): @@ -808,11 +835,9 @@ class TCPRelayHandler(object): def destroy_local(self): if self._local_sock: logging.debug('disconnect local') - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) addr = None - for it_addr in self._client_address: - addr = it_addr - break + addr = self.get_local_address() self._write_to_sock(rsp_data, self._local_sock, addr) self._local_sock = None del self._reqid_to_handlers[self._request_id] @@ -822,8 +847,9 @@ def client_key(source_addr, server_af): # notice this is server af, not dest af return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) + class UDPRelay(object): - def __init__(self, config, dns_resolver, is_local): + def __init__(self, config, dns_resolver, is_local, stat_callback=None): self._config = config if is_local: self._listen_addr = config['local_address'] @@ -836,7 +862,7 @@ class UDPRelay(object): self._remote_addr = None self._remote_port = None self._dns_resolver = dns_resolver - self._password = config['password'] + self._password = common.to_bytes(config['password']) self._method = config['method'] self._timeout = config['timeout'] self._is_local = is_local @@ -877,6 +903,7 @@ class UDPRelay(object): server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) self._server_socket = server_socket + self._stat_callback = stat_callback def _get_a_server(self): server = self._config['server'] @@ -937,6 +964,8 @@ class UDPRelay(object): data, r_addr = server.recvfrom(BUF_SIZE) if not data: logging.debug('UDP handle_server: data is empty') + if self._stat_callback: + self._stat_callback(self._listen_port, len(data)) if self._is_local: frag = common.ord(data[2]) if frag != 0: @@ -976,7 +1005,7 @@ class UDPRelay(object): break # return req id self._reqid_to_hd[req_id] = (data[2][0:4], None) - rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, "\x01") + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, RSP_STATE_CONNECTED) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) self.write_to_server_socket(data_to_send, r_addr) elif data[0] == CMD_CONNECT_REMOTE: @@ -994,7 +1023,7 @@ class UDPRelay(object): self.update_activity(handle) else: # disconnect - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) self.write_to_server_socket(data_to_send, r_addr) else: @@ -1002,16 +1031,19 @@ class UDPRelay(object): self._reqid_to_hd[data[1]].handle_client(r_addr, *data) else: # disconnect - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) self.write_to_server_socket(data_to_send, r_addr) elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT: if data[1] in self._reqid_to_hd: - self.update_activity(self._reqid_to_hd[data[1]]) - self._reqid_to_hd[data[1]].handle_client(r_addr, *data) + if type(self._reqid_to_hd[data[1]]) is tuple: + pass + else: + self.update_activity(self._reqid_to_hd[data[1]]) + self._reqid_to_hd[data[1]].handle_client(r_addr, *data) else: # disconnect - rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) self.write_to_server_socket(data_to_send, r_addr) return @@ -1042,7 +1074,6 @@ class UDPRelay(object): af, socktype, proto, canonname, sa = addrs[0] key = client_key(r_addr, af) - logging.debug(key) client = self._cache.get(key, None) if not client: # TODO async getaddrinfo @@ -1083,6 +1114,8 @@ class UDPRelay(object): if not data: logging.debug('UDP handle_client: data is empty') return + if self._stat_callback: + self._stat_callback(self._listen_port, len(data)) if not self._is_local: addrlen = len(r_addr[0]) if addrlen > 255: @@ -1101,7 +1134,7 @@ class UDPRelay(object): header_result = parse_header(data) if header_result is None: return - connecttype, dest_addr, dest_port, header_length = header_result + #connecttype, dest_addr, dest_port, header_length = header_result #logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port)) response = b'\x00\x00\x00' + data @@ -1250,3 +1283,5 @@ class UDPRelay(object): self._eventloop.remove_periodic(self.handle_periodic) self._eventloop.remove(self._server_socket) self._server_socket.close() + for client in list(self._cache.values()): + client.close()