diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 4b74825..89ffbcb 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -29,15 +29,15 @@ import select from collections import defaultdict -__all__ = ['EventLoop', 'MODE_NULL', 'MODE_IN', 'MODE_OUT', 'MODE_ERR', - 'MODE_HUP', 'MODE_NVAL'] +__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR', + 'POLL_HUP', 'POLL_NVAL'] -MODE_NULL = 0x00 -MODE_IN = 0x01 -MODE_OUT = 0x04 -MODE_ERR = 0x08 -MODE_HUP = 0x10 -MODE_NVAL = 0x20 +POLL_NULL = 0x00 +POLL_IN = 0x01 +POLL_OUT = 0x04 +POLL_ERR = 0x08 +POLL_HUP = 0x10 +POLL_NVAL = 0x20 class EpollLoop(object): @@ -68,9 +68,9 @@ class KqueueLoop(object): def _control(self, fd, mode, flags): events = [] - if mode & MODE_IN: + if mode & POLL_IN: events.append(select.kevent(fd, select.KQ_FILTER_READ, flags)) - if mode & MODE_OUT: + if mode & POLL_OUT: events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags)) for e in events: self._kqueue.control([e], 0) @@ -79,13 +79,13 @@ class KqueueLoop(object): if timeout < 0: timeout = None # kqueue behaviour events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout) - results = defaultdict(lambda: MODE_NULL) + results = defaultdict(lambda: POLL_NULL) for e in events: fd = e.ident if e.filter == select.KQ_FILTER_READ: - results[fd] |= MODE_IN + results[fd] |= POLL_IN elif e.filter == select.KQ_FILTER_WRITE: - results[fd] |= MODE_OUT + results[fd] |= POLL_OUT return results.iteritems() def add_fd(self, fd, mode): @@ -111,18 +111,18 @@ class SelectLoop(object): def poll(self, timeout): r, w, x = select.select(self._r_list, self._w_list, self._x_list, timeout) - results = defaultdict(lambda: MODE_NULL) - for p in [(r, MODE_IN), (w, MODE_OUT), (x, MODE_ERR)]: + results = defaultdict(lambda: POLL_NULL) + for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]: for fd in p[0]: results[fd] |= p[1] return results.items() def add_fd(self, fd, mode): - if mode & MODE_IN: + if mode & POLL_IN: self._r_list.add(fd) - if mode & MODE_OUT: + if mode & POLL_OUT: self._w_list.add(fd) - if mode & MODE_ERR: + if mode & POLL_ERR: self._x_list.add(fd) def remove_fd(self, fd): @@ -168,3 +168,22 @@ class EventLoop(object): def modify(self, f, mode): fd = f.fileno() self._impl.modify_fd(fd, mode) + + +# from tornado +def errno_from_exception(e): + """Provides the errno from an Exception object. + + There are cases that the errno attribute was not set so we pull + the errno out of the args but if someone instatiates an Exception + without any args you will get a tuple error. So this function + abstracts all that behavior to give you a safe way to get the + errno. + """ + + if hasattr(e, 'errno'): + return e.errno + elif e.args: + return e.args[0] + else: + return None diff --git a/shadowsocks/lru_cache.py b/shadowsocks/lru_cache.py index 399c810..ce40d17 100644 --- a/shadowsocks/lru_cache.py +++ b/shadowsocks/lru_cache.py @@ -10,8 +10,9 @@ import time class LRUCache(collections.MutableMapping): """This class is not thread safe""" - def __init__(self, timeout=60, *args, **kwargs): + def __init__(self, timeout=60, close_callback=None, *args, **kwargs): self.timeout = timeout + self.close_callback = close_callback self.store = {} self.time_to_keys = collections.defaultdict(list) self.last_visits = [] @@ -53,8 +54,9 @@ class LRUCache(collections.MutableMapping): heapq.heappop(self.last_visits) if self.store.__contains__(key): value = self.store[key] - if hasattr(value, 'close'): - value.close() + if self.close_callback is not None: + self.close_callback(value) + del self.store[key] c += 1 del self.time_to_keys[least] diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index eb5c677..1e72742 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -74,6 +74,7 @@ import struct import encrypt import eventloop import lru_cache +import errno BUF_SIZE = 65536 @@ -137,6 +138,14 @@ class UDPRelay(object): self._cache = lru_cache.LRUCache(timeout=timeout) self._client_fd_to_server_addr = lru_cache.LRUCache(timeout=timeout) + def _close_client(self, client): + if hasattr(client, 'close'): + self._eventloop.remove(client) + client.close() + else: + # just an address + pass + def _handle_server(self): server = self._server_socket data, r_addr = server.recvfrom(BUF_SIZE) @@ -177,7 +186,7 @@ class UDPRelay(object): else: # drop return - self._eventloop.add(client, eventloop.MODE_IN) + self._eventloop.add(client, eventloop.POLL_IN) # prevent from recv other sources if self._is_local: @@ -225,10 +234,18 @@ class UDPRelay(object): def _run(self): server_socket = self._server_socket - self._eventloop.add(server_socket, eventloop.MODE_IN) + self._eventloop.add(server_socket, eventloop.POLL_IN) last_time = time.time() while True: - events = self._eventloop.poll(10) + try: + events = self._eventloop.poll(10) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == errno.EPIPE: + # Happens when the client closes the connection + continue + else: + logging.error(e) + continue for sock, event in events: if sock == self._server_socket: self._handle_server()