diff --git a/shadowsocks/encrypt.py b/shadowsocks/encrypt.py index 610e5b3..5fdaeff 100644 --- a/shadowsocks/encrypt.py +++ b/shadowsocks/encrypt.py @@ -32,6 +32,9 @@ def random_string(length): return M2Crypto.Rand.rand_bytes(length) +cached_tables = {} + + def get_table(key): m = hashlib.md5() m.update(key) @@ -42,12 +45,9 @@ def get_table(key): table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i))) return table -encrypt_table = None -decrypt_table = None - def init_table(key, method=None): - if method == 'table': + if method is not None and method == 'table': method = None if method: try: @@ -57,10 +57,12 @@ def init_table(key, method=None): 'default method') sys.exit(1) if not method: - global encrypt_table, decrypt_table + if key in cached_tables: + return cached_tables[key] encrypt_table = ''.join(get_table(key)) decrypt_table = string.maketrans(encrypt_table, string.maketrans('', '')) + cached_tables[key] = [encrypt_table, decrypt_table] else: try: Encryptor(key, method) # test if the settings if OK @@ -116,9 +118,10 @@ class Encryptor(object): self.iv_sent = False self.cipher_iv = '' self.decipher = None - if method is not None: + if method: self.cipher = self.get_cipher(key, method, 1, iv=random_string(32)) else: + self.encrypt_table, self.decrypt_table = init_table(key) self.cipher = None def get_cipher_len(self, method): @@ -150,8 +153,8 @@ class Encryptor(object): def encrypt(self, buf): if len(buf) == 0: return buf - if self.method is None: - return string.translate(buf, encrypt_table) + if not self.method: + return string.translate(buf, self.encrypt_table) else: if self.iv_sent: return self.cipher.update(buf) @@ -162,8 +165,8 @@ class Encryptor(object): def decrypt(self, buf): if len(buf) == 0: return buf - if self.method is None: - return string.translate(buf, decrypt_table) + if not self.method: + return string.translate(buf, self.decrypt_table) else: if self.decipher is None: decipher_iv_len = self.get_cipher_len(self.method)[1] @@ -174,3 +177,34 @@ class Encryptor(object): if len(buf) == 0: return buf return self.decipher.update(buf) + + +def encrypt_all(password, method, op, data): + if method is not None and method.lower() == 'table': + method = None + if not method: + [encrypt_table, decrypt_table] = init_table(password) + if op: + return string.translate(encrypt_table, data) + else: + return string.translate(decrypt_table, data) + else: + import M2Crypto.EVP + result = [] + method = method.lower() + (key_len, iv_len) = method_supported[method] + (key, _) = EVP_BytesToKey(password, key_len, iv_len) + if op: + iv = random_string(iv_len) + result.append(iv) + else: + iv = data[:iv_len] + data = data[iv_len:] + cipher = M2Crypto.EVP.Cipher(method.replace('-', '_'), key, iv, op, + key_as_bytes=0, d='md5', salt=None, i=1, + padding=1) + result.append(cipher.update(data)) + f = cipher.final() + if f: + result.append(f) + return ''.join(result) diff --git a/shadowsocks/event.py b/shadowsocks/eventloop.py similarity index 100% rename from shadowsocks/event.py rename to shadowsocks/eventloop.py diff --git a/shadowsocks/server.py b/shadowsocks/server.py index bc0570a..eb60595 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -215,7 +215,7 @@ def main(): tuple(server.server_address[:2])) threading.Thread(target=server.serve_forever).start() udprelay.UDPRelay(SERVER, int(port), None, None, key, METHOD, - int(TIMEOUT), False) + int(TIMEOUT), False).start() if __name__ == '__main__': diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 819c874..caebe12 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -68,7 +68,48 @@ import threading import socket -import event +import logging +import struct +import eventloop + +BUF_SIZE = 65536 + + +def parse_header(data): + addrtype = ord(data[0]) + dest_addr = None + dest_port = None + header_length = 0 + if addrtype == 1: + if len(data) >= 7: + dest_addr = socket.inet_ntoa(data[1:5]) + dest_port = struct.unpack('>H', data[5:7])[0] + header_length = 7 + else: + logging.warn('[udp] header is too short') + elif addrtype == 3: + if len(data) > 2: + addrlen = ord(data[1]) + if len(data) >= 2 + addrlen: + dest_addr = data[2:2 + addrlen] + dest_port = struct.unpack('>H', data[2 + addrlen:4 + addrlen])[0] + header_length = 4 + addrlen + else: + logging.warn('[udp] header is too short') + else: + logging.warn('[udp] header is too short') + elif addrtype == 4: + if len(data) >= 19: + dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17]) + dest_port = struct.unpack('>H', data[17:19])[0] + header_length = 19 + else: + logging.warn('[udp] header is too short') + else: + logging.warn('unsupported addrtype %d' % addrtype) + if dest_addr is None: + return None + return (addrtype, dest_addr, dest_port, header_length) class UDPRelay(object): @@ -83,25 +124,37 @@ class UDPRelay(object): self._method = method self._timeout = timeout self._is_local = is_local - self._eventloop = event.EventLoop() + self._eventloop = eventloop.EventLoop() self._cache = {} # TODO replace this dictionary with an LRU cache - def _handle_server(self, addr, sock, data): - # TODO - pass - - def _handle_client(self, addr, sock, data): + def _handle_server(self): + server = self._server_socket + data = server.recvfrom(BUF_SIZE) + if self._is_local: + frag = ord(data[2]) + if frag != 0: + logging.warn('drop a message since frag is not 0') + else: + data = data[3:] + else: + decrypt + + + def _handle_client(self, sock): # TODO pass def _run(self): - eventloop = self._eventloop server_socket = self._server_socket - eventloop.add(server_socket, event.MODE_IN) + self._eventloop.add(server_socket, eventloop.MODE_IN) is_local = self._is_local while True: - r = eventloop.poll() - # TODO + events = self._eventloop.poll() + for sock, event in events: + if sock == self._server_socket: + self._handle_server() + else: + self._handle_client(sock) def start(self): addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, @@ -115,4 +168,6 @@ class UDPRelay(object): server_socket.setblocking(False) self._server_socket = server_socket - threading.Thread(target=self._run).start() + t = threading.Thread(target=self._run) + t.setDaemon(True) + t.start()