diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 732951f..f068532 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -28,6 +28,8 @@ import os import socket import select +import errno +import logging from collections import defaultdict @@ -154,25 +156,24 @@ class EventLoop(object): def __init__(self): if hasattr(select, 'epoll'): self._impl = EpollLoop() - self._model = 'epoll' + model = 'epoll' elif hasattr(select, 'kqueue'): self._impl = KqueueLoop() - self._model = 'kqueue' + model = 'kqueue' elif hasattr(select, 'select'): self._impl = SelectLoop() - self._model = 'select' + model = 'select' else: raise Exception('can not find any available functions in select ' 'package') self._fd_to_f = {} - - @property - def model(self): - return self._model + self._handlers = [] + self.stopping = False + logging.debug('using event model: %s', model) def poll(self, timeout=None): events = self._impl.poll(timeout) - return ((self._fd_to_f[fd], event) for fd, event in events) + return [(self._fd_to_f[fd], fd, event) for fd, event in events] def add(self, f, mode): fd = f.fileno() @@ -188,6 +189,26 @@ class EventLoop(object): fd = f.fileno() self._impl.modify_fd(fd, mode) + def add_handler(self, handler): + self._handlers.append(handler) + + def run(self): + while not self.stopping: + events = None + try: + events = self.poll(1) + except (OSError, IOError) as e: + if errno_from_exception(e) == errno.EPIPE: + # Happens when the client closes the connection + continue + else: + logging.error(e) + continue + for handler in self._handlers: + # no exceptions should be raised by users + # TODO when there are a lot of handlers + handler(events) + # from tornado def errno_from_exception(e): diff --git a/shadowsocks/local.py b/shadowsocks/local.py index 24d3645..0b4b17d 100755 --- a/shadowsocks/local.py +++ b/shadowsocks/local.py @@ -24,8 +24,9 @@ import sys import os import logging -import encrypt import utils +import encrypt +import eventloop import tcprelay import udprelay @@ -49,11 +50,12 @@ def main(): logging.info("starting local at %s:%d" % (config['local_address'], config['local_port'])) - # TODO combine the two threads into one loop on a single thread - udprelay.UDPRelay(config, True).start() - tcprelay.TCPRelay(config, True).start() - while sys.stdin.read(): - pass + tcp_server = tcprelay.TCPRelay(config, True) + udp_server = udprelay.UDPRelay(config, True) + loop = eventloop.EventLoop() + tcp_server.add_to_loop(loop) + udp_server.add_to_loop(loop) + loop.run() except (KeyboardInterrupt, IOError, OSError) as e: logging.error(e) os._exit(0) diff --git a/shadowsocks/server.py b/shadowsocks/server.py index 7991dd8..2e35698 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -22,11 +22,11 @@ # SOFTWARE. import sys -import socket -import logging -import encrypt import os +import logging import utils +import encrypt +import eventloop import tcprelay import udprelay @@ -56,19 +56,17 @@ def main(): a_config['password'] = password logging.info("starting server at %s:%d" % (a_config['server'], int(port))) - tcp_server = tcprelay.TCPRelay(a_config, False) - tcp_servers.append(tcp_server) - udp_server = udprelay.UDPRelay(a_config, False) - udp_servers.append(udp_server) + tcp_servers.append(tcprelay.TCPRelay(a_config, False)) + udp_servers.append(udprelay.UDPRelay(a_config, False)) def run_server(): try: + loop = eventloop.EventLoop() for tcp_server in tcp_servers: - tcp_server.start() + tcp_server.add_to_loop(loop) for udp_server in udp_servers: - udp_server.start() - while sys.stdin.read(): - pass + udp_server.add_to_loop(loop) + loop.run() except (KeyboardInterrupt, IOError, OSError) as e: logging.error(e) os._exit(0) @@ -96,10 +94,10 @@ def main(): signal.signal(signal.SIGTERM, handler) # master - for tcp_server in tcp_servers: - tcp_server.close() - for udp_server in udp_servers: - udp_server.close() + for a_tcp_server in tcp_servers: + a_tcp_server.close() + for a_udp_server in udp_servers: + a_udp_server.close() for child in children: os.waitpid(child, 0) @@ -111,7 +109,4 @@ def main(): if __name__ == '__main__': - try: - main() - except socket.error, e: - logging.error(e) + main() diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index e969d3e..6239629 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -26,7 +26,6 @@ import socket import logging import encrypt import errno -import threading import eventloop from common import parse_header @@ -303,6 +302,7 @@ class TCPRelayHandler(object): logging.warn('unknown socket') def destroy(self): + logging.debug('destroy') if self._remote_sock: self._loop.remove(self._remote_sock) del self._fd_to_handlers[self._remote_sock.fileno()] @@ -320,8 +320,9 @@ class TCPRelay(object): self._config = config self._is_local = is_local self._closed = False - self._thread = None + self._eventloop = None self._fd_to_handlers = {} + self._last_time = time.time() if is_local: listen_addr = config['local_address'] @@ -343,70 +344,48 @@ class TCPRelay(object): server_socket.listen(1024) self._server_socket = server_socket - def _run(self): - server_socket = self._server_socket - self._eventloop = eventloop.EventLoop() - logging.debug('using event model: %s', self._eventloop.model) - self._eventloop.add(server_socket, + def add_to_loop(self, loop): + if self._closed: + raise Exception('already closed') + self._eventloop = loop + loop.add_handler(self._handle_events) + + self._eventloop.add(self._server_socket, eventloop.POLL_IN | eventloop.POLL_ERR) - last_time = time.time() - while not self._closed: - try: - events = self._eventloop.poll(1) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == errno.EPIPE: - # Happens when the client closes the connection - continue - else: - logging.error(e) - continue - for sock, event in events: + + def _handle_events(self, events): + for sock, fd, event in events: + if sock: + logging.debug('fd %d %s', fd, + eventloop.EVENT_NAMES.get(event, event)) + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + # TODO + raise Exception('server_socket error') + try: + logging.debug('accept') + conn = self._server_socket.accept() + TCPRelayHandler(self._fd_to_handlers, self._eventloop, + conn[0], self._config, self._is_local) + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS): + continue + else: + logging.error(e) + else: if sock: - logging.debug('fd %d %s', sock.fileno(), - eventloop.EVENT_NAMES[event]) - if sock == self._server_socket: - if event & eventloop.POLL_ERR: - # TODO - raise Exception('server_socket error') - try: - conn = self._server_socket.accept() - TCPRelayHandler(self._fd_to_handlers, self._eventloop, - conn[0], self._config, self._is_local) - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS): - continue - else: - logging.error(e) + handler = self._fd_to_handlers.get(fd, None) + if handler: + handler.handle_event(sock, event) else: - if sock: - handler = self._fd_to_handlers.get(sock.fileno(), None) - if handler: - handler.handle_event(sock, event) - else: - logging.warn('can not find handler for fd %d', - sock.fileno()) - self._eventloop.remove(sock) - else: - logging.warn('poll removed fd') + logging.warn('poll removed fd') + now = time.time() - if now - last_time > 5: + if now - self._last_time > 5: # TODO sweep timeouts - last_time = now - - def start(self): - # TODO combine loops on multiple ports into one single loop - if self._closed: - raise Exception('closed') - t = threading.Thread(target=self._run) - t.setName('TCPThread') - t.setDaemon(False) - t.start() - self._thread = t + self._last_time = now def close(self): self._closed = True self._server_socket.close() - - def thread(self): - return self._thread \ No newline at end of file diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 68414e8..2aeb069 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -67,7 +67,6 @@ import time -import threading import socket import logging import struct @@ -105,8 +104,10 @@ class UDPRelay(object): close_callback=self._close_client) self._client_fd_to_server_addr = \ lru_cache.LRUCache(timeout=config['timeout']) + self._eventloop = None self._closed = False - self._thread = None + self._last_time = time.time() + self._sockets = set() addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, socket.SOCK_DGRAM, socket.SOL_UDP) @@ -121,6 +122,7 @@ class UDPRelay(object): def _close_client(self, client): if hasattr(client, 'close'): + self._sockets.remove(client.fileno()) self._eventloop.remove(client) client.close() else: @@ -167,6 +169,7 @@ class UDPRelay(object): else: # drop return + self._sockets.add(client.fileno()) self._eventloop.add(client, eventloop.POLL_IN) data = data[header_length:] @@ -216,45 +219,29 @@ class UDPRelay(object): # simply drop that packet pass - def _run(self): - server_socket = self._server_socket - self._eventloop = eventloop.EventLoop() - self._eventloop.add(server_socket, eventloop.POLL_IN) - last_time = time.time() - while not self._closed: - try: - events = self._eventloop.poll(10) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == errno.EPIPE: - # Happens when the client closes the connection - continue - else: - logging.error(e) - continue - for sock, event in events: - if sock == self._server_socket: - self._handle_server() - else: - self._handle_client(sock) - now = time.time() - if now - last_time > 3.5: - self._cache.sweep() - if now - last_time > 7: - self._client_fd_to_server_addr.sweep() - last_time = now - - def start(self): + def add_to_loop(self, loop): if self._closed: - raise Exception('closed') - t = threading.Thread(target=self._run) - t.setName('UDPThread') - t.setDaemon(False) - t.start() - self._thread = t + raise Exception('already closed') + self._eventloop = loop + loop.add_handler(self._handle_events) + + server_socket = self._server_socket + self._eventloop.add(server_socket, + eventloop.POLL_IN | eventloop.POLL_ERR) + + def _handle_events(self, events): + for sock, fd, event in events: + if sock == self._server_socket: + self._handle_server() + elif sock and (fd in self._sockets): + self._handle_client(sock) + now = time.time() + if now - self._last_time > 3.5: + self._cache.sweep() + if now - self._last_time > 7: + self._client_fd_to_server_addr.sweep() + self._last_time = now def close(self): self._closed = True self._server_socket.close() - - def thread(self): - return self._thread