Browse Source

merge manyuser branch

master
breakwa11 9 years ago
parent
commit
5b6dec5940
  1. 21
      shadowsocks/common.py
  2. 9
      shadowsocks/eventloop.py
  3. 286
      shadowsocks/manager.py
  4. 12
      shadowsocks/server.py
  5. 13
      shadowsocks/shell.py
  6. 27
      shadowsocks/tcprelay.py
  7. 145
      shadowsocks/udprelay.py

21
shadowsocks/common.py

@ -21,7 +21,7 @@ from __future__ import absolute_import, division, print_function, \
import socket import socket
import struct import struct
import logging import logging
import binascii
def compat_ord(s): def compat_ord(s):
if type(s) == int: if type(s) == int:
@ -140,7 +140,7 @@ def pack_addr(address):
def pre_parse_header(data): def pre_parse_header(data):
datatype = ord(data[0]) datatype = ord(data[0])
if datatype == 0x80 : if datatype == 0x80:
if len(data) <= 2: if len(data) <= 2:
return None return None
rand_data_size = ord(data[1]) rand_data_size = ord(data[1])
@ -151,7 +151,7 @@ def pre_parse_header(data):
data = data[rand_data_size + 2:] data = data[rand_data_size + 2:]
elif datatype == 0x81: elif datatype == 0x81:
data = data[1:] data = data[1:]
elif datatype == 0x82 : elif datatype == 0x82:
if len(data) <= 3: if len(data) <= 3:
return None return None
rand_data_size = struct.unpack('>H', data[1:3])[0] rand_data_size = struct.unpack('>H', data[1:3])[0]
@ -160,6 +160,21 @@ def pre_parse_header(data):
'encryption method') 'encryption method')
return None return None
data = data[rand_data_size + 3:] data = data[rand_data_size + 3:]
elif datatype == 0x88:
if len(data) <= 7 + 7:
return None
data_size = struct.unpack('>H', data[1:3])[0]
ogn_data = data
data = data[:data_size]
crc = binascii.crc32(data) & 0xffffffff
if crc != 0xffffffff:
logging.warn('uncorrect CRC32, maybe wrong password or '
'encryption method')
return None
start_pos = 3 + ord(data[3])
data = data[start_pos:-4]
if data_size < len(ogn_data):
data += ogn_data[data_size:]
return data return data
def parse_header(data): def parse_header(data):

9
shadowsocks/eventloop.py

@ -98,6 +98,9 @@ class KqueueLoop(object):
self.unregister(fd) self.unregister(fd)
self.register(fd, mode) self.register(fd, mode)
def close(self):
self.kqueue.close()
class SelectLoop(object): class SelectLoop(object):
@ -135,6 +138,9 @@ class SelectLoop(object):
self.unregister(fd) self.unregister(fd)
self.register(fd, mode) self.register(fd, mode)
def close(self):
pass
class EventLoop(object): class EventLoop(object):
def __init__(self): def __init__(self):
@ -216,6 +222,9 @@ class EventLoop(object):
callback() callback()
self._last_time = now self._last_time = now
def __del__(self):
self._impl.close()
# from tornado # from tornado
def errno_from_exception(e): def errno_from_exception(e):

286
shadowsocks/manager.py

