Browse Source

merge manyuser

master
breakwa11 9 years ago
parent
commit
7df3152608
  1. 38
      shadowsocks/asyncdns.py
  2. 10
      shadowsocks/common.py
  3. 18
      shadowsocks/encrypt.py
  4. 7
      shadowsocks/local.py
  5. 2
      shadowsocks/manager.py
  6. 124
      shadowsocks/obfsplugin/http_simple.py
  7. 255
      shadowsocks/obfsplugin/verify_simple.py
  8. 12
      shadowsocks/server.py
  9. 21
      shadowsocks/tcprelay.py

38
shadowsocks/asyncdns.py

@ -24,6 +24,13 @@ import struct
import re
import logging
if __name__ == '__main__':
import sys
import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
os.chdir(file_path)
sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import common, lru_cache, eventloop, shell
@ -71,6 +78,17 @@ QTYPE_CNAME = 5
QTYPE_NS = 2
QCLASS_IN = 1
def detect_ipv6_supprot():
if 'has_ipv6' in dir(socket):
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
try:
s.connect(('ipv6.google.com', 0))
return True
except:
pass
return False
IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot()
def build_address(address):
address = address.strip(b'.')
@ -338,17 +356,17 @@ class DNSResolver(object):
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
if not ip and self._hostname_status.get(hostname, STATUS_IPV4) \
== STATUS_IPV6:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
else:
if ip:
self._cache[hostname] = ip
self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
elif self._hostname_status.get(hostname, None) == STATUS_IPV4:
for question in response.questions:
if question[1] == QTYPE_AAAA:
if question[1] == QTYPE_A:
self._call_callback(hostname, None)
break
@ -414,6 +432,10 @@ class DNSResolver(object):
return
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
if IPV6_CONNECTION_SUPPORT:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
else:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
@ -421,6 +443,9 @@ class DNSResolver(object):
else:
arr.append(callback)
# TODO send again only if waited too long
if IPV6_CONNECTION_SUPPORT:
self._send_req(hostname, QTYPE_AAAA)
else:
self._send_req(hostname, QTYPE_A)
def close(self):
@ -479,3 +504,4 @@ def test():
if __name__ == '__main__':
test()

10
shadowsocks/common.py

@ -54,6 +54,16 @@ def to_str(s):
return s.decode('utf-8')
return s
def int32(x):
if x > 0xFFFFFFFF or x < 0:
x &= 0xFFFFFFFF
if x > 0x7FFFFFFF:
x = int(0x100000000 - x)
if x < 0x80000000:
return -x
else:
return -2147483648
return x
def inet_ntop(family, ipstr):
if family == socket.AF_INET:

18
shadowsocks/encrypt.py

@ -77,6 +77,7 @@ class Encryptor(object):
self.iv = None
self.iv_sent = False
self.cipher_iv = b''
self.iv_buf = b''
self.decipher = None
method = method.lower()
self._method_info = self.get_method_info(method)
@ -122,16 +123,21 @@ class Encryptor(object):
def decrypt(self, buf):
if len(buf) == 0:
return buf
if self.decipher is None:
if self.decipher is not None: #optimize
return self.decipher.update(buf)
decipher_iv_len = self._method_info[1]
decipher_iv = buf[:decipher_iv_len]
if len(self.iv_buf) <= decipher_iv_len:
self.iv_buf += buf
if len(self.iv_buf) > decipher_iv_len:
decipher_iv = self.iv_buf[:decipher_iv_len]
self.decipher = self.get_cipher(self.key, self.method, 0,
iv=decipher_iv)
buf = buf[decipher_iv_len:]
if len(buf) == 0:
return buf
buf = self.iv_buf[decipher_iv_len:]
del self.iv_buf
return self.decipher.update(buf)
else:
return b''
def encrypt_all(password, method, op, data):
result = []

7
shadowsocks/local.py

@ -23,7 +23,12 @@ import os
import logging
import signal
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
if __name__ == '__main__':
import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
os.chdir(file_path)
sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns

2
shadowsocks/manager.py

@ -168,7 +168,7 @@ class Manager(object):
send_data(r)
r.clear()
i = 0
if len(r) > 0:
if len(r) > 0 :
send_data(r)
self._statistics.clear()

124
shadowsocks/obfsplugin/http_simple.py

@ -22,8 +22,10 @@ import sys
import hashlib
import logging
import binascii
import struct
import base64
import datetime
import random
from shadowsocks import common
from shadowsocks.obfsplugin import plain
@ -66,14 +68,53 @@ class http_simple(plain.plain):
self.host = None
self.port = 0
self.recv_buffer = b''
self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0",
b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0",
b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36",
b"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.11 (KHTML, like Gecko) Ubuntu/11.10 Chromium/27.0.1453.93 Chrome/27.0.1453.93 Safari/537.36",
b"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:35.0) Gecko/20100101 Firefox/35.0",
b"Mozilla/5.0 (compatible; WOW64; MSIE 10.0; Windows NT 6.2)",
b"Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27",
b"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.3; Trident/7.0; .NET4.0E; .NET4.0C)",
b"Mozilla/5.0 (Windows NT 6.3; Trident/7.0; rv:11.0) like Gecko",
b"Mozilla/5.0 (Linux; Android 4.4; Nexus 5 Build/BuildID) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/30.0.0.0 Mobile Safari/537.36",
b"Mozilla/5.0 (iPad; CPU OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3",
b"Mozilla/5.0 (iPhone; CPU iPhone OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3"]
def encode_head(self, buf):
ret = b''
for ch in buf:
ret += '%' + binascii.hexlify(ch)
return ret
def client_encode(self, buf):
# TODO
if self.has_sent_header:
return buf
if len(buf) > 64:
headlen = random.randint(1, 64)
else:
headlen = len(buf)
headdata = buf[:headlen]
buf = buf[headlen:]
port = b''
if self.server_info.port != 80:
port = b':' + common.to_bytes(str(self.server_info.port))
http_head = b"GET /" + self.encode_head(headdata) + b" HTTP/1.1\r\n"
http_head += b"Host: " + (self.server_info.param or self.server_info.host) + port + b"\r\n"
http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n"
http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nDNT: 1\r\nConnection: keep-alive\r\n\r\n"
self.has_sent_header = True
return http_head + buf
def client_decode(self, buf):
# TODO
if self.has_recv_header:
return (buf, False)
pos = buf.find(b'\r\n\r\n')
if pos >= 0:
self.has_recv_header = True
return (buf[pos + 4:], False)
else:
return (b'', False)
def server_encode(self, buf):
if self.has_sent_header:
@ -110,17 +151,19 @@ class http_simple(plain.plain):
if self.has_recv_header:
return (buf, True, False)
buf = self.recv_buffer + buf
self.recv_buffer += buf
buf = self.recv_buffer
if len(buf) > 10:
if match_begin(buf, b'GET /') or match_begin(buf, b'POST /'):
if len(buf) > 65536:
self.recv_buffer = None
logging.warn('http_simple: over size')
return self.not_match_return(buf)
else: #not http header, run on original protocol
self.recv_buffer = None
logging.debug('http_simple: not match begin')
return self.not_match_return(buf)
else:
self.recv_buffer = buf
return (b'', True, False)
datas = buf.split(b'\r\n\r\n', 1)
@ -130,29 +173,47 @@ class http_simple(plain.plain):
if len(ret_buf) >= 15:
self.has_recv_header = True
return (ret_buf, True, False)
self.recv_buffer = buf
return (b'', True, False)
else:
self.recv_buffer = buf
return (b'', True, False)
return self.not_match_return(buf)
class http2_simple(plain.plain):
def __init__(self, method):
self.method = method
self.has_sent_header = False
self.has_recv_header = False
self.raw_trans_sent = False
self.host = None
self.port = 0
self.recv_buffer = b''
def client_encode(self, buf):
# TODO
if self.raw_trans_sent:
return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
http_head = b"GET / HTTP/1.1\r\n"
http_head += b"Host: " + (self.server_info.param or self.server_info.host) + port + b"\r\n"
http_head += b"Connection: Upgrade, HTTP2-Settings\r\nUpgrade: h2c\r\n"
http_head += b"HTTP2-Settings: " + base64.urlsafe_b64encode(buf) + b"\r\n"
return http_head + b"\r\n"
if self.has_recv_header:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf):
# TODO
if self.has_recv_header:
return (buf, False)
pos = buf.find(b'\r\n\r\n')
if pos >= 0:
self.has_recv_header = True
return (buf[pos + 4:], False)
else:
return (b'', False)
def server_encode(self, buf):
if self.has_sent_header:
@ -173,7 +234,8 @@ class http2_simple(plain.plain):
if self.has_recv_header:
return (buf, True, False)
buf = self.recv_buffer + buf
self.recv_buffer += buf
buf = self.recv_buffer
if len(buf) > 10:
if match_begin(buf, b'GET /'):
pass
@ -181,7 +243,6 @@ class http2_simple(plain.plain):
self.recv_buffer = None
return self.not_match_return(buf)
else:
self.recv_buffer = buf
return (b'', True, False)
datas = buf.split(b'\r\n\r\n', 1)
@ -193,10 +254,8 @@ class http2_simple(plain.plain):
ret_buf += datas[1]
self.has_recv_header = True
return (ret_buf, True, False)
self.recv_buffer = buf
return (b'', True, False)
else:
self.recv_buffer = buf
return (b'', True, False)
return self.not_match_return(buf)
@ -205,13 +264,30 @@ class tls_simple(plain.plain):
self.method = method
self.has_sent_header = False
self.has_recv_header = False
self.raw_trans_sent = False
def client_encode(self, buf):
if self.raw_trans_sent:
return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = b"\x03\x03" + os.urandom(32) + binascii.unhexlify(b"000016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202")
data = b"\x01\x00" + struct.pack('>H', len(data)) + data
data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data
return data
if self.has_recv_header:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf):
# (buffer_to_recv, is_need_to_encode_and_send_back)
if self.has_recv_header:
return (buf, False)
self.has_recv_header = True
return (b'', True)
def server_encode(self, buf):
if self.has_sent_header:
@ -239,13 +315,31 @@ class random_head(plain.plain):
self.method = method
self.has_sent_header = False
self.has_recv_header = False
self.raw_trans_sent = False
self.raw_trans_recv = False
self.send_buffer = b''
def client_encode(self, buf):
if self.raw_trans_sent:
return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = os.urandom(common.ord(os.urandom(1)[0]) % 96 + 4)
crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff
return data + struct.pack('<I', crc)
if self.raw_trans_recv:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf):
# (buffer_to_recv, is_need_to_encode_and_send_back)
if self.raw_trans_recv:
return (buf, False)
self.raw_trans_recv = True
return (b'', True)
def server_encode(self, buf):
if self.has_sent_header:

