From 3ce6e6f71432997e27838f3791b4bf33cf2a98ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A0=B4=E5=A8=83=E9=85=B1?= Date: Sat, 13 Aug 2016 23:36:00 +0800 Subject: [PATCH] fix transfer update --- db_transfer.py | 64 +++++++++++++++++---------- shadowsocks/eventloop.py | 2 +- shadowsocks/obfsplugin/http_simple.py | 3 +- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/db_transfer.py b/db_transfer.py index a34cd76..d08a442 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -18,12 +18,12 @@ class TransferBase(object): import threading self.event = threading.Event() self.key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable'] - self.last_get_transfer = {} - self.last_update_transfer = {} - self.user_pass = {} - self.port_uid_table = {} - self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30) - self.pull_ok = False + self.last_get_transfer = {} #上一次的实际流量 + self.last_update_transfer = {} #上一次更新到的流量(小于等于实际流量) + self.force_update_transfer = set() #强制推入数据库的ID + self.port_uid_table = {} #端口到uid的映射(仅v3以上有用) + self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30) #用户在线状态记录 + self.pull_ok = False #记录是否已经拉出过数据 def load_cfg(self): pass @@ -36,7 +36,19 @@ class TransferBase(object): curr_transfer = ServerPool.get_instance().get_servers_transfer() #上次和本次的增量 dt_transfer = {} + for id in self.force_update_transfer: #此表中的用户统计上次未计入的流量 + if id in self.last_get_transfer and id in last_transfer: + dt_transfer[id] = [self.last_get_transfer[id][0] - last_transfer[id][0], self.last_get_transfer[id][1] - last_transfer[id][1]] + for id in curr_transfer.keys(): + #有流量的,先记录在线状态 + if id in self.last_get_transfer: + if curr_transfer[id][0] + curr_transfer[id][1] > self.last_get_transfer[id][0] + self.last_get_transfer[id][1]: + self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1] + else: + self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1] + + #算出与上次记录的流量差值,保存于dt_transfer表 if id in last_transfer: if curr_transfer[id][0] + curr_transfer[id][1] - last_transfer[id][0] - last_transfer[id][1] <= 0: continue @@ -50,17 +62,18 @@ class TransferBase(object): if curr_transfer[id][0] + curr_transfer[id][1] <= 0: continue dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] - if id in self.last_get_transfer: - if curr_transfer[id][0] + curr_transfer[id][1] > self.last_get_transfer[id][0] + self.last_get_transfer[id][1]: - self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1] - else: - self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1] + self.onlineuser_cache.sweep() - update_transfer = self.update_all_user(dt_transfer) - for id in update_transfer.keys(): - last = self.last_update_transfer.get(id, [0,0]) - self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]] + update_transfer = self.update_all_user(dt_transfer) #返回有更新的表 + for id in update_transfer.keys(): #其增量加在此表 + if id in self.force_update_transfer: #但排除在force_update_transfer内的 + if id in self.last_update_transfer: + del self.last_update_transfer[id] + self.force_update_transfer.remove(id) + else: + last = self.last_update_transfer.get(id, [0,0]) + self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]] self.last_get_transfer = curr_transfer def del_server_out_of_bound_safe(self, last_rows, rows): @@ -125,11 +138,7 @@ class TransferBase(object): new_servers[port] = (passwd, cfg) elif allow and ServerPool.get_instance().server_run_status(port) is False: - #new_servers[port] = passwd - protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin')) - obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain')) - logging.info('db start server at port [%s] pass [%s] protocol [%s] obfs [%s]' % (port, passwd, protocol, obfs)) - ServerPool.get_instance().new_server(port, cfg) + self.new_server(port, passwd, cfg) for row in last_rows: if row['port'] in cur_servers: @@ -145,10 +154,15 @@ class TransferBase(object): self.event.wait(eventloop.TIMEOUT_PRECISION + eventloop.TIMEOUT_PRECISION / 2) for port in new_servers.keys(): passwd, cfg = new_servers[port] - protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin')) - obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain')) - logging.info('db start server at port [%s] pass [%s] protocol [%s] obfs [%s]' % (port, passwd, protocol, obfs)) - ServerPool.get_instance().new_server(port, cfg) + self.new_server(port, passwd, cfg) + + def new_server(self, port, passwd, cfg): + protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin')) + method = cfg.get('method', ServerPool.get_instance().config.get('method', 'None')) + obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain')) + logging.info('db start server at port [%s] pass [%s] protocol [%s] method [%s] obfs [%s]' % (port, passwd, protocol, method, obfs)) + ServerPool.get_instance().new_server(port, cfg) + self.force_update_transfer.add(port) def cmp(self, val1, val2): if type(val1) is bytes: @@ -206,6 +220,7 @@ class TransferBase(object): class DbTransfer(TransferBase): def __init__(self): super(DbTransfer, self).__init__() + self.user_pass = {} #记录更新此用户流量时被跳过多少次 self.cfg = { "host": "127.0.0.1", "port": 3306, @@ -242,6 +257,7 @@ class DbTransfer(TransferBase): for id in dt_transfer.keys(): transfer = dt_transfer[id] + #小于最低更新流量的先不更新 update_trs = 1024 * max(2048 - self.user_pass.get(id, 0) * 64, 16) if transfer[0] + transfer[1] < update_trs: continue diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index ce9c11b..2d7e696 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -53,7 +53,7 @@ EVENT_NAMES = { } # we check timeouts every TIMEOUT_PRECISION seconds -TIMEOUT_PRECISION = 10 +TIMEOUT_PRECISION = 5 class KqueueLoop(object): diff --git a/shadowsocks/obfsplugin/http_simple.py b/shadowsocks/obfsplugin/http_simple.py index 103ce16..42715fc 100644 --- a/shadowsocks/obfsplugin/http_simple.py +++ b/shadowsocks/obfsplugin/http_simple.py @@ -100,7 +100,8 @@ class http_simple(plain.plain): hosts = (self.server_info.obfs_param or self.server_info.host) pos = hosts.find("#") if pos >= 0: - body = hosts[pos + 1:].replace("\\n", "\r\n") + body = hosts[pos + 1:].replace("\n", "\r\n") + body = body.replace("\\n", "\r\n") hosts = hosts[:pos] hosts = hosts.split(',') host = random.choice(hosts)