fix reconnect and timeouts
This commit is contained in:
parent
d97d738880
commit
6997ea608d
6 changed files with 38 additions and 161 deletions
|
@ -1,132 +1,30 @@
|
|||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
from torba.testcase import IntegrationTestCase, Conductor
|
||||
|
||||
from torba.client.basenetwork import ClientSession
|
||||
from torba.orchstr8 import Conductor
|
||||
from torba.testcase import IntegrationTestCase
|
||||
import lbry.wallet
|
||||
from lbry.schema.claim import Claim
|
||||
from lbry.wallet.transaction import Transaction, Output
|
||||
from lbry.wallet.dewies import dewies_to_lbc as d2l, lbc_to_dewies as l2d
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
def wrap_callback_event(fn, callback):
|
||||
def inner(*a, **kw):
|
||||
callback()
|
||||
return fn(*a, **kw)
|
||||
return inner
|
||||
|
||||
|
||||
class TestSessionBloat(IntegrationTestCase):
|
||||
"""
|
||||
ERROR:asyncio:Fatal read error on socket transport
|
||||
protocol: <lbrynet.wallet.server.session.LBRYElectrumX object at 0x7f7e3bfcaf60>
|
||||
transport: <_SelectorSocketTransport fd=3236 read=polling write=<idle, bufsize=0>>
|
||||
Traceback (most recent call last):
|
||||
File "/usr/lib/python3.7/asyncio/selector_events.py", line 801, in _read_ready__data_received
|
||||
data = self._sock.recv(self.max_size)
|
||||
TimeoutError: [Errno 110] Connection timed out
|
||||
Tests that server cleans up stale connections after session timeout and client times out too.
|
||||
"""
|
||||
|
||||
LEDGER = lbry.wallet
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.conductor = Conductor(
|
||||
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
|
||||
)
|
||||
await self.conductor.start_blockchain()
|
||||
self.addCleanup(self.conductor.stop_blockchain)
|
||||
|
||||
await self.conductor.start_spv()
|
||||
|
||||
self.session_manager = self.conductor.spv_node.server.session_mgr
|
||||
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 64)
|
||||
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 64)
|
||||
|
||||
self.addCleanup(self.conductor.stop_spv)
|
||||
await self.conductor.start_wallet()
|
||||
self.addCleanup(self.conductor.stop_wallet)
|
||||
|
||||
self.client_session = list(self.session_manager.sessions)[0]
|
||||
self.client_session.transport.set_write_buffer_limits(0, 0)
|
||||
|
||||
self.paused_session = asyncio.Event(loop=self.loop)
|
||||
self.resumed_session = asyncio.Event(loop=self.loop)
|
||||
|
||||
def paused():
|
||||
self.resumed_session.clear()
|
||||
self.paused_session.set()
|
||||
|
||||
def delayed_resume():
|
||||
self.paused_session.clear()
|
||||
|
||||
time.sleep(1)
|
||||
self.resumed_session.set()
|
||||
|
||||
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
|
||||
self.client_session.resume_writing = wrap_callback_event(self.client_session.resume_writing, delayed_resume)
|
||||
|
||||
self.blockchain = self.conductor.blockchain_node
|
||||
self.wallet_node = self.conductor.wallet_node
|
||||
self.manager = self.wallet_node.manager
|
||||
self.ledger = self.wallet_node.ledger
|
||||
self.wallet = self.wallet_node.wallet
|
||||
self.account = self.wallet_node.wallet.default_account
|
||||
|
||||
async def test_session_bloat_from_socket_timeout(self):
|
||||
await self.account.ensure_address_gap()
|
||||
|
||||
address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True)
|
||||
sendtxid1 = await self.blockchain.send_to_address(address1, 5)
|
||||
sendtxid2 = await self.blockchain.send_to_address(address2, 5)
|
||||
|
||||
await self.blockchain.generate(1)
|
||||
await asyncio.wait([
|
||||
self.on_transaction_id(sendtxid1),
|
||||
self.on_transaction_id(sendtxid2)
|
||||
])
|
||||
|
||||
self.assertEqual(d2l(await self.account.get_balance()), '10.0')
|
||||
|
||||
channel = Claim()
|
||||
channel_txo = Output.pay_claim_name_pubkey_hash(
|
||||
l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1)
|
||||
)
|
||||
channel_txo.generate_channel_private_key()
|
||||
channel_txo.script.generate()
|
||||
channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account)
|
||||
|
||||
stream = Claim()
|
||||
stream.stream.description = "0" * 8000
|
||||
stream_txo = Output.pay_claim_name_pubkey_hash(
|
||||
l2d('1.0'), 'foo', stream, self.account.ledger.address_to_hash160(address1)
|
||||
)
|
||||
stream_tx = await Transaction.create([], [stream_txo], [self.account], self.account)
|
||||
stream_txo.sign(channel_txo)
|
||||
await stream_tx.sign([self.account])
|
||||
self.paused_session.clear()
|
||||
self.resumed_session.clear()
|
||||
|
||||
await self.broadcast(channel_tx)
|
||||
await self.broadcast(stream_tx)
|
||||
await asyncio.wait_for(self.paused_session.wait(), 2)
|
||||
self.assertEqual(1, len(self.session_manager.sessions))
|
||||
|
||||
real_sock = self.client_session.transport._extra.pop('socket')
|
||||
mock_sock = Mock(spec=socket.socket)
|
||||
|
||||
for attr in dir(real_sock):
|
||||
if not attr.startswith('__'):
|
||||
setattr(mock_sock, attr, getattr(real_sock, attr))
|
||||
|
||||
def recv(*a, **kw):
|
||||
raise TimeoutError("[Errno 110] Connection timed out")
|
||||
|
||||
mock_sock.recv = recv
|
||||
self.client_session.transport._sock = mock_sock
|
||||
self.client_session.transport._extra['socket'] = mock_sock
|
||||
self.assertFalse(self.resumed_session.is_set())
|
||||
self.assertFalse(self.session_manager.session_event.is_set())
|
||||
await self.session_manager.session_event.wait()
|
||||
self.assertEqual(0, len(self.session_manager.sessions))
|
||||
await self.conductor.stop_spv()
|
||||
self.conductor.spv_node.session_timeout = 1
|
||||
await self.conductor.start_spv()
|
||||
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
|
||||
await session.create_connection()
|
||||
session.ping_task.cancel()
|
||||
await session.send_request('server.banner', ())
|
||||
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
|
||||
self.assertFalse(session.is_closing())
|
||||
await asyncio.sleep(1.1)
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await session.send_request('server.banner', ())
|
||||
self.assertTrue(session.is_closing())
|
||||
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 0)
|
||||
|
|
|
@ -15,22 +15,26 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class ClientSession(BaseClientSession):
|
||||
|
||||
def __init__(self, *args, network, server, **kwargs):
|
||||
def __init__(self, *args, network, server, timeout=30, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_disconnect_controller = StreamController()
|
||||
self.on_disconnected = self._on_disconnect_controller.stream
|
||||
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.max_seconds_idle = 60
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.ping_task = None
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
try:
|
||||
return await super().send_request(method, args)
|
||||
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
|
||||
except RPCError as e:
|
||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||
raise e
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
raise
|
||||
|
||||
async def ping_forever(self):
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
|
@ -209,7 +213,10 @@ class SessionPool:
|
|||
self.ensure_connection(session)
|
||||
for session in self.sessions
|
||||
], return_exceptions=True)
|
||||
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
|
||||
try:
|
||||
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
self._lost_master.clear()
|
||||
if not self.sessions:
|
||||
self.sessions.extend(self._dead_servers)
|
||||
|
|
|
@ -195,6 +195,7 @@ class SPVNode:
|
|||
self.server = None
|
||||
self.hostname = 'localhost'
|
||||
self.port = 50001 + 1 # avoid conflict with default daemon
|
||||
self.session_timeout = 600
|
||||
|
||||
async def start(self, blockchain_node: 'BlockchainNode'):
|
||||
self.data_path = tempfile.mkdtemp()
|
||||
|
@ -204,6 +205,7 @@ class SPVNode:
|
|||
'REORG_LIMIT': '100',
|
||||
'HOST': self.hostname,
|
||||
'TCP_PORT': str(self.port),
|
||||
'SESSION_TIMEOUT': str(self.session_timeout),
|
||||
'MAX_QUERY_WORKERS': '0'
|
||||
}
|
||||
# TODO: don't use os.environ
|
||||
|
|
|
@ -98,6 +98,8 @@ class Daemon:
|
|||
return aiohttp.ClientSession(connector=self.connector, connector_owner=False)
|
||||
|
||||
async def _send_data(self, data):
|
||||
if not self.connector:
|
||||
raise asyncio.CancelledError('Tried to send request during shutdown.')
|
||||
async with self.workqueue_semaphore:
|
||||
async with self.client_session() as session:
|
||||
async with session.post(self.current_url(), data=data) as resp:
|
||||
|
|
|
@ -5,38 +5,6 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
|||
|
||||
import torba
|
||||
from torba.server.mempool import MemPool, MemPoolAPI
|
||||
from torba.server.session import SessionManager, SessionBase
|
||||
|
||||
|
||||
CONNECTION_TIMED_OUT = 110
|
||||
NO_ROUTE_TO_HOST = 113
|
||||
|
||||
|
||||
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
|
||||
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def protocol_exception_handler(loop, context):
|
||||
exception = context['exception']
|
||||
if 'protocol' not in context or 'transport' not in context:
|
||||
raise exception
|
||||
if not isinstance(context['protocol'], SessionBase):
|
||||
raise exception
|
||||
session: SessionBase = context['protocol']
|
||||
transport: asyncio.Transport = context['transport']
|
||||
message = context['message']
|
||||
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
|
||||
raise exception
|
||||
|
||||
for err_msg in err_msgs:
|
||||
if str(exception).startswith(err_msg):
|
||||
log.debug("caught: '%s' for %s", str(exception), session)
|
||||
transport.abort()
|
||||
transport.close()
|
||||
loop.create_task(session.close(force_after=1))
|
||||
return
|
||||
raise exception
|
||||
return protocol_exception_handler
|
||||
|
||||
|
||||
class Notifications:
|
||||
|
@ -121,7 +89,6 @@ class Server:
|
|||
)
|
||||
|
||||
async def start(self):
|
||||
asyncio.get_event_loop().set_exception_handler(handle_socket_errors())
|
||||
env = self.env
|
||||
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
|
||||
self.log.info(f'software version: {torba.__version__}')
|
||||
|
|
|
@ -255,9 +255,10 @@ class SessionManager:
|
|||
|
||||
async def _clear_stale_sessions(self):
|
||||
"""Cut off sessions that haven't done anything for 10 minutes."""
|
||||
session_timeout = self.env.session_timeout
|
||||
while True:
|
||||
await sleep(60)
|
||||
stale_cutoff = time.time() - self.env.session_timeout
|
||||
await sleep(session_timeout // 10)
|
||||
stale_cutoff = time.time() - session_timeout
|
||||
stale_sessions = [session for session in self.sessions
|
||||
if session.last_recv < stale_cutoff]
|
||||
if stale_sessions:
|
||||
|
@ -267,7 +268,7 @@ class SessionManager:
|
|||
# Give the sockets some time to close gracefully
|
||||
if stale_sessions:
|
||||
await asyncio.wait([
|
||||
session.close() for session in stale_sessions
|
||||
session.close(force_after=session_timeout // 10) for session in stale_sessions
|
||||
])
|
||||
|
||||
# Consolidate small groups
|
||||
|
|
Loading…
Add table
Reference in a new issue