Browse Source

multiuser in single port protocol

dev
破娃酱 8 years ago
parent
commit
959aad3f41
  1. 31
      shadowsocks/obfsplugin/auth.py
  2. 5
      shadowsocks/obfsplugin/plain.py
  3. 8
      shadowsocks/obfsplugin/verify.py
  4. 48
      shadowsocks/tcprelay.py
  5. 84
      shadowsocks/udprelay.py

31
shadowsocks/obfsplugin/auth.py

@ -1103,8 +1103,8 @@ class auth_aes128(auth_base):
length = len(buf) length = len(buf)
data = buf[:-4] data = buf[:-4]
if struct.pack('<I', zlib.adler32(data) & 0xFFFFFFFF) != buf[length - 4:]: if struct.pack('<I', zlib.adler32(data) & 0xFFFFFFFF) != buf[length - 4:]:
return b'' return (b'', None)
return data return (data, None)
class auth_aes128_sha1(auth_base): class auth_aes128_sha1(auth_base):
def __init__(self, method, hashfunc): def __init__(self, method, hashfunc):
@ -1280,9 +1280,15 @@ class auth_aes128_sha1(auth_base):
return (b'', False) return (b'', False)
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
user_key = self.recv_buf[7:11] uid = self.recv_buf[7:11]
#if user_key in user_map: self.user_key[user_key] else: # TODO if uid in self.server_info.users:
self.user_key = self.server_info.key self.user_key = self.server_info.users[uid]
self.server_info.update_user_func(uid)
else:
if not self.server_info.users:
self.user_key = self.server_info.key
else:
self.user_key = self.server_info.recv_iv
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc') encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc')
head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[11:27] + b'\x00') # need an extra byte or recv empty head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[11:27] + b'\x00') # need an extra byte or recv empty
length = struct.unpack('<H', head[12:14])[0] length = struct.unpack('<H', head[12:14])[0]
@ -1377,8 +1383,15 @@ class auth_aes128_sha1(auth_base):
def server_udp_post_decrypt(self, buf): def server_udp_post_decrypt(self, buf):
uid = buf[-8:-4] uid = buf[-8:-4]
user_key = self.server_info.key if uid in self.server_info.users:
if hmac.new(user_key, buf[:-4], self.hashfunc).digest()[:4] != buf[-4:]: self.user_key = self.server_info.users[uid]
return b'' else:
return buf[:-8] uid = None
if not self.server_info.users:
self.user_key = self.server_info.key
else:
self.user_key = self.server_info.recv_iv
if hmac.new(self.user_key, buf[:-4], self.hashfunc).digest()[:4] != buf[-4:]:
return (b'', None)
return (buf[:-8], uid)

5
shadowsocks/obfsplugin/plain.py

@ -40,6 +40,9 @@ class plain(object):
def init_data(self): def init_data(self):
return b'' return b''
def get_server_info(self):
return self.server_info
def set_server_info(self, server_info): def set_server_info(self, server_info):
self.server_info = server_info self.server_info = server_info
@ -79,7 +82,7 @@ class plain(object):
return buf return buf
def server_udp_post_decrypt(self, buf): def server_udp_post_decrypt(self, buf):
return buf return (buf, None)
def dispose(self): def dispose(self):
pass pass

8
shadowsocks/obfsplugin/verify.py

@ -350,11 +350,11 @@ class verify_sha1(verify_base):
def server_udp_post_decrypt(self, buf): def server_udp_post_decrypt(self, buf):
if buf and ((ord(buf[0]) & 0x10) == 0x10): if buf and ((ord(buf[0]) & 0x10) == 0x10):
if len(buf) <= 11: if len(buf) <= 11:
return b'' return (b'', None)
sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, buf[:-10], hashlib.sha1).digest()[:10] sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, buf[:-10], hashlib.sha1).digest()[:10]
if sha1data != buf[-10:]: if sha1data != buf[-10:]:
return b'' return (b'', None)
return to_bytes(chr(ord(buf[0]) & 0xEF)) + buf[1:-10] return (to_bytes(chr(ord(buf[0]) & 0xEF)) + buf[1:-10], None)
else: else:
return buf return (buf, None)

48
shadowsocks/tcprelay.py

