Exclude mapped IPv4 addresses. Add resolve_host() code from client.

This commit is contained in:
Jonathan Moody 2023-01-16 14:04:14 -06:00
parent 9c43c811a1
commit 14f2f3b55b

View file

@ -5,6 +5,7 @@ import hmac
import ipaddress
import logging
import logging.handlers
import socket
import typing
import collections
from bisect import insort_right
@ -153,6 +154,38 @@ def protocol_version(client_req, min_tuple, max_tuple):
return result, client_min
async def resolve_host(url: str, port: int, proto: str,
family: int = socket.AF_INET, all_results: bool = False) \
-> typing.Union[str, typing.List[str]]:
if proto not in ['udp', 'tcp']:
raise Exception("invalid protocol")
try:
if ipaddress.ip_address(url):
return [url] if all_results else url
except ValueError:
pass
loop = asyncio.get_running_loop()
records = await loop.getaddrinfo(
url, port,
proto=socket.IPPROTO_TCP if proto == 'tcp' else socket.IPPROTO_UDP,
type=socket.SOCK_STREAM if proto == 'tcp' else socket.SOCK_DGRAM,
family=family,
)
def addr_not_ipv4_mapped(rec):
_, _, _, _, sockaddr = rec
ipaddr = ipaddress.ip_address(sockaddr[0])
return ipaddr.version != 6 or not ipaddr.ipv4_mapped
records = filter(addr_not_ipv4_mapped, records)
results = [sockaddr[0] for fam, type, prot, canonname, sockaddr in records]
if not results and not all_results:
raise socket.gaierror(
socket.EAI_ADDRFAMILY,
'The specified network host does not have any network '
'addresses in the requested address family'
)
return results if all_results else results[0]
class LRUCacheWithMetrics:
__slots__ = [
'capacity',
@ -577,10 +610,10 @@ IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')
def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool = False):
try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
if allow_lan and parsed_ip.is_private:
return True
if parsed_ip.is_loopback:
return allow_localhost
if parsed_ip.is_private:
return allow_lan
if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)):
return False
@ -593,12 +626,14 @@ def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool
def is_valid_public_ipv6(address, allow_localhost: bool = False, allow_lan: bool = False):
try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
if allow_lan and parsed_ip.is_private:
return True
return not any((parsed_ip.version != 6, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private))
if parsed_ip.is_loopback:
return allow_localhost
if parsed_ip.is_private:
return allow_lan
return not any((parsed_ip.version != 6, parsed_ip.is_unspecified,
parsed_ip.is_link_local, parsed_ip.is_loopback,
parsed_ip.is_multicast, parsed_ip.is_reserved,
parsed_ip.is_private, parsed_ip.ipv4_mapped))
except (ipaddress.AddressValueError, ValueError):
return False