diff --git a/hub/common.py b/hub/common.py index fc147a2..31c9c1a 100644 --- a/hub/common.py +++ b/hub/common.py @@ -5,6 +5,7 @@ import hmac import ipaddress import logging import logging.handlers +import socket import typing import collections from bisect import insort_right @@ -153,6 +154,38 @@ def protocol_version(client_req, min_tuple, max_tuple): return result, client_min +async def resolve_host(url: str, port: int, proto: str, + family: int = socket.AF_INET, all_results: bool = False) \ + -> typing.Union[str, typing.List[str]]: + if proto not in ['udp', 'tcp']: + raise Exception("invalid protocol") + try: + if ipaddress.ip_address(url): + return [url] if all_results else url + except ValueError: + pass + loop = asyncio.get_running_loop() + records = await loop.getaddrinfo( + url, port, + proto=socket.IPPROTO_TCP if proto == 'tcp' else socket.IPPROTO_UDP, + type=socket.SOCK_STREAM if proto == 'tcp' else socket.SOCK_DGRAM, + family=family, + ) + def addr_not_ipv4_mapped(rec): + _, _, _, _, sockaddr = rec + ipaddr = ipaddress.ip_address(sockaddr[0]) + return ipaddr.version != 6 or not ipaddr.ipv4_mapped + records = filter(addr_not_ipv4_mapped, records) + results = [sockaddr[0] for fam, type, prot, canonname, sockaddr in records] + if not results and not all_results: + raise socket.gaierror( + socket.EAI_ADDRFAMILY, + 'The specified network host does not have any network ' + 'addresses in the requested address family' + ) + return results if all_results else results[0] + + class LRUCacheWithMetrics: __slots__ = [ 'capacity', @@ -577,10 +610,10 @@ IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24') def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool = False): try: parsed_ip = ipaddress.ip_address(address) - if parsed_ip.is_loopback and allow_localhost: - return True - if allow_lan and parsed_ip.is_private: - return True + if parsed_ip.is_loopback: + return allow_localhost + if parsed_ip.is_private: + return allow_lan if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback, parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)): return False @@ -593,12 +626,14 @@ def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool def is_valid_public_ipv6(address, allow_localhost: bool = False, allow_lan: bool = False): try: parsed_ip = ipaddress.ip_address(address) - if parsed_ip.is_loopback and allow_localhost: - return True - if allow_lan and parsed_ip.is_private: - return True - return not any((parsed_ip.version != 6, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback, - parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)) + if parsed_ip.is_loopback: + return allow_localhost + if parsed_ip.is_private: + return allow_lan + return not any((parsed_ip.version != 6, parsed_ip.is_unspecified, + parsed_ip.is_link_local, parsed_ip.is_loopback, + parsed_ip.is_multicast, parsed_ip.is_reserved, + parsed_ip.is_private, parsed_ip.ipv4_mapped)) except (ipaddress.AddressValueError, ValueError): return False