diff --git a/db_transfer.py b/db_transfer.py index 7ed02c4..502aef6 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -18,6 +18,8 @@ class DbTransfer(object): import threading self.last_get_transfer = {} self.event = threading.Event() + self.user_pass = {} + self.port_uid_table = {} def update_all_user(self, dt_transfer): import cymysql @@ -27,8 +29,6 @@ class DbTransfer(object): query_sub_in = None last_time = time.time() for id in dt_transfer.keys(): - if dt_transfer[id][0] == 0 and dt_transfer[id][1] == 0: - continue query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0]) query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1]) if query_sub_in is not None: @@ -41,7 +41,6 @@ class DbTransfer(object): ' END, d = CASE port' + query_sub_when2 + \ ' END, t = ' + str(int(last_time)) + \ ' WHERE port IN (%s)' % query_sub_in - #print query_sql conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT, user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS, db=get_config().MYSQL_DB, charset='utf8') cur = conn.cursor() @@ -57,23 +56,26 @@ class DbTransfer(object): #上次和本次的增量 dt_transfer = {} for id in curr_transfer.keys(): + update_trs = 1024 * max(2048 - self.user_pass.get(id, 0) * 64, 16) if id in last_transfer: - if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]: - continue - elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: + if curr_transfer[id][0] + curr_transfer[id][1] - last_transfer[id][0] - last_transfer[id][1] < update_trs: + self.user_pass[id] = self.user_pass.get(id, 0) + 1 continue - elif last_transfer[id][0] <= curr_transfer[id][0] and \ - last_transfer[id][1] <= curr_transfer[id][1]: + if last_transfer[id][0] <= curr_transfer[id][0] and \ + last_transfer[id][1] <= curr_transfer[id][1]: dt_transfer[id] = [int((curr_transfer[id][0] - last_transfer[id][0]) * get_config().TRANSFER_MUL), int((curr_transfer[id][1] - last_transfer[id][1]) * get_config().TRANSFER_MUL)] else: dt_transfer[id] = [int(curr_transfer[id][0] * get_config().TRANSFER_MUL), int(curr_transfer[id][1] * get_config().TRANSFER_MUL)] else: - if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: + if curr_transfer[id][0] + curr_transfer[id][1] < update_trs: + self.user_pass[id] = self.user_pass.get(id, 0) + 1 continue dt_transfer[id] = [int(curr_transfer[id][0] * get_config().TRANSFER_MUL), int(curr_transfer[id][1] * get_config().TRANSFER_MUL)] + if id in self.user_pass: + del self.user_pass[id] self.update_all_user(dt_transfer) self.last_get_transfer = curr_transfer @@ -126,8 +128,9 @@ class DbTransfer(object): port = row['port'] passwd = common.to_bytes(row['passwd']) cfg = {'password': passwd} + self.port_uid_table[row['port']] = row['id'] - read_config_keys = ['method', 'obfs', 'protocol', 'forbidden_ip', 'forbidden_port'] + read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port'] for name in read_config_keys: if name in row and row[name]: cfg[name] = row[name] @@ -178,6 +181,7 @@ class DbTransfer(object): else: logging.info('db stop server at port [%s] reason: port not exist' % (row['port'])) ServerPool.get_instance().cb_del_server(row['port']) + del self.port_uid_table[row['port']] if len(new_servers) > 0: from shadowsocks import eventloop @@ -230,6 +234,85 @@ class DbTransfer(object): global db_instance db_instance.event.set() +class Dbv3Transfer(DbTransfer): + def __init__(self): + super(Dbv3Transfer, self).__init__() + + def update_all_user(self, dt_transfer): + import cymysql + + query_head = 'UPDATE user' + query_sub_when = '' + query_sub_when2 = '' + query_sub_in = None + last_time = time.time() + + alive_user_count = 0 + bandwidth_thistime = 0 + + conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT, user=get_config().MYSQL_USER, + passwd=get_config().MYSQL_PASS, db=get_config().MYSQL_DB, charset='utf8') + conn.autocommit(True) + + for id in dt_transfer.keys(): + transfer = dt_transfer[id] + query_sub_when += ' WHEN %s THEN u+%s' % (id, transfer[0]) + query_sub_when2 += ' WHEN %s THEN d+%s' % (id, transfer[1]) + + cur = conn.cursor() + cur.execute("INSERT INTO `user_traffic_log` (`id`, `user_id`, `u`, `d`, `Node_ID`, `rate`, `traffic`, `log_time`) VALUES (NULL, '" + \ + str(self.port_uid_table[id]) + "', '" + str(transfer[0]) + "', '" + str(transfer[1]) + "', '" + \ + str(get_config().NODE_ID) + "', '" + str(get_config().TRANSFER_MUL) + "', '" + \ + self.traffic_format(transfer[0] + transfer[1]) + "', unix_timestamp()); ") + cur.close() + + alive_user_count = alive_user_count + 1 + bandwidth_thistime = bandwidth_thistime + transfer[0] + transfer[1] + + if query_sub_in is not None: + query_sub_in += ',%s' % id + else: + query_sub_in = '%s' % id + + if query_sub_when != '': + query_sql = query_head + ' SET u = CASE port' + query_sub_when + \ + ' END, d = CASE port' + query_sub_when2 + \ + ' END, t = ' + str(int(last_time)) + \ + ' WHERE port IN (%s)' % query_sub_in + cur = conn.cursor() + cur.execute(query_sql) + cur.close() + + cur = conn.cursor() + cur.execute("INSERT INTO `ss_node_online_log` (`id`, `Node_ID`, `online_user`, `log_time`) VALUES (NULL, '" + \ + str(get_config().NODE_ID) + "', '" + str(alive_user_count) + "', unix_timestamp()); ") + cur.close() + + cur = conn.cursor() + cur.execute("INSERT INTO `ss_node_info_log` (`id`, `node_id`, `uptime`, `load`, `log_time`) VALUES (NULL, '" + \ + str(get_config().NODE_ID) + "', '" + str(self.uptime()) + "', '" + \ + str(self.load()) + "', unix_timestamp()); ") + cur.close() + + conn.close() + + def load(self): + import os + return os.popen("cat /proc/loadavg | awk '{ print $1\" \"$2\" \"$3 }'").readlines()[0] + + def uptime(self): + with open('/proc/uptime', 'r') as f: + return float(f.readline().split()[0]) + + def traffic_format(self, traffic): + if traffic < 1024 * 8: + return str(traffic) + "B"; + + if traffic < 1024 * 1024 * 8: + return str(round((traffic / 1024.0), 2)) + "KB"; + + return str(round((traffic / 1048576.0), 2)) + "MB"; + class MuJsonTransfer(DbTransfer): def __init__(self): super(MuJsonTransfer, self).__init__() diff --git a/server.py b/server.py index 75ce340..4636721 100644 --- a/server.py +++ b/server.py @@ -47,8 +47,10 @@ def main(): else: if get_config().API_INTERFACE == 'mudbjson': thread = MainThread(db_transfer.MuJsonTransfer) - else: + elif get_config().API_INTERFACE == 'sspanelv2': thread = MainThread(db_transfer.DbTransfer) + else: + thread = MainThread(db_transfer.Dbv3Transfer) thread.start() try: while thread.is_alive(): diff --git a/switchrule.py b/switchrule.py index 5b2d313..d67c6c6 100644 --- a/switchrule.py +++ b/switchrule.py @@ -3,9 +3,9 @@ from configloader import load_config, get_config def getKeys(): key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable' ] if get_config().API_INTERFACE == 'sspanelv3': - key_list += ['method'] + key_list += ['id', 'method'] elif get_config().API_INTERFACE == 'sspanelv3ssr': - key_list += ['method', 'obfs', 'protocol'] + key_list += ['id', 'method', 'obfs', 'protocol'] return key_list #return key_list + ['plan'] # append the column name 'plan'