Browse Source

add traffic control for each TCP connection

dev
破娃酱 8 years ago
parent
commit
72ca31dad9
  1. 3
      config.json
  2. 5
      shadowsocks/eventloop.py
  3. 4
      shadowsocks/lru_cache.py
  4. 54
      shadowsocks/tcprelay.py
  5. 19
      shadowsocks/udprelay.py

3
config.json

@ -4,6 +4,7 @@
"server_port": 8388, "server_port": 8388,
"local_address": "127.0.0.1", "local_address": "127.0.0.1",
"local_port": 1080, "local_port": 1080,
"password": "m", "password": "m",
"timeout": 120, "timeout": 120,
"udp_timeout": 60, "udp_timeout": 60,
@ -12,6 +13,8 @@
"protocol_param": "", "protocol_param": "",
"obfs": "tls1.2_ticket_auth_compatible", "obfs": "tls1.2_ticket_auth_compatible",
"obfs_param": "", "obfs_param": "",
"speed_limit_per_con": 0,
"dns_ipv6": false, "dns_ipv6": false,
"connect_verbose_info": 0, "connect_verbose_info": 0,
"redirect": "", "redirect": "",

5
shadowsocks/eventloop.py

@ -208,12 +208,13 @@ class EventLoop(object):
traceback.print_exc() traceback.print_exc()
continue continue
handle = False
for sock, fd, event in events: for sock, fd, event in events:
handler = self._fdmap.get(fd, None) handler = self._fdmap.get(fd, None)
if handler is not None: if handler is not None:
handler = handler[1] handler = handler[1]
try: try:
handler.handle_event(sock, fd, event) handle = handle or handler.handle_event(sock, fd, event)
except (OSError, IOError) as e: except (OSError, IOError) as e:
shell.print_exception(e) shell.print_exception(e)
now = time.time() now = time.time()
@ -221,6 +222,8 @@ class EventLoop(object):
for callback in self._periodic_callbacks: for callback in self._periodic_callbacks:
callback() callback()
self._last_time = now self._last_time = now
if events and (handle is False):
time.sleep(0.01)
def __del__(self): def __del__(self):
self._impl.close() self._impl.close()

4
shadowsocks/lru_cache.py

@ -88,11 +88,11 @@ class LRUCache(collections.MutableMapping):
for key in self._keys_to_last_time: for key in self._keys_to_last_time:
return key return key
def sweep(self): def sweep(self, sweep_item_cnt = SWEEP_MAX_ITEMS):
# O(n - m) # O(n - m)
now = time.time() now = time.time()
c = 0 c = 0
while c < SWEEP_MAX_ITEMS: while c < sweep_item_cnt:
if len(self._keys_to_last_time) == 0: if len(self._keys_to_last_time) == 0:
break break
for key in self._keys_to_last_time: for key in self._keys_to_last_time:

54
shadowsocks/tcprelay.py

