diff --git a/shadowsocks/event.py b/shadowsocks/event.py index 737d039..2cdd065 100644 --- a/shadowsocks/event.py +++ b/shadowsocks/event.py @@ -29,6 +29,9 @@ import select from collections import defaultdict +__all__ = ['EventLoop', 'MODE_NULL', 'MODE_IN', 'MODE_OUT', 'MODE_ERR', + 'MODE_HUP', 'MODE_NVAL'] + MODE_NULL = 0x00 MODE_IN = 0x01 MODE_OUT = 0x04 @@ -135,13 +138,37 @@ class SelectLoop(object): self.add_fd(fd, mode) -EventLoop = None - -if hasattr(select, 'epoll'): - EventLoop = EpollLoop -elif hasattr(select, 'kqueue'): - EventLoop = KqueueLoop -elif hasattr(select, 'select'): - EventLoop = SelectLoop -else: - raise Exception('can not find any available functions in select package') +class EventLoop(object): + def __init__(self): + if hasattr(select, 'epoll'): + self._impl = EpollLoop() + elif hasattr(select, 'kqueue'): + self._impl = KqueueLoop() + elif hasattr(select, 'select'): + self._impl = SelectLoop() + else: + raise Exception('can not find any available functions in select ' + 'package') + self._fd_to_f = defaultdict(list) + + def poll(self, timeout=None): + events = self._impl.poll(timeout) + return ((self._fd_to_f[fd], event) for fd, event in events) + + def add(self, f, mode): + fd = f.fileno() + self._fd_to_f[fd].append(f) + self._impl.add_fd(fd, mode) + + def remove(self, f): + fd = f.fileno() + a = self._fd_to_f[fd] + if len(a) <= 1: + self._fd_to_f[fd] = None + else: + a.remove(f) + self._impl.remove_fd(fd) + + def modify(self, f, mode): + fd = f.fileno() + self._impl.modify_fd(fd, mode) diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 841c035..819c874 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -84,6 +84,7 @@ class UDPRelay(object): self._timeout = timeout self._is_local = is_local self._eventloop = event.EventLoop() + self._cache = {} # TODO replace this dictionary with an LRU cache def _handle_server(self, addr, sock, data): # TODO @@ -96,7 +97,7 @@ class UDPRelay(object): def _run(self): eventloop = self._eventloop server_socket = self._server_socket - eventloop.add_fd(server_socket, event.MODE_IN) + eventloop.add(server_socket, event.MODE_IN) is_local = self._is_local while True: r = eventloop.poll()