7 changed files with 438 additions and 75 deletions
			
			
		@ -0,0 +1,286 @@ | 
				
			|||
#!/usr/bin/python | 
				
			|||
# -*- coding: utf-8 -*- | 
				
			|||
# | 
				
			|||
# Copyright 2015 clowwindy | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may | 
				
			|||
# not use this file except in compliance with the License. You may obtain | 
				
			|||
# a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
				
			|||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
				
			|||
# License for the specific language governing permissions and limitations | 
				
			|||
# under the License. | 
				
			|||
 | 
				
			|||
from __future__ import absolute_import, division, print_function, \ | 
				
			|||
    with_statement | 
				
			|||
 | 
				
			|||
import errno | 
				
			|||
import traceback | 
				
			|||
import socket | 
				
			|||
import logging | 
				
			|||
import json | 
				
			|||
import collections | 
				
			|||
 | 
				
			|||
from shadowsocks import common, eventloop, tcprelay, udprelay, asyncdns, shell | 
				
			|||
 | 
				
			|||
 | 
				
			|||
BUF_SIZE = 1506 | 
				
			|||
STAT_SEND_LIMIT = 100 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class Manager(object): | 
				
			|||
 | 
				
			|||
    def __init__(self, config): | 
				
			|||
        self._config = config | 
				
			|||
        self._relays = {}  # (tcprelay, udprelay) | 
				
			|||
        self._loop = eventloop.EventLoop() | 
				
			|||
        self._dns_resolver = asyncdns.DNSResolver() | 
				
			|||
        self._dns_resolver.add_to_loop(self._loop) | 
				
			|||
 | 
				
			|||
        self._statistics = collections.defaultdict(int) | 
				
			|||
        self._control_client_addr = None | 
				
			|||
        try: | 
				
			|||
            manager_address = config['manager_address'] | 
				
			|||
            if ':' in manager_address: | 
				
			|||
                addr = manager_address.rsplit(':', 1) | 
				
			|||
                addr = addr[0], int(addr[1]) | 
				
			|||
                addrs = socket.getaddrinfo(addr[0], addr[1]) | 
				
			|||
                if addrs: | 
				
			|||
                    family = addrs[0][0] | 
				
			|||
                else: | 
				
			|||
                    logging.error('invalid address: %s', manager_address) | 
				
			|||
                    exit(1) | 
				
			|||
            else: | 
				
			|||
                addr = manager_address | 
				
			|||
                family = socket.AF_UNIX | 
				
			|||
            self._control_socket = socket.socket(family, | 
				
			|||
                                                 socket.SOCK_DGRAM) | 
				
			|||
            self._control_socket.bind(addr) | 
				
			|||
            self._control_socket.setblocking(False) | 
				
			|||
        except (OSError, IOError) as e: | 
				
			|||
            logging.error(e) | 
				
			|||
            logging.error('can not bind to manager address') | 
				
			|||
            exit(1) | 
				
			|||
        self._loop.add(self._control_socket, | 
				
			|||
                       eventloop.POLL_IN, self) | 
				
			|||
        self._loop.add_periodic(self.handle_periodic) | 
				
			|||
 | 
				
			|||
        port_password = config['port_password'] | 
				
			|||
        del config['port_password'] | 
				
			|||
        for port, password in port_password.items(): | 
				
			|||
            a_config = config.copy() | 
				
			|||
            a_config['server_port'] = int(port) | 
				
			|||
            a_config['password'] = password | 
				
			|||
            self.add_port(a_config) | 
				
			|||
 | 
				
			|||
    def add_port(self, config): | 
				
			|||
        port = int(config['server_port']) | 
				
			|||
        servers = self._relays.get(port, None) | 
				
			|||
        if servers: | 
				
			|||
            logging.error("server already exists at %s:%d" % (config['server'], | 
				
			|||
                                                              port)) | 
				
			|||
            return | 
				
			|||
        logging.info("adding server at %s:%d" % (config['server'], port)) | 
				
			|||
        t = tcprelay.TCPRelay(config, self._dns_resolver, False, | 
				
			|||
                              self.stat_callback) | 
				
			|||
        u = udprelay.UDPRelay(config, self._dns_resolver, False, | 
				
			|||
                              self.stat_callback) | 
				
			|||
        t.add_to_loop(self._loop) | 
				
			|||
        u.add_to_loop(self._loop) | 
				
			|||
        self._relays[port] = (t, u) | 
				
			|||
 | 
				
			|||
    def remove_port(self, config): | 
				
			|||
        port = int(config['server_port']) | 
				
			|||
        servers = self._relays.get(port, None) | 
				
			|||
        if servers: | 
				
			|||
            logging.info("removing server at %s:%d" % (config['server'], port)) | 
				
			|||
            t, u = servers | 
				
			|||
            t.close(next_tick=False) | 
				
			|||
            u.close(next_tick=False) | 
				
			|||
            del self._relays[port] | 
				
			|||
        else: | 
				
			|||
            logging.error("server not exist at %s:%d" % (config['server'], | 
				
			|||
                                                         port)) | 
				
			|||
 | 
				
			|||
    def handle_event(self, sock, fd, event): | 
				
			|||
        if sock == self._control_socket and event == eventloop.POLL_IN: | 
				
			|||
            data, self._control_client_addr = sock.recvfrom(BUF_SIZE) | 
				
			|||
            parsed = self._parse_command(data) | 
				
			|||
            if parsed: | 
				
			|||
                command, config = parsed | 
				
			|||
                a_config = self._config.copy() | 
				
			|||
                if config: | 
				
			|||
                    # let the command override the configuration file | 
				
			|||
                    a_config.update(config) | 
				
			|||
                if 'server_port' not in a_config: | 
				
			|||
                    logging.error('can not find server_port in config') | 
				
			|||
                else: | 
				
			|||
                    if command == 'add': | 
				
			|||
                        self.add_port(a_config) | 
				
			|||
                        self._send_control_data(b'ok') | 
				
			|||
                    elif command == 'remove': | 
				
			|||
                        self.remove_port(a_config) | 
				
			|||
                        self._send_control_data(b'ok') | 
				
			|||
                    elif command == 'ping': | 
				
			|||
                        self._send_control_data(b'pong') | 
				
			|||
                    else: | 
				
			|||
                        logging.error('unknown command %s', command) | 
				
			|||
 | 
				
			|||
    def _parse_command(self, data): | 
				
			|||
        # commands: | 
				
			|||
        # add: {"server_port": 8000, "password": "foobar"} | 
				
			|||
        # remove: {"server_port": 8000"} | 
				
			|||
        data = common.to_str(data) | 
				
			|||
        parts = data.split(':', 1) | 
				
			|||
        if len(parts) < 2: | 
				
			|||
            return data, None | 
				
			|||
        command, config_json = parts | 
				
			|||
        try: | 
				
			|||
            config = shell.parse_json_in_str(config_json) | 
				
			|||
            return command, config | 
				
			|||
        except Exception as e: | 
				
			|||
            logging.error(e) | 
				
			|||
            return None | 
				
			|||
 | 
				
			|||
    def stat_callback(self, port, data_len): | 
				
			|||
        self._statistics[port] += data_len | 
				
			|||
 | 
				
			|||
    def handle_periodic(self): | 
				
			|||
        r = {} | 
				
			|||
        i = 0 | 
				
			|||
 | 
				
			|||
        def send_data(data_dict): | 
				
			|||
            if data_dict: | 
				
			|||
                # use compact JSON format (without space) | 
				
			|||
                data = common.to_bytes(json.dumps(data_dict, | 
				
			|||
                                                  separators=(',', ':'))) | 
				
			|||
                self._send_control_data(b'stat: ' + data) | 
				
			|||
 | 
				
			|||
        for k, v in self._statistics.items(): | 
				
			|||
            r[k] = v | 
				
			|||
            i += 1 | 
				
			|||
            # split the data into segments that fit in UDP packets | 
				
			|||
            if i >= STAT_SEND_LIMIT: | 
				
			|||
                send_data(r) | 
				
			|||
                r.clear() | 
				
			|||
        send_data(r) | 
				
			|||
        self._statistics.clear() | 
				
			|||
 | 
				
			|||
    def _send_control_data(self, data): | 
				
			|||
        if self._control_client_addr: | 
				
			|||
            try: | 
				
			|||
                self._control_socket.sendto(data, self._control_client_addr) | 
				
			|||
            except (socket.error, OSError, IOError) as e: | 
				
			|||
                error_no = eventloop.errno_from_exception(e) | 
				
			|||
                if error_no in (errno.EAGAIN, errno.EINPROGRESS, | 
				
			|||
                                errno.EWOULDBLOCK): | 
				
			|||
                    return | 
				
			|||
                else: | 
				
			|||
                    shell.print_exception(e) | 
				
			|||
                    if self._config['verbose']: | 
				
			|||
                        traceback.print_exc() | 
				
			|||
 | 
				
			|||
    def run(self): | 
				
			|||
        self._loop.run() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def run(config): | 
				
			|||
    Manager(config).run() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def test(): | 
				
			|||
    import time | 
				
			|||
    import threading | 
				
			|||
    import struct | 
				
			|||
    from shadowsocks import encrypt | 
				
			|||
 | 
				
			|||
    logging.basicConfig(level=5, | 
				
			|||
                        format='%(asctime)s %(levelname)-8s %(message)s', | 
				
			|||
                        datefmt='%Y-%m-%d %H:%M:%S') | 
				
			|||
    enc = [] | 
				
			|||
    eventloop.TIMEOUT_PRECISION = 1 | 
				
			|||
 | 
				
			|||
    def run_server(): | 
				
			|||
        config = { | 
				
			|||
            'server': '127.0.0.1', | 
				
			|||
            'local_port': 1081, | 
				
			|||
            'port_password': { | 
				
			|||
                '8381': 'foobar1', | 
				
			|||
                '8382': 'foobar2' | 
				
			|||
            }, | 
				
			|||
            'method': 'aes-256-cfb', | 
				
			|||
            'manager_address': '127.0.0.1:6001', | 
				
			|||
            'timeout': 60, | 
				
			|||
            'fast_open': False, | 
				
			|||
            'verbose': 2 | 
				
			|||
        } | 
				
			|||
        manager = Manager(config) | 
				
			|||
        enc.append(manager) | 
				
			|||
        manager.run() | 
				
			|||
 | 
				
			|||
    t = threading.Thread(target=run_server) | 
				
			|||
    t.start() | 
				
			|||
    time.sleep(1) | 
				
			|||
    manager = enc[0] | 
				
			|||
    cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | 
				
			|||
    cli.connect(('127.0.0.1', 6001)) | 
				
			|||
 | 
				
			|||
    # test add and remove | 
				
			|||
    time.sleep(1) | 
				
			|||
    cli.send(b'add: {"server_port":7001, "password":"asdfadsfasdf"}') | 
				
			|||
    time.sleep(1) | 
				
			|||
    assert 7001 in manager._relays | 
				
			|||
    data, addr = cli.recvfrom(1506) | 
				
			|||
    assert b'ok' in data | 
				
			|||
 | 
				
			|||
    cli.send(b'remove: {"server_port":8381}') | 
				
			|||
    time.sleep(1) | 
				
			|||
    assert 8381 not in manager._relays | 
				
			|||
    data, addr = cli.recvfrom(1506) | 
				
			|||
    assert b'ok' in data | 
				
			|||
    logging.info('add and remove test passed') | 
				
			|||
 | 
				
			|||
    # test statistics for TCP | 
				
			|||
    header = common.pack_addr(b'google.com') + struct.pack('>H', 80) | 
				
			|||
    data = encrypt.encrypt_all(b'asdfadsfasdf', 'aes-256-cfb', 1, | 
				
			|||
                               header + b'GET /\r\n\r\n') | 
				
			|||
    tcp_cli = socket.socket() | 
				
			|||
    tcp_cli.connect(('127.0.0.1', 7001)) | 
				
			|||
    tcp_cli.send(data) | 
				
			|||
    tcp_cli.recv(4096) | 
				
			|||
    tcp_cli.close() | 
				
			|||
 | 
				
			|||
    data, addr = cli.recvfrom(1506) | 
				
			|||
    data = common.to_str(data) | 
				
			|||
    assert data.startswith('stat: ') | 
				
			|||
    data = data.split('stat:')[1] | 
				
			|||
    stats = shell.parse_json_in_str(data) | 
				
			|||
    assert '7001' in stats | 
				
			|||
    logging.info('TCP statistics test passed') | 
				
			|||
 | 
				
			|||
    # test statistics for UDP | 
				
			|||
    header = common.pack_addr(b'127.0.0.1') + struct.pack('>H', 80) | 
				
			|||
    data = encrypt.encrypt_all(b'foobar2', 'aes-256-cfb', 1, | 
				
			|||
                               header + b'test') | 
				
			|||
    udp_cli = socket.socket(type=socket.SOCK_DGRAM) | 
				
			|||
    udp_cli.sendto(data, ('127.0.0.1', 8382)) | 
				
			|||
    tcp_cli.close() | 
				
			|||
 | 
				
			|||
    data, addr = cli.recvfrom(1506) | 
				
			|||
    data = common.to_str(data) | 
				
			|||
    assert data.startswith('stat: ') | 
				
			|||
    data = data.split('stat:')[1] | 
				
			|||
    stats = json.loads(data) | 
				
			|||
    assert '8382' in stats | 
				
			|||
    logging.info('UDP statistics test passed') | 
				
			|||
 | 
				
			|||
    manager._loop.stop() | 
				
			|||
    t.join() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
if __name__ == '__main__': | 
				
			|||
    test() | 
				
			|||
					Loading…
					
					
				
		Reference in new issue