diff --git a/db_transfer.py b/db_transfer.py index 61a21dc..e1a404c 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -9,6 +9,7 @@ import traceback from shadowsocks import common, shell, lru_cache, obfs from configloader import load_config, get_config import importloader +import copy switchrule = None db_instance = None @@ -80,8 +81,10 @@ class TransferBase(object): def del_server_out_of_bound_safe(self, last_rows, rows): #停止超流量的服务 #启动没超流量的服务 + keymap = {} try: switchrule = importloader.load('switchrule') + keymap = switchrule.getRowMap() except Exception as e: logging.error('load switchrule.py fail') cur_servers = {} @@ -106,7 +109,10 @@ class TransferBase(object): read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user'] for name in read_config_keys: if name in row and row[name]: - cfg[name] = row[name] + if name in keymap: + cfg[keymap[name]] = row[name] + else: + cfg[name] = row[name] merge_config_keys = ['password'] + read_config_keys for name in cfg.keys(): @@ -392,11 +398,17 @@ class DbTransfer(TransferBase): return rows def pull_db_users(self, conn): + keys = copy.copy(self.key_list) try: switchrule = importloader.load('switchrule') - keys = switchrule.getKeys(self.key_list) + keymap = switchrule.getRowMap() + for key in keymap: + if keymap[key] in keys: + keys.remove(keymap[key]) + keys.append(key) + keys = switchrule.getKeys(keys) except Exception as e: - keys = self.key_list + logging.error('load switchrule.py fail') cur = conn.cursor() cur.execute("SELECT " + ','.join(keys) + " FROM user") @@ -520,11 +532,17 @@ class Dbv3Transfer(DbTransfer): return update_transfer def pull_db_users(self, conn): + keys = copy.copy(self.key_list) try: switchrule = importloader.load('switchrule') - keys = switchrule.getKeys(self.key_list) + keymap = switchrule.getRowMap() + for key in keymap: + if keymap[key] in keys: + keys.remove(keymap[key]) + keys.append(key) + keys = switchrule.getKeys(keys) except Exception as e: - keys = self.key_list + logging.error('load switchrule.py fail') cur = conn.cursor() diff --git a/switchrule.py b/switchrule.py index 6687e12..56ed995 100644 --- a/switchrule.py +++ b/switchrule.py @@ -1,3 +1,6 @@ +def getRowMap(): + return {} # if your db row "encrypt" means "method", write {"encrypt": "method"} + def getKeys(key_list): return key_list #return key_list + ['plan'] # append the column name 'plan'