fix reconnect and timeouts

This commit is contained in:
Victor Shyba 2019-07-16 06:23:44 -03:00
parent d97d738880
commit 6997ea608d
6 changed files with 38 additions and 161 deletions

View file

@ -1,132 +1,30 @@
import asyncio
import socket
import time
import logging
from unittest.mock import Mock
from torba.testcase import IntegrationTestCase, Conductor
from torba.client.basenetwork import ClientSession
from torba.orchstr8 import Conductor
from torba.testcase import IntegrationTestCase
import lbry.wallet
from lbry.schema.claim import Claim
from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.dewies import dewies_to_lbc as d2l, lbc_to_dewies as l2d
log = logging.getLogger(__name__)
def wrap_callback_event(fn, callback):
def inner(*a, **kw):
callback()
return fn(*a, **kw)
return inner
class TestSessionBloat(IntegrationTestCase):
"""
ERROR:asyncio:Fatal read error on socket transport
protocol: <lbrynet.wallet.server.session.LBRYElectrumX object at 0x7f7e3bfcaf60>
transport: <_SelectorSocketTransport fd=3236 read=polling write=<idle, bufsize=0>>
Traceback (most recent call last):
File "/usr/lib/python3.7/asyncio/selector_events.py", line 801, in _read_ready__data_received
data = self._sock.recv(self.max_size)
TimeoutError: [Errno 110] Connection timed out
Tests that server cleans up stale connections after session timeout and client times out too.
"""
LEDGER = lbry.wallet
async def asyncSetUp(self):
self.conductor = Conductor(
ledger_module=self.LEDGER, manager_module=self.MANAGER, verbosity=self.VERBOSITY
)
await self.conductor.start_blockchain()
self.addCleanup(self.conductor.stop_blockchain)
await self.conductor.start_spv()
self.session_manager = self.conductor.spv_node.server.session_mgr
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 64)
self.session_manager.servers['TCP'].sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 64)
self.addCleanup(self.conductor.stop_spv)
await self.conductor.start_wallet()
self.addCleanup(self.conductor.stop_wallet)
self.client_session = list(self.session_manager.sessions)[0]
self.client_session.transport.set_write_buffer_limits(0, 0)
self.paused_session = asyncio.Event(loop=self.loop)
self.resumed_session = asyncio.Event(loop=self.loop)
def paused():
self.resumed_session.clear()
self.paused_session.set()
def delayed_resume():
self.paused_session.clear()
time.sleep(1)
self.resumed_session.set()
self.client_session.pause_writing = wrap_callback_event(self.client_session.pause_writing, paused)
self.client_session.resume_writing = wrap_callback_event(self.client_session.resume_writing, delayed_resume)
self.blockchain = self.conductor.blockchain_node
self.wallet_node = self.conductor.wallet_node
self.manager = self.wallet_node.manager
self.ledger = self.wallet_node.ledger
self.wallet = self.wallet_node.wallet
self.account = self.wallet_node.wallet.default_account
async def test_session_bloat_from_socket_timeout(self):
await self.account.ensure_address_gap()
address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True)
sendtxid1 = await self.blockchain.send_to_address(address1, 5)
sendtxid2 = await self.blockchain.send_to_address(address2, 5)
await self.blockchain.generate(1)
await asyncio.wait([
self.on_transaction_id(sendtxid1),
self.on_transaction_id(sendtxid2)
])
self.assertEqual(d2l(await self.account.get_balance()), '10.0')
channel = Claim()
channel_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1)
)
channel_txo.generate_channel_private_key()
channel_txo.script.generate()
channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account)
stream = Claim()
stream.stream.description = "0" * 8000
stream_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), 'foo', stream, self.account.ledger.address_to_hash160(address1)
)
stream_tx = await Transaction.create([], [stream_txo], [self.account], self.account)
stream_txo.sign(channel_txo)
await stream_tx.sign([self.account])
self.paused_session.clear()
self.resumed_session.clear()
await self.broadcast(channel_tx)
await self.broadcast(stream_tx)
await asyncio.wait_for(self.paused_session.wait(), 2)
self.assertEqual(1, len(self.session_manager.sessions))
real_sock = self.client_session.transport._extra.pop('socket')
mock_sock = Mock(spec=socket.socket)
for attr in dir(real_sock):
if not attr.startswith('__'):
setattr(mock_sock, attr, getattr(real_sock, attr))
def recv(*a, **kw):
raise TimeoutError("[Errno 110] Connection timed out")
mock_sock.recv = recv
self.client_session.transport._sock = mock_sock
self.client_session.transport._extra['socket'] = mock_sock
self.assertFalse(self.resumed_session.is_set())
self.assertFalse(self.session_manager.session_event.is_set())
await self.session_manager.session_event.wait()
self.assertEqual(0, len(self.session_manager.sessions))
await self.conductor.stop_spv()
self.conductor.spv_node.session_timeout = 1
await self.conductor.start_spv()
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
await session.create_connection()
session.ping_task.cancel()
await session.send_request('server.banner', ())
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
self.assertFalse(session.is_closing())
await asyncio.sleep(1.1)
with self.assertRaises(asyncio.TimeoutError):
await session.send_request('server.banner', ())
self.assertTrue(session.is_closing())
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 0)

