Browse Source

Finish 3.2.2

many fix
akkariiin/master 3.2.2
Akkariiin 7 years ago
parent
commit
190bf5e79e
  1. 2
      configloader.py
  2. 30
      db_transfer.py
  3. 2
      importloader.py
  4. 10
      server_pool.py
  5. 36
      shadowsocks/common.py
  6. 4
      shadowsocks/crypto/util.py
  7. 19
      shadowsocks/encrypt.py
  8. 64
      shadowsocks/obfsplugin/auth_chain.py
  9. 6
      shadowsocks/server.py
  10. 7
      shadowsocks/tcprelay.py
  11. 2
      shadowsocks/udprelay.py
  12. 2
      shadowsocks/version.py
  13. 3
      switchrule.py

2
configloader.py

@ -1,5 +1,5 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
import importloader
g_config = None

30
db_transfer.py

@ -1,5 +1,5 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
import logging
import time
@ -9,6 +9,7 @@ import traceback
from shadowsocks import common, shell, lru_cache, obfs
from configloader import load_config, get_config
import importloader
import copy
switchrule = None
db_instance = None
@ -80,8 +81,10 @@ class TransferBase(object):
def del_server_out_of_bound_safe(self, last_rows, rows):
#停止超流量的服务
#启动没超流量的服务
keymap = {}
try:
switchrule = importloader.load('switchrule')
keymap = switchrule.getRowMap()
except Exception as e:
logging.error('load switchrule.py fail')
cur_servers = {}
@ -106,7 +109,10 @@ class TransferBase(object):
read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user']
for name in read_config_keys:
if name in row and row[name]:
cfg[name] = row[name]
if name in keymap:
cfg[keymap[name]] = row[name]
else:
cfg[name] = row[name]
merge_config_keys = ['password'] + read_config_keys
for name in cfg.keys():
@ -392,11 +398,17 @@ class DbTransfer(TransferBase):
return rows
def pull_db_users(self, conn):
keys = copy.copy(self.key_list)
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
keymap = switchrule.getRowMap()
for key in keymap:
if keymap[key] in keys:
keys.remove(keymap[key])
keys.append(key)
keys = switchrule.getKeys(keys)
except Exception as e:
keys = self.key_list
logging.error('load switchrule.py fail')
cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user")
@ -520,11 +532,17 @@ class Dbv3Transfer(DbTransfer):
return update_transfer
def pull_db_users(self, conn):
keys = copy.copy(self.key_list)
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
keymap = switchrule.getRowMap()
for key in keymap:
if keymap[key] in keys:
keys.remove(keymap[key])
keys.append(key)
keys = switchrule.getKeys(keys)
except Exception as e:
keys = self.key_list
logging.error('load switchrule.py fail')
cur = conn.cursor()

2
importloader.py

@ -1,5 +1,5 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
def load(name):
try:

10
server_pool.py

