forked from LBRYCommunity/lbry-sdk
use UDP ping for wallet server selection
-only connect to one spv server at a time -remove session pool
This commit is contained in:
parent
f0d8fb8f1a
commit
20efdc70b3
2 changed files with 188 additions and 178 deletions
|
@ -119,13 +119,14 @@ class WalletComponent(Component):
|
|||
async def get_status(self):
|
||||
if self.wallet_manager is None:
|
||||
return
|
||||
session_pool = self.wallet_manager.ledger.network.session_pool
|
||||
sessions = session_pool.sessions
|
||||
is_connected = self.wallet_manager.ledger.network.is_connected
|
||||
sessions = []
|
||||
connected = None
|
||||
if self.wallet_manager.ledger.network.client:
|
||||
addr_and_port = self.wallet_manager.ledger.network.client.server_address_and_port
|
||||
if addr_and_port:
|
||||
connected = f"{addr_and_port[0]}:{addr_and_port[1]}"
|
||||
if is_connected:
|
||||
addr, port = self.wallet_manager.ledger.network.client.server
|
||||
connected = f"{addr}:{port}"
|
||||
sessions.append(self.wallet_manager.ledger.network.client)
|
||||
|
||||
result = {
|
||||
'connected': connected,
|
||||
'connected_features': self.wallet_manager.ledger.network.server_features,
|
||||
|
@ -137,8 +138,8 @@ class WalletComponent(Component):
|
|||
'availability': session.available,
|
||||
} for session in sessions
|
||||
],
|
||||
'known_servers': len(sessions),
|
||||
'available_servers': len(list(session_pool.available_sessions))
|
||||
'known_servers': len(self.wallet_manager.ledger.network.config['default_servers']),
|
||||
'available_servers': 1 if is_connected else 0
|
||||
}
|
||||
|
||||
if self.wallet_manager.ledger.network.remote_height:
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import socket
|
||||
from time import perf_counter
|
||||
from operator import itemgetter
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Optional, Tuple
|
||||
import aiohttp
|
||||
|
||||
from lbry import __version__
|
||||
from lbry.utils import resolve_host
|
||||
from lbry.error import IncompatibleWalletServerError
|
||||
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
|
||||
from lbry.wallet.stream import StreamController
|
||||
from lbry.wallet.server.udp import SPVStatusClientProtocol, SPVPong
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientSession(BaseClientSession):
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
def __init__(self, *args, network: 'Network', server, timeout=30, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
|
@ -28,8 +29,6 @@ class ClientSession(BaseClientSession):
|
|||
self.connection_latency: Optional[float] = None
|
||||
self._response_samples = 0
|
||||
self.pending_amount = 0
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
|
@ -56,7 +55,7 @@ class ClientSession(BaseClientSession):
|
|||
|
||||
async def send_request(self, method, args=()):
|
||||
self.pending_amount += 1
|
||||
log.debug("send %s%s to %s:%i", method, tuple(args), *self.server)
|
||||
log.debug("send %s%s to %s:%i (%i timeout)", method, tuple(args), self.server[0], self.server[1], self.timeout)
|
||||
try:
|
||||
if method == 'server.version':
|
||||
return await self.send_timed_server_version_request(args, self.timeout)
|
||||
|
@ -93,38 +92,6 @@ class ClientSession(BaseClientSession):
|
|||
finally:
|
||||
self.pending_amount -= 1
|
||||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
retry_delay = default_delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if self.is_closing():
|
||||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.ensure_server_version()
|
||||
retry_delay = default_delay
|
||||
except RPCError as e:
|
||||
await self.close()
|
||||
log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
|
||||
retry_delay = 60 * 60
|
||||
except IncompatibleWalletServerError:
|
||||
await self.close()
|
||||
retry_delay = 60 * 60
|
||||
log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server)
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
retry_delay = min(60, retry_delay * 2)
|
||||
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||
try:
|
||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
async def ensure_server_version(self, required=None, timeout=3):
|
||||
required = required or self.network.PROTOCOL_VERSION
|
||||
response = await asyncio.wait_for(
|
||||
|
@ -134,6 +101,25 @@ class ClientSession(BaseClientSession):
|
|||
raise IncompatibleWalletServerError(*self.server)
|
||||
return response
|
||||
|
||||
async def keepalive_loop(self, timeout=3, max_idle=60):
|
||||
try:
|
||||
while True:
|
||||
now = perf_counter()
|
||||
if min(self.last_send, self.last_packet_received) + max_idle < now:
|
||||
await asyncio.wait_for(
|
||||
self.send_request('server.ping', []), timeout=timeout
|
||||
)
|
||||
else:
|
||||
await asyncio.sleep(max(0, max_idle - (now - self.last_send)))
|
||||
except Exception as err:
|
||||
if isinstance(err, asyncio.CancelledError):
|
||||
log.warning("closing connection to %s:%i", *self.server)
|
||||
else:
|
||||
log.exception("lost connection to spv")
|
||||
finally:
|
||||
if not self.is_closing():
|
||||
self._close()
|
||||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
start = perf_counter()
|
||||
|
@ -150,7 +136,9 @@ class ClientSession(BaseClientSession):
|
|||
self.response_time = None
|
||||
self.connection_latency = None
|
||||
self._response_samples = 0
|
||||
self._on_disconnect_controller.add(True)
|
||||
# self._on_disconnect_controller.add(True)
|
||||
if self.network:
|
||||
self.network.disconnect()
|
||||
|
||||
|
||||
class Network:
|
||||
|
@ -160,10 +148,9 @@ class Network:
|
|||
|
||||
def __init__(self, ledger):
|
||||
self.ledger = ledger
|
||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||
self.client: Optional[ClientSession] = None
|
||||
self.server_features = None
|
||||
self._switch_task: Optional[asyncio.Task] = None
|
||||
# self._switch_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
self.remote_height: int = 0
|
||||
self._concurrency = asyncio.Semaphore(16)
|
||||
|
@ -183,58 +170,170 @@ class Network:
|
|||
}
|
||||
|
||||
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
|
||||
self._urgent_need_reconnect = asyncio.Event()
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.ledger.config
|
||||
|
||||
async def switch_forever(self):
|
||||
while self.running:
|
||||
if self.is_connected:
|
||||
await self.client.on_disconnected.first
|
||||
self.server_features = None
|
||||
self.client = None
|
||||
continue
|
||||
self.client = await self.session_pool.wait_for_fastest_session()
|
||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||
try:
|
||||
self.server_features = await self.get_server_features()
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
self._on_connected_controller.add(True)
|
||||
log.info("Subscribed to headers: %s:%d", *self.client.server)
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
|
||||
self.client.synchronous_close()
|
||||
self.server_features = None
|
||||
self.client = None
|
||||
def disconnect(self):
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
self._keepalive_task.cancel()
|
||||
self._keepalive_task = None
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
self.aiohttp_session = aiohttp.ClientSession()
|
||||
self._switch_task = asyncio.ensure_future(self.switch_forever())
|
||||
# this may become unnecessary when there are no more bugs found,
|
||||
# but for now it helps understanding log reports
|
||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.aiohttp_session = aiohttp.ClientSession()
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
self._loop_task = asyncio.create_task(self.network_loop())
|
||||
self._urgent_need_reconnect.set()
|
||||
|
||||
def loop_task_done_callback(f):
|
||||
try:
|
||||
f.result()
|
||||
except Exception:
|
||||
if self.running:
|
||||
log.exception("wallet server connection loop crashed")
|
||||
|
||||
self._loop_task.add_done_callback(loop_task_done_callback)
|
||||
|
||||
async def resolve_spv_dns(self):
|
||||
hostname_to_ip = {}
|
||||
ip_to_hostnames = defaultdict(list)
|
||||
|
||||
async def resolve_spv(server, port):
|
||||
try:
|
||||
server_addr = await resolve_host(server, port, 'udp')
|
||||
hostname_to_ip[server] = (server_addr, port)
|
||||
ip_to_hostnames[(server_addr, port)].append(server)
|
||||
except socket.error:
|
||||
log.warning("error looking up dns for spv server %s:%i", server, port)
|
||||
except Exception:
|
||||
log.exception("error looking up dns for spv server %s:%i", server, port)
|
||||
|
||||
# accumulate the dns results
|
||||
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers']))
|
||||
return hostname_to_ip, ip_to_hostnames
|
||||
|
||||
async def get_n_fastest_spvs(self, n=5, timeout=3.0) -> Dict[Tuple[str, int], SPVPong]:
|
||||
loop = asyncio.get_event_loop()
|
||||
pong_responses = asyncio.Queue()
|
||||
connection = SPVStatusClientProtocol(pong_responses)
|
||||
sent_ping_timestamps = {}
|
||||
_, ip_to_hostnames = await self.resolve_spv_dns()
|
||||
log.info("%i possible spv servers to try (%i urls in config)", len(ip_to_hostnames),
|
||||
len(self.config['default_servers']))
|
||||
pongs = {}
|
||||
try:
|
||||
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
|
||||
# could raise OSError if it cant bind
|
||||
start = perf_counter()
|
||||
for server in ip_to_hostnames:
|
||||
connection.ping(server)
|
||||
sent_ping_timestamps[server] = perf_counter()
|
||||
while len(pongs) < n:
|
||||
(remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start))
|
||||
latency = ts - start
|
||||
log.info("%s:%i has latency of %sms (available: %s, height: %i)",
|
||||
'/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2),
|
||||
pong.available, pong.height)
|
||||
|
||||
if pong.available:
|
||||
pongs[remote] = pong
|
||||
return pongs
|
||||
except asyncio.TimeoutError:
|
||||
if pongs:
|
||||
log.info("%i/%i probed spv servers are accepting connections", len(pongs), len(ip_to_hostnames))
|
||||
else:
|
||||
log.warning("%i spv status probes failed, retrying later. servers tried: %s",
|
||||
len(sent_ping_timestamps),
|
||||
', '.join('/'.join(hosts) + f' ({ip})' for ip, hosts in ip_to_hostnames.items()))
|
||||
return pongs
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
async def connect_to_fastest(self) -> Optional[ClientSession]:
|
||||
fastest_spvs = await self.get_n_fastest_spvs()
|
||||
for (host, port) in fastest_spvs:
|
||||
|
||||
client = ClientSession(network=self, server=(host, port))
|
||||
try:
|
||||
await client.create_connection()
|
||||
log.warning("Connected to spv server %s:%i", host, port)
|
||||
await client.ensure_server_version()
|
||||
return client
|
||||
except (asyncio.TimeoutError, ConnectionError, OSError, IncompatibleWalletServerError, RPCError):
|
||||
log.warning("Connecting to %s:%d failed", host, port)
|
||||
client._close()
|
||||
return
|
||||
|
||||
async def network_loop(self):
|
||||
sleep_delay = 30
|
||||
while self.running:
|
||||
await asyncio.wait(
|
||||
[asyncio.sleep(30), self._urgent_need_reconnect.wait()], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if self._urgent_need_reconnect.is_set():
|
||||
sleep_delay = 30
|
||||
self._urgent_need_reconnect.clear()
|
||||
if not self.is_connected:
|
||||
client = await self.connect_to_fastest()
|
||||
if not client:
|
||||
log.warning("failed to connect to any spv servers, retrying later")
|
||||
sleep_delay *= 2
|
||||
sleep_delay = min(sleep_delay, 300)
|
||||
continue
|
||||
log.debug("get spv server features %s:%i", *client.server)
|
||||
features = await client.send_request('server.features', [])
|
||||
self.client, self.server_features = client, features
|
||||
log.info("subscribe to headers %s:%i", *client.server)
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
self._on_connected_controller.add(True)
|
||||
server_str = "%s:%i" % client.server
|
||||
log.info("maintaining connection to spv server %s", server_str)
|
||||
self._keepalive_task = asyncio.create_task(self.client.keepalive_loop())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[self._keepalive_task, self._urgent_need_reconnect.wait()],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if self._urgent_need_reconnect.is_set():
|
||||
log.warning("urgent reconnect needed")
|
||||
self._urgent_need_reconnect.clear()
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
self._keepalive_task.cancel()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._keepalive_task = None
|
||||
self.client = None
|
||||
self.server_features = None
|
||||
log.warning("connection lost to %s", server_str)
|
||||
log.info("network loop finished")
|
||||
|
||||
async def stop(self):
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.running = False
|
||||
self.disconnect()
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
self._loop_task = None
|
||||
if self.aiohttp_session:
|
||||
await self.aiohttp_session.close()
|
||||
self._switch_task.cancel()
|
||||
self.session_pool.stop()
|
||||
self.aiohttp_session = None
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args, restricted=True, session=None):
|
||||
session = session or (self.client if restricted else self.session_pool.fastest_session)
|
||||
if session and not session.is_closing():
|
||||
def rpc(self, list_or_method, args, restricted=True, session: Optional[ClientSession] = None):
|
||||
if session or self.is_connected:
|
||||
session = session or self.client
|
||||
return session.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
self._urgent_need_reconnect.set()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def retriable_call(self, function, *args, **kwargs):
|
||||
|
@ -242,14 +341,15 @@ class Network:
|
|||
while self.running:
|
||||
if not self.is_connected:
|
||||
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
|
||||
self._urgent_need_reconnect.set()
|
||||
await self.on_connected.first
|
||||
await self.session_pool.wait_for_fastest_session()
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Wallet server call timed out, retrying.")
|
||||
except ConnectionError:
|
||||
pass
|
||||
log.warning("connection error")
|
||||
|
||||
raise asyncio.CancelledError() # if we got here, we are shutting down
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
|
@ -340,94 +440,3 @@ class Network:
|
|||
result = await r.json()
|
||||
return result['result']
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
||||
def __init__(self, network: Network, timeout: float):
|
||||
self.network = network
|
||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||
self.timeout = timeout
|
||||
self.new_connection_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
return any(not session.is_closing() for session in self.sessions)
|
||||
|
||||
@property
|
||||
def available_sessions(self):
|
||||
return (session for session in self.sessions if session.available)
|
||||
|
||||
@property
|
||||
def fastest_session(self):
|
||||
if not self.online:
|
||||
return None
|
||||
return min(
|
||||
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
|
||||
for session in self.available_sessions] or [(0, None)],
|
||||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
task = self.sessions.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def start(self, default_servers):
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.synchronous_close()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session in self.sessions:
|
||||
self._connect_session(session.server)
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
for session in self.sessions:
|
||||
session.trigger_urgent_reconnect.set()
|
||||
|
||||
async def wait_for_fastest_session(self):
|
||||
while not self.fastest_session:
|
||||
self.trigger_nodelay_connect()
|
||||
self.new_connection_event.clear()
|
||||
await self.new_connection_event.wait()
|
||||
return self.fastest_session
|
||||
|
|
Loading…
Reference in a new issue