Browse Source

merge manyuser

master
breakwa11 9 years ago
parent
commit
7df3152608
  1. 38
      shadowsocks/asyncdns.py
  2. 10
      shadowsocks/common.py
  3. 18
      shadowsocks/encrypt.py
  4. 7
      shadowsocks/local.py
  5. 124
      shadowsocks/obfsplugin/http_simple.py
  6. 255
      shadowsocks/obfsplugin/verify_simple.py
  7. 12
      shadowsocks/server.py
  8. 21
      shadowsocks/tcprelay.py

38
shadowsocks/asyncdns.py

@ -24,6 +24,13 @@ import struct
import re import re
import logging import logging
if __name__ == '__main__':
import sys
import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
os.chdir(file_path)
sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import common, lru_cache, eventloop, shell from shadowsocks import common, lru_cache, eventloop, shell
@ -71,6 +78,17 @@ QTYPE_CNAME = 5
QTYPE_NS = 2 QTYPE_NS = 2
QCLASS_IN = 1 QCLASS_IN = 1
def detect_ipv6_supprot():
if 'has_ipv6' in dir(socket):
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
try:
s.connect(('ipv6.google.com', 0))
return True
except:
pass
return False
IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot()
def build_address(address): def build_address(address):
address = address.strip(b'.') address = address.strip(b'.')
@ -338,17 +356,17 @@ class DNSResolver(object):
answer[2] == QCLASS_IN: answer[2] == QCLASS_IN:
ip = answer[0] ip = answer[0]
break break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ if not ip and self._hostname_status.get(hostname, STATUS_IPV4) \
== STATUS_IPV4: == STATUS_IPV6:
self._hostname_status[hostname] = STATUS_IPV6 self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_AAAA) self._send_req(hostname, QTYPE_A)
else: else:
if ip: if ip:
self._cache[hostname] = ip self._cache[hostname] = ip
self._call_callback(hostname, ip) self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) == STATUS_IPV6: elif self._hostname_status.get(hostname, None) == STATUS_IPV4:
for question in response.questions: for question in response.questions:
if question[1] == QTYPE_AAAA: if question[1] == QTYPE_A:
self._call_callback(hostname, None) self._call_callback(hostname, None)
break break
@ -414,6 +432,10 @@ class DNSResolver(object):
return return
arr = self._hostname_to_cb.get(hostname, None) arr = self._hostname_to_cb.get(hostname, None)
if not arr: if not arr:
if IPV6_CONNECTION_SUPPORT:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
else:
self._hostname_status[hostname] = STATUS_IPV4 self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A) self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback] self._hostname_to_cb[hostname] = [callback]
@ -421,6 +443,9 @@ class DNSResolver(object):
else: else:
arr.append(callback) arr.append(callback)
# TODO send again only if waited too long # TODO send again only if waited too long
if IPV6_CONNECTION_SUPPORT:
self._send_req(hostname, QTYPE_AAAA)
else:
self._send_req(hostname, QTYPE_A) self._send_req(hostname, QTYPE_A)
def close(self): def close(self):
@ -479,3 +504,4 @@ def test():
if __name__ == '__main__': if __name__ == '__main__':
test() test()

10
shadowsocks/common.py

@ -54,6 +54,16 @@ def to_str(s):
return s.decode('utf-8') return s.decode('utf-8')
return s return s
def int32(x):
if x > 0xFFFFFFFF or x < 0:
x &= 0xFFFFFFFF
if x > 0x7FFFFFFF:
x = int(0x100000000 - x)
if x < 0x80000000:
return -x
else:
return -2147483648
return x
def inet_ntop(family, ipstr): def inet_ntop(family, ipstr):
if family == socket.AF_INET: if family == socket.AF_INET:

18
shadowsocks/encrypt.py

