Browse Source

impl auth_chain client

dev
破娃酱 8 years ago
parent
commit
caadb606ba
  1. 71
      shadowsocks/obfsplugin/auth_chain.py
  2. 57
      shadowsocks/tcprelay.py
  3. 1
      shadowsocks/udprelay.py

71
shadowsocks/obfsplugin/auth_chain.py

@ -336,34 +336,35 @@ class auth_chain_a(auth_base):
return data return data
def pack_auth_data(self, auth_data, buf): def pack_auth_data(self, auth_data, buf):
if len(buf) == 0:
return b''
if len(buf) > 400:
rnd_len = struct.unpack('<H', os.urandom(2))[0] % 512
else:
rnd_len = struct.unpack('<H', os.urandom(2))[0] % 1024
data = auth_data data = auth_data
data_len = 7 + 4 + 16 + 4 + len(buf) + rnd_len + 4 data_len = 12 + 4 + 16 + 4
data = data + struct.pack('<H', data_len) + struct.pack('<H', rnd_len) data = data + (struct.pack('<H', self.server_info.overhead) + struct.pack('<H', 0))
mac_key = self.server_info.iv + self.server_info.key mac_key = self.server_info.iv + self.server_info.key
uid = os.urandom(4)
check_head = os.urandom(4)
self.last_client_hash = hmac.new(mac_key, check_head, self.hashfunc).digest()
check_head += self.last_client_hash[:8]
if b':' in to_bytes(self.server_info.protocol_param): if b':' in to_bytes(self.server_info.protocol_param):
try: try:
items = to_bytes(self.server_info.protocol_param).split(b':') items = to_bytes(self.server_info.protocol_param).split(b':')
self.user_key = self.hashfunc(items[1]).digest() self.user_key = items[1]
uid = struct.pack('<I', int(items[0])) uid = struct.pack('<I', int(items[0]))
except: except:
pass uid = os.urandom(4)
else:
uid = os.urandom(4)
if self.user_key is None: if self.user_key is None:
self.user_key = self.server_info.key self.user_key = self.server_info.key
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
data = uid + encryptor.encrypt(data)[16:] encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key))+ to_bytes(base64.b64encode(self.last_client_hash)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
data += hmac.new(mac_key, data, self.hashfunc).digest()[:4]
check_head = os.urandom(1) uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0]
check_head += hmac.new(mac_key, check_head, self.hashfunc).digest()[:6] uid = struct.pack('<I', uid)
data = check_head + data + os.urandom(rnd_len) + buf data = check_head + uid + encryptor.encrypt(data)[16:]
data += hmac.new(self.user_key, data, self.hashfunc).digest()[:4] self.last_server_hash = hmac.new(mac_key, data, self.hashfunc).digest()
return data data += self.last_server_hash[:4]
return data + self.pack_client_data(buf)
def auth_data(self): def auth_data(self):
utc_time = int(time.time()) & 0xFFFFFFFF utc_time = int(time.time()) & 0xFFFFFFFF
@ -400,30 +401,34 @@ class auth_chain_a(auth_base):
out_buf = b'' out_buf = b''
while len(self.recv_buf) > 4: while len(self.recv_buf) > 4:
mac_key = self.user_key + struct.pack('<I', self.recv_id) mac_key = self.user_key + struct.pack('<I', self.recv_id)
mac = hmac.new(mac_key, self.recv_buf[:2], self.hashfunc).digest()[:2] data_len = struct.unpack('<H', self.recv_buf[:2])[0] ^ struct.unpack('<H', self.last_server_hash[14:16])[0]
if mac != self.recv_buf[2:4]: rand_len = self.rnd_data_len(data_len, self.last_server_hash, self.random_server)
raise Exception('client_post_decrypt data uncorrect mac') length = data_len + rand_len
length = struct.unpack('<H', self.recv_buf[:2])[0] if length >= 4096:
if length >= 8192 or length < 7:
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
raise Exception('client_post_decrypt data error') raise Exception('client_post_decrypt data error')
if length > len(self.recv_buf):
if length + 4 > len(self.recv_buf):
break break
if hmac.new(mac_key, self.recv_buf[:length - 4], self.hashfunc).digest()[:4] != self.recv_buf[length - 4:length]: server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
if server_hash[:2] != self.recv_buf[length + 2 : length + 4]:
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
raise Exception('client_post_decrypt data uncorrect checksum') raise Exception('client_post_decrypt data uncorrect checksum')
pos = 2
if data_len > 0 and rand_len > 0:
pos = 2 + self.rnd_start_pos(rand_len, self.random_server)
out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos])
self.last_server_hash = server_hash
if self.recv_id == 1:
self.server_info.tcp_mss = out_buf[:2]
out_buf = out_buf[2:]
self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF
pos = common.ord(self.recv_buf[4]) self.recv_buf = self.recv_buf[length + 4:]
if pos < 255:
pos += 4
else:
pos = struct.unpack('<H', self.recv_buf[5:7])[0] + 4
out_buf += self.recv_buf[pos:length - 4]
self.recv_buf = self.recv_buf[length:]
return out_buf return out_buf

