From e78e2738ef7d97c15b233674210c91f2e16f725e Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Tue, 6 Aug 2019 02:17:39 -0300 Subject: [PATCH] basenet refactor --- .../client_tests/integration/test_network.py | 21 ++++-- torba/torba/client/basenetwork.py | 73 +++++++++---------- 2 files changed, 48 insertions(+), 46 deletions(-) diff --git a/torba/tests/client_tests/integration/test_network.py b/torba/tests/client_tests/integration/test_network.py index b5b4643f1..5be044cef 100644 --- a/torba/tests/client_tests/integration/test_network.py +++ b/torba/tests/client_tests/integration/test_network.py @@ -22,9 +22,11 @@ class ReconnectTests(IntegrationTestCase): async def test_connection_drop_still_receives_events_after_reconnected(self): address1 = await self.account.receiving.get_or_create_usable_address() + # disconnect and send a new tx, should reconnect and get it self.ledger.network.client.connection_lost(Exception()) + self.assertFalse(self.ledger.network.is_connected) sendtxid = await self.blockchain.send_to_address(address1, 1.1337) - await self.on_transaction_id(sendtxid) # mempool + await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool await self.blockchain.generate(1) await self.on_transaction_id(sendtxid) # confirmed @@ -45,8 +47,10 @@ class ReconnectTests(IntegrationTestCase): await self.ledger.network.get_transaction(sendtxid) async def test_timeout_then_reconnect(self): + # tests that it connects back after some failed attempts await self.conductor.spv_node.stop() self.assertFalse(self.ledger.network.is_connected) + await asyncio.sleep(0.2) # let it retry and fail once await self.conductor.spv_node.start(self.conductor.blockchain_node) await self.ledger.network.on_connected.first self.assertTrue(self.ledger.network.is_connected) @@ -79,9 +83,9 @@ class ServerPickingTestCase(AsyncioTestCase): await self._make_bad_server(), ('localhost', 1), ('example.that.doesnt.resolve', 9000), - await self._make_fake_server(latency=1.2, port=1340), - await self._make_fake_server(latency=0.5, port=1337), - await self._make_fake_server(latency=0.7, port=1339), + await self._make_fake_server(latency=1.0, port=1340), + await self._make_fake_server(latency=0.1, port=1337), + await self._make_fake_server(latency=0.4, port=1339), ], 'connect_timeout': 3 }) @@ -89,9 +93,10 @@ class ServerPickingTestCase(AsyncioTestCase): network = BaseNetwork(ledger) self.addCleanup(network.stop) asyncio.ensure_future(network.start()) - await asyncio.wait_for(network.on_connected.first, timeout=3) + await asyncio.wait_for(network.on_connected.first, timeout=1) self.assertTrue(network.is_connected) self.assertEqual(network.client.server, ('127.0.0.1', 1337)) - # ensure we are connected to all of them - self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions])) - self.assertEqual(len(network.session_pool.sessions), 3) + self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions])) + # ensure we are connected to all of them after a while + await asyncio.sleep(1) + self.assertEqual(len(network.session_pool.available_sessions), 3) diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 7ba0fb215..69e3eedb6 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -1,6 +1,7 @@ import logging import asyncio -from typing import Dict +from operator import itemgetter +from typing import Dict, Optional from time import time from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError @@ -26,7 +27,7 @@ class ClientSession(BaseClientSession): @property def available(self): - return not self.is_closing() and self._can_send.is_set() + return not self.is_closing() and self._can_send.is_set() and self.latency < 1 << 32 async def send_request(self, method, args=()): try: @@ -39,14 +40,11 @@ class ClientSession(BaseClientSession): except RPCError as e: log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) raise e - except asyncio.TimeoutError: - self.abort() - raise async def ensure_session(self): # Handles reconnecting and maintaining a session alive # TODO: change to 'ping' on newer protocol (above 1.2) - retry_delay = 1.0 + retry_delay = default_delay = 0.1 while True: try: if self.is_closing(): @@ -54,10 +52,11 @@ class ClientSession(BaseClientSession): await self.ensure_server_version() if (time() - self.last_send) > self.max_seconds_idle: await self.send_request('server.banner') - retry_delay = 1.0 - except asyncio.TimeoutError: + retry_delay = default_delay + except (asyncio.TimeoutError, OSError): + await self.close() + retry_delay = min(60, retry_delay * 2) log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) - retry_delay = max(60, retry_delay * 2) await asyncio.sleep(retry_delay) def ensure_server_version(self, required='1.2'): @@ -80,9 +79,10 @@ class ClientSession(BaseClientSession): class BaseNetwork: def __init__(self, ledger): + self.switch_event = asyncio.Event() self.config = ledger.config - self.client: ClientSession = None - self.session_pool: SessionPool = None + self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) + self.client: Optional[ClientSession] = None self.running = False self.remote_height: int = 0 @@ -102,8 +102,6 @@ class BaseNetwork: async def start(self): self.running = True - connect_timeout = self.config.get('connect_timeout', 6) - self.session_pool = SessionPool(network=self, timeout=connect_timeout) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) while self.running: @@ -112,7 +110,9 @@ class BaseNetwork: self._update_remote_height((await self.subscribe_headers(),)) log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) self._on_connected_controller.add(True) - await self.client.on_disconnected.first + self.client.on_disconnected.listen(lambda _: self.switch_event.set()) + await self.switch_event.wait() + self.switch_event.clear() except asyncio.CancelledError: await self.stop() raise @@ -132,12 +132,14 @@ class BaseNetwork: @property def is_connected(self): - return self.session_pool.online + return self.client and not self.client.is_closing() - async def rpc(self, list_or_method, args): + def rpc(self, list_or_method, args): if self.is_connected: - await self.session_pool.wait_for_fastest_session() - return await self.session_pool.fastest_session.send_request(list_or_method, args) + fastest = self.session_pool.fastest_session + if self.client != fastest: + self.switch_event.set() + return self.client.send_request(list_or_method, args) else: raise ConnectionError("Attempting to send rpc request when connection is not available.") @@ -178,7 +180,6 @@ class SessionPool: def __init__(self, network: BaseNetwork, timeout: float): self.network = network self.sessions: Dict[ClientSession, asyncio.Task] = dict() - self.maintain_connections_task = None self.timeout = timeout @property @@ -193,31 +194,27 @@ class SessionPool: def fastest_session(self): if not self.available_sessions: return None - return min([(session.latency, session) for session in self.available_sessions])[1] + return min([(session.latency, session) for session in self.available_sessions], key=itemgetter(0))[1] def start(self, default_servers): - for server in default_servers: - session = ClientSession(network=self.network, server=server) - self.sessions[session] = asyncio.create_task(session.ensure_session()) - self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) + self.sessions = { + ClientSession(network=self.network, server=server): None for server in default_servers + } + self.ensure_connections() def stop(self): - if self.maintain_connections_task: - self.maintain_connections_task.cancel() - self.maintain_connections_task = None - for session, maintenance_task in self.sessions.items(): - maintenance_task.cancel() - if not session.is_closing(): - session.abort() + for session, task in self.sessions.items(): + task.cancel() + session.abort() self.sessions.clear() - async def ensure_connections(self): - while True: - log.info("Checking conns") - for session, task in list(self.sessions.items()): - if task.done(): - self.sessions[session] = asyncio.create_task(session.ensure_session()) - await asyncio.wait(self.sessions.items(), timeout=10) + def ensure_connections(self): + for session, task in list(self.sessions.items()): + if not task or task.done(): + task = asyncio.create_task(session.ensure_session()) + task.add_done_callback(lambda _: self.ensure_connections()) + self.sessions[session] = task + async def wait_for_fastest_session(self): while True: