reconnect torba ClientSession on fatal socket errors
This commit is contained in:
parent
c5aeac6898
commit
489da88e79
3 changed files with 63 additions and 1 deletions
|
@ -61,7 +61,7 @@ class TestSessionBloat(IntegrationTestCase):
|
|||
def delayed_resume():
|
||||
self.paused_session.clear()
|
||||
|
||||
time.sleep(1)
|
||||
time.sleep(0.9)
|
||||
self.resumed_session.set()
|
||||
|
||||
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import socket
|
||||
from unittest.mock import Mock
|
||||
|
||||
from torba.client.basenetwork import BaseNetwork
|
||||
|
@ -42,6 +43,51 @@ class ReconnectTests(IntegrationTestCase):
|
|||
await self.ledger.network.on_connected.first
|
||||
self.assertTrue(self.ledger.network.is_connected)
|
||||
|
||||
async def test_socket_timeout_then_reconnect(self):
|
||||
# TODO: test reconnecting on an rpc request
|
||||
# TODO: test rolling over to a working server when an rpc request fails before raising
|
||||
|
||||
self.assertTrue(self.ledger.network.is_connected)
|
||||
|
||||
await self.assertBalance(self.account, '0.0')
|
||||
|
||||
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||
|
||||
real_sock = self.ledger.network.client.transport._extra.pop('socket')
|
||||
mock_sock = Mock(spec=socket.socket)
|
||||
|
||||
for attr in dir(real_sock):
|
||||
if not attr.startswith('__'):
|
||||
setattr(mock_sock, attr, getattr(real_sock, attr))
|
||||
|
||||
raised = asyncio.Event(loop=self.loop)
|
||||
|
||||
def recv(*a, **kw):
|
||||
raised.set()
|
||||
raise TimeoutError("[Errno 60] Operation timed out")
|
||||
|
||||
mock_sock.recv = recv
|
||||
self.ledger.network.client.transport._sock = mock_sock
|
||||
self.ledger.network.client.transport._extra['socket'] = mock_sock
|
||||
|
||||
await self.blockchain.send_to_address(address1, 21)
|
||||
await self.blockchain.generate(1)
|
||||
self.assertFalse(raised.is_set())
|
||||
|
||||
await asyncio.wait_for(raised.wait(), 2)
|
||||
await self.assertBalance(self.account, '0.0')
|
||||
self.assertFalse(self.ledger.network.is_connected)
|
||||
self.assertIsNone(self.ledger.network.client.transport)
|
||||
|
||||
await self.blockchain.send_to_address(address1, 21)
|
||||
await self.blockchain.generate(1)
|
||||
await self.ledger.network.on_connected.first
|
||||
self.assertTrue(self.ledger.network.is_connected)
|
||||
|
||||
await asyncio.sleep(30, loop=self.loop)
|
||||
self.assertIsNotNone(self.ledger.network.client.transport)
|
||||
await self.assertBalance(self.account, '42.0')
|
||||
|
||||
|
||||
class ServerPickingTestCase(AsyncioTestCase):
|
||||
async def _make_fake_server(self, latency=1.0, port=1337):
|
||||
|
|
|
@ -55,9 +55,25 @@ class ClientSession(BaseClientSession):
|
|||
self.ping_task.cancel()
|
||||
|
||||
|
||||
def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context):
|
||||
exception = context['exception']
|
||||
if 'protocol' not in context or 'transport' not in context:
|
||||
raise exception
|
||||
if not isinstance(context['protocol'], ClientSession):
|
||||
raise exception
|
||||
transport: asyncio.Transport = context['transport']
|
||||
message = context['message']
|
||||
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
|
||||
raise exception
|
||||
log.warning("Disconnecting after error: %s", str(exception))
|
||||
transport.abort()
|
||||
transport.close()
|
||||
|
||||
|
||||
class BaseNetwork:
|
||||
|
||||
def __init__(self, ledger):
|
||||
asyncio.get_event_loop().set_exception_handler(protocol_exception_handler)
|
||||
self.config = ledger.config
|
||||
self.client: ClientSession = None
|
||||
self.session_pool: SessionPool = None
|
||||
|
|
Loading…
Add table
Reference in a new issue