Browse Source

Finish 3.2.2

many fix
akkariiin/master 3.2.2
Akkariiin 7 years ago
parent
commit
190bf5e79e
  1. 2
      configloader.py
  2. 30
      db_transfer.py
  3. 2
      importloader.py
  4. 10
      server_pool.py
  5. 36
      shadowsocks/common.py
  6. 4
      shadowsocks/crypto/util.py
  7. 19
      shadowsocks/encrypt.py
  8. 64
      shadowsocks/obfsplugin/auth_chain.py
  9. 6
      shadowsocks/server.py
  10. 7
      shadowsocks/tcprelay.py
  11. 2
      shadowsocks/udprelay.py
  12. 2
      shadowsocks/version.py
  13. 3
      switchrule.py

2
configloader.py

@ -1,5 +1,5 @@
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
import importloader import importloader
g_config = None g_config = None

30
db_transfer.py

@ -1,5 +1,5 @@
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import time import time
@ -9,6 +9,7 @@ import traceback
from shadowsocks import common, shell, lru_cache, obfs from shadowsocks import common, shell, lru_cache, obfs
from configloader import load_config, get_config from configloader import load_config, get_config
import importloader import importloader
import copy
switchrule = None switchrule = None
db_instance = None db_instance = None
@ -80,8 +81,10 @@ class TransferBase(object):
def del_server_out_of_bound_safe(self, last_rows, rows): def del_server_out_of_bound_safe(self, last_rows, rows):
#停止超流量的服务 #停止超流量的服务
#启动没超流量的服务 #启动没超流量的服务
keymap = {}
try: try:
switchrule = importloader.load('switchrule') switchrule = importloader.load('switchrule')
keymap = switchrule.getRowMap()
except Exception as e: except Exception as e:
logging.error('load switchrule.py fail') logging.error('load switchrule.py fail')
cur_servers = {} cur_servers = {}
@ -106,7 +109,10 @@ class TransferBase(object):
read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user'] read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user']
for name in read_config_keys: for name in read_config_keys:
if name in row and row[name]: if name in row and row[name]:
cfg[name] = row[name] if name in keymap:
cfg[keymap[name]] = row[name]
else:
cfg[name] = row[name]
merge_config_keys = ['password'] + read_config_keys merge_config_keys = ['password'] + read_config_keys
for name in cfg.keys(): for name in cfg.keys():
@ -392,11 +398,17 @@ class DbTransfer(TransferBase):
return rows return rows
def pull_db_users(self, conn): def pull_db_users(self, conn):
keys = copy.copy(self.key_list)
try: try:
switchrule = importloader.load('switchrule') switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list) keymap = switchrule.getRowMap()
for key in keymap:
if keymap[key] in keys:
keys.remove(keymap[key])
keys.append(key)
keys = switchrule.getKeys(keys)
except Exception as e: except Exception as e:
keys = self.key_list logging.error('load switchrule.py fail')
cur = conn.cursor() cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user") cur.execute("SELECT " + ','.join(keys) + " FROM user")
@ -520,11 +532,17 @@ class Dbv3Transfer(DbTransfer):
return update_transfer return update_transfer
def pull_db_users(self, conn): def pull_db_users(self, conn):
keys = copy.copy(self.key_list)
try: try:
switchrule = importloader.load('switchrule') switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list) keymap = switchrule.getRowMap()
for key in keymap:
if keymap[key] in keys:
keys.remove(keymap[key])
keys.append(key)
keys = switchrule.getKeys(keys)
except Exception as e: except Exception as e:
keys = self.key_list logging.error('load switchrule.py fail')
cur = conn.cursor() cur = conn.cursor()

2
importloader.py

@ -1,5 +1,5 @@
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
def load(name): def load(name):
try: try:

10
server_pool.py

@ -117,14 +117,14 @@ class ServerPool(object):
else: else:
a_config = self.config.copy() a_config = self.config.copy()
a_config.update(user_config) a_config.update(user_config)
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][-1] == b"]":
a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] a_config['server_ipv6'] = a_config['server_ipv6'][1:-1]
a_config['server'] = a_config['server_ipv6'] a_config['server'] = common.to_str(a_config['server_ipv6'])
a_config['server_port'] = port a_config['server_port'] = port
a_config['max_connect'] = 128 a_config['max_connect'] = 128
a_config['method'] = common.to_str(a_config['method']) a_config['method'] = common.to_str(a_config['method'])
try: try:
logging.info("starting server at [%s]:%d" % (common.to_str(a_config['server']), port)) logging.info("starting server at [%s]:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False, stat_counter=self.stat_counter) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False, stat_counter=self.stat_counter)
tcp_server.add_to_loop(self.loop) tcp_server.add_to_loop(self.loop)
@ -134,7 +134,7 @@ class ServerPool(object):
udp_server.add_to_loop(self.loop) udp_server.add_to_loop(self.loop)
self.udp_ipv6_servers_pool.update({port: udp_server}) self.udp_ipv6_servers_pool.update({port: udp_server})
if common.to_str(a_config['server_ipv6']) == "::": if a_config['server_ipv6'] == "::":
ipv6_ok = True ipv6_ok = True
except Exception as e: except Exception as e:
logging.warn("IPV6 %s " % (e,)) logging.warn("IPV6 %s " % (e,))
@ -150,7 +150,7 @@ class ServerPool(object):
a_config['max_connect'] = 128 a_config['max_connect'] = 128
a_config['method'] = common.to_str(a_config['method']) a_config['method'] = common.to_str(a_config['method'])
try: try:
logging.info("starting server at %s:%d" % (common.to_str(a_config['server']), port)) logging.info("starting server at %s:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server.add_to_loop(self.loop) tcp_server.add_to_loop(self.loop)

36
shadowsocks/common.py

@ -121,7 +121,19 @@ def is_ip(address):
return False return False
def sync_str_bytes(obj, target_example):
"""sync (obj)'s type to (target_example)'s type"""
if type(obj) != type(target_example):
if type(target_example) == str:
obj = to_str(obj)
if type(target_example) == bytes:
obj = to_bytes(obj)
return obj
def match_regex(regex, text): def match_regex(regex, text):
# avoid 'cannot use a string pattern on a bytes-like object'
regex = sync_str_bytes(regex, text)
regex = re.compile(regex) regex = re.compile(regex)
for item in regex.findall(text): for item in regex.findall(text):
return True return True
@ -381,12 +393,12 @@ def test_inet_conv():
def test_parse_header(): def test_parse_header():
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \ assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
(0, b'www.google.com', 80, 18) (0, ADDRTYPE_HOST, b'www.google.com', 80, 18)
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \ assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
(0, b'8.8.8.8', 53, 7) (0, ADDRTYPE_IPV4, b'8.8.8.8', 53, 7)
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00' assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
b'\x00\x10\x11\x00\x50')) == \ b'\x00\x10\x11\x00\x50')) == \
(0, b'2404:6800:4005:805::1011', 80, 19) (0, ADDRTYPE_IPV6, b'2404:6800:4005:805::1011', 80, 19)
def test_pack_header(): def test_pack_header():
@ -411,7 +423,25 @@ def test_ip_network():
assert 'www.google.com' not in ip_network assert 'www.google.com' not in ip_network
def test_sync_str_bytes():
assert sync_str_bytes(b'a\.b', b'a\.b') == b'a\.b'
assert sync_str_bytes('a\.b', b'a\.b') == b'a\.b'
assert sync_str_bytes(b'a\.b', 'a\.b') == 'a\.b'
assert sync_str_bytes('a\.b', 'a\.b') == 'a\.b'
pass
def test_match_regex():
assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'\bgoogle\.com\b', b' google.com ')
pass
if __name__ == '__main__': if __name__ == '__main__':
test_sync_str_bytes()
test_match_regex()
test_inet_conv() test_inet_conv()
test_parse_header() test_parse_header()
test_pack_header() test_pack_header()

4
shadowsocks/crypto/util.py

@ -68,7 +68,9 @@ def find_library(possible_lib_names, search_symbol, library_name):
if path: if path:
paths.append(path) paths.append(path)
if not paths: # always find lib on extend path that to avoid ```CDLL()``` failed on some strange linux environment
# in that case ```ctypes.util.find_library()``` have different find path from ```CDLL()```
if True:
# We may get here when find_library fails because, for example, # We may get here when find_library fails because, for example,
# the user does not have sufficient privileges to access those # the user does not have sufficient privileges to access those
# tools underlying find_library on linux. # tools underlying find_library on linux.

19
shadowsocks/encrypt.py

@ -46,7 +46,7 @@ def try_cipher(key, method=None):
Encryptor(key, method) Encryptor(key, method)
def EVP_BytesToKey(password, key_len, iv_len): def EVP_BytesToKey(password, key_len, iv_len, cache):
# equivalent to OpenSSL's EVP_BytesToKey() with count 1 # equivalent to OpenSSL's EVP_BytesToKey() with count 1
# so that we make the same key and iv as nodejs version # so that we make the same key and iv as nodejs version
cached_key = '%s-%d-%d' % (password, key_len, iv_len) cached_key = '%s-%d-%d' % (password, key_len, iv_len)
@ -66,13 +66,14 @@ def EVP_BytesToKey(password, key_len, iv_len):
ms = b''.join(m) ms = b''.join(m)
key = ms[:key_len] key = ms[:key_len]
iv = ms[key_len:key_len + iv_len] iv = ms[key_len:key_len + iv_len]
cached_keys[cached_key] = (key, iv) if cache:
cached_keys.sweep() cached_keys[cached_key] = (key, iv)
cached_keys.sweep()
return key, iv return key, iv
class Encryptor(object): class Encryptor(object):
def __init__(self, key, method, iv = None): def __init__(self, key, method, iv = None, cache = False):
self.key = key self.key = key
self.method = method self.method = method
self.iv = None self.iv = None
@ -81,6 +82,7 @@ class Encryptor(object):
self.iv_buf = b'' self.iv_buf = b''
self.cipher_key = b'' self.cipher_key = b''
self.decipher = None self.decipher = None
self.cache = cache
method = method.lower() method = method.lower()
self._method_info = self.get_method_info(method) self._method_info = self.get_method_info(method)
if self._method_info: if self._method_info:
@ -105,7 +107,7 @@ class Encryptor(object):
password = common.to_bytes(password) password = common.to_bytes(password)
m = self._method_info m = self._method_info
if m[0] > 0: if m[0] > 0:
key, iv_ = EVP_BytesToKey(password, m[0], m[1]) key, iv_ = EVP_BytesToKey(password, m[0], m[1], self.cache)
else: else:
# key_length == 0 indicates we should use the key directly # key_length == 0 indicates we should use the key directly
key, iv = password, b'' key, iv = password, b''
@ -119,6 +121,9 @@ class Encryptor(object):
def encrypt(self, buf): def encrypt(self, buf):
if len(buf) == 0: if len(buf) == 0:
if not self.iv_sent:
self.iv_sent = True
return self.cipher_iv
return buf return buf
if self.iv_sent: if self.iv_sent:
return self.cipher.update(buf) return self.cipher.update(buf)
@ -155,7 +160,7 @@ def encrypt_all(password, method, op, data):
method = method.lower() method = method.lower()
(key_len, iv_len, m) = method_supported[method] (key_len, iv_len, m) = method_supported[method]
if key_len > 0: if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len) key, _ = EVP_BytesToKey(password, key_len, iv_len, True)
else: else:
key = password key = password
if op: if op:
@ -172,7 +177,7 @@ def encrypt_key(password, method):
method = method.lower() method = method.lower()
(key_len, iv_len, m) = method_supported[method] (key_len, iv_len, m) = method_supported[method]
if key_len > 0: if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len) key, _ = EVP_BytesToKey(password, key_len, iv_len, True)
else: else:
key = password key = password
return key return key

64
shadowsocks/obfsplugin/auth_chain.py

