diff --git a/torba/client/basenetwork.py b/torba/client/basenetwork.py index 7435350d1..0b3c8aeca 100644 --- a/torba/client/basenetwork.py +++ b/torba/client/basenetwork.py @@ -2,7 +2,7 @@ import logging import asyncio from asyncio import CancelledError from time import time -from typing import Iterable +from typing import List from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError @@ -83,7 +83,7 @@ class BaseNetwork: self.session_pool.start(self.config['default_servers']) while True: try: - self.client = await self.session_pool.pick_fastest_server() + self.client = await self.pick_fastest_session() if self.is_connected: await self.ensure_server_version() log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) @@ -120,6 +120,21 @@ class BaseNetwork: else: raise ConnectionError("Attempting to send rpc request when connection is not available.") + async def pick_fastest_session(self): + sessions = await self.session_pool.get_online_sessions() + done, pending = await asyncio.wait([ + self.probe_session(session) + for session in sessions if not session.is_closing() + ], return_when='FIRST_COMPLETED') + for task in pending: + task.cancel() + for session in done: + return await session + + async def probe_session(self, session: ClientSession): + await session.send_request('server.banner') + return session + def ensure_server_version(self, required='1.2'): return self.rpc('server.version', [__version__, required]) @@ -152,8 +167,8 @@ class SessionPool: def __init__(self, network: BaseNetwork, timeout: float): self.network = network - self.sessions: Iterable[ClientSession] = [] - self._dead_servers: Iterable[ClientSession] = [] + self.sessions: List[ClientSession] = [] + self._dead_servers: List[ClientSession] = [] self.maintain_connections_task = None self.timeout = timeout # triggered when the master server is out, to speed up reconnect @@ -210,20 +225,8 @@ class SessionPool: self._dead_servers.append(session) self.sessions.remove(session) - async def pick_fastest_server(self): + async def get_online_sessions(self): self._lost_master.set() while not self.online: await asyncio.sleep(0.1) - - async def _probe(session): - await session.send_request('server.banner') - return session - - done, pending = await asyncio.wait([ - _probe(session) - for session in self.sessions if not session.is_closing() - ], return_when='FIRST_COMPLETED') - for task in pending: - task.cancel() - for session in done: - return await session + return self.sessions