diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 98db3bb..899fec9 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -280,7 +280,11 @@ class DNSResolver(object): if type(black_hostname_list) != list: self._black_hostname_list = [] else: - self._black_hostname_list = black_hostname_list + self._black_hostname_list = list(map( + (lambda t: t if type(t) == bytes else t.encode('utf8')), + black_hostname_list + )) + print('black_hostname_list init as : ' + str(self._black_hostname_list)) self._sock = None self._servers = None self._parse_resolv() @@ -474,6 +478,7 @@ class DNSResolver(object): callback((hostname, ip), None) elif hostname in self._black_hostname_list: callback(None, Exception('hostname <%s> in the black hostname list' % hostname)) + return else: if not is_valid_hostname(hostname): callback(None, Exception('invalid hostname: %s' % hostname)) @@ -515,7 +520,10 @@ class DNSResolver(object): def test(): - dns_resolver = DNSResolver() + black_hostname_list = [ + 'baidu.com' + ] + dns_resolver = DNSResolver(black_hostname_list=black_hostname_list) loop = eventloop.EventLoop() dns_resolver.add_to_loop(loop) @@ -541,6 +549,7 @@ def test(): dns_resolver.resolve(b'google.com', make_callback()) dns_resolver.resolve('google.com', make_callback()) + dns_resolver.resolve('baidu.com', make_callback()) dns_resolver.resolve('example.com', make_callback()) dns_resolver.resolve('ipv6.google.com', make_callback()) dns_resolver.resolve('www.facebook.com', make_callback()) @@ -556,8 +565,24 @@ def test(): 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'long.hostname', make_callback()) - loop.run() + # test black_hostname_list + dns_resolver = DNSResolver(black_hostname_list=[]) + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=123) + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=None) + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() + dns_resolver = DNSResolver() + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() if __name__ == '__main__':