From 34500e07c6c27250108dedf31c877fdebd2ba19e Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Fri, 19 Jul 2019 20:03:18 -0300 Subject: [PATCH] change picking logic to probe the session before considering it a good session --- .../client_tests/integration/test_network.py | 40 ++++++++++++------- torba/torba/client/basenetwork.py | 18 +++++---- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/torba/tests/client_tests/integration/test_network.py b/torba/tests/client_tests/integration/test_network.py index c74102fe5..b5b4643f1 100644 --- a/torba/tests/client_tests/integration/test_network.py +++ b/torba/tests/client_tests/integration/test_network.py @@ -53,35 +53,45 @@ class ReconnectTests(IntegrationTestCase): class ServerPickingTestCase(AsyncioTestCase): - async def _make_fake_server(self, latency=1.0, port=1337): - # local fake server with artificial latency - proto = RPCSession() - async def __handler(_): - await asyncio.sleep(latency) - return {'height': 1} - proto.handle_request = __handler - server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port) + async def _make_fake_server(self, latency=1.0, port=1): + # local fake server with artificial latency + class FakeSession(RPCSession): + async def handle_request(self, request): + await asyncio.sleep(latency) + return {"height": 1} + server = await self.loop.create_server(lambda: FakeSession(), host='127.0.0.1', port=port) + self.addCleanup(server.close) + return '127.0.0.1', port + + async def _make_bad_server(self, port=42420): + async def echo(reader, writer): + while True: + writer.write(await reader.read()) + server = await asyncio.start_server(echo, host='127.0.0.1', port=port) self.addCleanup(server.close) return '127.0.0.1', port async def test_pick_fastest(self): ledger = Mock(config={ 'default_servers': [ - await self._make_fake_server(latency=1.5, port=1340), - await self._make_fake_server(latency=0.1, port=1337), - await self._make_fake_server(latency=1.0, port=1339), - await self._make_fake_server(latency=0.5, port=1338), + # fast but unhealthy, should be discarded + await self._make_bad_server(), + ('localhost', 1), + ('example.that.doesnt.resolve', 9000), + await self._make_fake_server(latency=1.2, port=1340), + await self._make_fake_server(latency=0.5, port=1337), + await self._make_fake_server(latency=0.7, port=1339), ], - 'connect_timeout': 30 + 'connect_timeout': 3 }) network = BaseNetwork(ledger) self.addCleanup(network.stop) asyncio.ensure_future(network.start()) - await asyncio.wait_for(network.on_connected.first, timeout=1) + await asyncio.wait_for(network.on_connected.first, timeout=3) self.assertTrue(network.is_connected) self.assertEqual(network.client.server, ('127.0.0.1', 1337)) # ensure we are connected to all of them - self.assertEqual(len(network.session_pool.sessions), 4) self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions])) + self.assertEqual(len(network.session_pool.sessions), 3) diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index ba014a246..b40d171e8 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -47,6 +47,8 @@ class ClientSession(BaseClientSession): connector = Connector(lambda: self, *self.server) await asyncio.wait_for(connector.create_connection(), timeout=timeout) self.ping_task = asyncio.create_task(self.ping_forever()) + # tie the ping task to this connection: if the task dies for any unexpected error, abort + self.ping_task.add_done_callback(lambda _: self.abort()) async def handle_request(self, request): controller = self.network.subscription_controllers[request.method] @@ -223,10 +225,14 @@ class SessionPool: self._dead_servers = [] async def ensure_connection(self, session): - if not session.is_closing(): - return + self._dead_servers.append(session) + self.sessions.remove(session) try: - return await session.create_connection(self.timeout) + if session.is_closing(): + await session.create_connection(self.timeout) + await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout) + self.sessions.append(session) + self._dead_servers.remove(session) except asyncio.TimeoutError: log.warning("Timeout connecting to %s:%d", *session.server) except asyncio.CancelledError: # pylint: disable=try-except-raise @@ -238,11 +244,9 @@ class SessionPool: log.warning("Could not connect to %s:%d", *session.server) else: log.exception("Connecting to %s:%d raised an exception:", *session.server) - self._dead_servers.append(session) - self.sessions.remove(session) async def get_online_sessions(self): - self._lost_master.set() while not self.online: - await asyncio.sleep(0.1) + self._lost_master.set() + await asyncio.sleep(0.5) return self.sessions