@ -77,6 +77,7 @@ class Encryptor(object):
self.iv = None self.iv = None
self.iv_sent = False self.iv_sent = False
self.cipher_iv = b'' self.cipher_iv = b''
self.iv_buf = b''
self.decipher = None self.decipher = None
method = method.lower() method = method.lower()
self._method_info = self.get_method_info(method) self._method_info = self.get_method_info(method)
@ -122,16 +123,21 @@ class Encryptor(object):
def decrypt(self, buf): def decrypt(self, buf):
if len(buf) == 0: if len(buf) == 0:
return buf return buf
if self.decipher is None: if self.decipher is not None: #optimize
return self.decipher.update(buf)
decipher_iv_len = self._method_info[1] decipher_iv_len = self._method_info[1]
decipher_iv = buf[:decipher_iv_len] if len(self.iv_buf) <= decipher_iv_len:
self.iv_buf += buf
if len(self.iv_buf) > decipher_iv_len:
decipher_iv = self.iv_buf[:decipher_iv_len]
self.decipher = self.get_cipher(self.key, self.method, 0, self.decipher = self.get_cipher(self.key, self.method, 0,
iv=decipher_iv) iv=decipher_iv)
buf = buf[decipher_iv_len:] buf = self.iv_buf[decipher_iv_len:]
if len(buf) == 0: del self.iv_buf
return buf
return self.decipher.update(buf) return self.decipher.update(buf)
else:
return b''
def encrypt_all(password, method, op, data): def encrypt_all(password, method, op, data):
result = [] result = []

7
shadowsocks/local.py

@ -23,7 +23,12 @@ import os
import logging import logging
import signal import signal
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) if __name__ == '__main__':
import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
os.chdir(file_path)
sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns

124
shadowsocks/obfsplugin/http_simple.py