@ -106,6 +106,7 @@ class TCPRelayHandler(object):
self._dns_resolver = dns_resolver self._dns_resolver = dns_resolver
self._client_address = local_sock.getpeername()[:2] self._client_address = local_sock.getpeername()[:2]
self._accept_address = local_sock.getsockname()[:2] self._accept_address = local_sock.getsockname()[:2]
self._user = None
# TCP Relay works as either sslocal or ssserver # TCP Relay works as either sslocal or ssserver
# if is_local, this is sslocal # if is_local, this is sslocal
@ -123,6 +124,8 @@ class TCPRelayHandler(object):
server_info = obfs.server_info(server.obfs_data) server_info = obfs.server_info(server.obfs_data)
server_info.host = config['server'] server_info.host = config['server']
server_info.port = server._listen_port server_info.port = server._listen_port
#server_info.users = server.server_users
#server_info.update_user_func = self._update_user
server_info.client = self._client_address[0] server_info.client = self._client_address[0]
server_info.client_port = self._client_address[1] server_info.client_port = self._client_address[1]
server_info.protocol_param = '' server_info.protocol_param = ''
@ -139,6 +142,8 @@ class TCPRelayHandler(object):
server_info = obfs.server_info(server.protocol_data) server_info = obfs.server_info(server.protocol_data)
server_info.host = config['server'] server_info.host = config['server']
server_info.port = server._listen_port server_info.port = server._listen_port
server_info.users = server.server_users
server_info.update_user_func = self._update_user
server_info.client = self._client_address[0] server_info.client = self._client_address[0]
server_info.client_port = self._client_address[1] server_info.client_port = self._client_address[1]
server_info.protocol_param = config['protocol_param'] server_info.protocol_param = config['protocol_param']
@ -203,6 +208,9 @@ class TCPRelayHandler(object):
logging.debug('chosen server: %s:%d', server, server_port) logging.debug('chosen server: %s:%d', server, server_port)
return server, server_port return server, server_port
def _update_user(self, user):
self._user = user
def _update_activity(self, data_len=0): def _update_activity(self, data_len=0):
# tell the TCP Relay we have activities recently # tell the TCP Relay we have activities recently
# else it will think we are inactive and timed out # else it will think we are inactive and timed out
@ -303,7 +311,7 @@ class TCPRelayHandler(object):
try: try:
if self._encrypt_correct: if self._encrypt_correct:
if sock == self._remote_sock: if sock == self._remote_sock:
self._server.server_transfer_ul += len(data) self._server.add_transfer_u(self._user, len(data))
self._update_activity(len(data)) self._update_activity(len(data))
if data: if data:
l = len(data) l = len(data)
@ -839,7 +847,7 @@ class TCPRelayHandler(object):
data = self._encryptor.encrypt(data) data = self._encryptor.encrypt(data)
data = self._obfs.server_encode(data) data = self._obfs.server_encode(data)
self._update_activity(len(data)) self._update_activity(len(data))
self._server.server_transfer_dl += len(data) self._server.add_transfer_d(self._user, len(data))
else: else:
return return
try: try:
@ -989,6 +997,9 @@ class TCPRelay(object):
self._fd_to_handlers = {} self._fd_to_handlers = {}
self.server_transfer_ul = 0 self.server_transfer_ul = 0
self.server_transfer_dl = 0 self.server_transfer_dl = 0
self.server_users = {}
self.server_user_transfer_ul = {}
self.server_user_transfer_dl = {}
self.server_connections = 0 self.server_connections = 0
self.protocol_data = obfs.obfs(config['protocol']).init_data() self.protocol_data = obfs.obfs(config['protocol']).init_data()
self.obfs_data = obfs.obfs(config['obfs']).init_data() self.obfs_data = obfs.obfs(config['obfs']).init_data()
@ -1008,6 +1019,16 @@ class TCPRelay(object):
listen_port = config['server_port'] listen_port = config['server_port']
self._listen_port = listen_port self._listen_port = listen_port
if config['protocol'] in ["auth_aes128_md5", "auth_aes128_sha1"]:
user_list = config['protocol_param'].split(',')
if user_list:
for user in user_list:
items = user.split(':')
if len(items) == 2:
uid = struct.pack('<I', int(items[0]))
passwd = items[1]
self.add_user(uid, passwd)
addrs = socket.getaddrinfo(listen_addr, listen_port, 0, addrs = socket.getaddrinfo(listen_addr, listen_port, 0,
socket.SOCK_STREAM, socket.SOL_TCP) socket.SOCK_STREAM, socket.SOL_TCP)
if len(addrs) == 0: if len(addrs) == 0:
@ -1047,6 +1068,29 @@ class TCPRelay(object):
self.server_connections += val self.server_connections += val
logging.debug('server port %5d connections = %d' % (self._listen_port, self.server_connections,)) logging.debug('server port %5d connections = %d' % (self._listen_port, self.server_connections,))
def add_user(self, user, passwd): # user: binstr[4], passwd: str
self.server_users[user] = common.to_bytes(passwd)
def del_user(self, user, passwd):
if user in self.server_users:
del self.server_users[user]
def add_transfer_u(self, user, transfer):
if user is None:
self.server_transfer_ul += transfer
else:
if user not in self.server_user_transfer_ul:
self.server_user_transfer_ul[user] = 0
self.server_user_transfer_ul[user] += transfer
def add_transfer_d(self, user, transfer):
if user is None:
self.server_transfer_dl += transfer
else:
if user not in self.server_user_transfer_dl:
self.server_user_transfer_dl[user] = 0
self.server_user_transfer_dl[user] += transfer
def update_stat(self, port, stat_dict, val): def update_stat(self, port, stat_dict, val):
newval = stat_dict.get(0, 0) + val newval = stat_dict.get(0, 0) + val
stat_dict[0] = newval stat_dict[0] = newval