@ -0,0 +1,286 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import errno
import traceback
import socket
import logging
import json
import collections
from shadowsocks import common, eventloop, tcprelay, udprelay, asyncdns, shell
BUF_SIZE = 1506
STAT_SEND_LIMIT = 100
class Manager(object):
def __init__(self, config):
self._config = config
self._relays = {} # (tcprelay, udprelay)
self._loop = eventloop.EventLoop()
self._dns_resolver = asyncdns.DNSResolver()
self._dns_resolver.add_to_loop(self._loop)
self._statistics = collections.defaultdict(int)
self._control_client_addr = None
try:
manager_address = config['manager_address']
if ':' in manager_address:
addr = manager_address.rsplit(':', 1)
addr = addr[0], int(addr[1])
addrs = socket.getaddrinfo(addr[0], addr[1])
if addrs:
family = addrs[0][0]
else:
logging.error('invalid address: %s', manager_address)
exit(1)
else:
addr = manager_address
family = socket.AF_UNIX
self._control_socket = socket.socket(family,
socket.SOCK_DGRAM)
self._control_socket.bind(addr)
self._control_socket.setblocking(False)
except (OSError, IOError) as e:
logging.error(e)
logging.error('can not bind to manager address')
exit(1)
self._loop.add(self._control_socket,
eventloop.POLL_IN, self)
self._loop.add_periodic(self.handle_periodic)
port_password = config['port_password']
del config['port_password']
for port, password in port_password.items():
a_config = config.copy()
a_config['server_port'] = int(port)
a_config['password'] = password
self.add_port(a_config)
def add_port(self, config):
port = int(config['server_port'])
servers = self._relays.get(port, None)
if servers:
logging.error("server already exists at %s:%d" % (config['server'],
port))
return
logging.info("adding server at %s:%d" % (config['server'], port))
t = tcprelay.TCPRelay(config, self._dns_resolver, False,
self.stat_callback)
u = udprelay.UDPRelay(config, self._dns_resolver, False,
self.stat_callback)
t.add_to_loop(self._loop)
u.add_to_loop(self._loop)
self._relays[port] = (t, u)
def remove_port(self, config):
port = int(config['server_port'])
servers = self._relays.get(port, None)
if servers:
logging.info("removing server at %s:%d" % (config['server'], port))
t, u = servers
t.close(next_tick=False)
u.close(next_tick=False)
del self._relays[port]
else:
logging.error("server not exist at %s:%d" % (config['server'],
port))
def handle_event(self, sock, fd, event):
if sock == self._control_socket and event == eventloop.POLL_IN:
data, self._control_client_addr = sock.recvfrom(BUF_SIZE)
parsed = self._parse_command(data)
if parsed:
command, config = parsed
a_config = self._config.copy()
if config:
# let the command override the configuration file
a_config.update(config)
if 'server_port' not in a_config:
logging.error('can not find server_port in config')
else:
if command == 'add':
self.add_port(a_config)
self._send_control_data(b'ok')
elif command == 'remove':
self.remove_port(a_config)
self._send_control_data(b'ok')
elif command == 'ping':
self._send_control_data(b'pong')
else:
logging.error('unknown command %s', command)
def _parse_command(self, data):
# commands:
# add: {"server_port": 8000, "password": "foobar"}
# remove: {"server_port": 8000"}
data = common.to_str(data)
parts = data.split(':', 1)
if len(parts) < 2:
return data, None
command, config_json = parts
try:
config = shell.parse_json_in_str(config_json)
return command, config
except Exception as e:
logging.error(e)
return None
def stat_callback(self, port, data_len):
self._statistics[port] += data_len
def handle_periodic(self):
r = {}
i = 0
def send_data(data_dict):
if data_dict:
# use compact JSON format (without space)
data = common.to_bytes(json.dumps(data_dict,
separators=(',', ':')))
self._send_control_data(b'stat: ' + data)
for k, v in self._statistics.items():
r[k] = v
i += 1
# split the data into segments that fit in UDP packets
if i >= STAT_SEND_LIMIT:
send_data(r)
r.clear()
send_data(r)
self._statistics.clear()
def _send_control_data(self, data):
if self._control_client_addr:
try:
self._control_socket.sendto(data, self._control_client_addr)
except (socket.error, OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
return
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
def run(self):
self._loop.run()
def run(config):
Manager(config).run()
def test():
import time
import threading
import struct
from shadowsocks import encrypt
logging.basicConfig(level=5,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
enc = []
eventloop.TIMEOUT_PRECISION = 1
def run_server():
config = {
'server': '127.0.0.1',
'local_port': 1081,
'port_password': {
'8381': 'foobar1',
'8382': 'foobar2'
},
'method': 'aes-256-cfb',
'manager_address': '127.0.0.1:6001',
'timeout': 60,
'fast_open': False,
'verbose': 2
}
manager = Manager(config)
enc.append(manager)
manager.run()
t = threading.Thread(target=run_server)
t.start()
time.sleep(1)
manager = enc[0]
cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
cli.connect(('127.0.0.1', 6001))
# test add and remove
time.sleep(1)
cli.send(b'add: {"server_port":7001, "password":"asdfadsfasdf"}')
time.sleep(1)
assert 7001 in manager._relays
data, addr = cli.recvfrom(1506)
assert b'ok' in data
cli.send(b'remove: {"server_port":8381}')
time.sleep(1)
assert 8381 not in manager._relays
data, addr = cli.recvfrom(1506)
assert b'ok' in data
logging.info('add and remove test passed')
# test statistics for TCP
header = common.pack_addr(b'google.com') + struct.pack('>H', 80)
data = encrypt.encrypt_all(b'asdfadsfasdf', 'aes-256-cfb', 1,
header + b'GET /\r\n\r\n')
tcp_cli = socket.socket()
tcp_cli.connect(('127.0.0.1', 7001))
tcp_cli.send(data)
tcp_cli.recv(4096)
tcp_cli.close()
data, addr = cli.recvfrom(1506)
data = common.to_str(data)
assert data.startswith('stat: ')
data = data.split('stat:')[1]
stats = shell.parse_json_in_str(data)
assert '7001' in stats
logging.info('TCP statistics test passed')
# test statistics for UDP
header = common.pack_addr(b'127.0.0.1') + struct.pack('>H', 80)
data = encrypt.encrypt_all(b'foobar2', 'aes-256-cfb', 1,
header + b'test')
udp_cli = socket.socket(type=socket.SOCK_DGRAM)
udp_cli.sendto(data, ('127.0.0.1', 8382))
tcp_cli.close()
data, addr = cli.recvfrom(1506)
data = common.to_str(data)
assert data.startswith('stat: ')
data = data.split('stat:')[1]
stats = json.loads(data)
assert '8382' in stats
logging.info('UDP statistics test passed')
manager._loop.stop()
t.join()
if __name__ == '__main__':
test()

12
shadowsocks/server.py

@ -24,7 +24,8 @@ import logging
import signal import signal
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
asyncdns, manager
def main(): def main():
@ -48,10 +49,17 @@ def main():
else: else:
config['port_password'][str(server_port)] = config['password'] config['port_password'][str(server_port)] = config['password']
if config.get('manager_address', 0):
logging.info('entering manager mode')
manager.run(config)
return
tcp_servers = [] tcp_servers = []
udp_servers = [] udp_servers = []
dns_resolver = asyncdns.DNSResolver() dns_resolver = asyncdns.DNSResolver()
for port, password in config['port_password'].items(): port_password = config['port_password']
del config['port_password']
for port, password in port_password.items():
a_config = config.copy() a_config = config.copy()
ipv6_ok = False ipv6_ok = False
logging.info("server start with password [%s] method [%s]" % (password, a_config['method'])) logging.info("server start with password [%s] method [%s]" % (password, a_config['method']))

13
shadowsocks/shell.py

@ -136,7 +136,7 @@ def get_config(is_local):
else: else:
shortopts = 'hd:s:p:k:m:c:t:vq' shortopts = 'hd:s:p:k:m:c:t:vq'
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=', longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
'forbidden-ip=', 'user=', 'version'] 'forbidden-ip=', 'user=', 'manager-address=', 'version']
try: try:
config_path = find_config() config_path = find_config()
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts) optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
@ -148,8 +148,7 @@ def get_config(is_local):
logging.info('loading config from %s' % config_path) logging.info('loading config from %s' % config_path)
with open(config_path, 'rb') as f: with open(config_path, 'rb') as f:
try: try:
config = json.loads(f.read().decode('utf8'), config = parse_json_in_str(f.read().decode('utf8'))
object_hook=_decode_dict)
except ValueError as e: except ValueError as e:
logging.error('found an error in config.json: %s', logging.error('found an error in config.json: %s',
e.message) e.message)
@ -181,6 +180,8 @@ def get_config(is_local):
config['fast_open'] = True config['fast_open'] = True
elif key == '--workers': elif key == '--workers':
config['workers'] = int(value) config['workers'] = int(value)
elif key == '--manager-address':
config['manager_address'] = value
elif key == '--user': elif key == '--user':
config['user'] = to_str(value) config['user'] = to_str(value)
elif key == '--forbidden-ip': elif key == '--forbidden-ip':
@ -317,6 +318,7 @@ Proxy options:
--fast-open use TCP_FASTOPEN, requires Linux 3.7+ --fast-open use TCP_FASTOPEN, requires Linux 3.7+
--workers WORKERS number of workers, available on Unix/Linux --workers WORKERS number of workers, available on Unix/Linux
--forbidden-ip IPLIST comma seperated IP list forbidden to connect --forbidden-ip IPLIST comma seperated IP list forbidden to connect
--manager-address ADDR optional server manager UDP address, see wiki
General options: General options:
-h, --help show this help message and exit -h, --help show this help message and exit
@ -356,3 +358,8 @@ def _decode_dict(data):
value = _decode_dict(value) value = _decode_dict(value)
rv[key] = value rv[key] = value
return rv return rv
def parse_json_in_str(data):
# parse json and convert everything from unicode to str
return json.loads(data, object_hook=_decode_dict)