57
shadowsocks/tcprelay.py

@ -135,35 +135,25 @@ class TCPRelayHandler(object):
self._remote_udp = False self._remote_udp = False
self._config = config self._config = config
self._dns_resolver = dns_resolver self._dns_resolver = dns_resolver
if not self._create_encryptor(config):
return
self._client_address = local_sock.getpeername()[:2] self._client_address = local_sock.getpeername()[:2]
self._accept_address = local_sock.getsockname()[:2] self._accept_address = local_sock.getsockname()[:2]
self._user = None self._user = None
self._user_id = server._listen_port self._user_id = server._listen_port
self._tcp_mss = TCP_MSS self._update_tcp_mss(local_sock)
# TCP Relay works as either sslocal or ssserver # TCP Relay works as either sslocal or ssserver
# if is_local, this is sslocal # if is_local, this is sslocal
self._is_local = is_local self._is_local = is_local
self._stage = STAGE_INIT self._stage = STAGE_INIT
try:
self._encryptor = encrypt.Encryptor(config['password'],
config['method'])
except Exception:
self._stage = STAGE_DESTROYED
logging.error('create encryptor fail at port %d', server._listen_port)
return
self._encrypt_correct = True self._encrypt_correct = True
self._obfs = obfs.obfs(config['obfs']) self._obfs = obfs.obfs(config['obfs'])
self._protocol = obfs.obfs(config['protocol']) self._protocol = obfs.obfs(config['protocol'])
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local) self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local)
self._recv_buffer_size = BUF_SIZE - self._overhead self._recv_buffer_size = BUF_SIZE - self._overhead
try:
self._tcp_mss = local_sock.getsockopt(socket.SOL_TCP, socket.TCP_MAXSEG)
logging.debug("TCP MSS = %d" % (self._tcp_mss,))
except:
pass
server_info = obfs.server_info(server.obfs_data) server_info = obfs.server_info(server.obfs_data)
server_info.host = config['server'] server_info.host = config['server']
server_info.port = server._listen_port server_info.port = server._listen_port
@ -180,6 +170,7 @@ class TCPRelayHandler(object):
server_info.head_len = 30 server_info.head_len = 30
server_info.tcp_mss = self._tcp_mss server_info.tcp_mss = self._tcp_mss
server_info.buffer_size = self._recv_buffer_size server_info.buffer_size = self._recv_buffer_size
server_info.overhead = self._overhead
self._obfs.set_server_info(server_info) self._obfs.set_server_info(server_info)
server_info = obfs.server_info(server.protocol_data) server_info = obfs.server_info(server.protocol_data)
@ -198,6 +189,7 @@ class TCPRelayHandler(object):
server_info.head_len = 30 server_info.head_len = 30
server_info.tcp_mss = self._tcp_mss server_info.tcp_mss = self._tcp_mss
server_info.buffer_size = self._recv_buffer_size server_info.buffer_size = self._recv_buffer_size
server_info.overhead = self._overhead
self._protocol.set_server_info(server_info) self._protocol.set_server_info(server_info)
self._redir_list = config.get('redirect', ["*#0.0.0.0:0"]) self._redir_list = config.get('redirect', ["*#0.0.0.0:0"])
@ -213,27 +205,24 @@ class TCPRelayHandler(object):
self._upstream_status = WAIT_STATUS_READING self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT self._downstream_status = WAIT_STATUS_INIT
self._remote_address = None self._remote_address = None
if 'forbidden_ip' in config:
self._forbidden_iplist = config['forbidden_ip'] self._forbidden_iplist = config.get('forbidden_ip', None)
else: self._forbidden_portset = config.get('forbidden_port', None)
self._forbidden_iplist = None
if 'forbidden_port' in config:
self._forbidden_portset = config['forbidden_port']
else:
self._forbidden_portset = None
if is_local: if is_local:
self._chosen_server = self._get_a_server() self._chosen_server = self._get_a_server()
fd_to_handlers[local_sock.fileno()] = self fd_to_handlers[local_sock.fileno()] = self
local_sock.setblocking(False) local_sock.setblocking(False)
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, self._server)
self._server)
self.last_activity = 0 self.last_activity = 0
self._update_activity() self._update_activity()
self._server.add_connection(1) self._server.add_connection(1)
self._server.stat_add(self._client_address[0], 1) self._server.stat_add(self._client_address[0], 1)
self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0)) self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0))
self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0)) self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0))
self._recv_pack_id = 0
def __hash__(self): def __hash__(self):
# default __hash__ is id / 16 # default __hash__ is id / 16
@ -254,6 +243,23 @@ class TCPRelayHandler(object):
logging.debug('chosen server: %s:%d', server, server_port) logging.debug('chosen server: %s:%d', server, server_port)
return server, server_port return server, server_port
def _update_tcp_mss(self, local_sock):
self._tcp_mss = TCP_MSS
try:
self._tcp_mss = local_sock.getsockopt(socket.SOL_TCP, socket.TCP_MAXSEG)
logging.debug("TCP MSS = %d" % (self._tcp_mss,))
except:
pass
def _create_encryptor(self, config):
try:
self._encryptor = encrypt.Encryptor(config['password'],
config['method'])
return True
except Exception:
self._stage = STAGE_DESTROYED
logging.error('create encryptor fail at port %d', self._server._listen_port)
def _update_user(self, user): def _update_user(self, user):
self._user = user self._user = user
self._user_id = struct.unpack('<I', user)[0] self._user_id = struct.unpack('<I', user)[0]
@ -884,6 +890,7 @@ class TCPRelayHandler(object):
else: else:
recv_buffer_size = self._get_read_size(self._remote_sock, self._recv_buffer_size) recv_buffer_size = self._get_read_size(self._remote_sock, self._recv_buffer_size)
data = self._remote_sock.recv(recv_buffer_size) data = self._remote_sock.recv(recv_buffer_size)
self._recv_pack_id += 1
except (OSError, IOError) as e: except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) in \ if eventloop.errno_from_exception(e) in \
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK
@ -912,6 +919,8 @@ class TCPRelayHandler(object):
data = self._encryptor.decrypt(obfs_decode[0]) data = self._encryptor.decrypt(obfs_decode[0])
try: try:
data = self._protocol.client_post_decrypt(data) data = self._protocol.client_post_decrypt(data)
if self._recv_pack_id == 1:
self._tcp_mss = self._protocol.get_server_info().tcp_mss
except Exception as e: except Exception as e:
shell.print_exception(e) shell.print_exception(e)
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))

1
shadowsocks/udprelay.py

@ -181,6 +181,7 @@ class UDPRelay(object):
server_info.head_len = 30 server_info.head_len = 30
server_info.tcp_mss = 1452 server_info.tcp_mss = 1452
server_info.buffer_size = BUF_SIZE server_info.buffer_size = BUF_SIZE
server_info.overhead = 0
self._protocol.set_server_info(server_info) self._protocol.set_server_info(server_info)
self._sockets = set() self._sockets = set()

Loading…
Cancel
Save