diff --git a/db_transfer.py b/db_transfer.py index 8126d0a..fc1c80d 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -7,142 +7,145 @@ import time import sys from server_pool import ServerPool import Config +import traceback class DbTransfer(object): - instance = None + instance = None + + def __init__(self): + self.last_get_transfer = {} + + @staticmethod + def get_instance(): + if DbTransfer.instance is None: + DbTransfer.instance = DbTransfer() + return DbTransfer.instance + + def push_db_all_user(self): + #更新用户流量到数据库 + last_transfer = self.last_get_transfer + curr_transfer = ServerPool.get_instance().get_servers_transfer() + #上次和本次的增量 + dt_transfer = {} + for id in curr_transfer.keys(): + if id in last_transfer: + if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]: + continue + elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: + continue + elif last_transfer[id][0] <= curr_transfer[id][0] and \ + last_transfer[id][1] <= curr_transfer[id][1]: + dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0], + curr_transfer[id][1] - last_transfer[id][1]] + else: + dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] + else: + if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: + continue + dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] + + self.last_get_transfer = curr_transfer + query_head = 'UPDATE user' + query_sub_when = '' + query_sub_when2 = '' + query_sub_in = None + last_time = time.time() + for id in dt_transfer.keys(): + query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0]) + query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1]) + if query_sub_in is not None: + query_sub_in += ',%s' % id + else: + query_sub_in = '%s' % id + if query_sub_when == '': + return + query_sql = query_head + ' SET u = CASE port' + query_sub_when + \ + ' END, d = CASE port' + query_sub_when2 + \ + ' END, t = ' + str(int(last_time)) + \ + ' WHERE port IN (%s)' % query_sub_in + #print query_sql + conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER, + passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8') + cur = conn.cursor() + cur.execute(query_sql) + cur.close() + conn.commit() + conn.close() + + @staticmethod + def pull_db_all_user(): + #数据库所有用户信息 + keys = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'switch', 'enable', 'plan' ] + conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER, + passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8') + cur = conn.cursor() + cur.execute("SELECT " + ','.join(keys) + " FROM user") + rows = [] + for r in cur.fetchall(): + d = {} + for column in xrange(len(keys)): + d[keys[column]] = r[column] + rows.append(d) + cur.close() + conn.close() + return rows + + @staticmethod + def del_server_out_of_bound_safe(last_rows, rows): + #停止超流量的服务 + #启动没超流量的服务 + #需要动态载入switchrule,以便实时修改规则 + cur_servers = {} + for row in rows: + try: + import switchrule + allow = switchrule.isTurnOn(row) and row['enable'] == 1 and row['u'] + row['d'] < row['transfer_enable'] + except Exception, e: + allow = False + + port = row['port'] + passwd = row['passwd'] + cur_servers[port] = passwd + + if ServerPool.get_instance().server_is_run(port) > 0: + if not allow: + logging.info('db stop server at port [%s]' % (port,)) + ServerPool.get_instance().del_server(port) + elif (port in ServerPool.get_instance().tcp_servers_pool and ServerPool.get_instance().tcp_servers_pool[port]._config['password'] != passwd) \ + or (port in ServerPool.get_instance().tcp_ipv6_servers_pool and ServerPool.get_instance().tcp_ipv6_servers_pool[port]._config['password'] != passwd): + #password changed + logging.info('db stop server at port [%s] reason: password changed' % (port,)) + ServerPool.get_instance().del_server(port) + + if allow and ServerPool.get_instance().server_is_run(port) == 0: + logging.info('db start server at port [%s] pass [%s]' % (port, passwd)) + ServerPool.get_instance().new_server(port, passwd) + + for row in last_rows: + if row['port'] in cur_servers: + pass + else: + logging.info('db stop server at port [%s] reason: port not exist' % (row['port'])) + ServerPool.get_instance().del_server(row['port']) + + @staticmethod + def thread_db(): + import socket + import time + timeout = 60 + socket.setdefaulttimeout(timeout) + last_rows = [] + while True: + try: + DbTransfer.get_instance().push_db_all_user() + rows = DbTransfer.get_instance().pull_db_all_user() + DbTransfer.del_server_out_of_bound_safe(last_rows, rows) + last_rows = rows + except Exception as e: + trace = traceback.format_exc() + logging.error(trace) + #logging.warn('db thread except:%s' % e) + finally: + time.sleep(15) - def __init__(self): - self.last_get_transfer = {} - - @staticmethod - def get_instance(): - if DbTransfer.instance is None: - DbTransfer.instance = DbTransfer() - return DbTransfer.instance - - def push_db_all_user(self): - #更新用户流量到数据库 - last_transfer = self.last_get_transfer - curr_transfer = ServerPool.get_instance().get_servers_transfer() - #上次和本次的增量 - dt_transfer = {} - for id in curr_transfer.keys(): - if id in last_transfer: - if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]: - continue - elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: - continue - elif last_transfer[id][0] <= curr_transfer[id][0] and \ - last_transfer[id][1] <= curr_transfer[id][1]: - dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0], - curr_transfer[id][1] - last_transfer[id][1]] - else: - dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] - else: - if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: - continue - dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] - - self.last_get_transfer = curr_transfer - query_head = 'UPDATE user' - query_sub_when = '' - query_sub_when2 = '' - query_sub_in = None - last_time = time.time() - for id in dt_transfer.keys(): - query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0]) - query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1]) - if query_sub_in is not None: - query_sub_in += ',%s' % id - else: - query_sub_in = '%s' % id - if query_sub_when == '': - return - query_sql = query_head + ' SET u = CASE port' + query_sub_when + \ - ' END, d = CASE port' + query_sub_when2 + \ - ' END, t = ' + str(int(last_time)) + \ - ' WHERE port IN (%s)' % query_sub_in - #print query_sql - conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER, - passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8') - cur = conn.cursor() - cur.execute(query_sql) - cur.close() - conn.commit() - conn.close() - - @staticmethod - def pull_db_all_user(): - #数据库所有用户信息 - conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER, - passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8') - cur = conn.cursor() - cur.execute("SELECT port, u, d, transfer_enable, passwd, switch, enable, plan FROM user") - rows = [] - for r in cur.fetchall(): - rows.append(list(r)) - cur.close() - conn.close() - return rows - - @staticmethod - def del_server_out_of_bound_safe(last_rows, rows): - #停止超流量的服务 - #启动没超流量的服务 - #需要动态载入switchrule,以便实时修改规则 - cur_servers = {} - for row in rows: - try: - import switchrule - allow = switchrule.isTurnOn(row[7], row[5]) and row[6] == 1 and row[1] + row[2] < row[3] - except Exception, e: - allow = False - - cur_servers[row[0]] = row[4] - - if ServerPool.get_instance().server_is_run(row[0]) > 0: - if not allow: - logging.info('db stop server at port [%s]' % (row[0])) - ServerPool.get_instance().del_server(row[0]) - elif (row[0] in ServerPool.get_instance().tcp_servers_pool and ServerPool.get_instance().tcp_servers_pool[row[0]]._config['password'] != row[4]) \ - or (row[0] in ServerPool.get_instance().tcp_ipv6_servers_pool and ServerPool.get_instance().tcp_ipv6_servers_pool[row[0]]._config['password'] != row[4]): - #password changed - logging.info('db stop server at port [%s] reason: password changed' % (row[0])) - ServerPool.get_instance().del_server(row[0]) - elif ServerPool.get_instance().server_run_status(row[0]) is False: - if allow: - logging.info('db start server at port [%s] pass [%s]' % (row[0], row[4])) - ServerPool.get_instance().new_server(row[0], row[4]) - - for row in last_rows: - if row[0] in cur_servers: - if row[4] == cur_servers[row[0]]: - pass - else: - logging.info('db stop server at port [%s] reason: port not exist' % (row[0])) - ServerPool.get_instance().del_server(row[0]) - - @staticmethod - def thread_db(): - import socket - import time - timeout = 60 - socket.setdefaulttimeout(timeout) - last_rows = [] - while True: - #logging.warn('db loop') - - try: - DbTransfer.get_instance().push_db_all_user() - rows = DbTransfer.get_instance().pull_db_all_user() - DbTransfer.del_server_out_of_bound_safe(last_rows, rows) - last_rows = rows - except Exception as e: - logging.warn('db thread except:%s' % e) - finally: - time.sleep(15) - - -#SQLData.pull_db_all_user() -#print DbTransfer.get_instance().test() diff --git a/server.py b/server.py index ad1e558..de03e47 100644 --- a/server.py +++ b/server.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - +#!/usr/bin/python import time import sys import thread @@ -11,12 +9,12 @@ import server_pool import db_transfer #def test(): -# thread.start_new_thread(DbTransfer.thread_db, ()) -# Api.web_server() +# thread.start_new_thread(DbTransfer.thread_db, ()) +# Api.web_server() if __name__ == '__main__': - #server_pool.ServerPool.get_instance() - #server_pool.ServerPool.get_instance().new_server(2333, '2333') - thread.start_new_thread(db_transfer.DbTransfer.thread_db, ()) - while True: - time.sleep(99999) + #server_pool.ServerPool.get_instance() + #server_pool.ServerPool.get_instance().new_server(2333, '2333') + thread.start_new_thread(db_transfer.DbTransfer.thread_db, ()) + while True: + time.sleep(99999) diff --git a/server_pool.py b/server_pool.py index c123271..857afdf 100644 --- a/server_pool.py +++ b/server_pool.py @@ -24,7 +24,7 @@ import os import logging import time -from shadowsocks import utils +from shadowsocks import shell from shadowsocks import eventloop from shadowsocks import tcprelay from shadowsocks import udprelay @@ -38,174 +38,172 @@ from socket import * class ServerPool(object): - instance = None - - def __init__(self): - utils.check_python() - self.config = utils.get_config(False) - utils.print_shadowsocks() - self.dns_resolver = asyncdns.DNSResolver() - self.mgr = asyncmgr.ServerMgr() - self.udp_on = True ### UDP switch ===================================== - - self.tcp_servers_pool = {} - self.tcp_ipv6_servers_pool = {} - self.udp_servers_pool = {} - self.udp_ipv6_servers_pool = {} - - self.loop = eventloop.EventLoop() - thread.start_new_thread(ServerPool._loop, (self.loop, self.dns_resolver, self.mgr)) - - @staticmethod - def get_instance(): - if ServerPool.instance is None: - ServerPool.instance = ServerPool() - return ServerPool.instance - - @staticmethod - def _loop(loop, dns_resolver, mgr): - try: - mgr.add_to_loop(loop) - dns_resolver.add_to_loop(loop) - loop.run() - except (KeyboardInterrupt, IOError, OSError) as e: - logging.error(e) - import traceback - traceback.print_exc() - os.exit(0) - - def server_is_run(self, port): - port = int(port) - ret = 0 - if port in self.tcp_servers_pool: - ret = 1 - if port in self.tcp_ipv6_servers_pool: - ret |= 2 - return ret - - def server_run_status(self, port): - if 'server' in self.config: - if port not in self.tcp_servers_pool: - return False - if 'server_ipv6' in self.config: - if port not in self.tcp_ipv6_servers_pool: - return False - return True - - def new_server(self, port, password): - ret = True - port = int(port) - - if 'server_ipv6' in self.config: - if port in self.tcp_ipv6_servers_pool: - logging.info("server already at %s:%d" % (self.config['server_ipv6'], port)) - return 'this port server is already running' - else: - a_config = self.config.copy() - a_config['server'] = a_config['server_ipv6'] - a_config['server_port'] = port - a_config['password'] = password - try: - logging.info("starting server at %s:%d" % (a_config['server'], port)) - tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) - tcp_server.add_to_loop(self.loop) - self.tcp_ipv6_servers_pool.update({port: tcp_server}) - if self.udp_on: - udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) - udp_server.add_to_loop(self.loop) - self.udp_ipv6_servers_pool.update({port: udp_server}) - except Exception, e: - logging.warn("IPV6 exception") - logging.warn(e) - - if 'server' in self.config: - if port in self.tcp_servers_pool: - logging.info("server already at %s:%d" % (self.config['server'], port)) - return 'this port server is already running' - else: - a_config = self.config.copy() - a_config['server_port'] = port - a_config['password'] = password - try: - logging.info("starting server at %s:%d" % (a_config['server'], port)) - tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) - tcp_server.add_to_loop(self.loop) - self.tcp_servers_pool.update({port: tcp_server}) - if self.udp_on: - udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) - udp_server.add_to_loop(self.loop) - self.udp_servers_pool.update({port: udp_server}) - except Exception, e: - logging.warn("IPV4 exception") - logging.warn(e) - - return True - - def del_server(self, port): - port = int(port) - logging.info("del server at %d" % port) - try: - udpsock = socket(AF_INET, SOCK_DGRAM) - udpsock.sendto('%s:%s:0:0' % (Config.MANAGE_PASS, port), (Config.MANAGE_BIND_IP, Config.MANAGE_PORT)) - udpsock.close() - except Exception, e: - logging.warn(e) - return True - - def cb_del_server(self, port): - port = int(port) - - if port not in self.tcp_servers_pool: - logging.info("stopped server at %s:%d already stop" % (self.config['server'], port)) - else: - logging.info("stopped server at %s:%d" % (self.config['server'], port)) - try: - self.tcp_servers_pool[port].destroy() - del self.tcp_servers_pool[port] - except Exception, e: - logging.warn(e) - if self.udp_on: - try: - self.udp_servers_pool[port].destroy() - del self.udp_servers_pool[port] - except Exception, e: - logging.warn(e) - - if 'server_ipv6' in self.config: - if port not in self.tcp_ipv6_servers_pool: - logging.info("stopped server at %s:%d already stop" % (self.config['server_ipv6'], port)) - else: - logging.info("stopped server at %s:%d" % (self.config['server_ipv6'], port)) - try: - self.tcp_ipv6_servers_pool[port].destroy() - del self.tcp_ipv6_servers_pool[port] - except Exception, e: - logging.warn(e) - if self.udp_on: - try: - self.udp_ipv6_servers_pool[port].destroy() - del self.udp_ipv6_servers_pool[port] - except Exception, e: - logging.warn(e) - - return True - - def get_server_transfer(self, port): - port = int(port) - ret = [0, 0] - if port in self.tcp_servers_pool: - ret[0] = self.tcp_servers_pool[port].server_transfer_ul - ret[1] = self.tcp_servers_pool[port].server_transfer_dl - if port in self.tcp_ipv6_servers_pool: - ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul - ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl - return ret - - def get_servers_transfer(self): - servers = self.tcp_servers_pool.copy() - servers.update(self.tcp_ipv6_servers_pool) - ret = {} - for port in servers.keys(): - ret[port] = self.get_server_transfer(port) - return ret + instance = None + + def __init__(self): + shell.check_python() + self.config = shell.get_config(False) + shell.print_shadowsocks() + self.dns_resolver = asyncdns.DNSResolver() + self.mgr = asyncmgr.ServerMgr() + self.udp_on = True ### UDP switch ===================================== + + self.tcp_servers_pool = {} + self.tcp_ipv6_servers_pool = {} + self.udp_servers_pool = {} + self.udp_ipv6_servers_pool = {} + + self.loop = eventloop.EventLoop() + thread.start_new_thread(ServerPool._loop, (self.loop, self.dns_resolver, self.mgr)) + + @staticmethod + def get_instance(): + if ServerPool.instance is None: + ServerPool.instance = ServerPool() + return ServerPool.instance + + @staticmethod + def _loop(loop, dns_resolver, mgr): + try: + mgr.add_to_loop(loop) + dns_resolver.add_to_loop(loop) + loop.run() + except (KeyboardInterrupt, IOError, OSError) as e: + logging.error(e) + import traceback + traceback.print_exc() + os.exit(0) + + def server_is_run(self, port): + port = int(port) + ret = 0 + if port in self.tcp_servers_pool: + ret = 1 + if port in self.tcp_ipv6_servers_pool: + ret |= 2 + return ret + + def server_run_status(self, port): + if 'server' in self.config: + if port not in self.tcp_servers_pool: + return False + if 'server_ipv6' in self.config: + if port not in self.tcp_ipv6_servers_pool: + return False + return True + + def new_server(self, port, password): + ret = True + port = int(port) + + if 'server_ipv6' in self.config: + if port in self.tcp_ipv6_servers_pool: + logging.info("server already at %s:%d" % (self.config['server_ipv6'], port)) + return 'this port server is already running' + else: + a_config = self.config.copy() + a_config['server'] = a_config['server_ipv6'] + a_config['server_port'] = port + a_config['password'] = password + try: + logging.info("starting server at %s:%d" % (a_config['server'], port)) + tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) + tcp_server.add_to_loop(self.loop) + self.tcp_ipv6_servers_pool.update({port: tcp_server}) + if self.udp_on: + udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) + udp_server.add_to_loop(self.loop) + self.udp_ipv6_servers_pool.update({port: udp_server}) + except Exception, e: + logging.warn("IPV6 %s " % (e,)) + + if 'server' in self.config: + if port in self.tcp_servers_pool: + logging.info("server already at %s:%d" % (self.config['server'], port)) + return 'this port server is already running' + else: + a_config = self.config.copy() + a_config['server_port'] = port + a_config['password'] = password + try: + logging.info("starting server at %s:%d" % (a_config['server'], port)) + tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) + tcp_server.add_to_loop(self.loop) + self.tcp_servers_pool.update({port: tcp_server}) + if self.udp_on: + udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) + udp_server.add_to_loop(self.loop) + self.udp_servers_pool.update({port: udp_server}) + except Exception, e: + logging.warn("IPV4 %s " % (e,)) + + return True + + def del_server(self, port): + port = int(port) + logging.info("del server at %d" % port) + try: + udpsock = socket(AF_INET, SOCK_DGRAM) + udpsock.sendto('%s:%s:0:0' % (Config.MANAGE_PASS, port), (Config.MANAGE_BIND_IP, Config.MANAGE_PORT)) + udpsock.close() + except Exception, e: + logging.warn(e) + return True + + def cb_del_server(self, port): + port = int(port) + + if port not in self.tcp_servers_pool: + logging.info("stopped server at %s:%d already stop" % (self.config['server'], port)) + else: + logging.info("stopped server at %s:%d" % (self.config['server'], port)) + try: + self.tcp_servers_pool[port].destroy() + del self.tcp_servers_pool[port] + except Exception, e: + logging.warn(e) + if self.udp_on: + try: + self.udp_servers_pool[port].destroy() + del self.udp_servers_pool[port] + except Exception, e: + logging.warn(e) + + if 'server_ipv6' in self.config: + if port not in self.tcp_ipv6_servers_pool: + logging.info("stopped server at %s:%d already stop" % (self.config['server_ipv6'], port)) + else: + logging.info("stopped server at %s:%d" % (self.config['server_ipv6'], port)) + try: + self.tcp_ipv6_servers_pool[port].destroy() + del self.tcp_ipv6_servers_pool[port] + except Exception, e: + logging.warn(e) + if self.udp_on: + try: + self.udp_ipv6_servers_pool[port].destroy() + del self.udp_ipv6_servers_pool[port] + except Exception, e: + logging.warn(e) + + return True + + def get_server_transfer(self, port): + port = int(port) + ret = [0, 0] + if port in self.tcp_servers_pool: + ret[0] = self.tcp_servers_pool[port].server_transfer_ul + ret[1] = self.tcp_servers_pool[port].server_transfer_dl + if port in self.tcp_ipv6_servers_pool: + ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul + ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl + return ret + + def get_servers_transfer(self): + servers = self.tcp_servers_pool.copy() + servers.update(self.tcp_ipv6_servers_pool) + ret = {} + for port in servers.keys(): + ret[port] = self.get_server_transfer(port) + return ret diff --git a/switchrule.py b/switchrule.py index ce6b76e..2a0b324 100644 --- a/switchrule.py +++ b/switchrule.py @@ -1,2 +1,3 @@ -def isTurnOn(plan, switch): +def isTurnOn(row): return True +