diff --git a/tests/client_tests/integration/test_reconnect.py b/tests/client_tests/integration/test_reconnect.py index 0442cd4cd..0be16e012 100644 --- a/tests/client_tests/integration/test_reconnect.py +++ b/tests/client_tests/integration/test_reconnect.py @@ -2,7 +2,6 @@ import logging import asyncio from unittest.mock import Mock -from torba.client.baseledger import BaseLedger from torba.client.basenetwork import BaseNetwork from torba.rpc import RPCSession from torba.testcase import IntegrationTestCase, AsyncioTestCase @@ -70,3 +69,6 @@ class ServerPickingTestCase(AsyncioTestCase): await asyncio.wait_for(network.on_connected.first, timeout=1) self.assertTrue(network.is_connected) self.assertEqual(network.client.server, ('127.0.0.1', 1337)) + # ensure we are connected to all of them + self.assertEqual(len(network.session_pool.sessions), 4) + self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions])) diff --git a/torba/client/basenetwork.py b/torba/client/basenetwork.py index fa1f934c3..5b4a1203a 100644 --- a/torba/client/basenetwork.py +++ b/torba/client/basenetwork.py @@ -58,6 +58,7 @@ class BaseNetwork: def __init__(self, ledger): self.config = ledger.config self.client: ClientSession = None + self.session_pool: SessionPool = None self.running = False self._on_connected_controller = StreamController() @@ -74,40 +75,18 @@ class BaseNetwork: 'blockchain.address.subscribe': self._on_status_controller, } - async def pick_fastest_server(self, timeout): - async def __probe(server): - client = ClientSession(network=self, server=server) - try: - await client.create_connection(timeout) - await client.send_request('server.banner') - return client - except (asyncio.TimeoutError, asyncio.CancelledError): - if not client.is_closing(): - client.abort() - raise - except Exception: # pylint: disable=broad-except - log.exception("Connecting to %s:%d raised an exception:", *server) - futures = [] - for server in self.config['default_servers']: - futures.append(__probe(server)) - done, pending = await asyncio.wait(futures, return_when='FIRST_COMPLETED') - for task in pending: - task.cancel() - for client in done: - return await client - async def start(self): self.running = True - delay = 0.0 connect_timeout = self.config.get('connect_timeout', 6) + self.session_pool = SessionPool(network=self, timeout=connect_timeout) + self.session_pool.start(self.config['default_servers']) while True: try: - self.client = await self.pick_fastest_server(connect_timeout) + self.client = await self.session_pool.pick_fastest_server() if self.is_connected: await self.ensure_server_version() log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) self._on_connected_controller.add(True) - delay = 0.0 await self.client.on_disconnected.first except CancelledError: self.running = False @@ -120,11 +99,11 @@ class BaseNetwork: elif self.client: await self.client.close() self.client.connection.cancel_pending_requests() - await asyncio.sleep(delay) - delay = min(delay + 1.0, 10.0) async def stop(self): self.running = False + if self.session_pool: + self.session_pool.stop() if self.is_connected: disconnected = self.client.on_disconnected.first await self.client.close() @@ -166,3 +145,84 @@ class BaseNetwork: def subscribe_address(self, address): return self.rpc('blockchain.address.subscribe', [address]) + + +class SessionPool: + + def __init__(self, network: BaseNetwork, timeout: float): + self.network = network + self.sessions = [] + self._dead_servers = [] + self.maintain_connections_task = None + self.timeout = timeout + # triggered when the master server is out, to speed up reconnect + self._lost_master = asyncio.Event() + + @property + def online(self): + for session in self.sessions: + if not session.is_closing(): + return True + return False + + def start(self, default_servers): + self.sessions = [ + ClientSession(network=self.network, server=server) + for server in default_servers + ] + self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) + + def stop(self): + if self.maintain_connections_task: + self.maintain_connections_task.cancel() + for session in self.sessions: + if not session.is_closing(): + session.abort() + self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None + + async def ensure_connections(self): + while True: + await asyncio.gather(*[ + self.ensure_connection(session) + for session in self.sessions + ], return_exceptions=True) + await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED') + self._lost_master.clear() + if not self.sessions: + self.sessions.extend(self._dead_servers) + self._dead_servers = [] + + async def ensure_connection(self, session): + if not session.is_closing(): + return + try: + return await session.create_connection(self.timeout) + except asyncio.TimeoutError: + log.warning("Timeout connecting to %s:%d", *session.server) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as err: # pylint: disable=broad-except + if 'Connect call failed' in str(err): + log.warning("Could not connect to %s:%d", *session.server) + else: + log.exception("Connecting to %s:%d raised an exception:", *session.server) + self._dead_servers.append(session) + self.sessions.remove(session) + + async def pick_fastest_server(self): + self._lost_master.set() + while not self.online: + await asyncio.sleep(0.1) + + async def _probe(session): + await session.send_request('server.banner') + return session + + done, pending = await asyncio.wait([ + _probe(session) + for session in self.sessions if not session.is_closing() + ], return_when='FIRST_COMPLETED') + for task in pending: + task.cancel() + for session in done: + return await session