Merge pull request #2273 from lbryio/fix-session-bloat

Fix wallet server session bloat from unhandled socket errors
This commit is contained in:
Jack Robison 2019-07-01 12:18:57 -04:00 committed by GitHub
commit abf2ca40a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 236 additions and 37 deletions

View file

@ -0,0 +1,132 @@
import asyncio
import socket
import time
import logging
from unittest.mock import Mock
from torba.testcase import IntegrationTestCase, Conductor
import lbry.wallet
from lbry.schema.claim import Claim
from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.dewies import dewies_to_lbc as d2l, lbc_to_dewies as l2d
log = logging.getLogger(__name__)
def wrap_callback_event(fn, callback):
def inner(*a, **kw):
callback()
return fn(*a, **kw)
return inner
class TestSessionBloat(IntegrationTestCase):
"""
ERROR:asyncio:Fatal read error on socket transport
protocol: <lbrynet.wallet.server.session.LBRYElectrumX object at 0x7f7e3bfcaf60>
transport: <_SelectorSocketTransport fd=3236 read=polling write=<idle, bufsize=0>>
Traceback (most recent call last):
File "/usr/lib/python3.7/asyncio/selector_events.py", line 801, in _read_ready__data_received
data = self._sock.recv(self.max_size)
TimeoutError: [Errno 110] Connection timed out
"""
LEDGER = lbry.wallet
async def asyncSetUp(self):
self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
)
await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv()
self.session_manager = self.conductor.spv_node.server.session_mgr
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 64)
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 64)
self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet)
self.client_session = list(self.session_manager.sessions)[0]
self.client_session.transport.set_write_buffer_limits(0, 0)
self.paused_session = asyncio.Event(loop=self.loop)
self.resumed_session = asyncio.Event(loop=self.loop)
def paused():
self.resumed_session.clear()
self.paused_session.set()
def delayed_resume():
self.paused_session.clear()
time.sleep(1)
self.resumed_session.set()
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
self.client_session.resume_writing = wrap_callback_event(self.client_session.resume_writing, delayed_resume)
self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger
self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account
async def test_session_bloat_from_socket_timeout(self):
await self.account.ensure_address_gap()
address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True)
sendtxid1 = await self.blockchain.send_to_address(address1, 5)
sendtxid2 = await self.blockchain.send_to_address(address2, 5)
await self.blockchain.generate(1)
await asyncio.wait([
self.on_transaction_id(sendtxid1),
self.on_transaction_id(sendtxid2)
])
self.assertEqual(d2l(await self.account.get_balance()), '10.0')
channel = Claim()
channel_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1)
)
channel_txo.generate_channel_private_key()
channel_txo.script.generate()
channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account)
stream = Claim()
stream.stream.description = "0" * 8000
stream_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), 'foo', stream, self.account.ledger.address_to_hash160(address1)
)
stream_tx = await Transaction.create([], [stream_txo], [self.account], self.account)
stream_txo.sign(channel_txo)
await stream_tx.sign([self.account])
self.paused_session.clear()
self.resumed_session.clear()
await self.broadcast(channel_tx)
await self.broadcast(stream_tx)
await asyncio.wait_for(self.paused_session.wait(), 2)
self.assertEqual(1, len(self.session_manager.sessions))
real_sock = self.client_session.transport._extra.pop('socket')
mock_sock = Mock(spec=socket.socket)
for attr in dir(real_sock):
if not attr.startswith('__'):
setattr(mock_sock, attr, getattr(real_sock, attr))
def recv(*a, **kw):
raise TimeoutError("[Errno 110] Connection timed out")
mock_sock.recv = recv
self.client_session.transport._sock = mock_sock
self.client_session.transport._extra['socket'] = mock_sock
self.assertFalse(self.resumed_session.is_set())
self.assertFalse(self.session_manager.session_event.is_set())
await self.session_manager.session_event.wait()
self.assertEqual(0, len(self.session_manager.sessions))

View file