@ -22,8 +22,10 @@ import sys
import hashlib import hashlib
import logging import logging
import binascii import binascii
import struct
import base64 import base64
import datetime import datetime
import random
from shadowsocks import common from shadowsocks import common
from shadowsocks.obfsplugin import plain from shadowsocks.obfsplugin import plain
@ -66,14 +68,53 @@ class http_simple(plain.plain):
self.host = None self.host = None
self.port = 0 self.port = 0
self.recv_buffer = b'' self.recv_buffer = b''
self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0",
b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0",
b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36",
b"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.11 (KHTML, like Gecko) Ubuntu/11.10 Chromium/27.0.1453.93 Chrome/27.0.1453.93 Safari/537.36",
b"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:35.0) Gecko/20100101 Firefox/35.0",
b"Mozilla/5.0 (compatible; WOW64; MSIE 10.0; Windows NT 6.2)",
b"Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27",
b"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.3; Trident/7.0; .NET4.0E; .NET4.0C)",
b"Mozilla/5.0 (Windows NT 6.3; Trident/7.0; rv:11.0) like Gecko",
b"Mozilla/5.0 (Linux; Android 4.4; Nexus 5 Build/BuildID) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/30.0.0.0 Mobile Safari/537.36",
b"Mozilla/5.0 (iPad; CPU OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3",
b"Mozilla/5.0 (iPhone; CPU iPhone OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3"]
def encode_head(self, buf):
ret = b''
for ch in buf:
ret += '%' + binascii.hexlify(ch)
return ret
def client_encode(self, buf): def client_encode(self, buf):
# TODO if self.has_sent_header:
return buf return buf
if len(buf) > 64:
headlen = random.randint(1, 64)
else:
headlen = len(buf)
headdata = buf[:headlen]
buf = buf[headlen:]
port = b''
if self.server_info.port != 80:
port = b':' + common.to_bytes(str(self.server_info.port))
http_head = b"GET /" + self.encode_head(headdata) + b" HTTP/1.1\r\n"
http_head += b"Host: " + (self.server_info.param or self.server_info.host) + port + b"\r\n"
http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n"
http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nDNT: 1\r\nConnection: keep-alive\r\n\r\n"
self.has_sent_header = True
return http_head + buf
def client_decode(self, buf): def client_decode(self, buf):
# TODO if self.has_recv_header:
return (buf, False) return (buf, False)
pos = buf.find(b'\r\n\r\n')
if pos >= 0:
self.has_recv_header = True
return (buf[pos + 4:], False)
else:
return (b'', False)
def server_encode(self, buf): def server_encode(self, buf):
if self.has_sent_header: if self.has_sent_header:
@ -110,17 +151,19 @@ class http_simple(plain.plain):
if self.has_recv_header: if self.has_recv_header:
return (buf, True, False) return (buf, True, False)
buf = self.recv_buffer + buf self.recv_buffer += buf
buf = self.recv_buffer
if len(buf) > 10: if len(buf) > 10:
if match_begin(buf, b'GET /') or match_begin(buf, b'POST /'): if match_begin(buf, b'GET /') or match_begin(buf, b'POST /'):
if len(buf) > 65536: if len(buf) > 65536:
self.recv_buffer = None self.recv_buffer = None
logging.warn('http_simple: over size')
return self.not_match_return(buf) return self.not_match_return(buf)
else: #not http header, run on original protocol else: #not http header, run on original protocol
self.recv_buffer = None self.recv_buffer = None
logging.debug('http_simple: not match begin')
return self.not_match_return(buf) return self.not_match_return(buf)
else: else:
self.recv_buffer = buf
return (b'', True, False) return (b'', True, False)
datas = buf.split(b'\r\n\r\n', 1) datas = buf.split(b'\r\n\r\n', 1)
@ -130,29 +173,47 @@ class http_simple(plain.plain):
if len(ret_buf) >= 15: if len(ret_buf) >= 15:
self.has_recv_header = True self.has_recv_header = True
return (ret_buf, True, False) return (ret_buf, True, False)
self.recv_buffer = buf
return (b'', True, False) return (b'', True, False)
else: else:
self.recv_buffer = buf
return (b'', True, False) return (b'', True, False)
return self.not_match_return(buf)
class http2_simple(plain.plain): class http2_simple(plain.plain):
def __init__(self, method): def __init__(self, method):
self.method = method self.method = method
self.has_sent_header = False self.has_sent_header = False
self.has_recv_header = False self.has_recv_header = False
self.raw_trans_sent = False
self.host = None self.host = None
self.port = 0 self.port = 0
self.recv_buffer = b'' self.recv_buffer = b''
def client_encode(self, buf): def client_encode(self, buf):
# TODO if self.raw_trans_sent:
return buf return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
http_head = b"GET / HTTP/1.1\r\n"
http_head += b"Host: " + (self.server_info.param or self.server_info.host) + port + b"\r\n"
http_head += b"Connection: Upgrade, HTTP2-Settings\r\nUpgrade: h2c\r\n"
http_head += b"HTTP2-Settings: " + base64.urlsafe_b64encode(buf) + b"\r\n"
return http_head + b"\r\n"
if self.has_recv_header:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf): def client_decode(self, buf):
# TODO if self.has_recv_header:
return (buf, False) return (buf, False)
pos = buf.find(b'\r\n\r\n')
if pos >= 0:
self.has_recv_header = True
return (buf[pos + 4:], False)
else:
return (b'', False)
def server_encode(self, buf): def server_encode(self, buf):
if self.has_sent_header: if self.has_sent_header:
@ -173,7 +234,8 @@ class http2_simple(plain.plain):
if self.has_recv_header: if self.has_recv_header:
return (buf, True, False) return (buf, True, False)
buf = self.recv_buffer + buf self.recv_buffer += buf
buf = self.recv_buffer
if len(buf) > 10: if len(buf) > 10:
if match_begin(buf, b'GET /'): if match_begin(buf, b'GET /'):
pass pass
@ -181,7 +243,6 @@ class http2_simple(plain.plain):
self.recv_buffer = None self.recv_buffer = None
return self.not_match_return(buf) return self.not_match_return(buf)
else: else:
self.recv_buffer = buf
return (b'', True, False) return (b'', True, False)
datas = buf.split(b'\r\n\r\n', 1) datas = buf.split(b'\r\n\r\n', 1)
@ -193,10 +254,8 @@ class http2_simple(plain.plain):
ret_buf += datas[1] ret_buf += datas[1]
self.has_recv_header = True self.has_recv_header = True
return (ret_buf, True, False) return (ret_buf, True, False)
self.recv_buffer = buf
return (b'', True, False) return (b'', True, False)
else: else:
self.recv_buffer = buf
return (b'', True, False) return (b'', True, False)
return self.not_match_return(buf) return self.not_match_return(buf)
@ -205,13 +264,30 @@ class tls_simple(plain.plain):
self.method = method self.method = method
self.has_sent_header = False self.has_sent_header = False
self.has_recv_header = False self.has_recv_header = False
self.raw_trans_sent = False
def client_encode(self, buf): def client_encode(self, buf):
if self.raw_trans_sent:
return buf return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = b"\x03\x03" + os.urandom(32) + binascii.unhexlify(b"000016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202")
data = b"\x01\x00" + struct.pack('>H', len(data)) + data
data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data
return data
if self.has_recv_header:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf): def client_decode(self, buf):
# (buffer_to_recv, is_need_to_encode_and_send_back) if self.has_recv_header:
return (buf, False) return (buf, False)
self.has_recv_header = True
return (b'', True)
def server_encode(self, buf): def server_encode(self, buf):
if self.has_sent_header: if self.has_sent_header:
@ -239,13 +315,31 @@ class random_head(plain.plain):
self.method = method self.method = method
self.has_sent_header = False self.has_sent_header = False
self.has_recv_header = False self.has_recv_header = False
self.raw_trans_sent = False
self.raw_trans_recv = False
self.send_buffer = b''
def client_encode(self, buf): def client_encode(self, buf):
if self.raw_trans_sent:
return buf return buf
self.send_buffer += buf
if not self.has_sent_header:
self.has_sent_header = True
data = os.urandom(common.ord(os.urandom(1)[0]) % 96 + 4)
crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff
return data + struct.pack('<I', crc)
if self.raw_trans_recv:
ret = self.send_buffer
self.send_buffer = b''
self.raw_trans_sent = True
return ret
return b''
def client_decode(self, buf): def client_decode(self, buf):
# (buffer_to_recv, is_need_to_encode_and_send_back) if self.raw_trans_recv:
return (buf, False) return (buf, False)
self.raw_trans_recv = True
return (b'', True)
def server_encode(self, buf): def server_encode(self, buf):
if self.has_sent_header: if self.has_sent_header:

255
shadowsocks/obfsplugin/verify_simple.py

@ -23,7 +23,9 @@ import hashlib
import logging import logging
import binascii import binascii
import base64 import base64
import time
import datetime import datetime
import random
import struct import struct
import zlib import zlib
@ -38,9 +40,13 @@ def create_verify_obfs(method):
def create_verify_deflate(method): def create_verify_deflate(method):
return verify_deflate(method) return verify_deflate(method)
def create_auth_obfs(method):
return auth_simple(method)
obfs_map = { obfs_map = {
'verify_simple': (create_verify_obfs,), 'verify_simple': (create_verify_obfs,),
'verify_deflate': (create_verify_deflate,), 'verify_deflate': (create_verify_deflate,),
'auth_simple': (create_auth_obfs,),
} }
def match_begin(str1, str2): def match_begin(str1, str2):
@ -49,7 +55,7 @@ def match_begin(str1, str2):
return True return True
return False return False
class sub_encode_obfs(object): class obfs_verify_data(object):
def __init__(self): def __init__(self):
self.sub_obfs = None self.sub_obfs = None
@ -60,7 +66,7 @@ class verify_base(plain.plain):
self.sub_obfs = None self.sub_obfs = None
def init_data(self): def init_data(self):
return sub_encode_obfs() return obfs_verify_data()
def set_server_info(self, server_info): def set_server_info(self, server_info):
try: try:
@ -283,3 +289,248 @@ class verify_deflate(verify_base):
self.decrypt_packet_num += 1 self.decrypt_packet_num += 1
return out_buf return out_buf
class client_queue(object):
def __init__(self, begin_id):
self.front = begin_id
self.back = begin_id
self.alloc = {}
self.enable = True
self.last_update = time.time()
def update(self):
self.last_update = time.time()
def is_active(self):
return time.time() - self.last_update < 60 * 3
def re_enable(self, connection_id):
self.enable = True
self.alloc = {}
self.front = connection_id
self.back = connection_id
def insert(self, connection_id):
self.update()
if not self.enable:
logging.warn('auth_simple: not enable')
return False
if connection_id < self.front:
logging.warn('auth_simple: duplicate id')
return False
if not self.is_active():
self.re_enable(connection_id)
if connection_id > self.front + 0x4000:
logging.warn('auth_simple: wrong id')
return False
if connection_id in self.alloc:
logging.warn('auth_simple: duplicate id 2')
return False
if self.back <= connection_id:
self.back = connection_id + 1
self.alloc[connection_id] = 1
while (self.front in self.alloc) or self.front + 0x1000 < self.back:
if self.front in self.alloc:
del self.alloc[self.front]
self.front += 1
return True
class obfs_auth_data(object):
def __init__(self):
self.sub_obfs = None
self.client_id = {}
self.startup_time = int(time.time() - 30) & 0xFFFFFFFF
self.local_client_id = b''
self.connection_id = 0
def update(self, client_id, connection_id):
if client_id in self.client_id:
self.client_id[client_id].update()
def insert(self, client_id, connection_id):
max_client = 16
if client_id not in self.client_id or not self.client_id[client_id].enable:
active = 0
for c_id in self.client_id:
if self.client_id[c_id].is_active():
active += 1
if active >= max_client:
logging.warn('auth_simple: max active clients exceeded')
return False
if len(self.client_id) < max_client:
if client_id not in self.client_id:
self.client_id[client_id] = client_queue(connection_id)
else:
self.client_id[client_id].re_enable(connection_id)
return self.client_id[client_id].insert(connection_id)
keys = self.client_id.keys()
random.shuffle(keys)
for c_id in keys:
if not self.client_id[c_id].is_active() and self.client_id[c_id].enable:
if len(self.client_id) >= 256:
del self.client_id[c_id]
else:
self.client_id[c_id].enable = False
if client_id not in self.client_id:
self.client_id[client_id] = client_queue(connection_id)
else:
self.client_id[client_id].re_enable(connection_id)
return self.client_id[client_id].insert(connection_id)
logging.warn('auth_simple: no inactive client [assert]')
return False
else:
return self.client_id[client_id].insert(connection_id)
class auth_simple(verify_base):
def __init__(self, method):
super(auth_simple, self).__init__(method)
self.recv_buf = b''
self.unit_len = 8100
self.decrypt_packet_num = 0
self.raw_trans = False
self.has_sent_header = False
self.has_recv_header = False
self.client_id = 0
self.connection_id = 0
def init_data(self):
return obfs_auth_data()
def pack_data(self, buf):
if len(buf) == 0:
return b''
rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16)
data = common.chr(len(rnd_data) + 1) + rnd_data + buf
data = struct.pack('>H', len(data) + 6) + data
crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff
data += struct.pack('<I', crc)
return data
def auth_data(self):
utc_time = int(time.time()) & 0xFFFFFFFF
if self.server_info.data.connection_id > 0xFF000000:
self.server_info.data.local_client_id = b''
if not self.server_info.data.local_client_id:
self.server_info.data.local_client_id = os.urandom(4)
logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),))
self.server_info.data.connection_id = struct.unpack('<I', os.urandom(4))[0] & 0xFFFFFF
self.server_info.data.connection_id += 1
return b''.join([struct.pack('<I', utc_time),
self.server_info.data.local_client_id,
struct.pack('<I', self.server_info.data.connection_id)])
def client_pre_encrypt(self, buf):
ret = b''
if not self.has_sent_header:
datalen = max(len(buf), common.ord(os.urandom(1)[0]) % 32 + 4)
ret += self.pack_data(self.auth_data() + buf[:datalen])
buf = buf[datalen:]
self.has_sent_header = True
while len(buf) > self.unit_len:
ret += self.pack_data(buf[:self.unit_len])
buf = buf[self.unit_len:]
ret += self.pack_data(buf)
return ret
def client_post_decrypt(self, buf):
if self.raw_trans:
return buf
self.recv_buf += buf
out_buf = b''
while len(self.recv_buf) > 2:
length = struct.unpack('>H', self.recv_buf[:2])[0]
if length >= 8192:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data error')
if length > len(self.recv_buf):
break
if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return None
else:
raise Exception('server_post_decrype data uncorrect CRC32')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
self.recv_buf = self.recv_buf[length:]
if out_buf:
self.decrypt_packet_num += 1
return out_buf
def server_pre_encrypt(self, buf):
ret = b''
while len(buf) > self.unit_len:
ret += self.pack_data(buf[:self.unit_len])
buf = buf[self.unit_len:]
ret += self.pack_data(buf)
return ret
def server_post_decrypt(self, buf):
if self.raw_trans:
return buf
self.recv_buf += buf
out_buf = b''
while len(self.recv_buf) > 2:
length = struct.unpack('>H', self.recv_buf[:2])[0]
if length >= 8192:
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
logging.info('auth_simple: over size')
return b'E'
else:
raise Exception('server_post_decrype data error')
if length > len(self.recv_buf):
break
if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff:
logging.info('auth_simple: crc32 error, data %s' % (binascii.hexlify(self.recv_buf[:length]),))
self.raw_trans = True
self.recv_buf = b''
if self.decrypt_packet_num == 0:
return b'E'
else:
raise Exception('server_post_decrype data uncorrect CRC32')
pos = common.ord(self.recv_buf[2]) + 2
out_buf += self.recv_buf[pos:length - 4]
if not self.has_recv_header:
if len(out_buf) < 12:
self.raw_trans = True
self.recv_buf = b''
logging.info('auth_simple: too short')
return b'E'
utc_time = struct.unpack('<I', out_buf[:4])[0]
client_id = struct.unpack('<I', out_buf[4:8])[0]
connection_id = struct.unpack('<I', out_buf[8:12])[0]
time_dif = common.int32((int(time.time()) & 0xffffffff) - utc_time)
if time_dif < 60 * -3 or time_dif > 60 * 3 or common.int32(utc_time - self.server_info.data.startup_time) < 0:
self.raw_trans = True
self.recv_buf = b''
logging.info('auth_simple: wrong timestamp, time_dif %d, data %s' % (time_dif, binascii.hexlify(out_buf),))
return b'E'
elif self.server_info.data.insert(client_id, connection_id):
self.has_recv_header = True
out_buf = out_buf[12:]
self.client_id = client_id
self.connection_id = connection_id
else:
self.raw_trans = True
self.recv_buf = b''
logging.info('auth_simple: auth fail, data %s' % (binascii.hexlify(out_buf),))
return b'E'
self.recv_buf = self.recv_buf[length:]
if out_buf:
self.server_info.data.update(self.client_id, self.connection_id)
self.decrypt_packet_num += 1
return out_buf

