clients can connect to wallet server even when they are not reachable by UDP

This commit is contained in:
Lex Berezhny 2021-03-12 11:00:30 -05:00
parent fe60d4be88
commit 5da2ccbafc
6 changed files with 42 additions and 11 deletions

View file

@ -2,6 +2,7 @@ import logging
import asyncio import asyncio
import json import json
import socket import socket
import random
from time import perf_counter from time import perf_counter
from collections import defaultdict from collections import defaultdict
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -218,14 +219,14 @@ class Network:
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers'])) await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers']))
return hostname_to_ip, ip_to_hostnames return hostname_to_ip, ip_to_hostnames
async def get_n_fastest_spvs(self, n=5, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]: async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue() pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses) connection = SPVStatusClientProtocol(pong_responses)
sent_ping_timestamps = {} sent_ping_timestamps = {}
_, ip_to_hostnames = await self.resolve_spv_dns() _, ip_to_hostnames = await self.resolve_spv_dns()
log.info("%i possible spv servers to try (%i urls in config)", len(ip_to_hostnames), n = len(ip_to_hostnames)
len(self.config['default_servers'])) log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['default_servers']))
pongs = {} pongs = {}
try: try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0)) await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
@ -247,11 +248,12 @@ class Network:
except asyncio.TimeoutError: except asyncio.TimeoutError:
if pongs: if pongs:
log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames)) log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames))
return pongs
else: else:
log.warning("%i spv status probes failed, retrying later. servers tried: %s", log.warning("%i spv status probes failed, retrying later. servers tried: %s",
len(sent_ping_timestamps), len(sent_ping_timestamps),
', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items())) ', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items()))
return pongs return {random.choice(list(ip_to_hostnames)): None}
finally: finally:
connection.close() connection.close()

View file

@ -168,6 +168,7 @@ class SPVNode:
self.server = None self.server = None
self.hostname = 'localhost' self.hostname = 'localhost'
self.port = 50001 + node_number # avoid conflict with default daemon self.port = 50001 + node_number # avoid conflict with default daemon
self.udp_port = self.port
self.session_timeout = 600 self.session_timeout = 600
self.rpc_port = '0' # disabled by default self.rpc_port = '0' # disabled by default
@ -182,6 +183,7 @@ class SPVNode:
'REORG_LIMIT': '100', 'REORG_LIMIT': '100',
'HOST': self.hostname, 'HOST': self.hostname,
'TCP_PORT': str(self.port), 'TCP_PORT': str(self.port),
'UDP_PORT': str(self.udp_port),
'SESSION_TIMEOUT': str(self.session_timeout), 'SESSION_TIMEOUT': str(self.session_timeout),
'MAX_QUERY_WORKERS': '0', 'MAX_QUERY_WORKERS': '0',
'INDIVIDUAL_TAG_INDEXES': '', 'INDIVIDUAL_TAG_INDEXES': '',

View file

@ -57,6 +57,7 @@ class Env:
self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT) self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT)
# Server stuff # Server stuff
self.tcp_port = self.integer('TCP_PORT', None) self.tcp_port = self.integer('TCP_PORT', None)
self.udp_port = self.integer('UDP_PORT', self.tcp_port)
self.ssl_port = self.integer('SSL_PORT', None) self.ssl_port = self.integer('SSL_PORT', None)
if self.ssl_port: if self.ssl_port:
self.ssl_certfile = self.required('SSL_CERTFILE') self.ssl_certfile = self.required('SSL_CERTFILE')

View file

@ -111,8 +111,11 @@ class Server:
return _flag.wait() return _flag.wait()
await self.start_prometheus() await self.start_prometheus()
await self.bp.status_server.start(0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1] if self.env.udp_port:
, self.env.host, self.env.tcp_port) await self.bp.status_server.start(
0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1],
self.env.host, self.env.udp_port
)
await _start_cancellable(self.bp.fetch_and_process_blocks) await _start_cancellable(self.bp.fetch_and_process_blocks)
await self.db.populate_header_merkle_cache() await self.db.populate_header_merkle_cache()

View file

@ -138,7 +138,7 @@ class StatusServer:
self._protocol: Optional[SPVServerStatusProtocol] = None self._protocol: Optional[SPVServerStatusProtocol] = None
async def start(self, height: int, tip: bytes, interface: str, port: int): async def start(self, height: int, tip: bytes, interface: str, port: int):
if self._protocol: if self.is_running:
return return
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self._protocol = SPVServerStatusProtocol(height, tip) self._protocol = SPVServerStatusProtocol(height, tip)
@ -147,18 +147,25 @@ class StatusServer:
log.info("started udp status server on %s:%i", interface, port) log.info("started udp status server on %s:%i", interface, port)
def stop(self): def stop(self):
if self._protocol: if self.is_running:
self._protocol.close() self._protocol.close()
self._protocol = None self._protocol = None
@property
def is_running(self):
return self._protocol is not None
def set_unavailable(self): def set_unavailable(self):
self._protocol.set_unavailable() if self.is_running:
self._protocol.set_unavailable()
def set_available(self): def set_available(self):
self._protocol.set_available() if self.is_running:
self._protocol.set_available()
def set_height(self, height: int, tip: bytes): def set_height(self, height: int, tip: bytes):
self._protocol.set_height(height, tip) if self.is_running:
self._protocol.set_height(height, tip)
class SPVStatusClientProtocol(asyncio.DatagramProtocol): class SPVStatusClientProtocol(asyncio.DatagramProtocol):

View file

@ -4,6 +4,7 @@ import lbry
from unittest.mock import Mock from unittest.mock import Mock
from lbry.wallet.network import Network from lbry.wallet.network import Network
from lbry.wallet.orchstr8 import Conductor
from lbry.wallet.orchstr8.node import SPVNode from lbry.wallet.orchstr8.node import SPVNode
from lbry.wallet.rpc import RPCSession from lbry.wallet.rpc import RPCSession
from lbry.wallet.server.udp import StatusServer from lbry.wallet.server.udp import StatusServer
@ -146,6 +147,21 @@ class ReconnectTests(IntegrationTestCase):
# self.assertIsNone(self.ledger.network.session_pool.fastest_session) # self.assertIsNone(self.ledger.network.session_pool.fastest_session)
class UDPServerFailDiscoveryTest(AsyncioTestCase):
async def test_wallet_connects_despite_lack_of_udp(self):
conductor = Conductor()
conductor.spv_node.udp_port = '0'
await conductor.start_blockchain()
self.addCleanup(conductor.stop_blockchain)
await conductor.start_spv()
self.addCleanup(conductor.stop_spv)
self.assertFalse(conductor.spv_node.server.bp.status_server.is_running)
await asyncio.wait_for(conductor.start_wallet(), timeout=5)
self.addCleanup(conductor.stop_wallet)
self.assertFalse(conductor.wallet_node.ledger.network.is_connected)
class ServerPickingTestCase(AsyncioTestCase): class ServerPickingTestCase(AsyncioTestCase):
async def _make_udp_server(self, port): async def _make_udp_server(self, port):
s = StatusServer() s = StatusServer()