View file

@ -15,22 +15,26 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, **kwargs):
def __init__(self, *args, network, server, timeout=30, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.max_seconds_idle = 60
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.ping_task = None
async def send_request(self, method, args=()):
try:
return await super().send_request(method, args)
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e
except asyncio.TimeoutError:
self.abort()
raise
async def ping_forever(self):
# TODO: change to 'ping' on newer protocol (above 1.2)
@ -209,7 +213,10 @@ class SessionPool:
self.ensure_connection(session)
for session in self.sessions
], return_exceptions=True)
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)

View file

@ -195,6 +195,7 @@ class SPVNode:
self.server = None
self.hostname = 'localhost'
self.port = 50001 + 1 # avoid conflict with default daemon
self.session_timeout = 600
async def start(self, blockchain_node: 'BlockchainNode'):
self.data_path = tempfile.mkdtemp()
@ -204,6 +205,7 @@ class SPVNode:
'REORG_LIMIT': '100',
'HOST': self.hostname,
'TCP_PORT': str(self.port),
'SESSION_TIMEOUT': str(self.session_timeout),
'MAX_QUERY_WORKERS': '0'
}
# TODO: don't use os.environ

View file

@ -98,6 +98,8 @@ class Daemon:
return aiohttp.ClientSession(connector=self.connector, connector_owner=False)
async def _send_data(self, data):
if not self.connector:
raise asyncio.CancelledError('Tried to send request during shutdown.')
async with self.workqueue_semaphore:
async with self.client_session() as session:
async with session.post(self.current_url(), data=data) as resp:

View file

@ -5,38 +5,6 @@ from concurrent.futures.thread import ThreadPoolExecutor
import torba
from torba.server.mempool import MemPool, MemPoolAPI
from torba.server.session import SessionManager, SessionBase
CONNECTION_TIMED_OUT = 110
NO_ROUTE_TO_HOST = 113
def handle_socket_errors(socket_errors=(CONNECTION_TIMED_OUT, NO_ROUTE_TO_HOST)):
err_msgs = tuple((f"[Errno {err_code}]" for err_code in socket_errors))
log = logging.getLogger(__name__)
def protocol_exception_handler(loop, context):
exception = context['exception']
if 'protocol' not in context or 'transport' not in context:
raise exception
if not isinstance(context['protocol'], SessionBase):
raise exception
session: SessionBase = context['protocol']
transport: asyncio.Transport = context['transport']
message = context['message']
if message not in ("Fatal read error on socket transport", "Fatal write error on socket transport"):
raise exception
for err_msg in err_msgs:
if str(exception).startswith(err_msg):
log.debug("caught: '%s' for %s", str(exception), session)
transport.abort()
transport.close()
loop.create_task(session.close(force_after=1))
return
raise exception
return protocol_exception_handler
class Notifications:
@ -121,7 +89,6 @@ class Server:
)
async def start(self):
asyncio.get_event_loop().set_exception_handler(handle_socket_errors())
env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.log.info(f'software version: {torba.__version__}')

View file

@ -255,9 +255,10 @@ class SessionManager:
async def _clear_stale_sessions(self):
"""Cut off sessions that haven't done anything for 10 minutes."""
session_timeout = self.env.session_timeout
while True:
await sleep(60)
stale_cutoff = time.time() - self.env.session_timeout
await sleep(session_timeout // 10)
stale_cutoff = time.time() - session_timeout
stale_sessions = [session for session in self.sessions
if session.last_recv < stale_cutoff]
if stale_sessions:
@ -267,7 +268,7 @@ class SessionManager:
# Give the sockets some time to close gracefully
if stale_sessions:
await asyncio.wait([
session.close() for session in stale_sessions
session.close(force_after=session_timeout // 10) for session in stale_sessions
])
# Consolidate small groups