Merge pull request #2273 from lbryio/fix-session-bloat

Fix wallet server session bloat from unhandled socket errors
This commit is contained in:
Jack Robison 2019-07-01 12:18:57 -04:00 committed by GitHub
commit abf2ca40a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 236 additions and 37 deletions

View file

@ -0,0 +1,132 @@
import asyncio
import socket
import time
import logging
from unittest.mock import Mock
from torba.testcase import IntegrationTestCase, Conductor
import lbry.wallet
from lbry.schema.claim import Claim
from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.dewies import dewies_to_lbc as d2l, lbc_to_dewies as l2d
log = logging.getLogger(__name__)
def wrap_callback_event(fn, callback):
def inner(*a, **kw):
callback()
return fn(*a, **kw)
return inner
class TestSessionBloat(IntegrationTestCase):
"""
ERROR:asyncio:Fatal read error on socket transport
protocol: <lbrynet.wallet.server.session.LBRYElectrumX object at 0x7f7e3bfcaf60>
transport: <_SelectorSocketTransport fd=3236 read=polling write=<idle, bufsize=0>>
Traceback (most recent call last):
File "/usr/lib/python3.7/asyncio/selector_events.py", line 801, in _read_ready__data_received
data = self._sock.recv(self.max_size)
TimeoutError: [Errno 110] Connection timed out
"""
LEDGER = lbry.wallet
async def asyncSetUp(self):
self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
)
await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv()
self.session_manager = self.conductor.spv_node.server.session_mgr
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 64)
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 64)
self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet)
self.client_session = list(self.session_manager.sessions)[0]
self.client_session.transport.set_write_buffer_limits(0, 0)
self.paused_session = asyncio.Event(loop=self.loop)
self.resumed_session = asyncio.Event(loop=self.loop)
def paused():
self.resumed_session.clear()
self.paused_session.set()
def delayed_resume():
self.paused_session.clear()
time.sleep(1)
self.resumed_session.set()
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
self.client_session.resume_writing = wrap_callback_event(self.client_session.resume_writing, delayed_resume)
self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger
self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account
async def test_session_bloat_from_socket_timeout(self):
await self.account.ensure_address_gap()
address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True)
sendtxid1 = await self.blockchain.send_to_address(address1, 5)
sendtxid2 = await self.blockchain.send_to_address(address2, 5)
await self.blockchain.generate(1)
await asyncio.wait([
self.on_transaction_id(sendtxid1),
self.on_transaction_id(sendtxid2)
])
self.assertEqual(d2l(await self.account.get_balance()), '10.0')
channel = Claim()
channel_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1)
)
channel_txo.generate_channel_private_key()
channel_txo.script.generate()
channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account)
stream = Claim()
stream.stream.description = "0" * 8000
stream_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), 'foo', stream, self.account.ledger.address_to_hash160(address1)
)
stream_tx = await Transaction.create([], [stream_txo], [self.account], self.account)
stream_txo.sign(channel_txo)
await stream_tx.sign([self.account])
self.paused_session.clear()
self.resumed_session.clear()
await self.broadcast(channel_tx)
await self.broadcast(stream_tx)
await asyncio.wait_for(self.paused_session.wait(), 2)
self.assertEqual(1, len(self.session_manager.sessions))
real_sock = self.client_session.transport._extra.pop('socket')
mock_sock = Mock(spec=socket.socket)
for attr in dir(real_sock):
if not attr.startswith('__'):
setattr(mock_sock, attr, getattr(real_sock, attr))
def recv(*a, **kw):
raise TimeoutError("[Errno 110] Connection timed out")
mock_sock.recv = recv
self.client_session.transport._sock = mock_sock
self.client_session.transport._extra['socket'] = mock_sock
self.assertFalse(self.resumed_session.is_set())
self.assertFalse(self.session_manager.session_event.is_set())
await self.session_manager.session_event.wait()
self.assertEqual(0, len(self.session_manager.sessions))

View file