@ -28,6 +28,7 @@ import traceback
import random import random
import platform import platform
import threading import threading
from collections import deque
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache
from shadowsocks.common import pre_parse_header, parse_header from shadowsocks.common import pre_parse_header, parse_header
@ -93,6 +94,30 @@ WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING
BUF_SIZE = 32 * 1024 BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536 UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object):
def __init__(self, max_speed = 0):
self.max_speed = max_speed
self.timeout = 3
self._cache = deque()
self.sum_len = 0
def add(self, data_len):
if self.max_speed > 0:
self._cache.append((time.time(), data_len))
self.sum_len += data_len
def isExceed(self):
if self.max_speed > 0:
if self.sum_len > 0:
if self._cache[0][0] + self.timeout < time.time():
self.sum_len -= self._cache[0][1]
self._cache.popleft()
if self.sum_len > 0:
t = max(time.time() - self._cache[0][0], 0.1)
speed = (self.sum_len - self._cache[0][1]) / (time.time() - self._cache[0][0])
return speed >= self.max_speed
return False
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config, def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local): dns_resolver, is_local):
@ -189,6 +214,8 @@ class TCPRelayHandler(object):
self._update_activity() self._update_activity()
self._server.add_connection(1) self._server.add_connection(1)
self._server.stat_add(self._client_address[0], 1) self._server.stat_add(self._client_address[0], 1)
self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0))
self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0))
def __hash__(self): def __hash__(self):
# default __hash__ is id / 16 # default __hash__ is id / 16
@ -725,6 +752,8 @@ class TCPRelayHandler(object):
if not data: if not data:
self.destroy() self.destroy()
return return
self.speed_tester_u.add(len(data))
ogn_data = data ogn_data = data
if not is_local: if not is_local:
if self._encryptor is not None: if self._encryptor is not None:
@ -819,6 +848,8 @@ class TCPRelayHandler(object):
if not data: if not data:
self.destroy() self.destroy()
return return
self.speed_tester_d.add(len(data))
if self._encryptor is not None: if self._encryptor is not None:
if self._is_local: if self._is_local:
try: try:
@ -900,38 +931,49 @@ class TCPRelayHandler(object):
def handle_event(self, sock, event): def handle_event(self, sock, event):
# handle all events in this handler and dispatch them to methods # handle all events in this handler and dispatch them to methods
handle = False
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
logging.debug('ignore handle_event: destroyed') logging.debug('ignore handle_event: destroyed')
return return True
if self._user is not None and self._user not in self._server.server_users: if self._user is not None and self._user not in self._server.server_users:
self.destroy() self.destroy()
return return True
# order is important # order is important
if sock == self._remote_sock or sock == self._remote_sock_v6: if sock == self._remote_sock or sock == self._remote_sock_v6:
if event & eventloop.POLL_ERR: if event & eventloop.POLL_ERR:
handle = True
self._on_remote_error() self._on_remote_error()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
if not self.speed_tester_d.isExceed():
handle = True
self._on_remote_read(sock == self._remote_sock) self._on_remote_read(sock == self._remote_sock)
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & eventloop.POLL_OUT: if event & eventloop.POLL_OUT:
handle = True
self._on_remote_write() self._on_remote_write()
elif sock == self._local_sock: elif sock == self._local_sock:
if event & eventloop.POLL_ERR: if event & eventloop.POLL_ERR:
handle = True
self._on_local_error() self._on_local_error()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
if not self.speed_tester_u.isExceed():
handle = True
self._on_local_read() self._on_local_read()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & eventloop.POLL_OUT: if event & eventloop.POLL_OUT:
handle = True
self._on_local_write() self._on_local_write()
else: else:
logging.warn('unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) logging.warn('unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
return handle
def _log_error(self, e): def _log_error(self, e):
logging.error('%s when handling connection from %s:%d' % logging.error('%s when handling connection from %s:%d' %
(e, self._client_address[0], self._client_address[1])) (e, self._client_address[0], self._client_address[1]))

19
shadowsocks/udprelay.py

@ -782,35 +782,44 @@ class TCPRelayHandler(object):
def handle_event(self, sock, event): def handle_event(self, sock, event):
# handle all events in this handler and dispatch them to methods # handle all events in this handler and dispatch them to methods
handle = False
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
logging.debug('ignore handle_event: destroyed') logging.debug('ignore handle_event: destroyed')
return return True
# order is important # order is important
if sock == self._remote_sock: if sock == self._remote_sock:
if event & eventloop.POLL_ERR: if event & eventloop.POLL_ERR:
handle = True
self._on_remote_error() self._on_remote_error()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
handle = True
self._on_remote_read() self._on_remote_read()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & eventloop.POLL_OUT: if event & eventloop.POLL_OUT:
handle = True
self._on_remote_write() self._on_remote_write()
elif sock == self._local_sock: elif sock == self._local_sock:
if event & eventloop.POLL_ERR: if event & eventloop.POLL_ERR:
handle = True
self._on_local_error() self._on_local_error()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP): if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
handle = True
self._on_local_read() self._on_local_read()
if self._stage == STAGE_DESTROYED: if self._stage == STAGE_DESTROYED:
return return True
if event & eventloop.POLL_OUT: if event & eventloop.POLL_OUT:
handle = True
self._on_local_write() self._on_local_write()
else: else:
logging.warn('unknown socket') logging.warn('unknown socket')
return handle
def _log_error(self, e): def _log_error(self, e):
logging.error('%s when handling connection from %s' % logging.error('%s when handling connection from %s' %
(e, self._client_address.keys())) (e, self._client_address.keys()))

Loading…
Cancel
Save