@ -123,751 +123,6 @@ RSP_STATE_ERROR = b"\x03"
RSP_STATE_DISCONNECT = b " \x04 "
RSP_STATE_DISCONNECT = b " \x04 "
RSP_STATE_REDIRECT = b " \x05 "
RSP_STATE_REDIRECT = b " \x05 "
class UDPLocalAddress ( object ) :
def __init__ ( self , addr ) :
self . addr = addr
self . last_activity = time . time ( )
def is_timeout ( self ) :
return time . time ( ) - self . last_activity > 30
class PacketInfo ( object ) :
def __init__ ( self , data ) :
self . data = data
self . time = time . time ( )
class SendingQueue ( object ) :
def __init__ ( self ) :
self . queue = { }
self . begin_id = 0
self . end_id = 1
self . interval = 0.5
def append ( self , data ) :
self . queue [ self . end_id ] = PacketInfo ( data )
self . end_id + = 1
return self . end_id - 1
def empty ( self ) :
return self . begin_id + 1 == self . end_id
def size ( self ) :
return self . end_id - self . begin_id - 1
def get_begin_id ( self ) :
return self . begin_id
def get_end_id ( self ) :
return self . end_id
def get_data_list ( self , pack_id_base , pack_id_list ) :
ret_list = [ ]
curtime = time . time ( )
for pack_id in pack_id_list :
offset = pack_id_base + pack_id
if offset < = self . begin_id or self . end_id < = offset :
continue
ret_data = self . queue [ offset ]
if curtime - ret_data . time > self . interval :
ret_data . time = curtime
ret_list . append ( ( offset , ret_data . data ) )
return ret_list
def set_finish ( self , begin_id , done_list ) :
while self . begin_id < begin_id :
self . begin_id + = 1
del self . queue [ self . begin_id ]
class RecvQueue ( object ) :
def __init__ ( self ) :
self . queue = { }
self . miss_queue = set ( )
self . begin_id = 0
self . end_id = 1
def empty ( self ) :
return self . begin_id + 1 == self . end_id
def insert ( self , pack_id , data ) :
if ( pack_id not in self . queue ) and pack_id > self . begin_id :
self . queue [ pack_id ] = PacketInfo ( data )
if self . end_id == pack_id :
self . end_id = pack_id + 1
elif self . end_id < pack_id :
eid = self . end_id
while eid < pack_id :
self . miss_queue . add ( eid )
eid + = 1
self . end_id = pack_id + 1
else :
self . miss_queue . remove ( pack_id )
def set_end ( self , end_id ) :
if end_id > self . end_id :
eid = self . end_id
while eid < end_id :
self . miss_queue . add ( eid )
eid + = 1
self . end_id = end_id
def get_begin_id ( self ) :
return self . begin_id
def has_data ( self ) :
return ( self . begin_id + 1 ) in self . queue
def get_data ( self ) :
if ( self . begin_id + 1 ) in self . queue :
self . begin_id + = 1
pack_id = self . begin_id
ret_data = self . queue [ pack_id ]
del self . queue [ pack_id ]
return ( pack_id , ret_data . data )
def get_missing_id ( self , begin_id ) :
missing = [ ]
if begin_id == 0 :
begin_id = self . begin_id
for i in self . miss_queue :
if i - begin_id > 32768 :
break
missing . append ( i - begin_id )
return ( begin_id , missing )
class AddressMap ( object ) :
def __init__ ( self ) :
self . _queue = [ ]
self . _addr_map = { }
def add ( self , addr ) :
if addr in self . _addr_map :
self . _addr_map [ addr ] = UDPLocalAddress ( addr )
else :
self . _addr_map [ addr ] = UDPLocalAddress ( addr )
self . _queue . append ( addr )
def keys ( self ) :
return self . _queue
def get ( self ) :
if self . _queue :
while True :
if len ( self . _queue ) == 1 :
return self . _queue [ 0 ]
index = random . randint ( 0 , len ( self . _queue ) - 1 )
addr = self . _queue [ index ]
if self . _addr_map [ addr ] . is_timeout ( ) :
self . _queue [ index ] = self . _queue [ len ( self . _queue ) - 1 ]
del self . _queue [ len ( self . _queue ) - 1 ]
del self . _addr_map [ addr ]
else :
break
return addr
else :
return None
class TCPRelayHandler ( object ) :
def __init__ ( self , server , reqid_to_handlers , fd_to_handlers , loop ,
local_sock , local_id , client_param , config ,
dns_resolver , is_local ) :
self . _server = server
self . _reqid_to_handlers = reqid_to_handlers
self . _fd_to_handlers = fd_to_handlers
self . _loop = loop
self . _local_sock = local_sock
self . _remote_sock = None
self . _remote_udp = False
self . _config = config
self . _dns_resolver = dns_resolver
self . _local_id = local_id
self . _is_local = is_local
self . _stage = STAGE_INIT
self . _password = config [ ' password ' ]
self . _method = config [ ' method ' ]
self . _fastopen_connected = False
self . _data_to_write_to_local = [ ]
self . _data_to_write_to_remote = [ ]
self . _upstream_status = WAIT_STATUS_READING
self . _downstream_status = WAIT_STATUS_INIT
self . _request_id = 0
self . _client_address = AddressMap ( )
self . _remote_address = None
self . _sendingqueue = SendingQueue ( )
self . _recvqueue = RecvQueue ( )
if ' forbidden_ip ' in config :
self . _forbidden_iplist = config [ ' forbidden_ip ' ]
else :
self . _forbidden_iplist = None
if ' forbidden_port ' in config :
self . _forbidden_portset = config [ ' forbidden_port ' ]
else :
self . _forbidden_portset = None
#fd_to_handlers[local_sock.fileno()] = self
#local_sock.setblocking(False)
#loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR)
self . last_activity = 0
self . _update_activity ( )
self . _random_mtu_size = [ random . randint ( POST_MTU_MIN , POST_MTU_MAX ) for i in range ( 1024 ) ]
self . _random_mtu_index = 0
self . _rand_data = b " \x01 \x02 \x03 \x04 \x05 \x06 \x07 \x08 \x09 \x0a \x0b \x0c \x0d \x0e \x0f \x10 " * 4
def __hash__ ( self ) :
# default __hash__ is id / 16
# we want to eliminate collisions
return id ( self )
@property
def remote_address ( self ) :
return self . _remote_address
def add_local_address ( self , addr ) :
self . _client_address . add ( addr )
def get_local_address ( self ) :
return self . _client_address . get ( )
def _update_activity ( self ) :
# tell the TCP Relay we have activities recently
# else it will think we are inactive and timed out
self . _server . update_activity ( self )
def _update_stream ( self , stream , status ) :
# update a stream to a new waiting status
# check if status is changed
# only update if dirty
dirty = False
if stream == STREAM_DOWN :
if self . _downstream_status != status :
self . _downstream_status = status
dirty = True
elif stream == STREAM_UP :
if self . _upstream_status != status :
self . _upstream_status = status
dirty = True
if dirty :
'''
if self . _local_sock :
event = eventloop . POLL_ERR
if self . _downstream_status & WAIT_STATUS_WRITING :
event | = eventloop . POLL_OUT
if self . _upstream_status & WAIT_STATUS_READING :
event | = eventloop . POLL_IN
self . _loop . modify ( self . _local_sock , event )
'''
if self . _remote_sock :
event = eventloop . POLL_ERR
if self . _downstream_status & WAIT_STATUS_READING :
event | = eventloop . POLL_IN
if self . _upstream_status & WAIT_STATUS_WRITING :
event | = eventloop . POLL_OUT
self . _loop . modify ( self . _remote_sock , event )
def _write_to_sock ( self , data , sock , addr = None ) :
# write data to sock
# if only some of the data are written, put remaining in the buffer
# and update the stream to wait for writing
if not data or not sock :
return False
uncomplete = False
retry = 0
if sock == self . _local_sock :
data = encrypt . encrypt_all ( self . _password , self . _method , 1 , data )
if addr is None :
return False
try :
self . _server . write_to_server_socket ( data , addr )
except ( OSError , IOError ) as e :
error_no = eventloop . errno_from_exception ( e )
uncomplete = True
if error_no in ( errno . EAGAIN , errno . EINPROGRESS ,
errno . EWOULDBLOCK ) :
pass
else :
#traceback.print_exc()
shell . print_exception ( e )
self . destroy ( )
return False
else :
try :
l = len ( data )
s = sock . send ( data )
if s < l :
data = data [ s : ]
uncomplete = True
except ( OSError , IOError ) as e :
error_no = eventloop . errno_from_exception ( e )
if error_no in ( errno . EAGAIN , errno . EINPROGRESS ,
errno . EWOULDBLOCK ) :
uncomplete = True
else :
#logging.error(traceback.extract_stack())
#traceback.print_exc()
shell . print_exception ( e )
self . destroy ( )
return False
if uncomplete :
if sock == self . _local_sock :
self . _update_stream ( STREAM_DOWN , WAIT_STATUS_WRITING )
elif sock == self . _remote_sock :
self . _data_to_write_to_remote . append ( data )
self . _update_stream ( STREAM_UP , WAIT_STATUS_WRITING )
else :
logging . error ( ' write_all_to_sock:unknown socket ' )
else :
if sock == self . _local_sock :
if self . _sendingqueue . size ( ) > SENDING_WINDOW_SIZE :
self . _update_stream ( STREAM_DOWN , WAIT_STATUS_WRITING )
else :
self . _update_stream ( STREAM_DOWN , WAIT_STATUS_READING )
elif sock == self . _remote_sock :
self . _update_stream ( STREAM_UP , WAIT_STATUS_READING )
else :
logging . error ( ' write_all_to_sock:unknown socket ' )
return True
def _create_remote_socket ( self , ip , port ) :
addrs = socket . getaddrinfo ( ip , port , 0 , socket . SOCK_STREAM , socket . SOL_TCP )
if len ( addrs ) == 0 :
raise Exception ( " getaddrinfo failed for %s : %d " % ( ip , port ) )
af , socktype , proto , canonname , sa = addrs [ 0 ]
if self . _forbidden_iplist :
if common . to_str ( sa [ 0 ] ) in self . _forbidden_iplist :
raise Exception ( ' IP %s is in forbidden list, reject ' %
common . to_str ( sa [ 0 ] ) )
remote_sock = socket . socket ( af , socktype , proto )
self . _remote_sock = remote_sock
self . _fd_to_handlers [ remote_sock . fileno ( ) ] = self
remote_sock . setblocking ( False )
remote_sock . setsockopt ( socket . SOL_TCP , socket . TCP_NODELAY , 1 )
return remote_sock
def _handle_dns_resolved ( self , result , error ) :
if error :
self . _log_error ( error )
self . destroy ( )
return
if result :
ip = result [ 1 ]
if ip :
try :
self . _stage = STAGE_CONNECTING
remote_addr = ip
remote_port = self . _remote_address [ 1 ]
logging . info ( " connect to %s : %d " % ( remote_addr , remote_port ) )
remote_sock = self . _create_remote_socket ( remote_addr ,
remote_port )
try :
remote_sock . connect ( ( remote_addr , remote_port ) )
except ( OSError , IOError ) as e :
if eventloop . errno_from_exception ( e ) in ( errno . EINPROGRESS ,
errno . EWOULDBLOCK ) :
pass # always goto here
else :
raise e
self . _loop . add ( remote_sock ,
eventloop . POLL_ERR | eventloop . POLL_OUT ,
self . _server )
self . _update_stream ( STREAM_UP , WAIT_STATUS_READWRITING )
self . _update_stream ( STREAM_DOWN , WAIT_STATUS_READING )
self . _stage = STAGE_STREAM
addr = self . get_local_address ( )
for i in range ( 2 ) :
rsp_data = self . _pack_rsp_data ( CMD_RSP_CONNECT_REMOTE , RSP_STATE_CONNECTEDREMOTE )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
return
except Exception as e :
shell . print_exception ( e )
if self . _config [ ' verbose ' ] :
traceback . print_exc ( )
self . destroy ( )
def _on_local_read ( self ) :
# handle all local read events and dispatch them to methods for
# each stage
self . _update_activity ( )
if not self . _local_sock :
return
data = None
try :
data = self . _local_sock . recv ( BUF_SIZE )
except ( OSError , IOError ) as e :
if eventloop . errno_from_exception ( e ) in \
( errno . ETIMEDOUT , errno . EAGAIN , errno . EWOULDBLOCK ) :
return
if not data :
self . destroy ( )
return
if not data :
return
self . _server . server_transfer_ul + = len ( data )
#TODO ============================================================
if self . _stage == STAGE_STREAM :
self . _write_to_sock ( data , self . _remote_sock )
return
def _on_remote_read ( self ) :
# handle all remote read events
self . _update_activity ( )
data = None
try :
data = self . _remote_sock . recv ( BUF_SIZE )
except ( OSError , IOError ) as e :
if eventloop . errno_from_exception ( e ) in \
( errno . ETIMEDOUT , errno . EAGAIN , errno . EWOULDBLOCK , 10035 ) : #errno.WSAEWOULDBLOCK
return
if not data :
self . destroy ( )
return
try :
self . _server . server_transfer_dl + = len ( data )
recv_data = data
beg_pos = 0
max_len = len ( recv_data )
while beg_pos < max_len :
if beg_pos + POST_MTU_MAX > = max_len :
split_pos = max_len
else :
split_pos = beg_pos + self . _random_mtu_size [ self . _random_mtu_index ]
self . _random_mtu_index = ( self . _random_mtu_index + 1 ) & 0x3ff
#split_pos = beg_pos + random.randint(POST_MTU_MIN, POST_MTU_MAX)
data = recv_data [ beg_pos : split_pos ]
beg_pos = split_pos
pack_id = self . _sendingqueue . append ( data )
post_data = self . _pack_post_data ( CMD_POST , pack_id , data )
addr = self . get_local_address ( )
self . _write_to_sock ( post_data , self . _local_sock , addr )
if pack_id < = DOUBLE_SEND_BEG_IDS :
post_data = self . _pack_post_data ( CMD_POST , pack_id , data )
self . _write_to_sock ( post_data , self . _local_sock , addr )
except Exception as e :
shell . print_exception ( e )
if self . _config [ ' verbose ' ] :
traceback . print_exc ( )
# TODO use logging when debug completed
self . destroy ( )
def _on_local_write ( self ) :
# handle local writable event
if self . _data_to_write_to_local :
data = b ' ' . join ( self . _data_to_write_to_local )
self . _data_to_write_to_local = [ ]
self . _write_to_sock ( data , self . _local_sock )
else :
self . _update_stream ( STREAM_DOWN , WAIT_STATUS_READING )
def _on_remote_write ( self ) :
# handle remote writable event
self . _stage = STAGE_STREAM
if self . _data_to_write_to_remote :
data = b ' ' . join ( self . _data_to_write_to_remote )
self . _data_to_write_to_remote = [ ]
self . _write_to_sock ( data , self . _remote_sock )
else :
self . _update_stream ( STREAM_UP , WAIT_STATUS_READING )
def _on_local_error ( self ) :
logging . debug ( ' got local error ' )
if self . _local_sock :
logging . error ( eventloop . get_sock_error ( self . _local_sock ) )
self . destroy ( )
def _on_remote_error ( self ) :
logging . debug ( ' got remote error ' )
if self . _remote_sock :
logging . error ( eventloop . get_sock_error ( self . _remote_sock ) )
self . destroy ( )
def _pack_rsp_data ( self , cmd , data ) :
reqid_str = struct . pack ( " >H " , self . _request_id )
return b ' ' . join ( [ CMD_VER_STR , common . chr ( cmd ) , reqid_str , data , self . _rand_data [ : random . randint ( 0 , len ( self . _rand_data ) ) ] , reqid_str ] )
def _pack_rnd_data ( self , data ) :
length = random . randint ( 0 , len ( self . _rand_data ) )
if length == 0 :
return data
elif length == 1 :
return b " \x81 " + data
elif length < 256 :
return b " \x80 " + common . chr ( length ) + self . _rand_data [ : length - 2 ] + data
else :
return b " \x82 " + struct . pack ( " >H " , length ) + self . _rand_data [ : length - 3 ] + data
def _pack_post_data ( self , cmd , pack_id , data ) :
reqid_str = struct . pack ( " >H " , self . _request_id )
recv_id = self . _recvqueue . get_begin_id ( )
rsp_data = b ' ' . join ( [ CMD_VER_STR , common . chr ( cmd ) , reqid_str , struct . pack ( " >I " , recv_id ) , struct . pack ( " >I " , pack_id ) , data , reqid_str ] )
return rsp_data
def _pack_post_data_64 ( self , cmd , send_id , pack_id , data ) :
reqid_str = struct . pack ( " >H " , self . _request_id )
recv_id = self . _recvqueue . get_begin_id ( )
rsp_data = b ' ' . join ( [ CMD_VER_STR , common . chr ( cmd ) , reqid_str , struct . pack ( " >Q " , recv_id ) , struct . pack ( " >Q " , pack_id ) , data , reqid_str ] )
return rsp_data
def sweep_timeout ( self ) :
logging . info ( " sweep_timeout " )
if self . _stage == STAGE_STREAM :
pack_id , missing = self . _recvqueue . get_missing_id ( 0 )
logging . info ( " sweep_timeout %s %s " % ( pack_id , missing ) )
data = b ' '
for pid in missing :
data + = struct . pack ( " >H " , pid )
rsp_data = self . _pack_post_data ( CMD_SYN_STATUS , pack_id , data )
addr = self . get_local_address ( )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
def handle_stream_sync_status ( self , addr , cmd , request_id , pack_id , max_send_id , data ) :
missing_list = [ ]
while len ( data ) > = 2 :
pid = struct . unpack ( " >H " , data [ 0 : 2 ] ) [ 0 ]
data = data [ 2 : ]
missing_list . append ( pid )
done_list = [ ]
self . _recvqueue . set_end ( max_send_id )
self . _sendingqueue . set_finish ( pack_id , done_list )
if self . _stage == STAGE_DESTROYED and self . _sendingqueue . empty ( ) :
self . destroy_local ( )
return
# post CMD_SYN_STATUS
send_id = self . _sendingqueue . get_end_id ( )
post_pack_id , missing = self . _recvqueue . get_missing_id ( 0 )
pack_ids_data = b ' '
for pid in missing :
pack_ids_data + = struct . pack ( " >H " , pid )
rsp_data = self . _pack_rnd_data ( self . _pack_post_data ( CMD_SYN_STATUS , send_id , pack_ids_data ) )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
send_list = self . _sendingqueue . get_data_list ( pack_id , missing_list )
for post_pack_id , post_data in send_list :
rsp_data = self . _pack_post_data ( CMD_POST , post_pack_id , post_data )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
if post_pack_id < = DOUBLE_SEND_BEG_IDS :
rsp_data = self . _pack_post_data ( CMD_POST , post_pack_id , post_data )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
def handle_client ( self , addr , cmd , request_id , data ) :
self . add_local_address ( addr )
if cmd == CMD_DISCONNECT :
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
self . destroy ( )
self . destroy_local ( )
return
if self . _stage == STAGE_INIT :
if cmd == CMD_CONNECT :
self . _request_id = request_id
self . _stage = STAGE_RSP_ID
return
if self . _request_id != request_id :
return
if self . _stage == STAGE_RSP_ID :
if cmd == CMD_CONNECT :
for i in range ( 2 ) :
rsp_data = self . _pack_rsp_data ( CMD_RSP_CONNECT , RSP_STATE_CONNECTED )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
elif cmd == CMD_CONNECT_REMOTE :
local_id = data [ 0 : 4 ]
if self . _local_id == local_id :
data = data [ 4 : ]
header_result = parse_header ( data )
if header_result is None :
return
connecttype , remote_addr , remote_port , header_length = header_result
self . _remote_address = ( common . to_str ( remote_addr ) , remote_port )
self . _stage = STAGE_DNS
self . _dns_resolver . resolve ( remote_addr ,
self . _handle_dns_resolved )
common . connect_log ( ' TCPonUDP connect %s : %d from %s : %d ' % ( remote_addr , remote_port , addr [ 0 ] , addr [ 1 ] ) )
else :
# ileagal request
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
elif self . _stage == STAGE_CONNECTING :
if cmd == CMD_CONNECT_REMOTE :
local_id = data [ 0 : 4 ]
if self . _local_id == local_id :
for i in range ( 2 ) :
rsp_data = self . _pack_rsp_data ( CMD_RSP_CONNECT_REMOTE , RSP_STATE_CONNECTEDREMOTE )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
else :
# ileagal request
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
elif self . _stage == STAGE_STREAM :
if len ( data ) < 4 :
# ileagal request
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
return
local_id = data [ 0 : 4 ]
if self . _local_id != local_id :
# ileagal request
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
return
else :
data = data [ 4 : ]
if cmd == CMD_CONNECT_REMOTE :
rsp_data = self . _pack_rsp_data ( CMD_RSP_CONNECT_REMOTE , RSP_STATE_CONNECTEDREMOTE )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
elif cmd == CMD_POST :
recv_id = struct . unpack ( " >I " , data [ 0 : 4 ] ) [ 0 ]
pack_id = struct . unpack ( " >I " , data [ 4 : 8 ] ) [ 0 ]
self . _recvqueue . insert ( pack_id , data [ 8 : ] )
self . _sendingqueue . set_finish ( recv_id , [ ] )
elif cmd == CMD_POST_64 :
recv_id = struct . unpack ( " >Q " , data [ 0 : 8 ] ) [ 0 ]
pack_id = struct . unpack ( " >Q " , data [ 8 : 16 ] ) [ 0 ]
self . _recvqueue . insert ( pack_id , data [ 16 : ] )
self . _sendingqueue . set_finish ( recv_id , [ ] )
elif cmd == CMD_DISCONNECT :
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
self . destroy ( )
self . destroy_local ( )
return
elif cmd == CMD_SYN_STATUS :
pack_id = struct . unpack ( " >I " , data [ 0 : 4 ] ) [ 0 ]
max_send_id = struct . unpack ( " >I " , data [ 4 : 8 ] ) [ 0 ]
data = data [ 8 : ]
self . handle_stream_sync_status ( addr , cmd , request_id , pack_id , max_send_id , data )
elif cmd == CMD_SYN_STATUS_64 :
pack_id = struct . unpack ( " >Q " , data [ 0 : 8 ] ) [ 0 ]
max_send_id = struct . unpack ( " >Q " , data [ 8 : 16 ] ) [ 0 ]
data = data [ 16 : ]
self . handle_stream_sync_status ( addr , cmd , request_id , pack_id , max_send_id , data )
while self . _recvqueue . has_data ( ) :
pack_id , post_data = self . _recvqueue . get_data ( )
self . _write_to_sock ( post_data , self . _remote_sock )
elif self . _stage == STAGE_DESTROYED :
local_id = data [ 0 : 4 ]
if self . _local_id != local_id :
# ileagal request
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
return
else :
data = data [ 4 : ]
if cmd == CMD_SYN_STATUS :
pack_id = struct . unpack ( " >I " , data [ 0 : 4 ] ) [ 0 ]
max_send_id = struct . unpack ( " >I " , data [ 4 : 8 ] ) [ 0 ]
data = data [ 8 : ]
self . handle_stream_sync_status ( addr , cmd , request_id , pack_id , max_send_id , data )
elif cmd == CMD_SYN_STATUS_64 :
pack_id = struct . unpack ( " >Q " , data [ 0 : 8 ] ) [ 0 ]
max_send_id = struct . unpack ( " >Q " , data [ 8 : 16 ] ) [ 0 ]
data = data [ 16 : ]
self . handle_stream_sync_status ( addr , cmd , request_id , pack_id , max_send_id , data )
def handle_event ( self , sock , event ) :
# handle all events in this handler and dispatch them to methods
handle = False
if self . _stage == STAGE_DESTROYED :
logging . debug ( ' ignore handle_event: destroyed ' )
return True
# order is important
if sock == self . _remote_sock :
if event & eventloop . POLL_ERR :
handle = True
self . _on_remote_error ( )
if self . _stage == STAGE_DESTROYED :
return True
if event & ( eventloop . POLL_IN | eventloop . POLL_HUP ) :
handle = True
self . _on_remote_read ( )
if self . _stage == STAGE_DESTROYED :
return True
if event & eventloop . POLL_OUT :
handle = True
self . _on_remote_write ( )
elif sock == self . _local_sock :
if event & eventloop . POLL_ERR :
handle = True
self . _on_local_error ( )
if self . _stage == STAGE_DESTROYED :
return True
if event & ( eventloop . POLL_IN | eventloop . POLL_HUP ) :
handle = True
self . _on_local_read ( )
if self . _stage == STAGE_DESTROYED :
return True
if event & eventloop . POLL_OUT :
handle = True
self . _on_local_write ( )
else :
logging . warn ( ' unknown socket ' )
return handle
def _log_error ( self , e ) :
logging . error ( ' %s when handling connection from %s ' %
( e , self . _client_address . keys ( ) ) )
def destroy ( self ) :
# destroy the handler and release any resources
# promises:
# 1. destroy won't make another destroy() call inside
# 2. destroy releases resources so it prevents future call to destroy
# 3. destroy won't raise any exceptions
# if any of the promises are broken, it indicates a bug has been
# introduced! mostly likely memory leaks, etc
#logging.info('tcp destroy called')
if self . _stage == STAGE_DESTROYED :
# this couldn't happen
logging . debug ( ' already destroyed ' )
return
self . _stage = STAGE_DESTROYED
if self . _remote_address :
logging . debug ( ' destroy: %s : %d ' %
self . _remote_address )
else :
logging . debug ( ' destroy ' )
if self . _remote_sock :
logging . debug ( ' destroying remote ' )
self . _loop . remove ( self . _remote_sock )
try :
del self . _fd_to_handlers [ self . _remote_sock . fileno ( ) ]
except Exception as e :
pass
self . _remote_sock . close ( )
self . _remote_sock = None
if self . _sendingqueue . empty ( ) :
self . destroy_local ( )
self . _dns_resolver . remove_callback ( self . _handle_dns_resolved )
def destroy_local ( self ) :
if self . _local_sock :
logging . debug ( ' disconnect local ' )
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , RSP_STATE_EMPTY )
addr = None
addr = self . get_local_address ( )
self . _write_to_sock ( rsp_data , self . _local_sock , addr )
self . _local_sock = None
try :
del self . _reqid_to_handlers [ self . _request_id ]
except Exception as e :
pass
self . _server . remove_handler ( self )
def client_key ( source_addr , server_af ) :
def client_key ( source_addr , server_af ) :
# notice this is server af, not dest af
# notice this is server af, not dest af
return ' %s : %s : %d ' % ( source_addr [ 0 ] , source_addr [ 1 ] , server_af )
return ' %s : %s : %d ' % ( source_addr [ 0 ] , source_addr [ 1 ] , server_af )
@ -908,7 +163,7 @@ class UDPRelay(object):
self . server_user_transfer_ul = { }
self . server_user_transfer_ul = { }
self . server_user_transfer_dl = { }
self . server_user_transfer_dl = { }
if common . to_bytes ( config [ ' protocol ' ] ) in [ b " auth_aes128_md5 " , b " auth_aes128_sha1 " ] :
if common . to_bytes ( config [ ' protocol ' ] ) in obfs . mu_protocol ( ) :
self . _update_users ( None , None )
self . _update_users ( None , None )
self . protocol_data = obfs . obfs ( config [ ' protocol ' ] ) . init_data ( )
self . protocol_data = obfs . obfs ( config [ ' protocol ' ] ) . init_data ( )
@ -1261,71 +516,6 @@ class UDPRelay(object):
else :
else :
shell . print_exception ( e )
shell . print_exception ( e )
def _handle_tcp_over_udp ( self , data , r_addr ) :
#(cmd, request_id, data)
#logging.info("UDP data %d %d %s" % (data[0], data[1], binascii.hexlify(data[2])))
try :
self . server_transfer_ul + = len ( data [ 2 ] )
if data [ 0 ] == 0 :
if len ( data [ 2 ] ) > = 4 :
for i in range ( 64 ) :
req_id = random . randint ( 1 , 65535 )
if req_id not in self . _reqid_to_hd :
break
if req_id in self . _reqid_to_hd :
for i in range ( 64 ) :
req_id = random . randint ( 1 , 65535 )
if type ( self . _reqid_to_hd [ req_id ] ) is tuple :
break
# return req id
self . _reqid_to_hd [ req_id ] = ( data [ 2 ] [ 0 : 4 ] , None )
rsp_data = self . _pack_rsp_data ( CMD_RSP_CONNECT , req_id , RSP_STATE_CONNECTED )
data_to_send = encrypt . encrypt_all ( self . _password , self . _method , 1 , rsp_data )
self . write_to_server_socket ( data_to_send , r_addr )
elif data [ 0 ] == CMD_CONNECT_REMOTE :
if len ( data [ 2 ] ) > 4 and data [ 1 ] in self . _reqid_to_hd :
# create
if type ( self . _reqid_to_hd [ data [ 1 ] ] ) is tuple :
if data [ 2 ] [ 0 : 4 ] == self . _reqid_to_hd [ data [ 1 ] ] [ 0 ] :
handle = TCPRelayHandler ( self , self . _reqid_to_hd , self . _fd_to_handlers ,
self . _eventloop , self . _server_socket ,
self . _reqid_to_hd [ data [ 1 ] ] [ 0 ] , self . _reqid_to_hd [ data [ 1 ] ] [ 1 ] ,
self . _config , self . _dns_resolver , self . _is_local )
self . _reqid_to_hd [ data [ 1 ] ] = handle
handle . handle_client ( r_addr , CMD_CONNECT , data [ 1 ] , data [ 2 ] )
handle . handle_client ( r_addr , * data )
self . update_activity ( handle )
else :
# disconnect
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , data [ 1 ] , RSP_STATE_EMPTY )
data_to_send = encrypt . encrypt_all ( self . _password , self . _method , 1 , rsp_data )
self . write_to_server_socket ( data_to_send , r_addr )
else :
self . update_activity ( self . _reqid_to_hd [ data [ 1 ] ] )
self . _reqid_to_hd [ data [ 1 ] ] . handle_client ( r_addr , * data )
else :
# disconnect
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , data [ 1 ] , RSP_STATE_EMPTY )
data_to_send = encrypt . encrypt_all ( self . _password , self . _method , 1 , rsp_data )
self . write_to_server_socket ( data_to_send , r_addr )
elif data [ 0 ] > CMD_CONNECT_REMOTE and data [ 0 ] < = CMD_DISCONNECT :
if data [ 1 ] in self . _reqid_to_hd :
if type ( self . _reqid_to_hd [ data [ 1 ] ] ) is tuple :
pass
else :
self . update_activity ( self . _reqid_to_hd [ data [ 1 ] ] )
self . _reqid_to_hd [ data [ 1 ] ] . handle_client ( r_addr , * data )
else :
# disconnect
rsp_data = self . _pack_rsp_data ( CMD_DISCONNECT , data [ 1 ] , RSP_STATE_EMPTY )
data_to_send = encrypt . encrypt_all ( self . _password , self . _method , 1 , rsp_data )
self . write_to_server_socket ( data_to_send , r_addr )
return
except Exception as e :
trace = traceback . format_exc ( )
logging . error ( trace )
return
def _handle_client ( self , sock ) :
def _handle_client ( self , sock ) :
data , r_addr = sock . recvfrom ( BUF_SIZE )
data , r_addr = sock . recvfrom ( BUF_SIZE )
if not data :
if not data :
@ -1333,6 +523,18 @@ class UDPRelay(object):
return
return
if self . _stat_callback :
if self . _stat_callback :
self . _stat_callback ( self . _listen_port , len ( data ) )
self . _stat_callback ( self . _listen_port , len ( data ) )
client_addr = self . _client_fd_to_server_addr . get ( sock . fileno ( ) )
client_uid = None
if client_addr :
key = client_key ( client_addr [ 0 ] , client_addr [ 1 ] )
client_pair = self . _cache . get ( key , None )
client_dns_pair = self . _cache_dns_client . get ( key , None )
if client_pair :
client , client_uid = client_pair
elif client_dns_pair :
client , client_uid = client_dns_pair
if not self . _is_local :
if not self . _is_local :
addrlen = len ( r_addr [ 0 ] )
addrlen = len ( r_addr [ 0 ] )
if addrlen > 255 :
if addrlen > 255 :
@ -1341,7 +543,7 @@ class UDPRelay(object):
data = pack_addr ( r_addr [ 0 ] ) + struct . pack ( ' >H ' , r_addr [ 1 ] ) + data
data = pack_addr ( r_addr [ 0 ] ) + struct . pack ( ' >H ' , r_addr [ 1 ] ) + data
ref_iv = [ encrypt . encrypt_new_iv ( self . _method ) ]
ref_iv = [ encrypt . encrypt_new_iv ( self . _method ) ]
self . _protocol . obfs . server_info . iv = ref_iv [ 0 ]
self . _protocol . obfs . server_info . iv = ref_iv [ 0 ]
data = self . _protocol . server_udp_pre_encrypt ( data )
data = self . _protocol . server_udp_pre_encrypt ( data , client_uid )
response = encrypt . encrypt_all_iv ( self . _protocol . obfs . server_info . key , self . _method , 1 ,
response = encrypt . encrypt_all_iv ( self . _protocol . obfs . server_info . key , self . _method , 1 ,
data , ref_iv )
data , ref_iv )
if not response :
if not response :
@ -1361,16 +563,9 @@ class UDPRelay(object):
#logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port))
#logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port))
response = b ' \x00 \x00 \x00 ' + data
response = b ' \x00 \x00 \x00 ' + data
client_addr = self . _client_fd_to_server_addr . get ( sock . fileno ( ) )
if client_addr :
if client_addr :
key = client_key ( client_addr [ 0 ] , client_addr [ 1 ] )
if client_uid :
client_pair = self . _cache . get ( key , None )
client_dns_pair = self . _cache_dns_client . get ( key , None )
if client_pair :
client , client_uid = client_pair
self . add_transfer_d ( client_uid , len ( response ) )
elif client_dns_pair :
client , client_uid = client_dns_pair
self . add_transfer_d ( client_uid , len ( response ) )
self . add_transfer_d ( client_uid , len ( response ) )
else :
else :
self . server_transfer_dl + = len ( response )
self . server_transfer_dl + = len ( response )