forked from LBRYCommunity/lbry-sdk
Merge pull request #2273 from lbryio/fix-session-bloat
Fix wallet server session bloat from unhandled socket errors
This commit is contained in:
commit
abf2ca40a2
8 changed files with 236 additions and 37 deletions
132
lbry/tests/integration/test_wallet_server_sessions.py
Normal file
132
lbry/tests/integration/test_wallet_server_sessions.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
from torba.testcase import IntegrationTestCase, Conductor
|
||||
import lbry.wallet
|
||||
from lbry.schema.claim import Claim
|
||||
from lbry.wallet.transaction import Transaction, Output
|
||||
from lbry.wallet.dewies import dewies_to_lbc as d2l, lbc_to_dewies as l2d
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
def wrap_callback_event(fn, callback):
|
||||
def inner(*a, **kw):
|
||||
callback()
|
||||
return fn(*a, **kw)
|
||||
return inner
|
||||
|
||||
|
||||
class TestSessionBloat(IntegrationTestCase):
|
||||
"""
|
||||
ERROR:asyncio:Fatal read error on socket transport
|
||||
protocol: <lbrynet.wallet.server.session.LBRYElectrumX object at 0x7f7e3bfcaf60>
|
||||
transport: <_SelectorSocketTransport fd=3236 read=polling write=<idle, bufsize=0>>
|
||||
Traceback (most recent call last):
|
||||
File "/usr/lib/python3.7/asyncio/selector_events.py", line 801, in _read_ready__data_received
|
||||
data = self._sock.recv(self.max_size)
|
||||
TimeoutError: [Errno 110] Connection timed out
|
||||
"""
|
||||
|
||||
LEDGER = lbry.wallet
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.conductor = Conductor(
|
||||
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
|
||||
)
|
||||
await self.conductor.start_blockchain()
|
||||
self.addCleanup(self.conductor.stop_blockchain)
|
||||
|
||||
await self.conductor.start_spv()
|
||||
|
||||
self.session_manager = self.conductor.spv_node.server.session_mgr
|
||||
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 64)
|
||||
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 64)
|
||||
|
||||
self.addCleanup(self.conductor.stop_spv)
|
||||
await self.conductor.start_wallet()
|
||||
self.addCleanup(self.conductor.stop_wallet)
|
||||
|
||||
self.client_session = list(self.session_manager.sessions)[0]
|
||||
self.client_session.transport.set_write_buffer_limits(0, 0)
|
||||
|
||||
self.paused_session = asyncio.Event(loop=self.loop)
|
||||
self.resumed_session = asyncio.Event(loop=self.loop)
|
||||
|
||||
def paused():
|
||||
self.resumed_session.clear()
|
||||
self.paused_session.set()
|
||||
|
||||
def delayed_resume():
|
||||
self.paused_session.clear()
|
||||
|
||||
time.sleep(1)
|
||||
self.resumed_session.set()
|
||||
|
||||
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
|
||||
self.client_session.resume_writing = wrap_callback_event(self.client_session.resume_writing, delayed_resume)
|
||||
|
||||
self.blockchain = self.conductor.blockchain_node
|
||||
self.wallet_node = self.conductor.wallet_node
|
||||
self.manager = self.wallet_node.manager
|
||||
self.ledger = self.wallet_node.ledger
|
||||
self.wallet = self.wallet_node.wallet
|
||||
self.account = self.wallet_node.wallet.default_account
|
||||
|
||||
async def test_session_bloat_from_socket_timeout(self):
|
||||
await self.account.ensure_address_gap()
|
||||
|
||||
address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True)
|
||||
sendtxid1 = await self.blockchain.send_to_address(address1, 5)
|
||||
sendtxid2 = await self.blockchain.send_to_address(address2, 5)
|
||||
|
||||
await self.blockchain.generate(1)
|
||||
await asyncio.wait([
|
||||
self.on_transaction_id(sendtxid1),
|
||||
self.on_transaction_id(sendtxid2)
|
||||
])
|
||||
|
||||
self.assertEqual(d2l(await self.account.get_balance()), '10.0')
|
||||
|
||||
channel = Claim()
|
||||
channel_txo = Output.pay_claim_name_pubkey_hash(
|
||||
l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1)
|
||||
)
|
||||
channel_txo.generate_channel_private_key()
|
||||
channel_txo.script.generate()
|
||||
channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account)
|
||||
|
||||
stream = Claim()
|
||||
stream.stream.description = "0" * 8000
|
||||
stream_txo = Output.pay_claim_name_pubkey_hash(
|
||||
l2d('1.0'), 'foo', stream, self.account.ledger.address_to_hash160(address1)
|
||||
)
|
||||
stream_tx = await Transaction.create([], [stream_txo], [self.account], self.account)
|
||||
stream_txo.sign(channel_txo)
|
||||
await stream_tx.sign([self.account])
|
||||
self.paused_session.clear()
|
||||
self.resumed_session.clear()
|
||||
|
||||
await self.broadcast(channel_tx)
|
||||
await self.broadcast(stream_tx)
|
||||
await asyncio.wait_for(self.paused_session.wait(), 2)
|
||||
self.assertEqual(1, len(self.session_manager.sessions))
|
||||
|
||||
real_sock = self.client_session.transport._extra.pop('socket')
|
||||
mock_sock = Mock(spec=socket.socket)
|
||||
|
||||
for attr in dir(real_sock):
|
||||
if not attr.startswith('__'):
|
||||
setattr(mock_sock, attr, getattr(real_sock, attr))
|
||||
|
||||
def recv(*a, **kw):
|
||||
raise TimeoutError("[Errno 110] Connection timed out")
|
||||
|
||||
mock_sock.recv = recv
|
||||
self.client_session.transport._sock = mock_sock
|
||||
self.client_session.transport._extra['socket'] = mock_sock
|
||||
self.assertFalse(self.resumed_session.is_set())
|
||||
self.assertFalse(self.session_manager.session_event.is_set())
|
||||
await self.session_manager.session_event.wait()
|
||||
self.assertEqual(0, len(self.session_manager.sessions))
|
|
@ -32,6 +32,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
|||
|
||||
import itertools
|
||||
import json
|
||||
import typing
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
|
@ -596,7 +597,7 @@ class JSONRPCConnection(object):
|
|||
# Sent Requests and Batches that have not received a response.
|
||||
# The key is its request ID; for a batch it is sorted tuple
|
||||
# of request IDs
|
||||
self._requests = {}
|
||||
self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {}
|
||||
# A public attribute intended to be settable dynamically
|
||||
self.max_response_size = 0
|
||||
|
||||
|
@ -683,7 +684,7 @@ class JSONRPCConnection(object):
|
|||
#
|
||||
# External API
|
||||
#
|
||||
def send_request(self, request):
|
||||
def send_request(self, request: Request) -> typing.Tuple[bytes, Event]:
|
||||
"""Send a Request. Return a (message, event) pair.
|
||||
|
||||
The message is an unframed message to send over the network.
|
||||
|
|
|
@ -55,6 +55,7 @@ class Daemon:
|
|||
self.max_retry = max_retry
|
||||
self._height = None
|
||||
self.available_rpcs = {}
|
||||
self.connector = aiohttp.TCPConnector()
|
||||
|
||||
def set_url(self, url):
|
||||
"""Set the URLS to the given list, and switch to the first one."""
|
||||
|
@ -89,7 +90,7 @@ class Daemon:
|
|||
|
||||
def client_session(self):
|
||||
"""An aiohttp client session."""
|
||||
return aiohttp.ClientSession()
|
||||
return aiohttp.ClientSession(connector=self.connector, connector_owner=False)
|
||||
|
||||
async def _send_data(self, data):
|
||||
async with self.workqueue_semaphore:
|
||||
|
|
|
@ -33,7 +33,7 @@ class Env:
|
|||
self.allow_root = self.boolean('ALLOW_ROOT', False)
|
||||
self.host = self.default('HOST', 'localhost')
|
||||
self.rpc_host = self.default('RPC_HOST', 'localhost')
|
||||
self.loop_policy = self.event_loop_policy()
|
||||
self.loop_policy = self.set_event_loop_policy()
|
||||
self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK'])
|
||||
self.db_dir = self.required('DB_DIRECTORY')
|
||||
self.db_engine = self.default('DB_ENGINE', 'leveldb')
|
||||
|
@ -129,14 +129,18 @@ class Env:
|
|||
raise cls.Error('remove obsolete environment variables {}'
|
||||
.format(bad))
|
||||
|
||||
def event_loop_policy(self):
|
||||
policy = self.default('EVENT_LOOP_POLICY', None)
|
||||
if policy is None:
|
||||
return None
|
||||
if policy == 'uvloop':
|
||||
def set_event_loop_policy(self):
|
||||
policy_name = self.default('EVENT_LOOP_POLICY', None)
|
||||
if not policy_name:
|
||||
import asyncio
|
||||
return asyncio.get_event_loop_policy()
|
||||
elif policy_name == 'uvloop':
|
||||
import uvloop
|
||||
return uvloop.EventLoopPolicy()
|
||||
raise self.Error('unknown event loop policy "{}"'.format(policy))
|
||||
import asyncio
|
||||
loop_policy = uvloop.EventLoopPolicy()
|
||||
asyncio.set_event_loop_policy(loop_policy)
|
||||
return loop_policy
|
||||
raise self.Error('unknown event loop policy "{}"'.format(policy_name))
|
||||
|
||||
def cs_host(self, *, for_rpc):
|
||||
"""Returns the 'host' argument to pass to asyncio's create_server
|
||||
|
|
|
@ -70,6 +70,7 @@ class Peer:
|
|||
# Transient, non-persisted metadata
|
||||
self.bad = False
|
||||
self.other_port_pairs = set()
|
||||
self.status = 2
|
||||
|
||||
@classmethod
|
||||
def peers_from_features(cls, features, source):
|
||||
|
|
|
@ -12,6 +12,7 @@ import random
|
|||
import socket
|
||||
import ssl
|
||||
import time
|
||||
import typing
|
||||
from asyncio import Event, sleep
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
|
@ -72,7 +73,7 @@ class PeerManager:
|
|||
# ip_addr property is either None, an onion peer, or the
|
||||
# IP address that was connected to. Adding a peer will evict
|
||||
# any other peers with the same host name or IP address.
|
||||
self.peers = set()
|
||||
self.peers: typing.Set[Peer] = set()
|
||||
self.permit_onion_peer_time = time.time()
|
||||
self.proxy = None
|
||||
self.group = TaskGroup()
|
||||
|
@ -394,7 +395,7 @@ class PeerManager:
|
|||
self.group.add(self._detect_proxy())
|
||||
self.group.add(self._import_peers())
|
||||
|
||||
def info(self):
|
||||
def info(self) -> typing.Dict[str, int]:
|
||||
"""The number of peers."""
|
||||
self._set_peer_statuses()
|
||||
counter = Counter(peer.status for peer in self.peers)
|
||||
|
|
|
@ -5,7 +5,38 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
|||
|
||||
import torba
|
||||
from torba.server.mempool import MemPool, MemPoolAPI
|
||||
from torba.server.session import SessionManager
|
||||
from torba.server.session import SessionManager, SessionBase
|
||||
|
||||
|
||||
CONNECTION_TIMED_OUT = 110
|
||||
NO_ROUTE_TO_HOST = 113
|
||||
|
||||
|
||||
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
|
||||
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def protocol_exception_handler(loop, context):
|
||||
exception = context['exception']
|
||||
if 'protocol' not in context or 'transport' not in context:
|
||||
raise exception
|
||||
if not isinstance(context['protocol'], SessionBase):
|
||||
raise exception
|
||||
session: SessionBase = context['protocol']
|
||||
transport: asyncio.Transport = context['transport']
|
||||
message = context['message']
|
||||
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
|
||||
raise exception
|
||||
|
||||
for err_msg in err_msgs:
|
||||
if str(exception).startswith(err_msg):
|
||||
log.debug("caught: '%s' for %s", str(exception), session)
|
||||
transport.abort()
|
||||
transport.close()
|
||||
loop.create_task(session.close(force_after=1))
|
||||
return
|
||||
raise exception
|
||||
return protocol_exception_handler
|
||||
|
||||
|
||||
class Notifications:
|
||||
|
@ -90,6 +121,7 @@ class Server:
|
|||
)
|
||||
|
||||
async def start(self):
|
||||
asyncio.get_event_loop().set_exception_handler(handle_socket_errors())
|
||||
env = self.env
|
||||
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
|
||||
self.log.info(f'software version: {torba.__version__}')
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# and warranty status of this software.
|
||||
|
||||
"""Classes for local RPC server and remote client TCP/SSL servers."""
|
||||
|
||||
import collections
|
||||
import asyncio
|
||||
import codecs
|
||||
import datetime
|
||||
|
@ -16,6 +16,7 @@ import os
|
|||
import pylru
|
||||
import ssl
|
||||
import time
|
||||
import typing
|
||||
from asyncio import Event, sleep
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
@ -31,13 +32,18 @@ from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
|
|||
HASHX_LEN, Base58Error)
|
||||
from torba.server.daemon import DaemonError
|
||||
from torba.server.peers import PeerManager
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from torba.server.env import Env
|
||||
from torba.server.db import DB
|
||||
from torba.server.block_processor import BlockProcessor
|
||||
from torba.server.mempool import MemPool
|
||||
from torba.server.daemon import Daemon
|
||||
|
||||
BAD_REQUEST = 1
|
||||
DAEMON_ERROR = 2
|
||||
|
||||
|
||||
def scripthash_to_hashX(scripthash):
|
||||
def scripthash_to_hashX(scripthash: str) -> bytes:
|
||||
try:
|
||||
bin_hash = hex_str_to_hash(scripthash)
|
||||
if len(bin_hash) == 32:
|
||||
|
@ -47,7 +53,7 @@ def scripthash_to_hashX(scripthash):
|
|||
raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash')
|
||||
|
||||
|
||||
def non_negative_integer(value):
|
||||
def non_negative_integer(value) -> int:
|
||||
"""Return param value it is or can be converted to a non-negative
|
||||
integer, otherwise raise an RPCError."""
|
||||
try:
|
||||
|
@ -60,14 +66,14 @@ def non_negative_integer(value):
|
|||
f'{value} should be a non-negative integer')
|
||||
|
||||
|
||||
def assert_boolean(value):
|
||||
def assert_boolean(value) -> bool:
|
||||
"""Return param value it is boolean otherwise raise an RPCError."""
|
||||
if value in (False, True):
|
||||
return value
|
||||
raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')
|
||||
|
||||
|
||||
def assert_tx_hash(value):
|
||||
def assert_tx_hash(value: str) -> None:
|
||||
"""Raise an RPCError if the value is not a valid transaction
|
||||
hash."""
|
||||
try:
|
||||
|
@ -97,7 +103,7 @@ class Semaphores:
|
|||
|
||||
class SessionGroup:
|
||||
|
||||
def __init__(self, gid):
|
||||
def __init__(self, gid: int):
|
||||
self.gid = gid
|
||||
# Concurrency per group
|
||||
self.semaphore = asyncio.Semaphore(20)
|
||||
|
@ -106,7 +112,8 @@ class SessionGroup:
|
|||
class SessionManager:
|
||||
"""Holds global state about all sessions."""
|
||||
|
||||
def __init__(self, env, db, bp, daemon, mempool, shutdown_event):
|
||||
def __init__(self, env: 'Env', db: 'DB', bp: 'BlockProcessor', daemon: 'Daemon', mempool: 'MemPool',
|
||||
shutdown_event: asyncio.Event):
|
||||
env.max_send = max(350000, env.max_send)
|
||||
self.env = env
|
||||
self.db = db
|
||||
|
@ -116,28 +123,29 @@ class SessionManager:
|
|||
self.peer_mgr = PeerManager(env, db)
|
||||
self.shutdown_event = shutdown_event
|
||||
self.logger = util.class_logger(__name__, self.__class__.__name__)
|
||||
self.servers = {}
|
||||
self.sessions = set()
|
||||
self.servers: typing.Dict[str, asyncio.AbstractServer] = {}
|
||||
self.sessions: typing.Set['SessionBase'] = set()
|
||||
self.max_subs = env.max_subs
|
||||
self.cur_group = SessionGroup(0)
|
||||
self.txs_sent = 0
|
||||
self.start_time = time.time()
|
||||
self.history_cache = pylru.lrucache(256)
|
||||
self.notified_height = None
|
||||
self.notified_height: typing.Optional[int] = None
|
||||
# Cache some idea of room to avoid recounting on each subscription
|
||||
self.subs_room = 0
|
||||
# Masternode stuff only for such coins
|
||||
if issubclass(env.coin.SESSIONCLS, DashElectrumX):
|
||||
self.mn_cache_height = 0
|
||||
self.mn_cache = []
|
||||
self.mn_cache = [] # type: ignore
|
||||
|
||||
self.session_event = Event()
|
||||
|
||||
# Set up the RPC request handlers
|
||||
cmds = ('add_peer daemon_url disconnect getinfo groups log peers '
|
||||
'query reorg sessions stop'.split())
|
||||
LocalRPC.request_handlers = {cmd: getattr(self, 'rpc_' + cmd)
|
||||
for cmd in cmds}
|
||||
LocalRPC.request_handlers.update(
|
||||
{cmd: getattr(self, 'rpc_' + cmd) for cmd in cmds}
|
||||
)
|
||||
|
||||
async def _start_server(self, kind, *args, **kw_args):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -147,11 +155,10 @@ class SessionManager:
|
|||
protocol_class = self.env.coin.SESSIONCLS
|
||||
protocol_factory = partial(protocol_class, self, self.db,
|
||||
self.mempool, self.peer_mgr, kind)
|
||||
server = loop.create_server(protocol_factory, *args, **kw_args)
|
||||
|
||||
host, port = args[:2]
|
||||
try:
|
||||
self.servers[kind] = await server
|
||||
self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args)
|
||||
except OSError as e: # don't suppress CancelledError
|
||||
self.logger.error(f'{kind} server failed to listen on {host}:'
|
||||
f'{port:d} :{e!r}')
|
||||
|
@ -219,7 +226,7 @@ class SessionManager:
|
|||
group_map[session.group].append(session)
|
||||
return group_map
|
||||
|
||||
def _sub_count(self):
|
||||
def _sub_count(self) -> int:
|
||||
return sum(s.sub_count() for s in self.sessions)
|
||||
|
||||
def _lookup_session(self, session_id):
|
||||
|
@ -278,18 +285,37 @@ class SessionManager:
|
|||
def _get_info(self):
|
||||
"""A summary of server state."""
|
||||
group_map = self._group_map()
|
||||
method_counts = collections.defaultdict(0)
|
||||
error_count = 0
|
||||
logged = 0
|
||||
paused = 0
|
||||
pending_requests = 0
|
||||
closing = 0
|
||||
|
||||
for s in self.sessions:
|
||||
error_count += s.errors
|
||||
if s.log_me:
|
||||
logged += 1
|
||||
if not s._can_send.is_set():
|
||||
paused += 1
|
||||
pending_requests += s.count_pending_items()
|
||||
if s.is_closing():
|
||||
closing += 1
|
||||
for request, _ in s.connection._requests.values():
|
||||
method_counts[request.method] += 1
|
||||
return {
|
||||
'closing': len([s for s in self.sessions if s.is_closing()]),
|
||||
'closing': closing,
|
||||
'daemon': self.daemon.logged_url(),
|
||||
'daemon_height': self.daemon.cached_height(),
|
||||
'db_height': self.db.db_height,
|
||||
'errors': sum(s.errors for s in self.sessions),
|
||||
'errors': error_count,
|
||||
'groups': len(group_map),
|
||||
'logged': len([s for s in self.sessions if s.log_me]),
|
||||
'paused': sum(not s.can_send.is_set() for s in self.sessions),
|
||||
'logged': logged,
|
||||
'paused': paused,
|
||||
'pid': os.getpid(),
|
||||
'peers': self.peer_mgr.info(),
|
||||
'requests': sum(s.count_pending_items() for s in self.sessions),
|
||||
'requests': pending_requests,
|
||||
'method_counts': method_counts,
|
||||
'sessions': self.session_count(),
|
||||
'subs': self._sub_count(),
|
||||
'txs_sent': self.txs_sent,
|
||||
|
@ -514,7 +540,7 @@ class SessionManager:
|
|||
session.close(force_after=1) for session in self.sessions
|
||||
])
|
||||
|
||||
def session_count(self):
|
||||
def session_count(self) -> int:
|
||||
"""The number of connections that we've sent something to."""
|
||||
return len(self.sessions)
|
||||
|
||||
|
@ -601,6 +627,7 @@ class SessionBase(RPCSession):
|
|||
|
||||
MAX_CHUNK_SIZE = 2016
|
||||
session_counter = itertools.count()
|
||||
request_handlers: typing.Dict[str, typing.Callable] = {}
|
||||
|
||||
def __init__(self, session_mgr, db, mempool, peer_mgr, kind):
|
||||
connection = JSONRPCConnection(JSONRPCAutoDetect)
|
||||
|
|
Loading…
Add table
Reference in a new issue