@ -117,14 +117,14 @@ class ServerPool(object):
else:
a_config = self.config.copy()
a_config.update(user_config)
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]":
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][-1] == b"]":
a_config['server_ipv6'] = a_config['server_ipv6'][1:-1]
a_config['server'] = a_config['server_ipv6']
a_config['server'] = common.to_str(a_config['server_ipv6'])
a_config['server_port'] = port
a_config['max_connect'] = 128
a_config['method'] = common.to_str(a_config['method'])
try:
logging.info("starting server at [%s]:%d" % (common.to_str(a_config['server']), port))
logging.info("starting server at [%s]:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False, stat_counter=self.stat_counter)
tcp_server.add_to_loop(self.loop)
@ -134,7 +134,7 @@ class ServerPool(object):
udp_server.add_to_loop(self.loop)
self.udp_ipv6_servers_pool.update({port: udp_server})
if common.to_str(a_config['server_ipv6']) == "::":
if a_config['server_ipv6'] == "::":
ipv6_ok = True
except Exception as e:
logging.warn("IPV6 %s " % (e,))
@ -150,7 +150,7 @@ class ServerPool(object):
a_config['max_connect'] = 128
a_config['method'] = common.to_str(a_config['method'])
try:
logging.info("starting server at %s:%d" % (common.to_str(a_config['server']), port))
logging.info("starting server at %s:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server.add_to_loop(self.loop)

36
shadowsocks/common.py

@ -121,7 +121,19 @@ def is_ip(address):
return False
def sync_str_bytes(obj, target_example):
"""sync (obj)'s type to (target_example)'s type"""
if type(obj) != type(target_example):
if type(target_example) == str:
obj = to_str(obj)
if type(target_example) == bytes:
obj = to_bytes(obj)
return obj
def match_regex(regex, text):
# avoid 'cannot use a string pattern on a bytes-like object'
regex = sync_str_bytes(regex, text)
regex = re.compile(regex)
for item in regex.findall(text):
return True
@ -381,12 +393,12 @@ def test_inet_conv():
def test_parse_header():
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
(0, b'www.google.com', 80, 18)
(0, ADDRTYPE_HOST, b'www.google.com', 80, 18)
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
(0, b'8.8.8.8', 53, 7)
(0, ADDRTYPE_IPV4, b'8.8.8.8', 53, 7)
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
b'\x00\x10\x11\x00\x50')) == \
(0, b'2404:6800:4005:805::1011', 80, 19)
(0, ADDRTYPE_IPV6, b'2404:6800:4005:805::1011', 80, 19)
def test_pack_header():
@ -411,7 +423,25 @@ def test_ip_network():
assert 'www.google.com' not in ip_network
def test_sync_str_bytes():
assert sync_str_bytes(b'a\.b', b'a\.b') == b'a\.b'
assert sync_str_bytes('a\.b', b'a\.b') == b'a\.b'
assert sync_str_bytes(b'a\.b', 'a\.b') == 'a\.b'
assert sync_str_bytes('a\.b', 'a\.b') == 'a\.b'
pass
def test_match_regex():
assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'\bgoogle\.com\b', b' google.com ')
pass
if __name__ == '__main__':
test_sync_str_bytes()
test_match_regex()
test_inet_conv()
test_parse_header()
test_pack_header()

4
shadowsocks/crypto/util.py

@ -68,7 +68,9 @@ def find_library(possible_lib_names, search_symbol, library_name):
if path:
paths.append(path)
if not paths:
# always find lib on extend path that to avoid ```CDLL()``` failed on some strange linux environment
# in that case ```ctypes.util.find_library()``` have different find path from ```CDLL()```
if True:
# We may get here when find_library fails because, for example,
# the user does not have sufficient privileges to access those
# tools underlying find_library on linux.

19
shadowsocks/encrypt.py

@ -46,7 +46,7 @@ def try_cipher(key, method=None):
Encryptor(key, method)
def EVP_BytesToKey(password, key_len, iv_len):
def EVP_BytesToKey(password, key_len, iv_len, cache):
# equivalent to OpenSSL's EVP_BytesToKey() with count 1
# so that we make the same key and iv as nodejs version
cached_key = '%s-%d-%d' % (password, key_len, iv_len)
@ -66,13 +66,14 @@ def EVP_BytesToKey(password, key_len, iv_len):
ms = b''.join(m)
key = ms[:key_len]
iv = ms[key_len:key_len + iv_len]
cached_keys[cached_key] = (key, iv)
cached_keys.sweep()
if cache:
cached_keys[cached_key] = (key, iv)
cached_keys.sweep()
return key, iv
class Encryptor(object):
def __init__(self, key, method, iv = None):
def __init__(self, key, method, iv = None, cache = False):
self.key = key
self.method = method
self.iv = None
@ -81,6 +82,7 @@ class Encryptor(object):
self.iv_buf = b''
self.cipher_key = b''
self.decipher = None
self.cache = cache
method = method.lower()
self._method_info = self.get_method_info(method)
if self._method_info:
@ -105,7 +107,7 @@ class Encryptor(object):
password = common.to_bytes(password)
m = self._method_info
if m[0] > 0:
key, iv_ = EVP_BytesToKey(password, m[0], m[1])
key, iv_ = EVP_BytesToKey(password, m[0], m[1], self.cache)
else:
# key_length == 0 indicates we should use the key directly
key, iv = password, b''
@ -119,6 +121,9 @@ class Encryptor(object):
def encrypt(self, buf):
if len(buf) == 0:
if not self.iv_sent:
self.iv_sent = True
return self.cipher_iv
return buf
if self.iv_sent:
return self.cipher.update(buf)
@ -155,7 +160,7 @@ def encrypt_all(password, method, op, data):
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
key, _ = EVP_BytesToKey(password, key_len, iv_len, True)
else:
key = password
if op:
@ -172,7 +177,7 @@ def encrypt_key(password, method):
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
key, _ = EVP_BytesToKey(password, key_len, iv_len, True)
else:
key = password
return key

