#!/usr/bin/python # -*- coding: UTF-8 -*- import logging import time import sys from server_pool import ServerPool import traceback from shadowsocks import common, shell, lru_cache, obfs from configloader import load_config, get_config import importloader switchrule = None db_instance = None class TransferBase(object): def __init__(self): import threading self.event = threading.Event() self.key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable'] self.last_get_transfer = {} #上一次的实际流量 self.last_update_transfer = {} #上一次更新到的流量(小于等于实际流量) self.force_update_transfer = set() #强制推入数据库的ID self.port_uid_table = {} #端口到uid的映射(仅v3以上有用) self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30) #用户在线状态记录 self.pull_ok = False #记录是否已经拉出过数据 self.mu_ports = {} def load_cfg(self): pass def push_db_all_user(self): if self.pull_ok is False: return #更新用户流量到数据库 last_transfer = self.last_update_transfer curr_transfer = ServerPool.get_instance().get_servers_transfer() #上次和本次的增量 dt_transfer = {} for id in self.force_update_transfer: #此表中的用户统计上次未计入的流量 if id in self.last_get_transfer and id in last_transfer: dt_transfer[id] = [self.last_get_transfer[id][0] - last_transfer[id][0], self.last_get_transfer[id][1] - last_transfer[id][1]] for id in curr_transfer.keys(): if id in self.force_update_transfer or id in self.mu_ports: continue #算出与上次记录的流量差值,保存于dt_transfer表 if id in last_transfer: if curr_transfer[id][0] + curr_transfer[id][1] - last_transfer[id][0] - last_transfer[id][1] <= 0: continue dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0], curr_transfer[id][1] - last_transfer[id][1]] else: if curr_transfer[id][0] + curr_transfer[id][1] <= 0: continue dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] #有流量的,先记录在线状态 if id in self.last_get_transfer: if curr_transfer[id][0] + curr_transfer[id][1] > self.last_get_transfer[id][0] + self.last_get_transfer[id][1]: self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1] else: self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1] self.onlineuser_cache.sweep() update_transfer = self.update_all_user(dt_transfer) #返回有更新的表 for id in update_transfer.keys(): #其增量加在此表 if id not in self.force_update_transfer: #但排除在force_update_transfer内的 last = self.last_update_transfer.get(id, [0,0]) self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]] self.last_get_transfer = curr_transfer for id in self.force_update_transfer: if id in self.last_update_transfer: del self.last_update_transfer[id] if id in self.last_get_transfer: del self.last_get_transfer[id] self.force_update_transfer = set() def del_server_out_of_bound_safe(self, last_rows, rows): #停止超流量的服务 #启动没超流量的服务 try: switchrule = importloader.load('switchrule') except Exception as e: logging.error('load switchrule.py fail') cur_servers = {} new_servers = {} allow_users = {} mu_servers = {} config = shell.get_config(False) for row in rows: try: allow = switchrule.isTurnOn(row) and row['enable'] == 1 and row['u'] + row['d'] < row['transfer_enable'] except Exception as e: allow = False port = row['port'] passwd = common.to_bytes(row['passwd']) if hasattr(passwd, 'encode'): passwd = passwd.encode('utf-8') cfg = {'password': passwd} if 'id' in row: self.port_uid_table[row['port']] = row['id'] read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user'] for name in read_config_keys: if name in row and row[name]: cfg[name] = row[name] merge_config_keys = ['password'] + read_config_keys for name in cfg.keys(): if hasattr(cfg[name], 'encode'): try: cfg[name] = cfg[name].encode('utf-8') except Exception as e: logging.warning('encode cfg key "%s" fail, val "%s"' % (name, cfg[name])) if port not in cur_servers: cur_servers[port] = passwd else: logging.error('more than one user use the same port [%s]' % (port,)) continue if allow: allow_users[port] = passwd if 'protocol' in cfg and 'protocol_param' in cfg and common.to_str(cfg['protocol']) in obfs.mu_protocol(): if '#' in common.to_str(cfg['protocol_param']): mu_servers[port] = passwd del allow_users[port] cfgchange = False if port in ServerPool.get_instance().tcp_servers_pool: relay = ServerPool.get_instance().tcp_servers_pool[port] for name in merge_config_keys: if name in cfg and not self.cmp(cfg[name], relay._config[name]): cfgchange = True break if not cfgchange and port in ServerPool.get_instance().tcp_ipv6_servers_pool: relay = ServerPool.get_instance().tcp_ipv6_servers_pool[port] for name in merge_config_keys: if name in cfg and not self.cmp(cfg[name], relay._config[name]): cfgchange = True break if port in mu_servers: if ServerPool.get_instance().server_is_run(port) > 0: if cfgchange: logging.info('db stop server at port [%s] reason: config changed: %s' % (port, cfg)) ServerPool.get_instance().cb_del_server(port) self.force_update_transfer.add(port) new_servers[port] = (passwd, cfg) else: self.new_server(port, passwd, cfg) else: if ServerPool.get_instance().server_is_run(port) > 0: if config['additional_ports_only'] or not allow: logging.info('db stop server at port [%s]' % (port,)) ServerPool.get_instance().cb_del_server(port) self.force_update_transfer.add(port) else: if cfgchange: logging.info('db stop server at port [%s] reason: config changed: %s' % (port, cfg)) ServerPool.get_instance().cb_del_server(port) self.force_update_transfer.add(port) new_servers[port] = (passwd, cfg) elif not config['additional_ports_only'] and allow and port > 0 and port < 65536 and ServerPool.get_instance().server_run_status(port) is False: self.new_server(port, passwd, cfg) for row in last_rows: if row['port'] in cur_servers: pass else: logging.info('db stop server at port [%s] reason: port not exist' % (row['port'])) ServerPool.get_instance().cb_del_server(row['port']) self.clear_cache(row['port']) if row['port'] in self.port_uid_table: del self.port_uid_table[row['port']] if len(new_servers) > 0: from shadowsocks import eventloop self.event.wait(eventloop.TIMEOUT_PRECISION + eventloop.TIMEOUT_PRECISION / 2) for port in new_servers.keys(): passwd, cfg = new_servers[port] self.new_server(port, passwd, cfg) logging.debug('db allow users %s \nmu_servers %s' % (allow_users, mu_servers)) for port in mu_servers: ServerPool.get_instance().update_mu_users(port, allow_users) self.mu_ports = mu_servers def clear_cache(self, port): if port in self.force_update_transfer: del self.force_update_transfer[port] if port in self.last_get_transfer: del self.last_get_transfer[port] if port in self.last_update_transfer: del self.last_update_transfer[port] def new_server(self, port, passwd, cfg): protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin')) method = cfg.get('method', ServerPool.get_instance().config.get('method', 'None')) obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain')) logging.info('db start server at port [%s] pass [%s] protocol [%s] method [%s] obfs [%s]' % (port, passwd, protocol, method, obfs)) ServerPool.get_instance().new_server(port, cfg) def cmp(self, val1, val2): if type(val1) is bytes: val1 = common.to_str(val1) if type(val2) is bytes: val2 = common.to_str(val2) return val1 == val2 @staticmethod def del_servers(): for port in [v for v in ServerPool.get_instance().tcp_servers_pool.keys()]: if ServerPool.get_instance().server_is_run(port) > 0: ServerPool.get_instance().cb_del_server(port) for port in [v for v in ServerPool.get_instance().tcp_ipv6_servers_pool.keys()]: if ServerPool.get_instance().server_is_run(port) > 0: ServerPool.get_instance().cb_del_server(port) @staticmethod def thread_db(obj): import socket import time global db_instance timeout = 60 socket.setdefaulttimeout(timeout) last_rows = [] db_instance = obj() ServerPool.get_instance() shell.log_shadowsocks_version() try: import resource logging.info('current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) except: pass try: while True: load_config() db_instance.load_cfg() try: db_instance.push_db_all_user() rows = db_instance.pull_db_all_user() if rows: db_instance.pull_ok = True config = shell.get_config(False) for port in config['additional_ports']: val = config['additional_ports'][port] val['port'] = int(port) val['enable'] = 1 val['transfer_enable'] = 1024 ** 7 val['u'] = 0 val['d'] = 0 if "password" in val: val["passwd"] = val["password"] rows.append(val) db_instance.del_server_out_of_bound_safe(last_rows, rows) last_rows = rows except Exception as e: trace = traceback.format_exc() logging.error(trace) #logging.warn('db thread except:%s' % e) if db_instance.event.wait(get_config().UPDATE_TIME) or not ServerPool.get_instance().thread.is_alive(): break except KeyboardInterrupt as e: pass db_instance.del_servers() ServerPool.get_instance().stop() db_instance = None @staticmethod def thread_db_stop(): global db_instance db_instance.event.set() class DbTransfer(TransferBase): def __init__(self): super(DbTransfer, self).__init__() self.user_pass = {} #记录更新此用户流量时被跳过多少次 self.cfg = { "host": "127.0.0.1", "port": 3306, "user": "ss", "password": "pass", "db": "shadowsocks", "node_id": 0, "transfer_mul": 1.0, "ssl_enable": 0, "ssl_ca": "", "ssl_cert": "", "ssl_key": ""} self.load_cfg() def load_cfg(self): import json config_path = get_config().MYSQL_CONFIG cfg = None with open(config_path, 'rb+') as f: cfg = json.loads(f.read().decode('utf8')) if cfg: self.cfg.update(cfg) def update_all_user(self, dt_transfer): import cymysql update_transfer = {} query_head = 'UPDATE user' query_sub_when = '' query_sub_when2 = '' query_sub_in = None last_time = time.time() for id in dt_transfer.keys(): transfer = dt_transfer[id] #小于最低更新流量的先不更新 update_trs = 1024 * (2048 - self.user_pass.get(id, 0) * 64) if transfer[0] + transfer[1] < update_trs and id not in self.force_update_transfer: self.user_pass[id] = self.user_pass.get(id, 0) + 1 continue if id in self.user_pass: del self.user_pass[id] query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * self.cfg["transfer_mul"])) query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * self.cfg["transfer_mul"])) update_transfer[id] = transfer if query_sub_in is not None: query_sub_in += ',%s' % id else: query_sub_in = '%s' % id if query_sub_when == '': return update_transfer query_sql = query_head + ' SET u = CASE port' + query_sub_when + \ ' END, d = CASE port' + query_sub_when2 + \ ' END, t = ' + str(int(last_time)) + \ ' WHERE port IN (%s)' % query_sub_in if self.cfg["ssl_enable"] == 1: conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"], user=self.cfg["user"], passwd=self.cfg["password"], db=self.cfg["db"], charset='utf8', ssl={'ca':self.cfg["ssl_ca"],'cert':self.cfg["ssl_cert"],'key':self.cfg["ssl_key"]}) else: conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"], user=self.cfg["user"], passwd=self.cfg["password"], db=self.cfg["db"], charset='utf8') try: cur = conn.cursor() try: cur.execute(query_sql) except Exception as e: logging.error(e) update_transfer = {} cur.close() conn.commit() except Exception as e: logging.error(e) update_transfer = {} finally: conn.close() return update_transfer def pull_db_all_user(self): import cymysql #数据库所有用户信息 if self.cfg["ssl_enable"] == 1: conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"], user=self.cfg["user"], passwd=self.cfg["password"], db=self.cfg["db"], charset='utf8', ssl={'ca':self.cfg["ssl_ca"],'cert':self.cfg["ssl_cert"],'key':self.cfg["ssl_key"]}) else: conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"], user=self.cfg["user"], passwd=self.cfg["password"], db=self.cfg["db"], charset='utf8') try: rows = self.pull_db_users(conn) finally: conn.close() if not rows: logging.warn('no user in db') return rows def pull_db_users(self, conn): try: switchrule = importloader.load('switchrule') keys = switchrule.getKeys(self.key_list) except Exception as e: keys = self.key_list cur = conn.cursor() cur.execute("SELECT " + ','.join(keys) + " FROM user") rows = [] for r in cur.fetchall(): d = {} for column in range(len(keys)): d[keys[column]] = r[column] rows.append(d) cur.close() return rows class Dbv3Transfer(DbTransfer): def __init__(self): super(Dbv3Transfer, self).__init__() self.update_node_state = True if get_config().API_INTERFACE != 'legendsockssr' else False if self.update_node_state: self.key_list += ['id'] self.key_list += ['method'] if self.update_node_state: self.ss_node_info_name = 'ss_node_info_log' if get_config().API_INTERFACE == 'sspanelv3ssr': self.key_list += ['obfs', 'protocol'] if get_config().API_INTERFACE == 'glzjinmod': self.key_list += ['obfs', 'protocol'] self.ss_node_info_name = 'ss_node_info' else: self.key_list += ['obfs', 'protocol'] self.start_time = time.time() def update_all_user(self, dt_transfer): import cymysql update_transfer = {} query_head = 'UPDATE user' query_sub_when = '' query_sub_when2 = '' query_sub_in = None last_time = time.time() alive_user_count = len(self.onlineuser_cache) bandwidth_thistime = 0 if self.cfg["ssl_enable"] == 1: conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"], user=self.cfg["user"], passwd=self.cfg["password"], db=self.cfg["db"], charset='utf8', ssl={'ca':self.cfg["ssl_ca"],'cert':self.cfg["ssl_cert"],'key':self.cfg["ssl_key"]}) else: conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"], user=self.cfg["user"], passwd=self.cfg["password"], db=self.cfg["db"], charset='utf8') conn.autocommit(True) for id in dt_transfer.keys(): transfer = dt_transfer[id] bandwidth_thistime = bandwidth_thistime + transfer[0] + transfer[1] update_trs = 1024 * (2048 - self.user_pass.get(id, 0) * 64) if transfer[0] + transfer[1] < update_trs: self.user_pass[id] = self.user_pass.get(id, 0) + 1 continue if id in self.user_pass: del self.user_pass[id] query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * self.cfg["transfer_mul"])) query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * self.cfg["transfer_mul"])) update_transfer[id] = transfer if self.update_node_state: cur = conn.cursor() try: if id in self.port_uid_table: cur.execute("INSERT INTO `user_traffic_log` (`id`, `user_id`, `u`, `d`, `node_id`, `rate`, `traffic`, `log_time`) VALUES (NULL, '" + \ str(self.port_uid_table[id]) + "', '" + str(transfer[0]) + "', '" + str(transfer[1]) + "', '" + \ str(self.cfg["node_id"]) + "', '" + str(self.cfg["transfer_mul"]) + "', '" + \ self.traffic_format((transfer[0] + transfer[1]) * self.cfg["transfer_mul"]) + "', unix_timestamp()); ") except: logging.warn('no `user_traffic_log` in db') cur.close() if query_sub_in is not None: query_sub_in += ',%s' % id else: query_sub_in = '%s' % id if query_sub_when != '': query_sql = query_head + ' SET u = CASE port' + query_sub_when + \ ' END, d = CASE port' + query_sub_when2 + \ ' END, t = ' + str(int(last_time)) + \ ' WHERE port IN (%s)' % query_sub_in cur = conn.cursor() try: cur.execute(query_sql) except Exception as e: logging.error(e) cur.close() if self.update_node_state: try: cur = conn.cursor() try: cur.execute("INSERT INTO `ss_node_online_log` (`id`, `node_id`, `online_user`, `log_time`) VALUES (NULL, '" + \ str(self.cfg["node_id"]) + "', '" + str(alive_user_count) + "', unix_timestamp()); ") except Exception as e: logging.error(e) cur.close() cur = conn.cursor() try: cur.execute("INSERT INTO `" + self.ss_node_info_name + "` (`id`, `node_id`, `uptime`, `load`, `log_time`) VALUES (NULL, '" + \ str(self.cfg["node_id"]) + "', '" + str(self.uptime()) + "', '" + \ str(self.load()) + "', unix_timestamp()); ") except Exception as e: logging.error(e) cur.close() except: logging.warn('no `ss_node_online_log` or `" + self.ss_node_info_name + "` in db') conn.close() return update_transfer def pull_db_users(self, conn): try: switchrule = importloader.load('switchrule') keys = switchrule.getKeys(self.key_list) except Exception as e: keys = self.key_list cur = conn.cursor() if self.update_node_state: node_info_keys = ['traffic_rate'] try: cur.execute("SELECT " + ','.join(node_info_keys) +" FROM ss_node where `id`='" + str(self.cfg["node_id"]) + "'") nodeinfo = cur.fetchone() except Exception as e: logging.error(e) nodeinfo = None if nodeinfo == None: rows = [] cur.close() conn.commit() logging.warn('None result when select node info from ss_node in db, maybe you set the incorrect node id') return rows cur.close() node_info_dict = {} for column in range(len(nodeinfo)): node_info_dict[node_info_keys[column]] = nodeinfo[column] self.cfg['transfer_mul'] = float(node_info_dict['traffic_rate']) cur = conn.cursor() try: rows = [] cur.execute("SELECT " + ','.join(keys) + " FROM user") for r in cur.fetchall(): d = {} for column in range(len(keys)): d[keys[column]] = r[column] rows.append(d) except Exception as e: logging.error(e) cur.close() return rows def load(self): import os return os.popen("cat /proc/loadavg | awk '{ print $1\" \"$2\" \"$3 }'").readlines()[0] def uptime(self): return time.time() - self.start_time def traffic_format(self, traffic): if traffic < 1024 * 8: return str(int(traffic)) + "B"; if traffic < 1024 * 1024 * 2: return str(round((traffic / 1024.0), 2)) + "KB"; return str(round((traffic / 1048576.0), 2)) + "MB"; class MuJsonTransfer(TransferBase): def __init__(self): super(MuJsonTransfer, self).__init__() def update_all_user(self, dt_transfer): import json rows = None config_path = get_config().MUDB_FILE with open(config_path, 'rb+') as f: rows = json.loads(f.read().decode('utf8')) for row in rows: if "port" in row: port = row["port"] if port in dt_transfer: row["u"] += dt_transfer[port][0] row["d"] += dt_transfer[port][1] if rows: output = json.dumps(rows, sort_keys=True, indent=4, separators=(',', ': ')) with open(config_path, 'r+') as f: f.write(output) f.truncate() return dt_transfer def pull_db_all_user(self): import json rows = None config_path = get_config().MUDB_FILE with open(config_path, 'rb+') as f: rows = json.loads(f.read().decode('utf8')) for row in rows: try: if 'forbidden_ip' in row: row['forbidden_ip'] = common.IPNetwork(row['forbidden_ip']) except Exception as e: logging.error(e) try: if 'forbidden_port' in row: row['forbidden_port'] = common.PortRange(row['forbidden_port']) except Exception as e: logging.error(e) if not rows: logging.warn('no user in json file') return rows