@ -32,6 +32,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools import itertools
import json import json
import typing
from functools import partial from functools import partial
from numbers import Number from numbers import Number
@ -596,7 +597,7 @@ class JSONRPCConnection(object):
# Sent Requests and Batches that have not received a response. # Sent Requests and Batches that have not received a response.
# The key is its request ID; for a batch it is sorted tuple # The key is its request ID; for a batch it is sorted tuple
# of request IDs # of request IDs
self._requests = {} self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {}
# A public attribute intended to be settable dynamically # A public attribute intended to be settable dynamically
self.max_response_size = 0 self.max_response_size = 0
@ -683,7 +684,7 @@ class JSONRPCConnection(object):
# #
# External API # External API
# #
def send_request(self, request): def send_request(self, request: Request) -> typing.Tuple[bytes, Event]:
"""Send a Request. Return a (message, event) pair. """Send a Request. Return a (message, event) pair.
The message is an unframed message to send over the network. The message is an unframed message to send over the network.

View file

@ -55,6 +55,7 @@ class Daemon:
self.max_retry = max_retry self.max_retry = max_retry
self._height = None self._height = None
self.available_rpcs = {} self.available_rpcs = {}
self.connector = aiohttp.TCPConnector()
def set_url(self, url): def set_url(self, url):
"""Set the URLS to the given list, and switch to the first one.""" """Set the URLS to the given list, and switch to the first one."""
@ -89,7 +90,7 @@ class Daemon:
def client_session(self): def client_session(self):
"""An aiohttp client session.""" """An aiohttp client session."""
return aiohttp.ClientSession() return aiohttp.ClientSession(connector=self.connector, connector_owner=False)
async def _send_data(self, data): async def _send_data(self, data):
async with self.workqueue_semaphore: async with self.workqueue_semaphore:

View file

@ -33,7 +33,7 @@ class Env:
self.allow_root = self.boolean('ALLOW_ROOT', False) self.allow_root = self.boolean('ALLOW_ROOT', False)
self.host = self.default('HOST', 'localhost') self.host = self.default('HOST', 'localhost')
self.rpc_host = self.default('RPC_HOST', 'localhost') self.rpc_host = self.default('RPC_HOST', 'localhost')
self.loop_policy = self.event_loop_policy() self.loop_policy = self.set_event_loop_policy()
self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK']) self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK'])
self.db_dir = self.required('DB_DIRECTORY') self.db_dir = self.required('DB_DIRECTORY')
self.db_engine = self.default('DB_ENGINE', 'leveldb') self.db_engine = self.default('DB_ENGINE', 'leveldb')
@ -129,14 +129,18 @@ class Env:
raise cls.Error('remove obsolete environment variables {}' raise cls.Error('remove obsolete environment variables {}'
.format(bad)) .format(bad))
def event_loop_policy(self): def set_event_loop_policy(self):
policy = self.default('EVENT_LOOP_POLICY', None) policy_name = self.default('EVENT_LOOP_POLICY', None)
if policy is None: if not policy_name:
return None import asyncio
if policy == 'uvloop': return asyncio.get_event_loop_policy()
elif policy_name == 'uvloop':
import uvloop import uvloop
return uvloop.EventLoopPolicy() import asyncio
raise self.Error('unknown event loop policy "{}"'.format(policy)) loop_policy = uvloop.EventLoopPolicy()
asyncio.set_event_loop_policy(loop_policy)
return loop_policy
raise self.Error('unknown event loop policy "{}"'.format(policy_name))
def cs_host(self, *, for_rpc): def cs_host(self, *, for_rpc):
"""Returns the 'host' argument to pass to asyncio's create_server """Returns the 'host' argument to pass to asyncio's create_server

View file

@ -70,6 +70,7 @@ class Peer:
# Transient, non-persisted metadata # Transient, non-persisted metadata
self.bad = False self.bad = False
self.other_port_pairs = set() self.other_port_pairs = set()
self.status = 2
@classmethod @classmethod
def peers_from_features(cls, features, source): def peers_from_features(cls, features, source):

