diff --git a/shadowsocks/server.py b/shadowsocks/server.py index bb31bf2..d383826 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -243,26 +243,49 @@ def main(): logging.error('cant resolve listen address') sys.exit(1) ThreadingTCPServer.address_family = addrs[0][0] + tcp_servers = [] + udp_servers = [] for port, key in config_port_password.items(): - server = ThreadingTCPServer((config_server, int(port)), Socks5Server) - server.key, server.method, server.timeout = key, config_method,\ - int(config_timeout) + tcp_server = ThreadingTCPServer((config_server, int(port)), + Socks5Server) + tcp_server.key = key + tcp_server.method = config_method + tcp_server.timeout = int(config_timeout) logging.info("starting server at %s:%d" % - tuple(server.server_address[:2])) - threading.Thread(target=server.serve_forever).start() - udprelay.UDPRelay(config_server, int(port), None, None, key, - config_method, int(config_timeout), False).start() + tuple(tcp_server.server_address[:2])) + tcp_servers.append(tcp_server) + udp_server = udprelay.UDPRelay(config_server, int(port), None, None, + key, config_method, int(config_timeout), + False) + udp_servers.append(udp_server) + + def run_server(): + for tcp_server in tcp_servers: + threading.Thread(target=tcp_server.serve_forever).start() + for udp_server in udp_servers: + udp_server.start() if int(config_workers) > 1: if os.name == 'posix': - # TODO only serve in workers, not in master - for i in xrange(0, int(config_workers) - 1): + children = [] + is_child = False + for i in xrange(0, int(config_workers)): r = os.fork() if r == 0: + logging.info('worker started') + is_child = True + run_server() break + else: + children.append(r) + if not is_child: + # master + for child in children: + os.waitpid(child, 0) else: logging.warn('worker is only available on Unix/Linux') - + else: + run_server() if __name__ == '__main__': diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index d926e30..5888918 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -134,11 +134,21 @@ class UDPRelay(object): self._method = method self._timeout = timeout self._is_local = is_local - self._eventloop = eventloop.EventLoop() self._cache = lru_cache.LRUCache(timeout=timeout, close_callback=self._close_client) self._client_fd_to_server_addr = lru_cache.LRUCache(timeout=timeout) + addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, + socket.SOCK_DGRAM, socket.SOL_UDP) + if len(addrs) == 0: + raise Exception("can't get addrinfo for %s:%d" % + (self._listen_addr, self._listen_port)) + af, socktype, proto, canonname, sa = addrs[0] + server_socket = socket.socket(af, socktype, proto) + server_socket.bind((self._listen_addr, self._listen_port)) + server_socket.setblocking(False) + self._server_socket = server_socket + def _close_client(self, client): if hasattr(client, 'close'): self._eventloop.remove(client) @@ -238,6 +248,7 @@ class UDPRelay(object): def _run(self): server_socket = self._server_socket + self._eventloop = eventloop.EventLoop() self._eventloop.add(server_socket, eventloop.POLL_IN) last_time = time.time() while True: @@ -263,19 +274,11 @@ class UDPRelay(object): last_time = now def start(self): - addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) - if len(addrs) == 0: - raise Exception("can't get addrinfo for %s:%d" % - (self._listen_addr, self._listen_port)) - af, socktype, proto, canonname, sa = addrs[0] - server_socket = socket.socket(af, socktype, proto) - server_socket.bind((self._listen_addr, self._listen_port)) - server_socket.setblocking(False) - self._server_socket = server_socket - t = threading.Thread(target=self._run) t.setName('UDPThread') - t.setDaemon(True) + t.setDaemon(False) t.start() self._thread = t + + def thread(self): + return self._thread