disconnect from client sending message larger than MAX_RECEIVE

This commit is contained in:
Jack Robison 2020-02-20 16:08:21 -05:00
parent ad1e9ef086
commit 7fd0d6507f
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 28 additions and 5 deletions

View file

@ -1,5 +1,6 @@
import logging
import asyncio
import json
from time import perf_counter
from operator import itemgetter
from typing import Dict, Optional, Tuple
@ -68,7 +69,14 @@ class ClientSession(BaseClientSession):
log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError
if done:
return request.result()
try:
return request.result()
except ConnectionResetError:
log.error(
"wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes",
self.server[0], method, len(args), len(json.dumps(args))
)
raise
except (RPCError, ProtocolError) as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)

View file

@ -386,9 +386,11 @@ class RPCSession(SessionBase):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except MemoryError as e:
self.logger.warning(f'{e!r}')
continue
except MemoryError:
self.logger.warning('received oversized message from %s:%s, dropping connection',
self._address[0], self._address[1])
self._close()
return
self.last_recv = time.perf_counter()
self.recv_count += 1

View file

@ -1,10 +1,10 @@
import asyncio
import os.path
import tempfile
import logging
from binascii import unhexlify
from urllib.request import urlopen
from lbry.error import InsufficientFundsError
from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE
@ -73,6 +73,19 @@ class ClaimSearchCommand(ClaimTestCase):
(result['txid'], result['claim_id'])
)
async def test_disconnect_on_memory_error(self):
claim_ids = [
'0000000000000000000000000000000000000000',
] * 23828
self.assertListEqual([], await self.claim_search(claim_ids=claim_ids))
# 23829 claim ids makes the request just large enough
claim_ids = [
'0000000000000000000000000000000000000000',
] * 23829
with self.assertRaises(ConnectionResetError):
await self.claim_search(claim_ids=claim_ids)
async def test_basic_claim_search(self):
await self.create_channel()
channel_txo = self.channel['outputs'][0]