View file

@ -12,6 +12,7 @@ import random
import socket import socket
import ssl import ssl
import time import time
import typing
from asyncio import Event, sleep from asyncio import Event, sleep
from collections import defaultdict, Counter from collections import defaultdict, Counter
@ -72,7 +73,7 @@ class PeerManager:
# ip_addr property is either None, an onion peer, or the # ip_addr property is either None, an onion peer, or the
# IP address that was connected to. Adding a peer will evict # IP address that was connected to. Adding a peer will evict
# any other peers with the same host name or IP address. # any other peers with the same host name or IP address.
self.peers = set() self.peers: typing.Set[Peer] = set()
self.permit_onion_peer_time = time.time() self.permit_onion_peer_time = time.time()
self.proxy = None self.proxy = None
self.group = TaskGroup() self.group = TaskGroup()
@ -394,7 +395,7 @@ class PeerManager:
self.group.add(self._detect_proxy()) self.group.add(self._detect_proxy())
self.group.add(self._import_peers()) self.group.add(self._import_peers())
def info(self): def info(self) -> typing.Dict[str, int]:
"""The number of peers.""" """The number of peers."""
self._set_peer_statuses() self._set_peer_statuses()
counter = Counter(peer.status for peer in self.peers) counter = Counter(peer.status for peer in self.peers)

View file

@ -5,7 +5,38 @@ from concurrent.futures.thread import ThreadPoolExecutor
import torba import torba
from torba.server.mempool import MemPool, MemPoolAPI from torba.server.mempool import MemPool, MemPoolAPI
from torba.server.session import SessionManager from torba.server.session import SessionManager, SessionBase
CONNECTION_TIMED_OUT = 110
NO_ROUTE_TO_HOST = 113
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
log = logging.getLogger(__name__)
def protocol_exception_handler(loop, context):
exception = context['exception']
if 'protocol' not in context or 'transport' not in context:
raise exception
if not isinstance(context['protocol'], SessionBase):
raise exception
session: SessionBase = context['protocol']
transport: asyncio.Transport = context['transport']
message = context['message']
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
raise exception
for err_msg in err_msgs:
if str(exception).startswith(err_msg):
log.debug("caught: '%s' for %s", str(exception), session)
transport.abort()
transport.close()
loop.create_task(session.close(force_after=1))
return
raise exception
return protocol_exception_handler
class Notifications: class Notifications:
@ -90,6 +121,7 @@ class Server:
) )
async def start(self): async def start(self):
asyncio.get_event_loop().set_exception_handler(handle_socket_errors())
env = self.env env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.log.info(f'software version: {torba.__version__}') self.log.info(f'software version: {torba.__version__}')

View file

