From 92c9ba00097379fa1830f8e5bb926f8eea5c5c6c Mon Sep 17 00:00:00 2001 From: clowwindy Date: Sun, 4 May 2014 00:52:08 +0800 Subject: [PATCH] optimize --- shadowsocks/encrypt_salsa20.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/shadowsocks/encrypt_salsa20.py b/shadowsocks/encrypt_salsa20.py index 85b30ff..d69010f 100644 --- a/shadowsocks/encrypt_salsa20.py +++ b/shadowsocks/encrypt_salsa20.py @@ -4,7 +4,6 @@ import time import struct import logging import sys -import encrypt slow_xor = False imported = False @@ -32,8 +31,14 @@ def run_imports(): def numpy_xor(a, b): if slow_xor: return py_xor_str(a, b) - ab = numpy.frombuffer(a, dtype=numpy.byte) - bb = numpy.frombuffer(b, dtype=numpy.byte) + dtype = numpy.byte + if len(a) % 4 == 0: + dtype = numpy.uint32 + elif len(a) % 2 == 0: + dtype = numpy.uint16 + + ab = numpy.frombuffer(a, dtype=dtype) + bb = numpy.frombuffer(b, dtype=dtype) c = numpy.bitwise_xor(ab, bb) r = c.tostring() return r @@ -80,8 +85,7 @@ class Salsa20Cipher(object): if self._pos >= BLOCK_SIZE: self._next_stream() - self._pos -= BLOCK_SIZE - assert self._pos == 0 + self._pos = 0 if not data: break return ''.join(results) @@ -94,12 +98,12 @@ def test(): rounds = 1 * 1024 plain = urandom(BLOCK_SIZE * rounds) import M2Crypto.EVP - cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1, - key_as_bytes=0, d='md5', salt=None, i=1, - padding=1) - decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0, - key_as_bytes=0, d='md5', salt=None, i=1, - padding=1) + # cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1, + # key_as_bytes=0, d='md5', salt=None, i=1, + # padding=1) + # decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0, + # key_as_bytes=0, d='md5', salt=None, i=1, + # padding=1) cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) @@ -108,7 +112,7 @@ def test(): print 'start' start = time.time() while pos < len(plain): - l = random.randint(100, 16384) + l = random.randint(100, 32768) c = cipher.update(plain[pos:pos + l]) results.append(c) pos += l @@ -116,7 +120,7 @@ def test(): c = ''.join(results) results = [] while pos < len(plain): - l = random.randint(100, 16384) + l = random.randint(100, 32768) results.append(decipher.update(c[pos:pos + l])) pos += l end = time.time()