diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index 8ddd2f326..4e821499d 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -1,5 +1,6 @@ import logging import asyncio +import json from time import perf_counter from operator import itemgetter from typing import Dict, Optional, Tuple @@ -68,7 +69,14 @@ class ClientSession(BaseClientSession): log.info("timeout sending %s to %s:%i", method, *self.server) raise asyncio.TimeoutError if done: - return request.result() + try: + return request.result() + except ConnectionResetError: + log.error( + "wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes", + self.server[0], method, len(args), len(json.dumps(args)) + ) + raise except (RPCError, ProtocolError) as e: log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", *self.server, *e.args) diff --git a/lbry/wallet/rpc/session.py b/lbry/wallet/rpc/session.py index 425f264b1..e9c4c6925 100644 --- a/lbry/wallet/rpc/session.py +++ b/lbry/wallet/rpc/session.py @@ -386,9 +386,11 @@ class RPCSession(SessionBase): while not self.is_closing(): try: message = await self.framer.receive_message() - except MemoryError as e: - self.logger.warning(f'{e!r}') - continue + except MemoryError: + self.logger.warning('received oversized message from %s:%s, dropping connection', + self._address[0], self._address[1]) + self._close() + return self.last_recv = time.perf_counter() self.recv_count += 1 diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 9c6afb2d7..5c6e5bebd 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -1,10 +1,10 @@ +import asyncio import os.path import tempfile import logging from binascii import unhexlify from urllib.request import urlopen - from lbry.error import InsufficientFundsError from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE @@ -73,6 +73,19 @@ class ClaimSearchCommand(ClaimTestCase): (result['txid'], result['claim_id']) ) + async def test_disconnect_on_memory_error(self): + claim_ids = [ + '0000000000000000000000000000000000000000', + ] * 23828 + self.assertListEqual([], await self.claim_search(claim_ids=claim_ids)) + + # 23829 claim ids makes the request just large enough + claim_ids = [ + '0000000000000000000000000000000000000000', + ] * 23829 + with self.assertRaises(ConnectionResetError): + await self.claim_search(claim_ids=claim_ids) + async def test_basic_claim_search(self): await self.create_channel() channel_txo = self.channel['outputs'][0]