Browse Source

improve LRUCache

add "tls1.0_session_auth"
fix "auth_sha1" in local mode
refine log
dev
BreakWa11 9 years ago
parent
commit
aff97d4ce8
  1. 67
      shadowsocks/lru_cache.py
  2. 3
      shadowsocks/obfs.py
  3. 16
      shadowsocks/obfsplugin/auth.py
  4. 78
      shadowsocks/obfsplugin/http_simple.py
  5. 276
      shadowsocks/obfsplugin/obfs_tls.py
  6. 6
      shadowsocks/obfsplugin/verify.py
  7. 12
      shadowsocks/tcprelay.py

67
shadowsocks/lru_cache.py

@ -22,14 +22,14 @@ import collections
import logging
import time
# this LRUCache is optimized for concurrency, not QPS
# n: concurrency, keys stored in the cache
# m: visits not timed out, proportional to QPS * timeout
# get & set is O(1), not O(n). thus we can support very large n
# TODO: if timeout or QPS is too large, then this cache is not very efficient,
# as sweep() causes long pause
# get & set is O(log(n)), not O(n). thus we can support very large n
# sweep is O((n - m)*log(n)) or O(1024*log(n)) at most,
# no metter how large the cache or timeout value is
SWEEP_MAX_ITEMS = 1024
class LRUCache(collections.MutableMapping):
"""This class is not thread safe"""
@ -38,32 +38,39 @@ class LRUCache(collections.MutableMapping):
self.timeout = timeout
self.close_callback = close_callback
self._store = {}
self._time_to_keys = collections.defaultdict(list)
self._time_to_keys = collections.OrderedDict()
self._keys_to_last_time = {}
self._last_visits = collections.deque()
self._closed_values = set()
self._visit_id = 0
self.update(dict(*args, **kwargs)) # use the free update to set keys
def __getitem__(self, key):
# O(1)
# O(log(n))
t = time.time()
self._keys_to_last_time[key] = t
self._time_to_keys[t].append(key)
self._last_visits.append(t)
last_t, vid = self._keys_to_last_time[key]
self._keys_to_last_time[key] = (t, vid)
if last_t != t:
del self._time_to_keys[(last_t, vid)]
self._time_to_keys[(t, vid)] = key
return self._store[key]
def __setitem__(self, key, value):
# O(1)
# O(log(n))
t = time.time()
self._keys_to_last_time[key] = t
if key in self._keys_to_last_time:
last_t, vid = self._keys_to_last_time[key]
del self._time_to_keys[(last_t, vid)]
vid = self._visit_id
self._visit_id += 1
self._keys_to_last_time[key] = (t, vid)
self._store[key] = value
self._time_to_keys[t].append(key)
self._last_visits.append(t)
self._time_to_keys[(t, vid)] = key
def __delitem__(self, key):
# O(1)
# O(log(n))
last_t, vid = self._keys_to_last_time[key]
del self._store[key]
del self._keys_to_last_time[key]
del self._time_to_keys[(last_t, vid)]
def __iter__(self):
return iter(self._store)
@ -72,39 +79,33 @@ class LRUCache(collections.MutableMapping):
return len(self._store)
def sweep(self):
# O(m)
# O(n - m)
now = time.time()
c = 0
while len(self._last_visits) > 0:
least = self._last_visits[0]
if now - least <= self.timeout:
while c < SWEEP_MAX_ITEMS:
if len(self._time_to_keys) == 0:
break
if self.close_callback is not None:
for key in self._time_to_keys[least]:
if key in self._store:
if now - self._keys_to_last_time[key] > self.timeout:
last_t, vid = iter(self._time_to_keys).next()
if now - last_t <= self.timeout:
break
key = self._time_to_keys[(last_t, vid)]
value = self._store[key]
if value not in self._closed_values:
if self.close_callback is not None:
self.close_callback(value)
self._closed_values.add(value)
for key in self._time_to_keys[least]:
if key in self._store:
if now - self._keys_to_last_time[key] > self.timeout:
del self._store[key]
del self._keys_to_last_time[key]
del self._time_to_keys[(last_t, vid)]
c += 1
self._last_visits.popleft()
del self._time_to_keys[least]
if c:
self._closed_values.clear()
logging.debug('%d keys swept' % c)
return c < SWEEP_MAX_ITEMS
def test():
c = LRUCache(timeout=0.3)
c['a'] = 1
assert c['a'] == 1
c['a'] = 1
time.sleep(0.5)
c.sweep()

