diff --git a/lbry/tests/integration/test_wallet_server_sessions.py b/lbry/tests/integration/test_wallet_server_sessions.py index 0ba1914cf..5302d75f0 100644 --- a/lbry/tests/integration/test_wallet_server_sessions.py +++ b/lbry/tests/integration/test_wallet_server_sessions.py @@ -61,7 +61,7 @@ class TestSessionBloat(IntegrationTestCase): def delayed_resume(): self.paused_session.clear() - time.sleep(1) + time.sleep(0.9) self.resumed_session.set() self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused) diff --git a/torba/tests/client_tests/integration/test_reconnect.py b/torba/tests/client_tests/integration/test_reconnect.py index 0be16e012..90b122176 100644 --- a/torba/tests/client_tests/integration/test_reconnect.py +++ b/torba/tests/client_tests/integration/test_reconnect.py @@ -1,5 +1,6 @@ import logging import asyncio +import socket from unittest.mock import Mock from torba.client.basenetwork import BaseNetwork @@ -42,6 +43,51 @@ class ReconnectTests(IntegrationTestCase): await self.ledger.network.on_connected.first self.assertTrue(self.ledger.network.is_connected) + async def test_socket_timeout_then_reconnect(self): + # TODO: test reconnecting on an rpc request + # TODO: test rolling over to a working server when an rpc request fails before raising + + self.assertTrue(self.ledger.network.is_connected) + + await self.assertBalance(self.account, '0.0') + + address1 = await self.account.receiving.get_or_create_usable_address() + + real_sock = self.ledger.network.client.transport._extra.pop('socket') + mock_sock = Mock(spec=socket.socket) + + for attr in dir(real_sock): + if not attr.startswith('__'): + setattr(mock_sock, attr, getattr(real_sock, attr)) + + raised = asyncio.Event(loop=self.loop) + + def recv(*a, **kw): + raised.set() + raise TimeoutError("[Errno 60] Operation timed out") + + mock_sock.recv = recv + self.ledger.network.client.transport._sock = mock_sock + self.ledger.network.client.transport._extra['socket'] = mock_sock + + await self.blockchain.send_to_address(address1, 21) + await self.blockchain.generate(1) + self.assertFalse(raised.is_set()) + + await asyncio.wait_for(raised.wait(), 2) + await self.assertBalance(self.account, '0.0') + self.assertFalse(self.ledger.network.is_connected) + self.assertIsNone(self.ledger.network.client.transport) + + await self.blockchain.send_to_address(address1, 21) + await self.blockchain.generate(1) + await self.ledger.network.on_connected.first + self.assertTrue(self.ledger.network.is_connected) + + await asyncio.sleep(30, loop=self.loop) + self.assertIsNotNone(self.ledger.network.client.transport) + await self.assertBalance(self.account, '42.0') + class ServerPickingTestCase(AsyncioTestCase): async def _make_fake_server(self, latency=1.0, port=1337): diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 1ed55a51f..d98222b05 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -55,9 +55,25 @@ class ClientSession(BaseClientSession): self.ping_task.cancel() +def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context): + exception = context['exception'] + if 'protocol' not in context or 'transport' not in context: + raise exception + if not isinstance(context['protocol'], ClientSession): + raise exception + transport: asyncio.Transport = context['transport'] + message = context['message'] + if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"): + raise exception + log.warning("Disconnecting after error: %s", str(exception)) + transport.abort() + transport.close() + + class BaseNetwork: def __init__(self, ledger): + asyncio.get_event_loop().set_exception_handler(protocol_exception_handler) self.config = ledger.config self.client: ClientSession = None self.session_pool: SessionPool = None