@ -32,6 +32,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools
import json
import typing
from functools import partial
from numbers import Number
@ -596,7 +597,7 @@ class JSONRPCConnection(object):
# Sent Requests and Batches that have not received a response.
# The key is its request ID; for a batch it is sorted tuple
# of request IDs
self._requests = {}
self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {}
# A public attribute intended to be settable dynamically
self.max_response_size = 0
@ -683,7 +684,7 @@ class JSONRPCConnection(object):
#
# External API
#
def send_request(self, request):
def send_request(self, request: Request) -> typing.Tuple[bytes, Event]:
"""Send a Request. Return a (message, event) pair.
The message is an unframed message to send over the network.

View file

@ -55,6 +55,7 @@ class Daemon:
self.max_retry = max_retry
self._height = None
self.available_rpcs = {}
self.connector = aiohttp.TCPConnector()
def set_url(self, url):
"""Set the URLS to the given list, and switch to the first one."""
@ -89,7 +90,7 @@ class Daemon:
def client_session(self):
"""An aiohttp client session."""
return aiohttp.ClientSession()
return aiohttp.ClientSession(connector=self.connector, connector_owner=False)
async def _send_data(self, data):
async with self.workqueue_semaphore:

View file

@ -33,7 +33,7 @@ class Env:
self.allow_root = self.boolean('ALLOW_ROOT', False)
self.host = self.default('HOST', 'localhost')
self.rpc_host = self.default('RPC_HOST', 'localhost')
self.loop_policy = self.event_loop_policy()
self.loop_policy = self.set_event_loop_policy()
self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK'])
self.db_dir = self.required('DB_DIRECTORY')
self.db_engine = self.default('DB_ENGINE', 'leveldb')
@ -129,14 +129,18 @@ class Env:
raise cls.Error('remove obsolete environment variables {}'
.format(bad))
def event_loop_policy(self):
policy = self.default('EVENT_LOOP_POLICY', None)
if policy is None:
return None
if policy == 'uvloop':
def set_event_loop_policy(self):
policy_name = self.default('EVENT_LOOP_POLICY', None)
if not policy_name:
import asyncio
return asyncio.get_event_loop_policy()
elif policy_name == 'uvloop':
import uvloop
return uvloop.EventLoopPolicy()
raise self.Error('unknown event loop policy "{}"'.format(policy))
import asyncio
loop_policy = uvloop.EventLoopPolicy()
asyncio.set_event_loop_policy(loop_policy)
return loop_policy
raise self.Error('unknown event loop policy "{}"'.format(policy_name))
def cs_host(self, *, for_rpc):
"""Returns the 'host' argument to pass to asyncio's create_server

View file

@ -70,6 +70,7 @@ class Peer:
# Transient, non-persisted metadata
self.bad = False
self.other_port_pairs = set()
self.status = 2
@classmethod
def peers_from_features(cls, features, source):

View file

@ -12,6 +12,7 @@ import random
import socket
import ssl
import time
import typing
from asyncio import Event, sleep
from collections import defaultdict, Counter
@ -72,7 +73,7 @@ class PeerManager:
# ip_addr property is either None, an onion peer, or the
# IP address that was connected to. Adding a peer will evict
# any other peers with the same host name or IP address.
self.peers = set()
self.peers: typing.Set[Peer] = set()
self.permit_onion_peer_time = time.time()
self.proxy = None
self.group = TaskGroup()
@ -394,7 +395,7 @@ class PeerManager:
self.group.add(self._detect_proxy())
self.group.add(self._import_peers())
def info(self):
def info(self) -> typing.Dict[str, int]:
"""The number of peers."""
self._set_peer_statuses()
counter = Counter(peer.status for peer in self.peers)

View file

@ -5,7 +5,38 @@ from concurrent.futures.thread import ThreadPoolExecutor
import torba
from torba.server.mempool import MemPool, MemPoolAPI
from torba.server.session import SessionManager
from torba.server.session import SessionManager, SessionBase
CONNECTION_TIMED_OUT = 110
NO_ROUTE_TO_HOST = 113
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
log = logging.getLogger(__name__)
def protocol_exception_handler(loop, context):
exception = context['exception']
if 'protocol' not in context or 'transport' not in context:
raise exception
if not isinstance(context['protocol'], SessionBase):
raise exception
session: SessionBase = context['protocol']
transport: asyncio.Transport = context['transport']
message = context['message']
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
raise exception
for err_msg in err_msgs:
if str(exception).startswith(err_msg):
log.debug("caught: '%s' for %s", str(exception), session)
transport.abort()
transport.close()
loop.create_task(session.close(force_after=1))
return
raise exception
return protocol_exception_handler
class Notifications:
@ -90,6 +121,7 @@ class Server:
)
async def start(self):
asyncio.get_event_loop().set_exception_handler(handle_socket_errors())
env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.log.info(f'software version: {torba.__version__}')

View file