3
shadowsocks/obfs.py

@ -23,12 +23,13 @@ import hashlib
import logging
from shadowsocks import common
from shadowsocks.obfsplugin import plain, http_simple, verify, auth
from shadowsocks.obfsplugin import plain, http_simple, obfs_tls, verify, auth
method_supported = {}
method_supported.update(plain.obfs_map)
method_supported.update(http_simple.obfs_map)
method_supported.update(obfs_tls.obfs_map)
method_supported.update(verify.obfs_map)
method_supported.update(auth.obfs_map)

16
shadowsocks/obfsplugin/auth.py

@ -250,7 +250,7 @@ class auth_simple(verify_base):
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data error')
raise Exception('client_post_decrypt data error')
if length > len(self.recv_buf):
break
@ -260,7 +260,7 @@ class auth_simple(verify_base):
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data uncorrect CRC32')
raise Exception('client_post_decrypt data uncorrect CRC32')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
@ -379,7 +379,7 @@ class auth_sha1(verify_base):
return b''
rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 128)
data = common.chr(len(rnd_data) + 1) + rnd_data + buf
data = struct.pack('>H', len(data) + 10) + data
data = struct.pack('>H', len(data) + 16) + data
crc = binascii.crc32(self.server_info.key)
data = struct.pack('<I', crc) + data
data += hmac.new(self.server_info.iv + self.server_info.key, data, hashlib.sha1).digest()[:10]
@ -425,17 +425,17 @@ class auth_sha1(verify_base):
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data error')
raise Exception('client_post_decrypt data error')
if length > len(self.recv_buf):
break
if zlib.adler32(self.recv_buf[:length - 4]) != struct.unpack('<I', self.recv_buf[length - 4:length])[0]:
if struct.pack('<I', zlib.adler32(self.recv_buf[:length - 4]) & 0xFFFFFFFF) != self.recv_buf[length - 4:length]:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data uncorrect checksum')
raise Exception('client_post_decrypt data uncorrect checksum')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
@ -475,7 +475,7 @@ class auth_sha1(verify_base):
return b''
sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, self.recv_buf[:length - 10], hashlib.sha1).digest()[:10]
if sha1data != self.recv_buf[length - 10:length]:
logging.error('server_post_decrype data uncorrect auth HMAC-SHA1')
logging.error('auth_sha1 data uncorrect auth HMAC-SHA1')
return b'E'
pos = common.ord(self.recv_buf[6]) + 6
out_buf = self.recv_buf[pos:length - 10]
@ -520,7 +520,7 @@ class auth_sha1(verify_base):
if length > len(self.recv_buf):
break
if zlib.adler32(self.recv_buf[:length - 4]) & 0xFFFFFFFF != struct.unpack('<I', self.recv_buf[length - 4:length])[0]:
if struct.pack('<I', zlib.adler32(self.recv_buf[:length - 4]) & 0xFFFFFFFF) != self.recv_buf[length - 4:length]:
logging.info('auth_sha1: checksum error, data %s' % (binascii.hexlify(self.recv_buf[:length]),))
self.raw_trans = True
self.recv_buf = b''

78
shadowsocks/obfsplugin/http_simple.py