@ -18,8 +18,6 @@
from __future__ import absolute_import, division, print_function, \ from __future__ import absolute_import, division, print_function, \
with_statement with_statement
import os
import sys
import hashlib import hashlib
import logging import logging
import binascii import binascii
@ -29,16 +27,16 @@ import datetime
import random import random
import math import math
import struct import struct
import zlib
import hmac import hmac
import hashlib
import bisect import bisect
import shadowsocks import shadowsocks
from shadowsocks import common, lru_cache, encrypt from shadowsocks import common, lru_cache, encrypt
from shadowsocks.obfsplugin import plain from shadowsocks.obfsplugin import plain
from shadowsocks.common import to_bytes, to_str, ord, chr from shadowsocks.common import to_bytes, to_str, ord, chr
from shadowsocks.crypto import openssl
rand_bytes = openssl.rand_bytes
def create_auth_chain_a(method): def create_auth_chain_a(method):
return auth_chain_a(method) return auth_chain_a(method)
@ -87,25 +85,25 @@ class xorshift128plus(object):
y = self.v1 y = self.v1
self.v0 = y self.v0 = y
x ^= ((x & xorshift128plus.mov_mask) << 23) x ^= ((x & xorshift128plus.mov_mask) << 23)
x ^= (y ^ (x >> 17) ^ (y >> 26)) & xorshift128plus.max_int x ^= (y ^ (x >> 17) ^ (y >> 26))
self.v1 = x self.v1 = x
return (x + y) & xorshift128plus.max_int return (x + y) & xorshift128plus.max_int
def init_from_bin(self, bin): def init_from_bin(self, bin):
bin += b'\0' * 16 if len(bin) < 16:
bin += b'\0' * 16
self.v0 = struct.unpack('<Q', bin[:8])[0] self.v0 = struct.unpack('<Q', bin[:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0] self.v1 = struct.unpack('<Q', bin[8:16])[0]
def init_from_bin_len(self, bin, length): def init_from_bin_len(self, bin, length):
bin += b'\0' * 16 if len(bin) < 16:
bin = struct.pack('<H', length) + bin[2:] bin += b'\0' * 16
self.v0 = struct.unpack('<Q', bin[:8])[0] self.v0 = struct.unpack('<Q', struct.pack('<H', length) + bin[2:8])[0]
self.v1 = struct.unpack('<Q', bin[8:16])[0] self.v1 = struct.unpack('<Q', bin[8:16])[0]
for i in range(4): for i in range(4):
self.next() self.next()
def match_begin(str1, str2): def match_begin(str1, str2):
if len(str1) >= len(str2): if len(str1) >= len(str2):
if str1[:len(str2)] == str2: if str1[:len(str2)] == str2:
@ -335,7 +333,7 @@ class auth_chain_a(auth_base):
def rnd_data(self, buf_size, buf, last_hash, random): def rnd_data(self, buf_size, buf, last_hash, random):
rand_len = self.rnd_data_len(buf_size, last_hash, random) rand_len = self.rnd_data_len(buf_size, last_hash, random)
rnd_data_buf = os.urandom(rand_len) rnd_data_buf = rand_bytes(rand_len)
if buf_size == 0: if buf_size == 0:
return rnd_data_buf return rnd_data_buf
@ -349,7 +347,6 @@ class auth_chain_a(auth_base):
def pack_client_data(self, buf): def pack_client_data(self, buf):
buf = self.encryptor.encrypt(buf) buf = self.encryptor.encrypt(buf)
data = self.rnd_data(len(buf), buf, self.last_client_hash, self.random_client) data = self.rnd_data(len(buf), buf, self.last_client_hash, self.random_client)
data_len = len(data) + 8
mac_key = self.user_key + struct.pack('<I', self.pack_id) mac_key = self.user_key + struct.pack('<I', self.pack_id)
length = len(buf) ^ struct.unpack('<H', self.last_client_hash[14:])[0] length = len(buf) ^ struct.unpack('<H', self.last_client_hash[14:])[0]
data = struct.pack('<H', length) + data data = struct.pack('<H', length) + data
@ -361,7 +358,6 @@ class auth_chain_a(auth_base):
def pack_server_data(self, buf): def pack_server_data(self, buf):
buf = self.encryptor.encrypt(buf) buf = self.encryptor.encrypt(buf)
data = self.rnd_data(len(buf), buf, self.last_server_hash, self.random_server) data = self.rnd_data(len(buf), buf, self.last_server_hash, self.random_server)
data_len = len(data) + 8
mac_key = self.user_key + struct.pack('<I', self.pack_id) mac_key = self.user_key + struct.pack('<I', self.pack_id)
length = len(buf) ^ struct.unpack('<H', self.last_server_hash[14:])[0] length = len(buf) ^ struct.unpack('<H', self.last_server_hash[14:])[0]
data = struct.pack('<H', length) + data data = struct.pack('<H', length) + data
@ -372,11 +368,10 @@ class auth_chain_a(auth_base):
def pack_auth_data(self, auth_data, buf): def pack_auth_data(self, auth_data, buf):
data = auth_data data = auth_data
data_len = 12 + 4 + 16 + 4
data = data + (struct.pack('<H', self.server_info.overhead) + struct.pack('<H', 0)) data = data + (struct.pack('<H', self.server_info.overhead) + struct.pack('<H', 0))
mac_key = self.server_info.iv + self.server_info.key mac_key = self.server_info.iv + self.server_info.key
check_head = os.urandom(4) check_head = rand_bytes(4)
self.last_client_hash = hmac.new(mac_key, check_head, self.hashfunc).digest() self.last_client_hash = hmac.new(mac_key, check_head, self.hashfunc).digest()
check_head += self.last_client_hash[:8] check_head += self.last_client_hash[:8]
@ -386,9 +381,9 @@ class auth_chain_a(auth_base):
self.user_key = items[1] self.user_key = items[1]
uid = struct.pack('<I', int(items[0])) uid = struct.pack('<I', int(items[0]))
except: except:
uid = os.urandom(4) uid = rand_bytes(4)
else: else:
uid = os.urandom(4) uid = rand_bytes(4)
if self.user_key is None: if self.user_key is None:
self.user_key = self.server_info.key self.user_key = self.server_info.key
@ -409,14 +404,17 @@ class auth_chain_a(auth_base):
if self.server_info.data.connection_id > 0xFF000000: if self.server_info.data.connection_id > 0xFF000000:
self.server_info.data.local_client_id = b'' self.server_info.data.local_client_id = b''
if not self.server_info.data.local_client_id: if not self.server_info.data.local_client_id:
self.server_info.data.local_client_id = os.urandom(4) self.server_info.data.local_client_id = rand_bytes(4)
logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),))
self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF self.server_info.data.connection_id = struct.unpack('<I', rand_bytes(4))[0] & 0xFFFFFF
self.server_info.data.connection_id += 1 self.server_info.data.connection_id += 1
return b''.join([struct.pack('<I', utc_time), return b''.join([struct.pack('<I', utc_time),
self.server_info.data.local_client_id, self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)]) struct.pack('<I', self.server_info.data.connection_id)])
def on_recv_auth_data(self, utc_time):
pass
def client_pre_encrypt(self, buf): def client_pre_encrypt(self, buf):
ret = b'' ret = b''
ogn_data_len = len(buf) ogn_data_len = len(buf)
@ -551,6 +549,7 @@ class auth_chain_a(auth_base):
logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf)))
return self.not_match_return(self.recv_buf) return self.not_match_return(self.recv_buf)
self.on_recv_auth_data(utc_time)
self.encryptor = encrypt.Encryptor( self.encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4')
self.recv_buf = self.recv_buf[36:] self.recv_buf = self.recv_buf[36:]
@ -565,7 +564,7 @@ class auth_chain_a(auth_base):
if length >= 4096: if length >= 4096:
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
if self.recv_id == 0: if self.recv_id == 1:
logging.info(self.no_compatible_method + ': over size') logging.info(self.no_compatible_method + ': over size')
return (b'E' * 2048, False) return (b'E' * 2048, False)
else: else:
@ -581,7 +580,7 @@ class auth_chain_a(auth_base):
)) ))
self.raw_trans = True self.raw_trans = True
self.recv_buf = b'' self.recv_buf = b''
if self.recv_id == 0: if self.recv_id == 1:
return (b'E' * 2048, False) return (b'E' * 2048, False)
else: else:
raise Exception('server_post_decrype data uncorrect checksum') raise Exception('server_post_decrype data uncorrect checksum')
@ -610,9 +609,9 @@ class auth_chain_a(auth_base):
except: except:
pass pass
if self.user_key is None: if self.user_key is None:
self.user_id = os.urandom(4) self.user_id = rand_bytes(4)
self.user_key = self.server_info.key self.user_key = self.server_info.key
authdata = os.urandom(3) authdata = rand_bytes(3)
mac_key = self.server_info.key mac_key = self.server_info.key
md5data = hmac.new(mac_key, authdata, self.hashfunc).digest() md5data = hmac.new(mac_key, authdata, self.hashfunc).digest()
uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0] uid = struct.unpack('<I', self.user_id)[0] ^ struct.unpack('<I', md5data[:4])[0]
@ -621,7 +620,7 @@ class auth_chain_a(auth_base):
encryptor = encrypt.Encryptor( encryptor = encrypt.Encryptor(
to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4') to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf) out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata + uid buf = out_buf + rand_bytes(rand_len) + authdata + uid
return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1] return buf + hmac.new(self.user_key, buf, self.hashfunc).digest()[:1]
def client_udp_post_decrypt(self, buf): def client_udp_post_decrypt(self, buf):
@ -645,13 +644,13 @@ class auth_chain_a(auth_base):
user_key = self.server_info.key user_key = self.server_info.key
else: else:
user_key = self.server_info.recv_iv user_key = self.server_info.recv_iv
authdata = os.urandom(7) authdata = rand_bytes(7)
mac_key = self.server_info.key mac_key = self.server_info.key
md5data = hmac.new(mac_key, authdata, self.hashfunc).digest() md5data = hmac.new(mac_key, authdata, self.hashfunc).digest()
rand_len = self.udp_rnd_data_len(md5data, self.random_server) rand_len = self.udp_rnd_data_len(md5data, self.random_server)
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4') encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(user_key)) + to_bytes(base64.b64encode(md5data)), 'rc4')
out_buf = encryptor.encrypt(buf) out_buf = encryptor.encrypt(buf)
buf = out_buf + os.urandom(rand_len) + authdata buf = out_buf + rand_bytes(rand_len) + authdata
return buf + hmac.new(user_key, buf, self.hashfunc).digest()[:1] return buf + hmac.new(user_key, buf, self.hashfunc).digest()[:1]
def server_udp_post_decrypt(self, buf): def server_udp_post_decrypt(self, buf):
@ -855,6 +854,7 @@ class auth_chain_e(auth_chain_d):
self.no_compatible_method = 'auth_chain_e' self.no_compatible_method = 'auth_chain_e'
def rnd_data_len(self, buf_size, last_hash, random): def rnd_data_len(self, buf_size, last_hash, random):
random.init_from_bin_len(last_hash, buf_size)
other_data_size = buf_size + self.server_info.overhead other_data_size = buf_size + self.server_info.overhead
# if other_data_size > the bigest item in data_size_list0, not padding any data # if other_data_size > the bigest item in data_size_list0, not padding any data
if other_data_size >= self.data_size_list0[-1]: if other_data_size >= self.data_size_list0[-1]:
@ -879,15 +879,17 @@ class auth_chain_f(auth_chain_e):
max_client = int(server_info.protocol_param.split('#')[0]) max_client = int(server_info.protocol_param.split('#')[0])
except: except:
max_client = 64 max_client = 64
self.server_info.data.set_max_client(max_client)
try: try:
self.key_change_interval = int(server_info.protocol_param.split('#')[1]) # config are in second self.key_change_interval = int(server_info.protocol_param.split('#')[1]) # config are in second
except: except:
self.key_change_interval = 60 * 60 * 24 # a day by second self.key_change_interval = 60 * 60 * 24 # a day by second
self.key_change_datetime_key = int(int(time.time()) / self.key_change_interval)
def on_recv_auth_data(self, utc_time):
self.key_change_datetime_key = int(utc_time / self.key_change_interval)
self.key_change_datetime_key_bytes = [] # big bit first list self.key_change_datetime_key_bytes = [] # big bit first list
for i in range(7, -1, -1): # big-ending compare to c for i in range(7, -1, -1): # big-ending compare to c
self.key_change_datetime_key_bytes.append((self.key_change_datetime_key >> (8 * i)) & 0xFF) self.key_change_datetime_key_bytes.append((self.key_change_datetime_key >> (8 * i)) & 0xFF)
self.server_info.data.set_max_client(max_client)
self.init_data_size(self.server_info.key) self.init_data_size(self.server_info.key)
def init_data_size(self, key): def init_data_size(self, key):
@ -896,9 +898,13 @@ class auth_chain_f(auth_chain_e):
random = xorshift128plus() random = xorshift128plus()
# key xor with key_change_datetime_key # key xor with key_change_datetime_key
new_key = bytearray(key) new_key = bytearray(key)
new_key_str = ''
for i in range(0, 8): for i in range(0, 8):
new_key[i] ^= self.key_change_datetime_key_bytes[i] new_key[i] ^= self.key_change_datetime_key_bytes[i]
random.init_from_bin(new_key) new_key_str += chr(new_key[i])
for i in range(8, len(new_key)):
new_key_str += chr(new_key[i])
random.init_from_bin(to_bytes(new_key_str))
# 补全数组长为12~24-1 # 补全数组长为12~24-1
list_len = random.next() % (8 + 16) + (4 + 8) list_len = random.next() % (8 + 16) + (4 + 8)
for i in range(0, list_len): for i in range(0, list_len):

6
shadowsocks/server.py

@ -108,8 +108,8 @@ def main():
(protocol, password, method, obfs, obfs_param)) (protocol, password, method, obfs, obfs_param))
if 'server_ipv6' in a_config: if 'server_ipv6' in a_config:
try: try:
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][ if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][
-1] == "]": -1] == b"]":
a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] a_config['server_ipv6'] = a_config['server_ipv6'][1:-1]
a_config['server_port'] = int(port) a_config['server_port'] = int(port)
a_config['password'] = password a_config['password'] = password
@ -120,7 +120,7 @@ def main():
a_config['obfs_param'] = obfs_param a_config['obfs_param'] = obfs_param
a_config['out_bind'] = bind a_config['out_bind'] = bind
a_config['out_bindv6'] = bindv6 a_config['out_bindv6'] = bindv6
a_config['server'] = a_config['server_ipv6'] a_config['server'] = common.to_str(a_config['server_ipv6'])
logging.info("starting server at [%s]:%d" % logging.info("starting server at [%s]:%d" %
(a_config['server'], int(port))) (a_config['server'], int(port)))
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict)) tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict))

