From aff97d4ce8cba4dc8835037981512e22686fe0b3 Mon Sep 17 00:00:00 2001 From: BreakWa11 Date: Mon, 14 Dec 2015 00:55:44 +0800 Subject: [PATCH] improve LRUCache add "tls1.0_session_auth" fix "auth_sha1" in local mode refine log --- shadowsocks/lru_cache.py | 75 +++---- shadowsocks/obfs.py | 3 +- shadowsocks/obfsplugin/auth.py | 16 +- shadowsocks/obfsplugin/http_simple.py | 78 -------- shadowsocks/obfsplugin/obfs_tls.py | 276 ++++++++++++++++++++++++++ shadowsocks/obfsplugin/verify.py | 6 +- shadowsocks/tcprelay.py | 18 +- 7 files changed, 342 insertions(+), 130 deletions(-) create mode 100644 shadowsocks/obfsplugin/obfs_tls.py diff --git a/shadowsocks/lru_cache.py b/shadowsocks/lru_cache.py index e67fdff..3b3c264 100644 --- a/shadowsocks/lru_cache.py +++ b/shadowsocks/lru_cache.py @@ -22,14 +22,14 @@ import collections import logging import time - # this LRUCache is optimized for concurrency, not QPS # n: concurrency, keys stored in the cache # m: visits not timed out, proportional to QPS * timeout -# get & set is O(1), not O(n). thus we can support very large n -# TODO: if timeout or QPS is too large, then this cache is not very efficient, -# as sweep() causes long pause +# get & set is O(log(n)), not O(n). thus we can support very large n +# sweep is O((n - m)*log(n)) or O(1024*log(n)) at most, +# no metter how large the cache or timeout value is +SWEEP_MAX_ITEMS = 1024 class LRUCache(collections.MutableMapping): """This class is not thread safe""" @@ -38,32 +38,39 @@ class LRUCache(collections.MutableMapping): self.timeout = timeout self.close_callback = close_callback self._store = {} - self._time_to_keys = collections.defaultdict(list) + self._time_to_keys = collections.OrderedDict() self._keys_to_last_time = {} - self._last_visits = collections.deque() - self._closed_values = set() + self._visit_id = 0 self.update(dict(*args, **kwargs)) # use the free update to set keys def __getitem__(self, key): - # O(1) + # O(log(n)) t = time.time() - self._keys_to_last_time[key] = t - self._time_to_keys[t].append(key) - self._last_visits.append(t) + last_t, vid = self._keys_to_last_time[key] + self._keys_to_last_time[key] = (t, vid) + if last_t != t: + del self._time_to_keys[(last_t, vid)] + self._time_to_keys[(t, vid)] = key return self._store[key] def __setitem__(self, key, value): - # O(1) + # O(log(n)) t = time.time() - self._keys_to_last_time[key] = t + if key in self._keys_to_last_time: + last_t, vid = self._keys_to_last_time[key] + del self._time_to_keys[(last_t, vid)] + vid = self._visit_id + self._visit_id += 1 + self._keys_to_last_time[key] = (t, vid) self._store[key] = value - self._time_to_keys[t].append(key) - self._last_visits.append(t) + self._time_to_keys[(t, vid)] = key def __delitem__(self, key): - # O(1) + # O(log(n)) + last_t, vid = self._keys_to_last_time[key] del self._store[key] del self._keys_to_last_time[key] + del self._time_to_keys[(last_t, vid)] def __iter__(self): return iter(self._store) @@ -72,39 +79,33 @@ class LRUCache(collections.MutableMapping): return len(self._store) def sweep(self): - # O(m) + # O(n - m) now = time.time() c = 0 - while len(self._last_visits) > 0: - least = self._last_visits[0] - if now - least <= self.timeout: + while c < SWEEP_MAX_ITEMS: + if len(self._time_to_keys) == 0: + break + last_t, vid = iter(self._time_to_keys).next() + if now - last_t <= self.timeout: break + key = self._time_to_keys[(last_t, vid)] + value = self._store[key] if self.close_callback is not None: - for key in self._time_to_keys[least]: - if key in self._store: - if now - self._keys_to_last_time[key] > self.timeout: - value = self._store[key] - if value not in self._closed_values: - self.close_callback(value) - self._closed_values.add(value) - for key in self._time_to_keys[least]: - if key in self._store: - if now - self._keys_to_last_time[key] > self.timeout: - del self._store[key] - del self._keys_to_last_time[key] - c += 1 - self._last_visits.popleft() - del self._time_to_keys[least] + self.close_callback(value) + del self._store[key] + del self._keys_to_last_time[key] + del self._time_to_keys[(last_t, vid)] + c += 1 if c: - self._closed_values.clear() logging.debug('%d keys swept' % c) - + return c < SWEEP_MAX_ITEMS def test(): c = LRUCache(timeout=0.3) c['a'] = 1 assert c['a'] == 1 + c['a'] = 1 time.sleep(0.5) c.sweep() diff --git a/shadowsocks/obfs.py b/shadowsocks/obfs.py index 1bfaf67..94a14f2 100644 --- a/shadowsocks/obfs.py +++ b/shadowsocks/obfs.py @@ -23,12 +23,13 @@ import hashlib import logging from shadowsocks import common -from shadowsocks.obfsplugin import plain, http_simple, verify, auth +from shadowsocks.obfsplugin import plain, http_simple, obfs_tls, verify, auth method_supported = {} method_supported.update(plain.obfs_map) method_supported.update(http_simple.obfs_map) +method_supported.update(obfs_tls.obfs_map) method_supported.update(verify.obfs_map) method_supported.update(auth.obfs_map) diff --git a/shadowsocks/obfsplugin/auth.py b/shadowsocks/obfsplugin/auth.py index 2260bb4..0f639aa 100644 --- a/shadowsocks/obfsplugin/auth.py +++ b/shadowsocks/obfsplugin/auth.py @@ -250,7 +250,7 @@ class auth_simple(verify_base): if self.decrypt_packet_num == 0: return None else: - raise Exception('server_post_decrype data error') + raise Exception('client_post_decrypt data error') if length > len(self.recv_buf): break @@ -260,7 +260,7 @@ class auth_simple(verify_base): if self.decrypt_packet_num == 0: return None else: - raise Exception('server_post_decrype data uncorrect CRC32') + raise Exception('client_post_decrypt data uncorrect CRC32') pos = common.ord(self.recv_buf[2]) + 2 out_buf += self.recv_buf[pos:length - 4] @@ -379,7 +379,7 @@ class auth_sha1(verify_base): return b'' rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 128) data = common.chr(len(rnd_data) + 1) + rnd_data + buf - data = struct.pack('>H', len(data) + 10) + data + data = struct.pack('>H', len(data) + 16) + data crc = binascii.crc32(self.server_info.key) data = struct.pack(' len(self.recv_buf): break - if zlib.adler32(self.recv_buf[:length - 4]) != struct.unpack(' len(self.recv_buf): break - if zlib.adler32(self.recv_buf[:length - 4]) & 0xFFFFFFFF != struct.unpack('H', len(data)) + data - data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data - return data - if self.has_recv_header: - ret = self.send_buffer - self.send_buffer = b'' - self.raw_trans_sent = True - return ret - return b'' - - def client_decode(self, buf): - if self.has_recv_header: - return (buf, False) - self.has_recv_header = True - return (b'', True) - - def server_encode(self, buf): - if self.has_sent_header: - return buf - self.has_sent_header = True - # TODO - data = b"\x03\x03" + os.urandom(32) - data = b"\x02\x00" + struct.pack('>H', len(data)) + data - data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data - return data - - def decode_error_return(self, buf): - self.has_sent_header = True - if self.method == 'tls_simple': - return (b'E', False, False) - return (buf, True, False) - - def server_decode(self, buf): - if self.has_recv_header: - return (buf, True, False) - - self.has_recv_header = True - if not match_begin(buf, b'\x16\x03\x01'): - return self.decode_error_return(buf); - buf = buf[3:] - if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: - return self.decode_error_return(buf); - buf = buf[2:] - if not match_begin(buf, b'\x01\x00'): #client hello - return self.decode_error_return(buf); - buf = buf[2:] - if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: - return self.decode_error_return(buf); - buf = buf[2:] - if not match_begin(buf, b'\x03\x03'): - return self.decode_error_return(buf); - buf = buf[2:] - verifyid = buf[:32] - buf = buf[32:] - sessionid = buf[:4] - # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) - return (b'', False, True) - class random_head(plain.plain): def __init__(self, method): self.method = method diff --git a/shadowsocks/obfsplugin/obfs_tls.py b/shadowsocks/obfsplugin/obfs_tls.py new file mode 100644 index 0000000..82f25d0 --- /dev/null +++ b/shadowsocks/obfsplugin/obfs_tls.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import struct +import base64 +import time +import random +import hmac +import hashlib + +from shadowsocks import common +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord +from shadowsocks import lru_cache + +def create_tls_obfs(method): + return tls_simple(method) + +def create_tls_auth_obfs(method): + return tls_auth(method) + +obfs_map = { + 'tls_simple': (create_tls_obfs,), + 'tls_simple_compatible': (create_tls_obfs,), + 'tls1.0_session_auth': (create_tls_auth_obfs,), + 'tls1.0_session_auth_compatible': (create_tls_auth_obfs,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class tls_simple(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + self.send_buffer = b'' + self.tls_version = b'\x03\x01' + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + self.has_sent_header = True + data = self.tls_version + os.urandom(32) + binascii.unhexlify(b"000016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202") + data = b"\x01\x00" + struct.pack('>H', len(data)) + data + data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + return data + if self.has_recv_header: + ret = self.send_buffer + self.send_buffer = b'' + self.raw_trans_sent = True + return ret + return b'' + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + self.has_recv_header = True + return (b'', True) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + self.has_sent_header = True + # TODO + data = self.tls_version + os.urandom(32) + data = b"\x02\x00" + struct.pack('>H', len(data)) + data + data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + return data + + def decode_error_return(self, buf): + self.has_sent_header = True + if self.method == 'tls_simple': + return (b'E', False, False) + return (buf, True, False) + + def server_decode(self, buf): + if self.has_recv_header: + return (buf, True, False) + + self.has_recv_header = True + if not match_begin(buf, b'\x16' + self.tls_version): + return self.decode_error_return(buf) + buf = buf[3:] + if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: + return self.decode_error_return(buf) + buf = buf[2:] + if not match_begin(buf, b'\x01\x00'): #client hello + return self.decode_error_return(buf) + buf = buf[2:] + if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: + return self.decode_error_return(buf) + buf = buf[2:] + if not match_begin(buf, self.tls_version): + return self.decode_error_return(buf) + buf = buf[2:] + verifyid = buf[:32] + buf = buf[32:] + sessionid_len = ord(buf[1]) + sessionid = buf[1:sessionid_len + 1] + buf = buf[sessionid_len+1:] + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (b'', False, True) + +class obfs_client_data(object): + def __init__(self, cid): + self.client_id = cid + self.auth_code = {} + +class obfs_auth_data(object): + def __init__(self): + self.client_data = lru_cache.LRUCache(60 * 5) + self.client_id = os.urandom(32) + self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF + +class tls_auth(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + self.raw_trans_recv = False + self.send_buffer = b'' + self.client_id = b'' + self.max_time_dif = 60 * 60 # time dif (second) setting + self.tls_version = b'\x03\x01' + + def init_data(self): + return obfs_auth_data() + + def pack_auth_data(self, client_id): + utc_time = int(time.time()) & 0xFFFFFFFF + data = struct.pack('>I', utc_time) + os.urandom(18) + data += hmac.new(self.server_info.key + client_id, data, hashlib.sha1).digest()[:10] + return data + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + self.has_sent_header = True + data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"0016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202") + data = b"\x01\x00" + struct.pack('>H', len(data)) + data + data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + return data + if self.has_recv_header: + data = b"\x14" + self.tls_version + "\x00\x01\x01" #ChangeCipherSpec + data += b"\x16" + self.tls_version + "\x00\x01\x20" + os.urandom(22) #Finished + data += hmac.new(self.server_info.key + self.server_info.data.client_id, data, hashlib.sha1).digest()[:10] + ret = data + self.send_buffer + self.send_buffer = b'' + self.raw_trans_sent = True + return ret + return b'' + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + self.has_recv_header = True + return (b'', True) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + self.has_sent_header = True + data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"0016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202") + data = b"\x02\x00" + struct.pack('>H', len(data)) + data #server hello + data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + data += b"\x14" + self.tls_version + "\x00\x01\x01" #ChangeCipherSpec + data += b"\x16" + self.tls_version + "\x00\x01\x20" + os.urandom(22) #Finished + data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10] + return data + + def decode_error_return(self, buf): + self.raw_trans_recv = True + if self.method == 'tls_simple': + return (b'E', False, False) + return (buf, True, False) + + def server_decode(self, buf): + if self.raw_trans_recv: + return (buf, True, False) + + if self.has_recv_header: + verify = buf + verify_len = 44 - 10 + if len(buf) < 44: + logging.error('server_decode data error') + return decode_error_return(b'') + if not match_begin(buf, b"\x14" + self.tls_version + "\x00\x01\x01"): #ChangeCipherSpec + logging.error('server_decode data error') + return decode_error_return(b'') + buf = buf[6:] + if not match_begin(buf, b"\x16" + self.tls_version + "\x00\x01\x20"): #Finished + logging.error('server_decode data error') + return decode_error_return(b'') + if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]: + logging.error('server_decode data error') + return decode_error_return(b'') + if len(buf) < 38: + logging.error('server_decode data error') + return decode_error_return(b'') + buf = buf[38:] + self.raw_trans_recv = True + return (buf, True, False) + + self.has_recv_header = True + ogn_buf = buf + if not match_begin(buf, b'\x16' + self.tls_version): + return self.decode_error_return(ogn_buf) + buf = buf[3:] + if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if not match_begin(buf, b'\x01\x00'): #client hello + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if not match_begin(buf, self.tls_version): + return self.decode_error_return(ogn_buf) + buf = buf[2:] + verifyid = buf[:32] + buf = buf[32:] + sessionid_len = ord(buf[0]) + if sessionid_len < 32: + logging.error("tls_auth wrong sessionid_len") + return self.decode_error_return(ogn_buf) + sessionid = buf[1:sessionid_len + 1] + buf = buf[sessionid_len+1:] + self.client_id = sessionid + sha1 = hmac.new(self.server_info.key + sessionid, verifyid[:22], hashlib.sha1).digest()[:10] + utc_time = struct.unpack('>I', verifyid[:4])[0] + time_dif = common.int32((int(time.time()) & 0xffffffff) - utc_time) + if time_dif < -self.max_time_dif or time_dif > self.max_time_dif \ + or common.int32(utc_time - self.server_info.data.startup_time) < -self.max_time_dif / 2: + logging.debug("tls_auth wrong time") + return self.decode_error_return(ogn_buf) + if sha1 != verifyid[22:]: + logging.debug("tls_auth wrong sha1") + return self.decode_error_return(ogn_buf) + if verifyid[4:22] in self.server_info.data.client_data: + logging.error("replay attack detect, id = %s" % (binascii.hexlify(verifyid))) + return self.decode_error_return(ogn_buf) + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (b'', False, True) + diff --git a/shadowsocks/obfsplugin/verify.py b/shadowsocks/obfsplugin/verify.py index 5c54c78..8b17345 100644 --- a/shadowsocks/obfsplugin/verify.py +++ b/shadowsocks/obfsplugin/verify.py @@ -127,7 +127,7 @@ class verify_simple(verify_base): if self.decrypt_packet_num == 0: return None else: - raise Exception('server_post_decrype data error') + raise Exception('client_post_decrypt data error') if length > len(self.recv_buf): break @@ -137,7 +137,7 @@ class verify_simple(verify_base): if self.decrypt_packet_num == 0: return None else: - raise Exception('server_post_decrype data uncorrect CRC32') + raise Exception('client_post_decrypt data uncorrect CRC32') pos = common.ord(self.recv_buf[2]) + 2 out_buf += self.recv_buf[pos:length - 4] @@ -224,7 +224,7 @@ class verify_deflate(verify_base): if self.decrypt_packet_num == 0: return None else: - raise Exception('server_post_decrype data error') + raise Exception('client_post_decrypt data error') if length > len(self.recv_buf): break diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 1f301aa..2773d8b 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -598,7 +598,11 @@ class TCPRelayHandler(object): if not is_local: if self._encryptor is not None: if self._encrypt_correct: - obfs_decode = self._obfs.server_decode(data) + try: + obfs_decode = self._obfs.server_decode(data) + except Exception as e: + shell.print_exception(e) + self.destroy() if obfs_decode[2]: self._write_to_sock(b'', self._local_sock) if obfs_decode[1]: @@ -665,7 +669,11 @@ class TCPRelayHandler(object): return if self._encryptor is not None: if self._is_local: - obfs_decode = self._obfs.client_decode(data) + try: + obfs_decode = self._obfs.client_decode(data) + except Exception as e: + shell.print_exception(e) + self.destroy() if obfs_decode[1]: send_back = self._obfs.client_encode(b'') self._write_to_sock(send_back, self._remote_sock) @@ -673,7 +681,11 @@ class TCPRelayHandler(object): iv_len = len(self._protocol.obfs.server_info.iv) self._protocol.obfs.server_info.recv_iv = obfs_decode[0][:iv_len] data = self._encryptor.decrypt(obfs_decode[0]) - data = self._protocol.client_post_decrypt(data) + try: + data = self._protocol.client_post_decrypt(data) + except Exception as e: + shell.print_exception(e) + self.destroy() else: if self._encrypt_correct: data = self._protocol.server_pre_encrypt(data)