|
|
@ -28,6 +28,7 @@ import traceback |
|
|
|
import random |
|
|
|
import platform |
|
|
|
import threading |
|
|
|
from collections import deque |
|
|
|
|
|
|
|
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache |
|
|
|
from shadowsocks.common import pre_parse_header, parse_header |
|
|
@ -93,6 +94,30 @@ WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING |
|
|
|
BUF_SIZE = 32 * 1024 |
|
|
|
UDP_MAX_BUF_SIZE = 65536 |
|
|
|
|
|
|
|
class SpeedTester(object): |
|
|
|
def __init__(self, max_speed = 0): |
|
|
|
self.max_speed = max_speed |
|
|
|
self.timeout = 3 |
|
|
|
self._cache = deque() |
|
|
|
self.sum_len = 0 |
|
|
|
|
|
|
|
def add(self, data_len): |
|
|
|
if self.max_speed > 0: |
|
|
|
self._cache.append((time.time(), data_len)) |
|
|
|
self.sum_len += data_len |
|
|
|
|
|
|
|
def isExceed(self): |
|
|
|
if self.max_speed > 0: |
|
|
|
if self.sum_len > 0: |
|
|
|
if self._cache[0][0] + self.timeout < time.time(): |
|
|
|
self.sum_len -= self._cache[0][1] |
|
|
|
self._cache.popleft() |
|
|
|
if self.sum_len > 0: |
|
|
|
t = max(time.time() - self._cache[0][0], 0.1) |
|
|
|
speed = (self.sum_len - self._cache[0][1]) / (time.time() - self._cache[0][0]) |
|
|
|
return speed >= self.max_speed |
|
|
|
return False |
|
|
|
|
|
|
|
class TCPRelayHandler(object): |
|
|
|
def __init__(self, server, fd_to_handlers, loop, local_sock, config, |
|
|
|
dns_resolver, is_local): |
|
|
@ -189,6 +214,8 @@ class TCPRelayHandler(object): |
|
|
|
self._update_activity() |
|
|
|
self._server.add_connection(1) |
|
|
|
self._server.stat_add(self._client_address[0], 1) |
|
|
|
self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0)) |
|
|
|
self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0)) |
|
|
|
|
|
|
|
def __hash__(self): |
|
|
|
# default __hash__ is id / 16 |
|
|
@ -725,6 +752,8 @@ class TCPRelayHandler(object): |
|
|
|
if not data: |
|
|
|
self.destroy() |
|
|
|
return |
|
|
|
|
|
|
|
self.speed_tester_u.add(len(data)) |
|
|
|
ogn_data = data |
|
|
|
if not is_local: |
|
|
|
if self._encryptor is not None: |
|
|
@ -819,6 +848,8 @@ class TCPRelayHandler(object): |
|
|
|
if not data: |
|
|
|
self.destroy() |
|
|
|
return |
|
|
|
|
|
|
|
self.speed_tester_d.add(len(data)) |
|
|
|
if self._encryptor is not None: |
|
|
|
if self._is_local: |
|
|
|
try: |
|
|
@ -900,38 +931,49 @@ class TCPRelayHandler(object): |
|
|
|
|
|
|
|
def handle_event(self, sock, event): |
|
|
|
# handle all events in this handler and dispatch them to methods |
|
|
|
handle = False |
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
logging.debug('ignore handle_event: destroyed') |
|
|
|
return |
|
|
|
return True |
|
|
|
if self._user is not None and self._user not in self._server.server_users: |
|
|
|
self.destroy() |
|
|
|
return |
|
|
|
return True |
|
|
|
# order is important |
|
|
|
if sock == self._remote_sock or sock == self._remote_sock_v6: |
|
|
|
if event & eventloop.POLL_ERR: |
|
|
|
handle = True |
|
|
|
self._on_remote_error() |
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
return |
|
|
|
return True |
|
|
|
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): |
|
|
|
if not self.speed_tester_d.isExceed(): |
|
|
|
handle = True |
|
|
|
self._on_remote_read(sock == self._remote_sock) |
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
return |
|
|
|
return True |
|
|
|
if event & eventloop.POLL_OUT: |
|
|
|
handle = True |
|
|
|
self._on_remote_write() |
|
|
|
elif sock == self._local_sock: |
|
|
|
if event & eventloop.POLL_ERR: |
|
|
|
handle = True |
|
|
|
self._on_local_error() |
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
return |
|
|
|
return True |
|
|
|
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): |
|
|
|
if not self.speed_tester_u.isExceed(): |
|
|
|
handle = True |
|
|
|
self._on_local_read() |
|
|
|
if self._stage == STAGE_DESTROYED: |
|
|
|
return |
|
|
|
return True |
|
|
|
if event & eventloop.POLL_OUT: |
|
|
|
handle = True |
|
|
|
self._on_local_write() |
|
|
|
else: |
|
|
|
logging.warn('unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) |
|
|
|
|
|
|
|
return handle |
|
|
|
|
|
|
|
def _log_error(self, e): |
|
|
|
logging.error('%s when handling connection from %s:%d' % |
|
|
|
(e, self._client_address[0], self._client_address[1])) |
|
|
|