255
shadowsocks/obfsplugin/verify_simple.py

@ -23,7 +23,9 @@ import hashlib
import logging
import binascii
import base64
import time
import datetime
import random
import struct
import zlib
@ -38,9 +40,13 @@ def create_verify_obfs(method):
def create_verify_deflate(method):
return verify_deflate(method)
def create_auth_obfs(method):
return auth_simple(method)
obfs_map = {
'verify_simple': (create_verify_obfs,),
'verify_deflate': (create_verify_deflate,),
'auth_simple': (create_auth_obfs,),
}
def match_begin(str1, str2):
@ -49,7 +55,7 @@ def match_begin(str1, str2):
return True
return False
class sub_encode_obfs(object):
class obfs_verify_data(object):
def __init__(self):
self.sub_obfs = None
@ -60,7 +66,7 @@ class verify_base(plain.plain):
self.sub_obfs = None
def init_data(self):
return sub_encode_obfs()
return obfs_verify_data()
def set_server_info(self, server_info):
try:
@ -283,3 +289,248 @@ class verify_deflate(verify_base):
self.decrypt_packet_num += 1
return out_buf
class client_queue(object):
def __init__(self, begin_id):
self.front = begin_id
self.back = begin_id
self.alloc = {}
self.enable = True
self.last_update = time.time()
def update(self):
self.last_update = time.time()
def is_active(self):
return time.time() - self.last_update < 60 * 3
def re_enable(self, connection_id):
self.enable = True
self.alloc = {}
self.front = connection_id
self.back = connection_id
def insert(self, connection_id):
self.update()
if not self.enable:
logging.warn('auth_simple: not enable')
return False
if connection_id < self.front:
logging.warn('auth_simple: duplicate id')
return False
if not self.is_active():
self.re_enable(connection_id)
if connection_id > self.front + 0x4000:
logging.warn('auth_simple: wrong id')
return False
if connection_id in self.alloc:
logging.warn('auth_simple: duplicate id 2')
return False
if self.back <= connection_id:
self.back = connection_id + 1
self.alloc[connection_id] = 1
while (self.front in self.alloc) or self.front + 0x1000 < self.back:
if self.front in self.alloc:
del self.alloc[self.front]
self.front += 1
return True
class obfs_auth_data(object):
def __init__(self):
self.sub_obfs = None
self.client_id = {}
self.startup_time = int(time.time() - 30) & 0xFFFFFFFF
self.local_client_id = b''
self.connection_id = 0
def update(self, client_id, connection_id):
if client_id in self.client_id:
self.client_id[client_id].update()
def insert(self, client_id, connection_id):
max_client = 16
if client_id not in self.client_id or not self.client_id[client_id].enable:
active = 0
for c_id in self.client_id:
if self.client_id[c_id].is_active():
active += 1
if active >= max_client:
logging.warn('auth_simple: max active clients exceeded')
return False
if len(self.client_id) < max_client:
if client_id not in self.client_id:
self.client_id[client_id] = client_queue(connection_id)
else:
self.client_id[client_id].re_enable(connection_id)
return self.client_id[client_id].insert(connection_id)
keys = self.client_id.keys()
random.shuffle(keys)
for c_id in keys:
if not self.client_id[c_id].is_active() and self.client_id[c_id].enable:
if len(self.client_id) >= 256:
del self.client_id[c_id]
else:
self.client_id[c_id].enable = False
if client_id not in self.client_id:
self.client_id[client_id] = client_queue(connection_id)
else:
self.client_id[client_id].re_enable(connection_id)
return self.client_id[client_id].insert(connection_id)
logging.warn('auth_simple: no inactive client [assert]')
return False
else:
return self.client_id[client_id].insert(connection_id)
class auth_simple(verify_base):
def __init__(self, method):
super(auth_simple, self).__init__(method)
self.recv_buf = b''
self.unit_len = 8100
self.decrypt_packet_num = 0
self.raw_trans = False
self.has_sent_header = False
self.has_recv_header = False
self.client_id = 0
self.connection_id = 0
def init_data(self):
return obfs_auth_data()
def pack_data(self, buf):
if len(buf) == 0:
return b''
rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16)
data = common.chr(len(rnd_data) + 1) + rnd_data + buf
data = struct.pack('>H', len(data) + 6) + data
crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff
data += struct.pack('<I', crc)
return data
def auth_data(self):
utc_time = int(time.time()) & 0xFFFFFFFF
if self.server_info.data.connection_id > 0xFF000000:
self.server_info.data.local_client_id = b''
if not self.server_info.data.local_client_id:
self.server_info.data.local_client_id = os.urandom(4)
logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),))
self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF
self.server_info.data.connection_id += 1
return b''.join([struct.pack('<I', utc_time),
self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)])
def client_pre_encrypt(self, buf):
ret = b''
if not self.has_sent_header:
datalen = max(len(buf), common.ord(os.urandom(1)[0]) % 32 + 4)
ret += self.pack_data(self.auth_data() + buf[:datalen])
buf = buf[datalen:]
self.has_sent_header = True
while len(buf) > self.unit_len:
ret += self.pack_data(buf[:self.unit_len])
buf = buf[self.unit_len:]
ret += self.pack_data(buf)
return ret
def client_post_decrypt(self, buf):
if self.raw_trans:
return buf
self.recv_buf += buf
out_buf = b''
while len(self.recv_buf) > 2:
length = struct.unpack('>H', self.recv_buf[:2])[0]
if length >= 8192:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data error')
if length > len(self.recv_buf):
break
if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data uncorrect CRC32')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
self.recv_buf = self.recv_buf[length:]
if out_buf:
self.decrypt_packet_num += 1
return out_buf
def server_pre_encrypt(self, buf):
ret = b''
while len(buf) > self.unit_len:
ret += self.pack_data(buf[:self.unit_len])
buf = buf[self.unit_len:]
ret += self.pack_data(buf)
return ret
def server_post_decrypt(self, buf):
if self.raw_trans:
return buf
self.recv_buf += buf
out_buf = b''
while len(self.recv_buf) > 2:
length = struct.unpack('>H', self.recv_buf[:2])[0]
if length >= 8192:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
logging.info('auth_simple: over size')
return b'E'
else:
raise Exception('server_post_decrype data error')
if length > len(self.recv_buf):
break
if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff:
logging.info('auth_simple: crc32 error, data %s' % (binascii.hexlify(self.recv_buf[:length]),))
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return b'E'
else:
raise Exception('server_post_decrype data uncorrect CRC32')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
if not self.has_recv_header:
if len(out_buf) < 12:
self.raw_trans = True
self.recv_buf = b''
logging.info('auth_simple: too short')
return b'E'
utc_time = struct.unpack('<I', out_buf[:4])[0]
client_id = struct.unpack('<I', out_buf[4:8])[0]
connection_id = struct.unpack('<I', out_buf[8:12])[0]
time_dif = common.int32((int(time.time()) & 0xffffffff) - utc_time)
if time_dif < 60 * -3 or time_dif > 60 * 3 or common.int32(utc_time - self.server_info.data.startup_time) < 0:
self.raw_trans = True
self.recv_buf = b''
logging.info('auth_simple: wrong timestamp, time_dif %d, data %s' % (time_dif, binascii.hexlify(out_buf),))
return b'E'
elif self.server_info.data.insert(client_id, connection_id):
self.has_recv_header = True
out_buf = out_buf[12:]
self.client_id = client_id
self.connection_id = connection_id
else:
self.raw_trans = True
self.recv_buf = b''
logging.info('auth_simple: auth fail, data %s' % (binascii.hexlify(out_buf),))
return b'E'
self.recv_buf = self.recv_buf[length:]
if out_buf:
self.server_info.data.update(self.client_id, self.connection_id)
self.decrypt_packet_num += 1
return out_buf

