Browse Source

use db key

dev
breakwa11 10 years ago
parent
commit
7c2fe9fd56
  1. 273
      db_transfer.py
  2. 18
      server.py
  3. 340
      server_pool.py
  4. 3
      switchrule.py

273
db_transfer.py

@ -7,142 +7,145 @@ import time
import sys import sys
from server_pool import ServerPool from server_pool import ServerPool
import Config import Config
import traceback
class DbTransfer(object): class DbTransfer(object):
instance = None instance = None
def __init__(self):
self.last_get_transfer = {}
@staticmethod
def get_instance():
if DbTransfer.instance is None:
DbTransfer.instance = DbTransfer()
return DbTransfer.instance
def push_db_all_user(self):
#更新用户流量到数据库
last_transfer = self.last_get_transfer
curr_transfer = ServerPool.get_instance().get_servers_transfer()
#上次和本次的增量
dt_transfer = {}
for id in curr_transfer.keys():
if id in last_transfer:
if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]:
continue
elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0:
continue
elif last_transfer[id][0] <= curr_transfer[id][0] and \
last_transfer[id][1] <= curr_transfer[id][1]:
dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0],
curr_transfer[id][1] - last_transfer[id][1]]
else:
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
else:
if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0:
continue
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
self.last_get_transfer = curr_transfer
query_head = 'UPDATE user'
query_sub_when = ''
query_sub_when2 = ''
query_sub_in = None
last_time = time.time()
for id in dt_transfer.keys():
query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0])
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1])
if query_sub_in is not None:
query_sub_in += ',%s' % id
else:
query_sub_in = '%s' % id
if query_sub_when == '':
return
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in
#print query_sql
conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER,
passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8')
cur = conn.cursor()
cur.execute(query_sql)
cur.close()
conn.commit()
conn.close()
@staticmethod
def pull_db_all_user():
#数据库所有用户信息
keys = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'switch', 'enable', 'plan' ]
conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER,
passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8')
cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user")
rows = []
for r in cur.fetchall():
d = {}
for column in xrange(len(keys)):
d[keys[column]] = r[column]
rows.append(d)
cur.close()
conn.close()
return rows
@staticmethod
def del_server_out_of_bound_safe(last_rows, rows):
#停止超流量的服务
#启动没超流量的服务
#需要动态载入switchrule,以便实时修改规则
cur_servers = {}
for row in rows:
try:
import switchrule
allow = switchrule.isTurnOn(row) and row['enable'] == 1 and row['u'] + row['d'] < row['transfer_enable']
except Exception, e:
allow = False
port = row['port']
passwd = row['passwd']
cur_servers[port] = passwd
if ServerPool.get_instance().server_is_run(port) > 0:
if not allow:
logging.info('db stop server at port [%s]' % (port,))
ServerPool.get_instance().del_server(port)
elif (port in ServerPool.get_instance().tcp_servers_pool and ServerPool.get_instance().tcp_servers_pool[port]._config['password'] != passwd) \
or (port in ServerPool.get_instance().tcp_ipv6_servers_pool and ServerPool.get_instance().tcp_ipv6_servers_pool[port]._config['password'] != passwd):
#password changed
logging.info('db stop server at port [%s] reason: password changed' % (port,))
ServerPool.get_instance().del_server(port)
if allow and ServerPool.get_instance().server_is_run(port) == 0:
logging.info('db start server at port [%s] pass [%s]' % (port, passwd))
ServerPool.get_instance().new_server(port, passwd)
for row in last_rows:
if row['port'] in cur_servers:
pass
else:
logging.info('db stop server at port [%s] reason: port not exist' % (row['port']))
ServerPool.get_instance().del_server(row['port'])
@staticmethod
def thread_db():
import socket
import time
timeout = 60
socket.setdefaulttimeout(timeout)
last_rows = []
while True:
try:
DbTransfer.get_instance().push_db_all_user()
rows = DbTransfer.get_instance().pull_db_all_user()
DbTransfer.del_server_out_of_bound_safe(last_rows, rows)
last_rows = rows
except Exception as e:
trace = traceback.format_exc()
logging.error(trace)
#logging.warn('db thread except:%s' % e)
finally:
time.sleep(15)
def __init__(self):
self.last_get_transfer = {}
@staticmethod
def get_instance():
if DbTransfer.instance is None:
DbTransfer.instance = DbTransfer()
return DbTransfer.instance
def push_db_all_user(self):
#更新用户流量到数据库
last_transfer = self.last_get_transfer
curr_transfer = ServerPool.get_instance().get_servers_transfer()
#上次和本次的增量
dt_transfer = {}
for id in curr_transfer.keys():
if id in last_transfer:
if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]:
continue
elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0:
continue
elif last_transfer[id][0] <= curr_transfer[id][0] and \
last_transfer[id][1] <= curr_transfer[id][1]:
dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0],
curr_transfer[id][1] - last_transfer[id][1]]
else:
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
else:
if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0:
continue
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
self.last_get_transfer = curr_transfer
query_head = 'UPDATE user'
query_sub_when = ''
query_sub_when2 = ''
query_sub_in = None
last_time = time.time()
for id in dt_transfer.keys():
query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0])
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1])
if query_sub_in is not None:
query_sub_in += ',%s' % id
else:
query_sub_in = '%s' % id
if query_sub_when == '':
return
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in
#print query_sql
conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER,
passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8')
cur = conn.cursor()
cur.execute(query_sql)
cur.close()
conn.commit()
conn.close()
@staticmethod
def pull_db_all_user():
#数据库所有用户信息
conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER,
passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8')
cur = conn.cursor()
cur.execute("SELECT port, u, d, transfer_enable, passwd, switch, enable, plan FROM user")
rows = []
for r in cur.fetchall():
rows.append(list(r))
cur.close()
conn.close()
return rows
@staticmethod
def del_server_out_of_bound_safe(last_rows, rows):
#停止超流量的服务
#启动没超流量的服务
#需要动态载入switchrule,以便实时修改规则
cur_servers = {}
for row in rows:
try:
import switchrule
allow = switchrule.isTurnOn(row[7], row[5]) and row[6] == 1 and row[1] + row[2] < row[3]
except Exception, e:
allow = False
cur_servers[row[0]] = row[4]
if ServerPool.get_instance().server_is_run(row[0]) > 0:
if not allow:
logging.info('db stop server at port [%s]' % (row[0]))
ServerPool.get_instance().del_server(row[0])
elif (row[0] in ServerPool.get_instance().tcp_servers_pool and ServerPool.get_instance().tcp_servers_pool[row[0]]._config['password'] != row[4]) \
or (row[0] in ServerPool.get_instance().tcp_ipv6_servers_pool and ServerPool.get_instance().tcp_ipv6_servers_pool[row[0]]._config['password'] != row[4]):
#password changed
logging.info('db stop server at port [%s] reason: password changed' % (row[0]))
ServerPool.get_instance().del_server(row[0])
elif ServerPool.get_instance().server_run_status(row[0]) is False:
if allow:
logging.info('db start server at port [%s] pass [%s]' % (row[0], row[4]))
ServerPool.get_instance().new_server(row[0], row[4])
for row in last_rows:
if row[0] in cur_servers:
if row[4] == cur_servers[row[0]]:
pass
else:
logging.info('db stop server at port [%s] reason: port not exist' % (row[0]))
ServerPool.get_instance().del_server(row[0])
@staticmethod
def thread_db():
import socket
import time
timeout = 60
socket.setdefaulttimeout(timeout)
last_rows = []
while True:
#logging.warn('db loop')
try:
DbTransfer.get_instance().push_db_all_user()
rows = DbTransfer.get_instance().pull_db_all_user()
DbTransfer.del_server_out_of_bound_safe(last_rows, rows)
last_rows = rows
except Exception as e:
logging.warn('db thread except:%s' % e)
finally:
time.sleep(15)
#SQLData.pull_db_all_user()
#print DbTransfer.get_instance().test()

