Browse Source

new UDP over TCP protocol, merge master

master
breakwa11 10 years ago
parent
commit
a38db829f2
  1. 60
      shadowsocks/asyncdns.py
  2. 15
      shadowsocks/common.py
  3. 120
      shadowsocks/eventloop.py
  4. 209
      shadowsocks/tcprelay.py
  5. 1010
      shadowsocks/udprelay.py

60
shadowsocks/asyncdns.py

@ -18,7 +18,6 @@
from __future__ import absolute_import, division, print_function, \ from __future__ import absolute_import, division, print_function, \
with_statement with_statement
import time
import os import os
import socket import socket
import struct import struct
@ -256,7 +255,6 @@ class DNSResolver(object):
self._hostname_to_cb = {} self._hostname_to_cb = {}
self._cb_to_hostname = {} self._cb_to_hostname = {}
self._cache = lru_cache.LRUCache(timeout=300) self._cache = lru_cache.LRUCache(timeout=300)
self._last_time = time.time()
self._sock = None self._sock = None
self._servers = None self._servers = None
self._parse_resolv() self._parse_resolv()
@ -304,7 +302,7 @@ class DNSResolver(object):
except IOError: except IOError:
self._hosts['localhost'] = '127.0.0.1' self._hosts['localhost'] = '127.0.0.1'
def add_to_loop(self, loop, ref=False): def add_to_loop(self, loop):
if self._loop: if self._loop:
raise Exception('already add to loop') raise Exception('already add to loop')
self._loop = loop self._loop = loop
@ -312,8 +310,8 @@ class DNSResolver(object):
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP) socket.SOL_UDP)
self._sock.setblocking(False) self._sock.setblocking(False)
loop.add(self._sock, eventloop.POLL_IN) loop.add(self._sock, eventloop.POLL_IN, self)
loop.add_handler(self.handle_events, ref=ref) loop.add_periodic(self.handle_periodic)
def _call_callback(self, hostname, ip, error=None): def _call_callback(self, hostname, ip, error=None):
callbacks = self._hostname_to_cb.get(hostname, []) callbacks = self._hostname_to_cb.get(hostname, [])
@ -354,30 +352,27 @@ class DNSResolver(object):
self._call_callback(hostname, None) self._call_callback(hostname, None)
break break
def handle_events(self, events): def handle_event(self, sock, fd, event):
for sock, fd, event in events: if sock != self._sock:
if sock != self._sock: return
continue if event & eventloop.POLL_ERR:
if event & eventloop.POLL_ERR: logging.error('dns socket err')
logging.error('dns socket err') self._loop.remove(self._sock)
self._loop.remove(self._sock) self._sock.close()
self._sock.close() # TODO when dns server is IPv6
# TODO when dns server is IPv6 self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP)
socket.SOL_UDP) self._sock.setblocking(False)
self._sock.setblocking(False) self._loop.add(self._sock, eventloop.POLL_IN, self)
self._loop.add(self._sock, eventloop.POLL_IN) else:
else: data, addr = sock.recvfrom(1024)
data, addr = sock.recvfrom(1024) if addr[0] not in self._servers:
if addr[0] not in self._servers: logging.warn('received a packet other than our dns')
logging.warn('received a packet other than our dns') return
break self._handle_data(data)
self._handle_data(data)
break def handle_periodic(self):
now = time.time() self._cache.sweep()
if now - self._last_time > CACHE_SWEEP_INTERVAL:
self._cache.sweep()
self._last_time = now
def remove_callback(self, callback): def remove_callback(self, callback):
hostname = self._cb_to_hostname.get(callback) hostname = self._cb_to_hostname.get(callback)
@ -430,6 +425,9 @@ class DNSResolver(object):
def close(self): def close(self):
if self._sock: if self._sock:
if self._loop:
self._loop.remove_periodic(self.handle_periodic)
self._loop.remove(self._sock)
self._sock.close() self._sock.close()
self._sock = None self._sock = None
@ -437,7 +435,7 @@ class DNSResolver(object):
def test(): def test():
dns_resolver = DNSResolver() dns_resolver = DNSResolver()
loop = eventloop.EventLoop() loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop, ref=True) dns_resolver.add_to_loop(loop)
global counter global counter
counter = 0 counter = 0
@ -451,8 +449,8 @@ def test():
print(result, error) print(result, error)
counter += 1 counter += 1
if counter == 9: if counter == 9:
loop.remove_handler(dns_resolver.handle_events)
dns_resolver.close() dns_resolver.close()
loop.stop()
a_callback = callback a_callback = callback
return a_callback return a_callback

