From 20efdc70b3d97d636062bc02335f2801c82a4b96 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Thu, 21 Jan 2021 16:15:30 -0500 Subject: [PATCH] use UDP ping for wallet server selection -only connect to one spv server at a time -remove session pool --- lbry/extras/daemon/components.py | 17 +- lbry/wallet/network.py | 349 ++++++++++++++++--------------- 2 files changed, 188 insertions(+), 178 deletions(-) diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 50c45612a..aeff09f3d 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -119,13 +119,14 @@ class WalletComponent(Component): async def get_status(self): if self.wallet_manager is None: return - session_pool = self.wallet_manager.ledger.network.session_pool - sessions = session_pool.sessions + is_connected = self.wallet_manager.ledger.network.is_connected + sessions = [] connected = None - if self.wallet_manager.ledger.network.client: - addr_and_port = self.wallet_manager.ledger.network.client.server_address_and_port - if addr_and_port: - connected = f"{addr_and_port[0]}:{addr_and_port[1]}" + if is_connected: + addr, port = self.wallet_manager.ledger.network.client.server + connected = f"{addr}:{port}" + sessions.append(self.wallet_manager.ledger.network.client) + result = { 'connected': connected, 'connected_features': self.wallet_manager.ledger.network.server_features, @@ -137,8 +138,8 @@ class WalletComponent(Component): 'availability': session.available, } for session in sessions ], - 'known_servers': len(sessions), - 'available_servers': len(list(session_pool.available_sessions)) + 'known_servers': len(self.wallet_manager.ledger.network.config['default_servers']), + 'available_servers': 1 if is_connected else 0 } if self.wallet_manager.ledger.network.remote_height: diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index da23dc4c3..8155a300a 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -1,26 +1,27 @@ import logging import asyncio import json +import socket from time import perf_counter -from operator import itemgetter +from collections import defaultdict from typing import Dict, Optional, Tuple import aiohttp from lbry import __version__ +from lbry.utils import resolve_host from lbry.error import IncompatibleWalletServerError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.stream import StreamController +from lbry.wallet.server.udp import SPVStatusClientProtocol, SPVPong log = logging.getLogger(__name__) class ClientSession(BaseClientSession): - def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): + def __init__(self, *args, network: 'Network', server, timeout=30, **kwargs): self.network = network self.server = server super().__init__(*args, **kwargs) - self._on_disconnect_controller = StreamController() - self.on_disconnected = self._on_disconnect_controller.stream self.framer.max_size = self.max_errors = 1 << 32 self.timeout = timeout self.max_seconds_idle = timeout * 2 @@ -28,8 +29,6 @@ class ClientSession(BaseClientSession): self.connection_latency: Optional[float] = None self._response_samples = 0 self.pending_amount = 0 - self._on_connect_cb = on_connect_callback or (lambda: None) - self.trigger_urgent_reconnect = asyncio.Event() @property def available(self): @@ -56,7 +55,7 @@ class ClientSession(BaseClientSession): async def send_request(self, method, args=()): self.pending_amount += 1 - log.debug("send %s%s to %s:%i", method, tuple(args), *self.server) + log.debug("send %s%s to %s:%i (%i timeout)", method, tuple(args), self.server[0], self.server[1], self.timeout) try: if method == 'server.version': return await self.send_timed_server_version_request(args, self.timeout) @@ -93,38 +92,6 @@ class ClientSession(BaseClientSession): finally: self.pending_amount -= 1 - async def ensure_session(self): - # Handles reconnecting and maintaining a session alive - # TODO: change to 'ping' on newer protocol (above 1.2) - retry_delay = default_delay = 1.0 - while True: - try: - if self.is_closing(): - await self.create_connection(self.timeout) - await self.ensure_server_version() - self._on_connect_cb() - if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: - await self.ensure_server_version() - retry_delay = default_delay - except RPCError as e: - await self.close() - log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message) - retry_delay = 60 * 60 - except IncompatibleWalletServerError: - await self.close() - retry_delay = 60 * 60 - log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server) - except (asyncio.TimeoutError, OSError): - await self.close() - retry_delay = min(60, retry_delay * 2) - log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) - try: - await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) - except asyncio.TimeoutError: - pass - finally: - self.trigger_urgent_reconnect.clear() - async def ensure_server_version(self, required=None, timeout=3): required = required or self.network.PROTOCOL_VERSION response = await asyncio.wait_for( @@ -134,6 +101,25 @@ class ClientSession(BaseClientSession): raise IncompatibleWalletServerError(*self.server) return response + async def keepalive_loop(self, timeout=3, max_idle=60): + try: + while True: + now = perf_counter() + if min(self.last_send, self.last_packet_received) + max_idle < now: + await asyncio.wait_for( + self.send_request('server.ping', []), timeout=timeout + ) + else: + await asyncio.sleep(max(0, max_idle - (now - self.last_send))) + except Exception as err: + if isinstance(err, asyncio.CancelledError): + log.warning("closing connection to %s:%i", *self.server) + else: + log.exception("lost connection to spv") + finally: + if not self.is_closing(): + self._close() + async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) start = perf_counter() @@ -150,7 +136,9 @@ class ClientSession(BaseClientSession): self.response_time = None self.connection_latency = None self._response_samples = 0 - self._on_disconnect_controller.add(True) + # self._on_disconnect_controller.add(True) + if self.network: + self.network.disconnect() class Network: @@ -160,10 +148,9 @@ class Network: def __init__(self, ledger): self.ledger = ledger - self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.client: Optional[ClientSession] = None self.server_features = None - self._switch_task: Optional[asyncio.Task] = None + # self._switch_task: Optional[asyncio.Task] = None self.running = False self.remote_height: int = 0 self._concurrency = asyncio.Semaphore(16) @@ -183,58 +170,170 @@ class Network: } self.aiohttp_session: Optional[aiohttp.ClientSession] = None + self._urgent_need_reconnect = asyncio.Event() + self._loop_task: Optional[asyncio.Task] = None + self._keepalive_task: Optional[asyncio.Task] = None @property def config(self): return self.ledger.config - async def switch_forever(self): - while self.running: - if self.is_connected: - await self.client.on_disconnected.first - self.server_features = None - self.client = None - continue - self.client = await self.session_pool.wait_for_fastest_session() - log.info("Switching to SPV wallet server: %s:%d", *self.client.server) - try: - self.server_features = await self.get_server_features() - self._update_remote_height((await self.subscribe_headers(),)) - self._on_connected_controller.add(True) - log.info("Subscribed to headers: %s:%d", *self.client.server) - except (asyncio.TimeoutError, ConnectionError): - log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server) - self.client.synchronous_close() - self.server_features = None - self.client = None + def disconnect(self): + if self._keepalive_task and not self._keepalive_task.done(): + self._keepalive_task.cancel() + self._keepalive_task = None async def start(self): - self.running = True - self.aiohttp_session = aiohttp.ClientSession() - self._switch_task = asyncio.ensure_future(self.switch_forever()) - # this may become unnecessary when there are no more bugs found, - # but for now it helps understanding log reports - self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) - self.session_pool.start(self.config['default_servers']) - self.on_header.listen(self._update_remote_height) + if not self.running: + self.running = True + self.aiohttp_session = aiohttp.ClientSession() + self.on_header.listen(self._update_remote_height) + self._loop_task = asyncio.create_task(self.network_loop()) + self._urgent_need_reconnect.set() + + def loop_task_done_callback(f): + try: + f.result() + except Exception: + if self.running: + log.exception("wallet server connection loop crashed") + + self._loop_task.add_done_callback(loop_task_done_callback) + + async def resolve_spv_dns(self): + hostname_to_ip = {} + ip_to_hostnames = defaultdict(list) + + async def resolve_spv(server, port): + try: + server_addr = await resolve_host(server, port, 'udp') + hostname_to_ip[server] = (server_addr, port) + ip_to_hostnames[(server_addr, port)].append(server) + except socket.error: + log.warning("error looking up dns for spv server %s:%i", server, port) + except Exception: + log.exception("error looking up dns for spv server %s:%i", server, port) + + # accumulate the dns results + await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers'])) + return hostname_to_ip, ip_to_hostnames + + async def get_n_fastest_spvs(self, n=5, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]: + loop = asyncio.get_event_loop() + pong_responses = asyncio.Queue() + connection = SPVStatusClientProtocol(pong_responses) + sent_ping_timestamps = {} + _, ip_to_hostnames = await self.resolve_spv_dns() + log.info("%i possible spv servers to try (%i urls in config)", len(ip_to_hostnames), + len(self.config['default_servers'])) + pongs = {} + try: + await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0)) + # could raise OSError if it cant bind + start = perf_counter() + for server in ip_to_hostnames: + connection.ping(server) + sent_ping_timestamps[server] = perf_counter() + while len(pongs) < n: + (remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start)) + latency = ts - start + log.info("%s:%i has latency of %sms (available: %s, height: %i)", + '/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2), + pong.available, pong.height) + + if pong.available: + pongs[remote] = pong + return pongs + except asyncio.TimeoutError: + if pongs: + log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames)) + else: + log.warning("%i spv status probes failed, retrying later. servers tried: %s", + len(sent_ping_timestamps), + ', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items())) + return pongs + finally: + connection.close() + + async def connect_to_fastest(self) -> Optional[ClientSession]: + fastest_spvs = await self.get_n_fastest_spvs() + for (host, port) in fastest_spvs: + + client = ClientSession(network=self, server=(host, port)) + try: + await client.create_connection() + log.warning("Connected to spv server %s:%i", host, port) + await client.ensure_server_version() + return client + except (asyncio.TimeoutError, ConnectionError, OSError, IncompatibleWalletServerError, RPCError): + log.warning("Connecting to %s:%d failed", host, port) + client._close() + return + + async def network_loop(self): + sleep_delay = 30 + while self.running: + await asyncio.wait( + [asyncio.sleep(30), self._urgent_need_reconnect.wait()], return_when=asyncio.FIRST_COMPLETED + ) + if self._urgent_need_reconnect.is_set(): + sleep_delay = 30 + self._urgent_need_reconnect.clear() + if not self.is_connected: + client = await self.connect_to_fastest() + if not client: + log.warning("failed to connect to any spv servers, retrying later") + sleep_delay *= 2 + sleep_delay = min(sleep_delay, 300) + continue + log.debug("get spv server features %s:%i", *client.server) + features = await client.send_request('server.features', []) + self.client, self.server_features = client, features + log.info("subscribe to headers %s:%i", *client.server) + self._update_remote_height((await self.subscribe_headers(),)) + self._on_connected_controller.add(True) + server_str = "%s:%i" % client.server + log.info("maintaining connection to spv server %s", server_str) + self._keepalive_task = asyncio.create_task(self.client.keepalive_loop()) + try: + await asyncio.wait( + [self._keepalive_task, self._urgent_need_reconnect.wait()], + return_when=asyncio.FIRST_COMPLETED + ) + if self._urgent_need_reconnect.is_set(): + log.warning("urgent reconnect needed") + self._urgent_need_reconnect.clear() + if self._keepalive_task and not self._keepalive_task.done(): + self._keepalive_task.cancel() + except asyncio.CancelledError: + pass + finally: + self._keepalive_task = None + self.client = None + self.server_features = None + log.warning("connection lost to %s", server_str) + log.info("network loop finished") async def stop(self): - if self.running: - self.running = False + self.running = False + self.disconnect() + if self._loop_task and not self._loop_task.done(): + self._loop_task.cancel() + self._loop_task = None + if self.aiohttp_session: await self.aiohttp_session.close() - self._switch_task.cancel() - self.session_pool.stop() + self.aiohttp_session = None @property def is_connected(self): return self.client and not self.client.is_closing() - def rpc(self, list_or_method, args, restricted=True, session=None): - session = session or (self.client if restricted else self.session_pool.fastest_session) - if session and not session.is_closing(): + def rpc(self, list_or_method, args, restricted=True, session: Optional[ClientSession] = None): + if session or self.is_connected: + session = session or self.client return session.send_request(list_or_method, args) else: - self.session_pool.trigger_nodelay_connect() + self._urgent_need_reconnect.set() raise ConnectionError("Attempting to send rpc request when connection is not available.") async def retriable_call(self, function, *args, **kwargs): @@ -242,14 +341,15 @@ class Network: while self.running: if not self.is_connected: log.warning("Wallet server unavailable, waiting for it to come back and retry.") + self._urgent_need_reconnect.set() await self.on_connected.first - await self.session_pool.wait_for_fastest_session() try: return await function(*args, **kwargs) except asyncio.TimeoutError: log.warning("Wallet server call timed out, retrying.") except ConnectionError: - pass + log.warning("connection error") + raise asyncio.CancelledError() # if we got here, we are shutting down def _update_remote_height(self, header_args): @@ -340,94 +440,3 @@ class Network: result = await r.json() return result['result'] - -class SessionPool: - - def __init__(self, network: Network, timeout: float): - self.network = network - self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() - self.timeout = timeout - self.new_connection_event = asyncio.Event() - - @property - def online(self): - return any(not session.is_closing() for session in self.sessions) - - @property - def available_sessions(self): - return (session for session in self.sessions if session.available) - - @property - def fastest_session(self): - if not self.online: - return None - return min( - [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) - for session in self.available_sessions] or [(0, None)], - key=itemgetter(0) - )[1] - - def _get_session_connect_callback(self, session: ClientSession): - loop = asyncio.get_event_loop() - - def callback(): - duplicate_connections = [ - s for s in self.sessions - if s is not session and s.server_address_and_port == session.server_address_and_port - ] - already_connected = None if not duplicate_connections else duplicate_connections[0] - if already_connected: - self.sessions.pop(session).cancel() - session.synchronous_close() - log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour", - session.server[0], already_connected.server[0]) - loop.call_later(3600, self._connect_session, session.server) - return - self.new_connection_event.set() - log.info("connected to %s:%i", *session.server) - - return callback - - def _connect_session(self, server: Tuple[str, int]): - session = None - for s in self.sessions: - if s.server == server: - session = s - break - if not session: - session = ClientSession( - network=self.network, server=server - ) - session._on_connect_cb = self._get_session_connect_callback(session) - task = self.sessions.get(session, None) - if not task or task.done(): - task = asyncio.create_task(session.ensure_session()) - task.add_done_callback(lambda _: self.ensure_connections()) - self.sessions[session] = task - - def start(self, default_servers): - for server in default_servers: - self._connect_session(server) - - def stop(self): - for session, task in self.sessions.items(): - task.cancel() - session.synchronous_close() - self.sessions.clear() - - def ensure_connections(self): - for session in self.sessions: - self._connect_session(session.server) - - def trigger_nodelay_connect(self): - # used when other parts of the system sees we might have internet back - # bypasses the retry interval - for session in self.sessions: - session.trigger_urgent_reconnect.set() - - async def wait_for_fastest_session(self): - while not self.fastest_session: - self.trigger_nodelay_connect() - self.new_connection_event.clear() - await self.new_connection_event.wait() - return self.fastest_session