From 8dac9faa286c68704b44baf28073094d0d7ace35 Mon Sep 17 00:00:00 2001 From: BreakWa11 Date: Tue, 14 Jun 2016 14:06:52 +0800 Subject: [PATCH] 'obfs', 'protocol', 'method' leave blank that use the config of config.json --- db_transfer.py | 59 ++++++++++++++++++---------------- shadowsocks/obfsplugin/auth.py | 6 ---- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/db_transfer.py b/db_transfer.py index 18423e1..ddb4bcc 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -2,7 +2,6 @@ # -*- coding: UTF-8 -*- import logging -import cymysql import time import sys from server_pool import ServerPool @@ -25,31 +24,8 @@ class DbTransfer(object): DbTransfer.instance = DbTransfer() return DbTransfer.instance - def push_db_all_user(self): - #更新用户流量到数据库 - last_transfer = self.last_get_transfer - curr_transfer = ServerPool.get_instance().get_servers_transfer() - #上次和本次的增量 - dt_transfer = {} - for id in curr_transfer.keys(): - if id in last_transfer: - if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]: - continue - elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: - continue - elif last_transfer[id][0] <= curr_transfer[id][0] and \ - last_transfer[id][1] <= curr_transfer[id][1]: - dt_transfer[id] = [int((curr_transfer[id][0] - last_transfer[id][0]) * get_config().TRANSFER_MUL), - int((curr_transfer[id][1] - last_transfer[id][1]) * get_config().TRANSFER_MUL)] - else: - dt_transfer[id] = [int(curr_transfer[id][0] * get_config().TRANSFER_MUL), - int(curr_transfer[id][1] * get_config().TRANSFER_MUL)] - else: - if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: - continue - dt_transfer[id] = [int(curr_transfer[id][0] * get_config().TRANSFER_MUL), - int(curr_transfer[id][1] * get_config().TRANSFER_MUL)] - + def update_all_user(self, dt_transfer): + import cymysql query_head = 'UPDATE user' query_sub_when = '' query_sub_when2 = '' @@ -78,10 +54,38 @@ class DbTransfer(object): cur.close() conn.commit() conn.close() + + def push_db_all_user(self): + #更新用户流量到数据库 + last_transfer = self.last_get_transfer + curr_transfer = ServerPool.get_instance().get_servers_transfer() + #上次和本次的增量 + dt_transfer = {} + for id in curr_transfer.keys(): + if id in last_transfer: + if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]: + continue + elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: + continue + elif last_transfer[id][0] <= curr_transfer[id][0] and \ + last_transfer[id][1] <= curr_transfer[id][1]: + dt_transfer[id] = [int((curr_transfer[id][0] - last_transfer[id][0]) * get_config().TRANSFER_MUL), + int((curr_transfer[id][1] - last_transfer[id][1]) * get_config().TRANSFER_MUL)] + else: + dt_transfer[id] = [int(curr_transfer[id][0] * get_config().TRANSFER_MUL), + int(curr_transfer[id][1] * get_config().TRANSFER_MUL)] + else: + if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: + continue + dt_transfer[id] = [int(curr_transfer[id][0] * get_config().TRANSFER_MUL), + int(curr_transfer[id][1] * get_config().TRANSFER_MUL)] + + self.update_all_user(dt_transfer) self.last_get_transfer = curr_transfer @staticmethod def pull_db_all_user(): + import cymysql #数据库所有用户信息 try: import switchrule @@ -89,7 +93,6 @@ class DbTransfer(object): keys = switchrule.getKeys() except Exception as e: keys = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable' ] - reload(cymysql) conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT, user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS, db=get_config().MYSQL_DB, charset='utf8') cur = conn.cursor() @@ -126,7 +129,7 @@ class DbTransfer(object): passwd = common.to_bytes(row['passwd']) cfg = {'password': passwd} for name in ['method', 'obfs', 'protocol']: - if name in row: + if name in row and row[name]: cfg[name] = row[name] for name in cfg.keys(): diff --git a/shadowsocks/obfsplugin/auth.py b/shadowsocks/obfsplugin/auth.py index 1217cc4..d64c98b 100644 --- a/shadowsocks/obfsplugin/auth.py +++ b/shadowsocks/obfsplugin/auth.py @@ -206,8 +206,6 @@ class auth_simple(verify_base): self.server_info.data.set_max_client(max_client) def pack_data(self, buf): - if len(buf) == 0: - return b'' rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16) data = common.chr(len(rnd_data) + 1) + rnd_data + buf data = struct.pack('>H', len(data) + 6) + data @@ -364,8 +362,6 @@ class auth_sha1(verify_base): self.server_info.data.set_max_client(max_client) def pack_data(self, buf): - if len(buf) == 0: - return b'' rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16) data = common.chr(len(rnd_data) + 1) + rnd_data + buf data = struct.pack('>H', len(data) + 6) + data @@ -606,8 +602,6 @@ class auth_sha1_v2(verify_base): return common.chr(255) + struct.pack('>H', len(rnd_data) + 3) + rnd_data def pack_data(self, buf): - if len(buf) == 0: - return b'' data = self.rnd_data(len(buf)) + buf data = struct.pack('>H', len(data) + 6) + data adler32 = zlib.adler32(data) & 0xFFFFFFFF