18
server.py

@ -1,6 +1,4 @@
#!/usr/bin/env python #!/usr/bin/python
# -*- coding: utf-8 -*-
import time import time
import sys import sys
import thread import thread
@ -11,12 +9,12 @@ import server_pool
import db_transfer import db_transfer
#def test(): #def test():
# thread.start_new_thread(DbTransfer.thread_db, ()) # thread.start_new_thread(DbTransfer.thread_db, ())
# Api.web_server() # Api.web_server()
if __name__ == '__main__': if __name__ == '__main__':
#server_pool.ServerPool.get_instance() #server_pool.ServerPool.get_instance()
#server_pool.ServerPool.get_instance().new_server(2333, '2333') #server_pool.ServerPool.get_instance().new_server(2333, '2333')
thread.start_new_thread(db_transfer.DbTransfer.thread_db, ()) thread.start_new_thread(db_transfer.DbTransfer.thread_db, ())
while True: while True:
time.sleep(99999) time.sleep(99999)

340
server_pool.py

@ -24,7 +24,7 @@
import os import os
import logging import logging
import time import time
from shadowsocks import utils from shadowsocks import shell
from shadowsocks import eventloop from shadowsocks import eventloop
from shadowsocks import tcprelay from shadowsocks import tcprelay
from shadowsocks import udprelay from shadowsocks import udprelay
@ -38,174 +38,172 @@ from socket import *
class ServerPool(object): class ServerPool(object):
instance = None instance = None
def __init__(self): def __init__(self):
utils.check_python() shell.check_python()
self.config = utils.get_config(False) self.config = shell.get_config(False)
utils.print_shadowsocks() shell.print_shadowsocks()
self.dns_resolver = asyncdns.DNSResolver() self.dns_resolver = asyncdns.DNSResolver()
self.mgr = asyncmgr.ServerMgr() self.mgr = asyncmgr.ServerMgr()
self.udp_on = True ### UDP switch ===================================== self.udp_on = True ### UDP switch =====================================
self.tcp_servers_pool = {} self.tcp_servers_pool = {}
self.tcp_ipv6_servers_pool = {} self.tcp_ipv6_servers_pool = {}
self.udp_servers_pool = {} self.udp_servers_pool = {}
self.udp_ipv6_servers_pool = {} self.udp_ipv6_servers_pool = {}
self.loop = eventloop.EventLoop() self.loop = eventloop.EventLoop()
thread.start_new_thread(ServerPool._loop, (self.loop, self.dns_resolver, self.mgr)) thread.start_new_thread(ServerPool._loop, (self.loop, self.dns_resolver, self.mgr))
@staticmethod @staticmethod
def get_instance(): def get_instance():
if ServerPool.instance is None: if ServerPool.instance is None:
ServerPool.instance = ServerPool() ServerPool.instance = ServerPool()
return ServerPool.instance return ServerPool.instance
@staticmethod @staticmethod
def _loop(loop, dns_resolver, mgr): def _loop(loop, dns_resolver, mgr):
try: try:
mgr.add_to_loop(loop) mgr.add_to_loop(loop)
dns_resolver.add_to_loop(loop) dns_resolver.add_to_loop(loop)
loop.run() loop.run()
except (KeyboardInterrupt, IOError, OSError) as e: except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e) logging.error(e)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
os.exit(0) os.exit(0)
def server_is_run(self, port): def server_is_run(self, port):
port = int(port) port = int(port)
ret = 0 ret = 0
if port in self.tcp_servers_pool: if port in self.tcp_servers_pool:
ret = 1 ret = 1
if port in self.tcp_ipv6_servers_pool: if port in self.tcp_ipv6_servers_pool:
ret |= 2 ret |= 2
return ret return ret
def server_run_status(self, port): def server_run_status(self, port):
if 'server' in self.config: if 'server' in self.config:
if port not in self.tcp_servers_pool: if port not in self.tcp_servers_pool:
return False return False
if 'server_ipv6' in self.config: if 'server_ipv6' in self.config:
if port not in self.tcp_ipv6_servers_pool: if port not in self.tcp_ipv6_servers_pool:
return False return False
return True return True
def new_server(self, port, password): def new_server(self, port, password):
ret = True ret = True
port = int(port) port = int(port)
if 'server_ipv6' in self.config: if 'server_ipv6' in self.config:
if port in self.tcp_ipv6_servers_pool: if port in self.tcp_ipv6_servers_pool:
logging.info("server already at %s:%d" % (self.config['server_ipv6'], port)) logging.info("server already at %s:%d" % (self.config['server_ipv6'], port))
return 'this port server is already running' return 'this port server is already running'
else: else:
a_config = self.config.copy() a_config = self.config.copy()
a_config['server'] = a_config['server_ipv6'] a_config['server'] = a_config['server_ipv6']
a_config['server_port'] = port a_config['server_port'] = port
a_config['password'] = password a_config['password'] = password
try: try:
logging.info("starting server at %s:%d" % (a_config['server'], port)) logging.info("starting server at %s:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server.add_to_loop(self.loop) tcp_server.add_to_loop(self.loop)
self.tcp_ipv6_servers_pool.update({port: tcp_server}) self.tcp_ipv6_servers_pool.update({port: tcp_server})
if self.udp_on: if self.udp_on:
udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False)
udp_server.add_to_loop(self.loop) udp_server.add_to_loop(self.loop)
self.udp_ipv6_servers_pool.update({port: udp_server}) self.udp_ipv6_servers_pool.update({port: udp_server})
except Exception, e: except Exception, e:
logging.warn("IPV6 exception") logging.warn("IPV6 %s " % (e,))
logging.warn(e)
if 'server' in self.config:
if 'server' in self.config: if port in self.tcp_servers_pool:
if port in self.tcp_servers_pool: logging.info("server already at %s:%d" % (self.config['server'], port))
logging.info("server already at %s:%d" % (self.config['server'], port)) return 'this port server is already running'
return 'this port server is already running' else:
else: a_config = self.config.copy()
a_config = self.config.copy() a_config['server_port'] = port
a_config['server_port'] = port a_config['password'] = password
a_config['password'] = password try:
try: logging.info("starting server at %s:%d" % (a_config['server'], port))
logging.info("starting server at %s:%d" % (a_config['server'], port)) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) tcp_server.add_to_loop(self.loop)
tcp_server.add_to_loop(self.loop) self.tcp_servers_pool.update({port: tcp_server})
self.tcp_servers_pool.update({port: tcp_server}) if self.udp_on:
if self.udp_on: udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False)
udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) udp_server.add_to_loop(self.loop)
udp_server.add_to_loop(self.loop) self.udp_servers_pool.update({port: udp_server})
self.udp_servers_pool.update({port: udp_server}) except Exception, e:
except Exception, e: logging.warn("IPV4 %s " % (e,))
logging.warn("IPV4 exception")
logging.warn(e) return True
return True def del_server(self, port):
port = int(port)
def del_server(self, port): logging.info("del server at %d" % port)
port = int(port) try:
logging.info("del server at %d" % port) udpsock = socket(AF_INET, SOCK_DGRAM)
try: udpsock.sendto('%s:%s:0:0' % (Config.MANAGE_PASS, port), (Config.MANAGE_BIND_IP, Config.MANAGE_PORT))
udpsock = socket(AF_INET, SOCK_DGRAM) udpsock.close()
udpsock.sendto('%s:%s:0:0' % (Config.MANAGE_PASS, port), (Config.MANAGE_BIND_IP, Config.MANAGE_PORT)) except Exception, e:
udpsock.close() logging.warn(e)
except Exception, e: return True
logging.warn(e)
return True def cb_del_server(self, port):
port = int(port)
def cb_del_server(self, port):
port = int(port) if port not in self.tcp_servers_pool:
logging.info("stopped server at %s:%d already stop" % (self.config['server'], port))
if port not in self.tcp_servers_pool: else:
logging.info("stopped server at %s:%d already stop" % (self.config['server'], port)) logging.info("stopped server at %s:%d" % (self.config['server'], port))
else: try:
logging.info("stopped server at %s:%d" % (self.config['server'], port)) self.tcp_servers_pool[port].destroy()
try: del self.tcp_servers_pool[port]
self.tcp_servers_pool[port].destroy() except Exception, e:
del self.tcp_servers_pool[port] logging.warn(e)
except Exception, e: if self.udp_on:
logging.warn(e) try:
if self.udp_on: self.udp_servers_pool[port].destroy()
try: del self.udp_servers_pool[port]
self.udp_servers_pool[port].destroy() except Exception, e:
del self.udp_servers_pool[port] logging.warn(e)
except Exception, e:
logging.warn(e) if 'server_ipv6' in self.config:
if port not in self.tcp_ipv6_servers_pool:
if 'server_ipv6' in self.config: logging.info("stopped server at %s:%d already stop" % (self.config['server_ipv6'], port))
if port not in self.tcp_ipv6_servers_pool: else:
logging.info("stopped server at %s:%d already stop" % (self.config['server_ipv6'], port)) logging.info("stopped server at %s:%d" % (self.config['server_ipv6'], port))
else: try:
logging.info("stopped server at %s:%d" % (self.config['server_ipv6'], port)) self.tcp_ipv6_servers_pool[port].destroy()
try: del self.tcp_ipv6_servers_pool[port]
self.tcp_ipv6_servers_pool[port].destroy() except Exception, e:
del self.tcp_ipv6_servers_pool[port] logging.warn(e)
except Exception, e: if self.udp_on:
logging.warn(e) try:
if self.udp_on: self.udp_ipv6_servers_pool[port].destroy()
try: del self.udp_ipv6_servers_pool[port]
self.udp_ipv6_servers_pool[port].destroy() except Exception, e:
del self.udp_ipv6_servers_pool[port] logging.warn(e)
except Exception, e:
logging.warn(e) return True
return True def get_server_transfer(self, port):
port = int(port)
def get_server_transfer(self, port): ret = [0, 0]
port = int(port) if port in self.tcp_servers_pool:
ret = [0, 0] ret[0] = self.tcp_servers_pool[port].server_transfer_ul
if port in self.tcp_servers_pool: ret[1] = self.tcp_servers_pool[port].server_transfer_dl
ret[0] = self.tcp_servers_pool[port].server_transfer_ul if port in self.tcp_ipv6_servers_pool:
ret[1] = self.tcp_servers_pool[port].server_transfer_dl ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul
if port in self.tcp_ipv6_servers_pool: ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl
ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul return ret
ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl
return ret def get_servers_transfer(self):
servers = self.tcp_servers_pool.copy()
def get_servers_transfer(self): servers.update(self.tcp_ipv6_servers_pool)
servers = self.tcp_servers_pool.copy() ret = {}
servers.update(self.tcp_ipv6_servers_pool) for port in servers.keys():
ret = {} ret[port] = self.get_server_transfer(port)
for port in servers.keys(): return ret
ret[port] = self.get_server_transfer(port)
return ret

3
switchrule.py

@ -1,2 +1,3 @@
def isTurnOn(plan, switch): def isTurnOn(row):
return True return True

Loading…
Cancel
Save