Browse Source

using mysql.json to config mysql

dev
BreakWa11 8 years ago
parent
commit
0a7d71a393
  1. 30
      apiconfig.py
  2. 244
      db_transfer.py
  3. 13
      mysql.json
  4. 7
      shadowsocks/crypto/table.py
  5. 9
      switchrule.py

30
apiconfig.py

@ -1,35 +1,15 @@
# Config # Config
TRANSFER_MUL = 1.0
NODE_ID = 1
SERVER_PUB_ADDR = '127.0.0.1' # mujson_mgr need this to generate ssr link
API_INTERFACE = 'sspanelv2' #mudbjson, sspanelv2, sspanelv3, sspanelv3ssr, muapiv2(not support) API_INTERFACE = 'sspanelv2' #mudbjson, sspanelv2, sspanelv3, sspanelv3ssr, muapiv2(not support)
UPDATE_TIME = 60
SERVER_PUB_ADDR = '127.0.0.1' # mujson_mgr need this to generate ssr link
#mudb #mudb
MUDB_FILE = 'mudb.json' MUDB_FILE = 'mudb.json'
# Mysql # Mysql
MYSQL_HOST = '127.0.0.1' MYSQL_CONFIG = 'usermysql.json'
MYSQL_PORT = 3306
MYSQL_USER = 'ss'
MYSQL_PASS = 'ss'
MYSQL_DB = 'shadowsocks'
MYSQL_UPDATE_TIME = 60
MYSQL_SSL_ENABLE = 0
MYSQL_SSL_CA = ''
MYSQL_SSL_CERT = ''
MYSQL_SSL_KEY = ''
# API # API
API_HOST = '127.0.0.1' MUAPI_CONFIG = 'usermuapi.json'
API_PORT = 80
API_PATH = '/mu/v2/'
API_TOKEN = 'abcdef'
API_UPDATE_TIME = 60
# Manager (ignore this)
MANAGE_PASS = 'ss233333333'
#if you want manage in other server you should set this value to global ip
MANAGE_BIND_IP = '127.0.0.1'
#make sure this port is idle
MANAGE_PORT = 23333

244
db_transfer.py

