Browse Source

add tests for common.py

auth
clowwindy 10 years ago
parent
commit
70dae91e7c
  1. 72
      shadowsocks/common.py
  2. 2
      shadowsocks/tcprelay.py

72
shadowsocks/common.py

@ -48,31 +48,44 @@ chr = compat_chr
def to_bytes(s): def to_bytes(s):
return s.encode('utf-8') if bytes != str:
if type(s) == str:
return s.encode('utf-8')
return s
def to_str(s):
if bytes != str:
if type(s) == bytes:
return s.decode('utf-8')
return s
def inet_ntop(family, ipstr): def inet_ntop(family, ipstr):
if family == socket.AF_INET: if family == socket.AF_INET:
return socket.inet_ntoa(ipstr) return to_bytes(socket.inet_ntoa(ipstr))
elif family == socket.AF_INET6: elif family == socket.AF_INET6:
v6addr = b':'.join((b'%02X%02X' % (ord(i), ord(j))) import re
for i, j in zip(ipstr[::2], ipstr[1::2])) v6addr = ':'.join(('%02X%02X' % (ord(i), ord(j))).lstrip('0')
return v6addr for i, j in zip(ipstr[::2], ipstr[1::2]))
v6addr = re.sub('::+', '::', v6addr, count=1)
return to_bytes(v6addr)
def inet_pton(family, addr): def inet_pton(family, addr):
addr = to_str(addr)
if family == socket.AF_INET: if family == socket.AF_INET:
return socket.inet_aton(addr) return socket.inet_aton(addr)
elif family == socket.AF_INET6: elif family == socket.AF_INET6:
if b'.' in addr: # a v4 addr if '.' in addr: # a v4 addr
v4addr = addr[addr.rindex(b':') + 1:] v4addr = addr[addr.rindex(':') + 1:]
v4addr = socket.inet_aton(v4addr) v4addr = socket.inet_aton(v4addr)
v4addr = map(lambda x: (b'%02X' % ord(x)), v4addr) v4addr = map(lambda x: ('%02X' % ord(x)), v4addr)
v4addr.insert(2, b':') v4addr.insert(2, ':')
newaddr = addr[:addr.rindex(b':') + 1] + b''.join(v4addr) newaddr = addr[:addr.rindex(':') + 1] + ''.join(v4addr)
return inet_pton(family, newaddr) return inet_pton(family, newaddr)
dbyts = [0] * 8 # 8 groups dbyts = [0] * 8 # 8 groups
grps = addr.split(b':') grps = addr.split(':')
for i, v in enumerate(grps): for i, v in enumerate(grps):
if v: if v:
dbyts[i] = int(v, 16) dbyts[i] = int(v, 16)
@ -105,9 +118,10 @@ ADDRTYPE_HOST = 3
def pack_addr(address): def pack_addr(address):
address_str = to_str(address)
for family in (socket.AF_INET, socket.AF_INET6): for family in (socket.AF_INET, socket.AF_INET6):
try: try:
r = socket.inet_pton(family, address) r = socket.inet_pton(family, address_str)
if family == socket.AF_INET6: if family == socket.AF_INET6:
return b'\x04' + r return b'\x04' + r
else: else:
@ -155,4 +169,36 @@ def parse_header(data):
addrtype) addrtype)
if dest_addr is None: if dest_addr is None:
return None return None
return addrtype, dest_addr, dest_port, header_length return addrtype, to_bytes(dest_addr), dest_port, header_length
def test_inet_conv():
ipv4 = b'8.8.4.4'
b = inet_pton(socket.AF_INET, ipv4)
assert inet_ntop(socket.AF_INET, b) == ipv4
ipv6 = b'2404:6800:4005:805::1011'
b = inet_pton(socket.AF_INET6, ipv6)
assert inet_ntop(socket.AF_INET6, b) == ipv6
def test_parse_header():
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
(3, b'www.google.com', 80, 18)
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
(1, b'8.8.8.8', 53, 7)
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
b'\x00\x10\x11\x00\x50')) == \
(4, b'2404:6800:4005:805::1011', 80, 19)
def test_pack_header():
assert pack_addr(b'8.8.8.8') == b'\x01\x08\x08\x08\x08'
assert pack_addr(b'2404:6800:4005:805::1011') == \
b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00\x00\x10\x11'
assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com'
if __name__ == '__main__':
test_inet_conv()
test_parse_header()
test_pack_header()

2
shadowsocks/tcprelay.py

@ -261,7 +261,7 @@ class TCPRelayHandler(object):
if header_result is None: if header_result is None:
raise Exception('can not parse header') raise Exception('can not parse header')
addrtype, remote_addr, remote_port, header_length = header_result addrtype, remote_addr, remote_port, header_length = header_result
logging.info('connecting %s:%d' % (remote_addr.decode('utf-8'), logging.info('connecting %s:%d' % (common.to_str(remote_addr),
remote_port)) remote_port))
self._remote_address = (remote_addr, remote_port) self._remote_address = (remote_addr, remote_port)
# pause reading # pause reading

Loading…
Cancel
Save