Browse Source

add traffic control for each TCP connection

dev
破娃酱 8 years ago
parent
commit
72ca31dad9
  1. 3
      config.json
  2. 5
      shadowsocks/eventloop.py
  3. 4
      shadowsocks/lru_cache.py
  4. 62
      shadowsocks/tcprelay.py
  5. 19
      shadowsocks/udprelay.py

3
config.json

@ -4,6 +4,7 @@
"server_port": 8388,
"local_address": "127.0.0.1",
"local_port": 1080,
"password": "m",
"timeout": 120,
"udp_timeout": 60,
@ -12,6 +13,8 @@
"protocol_param": "",
"obfs": "tls1.2_ticket_auth_compatible",
"obfs_param": "",
"speed_limit_per_con": 0,
"dns_ipv6": false,
"connect_verbose_info": 0,
"redirect": "",

5
shadowsocks/eventloop.py

@ -208,12 +208,13 @@ class EventLoop(object):
traceback.print_exc()
continue
handle = False
for sock, fd, event in events:
handler = self._fdmap.get(fd, None)
if handler is not None:
handler = handler[1]
try:
handler.handle_event(sock, fd, event)
handle = handle or handler.handle_event(sock, fd, event)
except (OSError, IOError) as e:
shell.print_exception(e)
now = time.time()
@ -221,6 +222,8 @@ class EventLoop(object):
for callback in self._periodic_callbacks:
callback()
self._last_time = now
if events and (handle is False):
time.sleep(0.01)
def __del__(self):
self._impl.close()

4
shadowsocks/lru_cache.py

@ -88,11 +88,11 @@ class LRUCache(collections.MutableMapping):
for key in self._keys_to_last_time:
return key
def sweep(self):
def sweep(self, sweep_item_cnt = SWEEP_MAX_ITEMS):
# O(n - m)
now = time.time()
c = 0
while c < SWEEP_MAX_ITEMS:
while c < sweep_item_cnt:
if len(self._keys_to_last_time) == 0:
break
for key in self._keys_to_last_time:

62
shadowsocks/tcprelay.py

@ -28,6 +28,7 @@ import traceback
import random
import platform
import threading
from collections import deque
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache
from shadowsocks.common import pre_parse_header, parse_header
@ -93,6 +94,30 @@ WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING
BUF_SIZE = 32 * 1024
UDP_MAX_BUF_SIZE = 65536
class SpeedTester(object):
def __init__(self, max_speed = 0):
self.max_speed = max_speed
self.timeout = 3
self._cache = deque()
self.sum_len = 0
def add(self, data_len):
if self.max_speed > 0:
self._cache.append((time.time(), data_len))
self.sum_len += data_len
def isExceed(self):
if self.max_speed > 0:
if self.sum_len > 0:
if self._cache[0][0] + self.timeout < time.time():
self.sum_len -= self._cache[0][1]
self._cache.popleft()
if self.sum_len > 0:
t = max(time.time() - self._cache[0][0], 0.1)
speed = (self.sum_len - self._cache[0][1]) / (time.time() - self._cache[0][0])
return speed >= self.max_speed
return False
class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local):
@ -189,6 +214,8 @@ class TCPRelayHandler(object):
self._update_activity()
self._server.add_connection(1)
self._server.stat_add(self._client_address[0], 1)
self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0))
self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0))
def __hash__(self):
# default __hash__ is id / 16
@ -725,6 +752,8 @@ class TCPRelayHandler(object):
if not data:
self.destroy()
return
self.speed_tester_u.add(len(data))
ogn_data = data
if not is_local:
if self._encryptor is not None:
@ -819,6 +848,8 @@ class TCPRelayHandler(object):
if not data:
self.destroy()
return
self.speed_tester_d.add(len(data))
if self._encryptor is not None:
if self._is_local:
try:
@ -900,38 +931,49 @@ class TCPRelayHandler(object):
def handle_event(self, sock, event):
# handle all events in this handler and dispatch them to methods
handle = False
if self._stage == STAGE_DESTROYED:
logging.debug('ignore handle_event: destroyed')
return
return True
if self._user is not None and self._user not in self._server.server_users:
self.destroy()
return
return True
# order is important
if sock == self._remote_sock or sock == self._remote_sock_v6:
if event & eventloop.POLL_ERR:
handle = True
self._on_remote_error()
if self._stage == STAGE_DESTROYED:
return
return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
self._on_remote_read(sock == self._remote_sock)
if self._stage == STAGE_DESTROYED:
return
if not self.speed_tester_d.isExceed():
handle = True
self._on_remote_read(sock == self._remote_sock)
if self._stage == STAGE_DESTROYED:
return True
if event & eventloop.POLL_OUT:
handle = True
self._on_remote_write()
elif sock == self._local_sock:
if event & eventloop.POLL_ERR:
handle = True
self._on_local_error()
if self._stage == STAGE_DESTROYED:
return
return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
self._on_local_read()
if self._stage == STAGE_DESTROYED:
return
if not self.speed_tester_u.isExceed():
handle = True
self._on_local_read()
if self._stage == STAGE_DESTROYED:
return True
if event & eventloop.POLL_OUT:
handle = True
self._on_local_write()
else:
logging.warn('unknown socket from %s:%d' % (self._client_address[0], self._client_address[1]))
return handle
def _log_error(self, e):
logging.error('%s when handling connection from %s:%d' %
(e, self._client_address[0], self._client_address[1]))

19
shadowsocks/udprelay.py

@ -782,35 +782,44 @@ class TCPRelayHandler(object):
def handle_event(self, sock, event):
# handle all events in this handler and dispatch them to methods
handle = False
if self._stage == STAGE_DESTROYED:
logging.debug('ignore handle_event: destroyed')
return
return True
# order is important
if sock == self._remote_sock:
if event & eventloop.POLL_ERR:
handle = True
self._on_remote_error()
if self._stage == STAGE_DESTROYED:
return
return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
handle = True
self._on_remote_read()
if self._stage == STAGE_DESTROYED:
return
return True
if event & eventloop.POLL_OUT:
handle = True
self._on_remote_write()
elif sock == self._local_sock:
if event & eventloop.POLL_ERR:
handle = True
self._on_local_error()
if self._stage == STAGE_DESTROYED:
return
return True
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
handle = True
self._on_local_read()
if self._stage == STAGE_DESTROYED:
return
return True
if event & eventloop.POLL_OUT:
handle = True
self._on_local_write()
else:
logging.warn('unknown socket')
return handle
def _log_error(self, e):
logging.error('%s when handling connection from %s' %
(e, self._client_address.keys()))

Loading…
Cancel
Save