Browse Source

fix salsa20

1.4
clowwindy 11 years ago
parent
commit
af46629cd1
  1. 31
      shadowsocks/encrypt_salsa20.py

31
shadowsocks/encrypt_salsa20.py

@ -4,6 +4,7 @@ import time
import struct import struct
import logging import logging
import sys import sys
import encrypt
slow_xor = False slow_xor = False
imported = False imported = False
@ -72,14 +73,17 @@ class Salsa20Cipher(object):
cur_data = data[:remain] cur_data = data[:remain]
cur_data_len = len(cur_data) cur_data_len = len(cur_data)
cur_stream = self._stream[self._pos:self._pos + cur_data_len] cur_stream = self._stream[self._pos:self._pos + cur_data_len]
self._pos = (self._pos + cur_data_len) % BLOCK_SIZE self._pos = self._pos + cur_data_len
data = data[remain:] data = data[remain:]
results.append(numpy_xor(cur_data, cur_stream)) results.append(numpy_xor(cur_data, cur_stream))
if self._pos >= BLOCK_SIZE:
self._next_stream()
self._pos -= BLOCK_SIZE
assert self._pos == 0
if not data: if not data:
break break
self._next_stream()
return ''.join(results) return ''.join(results)
@ -87,8 +91,16 @@ def test():
from os import urandom from os import urandom
import random import random
rounds = 1 * 10 rounds = 1 * 1024
plain = urandom(BLOCK_SIZE * rounds) plain = urandom(BLOCK_SIZE * rounds)
import M2Crypto.EVP
cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1,
key_as_bytes=0, d='md5', salt=None, i=1,
padding=1)
decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0,
key_as_bytes=0, d='md5', salt=None, i=1,
padding=1)
cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
results = [] results = []
@ -96,13 +108,20 @@ def test():
print 'start' print 'start'
start = time.time() start = time.time()
while pos < len(plain): while pos < len(plain):
l = random.randint(10000, 32768) l = random.randint(100, 16384)
c = cipher.update(plain[pos:pos + l]) c = cipher.update(plain[pos:pos + l])
results.append(decipher.update(c)) results.append(c)
pos += l
pos = 0
c = ''.join(results)
results = []
while pos < len(plain):
l = random.randint(100, 16384)
results.append(decipher.update(c[pos:pos + l]))
pos += l pos += l
assert ''.join(results) == plain
end = time.time() end = time.time()
print BLOCK_SIZE * rounds / (end - start) print BLOCK_SIZE * rounds / (end - start)
assert ''.join(results) == plain
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save