use UDP ping for wallet server selection

-only connect to one spv server at a time
-remove session pool
This commit is contained in:
Jack Robison 2021-01-21 16:15:30 -05:00
parent f0d8fb8f1a
commit 20efdc70b3
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 188 additions and 178 deletions

View file

@ -119,13 +119,14 @@ class WalletComponent(Component):
async def get_status(self): async def get_status(self):
if self.wallet_manager is None: if self.wallet_manager is None:
return return
session_pool = self.wallet_manager.ledger.network.session_pool is_connected = self.wallet_manager.ledger.network.is_connected
sessions = session_pool.sessions sessions = []
connected = None connected = None
if self.wallet_manager.ledger.network.client: if is_connected:
addr_and_port = self.wallet_manager.ledger.network.client.server_address_and_port addr, port = self.wallet_manager.ledger.network.client.server
if addr_and_port: connected = f"{addr}:{port}"
connected = f"{addr_and_port[0]}:{addr_and_port[1]}" sessions.append(self.wallet_manager.ledger.network.client)
result = { result = {
'connected': connected, 'connected': connected,
'connected_features': self.wallet_manager.ledger.network.server_features, 'connected_features': self.wallet_manager.ledger.network.server_features,
@ -137,8 +138,8 @@ class WalletComponent(Component):
'availability': session.available, 'availability': session.available,
} for session in sessions } for session in sessions
], ],
'known_servers': len(sessions), 'known_servers': len(self.wallet_manager.ledger.network.config['default_servers']),
'available_servers': len(list(session_pool.available_sessions)) 'available_servers': 1 if is_connected else 0
} }
if self.wallet_manager.ledger.network.remote_height: if self.wallet_manager.ledger.network.remote_height:

View file