7
shadowsocks/tcprelay.py

@ -266,7 +266,7 @@ class TCPRelayHandler(object):
def _create_encryptor(self, config): def _create_encryptor(self, config):
try: try:
self._encryptor = encrypt.Encryptor(config['password'], self._encryptor = encrypt.Encryptor(config['password'],
config['method']) config['method'], None, True)
return True return True
except Exception: except Exception:
self._stage = STAGE_DESTROYED self._stage = STAGE_DESTROYED
@ -1163,8 +1163,9 @@ class TCPRelayHandler(object):
self._protocol.dispose() self._protocol.dispose()
self._protocol = None self._protocol = None
self._encryptor.dispose() if self._encryptor:
self._encryptor = None self._encryptor.dispose()
self._encryptor = None
self._dns_resolver.remove_callback(self._handle_dns_resolved) self._dns_resolver.remove_callback(self._handle_dns_resolved)
self._server.remove_handler(self) self._server.remove_handler(self)
if self._add_ref > 0: if self._add_ref > 0:

2
shadowsocks/udprelay.py

@ -213,6 +213,8 @@ class UDPRelay(object):
server_socket = socket.socket(af, socktype, proto) server_socket = socket.socket(af, socktype, proto)
server_socket.bind((self._listen_addr, self._listen_port)) server_socket.bind((self._listen_addr, self._listen_port))
server_socket.setblocking(False) server_socket.setblocking(False)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024)
self._server_socket = server_socket self._server_socket = server_socket
self._stat_callback = stat_callback self._stat_callback = stat_callback

2
shadowsocks/version.py

@ -16,5 +16,5 @@
# under the License. # under the License.
def version(): def version():
return 'SSRR 3.2.1 2017-10-15' return 'SSRR 3.2.2 2018-05-22'

3
switchrule.py

@ -1,3 +1,6 @@
def getRowMap():
return {} # if your db row "encrypt" means "method", write {"encrypt": "method"}
def getKeys(key_list): def getKeys(key_list):
return key_list return key_list
#return key_list + ['plan'] # append the column name 'plan' #return key_list + ['plan'] # append the column name 'plan'

Loading…
Cancel
Save