disconnect from client sending message larger than MAX_RECEIVE

This commit is contained in:
Jack Robison 2020-02-20 16:08:21 -05:00
parent ad1e9ef086
commit 7fd0d6507f
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 28 additions and 5 deletions

View file

@ -1,5 +1,6 @@
import logging import logging
import asyncio import asyncio
import json
from time import perf_counter from time import perf_counter
from operator import itemgetter from operator import itemgetter
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -68,7 +69,14 @@ class ClientSession(BaseClientSession):
log.info("timeout sending %s to %s:%i", method, *self.server) log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError raise asyncio.TimeoutError
if done: if done:
return request.result() try:
return request.result()
except ConnectionResetError:
log.error(
"wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes",
self.server[0], method, len(args), len(json.dumps(args))
)
raise
except (RPCError, ProtocolError) as e: except (RPCError, ProtocolError) as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args) *self.server, *e.args)

View file

@ -386,9 +386,11 @@ class RPCSession(SessionBase):
while not self.is_closing(): while not self.is_closing():
try: try:
message = await self.framer.receive_message() message = await self.framer.receive_message()
except MemoryError as e: except MemoryError:
self.logger.warning(f'{e!r}') self.logger.warning('received oversized message from %s:%s, dropping connection',
continue self._address[0], self._address[1])
self._close()
return
self.last_recv = time.perf_counter() self.last_recv = time.perf_counter()
self.recv_count += 1 self.recv_count += 1

View file

@ -1,10 +1,10 @@
import asyncio
import os.path import os.path
import tempfile import tempfile
import logging import logging
from binascii import unhexlify from binascii import unhexlify
from urllib.request import urlopen from urllib.request import urlopen
from lbry.error import InsufficientFundsError from lbry.error import InsufficientFundsError
from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE
@ -73,6 +73,19 @@ class ClaimSearchCommand(ClaimTestCase):
(result['txid'], result['claim_id']) (result['txid'], result['claim_id'])
) )
async def test_disconnect_on_memory_error(self):
claim_ids = [
'0000000000000000000000000000000000000000',
] * 23828
self.assertListEqual([], await self.claim_search(claim_ids=claim_ids))
# 23829 claim ids makes the request just large enough
claim_ids = [
'0000000000000000000000000000000000000000',
] * 23829
with self.assertRaises(ConnectionResetError):
await self.claim_search(claim_ids=claim_ids)
async def test_basic_claim_search(self): async def test_basic_claim_search(self):
await self.create_channel() await self.create_channel()
channel_txo = self.channel['outputs'][0] channel_txo = self.channel['outputs'][0]