diff --git a/lbry/tests/integration/test_wallet_server_sessions.py b/lbry/tests/integration/test_wallet_server_sessions.py index 0ba1914cf..21b25030a 100644 --- a/lbry/tests/integration/test_wallet_server_sessions.py +++ b/lbry/tests/integration/test_wallet_server_sessions.py @@ -1,132 +1,30 @@ import asyncio -import socket -import time -import logging -from unittest.mock import Mock -from torba.testcase import IntegrationTestCase, Conductor + +from torba.client.basenetwork import ClientSession +from torba.orchstr8 import Conductor +from torba.testcase import IntegrationTestCase import lbry.wallet -from lbry.schema.claim import Claim -from lbry.wallet.transaction import Transaction, Output -from lbry.wallet.dewies import dewies_to_lbc as d2l, lbc_to_dewies as l2d - - -log = logging.getLogger(__name__) -def wrap_callback_event(fn, callback): - def inner(*a, **kw): - callback() - return fn(*a, **kw) - return inner class TestSessionBloat(IntegrationTestCase): """ - ERROR:asyncio:Fatal read error on socket transport - protocol: - transport: <_SelectorSocketTransport fd=3236 read=polling write=> - Traceback (most recent call last): - File "/usr/lib/python3.7/asyncio/selector_events.py", line 801, in _read_ready__data_received - data = self._sock.recv(self.max_size) - TimeoutError: [Errno 110] Connection timed out + Tests that server cleans up stale connections after session timeout and client times out too. """ LEDGER = lbry.wallet - async def asyncSetUp(self): - self.conductor = Conductor( - ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY - ) - await self.conductor.start_blockchain() - self.addCleanup(self.conductor.stop_blockchain) - - await self.conductor.start_spv() - - self.session_manager = self.conductor.spv_node.server.session_mgr - self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 64) - self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 64) - - self.addCleanup(self.conductor.stop_spv) - await self.conductor.start_wallet() - self.addCleanup(self.conductor.stop_wallet) - - self.client_session = list(self.session_manager.sessions)[0] - self.client_session.transport.set_write_buffer_limits(0, 0) - - self.paused_session = asyncio.Event(loop=self.loop) - self.resumed_session = asyncio.Event(loop=self.loop) - - def paused(): - self.resumed_session.clear() - self.paused_session.set() - - def delayed_resume(): - self.paused_session.clear() - - time.sleep(1) - self.resumed_session.set() - - self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused) - self.client_session.resume_writing = wrap_callback_event(self.client_session.resume_writing, delayed_resume) - - self.blockchain = self.conductor.blockchain_node - self.wallet_node = self.conductor.wallet_node - self.manager = self.wallet_node.manager - self.ledger = self.wallet_node.ledger - self.wallet = self.wallet_node.wallet - self.account = self.wallet_node.wallet.default_account - async def test_session_bloat_from_socket_timeout(self): - await self.account.ensure_address_gap() - - address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True) - sendtxid1 = await self.blockchain.send_to_address(address1, 5) - sendtxid2 = await self.blockchain.send_to_address(address2, 5) - - await self.blockchain.generate(1) - await asyncio.wait([ - self.on_transaction_id(sendtxid1), - self.on_transaction_id(sendtxid2) - ]) - - self.assertEqual(d2l(await self.account.get_balance()), '10.0') - - channel = Claim() - channel_txo = Output.pay_claim_name_pubkey_hash( - l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1) - ) - channel_txo.generate_channel_private_key() - channel_txo.script.generate() - channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account) - - stream = Claim() - stream.stream.description = "0" * 8000 - stream_txo = Output.pay_claim_name_pubkey_hash( - l2d('1.0'), 'foo', stream, self.account.ledger.address_to_hash160(address1) - ) - stream_tx = await Transaction.create([], [stream_txo], [self.account], self.account) - stream_txo.sign(channel_txo) - await stream_tx.sign([self.account]) - self.paused_session.clear() - self.resumed_session.clear() - - await self.broadcast(channel_tx) - await self.broadcast(stream_tx) - await asyncio.wait_for(self.paused_session.wait(), 2) - self.assertEqual(1, len(self.session_manager.sessions)) - - real_sock = self.client_session.transport._extra.pop('socket') - mock_sock = Mock(spec=socket.socket) - - for attr in dir(real_sock): - if not attr.startswith('__'): - setattr(mock_sock, attr, getattr(real_sock, attr)) - - def recv(*a, **kw): - raise TimeoutError("[Errno 110] Connection timed out") - - mock_sock.recv = recv - self.client_session.transport._sock = mock_sock - self.client_session.transport._extra['socket'] = mock_sock - self.assertFalse(self.resumed_session.is_set()) - self.assertFalse(self.session_manager.session_event.is_set()) - await self.session_manager.session_event.wait() - self.assertEqual(0, len(self.session_manager.sessions)) + await self.conductor.stop_spv() + self.conductor.spv_node.session_timeout = 1 + await self.conductor.start_spv() + session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2) + await session.create_connection() + session.ping_task.cancel() + await session.send_request('server.banner', ()) + self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) + self.assertFalse(session.is_closing()) + await asyncio.sleep(1.1) + with self.assertRaises(asyncio.TimeoutError): + await session.send_request('server.banner', ()) + self.assertTrue(session.is_closing()) + self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 0) diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 1f5a122a8..ba014a246 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -15,22 +15,26 @@ log = logging.getLogger(__name__) class ClientSession(BaseClientSession): - def __init__(self, *args, network, server, **kwargs): + def __init__(self, *args, network, server, timeout=30, **kwargs): self.network = network self.server = server super().__init__(*args, **kwargs) self._on_disconnect_controller = StreamController() self.on_disconnected = self._on_disconnect_controller.stream self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 - self.max_seconds_idle = 60 + self.timeout = timeout + self.max_seconds_idle = timeout * 2 self.ping_task = None async def send_request(self, method, args=()): try: - return await super().send_request(method, args) + return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) except RPCError as e: log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) raise e + except asyncio.TimeoutError: + self.abort() + raise async def ping_forever(self): # TODO: change to 'ping' on newer protocol (above 1.2) @@ -209,7 +213,10 @@ class SessionPool: self.ensure_connection(session) for session in self.sessions ], return_exceptions=True) - await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED') + try: + await asyncio.wait_for(self._lost_master.wait(), timeout=3) + except asyncio.TimeoutError: + pass self._lost_master.clear() if not self.sessions: self.sessions.extend(self._dead_servers) diff --git a/torba/torba/orchstr8/node.py b/torba/torba/orchstr8/node.py index 9e5ecf225..6782ee619 100644 --- a/torba/torba/orchstr8/node.py +++ b/torba/torba/orchstr8/node.py @@ -195,6 +195,7 @@ class SPVNode: self.server = None self.hostname = 'localhost' self.port = 50001 + 1 # avoid conflict with default daemon + self.session_timeout = 600 async def start(self, blockchain_node: 'BlockchainNode'): self.data_path = tempfile.mkdtemp() @@ -204,6 +205,7 @@ class SPVNode: 'REORG_LIMIT': '100', 'HOST': self.hostname, 'TCP_PORT': str(self.port), + 'SESSION_TIMEOUT': str(self.session_timeout), 'MAX_QUERY_WORKERS': '0' } # TODO: don't use os.environ diff --git a/torba/torba/server/daemon.py b/torba/torba/server/daemon.py index 72a87675f..b66430fae 100644 --- a/torba/torba/server/daemon.py +++ b/torba/torba/server/daemon.py @@ -98,6 +98,8 @@ class Daemon: return aiohttp.ClientSession(connector=self.connector, connector_owner=False) async def _send_data(self, data): + if not self.connector: + raise asyncio.CancelledError('Tried to send request during shutdown.') async with self.workqueue_semaphore: async with self.client_session() as session: async with session.post(self.current_url(), data=data) as resp: diff --git a/torba/torba/server/server.py b/torba/torba/server/server.py index 391e1cb2f..691b84055 100644 --- a/torba/torba/server/server.py +++ b/torba/torba/server/server.py @@ -5,38 +5,6 @@ from concurrent.futures.thread import ThreadPoolExecutor import torba from torba.server.mempool import MemPool, MemPoolAPI -from torba.server.session import SessionManager, SessionBase - - -CONNECTION_TIMED_OUT = 110 -NO_ROUTE_TO_HOST = 113 - - -def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)): - err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors)) - log = logging.getLogger(__name__) - - def protocol_exception_handler(loop, context): - exception = context['exception'] - if 'protocol' not in context or 'transport' not in context: - raise exception - if not isinstance(context['protocol'], SessionBase): - raise exception - session: SessionBase = context['protocol'] - transport: asyncio.Transport = context['transport'] - message = context['message'] - if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"): - raise exception - - for err_msg in err_msgs: - if str(exception).startswith(err_msg): - log.debug("caught: '%s' for %s", str(exception), session) - transport.abort() - transport.close() - loop.create_task(session.close(force_after=1)) - return - raise exception - return protocol_exception_handler class Notifications: @@ -121,7 +89,6 @@ class Server: ) async def start(self): - asyncio.get_event_loop().set_exception_handler(handle_socket_errors()) env = self.env min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() self.log.info(f'software version: {torba.__version__}') diff --git a/torba/torba/server/session.py b/torba/torba/server/session.py index e28b82af4..9c5981d22 100644 --- a/torba/torba/server/session.py +++ b/torba/torba/server/session.py @@ -255,9 +255,10 @@ class SessionManager: async def _clear_stale_sessions(self): """Cut off sessions that haven't done anything for 10 minutes.""" + session_timeout = self.env.session_timeout while True: - await sleep(60) - stale_cutoff = time.time() - self.env.session_timeout + await sleep(session_timeout // 10) + stale_cutoff = time.time() - session_timeout stale_sessions = [session for session in self.sessions if session.last_recv < stale_cutoff] if stale_sessions: @@ -267,7 +268,7 @@ class SessionManager: # Give the sockets some time to close gracefully if stale_sessions: await asyncio.wait([ - session.close() for session in stale_sessions + session.close(force_after=session_timeout // 10) for session in stale_sessions ]) # Consolidate small groups