clients can connect to wallet server even when they are not reachable by UDP

This commit is contained in:
Lex Berezhny 2021-03-12 11:00:30 -05:00
parent fe60d4be88
commit 4343073c00
6 changed files with 42 additions and 11 deletions

View file

@ -2,6 +2,7 @@ import logging
import asyncio
import json
import socket
import random
from time import perf_counter
from collections import defaultdict
from typing import Dict, Optional, Tuple
@ -218,14 +219,14 @@ class Network:
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers']))
return hostname_to_ip, ip_to_hostnames
async def get_n_fastest_spvs(self, n=5, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]:
async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]:
loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses)
sent_ping_timestamps = {}
_, ip_to_hostnames = await self.resolve_spv_dns()
log.info("%i possible spv servers to try (%i urls in config)", len(ip_to_hostnames),
len(self.config['default_servers']))
n = len(ip_to_hostnames)
log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['default_servers']))
pongs = {}
try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
@ -247,11 +248,12 @@ class Network:
except asyncio.TimeoutError:
if pongs:
log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames))
return pongs
else:
log.warning("%i spv status probes failed, retrying later. servers tried: %s",
len(sent_ping_timestamps),
', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items()))
return pongs
return {random.choice(list(ip_to_hostnames)): None}
finally:
connection.close()

View file

@ -168,6 +168,7 @@ class SPVNode:
self.server = None
self.hostname = 'localhost'
self.port = 50001 + node_number # avoid conflict with default daemon
self.udp_port = self.port
self.session_timeout = 600
self.rpc_port = '0' # disabled by default
@ -182,6 +183,7 @@ class SPVNode:
'REORG_LIMIT': '100',
'HOST': self.hostname,
'TCP_PORT': str(self.port),
'UDP_PORT': str(self.udp_port),
'SESSION_TIMEOUT': str(self.session_timeout),
'MAX_QUERY_WORKERS': '0',
'INDIVIDUAL_TAG_INDEXES': '',

View file

@ -57,6 +57,7 @@ class Env:
self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT)
# Server stuff
self.tcp_port = self.integer('TCP_PORT', None)
self.udp_port = self.integer('UDP_PORT', self.tcp_port)
self.ssl_port = self.integer('SSL_PORT', None)
if self.ssl_port:
self.ssl_certfile = self.required('SSL_CERTFILE')

View file

@ -111,8 +111,11 @@ class Server:
return _flag.wait()
await self.start_prometheus()
await self.bp.status_server.start(0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1]
, self.env.host, self.env.tcp_port)
if self.env.udp_port:
await self.bp.status_server.start(
0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1],
self.env.host, self.env.udp_port
)
await _start_cancellable(self.bp.fetch_and_process_blocks)
await self.db.populate_header_merkle_cache()

View file

@ -138,7 +138,7 @@ class StatusServer:
self._protocol: Optional[SPVServerStatusProtocol] = None
async def start(self, height: int, tip: bytes, interface: str, port: int):
if self._protocol:
if self.is_running:
return
loop = asyncio.get_event_loop()
self._protocol = SPVServerStatusProtocol(height, tip)
@ -147,18 +147,25 @@ class StatusServer:
log.info("started udp status server on %s:%i", interface, port)
def stop(self):
if self._protocol:
if self.is_running:
self._protocol.close()
self._protocol = None
@property
def is_running(self):
return self._protocol is not None
def set_unavailable(self):
self._protocol.set_unavailable()
if self.is_running:
self._protocol.set_unavailable()
def set_available(self):
self._protocol.set_available()
if self.is_running:
self._protocol.set_available()
def set_height(self, height: int, tip: bytes):
self._protocol.set_height(height, tip)
if self.is_running:
self._protocol.set_height(height, tip)
class SPVStatusClientProtocol(asyncio.DatagramProtocol):

View file

@ -4,6 +4,7 @@ import lbry
from unittest.mock import Mock
from lbry.wallet.network import Network
from lbry.wallet.orchstr8 import Conductor
from lbry.wallet.orchstr8.node import SPVNode
from lbry.wallet.rpc import RPCSession
from lbry.wallet.server.udp import StatusServer
@ -146,6 +147,21 @@ class ReconnectTests(IntegrationTestCase):
# self.assertIsNone(self.ledger.network.session_pool.fastest_session)
class UDPServerFailDiscoveryTest(AsyncioTestCase):
async def test_wallet_connects_despite_lack_of_udp(self):
conductor = Conductor()
conductor.spv_node.udp_port = '0'
await conductor.start_blockchain()
self.addCleanup(conductor.stop_blockchain)
await conductor.start_spv()
self.addCleanup(conductor.stop_spv)
self.assertFalse(conductor.spv_node.server.bp.status_server.is_running)
await asyncio.wait_for(conductor.start_wallet(), timeout=5)
self.addCleanup(conductor.stop_wallet)
self.assertFalse(conductor.wallet_node.ledger.network.is_connected)
class ServerPickingTestCase(AsyncioTestCase):
async def _make_udp_server(self, port):
s = StatusServer()