15
shadowsocks/common.py

@ -151,6 +151,15 @@ def pre_parse_header(data):
data = data[rand_data_size + 2:] data = data[rand_data_size + 2:]
elif datatype == 0x81: elif datatype == 0x81:
data = data[1:] data = data[1:]
elif datatype == 0x82 :
if len(data) <= 3:
return None
rand_data_size = struct.unpack('>H', data[1:3])[0]
if rand_data_size + 3 >= len(data):
logging.warn('header too short, maybe wrong password or '
'encryption method')
return None
data = data[rand_data_size + 3:]
return data return data
def parse_header(data): def parse_header(data):
@ -158,8 +167,8 @@ def parse_header(data):
dest_addr = None dest_addr = None
dest_port = None dest_port = None
header_length = 0 header_length = 0
connecttype = (addrtype & 8) and 1 or 0 connecttype = (addrtype & 0x10) and 1 or 0
addrtype &= ~8 addrtype &= ~0x10
if addrtype == ADDRTYPE_IPV4: if addrtype == ADDRTYPE_IPV4:
if len(data) >= 7: if len(data) >= 7:
dest_addr = socket.inet_ntoa(data[1:5]) dest_addr = socket.inet_ntoa(data[1:5])
@ -173,7 +182,7 @@ def parse_header(data):
if len(data) >= 2 + addrlen: if len(data) >= 2 + addrlen:
dest_addr = data[2:2 + addrlen] dest_addr = data[2:2 + addrlen]
dest_port = struct.unpack('>H', data[2 + addrlen:4 + dest_port = struct.unpack('>H', data[2 + addrlen:4 +
addrlen])[0] addrlen])[0]
header_length = 4 + addrlen header_length = 4 + addrlen
else: else:
logging.warn('header is too short') logging.warn('header is too short')

120
shadowsocks/eventloop.py