@ -1,26 +1,27 @@
import logging import logging
import asyncio import asyncio
import json import json
import socket
from time import perf_counter from time import perf_counter
from operator import itemgetter from collections import defaultdict
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import aiohttp import aiohttp
from lbry import __version__ from lbry import __version__
from lbry.utils import resolve_host
from lbry.error import IncompatibleWalletServerError from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController from lbry.wallet.stream import StreamController
from lbry.wallet.server.udp import SPVStatusClientProtocol, SPVPong
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ClientSession(BaseClientSession): class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): def __init__(self, *args, network: 'Network', server, timeout=30, **kwargs):
self.network = network self.network = network
self.server = server self.server = server
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.framer.max_size = self.max_errors = 1 << 32 self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
@ -28,8 +29,6 @@ class ClientSession(BaseClientSession):
self.connection_latency: Optional[float] = None self.connection_latency: Optional[float] = None
self._response_samples = 0 self._response_samples = 0
self.pending_amount = 0 self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@property @property
def available(self): def available(self):
@ -56,7 +55,7 @@ class ClientSession(BaseClientSession):
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
self.pending_amount += 1 self.pending_amount += 1
log.debug("send %s%s to %s:%i", method, tuple(args), *self.server) log.debug("send %s%s to %s:%i (%i timeout)", method, tuple(args), self.server[0], self.server[1], self.timeout)
try: try:
if method == 'server.version': if method == 'server.version':
return await self.send_timed_server_version_request(args, self.timeout) return await self.send_timed_server_version_request(args, self.timeout)
@ -93,38 +92,6 @@ class ClientSession(BaseClientSession):
finally: finally:
self.pending_amount -= 1 self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except RPCError as e:
await self.close()
log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
retry_delay = 60 * 60
except IncompatibleWalletServerError:
await self.close()
retry_delay = 60 * 60
log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server)
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required=None, timeout=3): async def ensure_server_version(self, required=None, timeout=3):
required = required or self.network.PROTOCOL_VERSION required = required or self.network.PROTOCOL_VERSION
response = await asyncio.wait_for( response = await asyncio.wait_for(
@ -134,6 +101,25 @@ class ClientSession(BaseClientSession):
raise IncompatibleWalletServerError(*self.server) raise IncompatibleWalletServerError(*self.server)
return response return response
async def keepalive_loop(self, timeout=3, max_idle=60):
try:
while True:
now = perf_counter()
if min(self.last_send, self.last_packet_received) + max_idle < now:
await asyncio.wait_for(
self.send_request('server.ping', []), timeout=timeout
)
else:
await asyncio.sleep(max(0, max_idle - (now - self.last_send)))
except Exception as err:
if isinstance(err, asyncio.CancelledError):
log.warning("closing connection to %s:%i", *self.server)
else:
log.exception("lost connection to spv")
finally:
if not self.is_closing():
self._close()
async def create_connection(self, timeout=6): async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server) connector = Connector(lambda: self, *self.server)
start = perf_counter() start = perf_counter()
@ -150,7 +136,9 @@ class ClientSession(BaseClientSession):
self.response_time = None self.response_time = None
self.connection_latency = None self.connection_latency = None
self._response_samples = 0 self._response_samples = 0
self._on_disconnect_controller.add(True) # self._on_disconnect_controller.add(True)
if self.network:
self.network.disconnect()
class Network: class Network:
@ -160,10 +148,9 @@ class Network:
def __init__(self, ledger): def __init__(self, ledger):
self.ledger = ledger self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None self.client: Optional[ClientSession] = None
self.server_features = None self.server_features = None
self._switch_task: Optional[asyncio.Task] = None # self._switch_task: Optional[asyncio.Task] = None
self.running = False self.running = False
self.remote_height: int = 0 self.remote_height: int = 0
self._concurrency = asyncio.Semaphore(16) self._concurrency = asyncio.Semaphore(16)
@ -183,58 +170,170 @@ class Network:
} }
self.aiohttp_session: Optional[aiohttp.ClientSession] = None self.aiohttp_session: Optional[aiohttp.ClientSession] = None
self._urgent_need_reconnect = asyncio.Event()
self._loop_task: Optional[asyncio.Task] = None
self._keepalive_task: Optional[asyncio.Task] = None
@property @property
def config(self): def config(self):
return self.ledger.config return self.ledger.config
async def switch_forever(self): def disconnect(self):
while self.running: if self._keepalive_task and not self._keepalive_task.done():
if self.is_connected: self._keepalive_task.cancel()
await self.client.on_disconnected.first self._keepalive_task = None
self.server_features = None
self.client = None
continue
self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try:
self.server_features = await self.get_server_features()
self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True)
log.info("Subscribed to headers: %s:%d", *self.client.server)
except (asyncio.TimeoutError, ConnectionError):
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
self.client.synchronous_close()
self.server_features = None
self.client = None
async def start(self): async def start(self):
self.running = True if not self.running:
self.aiohttp_session = aiohttp.ClientSession() self.running = True
self._switch_task = asyncio.ensure_future(self.switch_forever()) self.aiohttp_session = aiohttp.ClientSession()
# this may become unnecessary when there are no more bugs found, self.on_header.listen(self._update_remote_height)
# but for now it helps understanding log reports self._loop_task = asyncio.create_task(self.network_loop())
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self._urgent_need_reconnect.set()
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) def loop_task_done_callback(f):
try:
f.result()
except Exception:
if self.running:
log.exception("wallet server connection loop crashed")
self._loop_task.add_done_callback(loop_task_done_callback)
async def resolve_spv_dns(self):
hostname_to_ip = {}
ip_to_hostnames = defaultdict(list)
async def resolve_spv(server, port):
try:
server_addr = await resolve_host(server, port, 'udp')
hostname_to_ip[server] = (server_addr, port)
ip_to_hostnames[(server_addr, port)].append(server)
except socket.error:
log.warning("error looking up dns for spv server %s:%i", server, port)
except Exception:
log.exception("error looking up dns for spv server %s:%i", server, port)
# accumulate the dns results
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers']))
return hostname_to_ip, ip_to_hostnames
async def get_n_fastest_spvs(self, n=5, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]:
loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses)
sent_ping_timestamps = {}
_, ip_to_hostnames = await self.resolve_spv_dns()
log.info("%i possible spv servers to try (%i urls in config)", len(ip_to_hostnames),
len(self.config['default_servers']))
pongs = {}
try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
# could raise OSError if it cant bind
start = perf_counter()
for server in ip_to_hostnames:
connection.ping(server)
sent_ping_timestamps[server] = perf_counter()
while len(pongs) < n:
(remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start))
latency = ts - start
log.info("%s:%i has latency of %sms (available: %s, height: %i)",
'/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2),
pong.available, pong.height)
if pong.available:
pongs[remote] = pong
return pongs
except asyncio.TimeoutError:
if pongs:
log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames))
else:
log.warning("%i spv status probes failed, retrying later. servers tried: %s",
len(sent_ping_timestamps),
', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items()))
return pongs
finally:
connection.close()
async def connect_to_fastest(self) -> Optional[ClientSession]:
fastest_spvs = await self.get_n_fastest_spvs()
for (host, port) in fastest_spvs:
client = ClientSession(network=self, server=(host, port))
try:
await client.create_connection()
log.warning("Connected to spv server %s:%i", host, port)
await client.ensure_server_version()
return client
except (asyncio.TimeoutError, ConnectionError, OSError, IncompatibleWalletServerError, RPCError):
log.warning("Connecting to %s:%d failed", host, port)
client._close()
return
async def network_loop(self):
sleep_delay = 30
while self.running:
await asyncio.wait(
[asyncio.sleep(30), self._urgent_need_reconnect.wait()], return_when=asyncio.FIRST_COMPLETED
)
if self._urgent_need_reconnect.is_set():
sleep_delay = 30
self._urgent_need_reconnect.clear()
if not self.is_connected:
client = await self.connect_to_fastest()
if not client:
log.warning("failed to connect to any spv servers, retrying later")
sleep_delay *= 2
sleep_delay = min(sleep_delay, 300)
continue
log.debug("get spv server features %s:%i", *client.server)
features = await client.send_request('server.features', [])
self.client, self.server_features = client, features
log.info("subscribe to headers %s:%i", *client.server)
self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True)
server_str = "%s:%i" % client.server
log.info("maintaining connection to spv server %s", server_str)
self._keepalive_task = asyncio.create_task(self.client.keepalive_loop())
try:
await asyncio.wait(
[self._keepalive_task, self._urgent_need_reconnect.wait()],
return_when=asyncio.FIRST_COMPLETED
)
if self._urgent_need_reconnect.is_set():
log.warning("urgent reconnect needed")
self._urgent_need_reconnect.clear()
if self._keepalive_task and not self._keepalive_task.done():
self._keepalive_task.cancel()
except asyncio.CancelledError:
pass
finally:
self._keepalive_task = None
self.client = None
self.server_features = None
log.warning("connection lost to %s", server_str)
log.info("network loop finished")
async def stop(self): async def stop(self):
if self.running: self.running = False
self.running = False self.disconnect()
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
self._loop_task = None
if self.aiohttp_session:
await self.aiohttp_session.close() await self.aiohttp_session.close()
self._switch_task.cancel() self.aiohttp_session = None
self.session_pool.stop()
@property @property
def is_connected(self): def is_connected(self):
return self.client and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True, session=None): def rpc(self, list_or_method, args, restricted=True, session: Optional[ClientSession] = None):
session = session or (self.client if restricted else self.session_pool.fastest_session) if session or self.is_connected:
if session and not session.is_closing(): session = session or self.client
return session.send_request(list_or_method, args) return session.send_request(list_or_method, args)
else: else:
self.session_pool.trigger_nodelay_connect() self._urgent_need_reconnect.set()
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs): async def retriable_call(self, function, *args, **kwargs):
@ -242,14 +341,15 @@ class Network:
while self.running: while self.running:
if not self.is_connected: if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.") log.warning("Wallet server unavailable, waiting for it to come back and retry.")
self._urgent_need_reconnect.set()
await self.on_connected.first await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
try: try:
return await function(*args, **kwargs) return await function(*args, **kwargs)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.") log.warning("Wallet server call timed out, retrying.")
except ConnectionError: except ConnectionError:
pass log.warning("connection error")
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
@ -340,94 +440,3 @@ class Network:
result = await r.json() result = await r.json()
return result['result'] return result['result']
class SessionPool:
def __init__(self, network: Network, timeout: float):
self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@property
def online(self):
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return (session for session in self.sessions if session.available)
@property
def fastest_session(self):
if not self.online:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions] or [(0, None)],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)
def stop(self):
for session, task in self.sessions.items():
task.cancel()
session.synchronous_close()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
self._connect_session(session.server)
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
session.trigger_urgent_reconnect.set()
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session