12
shadowsocks/server.py

@ -23,7 +23,12 @@ import os
import logging import logging
import signal import signal
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) if __name__ == '__main__':
import inspect
file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe())))
os.chdir(file_path)
sys.path.insert(0, os.path.join(file_path, '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \ from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
asyncdns, manager asyncdns, manager
@ -68,7 +73,8 @@ def main():
obfs = config["obfs"] obfs = config["obfs"]
a_config = config.copy() a_config = config.copy()
ipv6_ok = False ipv6_ok = False
logging.info("server start with password [%s] obfs [%s] method [%s]" % (password, obfs, a_config['method'])) logging.info("server start with password [%s] method [%s] obfs [%s] obfs_param [%s]" %
(password, a_config['method'], obfs, a_config['obfs_param']))
if 'server_ipv6' in a_config: if 'server_ipv6' in a_config:
try: try:
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]":
@ -77,7 +83,7 @@ def main():
a_config['password'] = password a_config['password'] = password
a_config['obfs'] = obfs a_config['obfs'] = obfs
a_config['server'] = a_config['server_ipv6'] a_config['server'] = a_config['server_ipv6']
logging.info("starting server at %s:%d" % logging.info("starting server at [%s]:%d" %
(a_config['server'], int(port))) (a_config['server'], int(port)))
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False)) tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False))
udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False)) udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False))