@ -22,6 +22,7 @@ from __future__ import absolute_import, division, print_function, \
with_statement with_statement
import os import os
import time
import socket import socket
import select import select
import errno import errno
@ -51,23 +52,8 @@ EVENT_NAMES = {
POLL_NVAL: 'POLL_NVAL', POLL_NVAL: 'POLL_NVAL',
} }
# we check timeouts every TIMEOUT_PRECISION seconds
class EpollLoop(object): TIMEOUT_PRECISION = 10
def __init__(self):
self._epoll = select.epoll()
def poll(self, timeout):
return self._epoll.poll(timeout)
def add_fd(self, fd, mode):
self._epoll.register(fd, mode)
def remove_fd(self, fd):
self._epoll.unregister(fd)
def modify_fd(self, fd, mode):
self._epoll.modify(fd, mode)
class KqueueLoop(object): class KqueueLoop(object):
@ -100,17 +86,17 @@ class KqueueLoop(object):
results[fd] |= POLL_OUT results[fd] |= POLL_OUT
return results.items() return results.items()
def add_fd(self, fd, mode): def register(self, fd, mode):
self._fds[fd] = mode self._fds[fd] = mode
self._control(fd, mode, select.KQ_EV_ADD) self._control(fd, mode, select.KQ_EV_ADD)
def remove_fd(self, fd): def unregister(self, fd):
self._control(fd, self._fds[fd], select.KQ_EV_DELETE) self._control(fd, self._fds[fd], select.KQ_EV_DELETE)
del self._fds[fd] del self._fds[fd]
def modify_fd(self, fd, mode): def modify(self, fd, mode):
self.remove_fd(fd) self.unregister(fd)
self.add_fd(fd, mode) self.register(fd, mode)
class SelectLoop(object): class SelectLoop(object):
@ -129,7 +115,7 @@ class SelectLoop(object):
results[fd] |= p[1] results[fd] |= p[1]
return results.items() return results.items()
def add_fd(self, fd, mode): def register(self, fd, mode):
if mode & POLL_IN: if mode & POLL_IN:
self._r_list.add(fd) self._r_list.add(fd)
if mode & POLL_OUT: if mode & POLL_OUT:
@ -137,7 +123,7 @@ class SelectLoop(object):
if mode & POLL_ERR: if mode & POLL_ERR:
self._x_list.add(fd) self._x_list.add(fd)
def remove_fd(self, fd): def unregister(self, fd):
if fd in self._r_list: if fd in self._r_list:
self._r_list.remove(fd) self._r_list.remove(fd)
if fd in self._w_list: if fd in self._w_list:
@ -145,16 +131,15 @@ class SelectLoop(object):
if fd in self._x_list: if fd in self._x_list:
self._x_list.remove(fd) self._x_list.remove(fd)
def modify_fd(self, fd, mode): def modify(self, fd, mode):
self.remove_fd(fd) self.unregister(fd)
self.add_fd(fd, mode) self.register(fd, mode)
class EventLoop(object): class EventLoop(object):
def __init__(self): def __init__(self):
self._iterating = False
if hasattr(select, 'epoll'): if hasattr(select, 'epoll'):
self._impl = EpollLoop() self._impl = select.epoll()
model = 'epoll' model = 'epoll'
elif hasattr(select, 'kqueue'): elif hasattr(select, 'kqueue'):
self._impl = KqueueLoop() self._impl = KqueueLoop()
@ -165,72 +150,71 @@ class EventLoop(object):
else: else:
raise Exception('can not find any available functions in select ' raise Exception('can not find any available functions in select '
'package') 'package')
self._fd_to_f = {} self._fdmap = {} # (f, handler)
self._handlers = [] self._last_time = time.time()
self._ref_handlers = [] self._periodic_callbacks = []
self._handlers_to_remove = [] self._stopping = False
logging.debug('using event model: %s', model) logging.debug('using event model: %s', model)
def poll(self, timeout=None): def poll(self, timeout=None):
events = self._impl.poll(timeout) events = self._impl.poll(timeout)
return [(self._fd_to_f[fd], fd, event) for fd, event in events] return [(self._fdmap[fd][0], fd, event) for fd, event in events]
def add(self, f, mode): def add(self, f, mode, handler):
fd = f.fileno() fd = f.fileno()
self._fd_to_f[fd] = f self._fdmap[fd] = (f, handler)
self._impl.add_fd(fd, mode) self._impl.register(fd, mode)
def remove(self, f): def remove(self, f):
fd = f.fileno() fd = f.fileno()
del self._fd_to_f[fd] del self._fdmap[fd]
self._impl.remove_fd(fd) self._impl.unregister(fd)
def add_periodic(self, callback):
self._periodic_callbacks.append(callback)
def remove_periodic(self, callback):
self._periodic_callbacks.remove(callback)
def modify(self, f, mode): def modify(self, f, mode):
fd = f.fileno() fd = f.fileno()
self._impl.modify_fd(fd, mode) self._impl.modify(fd, mode)
def add_handler(self, handler, ref=True): def stop(self):
self._handlers.append(handler) self._stopping = True
if ref:
# when all ref handlers are removed, loop stops
self._ref_handlers.append(handler)
def remove_handler(self, handler):
if handler in self._ref_handlers:
self._ref_handlers.remove(handler)
if self._iterating:
self._handlers_to_remove.append(handler)
else:
self._handlers.remove(handler)
def run(self): def run(self):
events = [] events = []
while self._ref_handlers: while not self._stopping:
asap = False
try: try:
events = self.poll(1) events = self.poll(TIMEOUT_PRECISION)
except (OSError, IOError) as e: except (OSError, IOError) as e:
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
# EPIPE: Happens when the client closes the connection # EPIPE: Happens when the client closes the connection
# EINTR: Happens when received a signal # EINTR: Happens when received a signal
# handles them as soon as possible # handles them as soon as possible
asap = True
logging.debug('poll:%s', e) logging.debug('poll:%s', e)
else: else:
logging.error('poll:%s', e) logging.error('poll:%s', e)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
continue continue
self._iterating = True
for handler in self._handlers: for sock, fd, event in events:
# TODO when there are a lot of handlers handler = self._fdmap.get(fd, None)
try: if handler is not None:
handler(events) handler = handler[1]
except (OSError, IOError) as e: try:
shell.print_exception(e) handler.handle_event(sock, fd, event)
if self._handlers_to_remove: except (OSError, IOError) as e:
for handler in self._handlers_to_remove: shell.print_exception(e)
self._handlers.remove(handler) now = time.time()
self._handlers_to_remove = [] if asap or now - self._last_time >= TIMEOUT_PRECISION:
self._iterating = False for callback in self._periodic_callbacks:
callback()
self._last_time = now
# from tornado # from tornado

