diff --git a/torba/torba/server/server.py b/torba/torba/server/server.py index 384287d20..322e68a30 100644 --- a/torba/torba/server/server.py +++ b/torba/torba/server/server.py @@ -5,7 +5,38 @@ from concurrent.futures.thread import ThreadPoolExecutor import torba from torba.server.mempool import MemPool, MemPoolAPI -from torba.server.session import SessionManager +from torba.server.session import SessionManager, SessionBase + + +CONNECTION_TIMED_OUT = 110 +NO_ROUTE_TO_HOST = 113 + + +def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)): + err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors)) + log = logging.getLogger(__name__) + + def protocol_exception_handler(loop, context): + exception = context['exception'] + if 'protocol' not in context or 'transport' not in context: + raise exception + if not isinstance(context['protocol'], SessionBase): + raise exception + session: SessionBase = context['protocol'] + transport: asyncio.Transport = context['transport'] + message = context['message'] + if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"): + raise exception + + for err_msg in err_msgs: + if str(exception).startswith(err_msg): + log.debug("caught: '%s' for %s", str(exception), session) + transport.abort() + transport.close() + loop.create_task(session.close(force_after=1)) + return + raise exception + return protocol_exception_handler class Notifications: @@ -90,6 +121,7 @@ class Server: ) async def start(self): + asyncio.get_event_loop().set_exception_handler(handle_socket_errors()) env = self.env min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() self.log.info(f'software version: {torba.__version__}')