@ -6,7 +6,7 @@
# and warranty status of this software.
"""Classes for local RPC server and remote client TCP/SSL servers."""
import collections
import asyncio
import codecs
import datetime
@ -16,6 +16,7 @@ import os
import pylru
import ssl
import time
import typing
from asyncio import Event, sleep
from collections import defaultdict
from functools import partial
@ -31,13 +32,18 @@ from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
HASHX_LEN, Base58Error)
from torba.server.daemon import DaemonError
from torba.server.peers import PeerManager
if typing.TYPE_CHECKING:
from torba.server.env import Env
from torba.server.db import DB
from torba.server.block_processor import BlockProcessor
from torba.server.mempool import MemPool
from torba.server.daemon import Daemon
BAD_REQUEST = 1
DAEMON_ERROR = 2
def scripthash_to_hashX(scripthash):
def scripthash_to_hashX(scripthash: str) -> bytes:
try:
bin_hash = hex_str_to_hash(scripthash)
if len(bin_hash) == 32:
@ -47,7 +53,7 @@ def scripthash_to_hashX(scripthash):
raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash')
def non_negative_integer(value):
def non_negative_integer(value) -> int:
"""Return param value it is or can be converted to a non-negative
integer, otherwise raise an RPCError."""
try:
@ -60,14 +66,14 @@ def non_negative_integer(value):
f'{value} should be a non-negative integer')
def assert_boolean(value):
def assert_boolean(value) -> bool:
"""Return param value it is boolean otherwise raise an RPCError."""
if value in (False, True):
return value
raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')
def assert_tx_hash(value):
def assert_tx_hash(value: str) -> None:
"""Raise an RPCError if the value is not a valid transaction
hash."""
try:
@ -97,7 +103,7 @@ class Semaphores:
class SessionGroup:
def __init__(self, gid):
def __init__(self, gid: int):
self.gid = gid
# Concurrency per group
self.semaphore = asyncio.Semaphore(20)
@ -106,7 +112,8 @@ class SessionGroup:
class SessionManager:
"""Holds global state about all sessions."""
def __init__(self, env, db, bp, daemon, mempool, shutdown_event):
def __init__(self, env: 'Env', db: 'DB', bp: 'BlockProcessor', daemon: 'Daemon', mempool: 'MemPool',
shutdown_event: asyncio.Event):
env.max_send = max(350000, env.max_send)
self.env = env
self.db = db
@ -116,28 +123,29 @@ class SessionManager:
self.peer_mgr = PeerManager(env, db)
self.shutdown_event = shutdown_event
self.logger = util.class_logger(__name__, self.__class__.__name__)
self.servers = {}
self.sessions = set()
self.servers: typing.Dict[str, asyncio.AbstractServer] = {}
self.sessions: typing.Set['SessionBase'] = set()
self.max_subs = env.max_subs
self.cur_group = SessionGroup(0)
self.txs_sent = 0
self.start_time = time.time()
self.history_cache = pylru.lrucache(256)
self.notified_height = None
self.notified_height: typing.Optional[int] = None
# Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0
# Masternode stuff only for such coins
if issubclass(env.coin.SESSIONCLS, DashElectrumX):
self.mn_cache_height = 0
self.mn_cache = []
self.mn_cache = [] # type: ignore
self.session_event = Event()
# Set up the RPC request handlers
cmds = ('add_peer daemon_url disconnect getinfo groups log peers '
'query reorg sessions stop'.split())
LocalRPC.request_handlers = {cmd: getattr(self, 'rpc_' + cmd)
for cmd in cmds}
LocalRPC.request_handlers.update(
{cmd: getattr(self, 'rpc_' + cmd) for cmd in cmds}
)
async def _start_server(self, kind, *args, **kw_args):
loop = asyncio.get_event_loop()
@ -147,11 +155,10 @@ class SessionManager:
protocol_class = self.env.coin.SESSIONCLS
protocol_factory = partial(protocol_class, self, self.db,
self.mempool, self.peer_mgr, kind)
server = loop.create_server(protocol_factory, *args, **kw_args)
host, port = args[:2]
try:
self.servers[kind] = await server
self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args)
except OSError as e: # don't suppress CancelledError
self.logger.error(f'{kind} server failed to listen on {host}:'
f'{port:d} :{e!r}')
@ -219,7 +226,7 @@ class SessionManager:
group_map[session.group].append(session)
return group_map
def _sub_count(self):
def _sub_count(self) -> int:
return sum(s.sub_count() for s in self.sessions)
def _lookup_session(self, session_id):
@ -278,18 +285,37 @@ class SessionManager:
def _get_info(self):
"""A summary of server state."""
group_map = self._group_map()
method_counts = collections.defaultdict(0)
error_count = 0
logged = 0
paused = 0
pending_requests = 0
closing = 0
for s in self.sessions:
error_count += s.errors
if s.log_me:
logged += 1
if not s._can_send.is_set():
paused += 1
pending_requests += s.count_pending_items()
if s.is_closing():
closing += 1
for request, _ in s.connection._requests.values():
method_counts[request.method] += 1
return {
'closing': len([s for s in self.sessions if s.is_closing()]),
'closing': closing,
'daemon': self.daemon.logged_url(),
'daemon_height': self.daemon.cached_height(),
'db_height': self.db.db_height,
'errors': sum(s.errors for s in self.sessions),
'errors': error_count,
'groups': len(group_map),
'logged': len([s for s in self.sessions if s.log_me]),
'paused': sum(not s.can_send.is_set() for s in self.sessions),
'logged': logged,
'paused': paused,
'pid': os.getpid(),
'peers': self.peer_mgr.info(),
'requests': sum(s.count_pending_items() for s in self.sessions),
'requests': pending_requests,
'method_counts': method_counts,
'sessions': self.session_count(),
'subs': self._sub_count(),
'txs_sent': self.txs_sent,
@ -514,7 +540,7 @@ class SessionManager:
session.close(force_after=1) for session in self.sessions
])
def session_count(self):
def session_count(self) -> int:
"""The number of connections that we've sent something to."""
return len(self.sessions)
@ -601,6 +627,7 @@ class SessionBase(RPCSession):
MAX_CHUNK_SIZE = 2016
session_counter = itertools.count()
request_handlers: typing.Dict[str, typing.Callable] = {}
def __init__(self, session_mgr, db, mempool, peer_mgr, kind):
connection = JSONRPCConnection(JSONRPCAutoDetect)