209
shadowsocks/tcprelay.py

@ -115,6 +115,7 @@ class TCPRelayHandler(object):
self._fastopen_connected = False self._fastopen_connected = False
self._data_to_write_to_local = [] self._data_to_write_to_local = []
self._data_to_write_to_remote = [] self._data_to_write_to_remote = []
self._udp_data_send_buffer = ''
self._upstream_status = WAIT_STATUS_READING self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT self._downstream_status = WAIT_STATUS_INIT
self._client_address = local_sock.getpeername()[:2] self._client_address = local_sock.getpeername()[:2]
@ -128,7 +129,8 @@ class TCPRelayHandler(object):
fd_to_handlers[local_sock.fileno()] = self fd_to_handlers[local_sock.fileno()] = self
local_sock.setblocking(False) local_sock.setblocking(False)
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR,
self._server)
self.last_activity = 0 self.last_activity = 0
self._update_activity() self._update_activity()
@ -185,6 +187,8 @@ class TCPRelayHandler(object):
if self._upstream_status & WAIT_STATUS_WRITING: if self._upstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_OUT event |= eventloop.POLL_OUT
self._loop.modify(self._remote_sock, event) self._loop.modify(self._remote_sock, event)
if self._remote_sock_v6:
self._loop.modify(self._remote_sock_v6, event)
def _write_to_sock(self, data, sock): def _write_to_sock(self, data, sock):
# write data to sock # write data to sock
@ -193,51 +197,70 @@ class TCPRelayHandler(object):
if not data or not sock: if not data or not sock:
return False return False
#logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp)) #logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp))
if self._remote_udp and self._remote_sock == sock: uncomplete = False
if self._remote_udp and sock == self._remote_sock:
try: try:
frag = common.ord(data[2]) self._udp_data_send_buffer += data
if frag != 0: #logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data)))
logging.warn('drop a message since frag is %d' % (frag,)) while len(self._udp_data_send_buffer) > 6:
return False length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0]
else:
data = data[3:] if length > len(self._udp_data_send_buffer):
header_result = parse_header(data) break
if header_result is None:
return False data = self._udp_data_send_buffer[:length]
connecttype, dest_addr, dest_port, header_length = header_result self._udp_data_send_buffer = self._udp_data_send_buffer[length:]
addrs = socket.getaddrinfo(dest_addr, dest_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP) frag = common.ord(data[2])
if addrs: if frag != 0:
af, socktype, proto, canonname, server_addr = addrs[0] logging.warn('drop a message since frag is %d' % (frag,))
data = data[header_length:] continue
if af == socket.AF_INET6:
self._remote_sock_v6.sendto(data, (server_addr[0], dest_port))
else: else:
sock.sendto(data, (server_addr[0], dest_port)) data = data[3:]
header_result = parse_header(data)
if header_result is None:
continue
connecttype, dest_addr, dest_port, header_length = header_result
addrs = socket.getaddrinfo(dest_addr, dest_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP)
#logging.info('UDP over TCP sendto %s:%d %d bytes from %s:%d' % (dest_addr, dest_port, len(data), self._client_address[0], self._client_address[1]))
if addrs:
af, socktype, proto, canonname, server_addr = addrs[0]
data = data[header_length:]
if af == socket.AF_INET6:
self._remote_sock_v6.sendto(data, (server_addr[0], dest_port))
else:
sock.sendto(data, (server_addr[0], dest_port))
except Exception as e: except Exception as e:
#trace = traceback.format_exc() #trace = traceback.format_exc()
#logging.error(trace) #logging.error(trace)
logging.error(e) error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
uncomplete = True
else:
shell.print_exception(e)
self.destroy()
return False
return True return True
else:
uncomplete = False try:
try: l = len(data)
l = len(data) s = sock.send(data)
s = sock.send(data) if s < l:
if s < l: data = data[s:]
data = data[s:] uncomplete = True
uncomplete = True except (OSError, IOError) as e:
except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e)
error_no = eventloop.errno_from_exception(e) if error_no in (errno.EAGAIN, errno.EINPROGRESS,
if error_no in (errno.EAGAIN, errno.EINPROGRESS, errno.EWOULDBLOCK):
errno.EWOULDBLOCK): uncomplete = True
uncomplete = True else:
else: #traceback.print_exc()
#traceback.print_exc() shell.print_exception(e)
shell.print_exception(e) self.destroy()
self.destroy() return False
return False
if uncomplete: if uncomplete:
if sock == self._local_sock: if sock == self._local_sock:
self._data_to_write_to_local.append(data) self._data_to_write_to_local.append(data)
@ -270,7 +293,7 @@ class TCPRelayHandler(object):
remote_sock = \ remote_sock = \
self._create_remote_socket(self._chosen_server[0], self._create_remote_socket(self._chosen_server[0],
self._chosen_server[1]) self._chosen_server[1])
self._loop.add(remote_sock, eventloop.POLL_ERR) self._loop.add(remote_sock, eventloop.POLL_ERR, self._server)
data = b''.join(self._data_to_write_to_remote) data = b''.join(self._data_to_write_to_remote)
l = len(data) l = len(data)
s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server)
@ -382,6 +405,11 @@ class TCPRelayHandler(object):
remote_sock_v6 = socket.socket(af, socktype, proto) remote_sock_v6 = socket.socket(af, socktype, proto)
self._remote_sock_v6 = remote_sock_v6 self._remote_sock_v6 = remote_sock_v6
self._fd_to_handlers[remote_sock_v6.fileno()] = self self._fd_to_handlers[remote_sock_v6.fileno()] = self
remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
remote_sock.setblocking(False) remote_sock.setblocking(False)
if self._remote_udp: if self._remote_udp:
@ -421,10 +449,12 @@ class TCPRelayHandler(object):
remote_port) remote_port)
if self._remote_udp: if self._remote_udp:
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_IN) eventloop.POLL_IN,
self._server)
if self._remote_sock_v6: if self._remote_sock_v6:
self._loop.add(self._remote_sock_v6, self._loop.add(self._remote_sock_v6,
eventloop.POLL_IN) eventloop.POLL_IN,
self._server)
else: else:
try: try:
remote_sock.connect((remote_addr, remote_port)) remote_sock.connect((remote_addr, remote_port))
@ -433,10 +463,16 @@ class TCPRelayHandler(object):
errno.EINPROGRESS: errno.EINPROGRESS:
pass pass
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT) eventloop.POLL_ERR | eventloop.POLL_OUT,
self._server)
self._stage = STAGE_CONNECTING self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
if self._remote_udp:
while self._data_to_write_to_remote:
data = self._data_to_write_to_remote[0]
del self._data_to_write_to_remote[0]
self._write_to_sock(data, self._remote_sock)
return return
except Exception as e: except Exception as e:
shell.print_exception(e) shell.print_exception(e)
@ -495,11 +531,12 @@ class TCPRelayHandler(object):
port = struct.pack('>H', addr[1]) port = struct.pack('>H', addr[1])
try: try:
ip = socket.inet_aton(addr[0]) ip = socket.inet_aton(addr[0])
data = '\x00\x00\x00\x01' + ip + port + data data = '\x00\x01' + ip + port + data
except Exception as e: except Exception as e:
ip = socket.inet_pton(socket.AF_INET6, addr[0]) ip = socket.inet_pton(socket.AF_INET6, addr[0])
data = '\x00\x00\x00\x04' + ip + port + data data = '\x00\x04' + ip + port + data
logging.info('UDP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1])) data = struct.pack('>H', len(data) + 2) + data
#logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
else: else:
data = self._remote_sock.recv(BUF_SIZE) data = self._remote_sock.recv(BUF_SIZE)
except (OSError, IOError) as e: except (OSError, IOError) as e:
@ -637,7 +674,6 @@ class TCPRelay(object):
self._closed = False self._closed = False
self._eventloop = None self._eventloop = None
self._fd_to_handlers = {} self._fd_to_handlers = {}
self._last_time = time.time()
self.server_transfer_ul = 0L self.server_transfer_ul = 0L
self.server_transfer_dl = 0L self.server_transfer_dl = 0L
@ -680,10 +716,9 @@ class TCPRelay(object):
if self._closed: if self._closed:
raise Exception('already closed') raise Exception('already closed')
self._eventloop = loop self._eventloop = loop
loop.add_handler(self._handle_events)
self._eventloop.add(self._server_socket, self._eventloop.add(self._server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR) eventloop.POLL_IN | eventloop.POLL_ERR, self)
self._eventloop.add_periodic(self.handle_periodic)
def remove_handler(self, handler): def remove_handler(self, handler):
index = self._handler_to_timeouts.get(hash(handler), -1) index = self._handler_to_timeouts.get(hash(handler), -1)
@ -695,7 +730,7 @@ class TCPRelay(object):
def update_activity(self, handler): def update_activity(self, handler):
# set handler to active # set handler to active
now = int(time.time()) now = int(time.time())
if now - handler.last_activity < TIMEOUT_PRECISION: if now - handler.last_activity < eventloop.TIMEOUT_PRECISION:
# thus we can lower timeout modification frequency # thus we can lower timeout modification frequency
return return
handler.last_activity = now handler.last_activity = now
@ -741,53 +776,55 @@ class TCPRelay(object):
pos = 0 pos = 0
self._timeout_offset = pos self._timeout_offset = pos
def _handle_events(self, events): def handle_event(self, sock, fd, event):
# handle events and dispatch to handlers # handle events and dispatch to handlers
for sock, fd, event in events: if sock:
logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd,
eventloop.EVENT_NAMES.get(event, event))
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
logging.debug('accept')
conn = self._server_socket.accept()
TCPRelayHandler(self, self._fd_to_handlers,
self._eventloop, conn[0], self._config,
self._dns_resolver, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
return
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
else:
if sock: if sock:
logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, handler = self._fd_to_handlers.get(fd, None)
eventloop.EVENT_NAMES.get(event, event)) if handler:
if sock == self._server_socket: handler.handle_event(sock, event)
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
logging.debug('accept')
conn = self._server_socket.accept()
TCPRelayHandler(self, self._fd_to_handlers,
self._eventloop, conn[0], self._config,
self._dns_resolver, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
continue
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
else: else:
if sock: logging.warn('poll removed fd')
handler = self._fd_to_handlers.get(fd, None)
if handler:
handler.handle_event(sock, event)
else:
logging.warn('poll removed fd')
now = time.time() def handle_periodic(self):
if now - self._last_time > TIMEOUT_PRECISION:
self._sweep_timeout()
self._last_time = now
if self._closed: if self._closed:
if self._server_socket: if self._server_socket:
self._eventloop.remove(self._server_socket) self._eventloop.remove(self._server_socket)
self._server_socket.close() self._server_socket.close()
self._server_socket = None self._server_socket = None
logging.info('closed listen port %d', self._listen_port) logging.info('closed TCP port %d', self._listen_port)
if not self._fd_to_handlers: if not self._fd_to_handlers:
self._eventloop.remove_handler(self._handle_events) logging.info('stopping')
self._eventloop.stop()
self._sweep_timeout()
def close(self, next_tick=False): def close(self, next_tick=False):
logging.debug('TCP close')
self._closed = True self._closed = True
if not next_tick: if not next_tick:
if self._eventloop:
self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket)
self._server_socket.close() self._server_socket.close()

1010
shadowsocks/udprelay.py

File diff suppressed because it is too large
Loading…
Cancel
Save