12
shadowsocks/server.py

@ -23,7 +23,12 @@ import os
import logging
import signal
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
if __name__ == '__main__':
import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
os.chdir(file_path)
sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
asyncdns, manager
@ -68,7 +73,8 @@ def main():
obfs = config["obfs"]
a_config = config.copy()
ipv6_ok = False
logging.info("server start with password [%s] obfs [%s] method [%s]" % (password, obfs, a_config['method']))
logging.info("server start with password [%s] method [%s] obfs [%s] obfs_param [%s]" %
(password, a_config['method'], obfs, a_config['obfs_param']))
if 'server_ipv6' in a_config:
try:
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]":
@ -77,7 +83,7 @@ def main():
a_config['password'] = password
a_config['obfs'] = obfs
a_config['server'] = a_config['server_ipv6']
logging.info("starting server at %s:%d" %
logging.info("starting server at [%s]:%d" %
(a_config['server'], int(port)))
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False))
udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False))

21
shadowsocks/tcprelay.py

@ -30,11 +30,8 @@ import random
from shadowsocks import encrypt, obfs, eventloop, shell, common
from shadowsocks.common import pre_parse_header, parse_header
# set it 'False' to use both new protocol and the original shadowsocks protocal
# set it 'True' to use new protocol ONLY, to avoid GFW detecting
FORCE_NEW_PROTOCOL = False
# set it 'True' if run as a local client and connect to a server which support new protocol
CLIENT_NEW_PROTOCOL = False
CLIENT_NEW_PROTOCOL = False #deprecated
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512
@ -118,8 +115,6 @@ class TCPRelayHandler(object):
config['method'])
self._encrypt_correct = True
self._obfs = obfs.obfs(config['obfs'])
if server.obfs_data is None:
server.obfs_data = self._obfs.init_data()
server_info = obfs.server_info(server.obfs_data)
server_info.host = config['server']
server_info.port = server._listen_port
@ -268,8 +263,8 @@ class TCPRelayHandler(object):
if sock == self._local_sock and self._encrypt_correct:
obfs_encode = self._obfs.server_encode(data)
data = obfs_encode
if data:
l = len(data)
if l > 0:
s = sock.send(data)
if s < l:
data = data[s:]
@ -310,7 +305,7 @@ class TCPRelayHandler(object):
def _get_redirect_host(self, client_address, ogn_data):
# test
host_list = [(b"www.bing.com", 80), (b"www.microsoft.com", 80), (b"www.baidu.com", 443), (b"www.qq.com", 80), (b"www.csdn.net", 80), (b"1.2.3.4", 1000)]
host_list = [(b"www.bing.com", 80), (b"www.microsoft.com", 80), (b"cloudfront.com", 80), (b"cloudflare.com", 80), (b"1.2.3.4", 1000), (b"0.0.0.0", 0)]
hash_code = binascii.crc32(ogn_data)
addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP)
af, socktype, proto, canonname, sa = addrs[0]
@ -338,6 +333,7 @@ class TCPRelayHandler(object):
data = self._obfs.client_pre_encrypt(data)
data = self._encryptor.encrypt(data)
data = self._obfs.client_encode(data)
if data:
self._data_to_write_to_remote.append(data)
if self._is_local and not self._fastopen_connected and \
self._config['fast_open']:
@ -404,8 +400,6 @@ class TCPRelayHandler(object):
if self._is_local:
header_result = parse_header(data)
else:
if data is None or FORCE_NEW_PROTOCOL and common.ord(data[0]) != 0x88 and (~common.ord(data[0]) & 0xff) != 0x88:
data = self._handel_protocol_error(self._client_address, ogn_data)
data = pre_parse_header(data)
if data is None:
data = self._handel_protocol_error(self._client_address, ogn_data)
@ -436,6 +430,8 @@ class TCPRelayHandler(object):
data += struct.pack('<I', crc)
data = self._obfs.client_pre_encrypt(data)
data_to_send = self._encryptor.encrypt(data)
data_to_send = self._obfs.client_encode(data_to_send)
if data_to_send:
self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._chosen_server[0],
@ -635,7 +631,8 @@ class TCPRelayHandler(object):
if self._is_local:
obfs_decode = self._obfs.client_decode(data)
if obfs_decode[1]:
self._write_to_sock(b'', self._remote_sock)
send_back = self._obfs.client_encode(b'')
self._write_to_sock(send_back, self._remote_sock)
data = self._encryptor.decrypt(obfs_decode[0])
data = self._obfs.client_post_decrypt(data)
else:
@ -774,7 +771,7 @@ class TCPRelay(object):
self.server_transfer_ul = 0
self.server_transfer_dl = 0
self.server_connections = 0
self.obfs_data = None
self.obfs_data = obfs.obfs(config['obfs']).init_data()
self._timeout = config['timeout']
self._timeouts = [] # a list for all the handlers

Loading…
Cancel
Save