refactor protocol handler + misc

This commit is contained in:
Victor Shyba 2019-07-12 06:38:23 -03:00
parent 489da88e79
commit 0bd3c3f503
4 changed files with 24 additions and 59 deletions

View file

@ -48,10 +48,9 @@ class ReconnectTests(IntegrationTestCase):
# TODO: test rolling over to a working server when an rpc request fails before raising # TODO: test rolling over to a working server when an rpc request fails before raising
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected)
await self.assertBalance(self.account, '0.0')
address1 = await self.account.receiving.get_or_create_usable_address() address1 = await self.account.receiving.get_or_create_usable_address()
txid = await self.blockchain.send_to_address(address1, 21)
await self.blockchain.generate(1)
real_sock = self.ledger.network.client.transport._extra.pop('socket') real_sock = self.ledger.network.client.transport._extra.pop('socket')
mock_sock = Mock(spec=socket.socket) mock_sock = Mock(spec=socket.socket)
@ -70,23 +69,21 @@ class ReconnectTests(IntegrationTestCase):
self.ledger.network.client.transport._sock = mock_sock self.ledger.network.client.transport._sock = mock_sock
self.ledger.network.client.transport._extra['socket'] = mock_sock self.ledger.network.client.transport._extra['socket'] = mock_sock
await self.blockchain.send_to_address(address1, 21)
await self.blockchain.generate(1)
self.assertFalse(raised.is_set()) self.assertFalse(raised.is_set())
with self.assertRaises(asyncio.CancelledError):
await self.ledger.network.get_transaction(txid)
await asyncio.wait_for(raised.wait(), 2) self.assertTrue(raised.is_set())
await self.assertBalance(self.account, '0.0')
self.assertFalse(self.ledger.network.is_connected) self.assertFalse(self.ledger.network.is_connected)
self.assertIsNone(self.ledger.network.client.transport) self.assertIsNone(self.ledger.network.client.transport)
await self.blockchain.send_to_address(address1, 21) txid = await self.blockchain.send_to_address(address1, 2)
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.ledger.network.on_connected.first await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected)
await asyncio.sleep(30, loop=self.loop)
self.assertIsNotNone(self.ledger.network.client.transport) self.assertIsNotNone(self.ledger.network.client.transport)
await self.assertBalance(self.account, '42.0') self.assertIsNotNone(await self.ledger.network.get_transaction(txid))
class ServerPickingTestCase(AsyncioTestCase): class ServerPickingTestCase(AsyncioTestCase):

View file

@ -8,6 +8,7 @@ import socket
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
from torba import __version__ from torba import __version__
from torba.rpc.util import protocol_exception_handler
from torba.stream import StreamController from torba.stream import StreamController
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -55,21 +56,6 @@ class ClientSession(BaseClientSession):
self.ping_task.cancel() self.ping_task.cancel()
def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context):
exception = context['exception']
if 'protocol' not in context or 'transport' not in context:
raise exception
if not isinstance(context['protocol'], ClientSession):
raise exception
transport: asyncio.Transport = context['transport']
message = context['message']
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
raise exception
log.warning("Disconnecting after error: %s", str(exception))
transport.abort()
transport.close()
class BaseNetwork: class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
@ -219,7 +205,10 @@ class SessionPool:
self.ensure_connection(session) self.ensure_connection(session)
for session in self.sessions for session in self.sessions
], return_exceptions=True) ], return_exceptions=True)
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED') try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear() self._lost_master.clear()
if not self.sessions: if not self.sessions:
self.sessions.extend(self._dead_servers) self.sessions.extend(self._dead_servers)

View file

@ -93,3 +93,13 @@ class Concurrency(object):
else: else:
for _ in range(-diff): for _ in range(-diff):
await self.semaphore.acquire() await self.semaphore.acquire()
def protocol_exception_handler(loop: asyncio.AbstractEventLoop, context):
message = context['message']
transport = context.get('transport')
if transport and message in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
transport.abort()
transport.close()
else:
return loop.default_exception_handler(context)

View file

@ -4,39 +4,8 @@ import asyncio
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
import torba import torba
from torba.rpc.util import protocol_exception_handler
from torba.server.mempool import MemPool, MemPoolAPI from torba.server.mempool import MemPool, MemPoolAPI
from torba.server.session import SessionManager, SessionBase
CONNECTION_TIMED_OUT = 110
NO_ROUTE_TO_HOST = 113
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
log = logging.getLogger(__name__)
def protocol_exception_handler(loop, context):
exception = context['exception']
if 'protocol' not in context or 'transport' not in context:
raise exception
if not isinstance(context['protocol'], SessionBase):
raise exception
session: SessionBase = context['protocol']
transport: asyncio.Transport = context['transport']
message = context['message']
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
raise exception
for err_msg in err_msgs:
if str(exception).startswith(err_msg):
log.debug("caught: '%s' for %s", str(exception), session)
transport.abort()
transport.close()
loop.create_task(session.close(force_after=1))
return
raise exception
return protocol_exception_handler
class Notifications: class Notifications:
@ -121,7 +90,7 @@ class Server:
) )
async def start(self): async def start(self):
asyncio.get_event_loop().set_exception_handler(handle_socket_errors()) asyncio.get_event_loop().set_exception_handler(protocol_exception_handler)
env = self.env env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.log.info(f'software version: {torba.__version__}') self.log.info(f'software version: {torba.__version__}')