64
shadowsocks/obfsplugin/auth_chain.py

@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import sys
import hashlib
import logging
import binascii
@ -29,16 +27,16 @@ import datetime
import random
import math
import struct
import zlib
import hmac
import hashlib
import bisect
import shadowsocks
from shadowsocks import common, lru_cache, encrypt
from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord, chr
from shadowsocks.crypto import openssl
rand_bytes = openssl.rand_bytes
def create_auth_chain_a(method):
return auth_chain_a(method)
@ -87,25 +85,25 @@ class xorshift128plus(object):
y = self.v1
self.v0 = y
x ^= ((x & xorshift128plus.mov_mask) << 23)
x ^= (y ^ (x >> 17) ^ (y >> 26)) & xorshift128plus.max_int
x ^= (y ^ (x >> 17) ^ (y >> 26))
self.v1 = x
return (x + y) & xorshift128plus.max_int
def init_from_bin(self, bin):
bin += b'\0' * 16
if len(bin) < 16:
bin += b'\0' * 16
self.v0 = struct.unpack('<Q', bin[:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0]
def init_from_bin_len(self, bin, length):
bin += b'\0' * 16
bin = struct.pack('<H', length) + bin[2:]
self.v0 = struct.unpack('<Q', bin[:8])[0]
if len(bin) < 16:
bin += b'\0' * 16
self.v0 = struct.unpack('<Q', struct.pack('<H', length) + bin[2:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0]
for i in range(4):
self.next()
def match_begin(str1, str2):
if len(str1) >= len(str2):
if str1[:len(str2)] == str2:
@ -335,7 +333,7 @@ class auth_chain_a(auth_base):
def rnd_data(self, buf_size, buf, last_hash, random):
rand_len = self.rnd_data_len(buf_size, last_hash, random)
rnd_data_buf = os.urandom(rand_len)
rnd_data_buf = rand_bytes(rand_len)
if buf_size == 0:
return rnd_data_buf
@ -349,7 +347,6 @@ class auth_chain_a(auth_base):
def pack_client_data(self, buf):
buf = self.encryptor.encrypt(buf)
data = self.rnd_data(len(buf), buf, self.last_client_hash, self.random_client)
data_len = len(data) + 8
mac_key = self.user_key + struct.pack('<I', self.pack_id)
length = len(buf) ^ struct.unpack('<H', self.last_client_hash[14:])[0]
data = struct.pack('<H', length) + data
@ -361,7 +358,6 @@ class auth_chain_a(auth_base):
def pack_server_data(self, buf):
buf = self.encryptor.encrypt(buf)
data = self.rnd_data(len(buf), buf, self.last_server_hash, self.random_server)
data_len = len(data) + 8
mac_key = self.user_key + struct.pack('<I', self.pack_id)
length = len(buf) ^ struct.unpack('<H', self.last_server_hash[14:])[0]
data = struct.pack('<H', length) + data
@ -372,11 +368,10 @@ class auth_chain_a(auth_base):
def pack_auth_data(self, auth_data, buf):
data = auth_data
data_len = 12 + 4 + 16 + 4
data = data + (struct.pack('<H', self.server_info.overhead) + struct.pack('<H', 0))
mac_key = self.server_info.iv + self.server_info.key
check_head = os.urandom(4)
check_head = rand_bytes(4)
self.last_client_hash = hmac.new(mac_key, check_head, self.hashfunc).digest()
check_head += self.last_client_hash[:8]
@ -386,9 +381,9 @@ class auth_chain_a(auth_base):
self.user_key = items[1]
uid = struct.pack('<I', int(items[0]))
except:
uid = os.urandom(4)
uid = rand_bytes(4)
else:
uid = os.urandom(4)
uid = rand_bytes(4)
if self.user_key is None:
self.user_key = self.server_info.key
@ -409,14 +404,17 @@ class auth_chain_a(auth_base):
if self.server_info.data.connection_id > 0xFF000000:
self.server_info.data.local_client_id = b''
if not self.server_info.data.local_client_id:
self.server_info.data.local_client_id = os.urandom(4)
self.server_info.data.local_client_id = rand_bytes(4)
logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),))
self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF
self.server_info.data.connection_id = struct.unpack('<I', rand_bytes(4))[0] & 0xFFFFFF
self.server_info.data.connection_id += 1
return b''.join([struct.pack('<I', utc_time),
self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)])
def on_recv_auth_data(self, utc_time):
pass
def client_pre_encrypt(self, buf):
ret = b''
ogn_data_len = len(buf)
@ -551,6 +549,7 @@ class auth_chain_a(auth_base):
logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf)))
return self.not_match_return(self.recv_buf)
self.on_recv_auth_data(utc_time)
self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.recv_buf = self.recv_buf[36:]
@ -565,7 +564,7 @@ class auth_chain_a(auth_base):
if length >= 4096:
self.raw_trans = True
self.recv_buf = b''
if self.recv_id == 0:
if self.recv_id == 1:
logging.info(self.no_compatible_method + ': over size')
return (b'E' * 2048, False)
else:
@ -581,7 +580,7 @@ class auth_chain_a(auth_base):
))
self.raw_trans = True
self.recv_buf = b''
if self.recv_id == 0:
if self.recv_id == 1:
return (b'E' * 2048, False)
else:
raise Exception('server_post_decrype data uncorrect checksum')
@ -610,9 +609,9 @@ class auth_chain_a(auth_base):
except:
pass
if self.user_key is None:
self.user_id = os.urandom(4)
self.user_id = rand_bytes(4)
self.user_key = self.server_info.key
authdata = os.urandom(3)
authdata = rand_bytes(3)
mac_key = self.server_info.key
md5data = hmac.new(mac_key, authdata, self.hashfunc).digest()
uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0]
@ -621,7 +620,7 @@ class auth_chain_a(auth_base):
encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata + uid
buf = out_buf + rand_bytes(rand_len) + authdata + uid
return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1]
def client_udp_post_decrypt(self, buf):
@ -645,13 +644,13 @@ class auth_chain_a(auth_base):
user_key = self.server_info.key
else:
user_key = self.server_info.recv_iv
authdata = os.urandom(7)
authdata = rand_bytes(7)
mac_key = self.server_info.key
md5data = hmac.new(mac_key, authdata, self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata
buf = out_buf + rand_bytes(rand_len) + authdata
return buf + hmac.new(user_key, buf, self.hashfunc).digest()[:1]
def server_udp_post_decrypt(self, buf):
@ -855,6 +854,7 @@ class auth_chain_e(auth_chain_d):
self.no_compatible_method = 'auth_chain_e'
def rnd_data_len(self, buf_size, last_hash, random):
random.init_from_bin_len(last_hash, buf_size)
other_data_size = buf_size + self.server_info.overhead
# if other_data_size > the bigest item in data_size_list0, not padding any data
if other_data_size >= self.data_size_list0[-1]:
@ -879,15 +879,17 @@ class auth_chain_f(auth_chain_e):
max_client = int(server_info.protocol_param.split('#')[0])
except:
max_client = 64
self.server_info.data.set_max_client(max_client)
try:
self.key_change_interval = int(server_info.protocol_param.split('#')[1]) # config are in second
except:
self.key_change_interval = 60 * 60 * 24 # a day by second
self.key_change_datetime_key = int(int(time.time()) / self.key_change_interval)
def on_recv_auth_data(self, utc_time):
self.key_change_datetime_key = int(utc_time / self.key_change_interval)
self.key_change_datetime_key_bytes = [] # big bit first list
for i in range(7, -1, -1): # big-ending compare to c
self.key_change_datetime_key_bytes.append((self.key_change_datetime_key >> (8 * i)) & 0xFF)
self.server_info.data.set_max_client(max_client)
self.init_data_size(self.server_info.key)
def init_data_size(self, key):
@ -896,9 +898,13 @@ class auth_chain_f(auth_chain_e):
random = xorshift128plus()
# key xor with key_change_datetime_key
new_key = bytearray(key)
new_key_str = ''
for i in range(0, 8):
new_key[i] ^= self.key_change_datetime_key_bytes[i]
random.init_from_bin(new_key)
new_key_str += chr(new_key[i])
for i in range(8, len(new_key)):
new_key_str += chr(new_key[i])
random.init_from_bin(to_bytes(new_key_str))
# 补全数组长为12~24-1
list_len = random.next() % (8 + 16) + (4 + 8)
for i in range(0, list_len):

6
shadowsocks/server.py

@ -108,8 +108,8 @@ def main():
(protocol, password, method, obfs, obfs_param))
if 'server_ipv6' in a_config:
try:
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][
-1] == "]":
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][
-1] == b"]":
a_config['server_ipv6'] = a_config['server_ipv6'][1:-1]
a_config['server_port'] = int(port)
a_config['password'] = password
@ -120,7 +120,7 @@ def main():
a_config['obfs_param'] = obfs_param
a_config['out_bind'] = bind
a_config['out_bindv6'] = bindv6
a_config['server'] = a_config['server_ipv6']
a_config['server'] = common.to_str(a_config['server_ipv6'])
logging.info("starting server at [%s]:%d" %
(a_config['server'], int(port)))
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict))

7
shadowsocks/tcprelay.py

@ -266,7 +266,7 @@ class TCPRelayHandler(object):
def _create_encryptor(self, config):
try:
self._encryptor = encrypt.Encryptor(config['password'],
config['method'])
config['method'], None, True)
return True
except Exception:
self._stage = STAGE_DESTROYED
@ -1163,8 +1163,9 @@ class TCPRelayHandler(object):
self._protocol.dispose()
self._protocol = None
self._encryptor.dispose()
self._encryptor = None
if self._encryptor:
self._encryptor.dispose()
self._encryptor = None
self._dns_resolver.remove_callback(self._handle_dns_resolved)
self._server.remove_handler(self)
if self._add_ref > 0:

2
shadowsocks/udprelay.py

@ -213,6 +213,8 @@ class UDPRelay(object):
server_socket = socket.socket(af, socktype, proto)
server_socket.bind((self._listen_addr, self._listen_port))
server_socket.setblocking(False)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024)
self._server_socket = server_socket
self._stat_callback = stat_callback

2
shadowsocks/version.py

@ -16,5 +16,5 @@
# under the License.
def version():
return 'SSRR 3.2.1 2017-10-15'
return 'SSRR 3.2.2 2018-05-22'

3
switchrule.py

@ -1,3 +1,6 @@
def getRowMap():
return {} # if your db row "encrypt" means "method", write {"encrypt": "method"}
def getKeys(key_list):
return key_list
#return key_list + ['plan'] # append the column name 'plan'

Loading…
Cancel
Save