From 0bd3c3f503d9f0e9fb1d30b8ccd19c6850ee2d3a Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Fri, 12 Jul 2019 06:38:23 -0300 Subject: [PATCH] refactor protocol handler + misc --- .../integration/test_reconnect.py | 17 ++++----- torba/torba/client/basenetwork.py | 21 +++-------- torba/torba/rpc/util.py | 10 ++++++ torba/torba/server/server.py | 35 ++----------------- 4 files changed, 24 insertions(+), 59 deletions(-) diff --git a/torba/tests/client_tests/integration/test_reconnect.py b/torba/tests/client_tests/integration/test_reconnect.py index 90b122176..fda2ea082 100644 --- a/torba/tests/client_tests/integration/test_reconnect.py +++ b/torba/tests/client_tests/integration/test_reconnect.py @@ -48,10 +48,9 @@ class ReconnectTests(IntegrationTestCase): # TODO: test rolling over to a working server when an rpc request fails before raising self.assertTrue(self.ledger.network.is_connected) - - await self.assertBalance(self.account, '0.0') - address1 = await self.account.receiving.get_or_create_usable_address() + txid = await self.blockchain.send_to_address(address1, 21) + await self.blockchain.generate(1) real_sock = self.ledger.network.client.transport._extra.pop('socket') mock_sock = Mock(spec=socket.socket) @@ -70,23 +69,21 @@ class ReconnectTests(IntegrationTestCase): self.ledger.network.client.transport._sock = mock_sock self.ledger.network.client.transport._extra['socket'] = mock_sock - await self.blockchain.send_to_address(address1, 21) - await self.blockchain.generate(1) self.assertFalse(raised.is_set()) + with self.assertRaises(asyncio.CancelledError): + await self.ledger.network.get_transaction(txid) - await asyncio.wait_for(raised.wait(), 2) - await self.assertBalance(self.account, '0.0') + self.assertTrue(raised.is_set()) self.assertFalse(self.ledger.network.is_connected) self.assertIsNone(self.ledger.network.client.transport) - await self.blockchain.send_to_address(address1, 21) + txid = await self.blockchain.send_to_address(address1, 2) await self.blockchain.generate(1) await self.ledger.network.on_connected.first self.assertTrue(self.ledger.network.is_connected) - await asyncio.sleep(30, loop=self.loop) self.assertIsNotNone(self.ledger.network.client.transport) - await self.assertBalance(self.account, '42.0') + self.assertIsNotNone(await self.ledger.network.get_transaction(txid)) class ServerPickingTestCase(AsyncioTestCase): diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index d98222b05..f644a5ba8 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -8,6 +8,7 @@ import socket from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba import __version__ +from torba.rpc.util import protocol_exception_handler from torba.stream import StreamController log = logging.getLogger(__name__) @@ -55,21 +56,6 @@ class ClientSession(BaseClientSession): self.ping_task.cancel() -def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context): - exception = context['exception'] - if 'protocol' not in context or 'transport' not in context: - raise exception - if not isinstance(context['protocol'], ClientSession): - raise exception - transport: asyncio.Transport = context['transport'] - message = context['message'] - if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"): - raise exception - log.warning("Disconnecting after error: %s", str(exception)) - transport.abort() - transport.close() - - class BaseNetwork: def __init__(self, ledger): @@ -219,7 +205,10 @@ class SessionPool: self.ensure_connection(session) for session in self.sessions ], return_exceptions=True) - await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED') + try: + await asyncio.wait_for(self._lost_master.wait(), timeout=3) + except asyncio.TimeoutError: + pass self._lost_master.clear() if not self.sessions: self.sessions.extend(self._dead_servers) diff --git a/torba/torba/rpc/util.py b/torba/torba/rpc/util.py index f62587b71..6b599ddf0 100644 --- a/torba/torba/rpc/util.py +++ b/torba/torba/rpc/util.py @@ -93,3 +93,13 @@ class Concurrency(object): else: for _ in range(-diff): await self.semaphore.acquire() + + +def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context): + message = context['message'] + transport = context.get('transport') + if transport and message in ("Fatal read error on socket transport", "Fatal write error on socket transport"): + transport.abort() + transport.close() + else: + return loop.default_exception_handler(context) diff --git a/torba/torba/server/server.py b/torba/torba/server/server.py index 893860db5..eea16c081 100644 --- a/torba/torba/server/server.py +++ b/torba/torba/server/server.py @@ -4,39 +4,8 @@ import asyncio from concurrent.futures.thread import ThreadPoolExecutor import torba +from torba.rpc.util import protocol_exception_handler from torba.server.mempool import MemPool, MemPoolAPI -from torba.server.session import SessionManager, SessionBase - - -CONNECTION_TIMED_OUT = 110 -NO_ROUTE_TO_HOST = 113 - - -def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)): - err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors)) - log = logging.getLogger(__name__) - - def protocol_exception_handler(loop, context): - exception = context['exception'] - if 'protocol' not in context or 'transport' not in context: - raise exception - if not isinstance(context['protocol'], SessionBase): - raise exception - session: SessionBase = context['protocol'] - transport: asyncio.Transport = context['transport'] - message = context['message'] - if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"): - raise exception - - for err_msg in err_msgs: - if str(exception).startswith(err_msg): - log.debug("caught: '%s' for %s", str(exception), session) - transport.abort() - transport.close() - loop.create_task(session.close(force_after=1)) - return - raise exception - return protocol_exception_handler class Notifications: @@ -121,7 +90,7 @@ class Server: ) async def start(self): - asyncio.get_event_loop().set_exception_handler(handle_socket_errors()) + asyncio.get_event_loop().set_exception_handler(protocol_exception_handler) env = self.env min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() self.log.info(f'software version: {torba.__version__}')