27
shadowsocks/tcprelay.py

@ -23,6 +23,7 @@ import socket
import errno import errno
import struct import struct
import logging import logging
import binascii
import traceback import traceback
import random import random
@ -32,9 +33,6 @@ from shadowsocks.common import pre_parse_header, parse_header
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time # we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512 TIMEOUTS_CLEAN_SIZE = 512
# we check timeouts every TIMEOUT_PRECISION seconds
TIMEOUT_PRECISION = 4
MSG_FASTOPEN = 0x20000000 MSG_FASTOPEN = 0x20000000
# SOCKS command definition # SOCKS command definition
@ -153,10 +151,10 @@ class TCPRelayHandler(object):
logging.debug('chosen server: %s:%d', server, server_port) logging.debug('chosen server: %s:%d', server, server_port)
return server, server_port return server, server_port
def _update_activity(self): def _update_activity(self, data_len=0):
# tell the TCP Relay we have activities recently # tell the TCP Relay we have activities recently
# else it will think we are inactive and timed out # else it will think we are inactive and timed out
self._server.update_activity(self) self._server.update_activity(self, data_len)
def _update_stream(self, stream, status): def _update_stream(self, stream, status):
# update a stream to a new waiting status # update a stream to a new waiting status
@ -343,6 +341,8 @@ class TCPRelayHandler(object):
logging.error('unknown command %d', cmd) logging.error('unknown command %d', cmd)
self.destroy() self.destroy()
return return
if False and ord(data[0]) != 0x88: # force new header
raise Exception('can not parse header')
data = pre_parse_header(data) data = pre_parse_header(data)
if data is None: if data is None:
raise Exception('can not parse header') raise Exception('can not parse header')
@ -379,7 +379,6 @@ class TCPRelayHandler(object):
self._log_error(e) self._log_error(e)
if self._config['verbose']: if self._config['verbose']:
traceback.print_exc() traceback.print_exc()
# TODO use logging when debug completed
self.destroy() self.destroy()
def _create_remote_socket(self, ip, port): def _create_remote_socket(self, ip, port):
@ -397,7 +396,6 @@ class TCPRelayHandler(object):
common.to_str(sa[0])) common.to_str(sa[0]))
remote_sock = socket.socket(af, socktype, proto) remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock self._remote_sock = remote_sock
self._fd_to_handlers[remote_sock.fileno()] = self self._fd_to_handlers[remote_sock.fileno()] = self
if self._remote_udp: if self._remote_udp:
@ -410,7 +408,6 @@ class TCPRelayHandler(object):
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
remote_sock.setblocking(False) remote_sock.setblocking(False)
if self._remote_udp: if self._remote_udp:
pass pass
@ -483,7 +480,6 @@ class TCPRelayHandler(object):
def _on_local_read(self): def _on_local_read(self):
# handle all local read events and dispatch them to methods for # handle all local read events and dispatch them to methods for
# each stage # each stage
self._update_activity()
if not self._local_sock: if not self._local_sock:
return return
is_local = self._is_local is_local = self._is_local
@ -497,6 +493,7 @@ class TCPRelayHandler(object):
if not data: if not data:
self.destroy() self.destroy()
return return
self._update_activity(len(data))
if not is_local: if not is_local:
data = self._encryptor.decrypt(data) data = self._encryptor.decrypt(data)
if not data: if not data:
@ -520,7 +517,6 @@ class TCPRelayHandler(object):
def _on_remote_read(self, is_remote_sock): def _on_remote_read(self, is_remote_sock):
# handle all remote read events # handle all remote read events
self._update_activity()
data = None data = None
try: try:
if self._remote_udp: if self._remote_udp:
@ -547,6 +543,7 @@ class TCPRelayHandler(object):
self.destroy() self.destroy()
return return
self._server.server_transfer_dl += len(data) self._server.server_transfer_dl += len(data)
self._update_activity(len(data))
if self._is_local: if self._is_local:
data = self._encryptor.decrypt(data) data = self._encryptor.decrypt(data)
else: else:
@ -667,7 +664,7 @@ class TCPRelayHandler(object):
class TCPRelay(object): class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local): def __init__(self, config, dns_resolver, is_local, stat_callback=None):
self._config = config self._config = config
self._is_local = is_local self._is_local = is_local
self._dns_resolver = dns_resolver self._dns_resolver = dns_resolver
@ -709,6 +706,7 @@ class TCPRelay(object):
self._config['fast_open'] = False self._config['fast_open'] = False
server_socket.listen(1024) server_socket.listen(1024)
self._server_socket = server_socket self._server_socket = server_socket
self._stat_callback = stat_callback
def add_to_loop(self, loop): def add_to_loop(self, loop):
if self._eventloop: if self._eventloop:
@ -727,7 +725,10 @@ class TCPRelay(object):
self._timeouts[index] = None self._timeouts[index] = None
del self._handler_to_timeouts[hash(handler)] del self._handler_to_timeouts[hash(handler)]
def update_activity(self, handler): def update_activity(self, handler, data_len):
if data_len and self._stat_callback:
self._stat_callback(self._listen_port, data_len)
# set handler to active # set handler to active
now = int(time.time()) now = int(time.time())
if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: if now - handler.last_activity < eventloop.TIMEOUT_PRECISION:
@ -828,3 +829,5 @@ class TCPRelay(object):
self._eventloop.remove_periodic(self.handle_periodic) self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket) self._eventloop.remove(self._server_socket)
self._server_socket.close() self._server_socket.close()
for handler in list(self._fd_to_handlers.values()):
handler.destroy()

145
shadowsocks/udprelay.py

@ -77,9 +77,6 @@ from shadowsocks.common import pre_parse_header, parse_header, pack_addr
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time # we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512 TIMEOUTS_CLEAN_SIZE = 512
# we check timeouts every TIMEOUT_PRECISION seconds
TIMEOUT_PRECISION = 4
# for each handler, we have 2 stream directions: # for each handler, we have 2 stream directions:
# upstream: from client to server direction # upstream: from client to server direction
# read local and write to remote # read local and write to remote
@ -97,8 +94,9 @@ WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING
BUF_SIZE = 65536 BUF_SIZE = 65536
DOUBLE_SEND_BEG_IDS = 16 DOUBLE_SEND_BEG_IDS = 16
POST_MTU_MIN = 1000 POST_MTU_MIN = 500
POST_MTU_MAX = 1400 POST_MTU_MAX = 1400
SENDING_WINDOW_SIZE = 8192
STAGE_INIT = 0 STAGE_INIT = 0
STAGE_RSP_ID = 1 STAGE_RSP_ID = 1
@ -119,6 +117,14 @@ CMD_DISCONNECT = 8
CMD_VER_STR = "\x08" CMD_VER_STR = "\x08"
RSP_STATE_EMPTY = ""
RSP_STATE_REJECT = "\x00"
RSP_STATE_CONNECTED = "\x01"
RSP_STATE_CONNECTEDREMOTE = "\x02"
RSP_STATE_ERROR = "\x03"
RSP_STATE_DISCONNECT = "\x04"
RSP_STATE_REDIRECT = "\x05"
class UDPLocalAddress(object): class UDPLocalAddress(object):
def __init__(self, addr): def __init__(self, addr):
self.addr = addr self.addr = addr
@ -173,9 +179,6 @@ class SendingQueue(object):
while self.begin_id < begin_id: while self.begin_id < begin_id:
self.begin_id += 1 self.begin_id += 1
del self.queue[self.begin_id] del self.queue[self.begin_id]
#while len(self.queue) > 0 and self.queue[0][0] <= begin_id:
# del self.queue[0]
# self.begin_id += 1
class RecvQueue(object): class RecvQueue(object):
def __init__(self): def __init__(self):
@ -229,6 +232,38 @@ class RecvQueue(object):
missing.append(i - begin_id) missing.append(i - begin_id)
return (begin_id, missing) return (begin_id, missing)
class AddressMap(object):
def __init__(self):
self._queue = []
self._addr_map = {}
def add(self, addr):
if addr in self._addr_map:
self._addr_map[addr] = UDPLocalAddress(addr)
else:
self._addr_map[addr] = UDPLocalAddress(addr)
self._queue.append(addr)
def keys(self):
return self._queue
def get(self):
if self._queue:
while True:
if len(self._queue) == 1:
return self._queue[0]
index = random.randint(0, len(self._queue) - 1)
addr = self._queue[index]
if self._addr_map[addr].is_timeout():
self._queue[index] = self._queue[len(self._queue) - 1]
del self._queue[len(self._queue) - 1]
del self._addr_map[addr]
else:
break
return addr
else:
return None
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop, def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop,
local_sock, local_id, client_param, config, local_sock, local_id, client_param, config,
@ -254,7 +289,7 @@ class TCPRelayHandler(object):
self._upstream_status = WAIT_STATUS_READING self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT self._downstream_status = WAIT_STATUS_INIT
self._request_id = 0 self._request_id = 0
self._client_address = {} self._client_address = AddressMap()
self._remote_address = None self._remote_address = None
self._sendingqueue = SendingQueue() self._sendingqueue = SendingQueue()
self._recvqueue = RecvQueue() self._recvqueue = RecvQueue()
@ -282,7 +317,10 @@ class TCPRelayHandler(object):
return self._remote_address return self._remote_address
def add_local_address(self, addr): def add_local_address(self, addr):
self._client_address[addr] = UDPLocalAddress(addr) self._client_address.add(addr)
def get_local_address(self):
return self._client_address.get()
def _update_activity(self): def _update_activity(self):
# tell the TCP Relay we have activities recently # tell the TCP Relay we have activities recently
@ -367,8 +405,6 @@ class TCPRelayHandler(object):
return False return False
if uncomplete: if uncomplete:
if sock == self._local_sock: if sock == self._local_sock:
#if data is not None and retry < 10:
# self._data_to_write_to_local.append([(data, addr), retry])
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
elif sock == self._remote_sock: elif sock == self._remote_sock:
self._data_to_write_to_remote.append(data) self._data_to_write_to_remote.append(data)
@ -377,15 +413,12 @@ class TCPRelayHandler(object):
logging.error('write_all_to_sock:unknown socket') logging.error('write_all_to_sock:unknown socket')
else: else:
if sock == self._local_sock: if sock == self._local_sock:
if self._sendingqueue.size() > 8192: if self._sendingqueue.size() > SENDING_WINDOW_SIZE:
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
else: else:
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
elif sock == self._remote_sock: elif sock == self._remote_sock:
if self._sendingqueue.size() > 8192: self._update_stream(STREAM_UP, WAIT_STATUS_READING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
else:
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
else: else:
logging.error('write_all_to_sock:unknown socket') logging.error('write_all_to_sock:unknown socket')
return True return True
@ -439,12 +472,10 @@ class TCPRelayHandler(object):
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
self._stage = STAGE_STREAM self._stage = STAGE_STREAM
for it_addr in self._client_address: addr = self.get_local_address()
addr = it_addr
break
for i in xrange(2): for i in xrange(2):
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
return return
@ -508,13 +539,11 @@ class TCPRelayHandler(object):
pack_id = self._sendingqueue.append(data) pack_id = self._sendingqueue.append(data)
post_data = self._pack_post_data(CMD_POST, pack_id, data) post_data = self._pack_post_data(CMD_POST, pack_id, data)
for it_addr in self._client_address: addr = self.get_local_address()
addr = it_addr
break
self._write_to_sock(post_data, self._local_sock, addr) self._write_to_sock(post_data, self._local_sock, addr)
#if pack_id <= DOUBLE_SEND_BEG_IDS: if pack_id <= DOUBLE_SEND_BEG_IDS:
# post_data = self._pack_post_data(CMD_POST, pack_id, data) post_data = self._pack_post_data(CMD_POST, pack_id, data)
# self._write_to_sock(post_data, self._local_sock, addr) self._write_to_sock(post_data, self._local_sock, addr)
except Exception as e: except Exception as e:
shell.print_exception(e) shell.print_exception(e)
@ -620,14 +649,14 @@ class TCPRelayHandler(object):
for post_pack_id, post_data in send_list: for post_pack_id, post_data in send_list:
rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
#if post_pack_id <= DOUBLE_SEND_BEG_IDS: if post_pack_id <= DOUBLE_SEND_BEG_IDS:
# rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data)
# self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
def handle_client(self, addr, cmd, request_id, data): def handle_client(self, addr, cmd, request_id, data):
self.add_local_address(addr) self.add_local_address(addr)
if cmd == CMD_DISCONNECT: if cmd == CMD_DISCONNECT:
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
self.destroy() self.destroy()
self.destroy_local() self.destroy_local()
@ -643,7 +672,7 @@ class TCPRelayHandler(object):
if self._stage == STAGE_RSP_ID: if self._stage == STAGE_RSP_ID:
if cmd == CMD_CONNECT: if cmd == CMD_CONNECT:
for i in xrange(2): for i in xrange(2):
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, "\x01") rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, RSP_STATE_CONNECTED)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
elif cmd == CMD_CONNECT_REMOTE: elif cmd == CMD_CONNECT_REMOTE:
local_id = data[0:4] local_id = data[0:4]
@ -660,35 +689,35 @@ class TCPRelayHandler(object):
logging.info('TCP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1])) logging.info('TCP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1]))
else: else:
# ileagal request # ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
elif self._stage == STAGE_CONNECTING: elif self._stage == STAGE_CONNECTING:
if cmd == CMD_CONNECT_REMOTE: if cmd == CMD_CONNECT_REMOTE:
local_id = data[0:4] local_id = data[0:4]
if self._local_id == local_id: if self._local_id == local_id:
for i in xrange(2): for i in xrange(2):
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
else: else:
# ileagal request # ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
elif self._stage == STAGE_STREAM: elif self._stage == STAGE_STREAM:
if len(data) < 4: if len(data) < 4:
# ileagal request # ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
return return
local_id = data[0:4] local_id = data[0:4]
if self._local_id != local_id: if self._local_id != local_id:
# ileagal request # ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
return return
else: else:
data = data[4:] data = data[4:]
if cmd == CMD_CONNECT_REMOTE: if cmd == CMD_CONNECT_REMOTE:
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02") rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
elif cmd == CMD_POST: elif cmd == CMD_POST:
recv_id = struct.unpack(">I", data[0:4])[0] recv_id = struct.unpack(">I", data[0:4])[0]
@ -701,7 +730,7 @@ class TCPRelayHandler(object):
self._recvqueue.insert(pack_id, data[16:]) self._recvqueue.insert(pack_id, data[16:])
self._sendingqueue.set_finish(recv_id, []) self._sendingqueue.set_finish(recv_id, [])
elif cmd == CMD_DISCONNECT: elif cmd == CMD_DISCONNECT:
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
self.destroy() self.destroy()
self.destroy_local() self.destroy_local()
@ -723,7 +752,7 @@ class TCPRelayHandler(object):
local_id = data[0:4] local_id = data[0:4]
if self._local_id != local_id: if self._local_id != local_id:
# ileagal request # ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
return return
else: else:
@ -732,13 +761,11 @@ class TCPRelayHandler(object):
pack_id = struct.unpack(">I", data[0:4])[0] pack_id = struct.unpack(">I", data[0:4])[0]
max_send_id = struct.unpack(">I", data[4:8])[0] max_send_id = struct.unpack(">I", data[4:8])[0]
data = data[8:] data = data[8:]
logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id))
self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data)
elif cmd == CMD_SYN_STATUS_64: elif cmd == CMD_SYN_STATUS_64:
pack_id = struct.unpack(">Q", data[0:8])[0] pack_id = struct.unpack(">Q", data[0:8])[0]
max_send_id = struct.unpack(">Q", data[8:16])[0] max_send_id = struct.unpack(">Q", data[8:16])[0]
data = data[16:] data = data[16:]
logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id))
self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data)
def handle_event(self, sock, event): def handle_event(self, sock, event):
@ -808,11 +835,9 @@ class TCPRelayHandler(object):
def destroy_local(self): def destroy_local(self):
if self._local_sock: if self._local_sock:
logging.debug('disconnect local') logging.debug('disconnect local')
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
addr = None addr = None
for it_addr in self._client_address: addr = self.get_local_address()
addr = it_addr
break
self._write_to_sock(rsp_data, self._local_sock, addr) self._write_to_sock(rsp_data, self._local_sock, addr)
self._local_sock = None self._local_sock = None
del self._reqid_to_handlers[self._request_id] del self._reqid_to_handlers[self._request_id]
@ -822,8 +847,9 @@ def client_key(source_addr, server_af):
# notice this is server af, not dest af # notice this is server af, not dest af
return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af)
class UDPRelay(object): class UDPRelay(object):
def __init__(self, config, dns_resolver, is_local): def __init__(self, config, dns_resolver, is_local, stat_callback=None):
self._config = config self._config = config
if is_local: if is_local:
self._listen_addr = config['local_address'] self._listen_addr = config['local_address']
@ -836,7 +862,7 @@ class UDPRelay(object):
self._remote_addr = None self._remote_addr = None
self._remote_port = None self._remote_port = None
self._dns_resolver = dns_resolver self._dns_resolver = dns_resolver
self._password = config['password'] self._password = common.to_bytes(config['password'])
self._method = config['method'] self._method = config['method']
self._timeout = config['timeout'] self._timeout = config['timeout']
self._is_local = is_local self._is_local = is_local
@ -877,6 +903,7 @@ class UDPRelay(object):
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
self._server_socket = server_socket self._server_socket = server_socket
self._stat_callback = stat_callback
def _get_a_server(self): def _get_a_server(self):
server = self._config['server'] server = self._config['server']
@ -937,6 +964,8 @@ class UDPRelay(object):
data, r_addr = server.recvfrom(BUF_SIZE) data, r_addr = server.recvfrom(BUF_SIZE)
if not data: if not data:
logging.debug('UDP handle_server: data is empty') logging.debug('UDP handle_server: data is empty')
if self._stat_callback:
self._stat_callback(self._listen_port, len(data))
if self._is_local: if self._is_local:
frag = common.ord(data[2]) frag = common.ord(data[2])
if frag != 0: if frag != 0:
@ -976,7 +1005,7 @@ class UDPRelay(object):
break break
# return req id # return req id
self._reqid_to_hd[req_id] = (data[2][0:4], None) self._reqid_to_hd[req_id] = (data[2][0:4], None)
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, "\x01") rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, RSP_STATE_CONNECTED)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr) self.write_to_server_socket(data_to_send, r_addr)
elif data[0] == CMD_CONNECT_REMOTE: elif data[0] == CMD_CONNECT_REMOTE:
@ -994,7 +1023,7 @@ class UDPRelay(object):
self.update_activity(handle) self.update_activity(handle)
else: else:
# disconnect # disconnect
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr) self.write_to_server_socket(data_to_send, r_addr)
else: else:
@ -1002,16 +1031,19 @@ class UDPRelay(object):
self._reqid_to_hd[data[1]].handle_client(r_addr, *data) self._reqid_to_hd[data[1]].handle_client(r_addr, *data)
else: else:
# disconnect # disconnect
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr) self.write_to_server_socket(data_to_send, r_addr)
elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT: elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT:
if data[1] in self._reqid_to_hd: if data[1] in self._reqid_to_hd:
self.update_activity(self._reqid_to_hd[data[1]]) if type(self._reqid_to_hd[data[1]]) is tuple:
self._reqid_to_hd[data[1]].handle_client(r_addr, *data) pass
else:
self.update_activity(self._reqid_to_hd[data[1]])
self._reqid_to_hd[data[1]].handle_client(r_addr, *data)
else: else:
# disconnect # disconnect
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "") rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr) self.write_to_server_socket(data_to_send, r_addr)
return return
@ -1042,7 +1074,6 @@ class UDPRelay(object):
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
key = client_key(r_addr, af) key = client_key(r_addr, af)
logging.debug(key)
client = self._cache.get(key, None) client = self._cache.get(key, None)
if not client: if not client:
# TODO async getaddrinfo # TODO async getaddrinfo
@ -1083,6 +1114,8 @@ class UDPRelay(object):
if not data: if not data:
logging.debug('UDP handle_client: data is empty') logging.debug('UDP handle_client: data is empty')
return return
if self._stat_callback:
self._stat_callback(self._listen_port, len(data))
if not self._is_local: if not self._is_local:
addrlen = len(r_addr[0]) addrlen = len(r_addr[0])
if addrlen > 255: if addrlen > 255:
@ -1101,7 +1134,7 @@ class UDPRelay(object):
header_result = parse_header(data) header_result = parse_header(data)
if header_result is None: if header_result is None:
return return
connecttype, dest_addr, dest_port, header_length = header_result #connecttype, dest_addr, dest_port, header_length = header_result
#logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port)) #logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port))
response = b'\x00\x00\x00' + data response = b'\x00\x00\x00' + data
@ -1250,3 +1283,5 @@ class UDPRelay(object):
self._eventloop.remove_periodic(self.handle_periodic) self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket) self._eventloop.remove(self._server_socket)
self._server_socket.close() self._server_socket.close()
for client in list(self._cache.values()):
client.close()

Loading…
Cancel
Save