@ -37,9 +37,6 @@ def create_http_obfs(method):
def create_http2_obfs(method):
return http2_simple(method)
def create_tls_obfs(method):
return tls_simple(method)
def create_random_head_obfs(method):
return random_head(method)
@ -48,8 +45,6 @@ obfs_map = {
'http_simple_compatible': (create_http_obfs,),
'http2_simple': (create_http2_obfs,),
'http2_simple_compatible': (create_http2_obfs,),
'tls_simple': (create_tls_obfs,),
'tls_simple_compatible': (create_tls_obfs,),
'random_head': (create_random_head_obfs,),
'random_head_compatible': (create_random_head_obfs,),
}
@ -271,79 +266,6 @@ class http2_simple(plain.plain):
return (b'', True, False)
return self.not_match_return(buf)
class tls_simple(plain.plain):
def __init__(self, method):
self.method = method
self.has_sent_header = False
self.has_recv_header = False
self.raw_trans_sent = False
self.send_buffer = b''
def client_encode(self, buf):
if self.raw_trans_sent:
return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = b"\x03\x03" + os.urandom(32) + binascii.unhexlify(b"000016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202")
data = b"\x01\x00" + struct.pack('>H', len(data)) + data
data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data
return data
if self.has_recv_header:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf):
if self.has_recv_header:
return (buf, False)
self.has_recv_header = True
return (b'', True)
def server_encode(self, buf):
if self.has_sent_header:
return buf
self.has_sent_header = True
# TODO
data = b"\x03\x03" + os.urandom(32)
data = b"\x02\x00" + struct.pack('>H', len(data)) + data
data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data
return data
def decode_error_return(self, buf):
self.has_sent_header = True
if self.method == 'tls_simple':
return (b'E', False, False)
return (buf, True, False)
def server_decode(self, buf):
if self.has_recv_header:
return (buf, True, False)
self.has_recv_header = True
if not match_begin(buf, b'\x16\x03\x01'):
return self.decode_error_return(buf);
buf = buf[3:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
return self.decode_error_return(buf);
buf = buf[2:]
if not match_begin(buf, b'\x01\x00'): #client hello
return self.decode_error_return(buf);
buf = buf[2:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
return self.decode_error_return(buf);
buf = buf[2:]
if not match_begin(buf, b'\x03\x03'):
return self.decode_error_return(buf);
buf = buf[2:]
verifyid = buf[:32]
buf = buf[32:]
sessionid = buf[:4]
# (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back)
return (b'', False, True)
class random_head(plain.plain):
def __init__(self, method):
self.method = method

276
shadowsocks/obfsplugin/obfs_tls.py

@ -0,0 +1,276 @@
#!/usr/bin/env python
#
# Copyright 2015-2015 breakwa11
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import sys
import hashlib
import logging
import binascii
import struct
import base64
import time
import random
import hmac
import hashlib
from shadowsocks import common
from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord
from shadowsocks import lru_cache
def create_tls_obfs(method):
return tls_simple(method)
def create_tls_auth_obfs(method):
return tls_auth(method)
obfs_map = {
'tls_simple': (create_tls_obfs,),
'tls_simple_compatible': (create_tls_obfs,),
'tls1.0_session_auth': (create_tls_auth_obfs,),
'tls1.0_session_auth_compatible': (create_tls_auth_obfs,),
}
def match_begin(str1, str2):
if len(str1) >= len(str2):
if str1[:len(str2)] == str2:
return True
return False
class tls_simple(plain.plain):
def __init__(self, method):
self.method = method
self.has_sent_header = False
self.has_recv_header = False
self.raw_trans_sent = False
self.send_buffer = b''
self.tls_version = b'\x03\x01'
def client_encode(self, buf):
if self.raw_trans_sent:
return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = self.tls_version + os.urandom(32) + binascii.unhexlify(b"000016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202")
data = b"\x01\x00" + struct.pack('>H', len(data)) + data
data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data
return data
if self.has_recv_header:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf):
if self.has_recv_header:
return (buf, False)
self.has_recv_header = True
return (b'', True)
def server_encode(self, buf):
if self.has_sent_header:
return buf
self.has_sent_header = True
# TODO
data = self.tls_version + os.urandom(32)
data = b"\x02\x00" + struct.pack('>H', len(data)) + data
data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data
return data
def decode_error_return(self, buf):
self.has_sent_header = True
if self.method == 'tls_simple':
return (b'E', False, False)
return (buf, True, False)
def server_decode(self, buf):
if self.has_recv_header:
return (buf, True, False)
self.has_recv_header = True
if not match_begin(buf, b'\x16' + self.tls_version):
return self.decode_error_return(buf)
buf = buf[3:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
return self.decode_error_return(buf)
buf = buf[2:]
if not match_begin(buf, b'\x01\x00'): #client hello
return self.decode_error_return(buf)
buf = buf[2:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
return self.decode_error_return(buf)
buf = buf[2:]
if not match_begin(buf, self.tls_version):
return self.decode_error_return(buf)
buf = buf[2:]
verifyid = buf[:32]
buf = buf[32:]
sessionid_len = ord(buf[1])
sessionid = buf[1:sessionid_len + 1]
buf = buf[sessionid_len+1:]
# (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back)
return (b'', False, True)
class obfs_client_data(object):
def __init__(self, cid):
self.client_id = cid
self.auth_code = {}
class obfs_auth_data(object):
def __init__(self):
self.client_data = lru_cache.LRUCache(60 * 5)
self.client_id = os.urandom(32)
self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF
class tls_auth(plain.plain):
def __init__(self, method):
self.method = method
self.has_sent_header = False
self.has_recv_header = False
self.raw_trans_sent = False
self.raw_trans_recv = False
self.send_buffer = b''
self.client_id = b''
self.max_time_dif = 60 * 60 # time dif (second) setting
self.tls_version = b'\x03\x01'
def init_data(self):
return obfs_auth_data()
def pack_auth_data(self, client_id):
utc_time = int(time.time()) & 0xFFFFFFFF
data = struct.pack('>I', utc_time) + os.urandom(18)
data += hmac.new(self.server_info.key + client_id, data, hashlib.sha1).digest()[:10]
return data
def client_encode(self, buf):
if self.raw_trans_sent:
return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"0016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202")
data = b"\x01\x00" + struct.pack('>H', len(data)) + data
data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data
return data
if self.has_recv_header:
data = b"\x14" + self.tls_version + "\x00\x01\x01" #ChangeCipherSpec
data += b"\x16" + self.tls_version + "\x00\x01\x20" + os.urandom(22) #Finished
data += hmac.new(self.server_info.key + self.server_info.data.client_id, data, hashlib.sha1).digest()[:10]
ret = data + self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf):
if self.has_recv_header:
return (buf, False)
self.has_recv_header = True
return (b'', True)
def server_encode(self, buf):
if self.has_sent_header:
return buf
self.has_sent_header = True
data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"0016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202")
data = b"\x02\x00" + struct.pack('>H', len(data)) + data #server hello
data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data
data += b"\x14" + self.tls_version + "\x00\x01\x01" #ChangeCipherSpec
data += b"\x16" + self.tls_version + "\x00\x01\x20" + os.urandom(22) #Finished
data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10]
return data
def decode_error_return(self, buf):
self.raw_trans_recv = True
if self.method == 'tls_simple':
return (b'E', False, False)
return (buf, True, False)
def server_decode(self, buf):
if self.raw_trans_recv:
return (buf, True, False)
if self.has_recv_header:
verify = buf
verify_len = 44 - 10
if len(buf) < 44:
logging.error('server_decode data error')
return decode_error_return(b'')
if not match_begin(buf, b"\x14" + self.tls_version + "\x00\x01\x01"): #ChangeCipherSpec
logging.error('server_decode data error')
return decode_error_return(b'')
buf = buf[6:]
if not match_begin(buf, b"\x16" + self.tls_version + "\x00\x01\x20"): #Finished
logging.error('server_decode data error')
return decode_error_return(b'')
if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]:
logging.error('server_decode data error')
return decode_error_return(b'')
if len(buf) < 38:
logging.error('server_decode data error')
return decode_error_return(b'')
buf = buf[38:]
self.raw_trans_recv = True
return (buf, True, False)
self.has_recv_header = True
ogn_buf = buf
if not match_begin(buf, b'\x16' + self.tls_version):
return self.decode_error_return(ogn_buf)
buf = buf[3:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
return self.decode_error_return(ogn_buf)
buf = buf[2:]
if not match_begin(buf, b'\x01\x00'): #client hello
return self.decode_error_return(ogn_buf)
buf = buf[2:]
if struct.unpack('>H', buf[:2])[0] != len(buf) - 2:
return self.decode_error_return(ogn_buf)
buf = buf[2:]
if not match_begin(buf, self.tls_version):
return self.decode_error_return(ogn_buf)
buf = buf[2:]
verifyid = buf[:32]
buf = buf[32:]
sessionid_len = ord(buf[0])
if sessionid_len < 32:
logging.error("tls_auth wrong sessionid_len")
return self.decode_error_return(ogn_buf)
sessionid = buf[1:sessionid_len + 1]
buf = buf[sessionid_len+1:]
self.client_id = sessionid
sha1 = hmac.new(self.server_info.key + sessionid, verifyid[:22], hashlib.sha1).digest()[:10]
utc_time = struct.unpack('>I', verifyid[:4])[0]
time_dif = common.int32((int(time.time()) & 0xffffffff) - utc_time)
if time_dif < -self.max_time_dif or time_dif > self.max_time_dif \
or common.int32(utc_time - self.server_info.data.startup_time) < -self.max_time_dif / 2:
logging.debug("tls_auth wrong time")
return self.decode_error_return(ogn_buf)
if sha1 != verifyid[22:]:
logging.debug("tls_auth wrong sha1")
return self.decode_error_return(ogn_buf)
if verifyid[4:22] in self.server_info.data.client_data:
logging.error("replay attack detect, id = %s" % (binascii.hexlify(verifyid)))
return self.decode_error_return(ogn_buf)
# (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back)
return (b'', False, True)

6
shadowsocks/obfsplugin/verify.py

@ -127,7 +127,7 @@ class verify_simple(verify_base):
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data error')
raise Exception('client_post_decrypt data error')
if length > len(self.recv_buf):
break
@ -137,7 +137,7 @@ class verify_simple(verify_base):
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data uncorrect CRC32')
raise Exception('client_post_decrypt data uncorrect CRC32')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
@ -224,7 +224,7 @@ class verify_deflate(verify_base):
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data error')
raise Exception('client_post_decrypt data error')
if length > len(self.recv_buf):
break

12
shadowsocks/tcprelay.py

@ -598,7 +598,11 @@ class TCPRelayHandler(object):
if not is_local:
if self._encryptor is not None:
if self._encrypt_correct:
try:
obfs_decode = self._obfs.server_decode(data)
except Exception as e:
shell.print_exception(e)
self.destroy()
if obfs_decode[2]:
self._write_to_sock(b'', self._local_sock)
if obfs_decode[1]:
@ -665,7 +669,11 @@ class TCPRelayHandler(object):
return
if self._encryptor is not None:
if self._is_local:
try:
obfs_decode = self._obfs.client_decode(data)
except Exception as e:
shell.print_exception(e)
self.destroy()
if obfs_decode[1]:
send_back = self._obfs.client_encode(b'')
self._write_to_sock(send_back, self._remote_sock)
@ -673,7 +681,11 @@ class TCPRelayHandler(object):
iv_len = len(self._protocol.obfs.server_info.iv)
self._protocol.obfs.server_info.recv_iv = obfs_decode[0][:iv_len]
data = self._encryptor.decrypt(obfs_decode[0])
try:
data = self._protocol.client_post_decrypt(data)
except Exception as e:
shell.print_exception(e)
self.destroy()
else:
if self._encrypt_correct:
data = self._protocol.server_pre_encrypt(data)

Loading…
Cancel
Save