Merge pull request #2371 from lbryio/basenetwork_refactor
refactor basenetwork so each session takes care of itself
This commit is contained in:
commit
ff73418fc1
7 changed files with 164 additions and 139 deletions
|
@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase):
|
|||
await self.conductor.start_spv()
|
||||
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
|
||||
await session.create_connection()
|
||||
session.ping_task.cancel()
|
||||
await session.send_request('server.banner', ())
|
||||
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
|
||||
self.assertFalse(session.is_closing())
|
||||
|
|
|
@ -22,31 +22,40 @@ class ReconnectTests(IntegrationTestCase):
|
|||
|
||||
async def test_connection_drop_still_receives_events_after_reconnected(self):
|
||||
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||
# disconnect and send a new tx, should reconnect and get it
|
||||
self.ledger.network.client.connection_lost(Exception())
|
||||
self.assertFalse(self.ledger.network.is_connected)
|
||||
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
|
||||
await self.on_transaction_id(sendtxid) # mempool
|
||||
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
|
||||
await self.blockchain.generate(1)
|
||||
await self.on_transaction_id(sendtxid) # confirmed
|
||||
self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine
|
||||
|
||||
await self.assertBalance(self.account, '1.1337')
|
||||
# is it real? are we rich!? let me see this tx...
|
||||
d = self.ledger.network.get_transaction(sendtxid)
|
||||
# what's that smoke on my ethernet cable? oh no!
|
||||
self.ledger.network.client.connection_lost(Exception())
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await d
|
||||
self.assertIsNone(self.ledger.network.client.response_time) # response time unknown as it failed
|
||||
# rich but offline? no way, no water, let's retry
|
||||
with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
|
||||
await self.ledger.network.get_transaction(sendtxid)
|
||||
# * goes to pick some water outside... * time passes by and another donation comes in
|
||||
sendtxid = await self.blockchain.send_to_address(address1, 42)
|
||||
await self.blockchain.generate(1)
|
||||
# (this is just so the test doesnt hang forever if it doesnt reconnect)
|
||||
if not self.ledger.network.is_connected:
|
||||
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
|
||||
# omg, the burned cable still works! torba is fire proof!
|
||||
await self.ledger.network.get_transaction(sendtxid)
|
||||
|
||||
async def test_timeout_then_reconnect(self):
|
||||
# tests that it connects back after some failed attempts
|
||||
await self.conductor.spv_node.stop()
|
||||
self.assertFalse(self.ledger.network.is_connected)
|
||||
await asyncio.sleep(0.2) # let it retry and fail once
|
||||
await self.conductor.spv_node.start(self.conductor.blockchain_node)
|
||||
await self.ledger.network.on_connected.first
|
||||
self.assertTrue(self.ledger.network.is_connected)
|
||||
|
@ -79,9 +88,9 @@ class ServerPickingTestCase(AsyncioTestCase):
|
|||
await self._make_bad_server(),
|
||||
('localhost', 1),
|
||||
('example.that.doesnt.resolve', 9000),
|
||||
await self._make_fake_server(latency=1.2, port=1340),
|
||||
await self._make_fake_server(latency=0.5, port=1337),
|
||||
await self._make_fake_server(latency=0.7, port=1339),
|
||||
await self._make_fake_server(latency=1.0, port=1340),
|
||||
await self._make_fake_server(latency=0.1, port=1337),
|
||||
await self._make_fake_server(latency=0.4, port=1339),
|
||||
],
|
||||
'connect_timeout': 3
|
||||
})
|
||||
|
@ -89,9 +98,10 @@ class ServerPickingTestCase(AsyncioTestCase):
|
|||
network = BaseNetwork(ledger)
|
||||
self.addCleanup(network.stop)
|
||||
asyncio.ensure_future(network.start())
|
||||
await asyncio.wait_for(network.on_connected.first, timeout=3)
|
||||
await asyncio.wait_for(network.on_connected.first, timeout=1)
|
||||
self.assertTrue(network.is_connected)
|
||||
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
|
||||
# ensure we are connected to all of them
|
||||
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))
|
||||
self.assertEqual(len(network.session_pool.sessions), 3)
|
||||
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
|
||||
# ensure we are connected to all of them after a while
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual(len(network.session_pool.available_sessions), 3)
|
||||
|
|
20
torba/tests/client_tests/unit/test_stream_controller.py
Normal file
20
torba/tests/client_tests/unit/test_stream_controller.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from torba.stream import StreamController
|
||||
from torba.testcase import AsyncioTestCase
|
||||
|
||||
|
||||
class StreamControllerTestCase(AsyncioTestCase):
|
||||
def test_non_unique_events(self):
|
||||
events = []
|
||||
controller = StreamController()
|
||||
controller.stream.listen(on_data=events.append)
|
||||
controller.add("yo")
|
||||
controller.add("yo")
|
||||
self.assertEqual(events, ["yo", "yo"])
|
||||
|
||||
def test_unique_events(self):
|
||||
events = []
|
||||
controller = StreamController(merge_repeated_events=True)
|
||||
controller.stream.listen(on_data=events.append)
|
||||
controller.add("yo")
|
||||
controller.add("yo")
|
||||
self.assertEqual(events, ["yo"])
|
|
@ -1,9 +1,8 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from asyncio import CancelledError
|
||||
from time import time
|
||||
from typing import List
|
||||
import socket
|
||||
from operator import itemgetter
|
||||
from typing import Dict, Optional
|
||||
from time import time, perf_counter
|
||||
|
||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||
|
||||
|
@ -15,7 +14,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class ClientSession(BaseClientSession):
|
||||
|
||||
def __init__(self, *args, network, server, timeout=30, **kwargs):
|
||||
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
|
||||
self.network = network
|
||||
self.server = server
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -24,61 +23,88 @@ class ClientSession(BaseClientSession):
|
|||
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
||||
self.timeout = timeout
|
||||
self.max_seconds_idle = timeout * 2
|
||||
self.ping_task = None
|
||||
self.response_time: Optional[float] = None
|
||||
self._on_connect_cb = on_connect_callback or (lambda: None)
|
||||
self.trigger_urgent_reconnect = asyncio.Event()
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
try:
|
||||
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
|
||||
start = perf_counter()
|
||||
result = await asyncio.wait_for(
|
||||
super().send_request(method, args), timeout=self.timeout
|
||||
)
|
||||
self.response_time = perf_counter() - start
|
||||
return result
|
||||
except RPCError as e:
|
||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||
raise e
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
except TimeoutError:
|
||||
self.response_time = None
|
||||
raise
|
||||
|
||||
async def ping_forever(self):
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
while not self.is_closing():
|
||||
if (time() - self.last_send) > self.max_seconds_idle:
|
||||
try:
|
||||
retry_delay = default_delay = 0.1
|
||||
while True:
|
||||
try:
|
||||
if self.is_closing():
|
||||
await self.create_connection(self.timeout)
|
||||
await self.ensure_server_version()
|
||||
self._on_connect_cb()
|
||||
if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None:
|
||||
await self.send_request('server.banner')
|
||||
except:
|
||||
self.abort()
|
||||
raise
|
||||
await asyncio.sleep(self.max_seconds_idle//3)
|
||||
retry_delay = default_delay
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
await self.close()
|
||||
retry_delay = min(60, retry_delay * 2)
|
||||
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||
try:
|
||||
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self.trigger_urgent_reconnect.clear()
|
||||
|
||||
def ensure_server_version(self, required='1.2'):
|
||||
return self.send_request('server.version', [__version__, required])
|
||||
|
||||
async def create_connection(self, timeout=6):
|
||||
connector = Connector(lambda: self, *self.server)
|
||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||
self.ping_task = asyncio.create_task(self.ping_forever())
|
||||
|
||||
async def handle_request(self, request):
|
||||
controller = self.network.subscription_controllers[request.method]
|
||||
controller.add(request.args)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
log.debug("Connection lost: %s:%d", *self.server)
|
||||
super().connection_lost(exc)
|
||||
self.response_time = None
|
||||
self._on_disconnect_controller.add(True)
|
||||
if self.ping_task:
|
||||
self.ping_task.cancel()
|
||||
|
||||
|
||||
class BaseNetwork:
|
||||
|
||||
def __init__(self, ledger):
|
||||
self.switch_event = asyncio.Event()
|
||||
self.config = ledger.config
|
||||
self.client: ClientSession = None
|
||||
self.session_pool: SessionPool = None
|
||||
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||
self.client: Optional[ClientSession] = None
|
||||
self.running = False
|
||||
self.remote_height: int = 0
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
self.on_connected = self._on_connected_controller.stream
|
||||
|
||||
self._on_header_controller = StreamController()
|
||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_header = self._on_header_controller.stream
|
||||
|
||||
self._on_status_controller = StreamController()
|
||||
self._on_status_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_status = self._on_status_controller.stream
|
||||
|
||||
self.subscription_controllers = {
|
||||
|
@ -88,30 +114,22 @@ class BaseNetwork:
|
|||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
connect_timeout = self.config.get('connect_timeout', 6)
|
||||
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
while True:
|
||||
while self.running:
|
||||
try:
|
||||
self.client = await self.pick_fastest_session()
|
||||
if self.is_connected:
|
||||
await self.ensure_server_version()
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
||||
self._on_connected_controller.add(True)
|
||||
await self.client.on_disconnected.first
|
||||
except CancelledError:
|
||||
self.running = False
|
||||
self.client = await self.session_pool.wait_for_fastest_session()
|
||||
self._update_remote_height((await self.subscribe_headers(),))
|
||||
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
|
||||
self._on_connected_controller.add(True)
|
||||
self.client.on_disconnected.listen(lambda _: self.switch_event.set())
|
||||
await self.switch_event.wait()
|
||||
self.switch_event.clear()
|
||||
except asyncio.CancelledError:
|
||||
await self.stop()
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Timed out while trying to find a server!")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.exception("Exception while trying to find a server!")
|
||||
if not self.running:
|
||||
return
|
||||
elif self.client:
|
||||
await self.client.close()
|
||||
self.client.connection.cancel_pending_requests()
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
|
@ -124,35 +142,21 @@ class BaseNetwork:
|
|||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self.client is not None and not self.client.is_closing()
|
||||
return self.client and not self.client.is_closing()
|
||||
|
||||
def rpc(self, list_or_method, args):
|
||||
fastest = self.session_pool.fastest_session
|
||||
if fastest is not None and self.client != fastest:
|
||||
self.switch_event.set()
|
||||
if self.is_connected:
|
||||
return self.client.send_request(list_or_method, args)
|
||||
else:
|
||||
self.session_pool.trigger_nodelay_connect()
|
||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||
|
||||
async def pick_fastest_session(self):
|
||||
sessions = await self.session_pool.get_online_sessions()
|
||||
done, pending = await asyncio.wait([
|
||||
self.probe_session(session)
|
||||
for session in sessions if not session.is_closing()
|
||||
], return_when='FIRST_COMPLETED')
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for session in done:
|
||||
return await session
|
||||
|
||||
async def probe_session(self, session: ClientSession):
|
||||
await session.send_request('server.banner')
|
||||
return session
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def ensure_server_version(self, required='1.2'):
|
||||
return self.rpc('server.version', [__version__, required])
|
||||
|
||||
def broadcast(self, raw_transaction):
|
||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||
|
||||
|
@ -182,73 +186,57 @@ class SessionPool:
|
|||
|
||||
def __init__(self, network: BaseNetwork, timeout: float):
|
||||
self.network = network
|
||||
self.sessions: List[ClientSession] = []
|
||||
self._dead_servers: List[ClientSession] = []
|
||||
self.maintain_connections_task = None
|
||||
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
|
||||
self.timeout = timeout
|
||||
# triggered when the master server is out, to speed up reconnect
|
||||
self._lost_master = asyncio.Event()
|
||||
self.new_connection_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
for session in self.sessions:
|
||||
if not session.is_closing():
|
||||
return True
|
||||
return False
|
||||
return any(not session.is_closing() for session in self.sessions)
|
||||
|
||||
@property
|
||||
def available_sessions(self):
|
||||
return [session for session in self.sessions if session.available]
|
||||
|
||||
@property
|
||||
def fastest_session(self):
|
||||
if not self.available_sessions:
|
||||
return None
|
||||
return min(
|
||||
[(session.response_time, session) for session in self.available_sessions], key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def start(self, default_servers):
|
||||
self.sessions = [
|
||||
ClientSession(network=self.network, server=server)
|
||||
for server in default_servers
|
||||
]
|
||||
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
||||
callback = self.new_connection_event.set
|
||||
self.sessions = {
|
||||
ClientSession(
|
||||
network=self.network, server=server, on_connect_callback=callback
|
||||
): None for server in default_servers
|
||||
}
|
||||
self.ensure_connections()
|
||||
|
||||
def stop(self):
|
||||
if self.maintain_connections_task:
|
||||
self.maintain_connections_task.cancel()
|
||||
for session, task in self.sessions.items():
|
||||
task.cancel()
|
||||
session.abort()
|
||||
self.sessions.clear()
|
||||
|
||||
def ensure_connections(self):
|
||||
for session, task in list(self.sessions.items()):
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
# bypasses the retry interval
|
||||
for session in self.sessions:
|
||||
if not session.is_closing():
|
||||
session.abort()
|
||||
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
|
||||
session.trigger_urgent_reconnect.set()
|
||||
|
||||
async def ensure_connections(self):
|
||||
while True:
|
||||
await asyncio.gather(*[
|
||||
self.ensure_connection(session)
|
||||
for session in self.sessions
|
||||
], return_exceptions=True)
|
||||
try:
|
||||
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
self._lost_master.clear()
|
||||
if not self.sessions:
|
||||
self.sessions.extend(self._dead_servers)
|
||||
self._dead_servers = []
|
||||
|
||||
async def ensure_connection(self, session):
|
||||
self._dead_servers.append(session)
|
||||
self.sessions.remove(session)
|
||||
try:
|
||||
if session.is_closing():
|
||||
await session.create_connection(self.timeout)
|
||||
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
|
||||
self.sessions.append(session)
|
||||
self._dead_servers.remove(session)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Timeout connecting to %s:%d", *session.server)
|
||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||
raise
|
||||
except socket.gaierror:
|
||||
log.warning("Could not resolve IP for %s", session.server[0])
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
if 'Connect call failed' in str(err):
|
||||
log.warning("Could not connect to %s:%d", *session.server)
|
||||
else:
|
||||
log.exception("Connecting to %s:%d raised an exception:", *session.server)
|
||||
|
||||
async def get_online_sessions(self):
|
||||
while not self.online:
|
||||
self._lost_master.set()
|
||||
await asyncio.sleep(0.5)
|
||||
return self.sessions
|
||||
async def wait_for_fastest_session(self):
|
||||
while not self.fastest_session:
|
||||
self.trigger_nodelay_connect()
|
||||
self.new_connection_event.clear()
|
||||
await self.new_connection_event.wait()
|
||||
return self.fastest_session
|
||||
|
|
|
@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
|||
import itertools
|
||||
import json
|
||||
import typing
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
|
@ -745,9 +746,10 @@ class JSONRPCConnection(object):
|
|||
self._protocol = item
|
||||
return self.receive_message(message)
|
||||
|
||||
def cancel_pending_requests(self):
|
||||
"""Cancel all pending requests."""
|
||||
exception = CancelledError()
|
||||
def time_out_pending_requests(self):
|
||||
"""Times out all pending requests."""
|
||||
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
|
||||
exception = asyncio.TimeoutError()
|
||||
for request, event in self._requests.values():
|
||||
event.result = exception
|
||||
event.set()
|
||||
|
|
|
@ -456,7 +456,7 @@ class RPCSession(SessionBase):
|
|||
|
||||
def connection_lost(self, exc):
|
||||
# Cancel pending requests and message processing
|
||||
self.connection.cancel_pending_requests()
|
||||
self.connection.time_out_pending_requests()
|
||||
super().connection_lost(exc)
|
||||
|
||||
# External API
|
||||
|
@ -473,6 +473,8 @@ class RPCSession(SessionBase):
|
|||
|
||||
async def send_request(self, method, args=()):
|
||||
"""Send an RPC request over the network."""
|
||||
if self.is_closing():
|
||||
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
|
||||
message, event = self.connection.send_request(Request(method, args))
|
||||
await self._send_message(message)
|
||||
await event.wait()
|
||||
|
|
|
@ -45,10 +45,12 @@ class BroadcastSubscription:
|
|||
|
||||
class StreamController:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, merge_repeated_events=False):
|
||||
self.stream = Stream(self)
|
||||
self._first_subscription = None
|
||||
self._last_subscription = None
|
||||
self._last_event = None
|
||||
self._merge_repeated = merge_repeated_events
|
||||
|
||||
@property
|
||||
def has_listener(self):
|
||||
|
@ -76,8 +78,10 @@ class StreamController:
|
|||
return f
|
||||
|
||||
def add(self, event):
|
||||
skip = self._merge_repeated and event == self._last_event
|
||||
self._last_event = event
|
||||
return self._notify_and_ensure_future(
|
||||
lambda subscription: subscription._add(event)
|
||||
lambda subscription: None if skip else subscription._add(event)
|
||||
)
|
||||
|
||||
def add_error(self, exception):
|
||||
|
|
Loading…
Reference in a new issue