21
shadowsocks/tcprelay.py

@ -30,11 +30,8 @@ import random
from shadowsocks import encrypt, obfs, eventloop, shell, common from shadowsocks import encrypt, obfs, eventloop, shell, common
from shadowsocks.common import pre_parse_header, parse_header from shadowsocks.common import pre_parse_header, parse_header
# set it 'False' to use both new protocol and the original shadowsocks protocal
# set it 'True' to use new protocol ONLY, to avoid GFW detecting
FORCE_NEW_PROTOCOL = False
# set it 'True' if run as a local client and connect to a server which support new protocol # set it 'True' if run as a local client and connect to a server which support new protocol
CLIENT_NEW_PROTOCOL = False CLIENT_NEW_PROTOCOL = False #deprecated
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time # we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512 TIMEOUTS_CLEAN_SIZE = 512
@ -118,8 +115,6 @@ class TCPRelayHandler(object):
config['method']) config['method'])
self._encrypt_correct = True self._encrypt_correct = True
self._obfs = obfs.obfs(config['obfs']) self._obfs = obfs.obfs(config['obfs'])
if server.obfs_data is None:
server.obfs_data = self._obfs.init_data()
server_info = obfs.server_info(server.obfs_data) server_info = obfs.server_info(server.obfs_data)
server_info.host = config['server'] server_info.host = config['server']
server_info.port = server._listen_port server_info.port = server._listen_port
@ -268,8 +263,8 @@ class TCPRelayHandler(object):
if sock == self._local_sock and self._encrypt_correct: if sock == self._local_sock and self._encrypt_correct:
obfs_encode = self._obfs.server_encode(data) obfs_encode = self._obfs.server_encode(data)
data = obfs_encode data = obfs_encode
if data:
l = len(data) l = len(data)
if l > 0:
s = sock.send(data) s = sock.send(data)
if s < l: if s < l:
data = data[s:] data = data[s:]
@ -310,7 +305,7 @@ class TCPRelayHandler(object):
def _get_redirect_host(self, client_address, ogn_data): def _get_redirect_host(self, client_address, ogn_data):
# test # test
host_list = [(b"www.bing.com", 80), (b"www.microsoft.com", 80), (b"www.baidu.com", 443), (b"www.qq.com", 80), (b"www.csdn.net", 80), (b"1.2.3.4", 1000)] host_list = [(b"www.bing.com", 80), (b"www.microsoft.com", 80), (b"cloudfront.com", 80), (b"cloudflare.com", 80), (b"1.2.3.4", 1000), (b"0.0.0.0", 0)]
hash_code = binascii.crc32(ogn_data) hash_code = binascii.crc32(ogn_data)
addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP) addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP)
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
@ -338,6 +333,7 @@ class TCPRelayHandler(object):
data = self._obfs.client_pre_encrypt(data) data = self._obfs.client_pre_encrypt(data)
data = self._encryptor.encrypt(data) data = self._encryptor.encrypt(data)
data = self._obfs.client_encode(data) data = self._obfs.client_encode(data)
if data:
self._data_to_write_to_remote.append(data) self._data_to_write_to_remote.append(data)
if self._is_local and not self._fastopen_connected and \ if self._is_local and not self._fastopen_connected and \
self._config['fast_open']: self._config['fast_open']:
@ -404,8 +400,6 @@ class TCPRelayHandler(object):
if self._is_local: if self._is_local:
header_result = parse_header(data) header_result = parse_header(data)
else: else:
if data is None or FORCE_NEW_PROTOCOL and common.ord(data[0]) != 0x88 and (~common.ord(data[0]) & 0xff) != 0x88:
data = self._handel_protocol_error(self._client_address, ogn_data)
data = pre_parse_header(data) data = pre_parse_header(data)
if data is None: if data is None:
data = self._handel_protocol_error(self._client_address, ogn_data) data = self._handel_protocol_error(self._client_address, ogn_data)
@ -436,6 +430,8 @@ class TCPRelayHandler(object):
data += struct.pack('<I', crc) data += struct.pack('<I', crc)
data = self._obfs.client_pre_encrypt(data) data = self._obfs.client_pre_encrypt(data)
data_to_send = self._encryptor.encrypt(data) data_to_send = self._encryptor.encrypt(data)
data_to_send = self._obfs.client_encode(data_to_send)
if data_to_send:
self._data_to_write_to_remote.append(data_to_send) self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly # notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._chosen_server[0], self._dns_resolver.resolve(self._chosen_server[0],
@ -635,7 +631,8 @@ class TCPRelayHandler(object):
if self._is_local: if self._is_local:
obfs_decode = self._obfs.client_decode(data) obfs_decode = self._obfs.client_decode(data)
if obfs_decode[1]: if obfs_decode[1]:
self._write_to_sock(b'', self._remote_sock) send_back = self._obfs.client_encode(b'')
self._write_to_sock(send_back, self._remote_sock)
data = self._encryptor.decrypt(obfs_decode[0]) data = self._encryptor.decrypt(obfs_decode[0])
data = self._obfs.client_post_decrypt(data) data = self._obfs.client_post_decrypt(data)
else: else:
@ -774,7 +771,7 @@ class TCPRelay(object):
self.server_transfer_ul = 0 self.server_transfer_ul = 0
self.server_transfer_dl = 0 self.server_transfer_dl = 0
self.server_connections = 0 self.server_connections = 0
self.obfs_data = None self.obfs_data = obfs.obfs(config['obfs']).init_data()
self._timeout = config['timeout'] self._timeout = config['timeout']
self._timeouts = [] # a list for all the handlers self._timeouts = [] # a list for all the handlers

Loading…
Cancel
Save