diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 6e2f22f..d2bbeb6 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -214,7 +214,7 @@ class EventLoop(object): if handler is not None: handler = handler[1] try: - handle = handle or handler.handle_event(sock, fd, event) + handle = handler.handle_event(sock, fd, event) or handle except (OSError, IOError) as e: shell.print_exception(e) now = time.time() diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 4ffbc0d..235505f 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -28,7 +28,6 @@ import traceback import random import platform import threading -from collections import deque from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache from shadowsocks.common import pre_parse_header, parse_header @@ -91,7 +90,7 @@ WAIT_STATUS_READING = 1 WAIT_STATUS_WRITING = 2 WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING -NETWORK_MTU = 1492 +NETWORK_MTU = 1500 TCP_MSS = NETWORK_MTU - 40 BUF_SIZE = 32 * 1024 UDP_MAX_BUF_SIZE = 65536 @@ -99,8 +98,7 @@ UDP_MAX_BUF_SIZE = 65536 class SpeedTester(object): def __init__(self, max_speed = 0): self.max_speed = max_speed * 1024 - self.timeout = 1 - self._cache = deque() + self.last_time = time.time() self.sum_len = 0 def update_limit(self, max_speed): @@ -108,19 +106,21 @@ class SpeedTester(object): def add(self, data_len): if self.max_speed > 0: - self._cache.append((time.time(), data_len)) + cut_t = time.time() + self.sum_len -= (cut_t - self.last_time) * self.max_speed + if self.sum_len < 0: + self.sum_len = 0 + self.last_time = cut_t self.sum_len += data_len def isExceed(self): if self.max_speed > 0: - if self.sum_len > 0: - cut_t = time.time() - t = max(cut_t - self._cache[0][0], 0.01) - speed = self.sum_len / t - if self._cache[0][0] + self.timeout < cut_t: - self.sum_len -= self._cache[0][1] - self._cache.popleft() - return speed >= self.max_speed + cut_t = time.time() + self.sum_len -= (cut_t - self.last_time) * self.max_speed + if self.sum_len < 0: + self.sum_len = 0 + self.last_time = cut_t + return self.sum_len >= self.max_speed return False class TCPRelayHandler(object): @@ -132,6 +132,9 @@ class TCPRelayHandler(object): self._local_sock = local_sock self._remote_sock = None self._remote_sock_v6 = None + self._local_sock_fd = None + self._remote_sock_fd = None + self._remotev6_sock_fd = None self._remote_udp = False self._config = config self._dns_resolver = dns_resolver @@ -216,13 +219,16 @@ class TCPRelayHandler(object): self._server.stat_add(self._client_address[0], 1) self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0)) self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0)) + self._recv_u_max_size = BUF_SIZE + self._recv_d_max_size = BUF_SIZE self._recv_pack_id = 0 self._udp_send_pack_id = 0 self._udpv6_send_pack_id = 0 local_sock.setblocking(False) local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - fd_to_handlers[local_sock.fileno()] = self + self._local_sock_fd = local_sock.fileno() + fd_to_handlers[self._local_sock_fd] = self loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, self._server) self._stage = STAGE_INIT @@ -695,13 +701,15 @@ class TCPRelayHandler(object): raise Exception('Port %d is in forbidden list, reject' % sa[1]) remote_sock = socket.socket(af, socktype, proto) self._remote_sock = remote_sock - self._fd_to_handlers[remote_sock.fileno()] = self + self._remote_sock_fd = remote_sock.fileno() + self._fd_to_handlers[self._remote_sock_fd] = self if self._remote_udp: af, socktype, proto, canonname, sa = addrs_v6[0] remote_sock_v6 = socket.socket(af, socktype, proto) self._remote_sock_v6 = remote_sock_v6 - self._fd_to_handlers[remote_sock_v6.fileno()] = self + self._remotev6_sock_fd = remote_sock_v6.fileno() + self._fd_to_handlers[self._remotev6_sock_fd] = self remote_sock.setblocking(False) if self._remote_udp: @@ -784,10 +792,16 @@ class TCPRelayHandler(object): logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) self.destroy() - def _get_read_size(self, sock, recv_buffer_size): + def _get_read_size(self, sock, recv_buffer_size, up): if self._overhead == 0: return recv_buffer_size buffer_size = len(sock.recv(recv_buffer_size, socket.MSG_PEEK)) + if up: + buffer_size = min(buffer_size, self._recv_u_max_size) + self._recv_u_max_size = min(self._recv_u_max_size + TCP_MSS, BUF_SIZE) + else: + buffer_size = min(buffer_size, self._recv_d_max_size) + self._recv_d_max_size = min(self._recv_d_max_size + TCP_MSS, BUF_SIZE) if buffer_size == recv_buffer_size: return buffer_size s = buffer_size % self._tcp_mss + self._overhead @@ -802,7 +816,7 @@ class TCPRelayHandler(object): return is_local = self._is_local if is_local: - recv_buffer_size = self._get_read_size(self._local_sock, self._recv_buffer_size) + recv_buffer_size = self._get_read_size(self._local_sock, self._recv_buffer_size, True) else: recv_buffer_size = BUF_SIZE data = None @@ -914,7 +928,7 @@ class TCPRelayHandler(object): if self._is_local: recv_buffer_size = BUF_SIZE else: - recv_buffer_size = self._get_read_size(self._remote_sock, self._recv_buffer_size) + recv_buffer_size = self._get_read_size(self._remote_sock, self._recv_buffer_size, False) data = self._remote_sock.recv(recv_buffer_size) self._recv_pack_id += 1 except (OSError, IOError) as e: @@ -1008,7 +1022,7 @@ class TCPRelayHandler(object): logging.error("remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1])) self.destroy() - def handle_event(self, sock, event): + def handle_event(self, sock, fd, event): # handle all events in this handler and dispatch them to methods handle = False if self._stage == STAGE_DESTROYED: @@ -1023,10 +1037,11 @@ class TCPRelayHandler(object): handle = True self._on_remote_error() elif event & (eventloop.POLL_IN | eventloop.POLL_HUP): - if not self.speed_tester_d.isExceed(): - if not self._server.speed_tester_d(self._user_id).isExceed(): - handle = True - self._on_remote_read(sock == self._remote_sock) + if not self.speed_tester_d.isExceed() and not self._server.speed_tester_d(self._user_id).isExceed(): + handle = True + self._on_remote_read(sock == self._remote_sock) + else: + self._recv_d_max_size = TCP_MSS elif event & eventloop.POLL_OUT: handle = True self._on_remote_write() @@ -1035,10 +1050,11 @@ class TCPRelayHandler(object): handle = True self._on_local_error() elif event & (eventloop.POLL_IN | eventloop.POLL_HUP): - if not self.speed_tester_u.isExceed(): - if not self._server.speed_tester_u(self._user_id).isExceed(): - handle = True - self._on_local_read() + if not self.speed_tester_u.isExceed() and not self._server.speed_tester_u(self._user_id).isExceed(): + handle = True + self._on_local_read() + else: + self._recv_u_max_size = TCP_MSS elif event & eventloop.POLL_OUT: handle = True self._on_local_write() @@ -1049,7 +1065,7 @@ class TCPRelayHandler(object): except Exception as e: shell.print_exception(e) try: - del self._fd_to_handlers[sock.fileno()] + del self._fd_to_handlers[fd] except Exception as e: shell.print_exception(e) sock.close() @@ -1088,7 +1104,8 @@ class TCPRelayHandler(object): except Exception as e: shell.print_exception(e) try: - del self._fd_to_handlers[self._remote_sock.fileno()] + if self._remote_sock_fd is not None: + del self._fd_to_handlers[self._remote_sock_fd] except Exception as e: shell.print_exception(e) self._remote_sock.close() @@ -1100,7 +1117,8 @@ class TCPRelayHandler(object): except Exception as e: shell.print_exception(e) try: - del self._fd_to_handlers[self._remote_sock_v6.fileno()] + if self._remotev6_sock_fd is not None: + del self._fd_to_handlers[self._remotev6_sock_fd] except Exception as e: shell.print_exception(e) self._remote_sock_v6.close() @@ -1112,7 +1130,8 @@ class TCPRelayHandler(object): except Exception as e: shell.print_exception(e) try: - del self._fd_to_handlers[self._local_sock.fileno()] + if self._local_sock_fd is not None: + del self._fd_to_handlers[self._local_sock_fd] except Exception as e: shell.print_exception(e) self._local_sock.close() @@ -1346,6 +1365,7 @@ class TCPRelay(object): def handle_event(self, sock, fd, event): # handle events and dispatch to handlers + handle = False if sock: logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, eventloop.EVENT_NAMES.get(event, event)) @@ -1354,6 +1374,7 @@ class TCPRelay(object): # TODO raise Exception('server_socket error') handler = None + handle = True try: logging.debug('accept') conn = self._server_socket.accept() @@ -1377,9 +1398,10 @@ class TCPRelay(object): if sock: handler = self._fd_to_handlers.get(fd, None) if handler: - handler.handle_event(sock, event) + handle = handler.handle_event(sock, fd, event) else: logging.warn('unknown fd') + handle = True try: self._eventloop.remove(sock) except Exception as e: @@ -1387,6 +1409,13 @@ class TCPRelay(object): sock.close() else: logging.warn('poll removed fd') + handle = True + if fd in self._fd_to_handlers: + try: + del self._fd_to_handlers[fd] + except Exception as e: + shell.print_exception(e) + return handle def handle_periodic(self): if self._closed: