Browse Source

Support redirect rule (#145)

Syntax:

match_hostname_regex:port(* means match all port)#redirect_dist_host:redirect_port

and in the config.json(user-config.json) redirect param should be a list,

for example

```
"redirect": ["*:8080#zhaoj.in:80","zhaojin97.cn:80#127.0.0.1:80","*#pku.edu.cn:80"]
```

This example means abnormal connection which connect to port 8080 will be redirected to zhaoj.in:80,and if there a http request with hostname zhaojin97.cn(you can set hosts file to test it) will be redirected to 127.0.0.1:80,and the rest of abnormal connection will be redirect to pku.edu.cn:80
dev
glzjin 8 years ago
committed by 破娃酱
parent
commit
c7815a0ee8
  1. 8
      shadowsocks/common.py
  2. 96
      shadowsocks/tcprelay.py

8
shadowsocks/common.py

@ -22,6 +22,7 @@ import socket
import struct
import logging
import binascii
import re
def compat_ord(s):
if type(s) == int:
@ -118,6 +119,13 @@ def is_ip(address):
return False
def match_regex(regex, text):
regex = re.compile(regex)
for item in regex.findall(text):
return True
return False
def patch_socket():
if not hasattr(socket, 'inet_pton'):
socket.inet_pton = inet_pton

96
shadowsocks/tcprelay.py

@ -151,7 +151,7 @@ class TCPRelayHandler(object):
server_info.tcp_mss = 1460
self._protocol.set_server_info(server_info)
self._redir_list = config.get('redirect', ["0.0.0.0:0"])
self._redir_list = config.get('redirect', ["*#0.0.0.0:0"])
self._bind = config.get('out_bind', '')
self._bindv6 = config.get('out_bindv6', '')
self._ignore_bind_list = config.get('ignore_bind', [])
@ -347,43 +347,77 @@ class TCPRelayHandler(object):
return True
def _get_redirect_host(self, client_address, ogn_data):
host_list = self._redir_list or ["0.0.0.0:0"]
hash_code = binascii.crc32(ogn_data)
addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP)
af, socktype, proto, canonname, sa = addrs[0]
address_bytes = common.inet_pton(af, sa[0])
if af == socket.AF_INET6:
addr = struct.unpack('>Q', address_bytes[8:])[0]
elif af == socket.AF_INET:
addr = struct.unpack('>I', address_bytes)[0]
else:
addr = 0
host_list = self._redir_list or ["*#0.0.0.0:0"]
host_port = []
match_port = False
if type(host_list) != list:
host_list = [host_list]
for host in host_list:
items = common.to_str(host).rsplit(':', 1)
if len(items) > 1:
try:
port = int(items[1])
if port == self._server._listen_port:
match_port = True
host_port.append((items[0], port))
except:
pass
items_sum = common.to_str(host_list[0]).rsplit('#', 1)
if len(items_sum) < 2:
hash_code = binascii.crc32(ogn_data)
addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP)
af, socktype, proto, canonname, sa = addrs[0]
address_bytes = common.inet_pton(af, sa[0])
if af == socket.AF_INET6:
addr = struct.unpack('>Q', address_bytes[8:])[0]
elif af == socket.AF_INET:
addr = struct.unpack('>I', address_bytes)[0]
else:
host_port.append((host, 80))
addr = 0
host_port = []
match_port = False
for host in host_list:
items = common.to_str(host).rsplit(':', 1)
if len(items) > 1:
try:
port = int(items[1])
if port == self._server._listen_port:
match_port = True
host_port.append((items[0], port))
except:
pass
else:
host_port.append((host, 80))
if match_port:
last_host_port = host_port
host_port = []
for host in last_host_port:
if host[1] == self._server._listen_port:
host_port.append(host)
if match_port:
last_host_port = host_port
return host_port[((hash_code & 0xffffffff) + addr) % len(host_port)]
else:
host_port = []
for host in last_host_port:
if host[1] == self._server._listen_port:
host_port.append(host)
for host in host_list:
items_sum = common.to_str(host).rsplit('#', 1)
items_match = common.to_str(items_sum[0]).rsplit(':', 1)
items = common.to_str(items_sum[1]).rsplit(':', 1)
if len(items_match) > 1:
if self._server._listen_port != int(items_match[1]):
continue
match_port = 0
if len(items_match) > 1:
if items_match[1] != "*":
try:
match_port = int(items_match[1])
except:
pass
if items_match[0] != "*" and common.match_regex(items_match[0], ogn_data) == False and \
not (match_port == self._server._listen_port or match_port == 0):
continue
if len(items) > 1:
try:
port = int(items[1])
return (items[0], port)
except:
pass
else:
return (items[0], 80)
return host_port[((hash_code & 0xffffffff) + addr) % len(host_port)]
return ("0.0.0.0", 0)
def _handel_protocol_error(self, client_address, ogn_data):
logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port))

Loading…
Cancel
Save