Rework StatusServer start() to handle lists of addresses, hostnames.
Handle and retry EADDRINUSE errors.
This commit is contained in:
parent
14f2f3b55b
commit
fa0d03fe95
2 changed files with 64 additions and 25 deletions
|
@ -1,3 +1,4 @@
|
|||
import errno
|
||||
import time
|
||||
import typing
|
||||
import asyncio
|
||||
|
@ -170,10 +171,20 @@ class HubServerService(BlockchainReaderService):
|
|||
|
||||
async def start_status_server(self):
|
||||
if self.env.udp_port and int(self.env.udp_port):
|
||||
await self.status_server.start(
|
||||
0, bytes.fromhex(self.env.coin.GENESIS_HASH)[::-1], self.env.country,
|
||||
self.env.host, self.env.udp_port, self.env.allow_lan_udp
|
||||
)
|
||||
hosts = self.env.cs_host()
|
||||
started = False
|
||||
while not started:
|
||||
try:
|
||||
await self.status_server.start(
|
||||
0, bytes.fromhex(self.env.coin.GENESIS_HASH)[::-1], self.env.country,
|
||||
hosts, self.env.udp_port, self.env.allow_lan_udp
|
||||
)
|
||||
started = True
|
||||
except OSError as e:
|
||||
if e.errno is errno.EADDRINUSE:
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
raise
|
||||
|
||||
def _iter_start_tasks(self):
|
||||
yield self.start_status_server()
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Optional, Tuple, NamedTuple, List, Union
|
|||
from hub.schema.attrs import country_str_to_int, country_int_to_str
|
||||
from hub.common import (
|
||||
LRUCache,
|
||||
resolve_host,
|
||||
is_valid_public_ip,
|
||||
is_valid_public_ipv4,
|
||||
is_valid_public_ipv6,
|
||||
|
@ -220,29 +221,56 @@ class StatusServer:
|
|||
def __init__(self):
|
||||
self._protocols: List[SPVServerStatusProtocol] = []
|
||||
|
||||
async def start(self, height: int, tip: bytes, country: str, interface: str, port: int, allow_lan: bool = False):
|
||||
if self.is_running:
|
||||
return
|
||||
loop = asyncio.get_event_loop()
|
||||
addr = interface if interface.lower() != 'localhost' else '127.0.0.1'
|
||||
proto = SPVServerStatusProtocol(
|
||||
height, tip, country, allow_localhost=addr == '127.0.0.1', allow_lan=allow_lan,
|
||||
is_valid_ip=is_valid_public_ipv4,
|
||||
)
|
||||
await loop.create_datagram_endpoint(lambda: proto, (addr, port), family=socket.AF_INET)
|
||||
log.warning("started udp4 status server on %s", proto.transport.get_extra_info('sockname')[:2])
|
||||
self._protocols.append(proto)
|
||||
if not socket.has_ipv6:
|
||||
return
|
||||
addr = interface if interface.lower() != 'localhost' else '::1'
|
||||
proto = SPVServerStatusProtocol(
|
||||
height, tip, country, allow_localhost=addr == '::1', allow_lan=allow_lan,
|
||||
is_valid_ip=is_valid_public_ipv6,
|
||||
)
|
||||
await loop.create_datagram_endpoint(lambda: proto, (addr, port), family=socket.AF_INET6)
|
||||
log.warning("started udp6 status server on %s", proto.transport.get_extra_info('sockname')[:2])
|
||||
async def _start(self, height: int, tip: bytes, country: str, addr: str, port: int, allow_lan: bool = False):
|
||||
ipaddr = ipaddress.ip_address(addr)
|
||||
if ipaddr.version == 4:
|
||||
proto = SPVServerStatusProtocol(
|
||||
height, tip, country,
|
||||
allow_localhost=ipaddr.is_loopback or ipaddr.is_unspecified,
|
||||
allow_lan=allow_lan,
|
||||
is_valid_ip=is_valid_public_ipv4,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.create_datagram_endpoint(lambda: proto, (ipaddr.compressed, port), family=socket.AF_INET)
|
||||
elif ipaddr.version == 6:
|
||||
proto = SPVServerStatusProtocol(
|
||||
height, tip, country,
|
||||
allow_localhost=ipaddr.is_loopback or ipaddr.is_unspecified,
|
||||
allow_lan=allow_lan,
|
||||
is_valid_ip=is_valid_public_ipv6,
|
||||
)
|
||||
# Because dualstack / IPv4 mapped address behavior on an IPv6 socket
|
||||
# differs based on system config, create the socket with IPV6_V6ONLY.
|
||||
# This disables the IPv4 mapped feature, so we don't need to consider
|
||||
# when an IPv6 socket may interfere with IPv4 binding / traffic.
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
||||
sock.bind((ipaddr.compressed, port))
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.create_datagram_endpoint(lambda: proto, sock=sock)
|
||||
else:
|
||||
raise ValueError(f'unexpected IP address version {ipaddr.version}')
|
||||
log.info("started udp%i status server on %s", ipaddr.version, proto.transport.get_extra_info('sockname')[:2])
|
||||
self._protocols.append(proto)
|
||||
|
||||
async def start(self, height: int, tip: bytes, country: str, hosts: List[str], port: int, allow_lan: bool = False):
|
||||
if not isinstance(hosts, list):
|
||||
hosts = [hosts]
|
||||
try:
|
||||
for host in hosts:
|
||||
addr = None
|
||||
if not host:
|
||||
resolved = ['::', '0.0.0.0'] # unspecified address
|
||||
else:
|
||||
resolved = await resolve_host(host, port, 'udp', family=socket.AF_UNSPEC, all_results=True)
|
||||
for addr in resolved:
|
||||
await self._start(height, tip, country, addr, port, allow_lan)
|
||||
except Exception as e:
|
||||
if not isinstance(e, asyncio.CancelledError):
|
||||
log.error("UDP status server failed to listen on (%s:%i) : %s", addr or host, port, e)
|
||||
await self.stop()
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
for proto in self._protocols:
|
||||
await proto.close()
|
||||
|
|
Loading…
Reference in a new issue