Browse Source

fix remove()

1.4
clowwindy 11 years ago
parent
commit
e6a225513e
  1. 55
      shadowsocks/eventloop.py
  2. 8
      shadowsocks/lru_cache.py
  3. 21
      shadowsocks/udprelay.py

55
shadowsocks/eventloop.py

@ -29,15 +29,15 @@ import select
from collections import defaultdict from collections import defaultdict
__all__ = ['EventLoop', 'MODE_NULL', 'MODE_IN', 'MODE_OUT', 'MODE_ERR', __all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR',
'MODE_HUP', 'MODE_NVAL'] 'POLL_HUP', 'POLL_NVAL']
MODE_NULL = 0x00 POLL_NULL = 0x00
MODE_IN = 0x01 POLL_IN = 0x01
MODE_OUT = 0x04 POLL_OUT = 0x04
MODE_ERR = 0x08 POLL_ERR = 0x08
MODE_HUP = 0x10 POLL_HUP = 0x10
MODE_NVAL = 0x20 POLL_NVAL = 0x20
class EpollLoop(object): class EpollLoop(object):
@ -68,9 +68,9 @@ class KqueueLoop(object):
def _control(self, fd, mode, flags): def _control(self, fd, mode, flags):
events = [] events = []
if mode & MODE_IN: if mode & POLL_IN:
events.append(select.kevent(fd, select.KQ_FILTER_READ, flags)) events.append(select.kevent(fd, select.KQ_FILTER_READ, flags))
if mode & MODE_OUT: if mode & POLL_OUT:
events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags)) events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags))
for e in events: for e in events:
self._kqueue.control([e], 0) self._kqueue.control([e], 0)
@ -79,13 +79,13 @@ class KqueueLoop(object):
if timeout < 0: if timeout < 0:
timeout = None # kqueue behaviour timeout = None # kqueue behaviour
events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout) events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout)
results = defaultdict(lambda: MODE_NULL) results = defaultdict(lambda: POLL_NULL)
for e in events: for e in events:
fd = e.ident fd = e.ident
if e.filter == select.KQ_FILTER_READ: if e.filter == select.KQ_FILTER_READ:
results[fd] |= MODE_IN results[fd] |= POLL_IN
elif e.filter == select.KQ_FILTER_WRITE: elif e.filter == select.KQ_FILTER_WRITE:
results[fd] |= MODE_OUT results[fd] |= POLL_OUT
return results.iteritems() return results.iteritems()
def add_fd(self, fd, mode): def add_fd(self, fd, mode):
@ -111,18 +111,18 @@ class SelectLoop(object):
def poll(self, timeout): def poll(self, timeout):
r, w, x = select.select(self._r_list, self._w_list, self._x_list, r, w, x = select.select(self._r_list, self._w_list, self._x_list,
timeout) timeout)
results = defaultdict(lambda: MODE_NULL) results = defaultdict(lambda: POLL_NULL)
for p in [(r, MODE_IN), (w, MODE_OUT), (x, MODE_ERR)]: for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]:
for fd in p[0]: for fd in p[0]:
results[fd] |= p[1] results[fd] |= p[1]
return results.items() return results.items()
def add_fd(self, fd, mode): def add_fd(self, fd, mode):
if mode & MODE_IN: if mode & POLL_IN:
self._r_list.add(fd) self._r_list.add(fd)
if mode & MODE_OUT: if mode & POLL_OUT:
self._w_list.add(fd) self._w_list.add(fd)
if mode & MODE_ERR: if mode & POLL_ERR:
self._x_list.add(fd) self._x_list.add(fd)
def remove_fd(self, fd): def remove_fd(self, fd):
@ -168,3 +168,22 @@ class EventLoop(object):
def modify(self, f, mode): def modify(self, f, mode):
fd = f.fileno() fd = f.fileno()
self._impl.modify_fd(fd, mode) self._impl.modify_fd(fd, mode)
# from tornado
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None

8
shadowsocks/lru_cache.py

@ -10,8 +10,9 @@ import time
class LRUCache(collections.MutableMapping): class LRUCache(collections.MutableMapping):
"""This class is not thread safe""" """This class is not thread safe"""
def __init__(self, timeout=60, *args, **kwargs): def __init__(self, timeout=60, close_callback=None, *args, **kwargs):
self.timeout = timeout self.timeout = timeout
self.close_callback = close_callback
self.store = {} self.store = {}
self.time_to_keys = collections.defaultdict(list) self.time_to_keys = collections.defaultdict(list)
self.last_visits = [] self.last_visits = []
@ -53,8 +54,9 @@ class LRUCache(collections.MutableMapping):
heapq.heappop(self.last_visits) heapq.heappop(self.last_visits)
if self.store.__contains__(key): if self.store.__contains__(key):
value = self.store[key] value = self.store[key]
if hasattr(value, 'close'): if self.close_callback is not None:
value.close() self.close_callback(value)
del self.store[key] del self.store[key]
c += 1 c += 1
del self.time_to_keys[least] del self.time_to_keys[least]

21
shadowsocks/udprelay.py

@ -74,6 +74,7 @@ import struct
import encrypt import encrypt
import eventloop import eventloop
import lru_cache import lru_cache
import errno
BUF_SIZE = 65536 BUF_SIZE = 65536
@ -137,6 +138,14 @@ class UDPRelay(object):
self._cache = lru_cache.LRUCache(timeout=timeout) self._cache = lru_cache.LRUCache(timeout=timeout)
self._client_fd_to_server_addr = lru_cache.LRUCache(timeout=timeout) self._client_fd_to_server_addr = lru_cache.LRUCache(timeout=timeout)
def _close_client(self, client):
if hasattr(client, 'close'):
self._eventloop.remove(client)
client.close()
else:
# just an address
pass
def _handle_server(self): def _handle_server(self):
server = self._server_socket server = self._server_socket
data, r_addr = server.recvfrom(BUF_SIZE) data, r_addr = server.recvfrom(BUF_SIZE)
@ -177,7 +186,7 @@ class UDPRelay(object):
else: else:
# drop # drop
return return
self._eventloop.add(client, eventloop.MODE_IN) self._eventloop.add(client, eventloop.POLL_IN)
# prevent from recv other sources # prevent from recv other sources
if self._is_local: if self._is_local:
@ -225,10 +234,18 @@ class UDPRelay(object):
def _run(self): def _run(self):
server_socket = self._server_socket server_socket = self._server_socket
self._eventloop.add(server_socket, eventloop.MODE_IN) self._eventloop.add(server_socket, eventloop.POLL_IN)
last_time = time.time() last_time = time.time()
while True: while True:
try:
events = self._eventloop.poll(10) events = self._eventloop.poll(10)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for sock, event in events: for sock, event in events:
if sock == self._server_socket: if sock == self._server_socket:
self._handle_server() self._handle_server()

Loading…
Cancel
Save