84
shadowsocks/udprelay.py

@ -888,21 +888,35 @@ class UDPRelay(object):
self._is_local = is_local self._is_local = is_local
self._udp_cache_size = config['udp_cache'] self._udp_cache_size = config['udp_cache']
self._cache = lru_cache.LRUCache(timeout=config['udp_timeout'], self._cache = lru_cache.LRUCache(timeout=config['udp_timeout'],
close_callback=self._close_client) close_callback=self._close_client_pair)
self._cache_dns_client = lru_cache.LRUCache(timeout=10, self._cache_dns_client = lru_cache.LRUCache(timeout=10,
close_callback=self._close_client) close_callback=self._close_client_pair)
self._client_fd_to_server_addr = {} self._client_fd_to_server_addr = {}
self._dns_cache = lru_cache.LRUCache(timeout=300) self._dns_cache = lru_cache.LRUCache(timeout=300)
self._eventloop = None self._eventloop = None
self._closed = False self._closed = False
self.server_transfer_ul = 0 self.server_transfer_ul = 0
self.server_transfer_dl = 0 self.server_transfer_dl = 0
self.server_users = {}
self.server_user_transfer_ul = {}
self.server_user_transfer_dl = {}
if config['protocol'] in ["auth_aes128_md5", "auth_aes128_sha1"]:
user_list = config['protocol_param'].split(',')
if user_list:
for user in user_list:
items = user.split(':')
if len(items) == 2:
uid = struct.pack('<I', int(items[0]))
passwd = items[1]
self.add_user(uid, passwd)
self.protocol_data = obfs.obfs(config['protocol']).init_data() self.protocol_data = obfs.obfs(config['protocol']).init_data()
self._protocol = obfs.obfs(config['protocol']) self._protocol = obfs.obfs(config['protocol'])
server_info = obfs.server_info(self.protocol_data) server_info = obfs.server_info(self.protocol_data)
server_info.host = self._listen_addr server_info.host = self._listen_addr
server_info.port = self._listen_port server_info.port = self._listen_port
server_info.users = self.server_users
server_info.protocol_param = config['protocol_param'] server_info.protocol_param = config['protocol_param']
server_info.obfs_param = '' server_info.obfs_param = ''
server_info.iv = b'' server_info.iv = b''
@ -956,6 +970,33 @@ class UDPRelay(object):
logging.debug('chosen server: %s:%d', server, server_port) logging.debug('chosen server: %s:%d', server, server_port)
return server, server_port return server, server_port
def add_user(self, user, passwd): # user: binstr[4], passwd: str
self.server_users[user] = common.to_bytes(passwd)
def del_user(self, user, passwd):
if user in self.server_users:
del self.server_users[user]
def add_transfer_u(self, user, transfer):
if user is None:
self.server_transfer_ul += transfer
else:
if user not in self.server_user_transfer_ul:
self.server_user_transfer_ul[user] = 0
self.server_user_transfer_ul[user] += transfer
def add_transfer_d(self, user, transfer):
if user is None:
self.server_transfer_dl += transfer
else:
if user not in self.server_user_transfer_dl:
self.server_user_transfer_dl[user] = 0
self.server_user_transfer_dl[user] += transfer
def _close_client_pair(self, client_pair):
client, uid = client_pair
self._close_client(client)
def _close_client(self, client): def _close_client(self, client):
if hasattr(client, 'close'): if hasattr(client, 'close'):
if not self._is_local: if not self._is_local:
@ -1039,6 +1080,7 @@ class UDPRelay(object):
logging.debug('UDP handle_server: data is empty') logging.debug('UDP handle_server: data is empty')
if self._stat_callback: if self._stat_callback:
self._stat_callback(self._listen_port, len(data)) self._stat_callback(self._listen_port, len(data))
uid = None
if self._is_local: if self._is_local:
frag = common.ord(data[2]) frag = common.ord(data[2])
if frag != 0: if frag != 0:
@ -1054,7 +1096,7 @@ class UDPRelay(object):
logging.debug('UDP handle_server: data is empty after decrypt') logging.debug('UDP handle_server: data is empty after decrypt')
return return
self._protocol.obfs.server_info.recv_iv = ref_iv[0] self._protocol.obfs.server_info.recv_iv = ref_iv[0]
data = self._protocol.server_udp_post_decrypt(data) data, uid = self._protocol.server_udp_post_decrypt(data)
#logging.info("UDP data %s" % (binascii.hexlify(data),)) #logging.info("UDP data %s" % (binascii.hexlify(data),))
if not self._is_local: if not self._is_local:
@ -1097,10 +1139,10 @@ class UDPRelay(object):
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
key = client_key(r_addr, af) key = client_key(r_addr, af)
client = self._cache.get(key, None) client_pair = self._cache.get(key, None)
if not client: if not client_pair:
client = self._cache_dns_client.get(key, None) client_pair = self._cache_dns_client.get(key, None)
if not client: if not client_pair:
if self._forbidden_iplist: if self._forbidden_iplist:
if common.to_str(sa[0]) in self._forbidden_iplist: if common.to_str(sa[0]) in self._forbidden_iplist:
logging.debug('IP %s is in forbidden list, drop' % logging.debug('IP %s is in forbidden list, drop' %
@ -1114,6 +1156,7 @@ class UDPRelay(object):
# drop # drop
return return
client = socket.socket(af, socktype, proto) client = socket.socket(af, socktype, proto)
client_uid = uid
client.setblocking(False) client.setblocking(False)
self._socket_bind_addr(client, af) self._socket_bind_addr(client, af)
is_dns = False is_dns = False
@ -1124,9 +1167,9 @@ class UDPRelay(object):
#logging.info("unknown data %s" % (binascii.hexlify(data),)) #logging.info("unknown data %s" % (binascii.hexlify(data),))
if sa[1] == 53 and is_dns: #DNS if sa[1] == 53 and is_dns: #DNS
logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1])) logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1]))
self._cache_dns_client[key] = client self._cache_dns_client[key] = (client, uid)
else: else:
self._cache[key] = client self._cache[key] = (client, uid)
self._client_fd_to_server_addr[client.fileno()] = (r_addr, af) self._client_fd_to_server_addr[client.fileno()] = (r_addr, af)
self._sockets.add(client.fileno()) self._sockets.add(client.fileno())
@ -1137,7 +1180,8 @@ class UDPRelay(object):
common.connect_log('UDP data to %s:%d via port %d' % common.connect_log('UDP data to %s:%d via port %d' %
(common.to_str(server_addr), server_port, (common.to_str(server_addr), server_port,
self._listen_port)) self._listen_port))
else:
client, client_uid = client_pair
self._cache.clear(self._udp_cache_size) self._cache.clear(self._udp_cache_size)
self._cache_dns_client.clear(16) self._cache_dns_client.clear(16)
@ -1156,7 +1200,7 @@ class UDPRelay(object):
try: try:
#logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data))) #logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data)))
client.sendto(data, (server_addr, server_port)) client.sendto(data, (server_addr, server_port))
self.server_transfer_ul += len(data) self.add_transfer_u(client_uid, len(data))
except IOError as e: except IOError as e:
err = eventloop.errno_from_exception(e) err = eventloop.errno_from_exception(e)
if err in (errno.EINPROGRESS, errno.EAGAIN): if err in (errno.EINPROGRESS, errno.EAGAIN):
@ -1266,14 +1310,22 @@ class UDPRelay(object):
response = b'\x00\x00\x00' + data response = b'\x00\x00\x00' + data
client_addr = self._client_fd_to_server_addr.get(sock.fileno()) client_addr = self._client_fd_to_server_addr.get(sock.fileno())
if client_addr: if client_addr:
self.server_transfer_dl += len(response)
self.write_to_server_socket(response, client_addr[0])
key = client_key(client_addr[0], client_addr[1]) key = client_key(client_addr[0], client_addr[1])
client = self._cache_dns_client.get(key, None) client_pair = self._cache.get(key, None)
if client: client_dns_pair = self._cache_dns_client.get(key, None)
if client_pair:
client, client_uid = client_pair
self.add_transfer_d(client_uid, len(response))
elif client_dns_pair:
client, client_uid = client_dns_pair
self.add_transfer_d(client_uid, len(response))
else:
self.server_transfer_dl += len(response)
self.write_to_server_socket(response, client_addr[0])
if client_dns_pair:
logging.debug("remove dns client %s:%d" % (client_addr[0][0], client_addr[0][1])) logging.debug("remove dns client %s:%d" % (client_addr[0][0], client_addr[0][1]))
del self._cache_dns_client[key] del self._cache_dns_client[key]
self._close_client(client) self._close_client(client_dns_pair[0])
else: else:
# this packet is from somewhere else we know # this packet is from somewhere else we know
# simply drop that packet # simply drop that packet

Loading…
Cancel
Save