Compare commits
4 commits
master
...
torba-reco
Author | SHA1 | Date | |
---|---|---|---|
|
a357cb6b26 | ||
|
5ba2190ca7 | ||
|
0bd3c3f503 | ||
|
489da88e79 |
5 changed files with 61 additions and 36 deletions
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from torba.testcase import IntegrationTestCase, Conductor
|
from torba.testcase import IntegrationTestCase, Conductor
|
||||||
|
@ -61,7 +60,6 @@ class TestSessionBloat(IntegrationTestCase):
|
||||||
def delayed_resume():
|
def delayed_resume():
|
||||||
self.paused_session.clear()
|
self.paused_session.clear()
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
self.resumed_session.set()
|
self.resumed_session.set()
|
||||||
|
|
||||||
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
|
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import socket
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from torba.client.basenetwork import BaseNetwork
|
from torba.client.basenetwork import BaseNetwork
|
||||||
|
@ -42,6 +43,48 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
await self.ledger.network.on_connected.first
|
await self.ledger.network.on_connected.first
|
||||||
self.assertTrue(self.ledger.network.is_connected)
|
self.assertTrue(self.ledger.network.is_connected)
|
||||||
|
|
||||||
|
async def test_socket_timeout_then_reconnect(self):
|
||||||
|
# TODO: test reconnecting on an rpc request
|
||||||
|
# TODO: test rolling over to a working server when an rpc request fails before raising
|
||||||
|
|
||||||
|
self.assertTrue(self.ledger.network.is_connected)
|
||||||
|
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||||
|
txid = await self.blockchain.send_to_address(address1, 21)
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
|
||||||
|
real_sock = self.ledger.network.client.transport._extra.pop('socket')
|
||||||
|
mock_sock = Mock(spec=socket.socket)
|
||||||
|
|
||||||
|
for attr in dir(real_sock):
|
||||||
|
if not attr.startswith('__'):
|
||||||
|
setattr(mock_sock, attr, getattr(real_sock, attr))
|
||||||
|
|
||||||
|
raised = asyncio.Event(loop=self.loop)
|
||||||
|
|
||||||
|
def recv(*a, **kw):
|
||||||
|
raised.set()
|
||||||
|
raise TimeoutError("[Errno 60] Operation timed out")
|
||||||
|
|
||||||
|
mock_sock.recv = recv
|
||||||
|
self.ledger.network.client.transport._sock = mock_sock
|
||||||
|
self.ledger.network.client.transport._extra['socket'] = mock_sock
|
||||||
|
|
||||||
|
self.assertFalse(raised.is_set())
|
||||||
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
|
await self.ledger.network.get_transaction(txid)
|
||||||
|
|
||||||
|
self.assertTrue(raised.is_set())
|
||||||
|
self.assertFalse(self.ledger.network.is_connected)
|
||||||
|
self.assertIsNone(self.ledger.network.client.transport)
|
||||||
|
|
||||||
|
txid = await self.blockchain.send_to_address(address1, 2)
|
||||||
|
await self.blockchain.generate(1)
|
||||||
|
await self.ledger.network.on_connected.first
|
||||||
|
self.assertTrue(self.ledger.network.is_connected)
|
||||||
|
|
||||||
|
self.assertIsNotNone(self.ledger.network.client.transport)
|
||||||
|
self.assertIsNotNone(await self.ledger.network.get_transaction(txid))
|
||||||
|
|
||||||
|
|
||||||
class ServerPickingTestCase(AsyncioTestCase):
|
class ServerPickingTestCase(AsyncioTestCase):
|
||||||
async def _make_fake_server(self, latency=1.0, port=1337):
|
async def _make_fake_server(self, latency=1.0, port=1337):
|
||||||
|
|
|
@ -8,6 +8,7 @@ import socket
|
||||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||||
|
|
||||||
from torba import __version__
|
from torba import __version__
|
||||||
|
from torba.rpc.util import protocol_exception_handler
|
||||||
from torba.stream import StreamController
|
from torba.stream import StreamController
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -58,6 +59,7 @@ class ClientSession(BaseClientSession):
|
||||||
class BaseNetwork:
|
class BaseNetwork:
|
||||||
|
|
||||||
def __init__(self, ledger):
|
def __init__(self, ledger):
|
||||||
|
asyncio.get_event_loop().set_exception_handler(protocol_exception_handler)
|
||||||
self.config = ledger.config
|
self.config = ledger.config
|
||||||
self.client: ClientSession = None
|
self.client: ClientSession = None
|
||||||
self.session_pool: SessionPool = None
|
self.session_pool: SessionPool = None
|
||||||
|
@ -203,7 +205,10 @@ class SessionPool:
|
||||||
self.ensure_connection(session)
|
self.ensure_connection(session)
|
||||||
for session in self.sessions
|
for session in self.sessions
|
||||||
], return_exceptions=True)
|
], return_exceptions=True)
|
||||||
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
|
try:
|
||||||
|
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
self._lost_master.clear()
|
self._lost_master.clear()
|
||||||
if not self.sessions:
|
if not self.sessions:
|
||||||
self.sessions.extend(self._dead_servers)
|
self.sessions.extend(self._dead_servers)
|
||||||
|
|
|
@ -93,3 +93,13 @@ class Concurrency(object):
|
||||||
else:
|
else:
|
||||||
for _ in range(-diff):
|
for _ in range(-diff):
|
||||||
await self.semaphore.acquire()
|
await self.semaphore.acquire()
|
||||||
|
|
||||||
|
|
||||||
|
def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context):
|
||||||
|
message = context['message']
|
||||||
|
transport = context.get('transport')
|
||||||
|
if transport and message in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
|
||||||
|
transport.abort()
|
||||||
|
transport.close()
|
||||||
|
else:
|
||||||
|
return loop.default_exception_handler(context)
|
||||||
|
|
|
@ -4,39 +4,8 @@ import asyncio
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
import torba
|
import torba
|
||||||
|
from torba.rpc.util import protocol_exception_handler
|
||||||
from torba.server.mempool import MemPool, MemPoolAPI
|
from torba.server.mempool import MemPool, MemPoolAPI
|
||||||
from torba.server.session import SessionManager, SessionBase
|
|
||||||
|
|
||||||
|
|
||||||
CONNECTION_TIMED_OUT = 110
|
|
||||||
NO_ROUTE_TO_HOST = 113
|
|
||||||
|
|
||||||
|
|
||||||
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
|
|
||||||
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def protocol_exception_handler(loop, context):
|
|
||||||
exception = context['exception']
|
|
||||||
if 'protocol' not in context or 'transport' not in context:
|
|
||||||
raise exception
|
|
||||||
if not isinstance(context['protocol'], SessionBase):
|
|
||||||
raise exception
|
|
||||||
session: SessionBase = context['protocol']
|
|
||||||
transport: asyncio.Transport = context['transport']
|
|
||||||
message = context['message']
|
|
||||||
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
|
|
||||||
raise exception
|
|
||||||
|
|
||||||
for err_msg in err_msgs:
|
|
||||||
if str(exception).startswith(err_msg):
|
|
||||||
log.debug("caught: '%s' for %s", str(exception), session)
|
|
||||||
transport.abort()
|
|
||||||
transport.close()
|
|
||||||
loop.create_task(session.close(force_after=1))
|
|
||||||
return
|
|
||||||
raise exception
|
|
||||||
return protocol_exception_handler
|
|
||||||
|
|
||||||
|
|
||||||
class Notifications:
|
class Notifications:
|
||||||
|
@ -121,7 +90,7 @@ class Server:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
asyncio.get_event_loop().set_exception_handler(handle_socket_errors())
|
asyncio.get_event_loop().set_exception_handler(protocol_exception_handler)
|
||||||
env = self.env
|
env = self.env
|
||||||
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
|
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
|
||||||
self.log.info(f'software version: {torba.__version__}')
|
self.log.info(f'software version: {torba.__version__}')
|
||||||
|
|
Loading…
Add table
Reference in a new issue