Add IPv6 support to StatusServer and related classes. #119

Open
moodyjon wants to merge 8 commits from moodyjon/ipv6 into master
3 changed files with 113 additions and 42 deletions
Showing only changes of commit f370e263b5 - Show all commits

View file

@ -590,6 +590,20 @@ def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool
except (ipaddress.AddressValueError, ValueError): except (ipaddress.AddressValueError, ValueError):
return False return False
def is_valid_public_ipv6(address, allow_localhost: bool = False, allow_lan: bool = False):
try:
parsed_ip = ipaddress.ip_address(address)
if parsed_ip.is_loopback and allow_localhost:
return True
if allow_lan and parsed_ip.is_private:
return True
return not any((parsed_ip.version != 6, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private))
except (ipaddress.AddressValueError, ValueError):
return False
def is_valid_public_ip(address, **kwargs):
return is_valid_public_ipv6(address, **kwargs) or is_valid_public_ipv4(address, **kwargs)
def sha256(x): def sha256(x):
"""Simple wrapper of hashlib sha256.""" """Simple wrapper of hashlib sha256."""

View file

@ -271,7 +271,8 @@ class SessionManager:
f'{host}:{port:d} : {e!r}') f'{host}:{port:d} : {e!r}')
raise raise
else: else:
self.logger.info(f'{kind} server listening on {host}:{port:d}') for s in self.servers[kind].sockets:
self.logger.info(f'{kind} server listening on {s.getsockname()[:2]}')
async def _start_external_servers(self): async def _start_external_servers(self):
"""Start listening on TCP and SSL ports, but only if the respective """Start listening on TCP and SSL ports, but only if the respective

View file