@ -13,65 +13,19 @@ import importloader
switchrule = None switchrule = None
db_instance = None db_instance = None
class DbTransfer(object): class TransferBase(object):
def __init__(self): def __init__(self):
import threading import threading
self.event = threading.Event()
self.key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable']
self.last_get_transfer = {} self.last_get_transfer = {}
self.last_update_transfer = {} self.last_update_transfer = {}
self.event = threading.Event()
self.user_pass = {} self.user_pass = {}
self.port_uid_table = {} self.port_uid_table = {}
self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30) self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30)
self.start_time = time.time()
def update_all_user(self, dt_transfer):
import cymysql
update_transfer = {}
query_head = 'UPDATE user'
query_sub_when = ''
query_sub_when2 = ''
query_sub_in = None
last_time = time.time()
for id in dt_transfer.keys(): def load_cfg(self):
transfer = dt_transfer[id] pass
update_trs = 1024 * max(2048 - self.user_pass.get(id, 0) * 64, 16)
if transfer[0] + transfer[1] < update_trs:
continue
if id in self.user_pass:
del self.user_pass[id]
query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * get_config().TRANSFER_MUL))
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * get_config().TRANSFER_MUL))
update_transfer[id] = transfer
if query_sub_in is not None:
query_sub_in += ',%s' % id
else:
query_sub_in = '%s' % id
if query_sub_when == '':
return update_transfer
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in
if get_config().MYSQL_SSL_ENABLE == 1:
conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT,
user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS,
db=get_config().MYSQL_DB, charset='utf8',
ssl={'ca':get_config().MYSQL_SSL_CA,'cert':get_config().MYSQL_SSL_CERT,'key':get_config().MYSQL_SSL_KEY})
else:
conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT,
user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS,
db=get_config().MYSQL_DB, charset='utf8')
cur = conn.cursor()
cur.execute(query_sql)
cur.close()
conn.commit()
conn.close()
return update_transfer
def push_db_all_user(self): def push_db_all_user(self):
#更新用户流量到数据库 #更新用户流量到数据库
@ -106,46 +60,9 @@ class DbTransfer(object):
self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]] self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]]
self.last_get_transfer = curr_transfer self.last_get_transfer = curr_transfer
def pull_db_all_user(self):
import cymysql
#数据库所有用户信息
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys()
except Exception as e:
keys = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable' ]
if get_config().MYSQL_SSL_ENABLE == 1:
conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT,
user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS,
db=get_config().MYSQL_DB, charset='utf8',
ssl={'ca':get_config().MYSQL_SSL_CA,'cert':get_config().MYSQL_SSL_CERT,'key':get_config().MYSQL_SSL_KEY})
else:
conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT,
user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS,
db=get_config().MYSQL_DB, charset='utf8')
cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user")
rows = []
for r in cur.fetchall():
d = {}
for column in range(len(keys)):
d[keys[column]] = r[column]
rows.append(d)
cur.close()
conn.close()
return rows
def cmp(self, val1, val2):
if type(val1) is bytes:
val1 = common.to_str(val1)
if type(val2) is bytes:
val2 = common.to_str(val2)
return val1 == val2
def del_server_out_of_bound_safe(self, last_rows, rows): def del_server_out_of_bound_safe(self, last_rows, rows):
#停止超流量的服务 #停止超流量的服务
#启动没超流量的服务 #启动没超流量的服务
#需要动态载入switchrule,以便实时修改规则
try: try:
switchrule = importloader.load('switchrule') switchrule = importloader.load('switchrule')
except Exception as e: except Exception as e:
@ -230,6 +147,13 @@ class DbTransfer(object):
logging.info('db start server at port [%s] pass [%s] protocol [%s] obfs [%s]' % (port, passwd, protocol, obfs)) logging.info('db start server at port [%s] pass [%s] protocol [%s] obfs [%s]' % (port, passwd, protocol, obfs))
ServerPool.get_instance().new_server(port, cfg) ServerPool.get_instance().new_server(port, cfg)
def cmp(self, val1, val2):
if type(val1) is bytes:
val1 = common.to_str(val1)
if type(val2) is bytes:
val2 = common.to_str(val2)
return val1 == val2
@staticmethod @staticmethod
def del_servers(): def del_servers():
for port in [v for v in ServerPool.get_instance().tcp_servers_pool.keys()]: for port in [v for v in ServerPool.get_instance().tcp_servers_pool.keys()]:
@ -251,6 +175,7 @@ class DbTransfer(object):
try: try:
while True: while True:
load_config() load_config()
db_instance.load_cfg()
try: try:
db_instance.push_db_all_user() db_instance.push_db_all_user()
rows = db_instance.pull_db_all_user() rows = db_instance.pull_db_all_user()
@ -260,7 +185,7 @@ class DbTransfer(object):
trace = traceback.format_exc() trace = traceback.format_exc()
logging.error(trace) logging.error(trace)
#logging.warn('db thread except:%s' % e) #logging.warn('db thread except:%s' % e)
if db_instance.event.wait(get_config().MYSQL_UPDATE_TIME) or not ServerPool.get_instance().thread.is_alive(): if db_instance.event.wait(get_config().UPDATE_TIME) or not ServerPool.get_instance().thread.is_alive():
break break
except KeyboardInterrupt as e: except KeyboardInterrupt as e:
pass pass
@ -273,9 +198,120 @@ class DbTransfer(object):
global db_instance global db_instance
db_instance.event.set() db_instance.event.set()
class DbTransfer(TransferBase):
def __init__(self):
super(DbTransfer, self).__init__()
self.cfg = {
"host": "127.0.0.1",
"port": 3306,
"user": "ss",
"password": "pass",
"db": "shadowsocks",
"node_id": 1,
"transfer_mul": 1.0,
"ssl_enable": 0,
"ssl_ca": "",
"ssl_cert": "",
"ssl_key": ""}
self.load_cfg()
def load_cfg(self):
import json
config_path = get_config().MYSQL_CONFIG
cfg = None
with open(config_path, 'r+') as f:
cfg = json.loads(f.read().decode('utf8'))
if cfg:
self.cfg.update(cfg)
def update_all_user(self, dt_transfer):
import cymysql
update_transfer = {}
query_head = 'UPDATE user'
query_sub_when = ''
query_sub_when2 = ''
query_sub_in = None
last_time = time.time()
for id in dt_transfer.keys():
transfer = dt_transfer[id]
update_trs = 1024 * max(2048 - self.user_pass.get(id, 0) * 64, 16)
if transfer[0] + transfer[1] < update_trs:
continue
if id in self.user_pass:
del self.user_pass[id]
query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * self.cfg["transfer_mul"]))
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * self.cfg["transfer_mul"]))
update_transfer[id] = transfer
if query_sub_in is not None:
query_sub_in += ',%s' % id
else:
query_sub_in = '%s' % id
if query_sub_when == '':
return update_transfer
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in
if self.cfg["ssl_enable"] == 1:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8',
ssl={'ca':self.cfg["ssl_enable"],'cert':self.cfg["ssl_enable"],'key':self.cfg["ssl_enable"]})
else:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8')
cur = conn.cursor()
cur.execute(query_sql)
cur.close()
conn.commit()
conn.close()
return update_transfer
def pull_db_all_user(self):
import cymysql
#数据库所有用户信息
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
except Exception as e:
keys = self.key_list
if self.cfg["ssl_enable"] == 1:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8',
ssl={'ca':self.cfg["ssl_enable"],'cert':self.cfg["ssl_enable"],'key':self.cfg["ssl_enable"]})
else:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8')
cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user")
rows = []
for r in cur.fetchall():
d = {}
for column in range(len(keys)):
d[keys[column]] = r[column]
rows.append(d)
cur.close()
conn.close()
return rows
class Dbv3Transfer(DbTransfer): class Dbv3Transfer(DbTransfer):
def __init__(self): def __init__(self):
super(Dbv3Transfer, self).__init__() super(Dbv3Transfer, self).__init__()
self.key_list += ['id', 'method']
if get_config().API_INTERFACE == 'sspanelv3ssr':
self.key_list += ['obfs', 'protocol', 'obfs_param', 'protocol_param']
self.start_time = time.time()
def update_all_user(self, dt_transfer): def update_all_user(self, dt_transfer):
import cymysql import cymysql
@ -290,15 +326,15 @@ class Dbv3Transfer(DbTransfer):
alive_user_count = len(self.onlineuser_cache) alive_user_count = len(self.onlineuser_cache)
bandwidth_thistime = 0 bandwidth_thistime = 0
if get_config().MYSQL_SSL_ENABLE == 1: if self.cfg["ssl_enable"] == 1:
conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT, conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS, user=self.cfg["user"], passwd=self.cfg["password"],
db=get_config().MYSQL_DB, charset='utf8', db=self.cfg["db"], charset='utf8',
ssl={'ca':get_config().MYSQL_SSL_CA,'cert':get_config().MYSQL_SSL_CERT,'key':get_config().MYSQL_SSL_KEY}) ssl={'ca':self.cfg["ssl_enable"],'cert':self.cfg["ssl_enable"],'key':self.cfg["ssl_enable"]})
else: else:
conn = cymysql.connect(host=get_config().MYSQL_HOST, port=get_config().MYSQL_PORT, conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=get_config().MYSQL_USER, passwd=get_config().MYSQL_PASS, user=self.cfg["user"], passwd=self.cfg["password"],
db=get_config().MYSQL_DB, charset='utf8') db=self.cfg["db"], charset='utf8')
conn.autocommit(True) conn.autocommit(True)
for id in dt_transfer.keys(): for id in dt_transfer.keys():
@ -312,8 +348,8 @@ class Dbv3Transfer(DbTransfer):
if id in self.user_pass: if id in self.user_pass:
del self.user_pass[id] del self.user_pass[id]
query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * get_config().TRANSFER_MUL)) query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * self.cfg["transfer_mul"]))
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * get_config().TRANSFER_MUL)) query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * self.cfg["transfer_mul"]))
update_transfer[id] = transfer update_transfer[id] = transfer
cur = conn.cursor() cur = conn.cursor()
@ -321,8 +357,8 @@ class Dbv3Transfer(DbTransfer):
if id in self.port_uid_table: if id in self.port_uid_table:
cur.execute("INSERT INTO `user_traffic_log` (`id`, `user_id`, `u`, `d`, `node_id`, `rate`, `traffic`, `log_time`) VALUES (NULL, '" + \ cur.execute("INSERT INTO `user_traffic_log` (`id`, `user_id`, `u`, `d`, `node_id`, `rate`, `traffic`, `log_time`) VALUES (NULL, '" + \
str(self.port_uid_table[id]) + "', '" + str(transfer[0]) + "', '" + str(transfer[1]) + "', '" + \ str(self.port_uid_table[id]) + "', '" + str(transfer[0]) + "', '" + str(transfer[1]) + "', '" + \
str(get_config().NODE_ID) + "', '" + str(get_config().TRANSFER_MUL) + "', '" + \ str(self.cfg["node_id"]) + "', '" + str(self.cfg["transfer_mul"]) + "', '" + \
self.traffic_format((transfer[0] + transfer[1]) * get_config().TRANSFER_MUL) + "', unix_timestamp()); ") self.traffic_format((transfer[0] + transfer[1]) * self.cfg["transfer_mul"]) + "', unix_timestamp()); ")
except: except:
logging.warn('no `user_traffic_log` in db') logging.warn('no `user_traffic_log` in db')
cur.close() cur.close()
@ -344,12 +380,12 @@ class Dbv3Transfer(DbTransfer):
try: try:
cur = conn.cursor() cur = conn.cursor()
cur.execute("INSERT INTO `ss_node_online_log` (`id`, `node_id`, `online_user`, `log_time`) VALUES (NULL, '" + \ cur.execute("INSERT INTO `ss_node_online_log` (`id`, `node_id`, `online_user`, `log_time`) VALUES (NULL, '" + \
str(get_config().NODE_ID) + "', '" + str(alive_user_count) + "', unix_timestamp()); ") str(self.cfg["node_id"]) + "', '" + str(alive_user_count) + "', unix_timestamp()); ")
cur.close() cur.close()
cur = conn.cursor() cur = conn.cursor()
cur.execute("INSERT INTO `ss_node_info_log` (`id`, `node_id`, `uptime`, `load`, `log_time`) VALUES (NULL, '" + \ cur.execute("INSERT INTO `ss_node_info_log` (`id`, `node_id`, `uptime`, `load`, `log_time`) VALUES (NULL, '" + \
str(get_config().NODE_ID) + "', '" + str(self.uptime()) + "', '" + \ str(self.cfg["node_id"]) + "', '" + str(self.uptime()) + "', '" + \
str(self.load()) + "', unix_timestamp()); ") str(self.load()) + "', unix_timestamp()); ")
cur.close() cur.close()
except: except:
@ -374,7 +410,7 @@ class Dbv3Transfer(DbTransfer):
return str(round((traffic / 1048576.0), 2)) + "MB"; return str(round((traffic / 1048576.0), 2)) + "MB";
class MuJsonTransfer(DbTransfer): class MuJsonTransfer(TransferBase):
def __init__(self): def __init__(self):
super(MuJsonTransfer, self).__init__() super(MuJsonTransfer, self).__init__()

13
mysql.json

@ -0,0 +1,13 @@
{
"host": "127.0.0.1",
"port": 3306,
"user": "ss",
"password": "pass",
"db": "shadowsocks",
"node_id": 1,
"transfer_mul": 1.0,
"ssl_enable": 0,
"ssl_ca": "",
"ssl_cert": "",
"ssl_key": ""
}

7
shadowsocks/crypto/table.py

@ -65,8 +65,15 @@ class TableCipher(object):
else: else:
return translate(data, self._decrypt_table) return translate(data, self._decrypt_table)
class NoneCipher(object):
def __init__(self, cipher_name, key, iv, op):
pass
def update(self, data):
return data
ciphers = { ciphers = {
'none': (0, 0, NoneCipher),
'table': (0, 0, TableCipher) 'table': (0, 0, TableCipher)
} }

9
switchrule.py

@ -1,11 +1,4 @@
from configloader import load_config, get_config def getKeys(key_list):
def getKeys():
key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable' ]
if get_config().API_INTERFACE == 'sspanelv3':
key_list += ['id', 'method']
elif get_config().API_INTERFACE == 'sspanelv3ssr':
key_list += ['id', 'method', 'obfs', 'protocol', 'obfs_param', 'protocol_param']
return key_list return key_list
#return key_list + ['plan'] # append the column name 'plan' #return key_list + ['plan'] # append the column name 'plan'

Loading…
Cancel
Save