Merge pull request #2820 from lbryio/disconnect-on-oversized

disconnect from client sending message larger than MAX_RECEIVE
This commit is contained in:
Jack Robison 2020-02-20 16:24:26 -05:00 committed by GitHub
commit 947017e334
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 5 deletions

View file

@ -1,5 +1,6 @@
import logging import logging
import asyncio import asyncio
import json
from time import perf_counter from time import perf_counter
from operator import itemgetter from operator import itemgetter
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -68,7 +69,14 @@ class ClientSession(BaseClientSession):
log.info("timeout sending %s to %s:%i", method, *self.server) log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError raise asyncio.TimeoutError
if done: if done:
return request.result() try:
return request.result()
except ConnectionResetError:
log.error(
"wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes",
self.server[0], method, len(args), len(json.dumps(args))
)
raise
except (RPCError, ProtocolError) as e: except (RPCError, ProtocolError) as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args) *self.server, *e.args)

View file

@ -386,9 +386,11 @@ class RPCSession(SessionBase):
while not self.is_closing(): while not self.is_closing():
try: try:
message = await self.framer.receive_message() message = await self.framer.receive_message()
except MemoryError as e: except MemoryError:
self.logger.warning(f'{e!r}') self.logger.warning('received oversized message from %s:%s, dropping connection',
continue self._address[0], self._address[1])
self._close()
return
self.last_recv = time.perf_counter() self.last_recv = time.perf_counter()
self.recv_count += 1 self.recv_count += 1

View file

@ -1,10 +1,10 @@
import asyncio
import os.path import os.path
import tempfile import tempfile
import logging import logging
from binascii import unhexlify from binascii import unhexlify
from urllib.request import urlopen from urllib.request import urlopen
from lbry.error import InsufficientFundsError from lbry.error import InsufficientFundsError
from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE
@ -73,6 +73,19 @@ class ClaimSearchCommand(ClaimTestCase):
(result['txid'], result['claim_id']) (result['txid'], result['claim_id'])
) )
async def test_disconnect_on_memory_error(self):
claim_ids = [
'0000000000000000000000000000000000000000',
] * 23828
self.assertListEqual([], await self.claim_search(claim_ids=claim_ids))
# 23829 claim ids makes the request just large enough
claim_ids = [
'0000000000000000000000000000000000000000',
] * 23829
with self.assertRaises(ConnectionResetError):
await self.claim_search(claim_ids=claim_ids)
async def test_basic_claim_search(self): async def test_basic_claim_search(self):
await self.create_channel() await self.create_channel()
channel_txo = self.channel['outputs'][0] channel_txo = self.channel['outputs'][0]