Compare commits

...
Sign in to create a new pull request.

4 commits

Author SHA1 Message Date
Victor Shyba
a357cb6b26 unused import 2019-07-12 06:59:05 -03:00
Victor Shyba
5ba2190ca7 remove sleep from test server sessions 2019-07-12 06:58:43 -03:00
Victor Shyba
0bd3c3f503 refactor protocol handler + misc 2019-07-12 06:38:23 -03:00
Jack Robison
489da88e79 reconnect torba ClientSession on fatal socket errors 2019-07-12 00:29:17 -03:00
5 changed files with 61 additions and 36 deletions

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
import socket import socket
import time
import logging import logging
from unittest.mock import Mock from unittest.mock import Mock
from torba.testcase import IntegrationTestCase, Conductor from torba.testcase import IntegrationTestCase, Conductor
@ -61,7 +60,6 @@ class TestSessionBloat(IntegrationTestCase):
def delayed_resume(): def delayed_resume():
self.paused_session.clear() self.paused_session.clear()
time.sleep(1)
self.resumed_session.set() self.resumed_session.set()
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused) self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)

View file

@ -1,5 +1,6 @@
import logging import logging
import asyncio import asyncio
import socket
from unittest.mock import Mock from unittest.mock import Mock
from torba.client.basenetwork import BaseNetwork from torba.client.basenetwork import BaseNetwork
@ -42,6 +43,48 @@ class ReconnectTests(IntegrationTestCase):
await self.ledger.network.on_connected.first await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected)
async def test_socket_timeout_then_reconnect(self):
# TODO: test reconnecting on an rpc request
# TODO: test rolling over to a working server when an rpc request fails before raising
self.assertTrue(self.ledger.network.is_connected)
address1 = await self.account.receiving.get_or_create_usable_address()
txid = await self.blockchain.send_to_address(address1, 21)
await self.blockchain.generate(1)
real_sock = self.ledger.network.client.transport._extra.pop('socket')
mock_sock = Mock(spec=socket.socket)
for attr in dir(real_sock):
if not attr.startswith('__'):
setattr(mock_sock, attr, getattr(real_sock, attr))
raised = asyncio.Event(loop=self.loop)
def recv(*a, **kw):
raised.set()
raise TimeoutError("[Errno 60] Operation timed out")
mock_sock.recv = recv
self.ledger.network.client.transport._sock = mock_sock
self.ledger.network.client.transport._extra['socket'] = mock_sock
self.assertFalse(raised.is_set())
with self.assertRaises(asyncio.CancelledError):
await self.ledger.network.get_transaction(txid)
self.assertTrue(raised.is_set())
self.assertFalse(self.ledger.network.is_connected)
self.assertIsNone(self.ledger.network.client.transport)
txid = await self.blockchain.send_to_address(address1, 2)
await self.blockchain.generate(1)
await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected)
self.assertIsNotNone(self.ledger.network.client.transport)
self.assertIsNotNone(await self.ledger.network.get_transaction(txid))
class ServerPickingTestCase(AsyncioTestCase): class ServerPickingTestCase(AsyncioTestCase):
async def _make_fake_server(self, latency=1.0, port=1337): async def _make_fake_server(self, latency=1.0, port=1337):

View file

@ -8,6 +8,7 @@ import socket
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
from torba import __version__ from torba import __version__
from torba.rpc.util import protocol_exception_handler
from torba.stream import StreamController from torba.stream import StreamController
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -58,6 +59,7 @@ class ClientSession(BaseClientSession):
class BaseNetwork: class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
asyncio.get_event_loop().set_exception_handler(protocol_exception_handler)
self.config = ledger.config self.config = ledger.config
self.client: ClientSession = None self.client: ClientSession = None
self.session_pool: SessionPool = None self.session_pool: SessionPool = None
@ -203,7 +205,10 @@ class SessionPool:
self.ensure_connection(session) self.ensure_connection(session)
for session in self.sessions for session in self.sessions
], return_exceptions=True) ], return_exceptions=True)
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED') try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear() self._lost_master.clear()
if not self.sessions: if not self.sessions:
self.sessions.extend(self._dead_servers) self.sessions.extend(self._dead_servers)

View file

@ -93,3 +93,13 @@ class Concurrency(object):
else: else:
for _ in range(-diff): for _ in range(-diff):
await self.semaphore.acquire() await self.semaphore.acquire()
def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context):
message = context['message']
transport = context.get('transport')
if transport and message in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
transport.abort()
transport.close()
else:
return loop.default_exception_handler(context)

View file

@ -4,39 +4,8 @@ import asyncio
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
import torba import torba
from torba.rpc.util import protocol_exception_handler
from torba.server.mempool import MemPool, MemPoolAPI from torba.server.mempool import MemPool, MemPoolAPI
from torba.server.session import SessionManager, SessionBase
CONNECTION_TIMED_OUT = 110
NO_ROUTE_TO_HOST = 113
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
log = logging.getLogger(__name__)
def protocol_exception_handler(loop, context):
exception = context['exception']
if 'protocol' not in context or 'transport' not in context:
raise exception
if not isinstance(context['protocol'], SessionBase):
raise exception
session: SessionBase = context['protocol']
transport: asyncio.Transport = context['transport']
message = context['message']
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
raise exception
for err_msg in err_msgs:
if str(exception).startswith(err_msg):
log.debug("caught: '%s' for %s", str(exception), session)
transport.abort()
transport.close()
loop.create_task(session.close(force_after=1))
return
raise exception
return protocol_exception_handler
class Notifications: class Notifications:
@ -121,7 +90,7 @@ class Server:
) )
async def start(self): async def start(self):
asyncio.get_event_loop().set_exception_handler(handle_socket_errors()) asyncio.get_event_loop().set_exception_handler(protocol_exception_handler)
env = self.env env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.log.info(f'software version: {torba.__version__}') self.log.info(f'software version: {torba.__version__}')