diff --git a/shadowsocks/common.py b/shadowsocks/common.py index e1f574c..fdb2243 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -22,6 +22,7 @@ import socket import struct import logging import binascii +import re def compat_ord(s): if type(s) == int: @@ -118,6 +119,13 @@ def is_ip(address): return False +def match_regex(regex, text): + regex = re.compile(regex) + for item in regex.findall(text): + return True + return False + + def patch_socket(): if not hasattr(socket, 'inet_pton'): socket.inet_pton = inet_pton diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 6678301..d8c5ca1 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -151,7 +151,7 @@ class TCPRelayHandler(object): server_info.tcp_mss = 1460 self._protocol.set_server_info(server_info) - self._redir_list = config.get('redirect', ["0.0.0.0:0"]) + self._redir_list = config.get('redirect', ["*#0.0.0.0:0"]) self._bind = config.get('out_bind', '') self._bindv6 = config.get('out_bindv6', '') self._ignore_bind_list = config.get('ignore_bind', []) @@ -347,43 +347,77 @@ class TCPRelayHandler(object): return True def _get_redirect_host(self, client_address, ogn_data): - host_list = self._redir_list or ["0.0.0.0:0"] - hash_code = binascii.crc32(ogn_data) - addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP) - af, socktype, proto, canonname, sa = addrs[0] - address_bytes = common.inet_pton(af, sa[0]) - if af == socket.AF_INET6: - addr = struct.unpack('>Q', address_bytes[8:])[0] - elif af == socket.AF_INET: - addr = struct.unpack('>I', address_bytes)[0] - else: - addr = 0 + host_list = self._redir_list or ["*#0.0.0.0:0"] - host_port = [] - match_port = False if type(host_list) != list: host_list = [host_list] - for host in host_list: - items = common.to_str(host).rsplit(':', 1) - if len(items) > 1: - try: - port = int(items[1]) - if port == self._server._listen_port: - match_port = True - host_port.append((items[0], port)) - except: - pass + + items_sum = common.to_str(host_list[0]).rsplit('#', 1) + if len(items_sum) < 2: + hash_code = binascii.crc32(ogn_data) + addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP) + af, socktype, proto, canonname, sa = addrs[0] + address_bytes = common.inet_pton(af, sa[0]) + if af == socket.AF_INET6: + addr = struct.unpack('>Q', address_bytes[8:])[0] + elif af == socket.AF_INET: + addr = struct.unpack('>I', address_bytes)[0] else: - host_port.append((host, 80)) + addr = 0 + + host_port = [] + match_port = False + for host in host_list: + items = common.to_str(host).rsplit(':', 1) + if len(items) > 1: + try: + port = int(items[1]) + if port == self._server._listen_port: + match_port = True + host_port.append((items[0], port)) + except: + pass + else: + host_port.append((host, 80)) + + if match_port: + last_host_port = host_port + host_port = [] + for host in last_host_port: + if host[1] == self._server._listen_port: + host_port.append(host) - if match_port: - last_host_port = host_port + return host_port[((hash_code & 0xffffffff) + addr) % len(host_port)] + + else: host_port = [] - for host in last_host_port: - if host[1] == self._server._listen_port: - host_port.append(host) + for host in host_list: + items_sum = common.to_str(host).rsplit('#', 1) + items_match = common.to_str(items_sum[0]).rsplit(':', 1) + items = common.to_str(items_sum[1]).rsplit(':', 1) + if len(items_match) > 1: + if self._server._listen_port != int(items_match[1]): + continue + match_port = 0 + if len(items_match) > 1: + if items_match[1] != "*": + try: + match_port = int(items_match[1]) + except: + pass + if items_match[0] != "*" and common.match_regex(items_match[0], ogn_data) == False and \ + not (match_port == self._server._listen_port or match_port == 0): + continue + if len(items) > 1: + try: + port = int(items[1]) + return (items[0], port) + except: + pass + else: + return (items[0], 80) - return host_port[((hash_code & 0xffffffff) + addr) % len(host_port)] + return ("0.0.0.0", 0) def _handel_protocol_error(self, client_address, ogn_data): logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port))