@ -1,10 +1,17 @@
import asyncio import asyncio
import ipaddress
import socket
import struct import struct
from time import perf_counter from time import perf_counter
import logging import logging
from typing import Optional, Tuple, NamedTuple from typing import Optional, Tuple, NamedTuple, List, Union
from hub.schema.attrs import country_str_to_int, country_int_to_str from hub.schema.attrs import country_str_to_int, country_int_to_str
from hub.common import LRUCache, is_valid_public_ipv4 from hub.common import (
LRUCache,
is_valid_public_ip,
is_valid_public_ipv4,
is_valid_public_ipv6,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -36,48 +43,75 @@ class SPVPing(NamedTuple):
return decoded return decoded
PONG_ENCODING = b'!BBL32s4sH' PONG_ENCODING_PRE = b'!BBL32s'
PONG_ENCODING_POST = b'!H'
class SPVPong(NamedTuple): class SPVPong(NamedTuple):
protocol_version: int protocol_version: int
flags: int flags: int
height: int height: int
tip: bytes tip: bytes
source_address_raw: bytes ipaddr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
country: int country: int
FLAG_AVAILABLE = 0b00000001
FLAG_IPV6 = 0b00000010
def encode(self): def encode(self):
return struct.pack(PONG_ENCODING, *self) return (struct.pack(PONG_ENCODING_PRE, self.protocol_version, self.flags, self.height, self.tip) +
self.encode_address(self.ipaddr) +
struct.pack(PONG_ENCODING_POST, self.country))
@staticmethod @staticmethod
def encode_address(address: str): def encode_address(address: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]):
return bytes(int(b) for b in address.split(".")) if not isinstance(address, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
address = ipaddress.ip_address(address)
return address.packed
@classmethod @classmethod
def make(cls, flags: int, height: int, tip: bytes, source_address: str, country: str) -> bytes: def make(cls, flags: int, height: int, tip: bytes, source_address: str, country: str) -> bytes:
ipaddr = ipaddress.ip_address(source_address)
flags = (flags | cls.FLAG_IPV6) if ipaddr.version == 6 else (flags & ~cls.FLAG_IPV6)
return SPVPong( return SPVPong(
PROTOCOL_VERSION, flags, height, tip, PROTOCOL_VERSION, flags, height, tip,
cls.encode_address(source_address), ipaddr,
country_str_to_int(country) country_str_to_int(country)
).encode() )
@classmethod @classmethod
def make_sans_source_address(cls, flags: int, height: int, tip: bytes, country: str) -> Tuple[bytes, bytes]: def make_sans_source_address(cls, flags: int, height: int, tip: bytes, country: str) -> Tuple[bytes, bytes]:
pong = cls.make(flags, height, tip, '0.0.0.0', country) pong = cls.make(flags, height, tip, '0.0.0.0', country)
return pong[:38], pong[42:] pong = pong.encode()
return pong[0:1], pong[2:38], pong[42:]
@classmethod @classmethod
def decode(cls, packet: bytes): def decode(cls, packet: bytes):
return cls(*struct.unpack(PONG_ENCODING, packet[:44])) offset = 0
protocol_version, flags, height, tip = struct.unpack(PONG_ENCODING_PRE, packet[offset:offset+38])
offset += 38
if flags & cls.FLAG_IPV6:
addr_len = ipaddress.IPV6LENGTH // 8
ipaddr = ipaddress.ip_address(packet[offset:offset+addr_len])
offset += addr_len
else:
addr_len = ipaddress.IPV4LENGTH // 8
ipaddr = ipaddress.ip_address(packet[offset:offset+addr_len])
offset += addr_len
country, = struct.unpack(PONG_ENCODING_POST, packet[offset:offset+2])
offset += 2
return cls(protocol_version, flags, height, tip, ipaddr, country)
@property @property
def available(self) -> bool: def available(self) -> bool:
return (self.flags & 0b00000001) > 0 return (self.flags & self.FLAG_AVAILABLE) > 0
@property
def ipv6(self) -> bool:
return (self.flags & self.FLAG_IPV6) > 0
@property @property
def ip_address(self) -> str: def ip_address(self) -> str:
return ".".join(map(str, self.source_address_raw)) return self.ipaddr.compressed
@property @property
def country_name(self): def country_name(self):
@ -94,7 +128,8 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol):
def __init__( def __init__(
self, height: int, tip: bytes, country: str, self, height: int, tip: bytes, country: str,
throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10, throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10,
allow_localhost: bool = False, allow_lan: bool = False allow_localhost: bool = False, allow_lan: bool = False,
is_valid_ip = is_valid_public_ip,
): ):
super().__init__() super().__init__()
self.transport: Optional[asyncio.transports.DatagramTransport] = None self.transport: Optional[asyncio.transports.DatagramTransport] = None
@ -102,26 +137,27 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol):
self._tip = tip self._tip = tip
self._flags = 0 self._flags = 0
self._country = country self._country = country
self._left_cache = self._right_cache = None self._cache0 = self._cache1 = self.cache2 = None
self.update_cached_response() self.update_cached_response()
self._throttle = LRUCache(throttle_cache_size) self._throttle = LRUCache(throttle_cache_size)
self._should_log = LRUCache(throttle_cache_size) self._should_log = LRUCache(throttle_cache_size)
self._min_delay = 1 / throttle_reqs_per_sec self._min_delay = 1 / throttle_reqs_per_sec
self._allow_localhost = allow_localhost self._allow_localhost = allow_localhost
self._allow_lan = allow_lan self._allow_lan = allow_lan
self._is_valid_ip = is_valid_ip
self.closed = asyncio.Event() self.closed = asyncio.Event()
def update_cached_response(self): def update_cached_response(self):
self._left_cache, self._right_cache = SPVPong.make_sans_source_address( self._cache0, self._cache1, self._cache2 = SPVPong.make_sans_source_address(
self._flags, max(0, self._height), self._tip, self._country self._flags, max(0, self._height), self._tip, self._country
) )
def set_unavailable(self): def set_unavailable(self):
self._flags &= 0b11111110 self._flags &= ~SPVPong.FLAG_AVAILABLE
self.update_cached_response() self.update_cached_response()
def set_available(self): def set_available(self):
self._flags |= 0b00000001 self._flags |= SPVPong.FLAG_AVAILABLE
self.update_cached_response() self.update_cached_response()
def set_height(self, height: int, tip: bytes): def set_height(self, height: int, tip: bytes):
@ -141,17 +177,25 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol):
return False return False
def make_pong(self, host): def make_pong(self, host):
return self._left_cache + SPVPong.encode_address(host) + self._right_cache ipaddr = ipaddress.ip_address(host)
if ipaddr.version == 6:
flags = self._flags | SPVPong.FLAG_IPV6
else:
flags = self._flags & ~SPVPong.FLAG_IPV6
return (self._cache0 + flags.to_bytes(1, 'big') +
self._cache1 + SPVPong.encode_address(ipaddr) +
self._cache2)
def datagram_received(self, data: bytes, addr: Tuple[str, int]): def datagram_received(self, data: bytes, addr: Union[Tuple[str, int], Tuple[str, int, int, int]]):
if self.should_throttle(addr[0]): if self.should_throttle(addr[0]):
# print(f"throttled: {addr}")
return return
try: try:
SPVPing.decode(data) SPVPing.decode(data)
except (ValueError, struct.error, AttributeError, TypeError): except (ValueError, struct.error, AttributeError, TypeError):
# log.exception("derp") # log.exception("derp")
return return
if addr[1] >= 1024 and is_valid_public_ipv4( if addr[1] >= 1024 and self._is_valid_ip(
addr[0], allow_localhost=self._allow_localhost, allow_lan=self._allow_lan): addr[0], allow_localhost=self._allow_localhost, allow_lan=self._allow_lan):
self.transport.sendto(self.make_pong(addr[0]), addr) self.transport.sendto(self.make_pong(addr[0]), addr)
else: else:
@ -174,39 +218,51 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol):
class StatusServer: class StatusServer:
def __init__(self): def __init__(self):
self._protocol: Optional[SPVServerStatusProtocol] = None self._protocols: List[SPVServerStatusProtocol] = []
async def start(self, height: int, tip: bytes, country: str, interface: str, port: int, allow_lan: bool = False): async def start(self, height: int, tip: bytes, country: str, interface: str, port: int, allow_lan: bool = False):
if self.is_running: if self.is_running:
return return
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
interface = interface if interface.lower() != 'localhost' else '127.0.0.1' addr = interface if interface.lower() != 'localhost' else '127.0.0.1'
self._protocol = SPVServerStatusProtocol( proto = SPVServerStatusProtocol(
height, tip, country, allow_localhost=interface == '127.0.0.1', allow_lan=allow_lan height, tip, country, allow_localhost=addr == '127.0.0.1', allow_lan=allow_lan,
is_valid_ip=is_valid_public_ipv4,
) )
await loop.create_datagram_endpoint(lambda: self._protocol, (interface, port)) await loop.create_datagram_endpoint(lambda: proto, (addr, port), family=socket.AF_INET)
log.info("started udp status server on %s:%i", interface, port) log.warning("started udp4 status server on %s", proto.transport.get_extra_info('sockname')[:2])
self._protocols.append(proto)
if not socket.has_ipv6:
return
addr = interface if interface.lower() != 'localhost' else '::1'
proto = SPVServerStatusProtocol(
height, tip, country, allow_localhost=addr == '::1', allow_lan=allow_lan,
is_valid_ip=is_valid_public_ipv6,
)
await loop.create_datagram_endpoint(lambda: proto, (addr, port), family=socket.AF_INET6)
log.warning("started udp6 status server on %s", proto.transport.get_extra_info('sockname')[:2])
self._protocols.append(proto)
async def stop(self): async def stop(self):
if self.is_running: for p in self._protocols:
await self._protocol.close() await p.close()
self._protocol = None self._protocols.clear()
@property @property
def is_running(self): def is_running(self):
return self._protocol is not None return self._protocols
def set_unavailable(self): def set_unavailable(self):
if self.is_running: for p in self._protocols:
self._protocol.set_unavailable() p.set_unavailable()
def set_available(self): def set_available(self):
if self.is_running: for p in self._protocols:
self._protocol.set_available() p.set_available()
def set_height(self, height: int, tip: bytes): def set_height(self, height: int, tip: bytes):
if self.is_running: for p in self._protocols:
self._protocol.set_height(height, tip) p.set_height(height, tip)
class SPVStatusClientProtocol(asyncio.DatagramProtocol): class SPVStatusClientProtocol(asyncio.DatagramProtocol):
@ -217,9 +273,9 @@ class SPVStatusClientProtocol(asyncio.DatagramProtocol):
self.responses = responses self.responses = responses
self._ping_packet = SPVPing.make() self._ping_packet = SPVPing.make()
def datagram_received(self, data: bytes, addr: Tuple[str, int]): def datagram_received(self, data: bytes, addr: Union[Tuple[str, int], Tuple[str, int, int, int]]):
try: try:
self.responses.put_nowait(((addr, perf_counter()), SPVPong.decode(data))) self.responses.put_nowait(((addr[:2], perf_counter()), SPVPong.decode(data)))
except (ValueError, struct.error, AttributeError, TypeError, RuntimeError): except (ValueError, struct.error, AttributeError, TypeError, RuntimeError):
return return
@ -230,7 +286,7 @@ class SPVStatusClientProtocol(asyncio.DatagramProtocol):
self.transport = None self.transport = None
log.info("closed udp spv server selection client") log.info("closed udp spv server selection client")
def ping(self, server: Tuple[str, int]): def ping(self, server: Union[Tuple[str, int], Tuple[str, int, int, int]]):
self.transport.sendto(self._ping_packet, server) self.transport.sendto(self._ping_packet, server)
def close(self): def close(self):