diff --git a/local.py b/local.py index 2941468..2df9878 100755 --- a/local.py +++ b/local.py @@ -34,6 +34,32 @@ import threading import time import SocketServer +def socket_create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """python 2.7 socket.create_connection""" + host, port = address + err = None + for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + + except socket.error as _: + err = _ + if sock is not None: + sock.close() + if err is not None: + raise err + else: + raise error("getaddrinfo returns an empty list") + def get_table(key): m = hashlib.md5() m.update(key) @@ -95,11 +121,10 @@ class Socks5Server(SocketServer.StreamRequestHandler): def handle(self): try: sock = self.connection - remote = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - remote.connect((SERVER, REMOTE_PORT)) + remote = socket_create_connection((SERVER, REMOTE_PORT)) self.handle_tcp(sock, remote) - except socket.error as e: - lock_print('socket error: ' + str(e)) + except socket.error: + lock_print('socket error') def main(): diff --git a/server.py b/server.py index 94f776c..a37293e 100755 --- a/server.py +++ b/server.py @@ -23,12 +23,45 @@ PORT = 8499 KEY = "foobar!" +try: + import gevent, gevent.monkey + gevent.monkey.patch_all(dns=gevent.version_info[0]>=1) +except ImportError: + gevent = None + import socket import select import SocketServer import struct import string import hashlib +import sys + +def socket_create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """python 2.7 socket.create_connection""" + host, port = address + err = None + for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + + except error as _: + err = _ + if sock is not None: + sock.close() + if err is not None: + raise err + else: + raise error("getaddrinfo returns an empty list") def get_table(key): m = hashlib.md5() @@ -42,7 +75,7 @@ def get_table(key): class ThreadingTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): - address_family = socket.AF_INET6 + pass class Socks5Server(SocketServer.StreamRequestHandler): @@ -90,8 +123,7 @@ class Socks5Server(SocketServer.StreamRequestHandler): reply = "\x05\x00\x00\x01" try: if mode == 1: - remote = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - remote.connect((addr, port[0])) + remote = socket_create_connection((addr, port[0])) local = remote.getsockname() reply += socket.inet_aton(local[0]) + struct.pack(">H", local[1]) @@ -106,11 +138,13 @@ class Socks5Server(SocketServer.StreamRequestHandler): if reply[1] == '\x00': if mode == 1: self.handle_tcp(sock, remote) - except socket.error: + except socket.error as e: print 'socket error' def main(): + if '-6' in sys.argv[1:]: + ThreadingTCPServer.address_family = socket.AF_INET6 server = ThreadingTCPServer(('', PORT), Socks5Server) server.allow_reuse_address = True print "starting server at port %d ..." % PORT