Python port of ShadowsocksR
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

870 lines
36 KiB

10 years ago
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
10 years ago
#
# http://www.apache.org/licenses/LICENSE-2.0
10 years ago
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
10 years ago
from __future__ import absolute_import, division, print_function, \
with_statement
10 years ago
import time
import socket
10 years ago
import errno
import struct
10 years ago
import logging
import binascii
10 years ago
import traceback
10 years ago
import random
from shadowsocks import encrypt, eventloop, shell, common
from shadowsocks.common import pre_parse_header, parse_header
10 years ago
# set it 'False' to use both new protocol and the original shadowsocks protocal
# set it 'True' to use new protocol ONLY, to avoid GFW detecting
FORCE_NEW_PROTOCOL = False
10 years ago
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512
10 years ago
MSG_FASTOPEN = 0x20000000
10 years ago
# SOCKS command definition
10 years ago
CMD_CONNECT = 1
10 years ago
CMD_BIND = 2
CMD_UDP_ASSOCIATE = 3
10 years ago
10 years ago
# for each opening port, we have a TCP Relay
10 years ago
10 years ago
# for each connection, we have a TCP Relay Handler to handle the connection
# for each handler, we have 2 sockets:
# local: connected to the client
# remote: connected to remote server
# for each handler, it could be at one of several stages:
10 years ago
# as sslocal:
10 years ago
# stage 0 SOCKS hello received from local, send hello to local
# stage 1 addr received from local, query DNS for remote
10 years ago
# stage 2 UDP assoc
10 years ago
# stage 3 DNS resolved, connect to remote
# stage 4 still connecting, more data from local received
# stage 5 remote connected, piping local and remote
10 years ago
10 years ago
# as ssserver:
10 years ago
# stage 0 just jump to stage 1
# stage 1 addr received from local, query DNS for remote
# stage 3 DNS resolved, connect to remote
# stage 4 still connecting, more data from local received
# stage 5 remote connected, piping local and remote
10 years ago
10 years ago
STAGE_INIT = 0
10 years ago
STAGE_ADDR = 1
10 years ago
STAGE_UDP_ASSOC = 2
STAGE_DNS = 3
10 years ago
STAGE_CONNECTING = 4
10 years ago
STAGE_STREAM = 5
STAGE_DESTROYED = -1
10 years ago
10 years ago
# for each handler, we have 2 stream directions:
# upstream: from client to server direction
# read local and write to remote
# downstream: from server to client direction
# read remote and write to local
10 years ago
STREAM_UP = 0
STREAM_DOWN = 1
10 years ago
# for each stream, it's waiting for reading, or writing, or both
10 years ago
WAIT_STATUS_INIT = 0
WAIT_STATUS_READING = 1
WAIT_STATUS_WRITING = 2
WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING
10 years ago
10 years ago
BUF_SIZE = 32 * 1024
10 years ago
class TCPRelayHandler(object):
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
dns_resolver, is_local):
self._server = server
10 years ago
self._fd_to_handlers = fd_to_handlers
self._loop = loop
10 years ago
self._local_sock = local_sock
self._remote_sock = None
self._remote_sock_v6 = None
self._remote_udp = False
10 years ago
self._config = config
self._dns_resolver = dns_resolver
10 years ago
# TCP Relay works as either sslocal or ssserver
# if is_local, this is sslocal
10 years ago
self._is_local = is_local
10 years ago
self._stage = STAGE_INIT
10 years ago
self._encryptor = encrypt.Encryptor(config['password'],
config['method'])
self._encrypt_correct = True
self._fastopen_connected = False
10 years ago
self._data_to_write_to_local = []
self._data_to_write_to_remote = []
self._udp_data_send_buffer = ''
10 years ago
self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT
self._client_address = local_sock.getpeername()[:2]
self._remote_address = None
if 'forbidden_ip' in config:
self._forbidden_iplist = config['forbidden_ip']
else:
self._forbidden_iplist = None
10 years ago
if is_local:
self._chosen_server = self._get_a_server()
10 years ago
fd_to_handlers[local_sock.fileno()] = self
local_sock.setblocking(False)
10 years ago
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR,
self._server)
self.last_activity = 0
10 years ago
self._update_activity()
def __hash__(self):
# default __hash__ is id / 16
# we want to eliminate collisions
return id(self)
@property
def remote_address(self):
return self._remote_address
10 years ago
def _get_a_server(self):
server = self._config['server']
server_port = self._config['server_port']
if type(server_port) == list:
server_port = random.choice(server_port)
if type(server) == list:
server = random.choice(server)
10 years ago
logging.debug('chosen server: %s:%d', server, server_port)
return server, server_port
def _update_activity(self, data_len=0):
10 years ago
# tell the TCP Relay we have activities recently
# else it will think we are inactive and timed out
self._server.update_activity(self, data_len)
10 years ago
10 years ago
def _update_stream(self, stream, status):
10 years ago
# update a stream to a new waiting status
# check if status is changed
# only update if dirty
10 years ago
dirty = False
if stream == STREAM_DOWN:
if self._downstream_status != status:
self._downstream_status = status
dirty = True
elif stream == STREAM_UP:
if self._upstream_status != status:
self._upstream_status = status
dirty = True
if dirty:
if self._local_sock:
event = eventloop.POLL_ERR
10 years ago
if self._downstream_status & WAIT_STATUS_WRITING:
10 years ago
event |= eventloop.POLL_OUT
10 years ago
if self._upstream_status & WAIT_STATUS_READING:
10 years ago
event |= eventloop.POLL_IN
self._loop.modify(self._local_sock, event)
if self._remote_sock:
event = eventloop.POLL_ERR
10 years ago
if self._downstream_status & WAIT_STATUS_READING:
10 years ago
event |= eventloop.POLL_IN
10 years ago
if self._upstream_status & WAIT_STATUS_WRITING:
10 years ago
event |= eventloop.POLL_OUT
self._loop.modify(self._remote_sock, event)
if self._remote_sock_v6:
self._loop.modify(self._remote_sock_v6, event)
10 years ago
10 years ago
def _write_to_sock(self, data, sock):
10 years ago
# write data to sock
# if only some of the data are written, put remaining in the buffer
# and update the stream to wait for writing
10 years ago
if not data or not sock:
return False
#logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp))
uncomplete = False
if self._remote_udp and sock == self._remote_sock:
try:
self._udp_data_send_buffer += data
#logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data)))
while len(self._udp_data_send_buffer) > 6:
length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0]
if length > len(self._udp_data_send_buffer):
break
data = self._udp_data_send_buffer[:length]
self._udp_data_send_buffer = self._udp_data_send_buffer[length:]
frag = common.ord(data[2])
if frag != 0:
logging.warn('drop a message since frag is %d' % (frag,))
continue
else:
data = data[3:]
header_result = parse_header(data)
if header_result is None:
continue
connecttype, dest_addr, dest_port, header_length = header_result
addrs = socket.getaddrinfo(dest_addr, dest_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP)
#logging.info('UDP over TCP sendto %s:%d %d bytes from %s:%d' % (dest_addr, dest_port, len(data), self._client_address[0], self._client_address[1]))
if addrs:
af, socktype, proto, canonname, server_addr = addrs[0]
data = data[header_length:]
if af == socket.AF_INET6:
self._remote_sock_v6.sendto(data, (server_addr[0], dest_port))
else:
sock.sendto(data, (server_addr[0], dest_port))
except Exception as e:
#trace = traceback.format_exc()
#logging.error(trace)
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
uncomplete = True
else:
shell.print_exception(e)
self.destroy()
return False
return True
else:
try:
l = len(data)
s = sock.send(data)
if s < l:
data = data[s:]
uncomplete = True
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
uncomplete = True
else:
#traceback.print_exc()
shell.print_exception(e)
self.destroy()
return False
10 years ago
if uncomplete:
if sock == self._local_sock:
self._data_to_write_to_local.append(data)
10 years ago
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
10 years ago
elif sock == self._remote_sock:
self._data_to_write_to_remote.append(data)
10 years ago
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
10 years ago
else:
logging.error('write_all_to_sock:unknown socket')
10 years ago
else:
if sock == self._local_sock:
10 years ago
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
10 years ago
elif sock == self._remote_sock:
10 years ago
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
10 years ago
else:
logging.error('write_all_to_sock:unknown socket')
return True
10 years ago
def _get_redirect_host(self, client_address, ogn_data):
# test
host_list = [("www.bing.com", 80), ("www.microsoft.com", 80), ("www.baidu.com", 443), ("www.qq.com", 80), ("www.csdn.net", 80), ("1.2.3.4", 1000)]
hash_code = binascii.crc32(ogn_data)
addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP)
af, socktype, proto, canonname, sa = addrs[0]
address_bytes = common.inet_pton(af, sa[0])
if len(address_bytes) == 16:
addr = struct.unpack('>Q', address_bytes[8:])[0]
if len(address_bytes) == 4:
addr = struct.unpack('>I', address_bytes)[0]
else:
addr = 0
return host_list[((hash_code & 0xffffffff) + addr + 3) % len(host_list)]
def _handel_protocol_error(self, client_address, ogn_data):
logging.warn("Protocol ERROR, TCP ogn data %s" % (binascii.hexlify(ogn_data), ))
self._encrypt_correct = False
#create redirect or disconnect by hash code
host, port = self._get_redirect_host(client_address, ogn_data)
data = "\x03" + chr(len(host)) + host + struct.pack('>H', port)
logging.warn("TCP data redir %s:%d %s" % (host, port, binascii.hexlify(data)))
#raise Exception('can not parse header')
return data + ogn_data
10 years ago
def _handle_stage_connecting(self, data):
10 years ago
if self._is_local:
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
if self._is_local and not self._fastopen_connected and \
10 years ago
self._config['fast_open']:
10 years ago
# for sslocal and fastopen, we basically wait for data and use
# sendto to connect
10 years ago
try:
10 years ago
# only connect once
self._fastopen_connected = True
10 years ago
remote_sock = \
self._create_remote_socket(self._chosen_server[0],
self._chosen_server[1])
self._loop.add(remote_sock, eventloop.POLL_ERR, self._server)
data = b''.join(self._data_to_write_to_remote)
10 years ago
l = len(data)
10 years ago
s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server)
10 years ago
if s < l:
data = data[s:]
self._data_to_write_to_remote = [data]
10 years ago
else:
self._data_to_write_to_remote = []
10 years ago
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
10 years ago
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EINPROGRESS:
# in this case data is not sent at all
10 years ago
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
elif eventloop.errno_from_exception(e) == errno.ENOTCONN:
logging.error('fast open not supported on this OS')
self._config['fast_open'] = False
self.destroy()
else:
shell.print_exception(e)
10 years ago
if self._config['verbose']:
traceback.print_exc()
10 years ago
self.destroy()
def _handle_stage_addr(self, ogn_data, data):
10 years ago
try:
if self._is_local:
cmd = common.ord(data[1])
10 years ago
if cmd == CMD_UDP_ASSOCIATE:
logging.debug('UDP associate')
if self._local_sock.family == socket.AF_INET6:
header = b'\x05\x00\x00\x04'
10 years ago
else:
header = b'\x05\x00\x00\x01'
10 years ago
addr, port = self._local_sock.getsockname()[:2]
10 years ago
addr_to_send = socket.inet_pton(self._local_sock.family,
addr)
port_to_send = struct.pack('>H', port)
self._write_to_sock(header + addr_to_send + port_to_send,
self._local_sock)
self._stage = STAGE_UDP_ASSOC
# just wait for the client to disconnect
return
elif cmd == CMD_CONNECT:
# just trim VER CMD RSV
data = data[3:]
else:
logging.error('unknown command %d', cmd)
self.destroy()
return
before_parse_data = data
if FORCE_NEW_PROTOCOL and ord(data[0]) != 0x88:
data = self._handel_protocol_error(self._client_address, ogn_data)
data = pre_parse_header(data)
if data is None:
data = self._handel_protocol_error(self._client_address, ogn_data)
10 years ago
header_result = parse_header(data)
if header_result is None:
data = self._handel_protocol_error(self._client_address, ogn_data)
header_result = parse_header(data)
connecttype, remote_addr, remote_port, header_length = header_result
logging.info('%s connecting %s:%d from %s:%d' %
((connecttype == 0) and 'TCP' or 'UDP',
common.to_str(remote_addr), remote_port,
self._client_address[0], self._client_address[1]))
self._remote_address = (common.to_str(remote_addr), remote_port)
self._remote_udp = (connecttype != 0)
# pause reading
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
self._stage = STAGE_DNS
10 years ago
if self._is_local:
# forward address to remote
10 years ago
self._write_to_sock((b'\x05\x00\x00\x01'
b'\x00\x00\x00\x00\x10\x10'),
10 years ago
self._local_sock)
data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly
10 years ago
self._dns_resolver.resolve(self._chosen_server[0],
self._handle_dns_resolved)
10 years ago
else:
if len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:])
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(remote_addr,
self._handle_dns_resolved)
10 years ago
except Exception as e:
self._log_error(e)
10 years ago
if self._config['verbose']:
traceback.print_exc()
10 years ago
self.destroy()
10 years ago
def _create_remote_socket(self, ip, port):
if self._remote_udp:
addrs_v6 = socket.getaddrinfo("::", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP)
addrs = socket.getaddrinfo("0.0.0.0", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP)
else:
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
10 years ago
if len(addrs) == 0:
10 years ago
raise Exception("getaddrinfo failed for %s:%d" % (ip, port))
10 years ago
af, socktype, proto, canonname, sa = addrs[0]
if self._forbidden_iplist:
if common.to_str(sa[0]) in self._forbidden_iplist:
raise Exception('IP %s is in forbidden list, reject' %
common.to_str(sa[0]))
10 years ago
remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock
self._fd_to_handlers[remote_sock.fileno()] = self
if self._remote_udp:
af, socktype, proto, canonname, sa = addrs_v6[0]
remote_sock_v6 = socket.socket(af, socktype, proto)
self._remote_sock_v6 = remote_sock_v6
self._fd_to_handlers[remote_sock_v6.fileno()] = self
remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
10 years ago
remote_sock.setblocking(False)
if self._remote_udp:
pass
else:
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
10 years ago
return remote_sock
def _handle_dns_resolved(self, result, error):
if error:
self._log_error(error)
self.destroy()
return
if result:
ip = result[1]
if ip:
try:
10 years ago
self._stage = STAGE_CONNECTING
10 years ago
remote_addr = ip
if self._is_local:
10 years ago
remote_port = self._chosen_server[1]
else:
remote_port = self._remote_address[1]
if self._is_local and self._config['fast_open']:
10 years ago
# for fastopen:
# wait for more data to arrive and send them in one SYN
10 years ago
self._stage = STAGE_CONNECTING
# we don't have to wait for remote since it's not
# created
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
# TODO when there is already data in this packet
else:
10 years ago
# else do connect
10 years ago
remote_sock = self._create_remote_socket(remote_addr,
remote_port)
if self._remote_udp:
self._loop.add(remote_sock,
eventloop.POLL_IN,
self._server)
if self._remote_sock_v6:
self._loop.add(self._remote_sock_v6,
eventloop.POLL_IN,
self._server)
else:
try:
remote_sock.connect((remote_addr, remote_port))
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == \
errno.EINPROGRESS:
pass
self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT,
self._server)
10 years ago
self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
if self._remote_udp:
while self._data_to_write_to_remote:
data = self._data_to_write_to_remote[0]
del self._data_to_write_to_remote[0]
self._write_to_sock(data, self._remote_sock)
return
10 years ago
except Exception as e:
shell.print_exception(e)
10 years ago
if self._config['verbose']:
traceback.print_exc()
self.destroy()
10 years ago
def _on_local_read(self):
10 years ago
# handle all local read events and dispatch them to methods for
# each stage
10 years ago
if not self._local_sock:
return
is_local = self._is_local
10 years ago
data = None
try:
data = self._local_sock.recv(BUF_SIZE)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in \
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK):
10 years ago
return
if not data:
self.destroy()
return
ogn_data = data
self._update_activity(len(data))
10 years ago
if not is_local:
if self._encrypt_correct:
data = self._encryptor.decrypt(data)
10 years ago
if not data:
return
self._server.server_transfer_ul += len(data)
10 years ago
if self._stage == STAGE_STREAM:
10 years ago
if self._is_local:
data = self._encryptor.encrypt(data)
10 years ago
self._write_to_sock(data, self._remote_sock)
10 years ago
return
10 years ago
elif is_local and self._stage == STAGE_INIT:
10 years ago
# TODO check auth method
self._write_to_sock(b'\x05\00', self._local_sock)
10 years ago
self._stage = STAGE_ADDR
10 years ago
return
10 years ago
elif self._stage == STAGE_CONNECTING:
self._handle_stage_connecting(data)
elif (is_local and self._stage == STAGE_ADDR) or \
10 years ago
(not is_local and self._stage == STAGE_INIT):
self._handle_stage_addr(ogn_data, data)
10 years ago
def _on_remote_read(self, is_remote_sock):
10 years ago
# handle all remote read events
10 years ago
data = None
try:
if self._remote_udp:
if is_remote_sock:
data, addr = self._remote_sock.recvfrom(BUF_SIZE)
else:
data, addr = self._remote_sock_v6.recvfrom(BUF_SIZE)
port = struct.pack('>H', addr[1])
try:
ip = socket.inet_aton(addr[0])
data = '\x00\x01' + ip + port + data
except Exception as e:
ip = socket.inet_pton(socket.AF_INET6, addr[0])
data = '\x00\x04' + ip + port + data
data = struct.pack('>H', len(data) + 2) + data
#logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
else:
data = self._remote_sock.recv(BUF_SIZE)
10 years ago
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in \
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK
10 years ago
return
if not data:
self.destroy()
return
self._server.server_transfer_dl += len(data)
self._update_activity(len(data))
10 years ago
if self._is_local:
data = self._encryptor.decrypt(data)
10 years ago
else:
if self._encrypt_correct:
data = self._encryptor.encrypt(data)
10 years ago
try:
10 years ago
self._write_to_sock(data, self._local_sock)
10 years ago
except Exception as e:
shell.print_exception(e)
10 years ago
if self._config['verbose']:
traceback.print_exc()
10 years ago
# TODO use logging when debug completed
self.destroy()
10 years ago
def _on_local_write(self):
10 years ago
# handle local writable event
10 years ago
if self._data_to_write_to_local:
data = b''.join(self._data_to_write_to_local)
10 years ago
self._data_to_write_to_local = []
10 years ago
self._write_to_sock(data, self._local_sock)
10 years ago
else:
10 years ago
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
10 years ago
10 years ago
def _on_remote_write(self):
10 years ago
# handle remote writable event
10 years ago
self._stage = STAGE_STREAM
10 years ago
if self._data_to_write_to_remote:
data = b''.join(self._data_to_write_to_remote)
10 years ago
self._data_to_write_to_remote = []
10 years ago
self._write_to_sock(data, self._remote_sock)
10 years ago
else:
10 years ago
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
10 years ago
10 years ago
def _on_local_error(self):
10 years ago
logging.debug('got local error')
10 years ago
if self._local_sock:
logging.error(eventloop.get_sock_error(self._local_sock))
10 years ago
self.destroy()
10 years ago
def _on_remote_error(self):
10 years ago
logging.debug('got remote error')
10 years ago
if self._remote_sock:
logging.error(eventloop.get_sock_error(self._remote_sock))
10 years ago
self.destroy()
def handle_event(self, sock, event):
10 years ago
# handle all events in this handler and dispatch them to methods
if self._stage == STAGE_DESTROYED:
10 years ago
logging.debug('ignore handle_event: destroyed')
return
10 years ago
# order is important
if sock == self._remote_sock or sock == self._remote_sock_v6:
if event & eventloop.POLL_ERR:
self._on_remote_error()
10 years ago
if self._stage == STAGE_DESTROYED:
return
10 years ago
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
self._on_remote_read(sock == self._remote_sock)
10 years ago
if self._stage == STAGE_DESTROYED:
return
10 years ago
if event & eventloop.POLL_OUT:
10 years ago
self._on_remote_write()
10 years ago
elif sock == self._local_sock:
if event & eventloop.POLL_ERR:
self._on_local_error()
10 years ago
if self._stage == STAGE_DESTROYED:
return
10 years ago
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
10 years ago
self._on_local_read()
10 years ago
if self._stage == STAGE_DESTROYED:
return
10 years ago
if event & eventloop.POLL_OUT:
10 years ago
self._on_local_write()
10 years ago
else:
logging.warn('unknown socket')
def _log_error(self, e):
logging.error('%s when handling connection from %s:%d' %
(e, self._client_address[0], self._client_address[1]))
10 years ago
def destroy(self):
10 years ago
# destroy the handler and release any resources
# promises:
# 1. destroy won't make another destroy() call inside
# 2. destroy releases resources so it prevents future call to destroy
# 3. destroy won't raise any exceptions
10 years ago
# if any of the promises are broken, it indicates a bug has been
10 years ago
# introduced! mostly likely memory leaks, etc
if self._stage == STAGE_DESTROYED:
10 years ago
# this couldn't happen
10 years ago
logging.debug('already destroyed')
return
self._stage = STAGE_DESTROYED
if self._remote_address:
logging.debug('destroy: %s:%d' %
self._remote_address)
else:
logging.debug('destroy')
10 years ago
if self._remote_sock:
10 years ago
logging.debug('destroying remote')
10 years ago
self._loop.remove(self._remote_sock)
del self._fd_to_handlers[self._remote_sock.fileno()]
10 years ago
self._remote_sock.close()
10 years ago
self._remote_sock = None
if self._remote_sock_v6:
logging.debug('destroying remote')
self._loop.remove(self._remote_sock_v6)
del self._fd_to_handlers[self._remote_sock_v6.fileno()]
self._remote_sock_v6.close()
self._remote_sock_v6 = None
10 years ago
if self._local_sock:
10 years ago
logging.debug('destroying local')
10 years ago
self._loop.remove(self._local_sock)
del self._fd_to_handlers[self._local_sock.fileno()]
10 years ago
self._local_sock.close()
10 years ago
self._local_sock = None
self._dns_resolver.remove_callback(self._handle_dns_resolved)
self._server.remove_handler(self)
10 years ago
class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local, stat_callback=None):
10 years ago
self._config = config
self._is_local = is_local
10 years ago
self._dns_resolver = dns_resolver
10 years ago
self._closed = False
10 years ago
self._eventloop = None
10 years ago
self._fd_to_handlers = {}
self.server_transfer_ul = 0L
self.server_transfer_dl = 0L
10 years ago
self._timeout = config['timeout']
self._timeouts = [] # a list for all the handlers
10 years ago
# we trim the timeouts once a while
self._timeout_offset = 0 # last checked position for timeout
self._handler_to_timeouts = {} # key: handler value: index in timeouts
10 years ago
if is_local:
listen_addr = config['local_address']
listen_port = config['local_port']
else:
listen_addr = config['server']
listen_port = config['server_port']
self._listen_port = listen_port
10 years ago
addrs = socket.getaddrinfo(listen_addr, listen_port, 0,
10 years ago
socket.SOCK_STREAM, socket.SOL_TCP)
if len(addrs) == 0:
raise Exception("can't get addrinfo for %s:%d" %
10 years ago
(listen_addr, listen_port))
10 years ago
af, socktype, proto, canonname, sa = addrs[0]
server_socket = socket.socket(af, socktype, proto)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
10 years ago
server_socket.bind(sa)
10 years ago
server_socket.setblocking(False)
if config['fast_open']:
try:
server_socket.setsockopt(socket.SOL_TCP, 23, 5)
except socket.error:
logging.error('warning: fast open is not available')
self._config['fast_open'] = False
server_socket.listen(1024)
10 years ago
self._server_socket = server_socket
self._stat_callback = stat_callback
10 years ago
10 years ago
def add_to_loop(self, loop):
10 years ago
if self._eventloop:
raise Exception('already add to loop')
10 years ago
if self._closed:
raise Exception('already closed')
self._eventloop = loop
self._eventloop.add(self._server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR, self)
self._eventloop.add_periodic(self.handle_periodic)
10 years ago
def remove_handler(self, handler):
index = self._handler_to_timeouts.get(hash(handler), -1)
if index >= 0:
# delete is O(n), so we just set it to None
self._timeouts[index] = None
del self._handler_to_timeouts[hash(handler)]
def update_activity(self, handler, data_len):
if data_len and self._stat_callback:
self._stat_callback(self._listen_port, data_len)
10 years ago
# set handler to active
now = int(time.time())
if now - handler.last_activity < eventloop.TIMEOUT_PRECISION:
# thus we can lower timeout modification frequency
return
handler.last_activity = now
index = self._handler_to_timeouts.get(hash(handler), -1)
if index >= 0:
# delete is O(n), so we just set it to None
self._timeouts[index] = None
length = len(self._timeouts)
self._timeouts.append(handler)
self._handler_to_timeouts[hash(handler)] = length
def _sweep_timeout(self):
10 years ago
# tornado's timeout memory management is more flexible than we need
10 years ago
# we just need a sorted last_activity queue and it's faster than heapq
# in fact we can do O(1) insertion/remove so we invent our own
if self._timeouts:
logging.log(shell.VERBOSE_LEVEL, 'sweeping timeouts')
now = time.time()
length = len(self._timeouts)
pos = self._timeout_offset
while pos < length:
handler = self._timeouts[pos]
if handler:
if now - handler.last_activity < self._timeout:
break
else:
if handler.remote_address:
logging.warn('timed out: %s:%d' %
handler.remote_address)
else:
logging.warn('timed out')
handler.destroy()
self._timeouts[pos] = None # free memory
pos += 1
else:
pos += 1
if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1:
# clean up the timeout queue when it gets larger than half
# of the queue
self._timeouts = self._timeouts[pos:]
for key in self._handler_to_timeouts:
self._handler_to_timeouts[key] -= pos
pos = 0
self._timeout_offset = pos
def handle_event(self, sock, fd, event):
10 years ago
# handle events and dispatch to handlers
if sock:
logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd,
eventloop.EVENT_NAMES.get(event, event))
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
logging.debug('accept')
conn = self._server_socket.accept()
TCPRelayHandler(self, self._fd_to_handlers,
self._eventloop, conn[0], self._config,
self._dns_resolver, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
return
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
else:
if sock:
handler = self._fd_to_handlers.get(fd, None)
if handler:
handler.handle_event(sock, event)
10 years ago
else:
logging.warn('poll removed fd')
10 years ago
def handle_periodic(self):
if self._closed:
if self._server_socket:
self._eventloop.remove(self._server_socket)
self._server_socket.close()
self._server_socket = None
logging.info('closed TCP port %d', self._listen_port)
if not self._fd_to_handlers:
logging.info('stopping')
self._eventloop.stop()
self._sweep_timeout()
def close(self, next_tick=False):
logging.debug('TCP close')
10 years ago
self._closed = True
if not next_tick:
if self._eventloop:
self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket)
self._server_socket.close()
for handler in list(self._fd_to_handlers.values()):
handler.destroy()