@ -6,7 +6,7 @@
# and warranty status of this software. # and warranty status of this software.
"""Classes for local RPC server and remote client TCP/SSL servers.""" """Classes for local RPC server and remote client TCP/SSL servers."""
import collections
import asyncio import asyncio
import codecs import codecs
import datetime import datetime
@ -16,6 +16,7 @@ import os
import pylru import pylru
import ssl import ssl
import time import time
import typing
from asyncio import Event, sleep from asyncio import Event, sleep
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
@ -31,13 +32,18 @@ from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
HASHX_LEN, Base58Error) HASHX_LEN, Base58Error)
from torba.server.daemon import DaemonError from torba.server.daemon import DaemonError
from torba.server.peers import PeerManager from torba.server.peers import PeerManager
if typing.TYPE_CHECKING:
from torba.server.env import Env
from torba.server.db import DB
from torba.server.block_processor import BlockProcessor
from torba.server.mempool import MemPool
from torba.server.daemon import Daemon
BAD_REQUEST = 1 BAD_REQUEST = 1
DAEMON_ERROR = 2 DAEMON_ERROR = 2
def scripthash_to_hashX(scripthash): def scripthash_to_hashX(scripthash: str) -> bytes:
try: try:
bin_hash = hex_str_to_hash(scripthash) bin_hash = hex_str_to_hash(scripthash)
if len(bin_hash) == 32: if len(bin_hash) == 32:
@ -47,7 +53,7 @@ def scripthash_to_hashX(scripthash):
raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash')
def non_negative_integer(value): def non_negative_integer(value) -> int:
"""Return param value it is or can be converted to a non-negative """Return param value it is or can be converted to a non-negative
integer, otherwise raise an RPCError.""" integer, otherwise raise an RPCError."""
try: try:
@ -60,14 +66,14 @@ def non_negative_integer(value):
f'{value} should be a non-negative integer') f'{value} should be a non-negative integer')
def assert_boolean(value): def assert_boolean(value) -> bool:
"""Return param value it is boolean otherwise raise an RPCError.""" """Return param value it is boolean otherwise raise an RPCError."""
if value in (False, True): if value in (False, True):
return value return value
raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')
def assert_tx_hash(value): def assert_tx_hash(value: str) -> None:
"""Raise an RPCError if the value is not a valid transaction """Raise an RPCError if the value is not a valid transaction
hash.""" hash."""
try: try:
@ -97,7 +103,7 @@ class Semaphores:
class SessionGroup: class SessionGroup:
def __init__(self, gid): def __init__(self, gid: int):
self.gid = gid self.gid = gid
# Concurrency per group # Concurrency per group
self.semaphore = asyncio.Semaphore(20) self.semaphore = asyncio.Semaphore(20)
@ -106,7 +112,8 @@ class SessionGroup:
class SessionManager: class SessionManager:
"""Holds global state about all sessions.""" """Holds global state about all sessions."""
def __init__(self, env, db, bp, daemon, mempool, shutdown_event): def __init__(self, env: 'Env', db: 'DB', bp: 'BlockProcessor', daemon: 'Daemon', mempool: 'MemPool',
shutdown_event: asyncio.Event):
env.max_send = max(350000, env.max_send) env.max_send = max(350000, env.max_send)
self.env = env self.env = env
self.db = db self.db = db
@ -116,28 +123,29 @@ class SessionManager:
self.peer_mgr = PeerManager(env, db) self.peer_mgr = PeerManager(env, db)
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.logger = util.class_logger(__name__, self.__class__.__name__) self.logger = util.class_logger(__name__, self.__class__.__name__)
self.servers = {} self.servers: typing.Dict[str, asyncio.AbstractServer] = {}
self.sessions = set() self.sessions: typing.Set['SessionBase'] = set()
self.max_subs = env.max_subs self.max_subs = env.max_subs
self.cur_group = SessionGroup(0) self.cur_group = SessionGroup(0)
self.txs_sent = 0 self.txs_sent = 0
self.start_time = time.time() self.start_time = time.time()
self.history_cache = pylru.lrucache(256) self.history_cache = pylru.lrucache(256)
self.notified_height = None self.notified_height: typing.Optional[int] = None
# Cache some idea of room to avoid recounting on each subscription # Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0 self.subs_room = 0
# Masternode stuff only for such coins # Masternode stuff only for such coins
if issubclass(env.coin.SESSIONCLS, DashElectrumX): if issubclass(env.coin.SESSIONCLS, DashElectrumX):
self.mn_cache_height = 0 self.mn_cache_height = 0
self.mn_cache = [] self.mn_cache = [] # type: ignore
self.session_event = Event() self.session_event = Event()
# Set up the RPC request handlers # Set up the RPC request handlers
cmds = ('add_peer daemon_url disconnect getinfo groups log peers ' cmds = ('add_peer daemon_url disconnect getinfo groups log peers '
'query reorg sessions stop'.split()) 'query reorg sessions stop'.split())
LocalRPC.request_handlers = {cmd: getattr(self, 'rpc_' + cmd) LocalRPC.request_handlers.update(
for cmd in cmds} {cmd: getattr(self, 'rpc_' + cmd) for cmd in cmds}
)
async def _start_server(self, kind, *args, **kw_args): async def _start_server(self, kind, *args, **kw_args):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -147,11 +155,10 @@ class SessionManager:
protocol_class = self.env.coin.SESSIONCLS protocol_class = self.env.coin.SESSIONCLS
protocol_factory = partial(protocol_class, self, self.db, protocol_factory = partial(protocol_class, self, self.db,
self.mempool, self.peer_mgr, kind) self.mempool, self.peer_mgr, kind)
server = loop.create_server(protocol_factory, *args, **kw_args)
host, port = args[:2] host, port = args[:2]
try: try:
self.servers[kind] = await server self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args)
except OSError as e: # don't suppress CancelledError except OSError as e: # don't suppress CancelledError
self.logger.error(f'{kind} server failed to listen on {host}:' self.logger.error(f'{kind} server failed to listen on {host}:'
f'{port:d} :{e!r}') f'{port:d} :{e!r}')
@ -219,7 +226,7 @@ class SessionManager:
group_map[session.group].append(session) group_map[session.group].append(session)
return group_map return group_map
def _sub_count(self): def _sub_count(self) -> int:
return sum(s.sub_count() for s in self.sessions) return sum(s.sub_count() for s in self.sessions)
def _lookup_session(self, session_id): def _lookup_session(self, session_id):
@ -278,18 +285,37 @@ class SessionManager:
def _get_info(self): def _get_info(self):
"""A summary of server state.""" """A summary of server state."""
group_map = self._group_map() group_map = self._group_map()
method_counts = collections.defaultdict(0)
error_count = 0
logged = 0
paused = 0
pending_requests = 0
closing = 0
for s in self.sessions:
error_count += s.errors
if s.log_me:
logged += 1
if not s._can_send.is_set():
paused += 1
pending_requests += s.count_pending_items()
if s.is_closing():
closing += 1
for request, _ in s.connection._requests.values():
method_counts[request.method] += 1
return { return {
'closing': len([s for s in self.sessions if s.is_closing()]), 'closing': closing,
'daemon': self.daemon.logged_url(), 'daemon': self.daemon.logged_url(),
'daemon_height': self.daemon.cached_height(), 'daemon_height': self.daemon.cached_height(),
'db_height': self.db.db_height, 'db_height': self.db.db_height,
'errors': sum(s.errors for s in self.sessions), 'errors': error_count,
'groups': len(group_map), 'groups': len(group_map),
'logged': len([s for s in self.sessions if s.log_me]), 'logged': logged,
'paused': sum(not s.can_send.is_set() for s in self.sessions), 'paused': paused,
'pid': os.getpid(), 'pid': os.getpid(),
'peers': self.peer_mgr.info(), 'peers': self.peer_mgr.info(),
'requests': sum(s.count_pending_items() for s in self.sessions), 'requests': pending_requests,
'method_counts': method_counts,
'sessions': self.session_count(), 'sessions': self.session_count(),
'subs': self._sub_count(), 'subs': self._sub_count(),
'txs_sent': self.txs_sent, 'txs_sent': self.txs_sent,
@ -514,7 +540,7 @@ class SessionManager:
session.close(force_after=1) for session in self.sessions session.close(force_after=1) for session in self.sessions
]) ])
def session_count(self): def session_count(self) -> int:
"""The number of connections that we've sent something to.""" """The number of connections that we've sent something to."""
return len(self.sessions) return len(self.sessions)
@ -601,6 +627,7 @@ class SessionBase(RPCSession):
MAX_CHUNK_SIZE = 2016 MAX_CHUNK_SIZE = 2016
session_counter = itertools.count() session_counter = itertools.count()
request_handlers: typing.Dict[str, typing.Callable] = {}
def __init__(self, session_mgr, db, mempool, peer_mgr, kind): def __init__(self, session_mgr, db, mempool, peer_mgr, kind):
connection = JSONRPCConnection(JSONRPCAutoDetect) connection = JSONRPCConnection(JSONRPCAutoDetect)