Merge pull request #2371 from lbryio/basenetwork_refactor

refactor basenetwork so each session takes care of itself
This commit is contained in:
Alex Grin 2019-08-16 10:19:53 -04:00 committed by GitHub
commit ff73418fc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 139 deletions

View file

@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase):
await self.conductor.start_spv()
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
await session.create_connection()
session.ping_task.cancel()
await session.send_request('server.banner', ())
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
self.assertFalse(session.is_closing())

View file

@ -22,31 +22,40 @@ class ReconnectTests(IntegrationTestCase):
async def test_connection_drop_still_receives_events_after_reconnected(self):
address1 = await self.account.receiving.get_or_create_usable_address()
# disconnect and send a new tx, should reconnect and get it
self.ledger.network.client.connection_lost(Exception())
self.assertFalse(self.ledger.network.is_connected)
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
await self.on_transaction_id(sendtxid) # mempool
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
await self.blockchain.generate(1)
await self.on_transaction_id(sendtxid) # confirmed
self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine
await self.assertBalance(self.account, '1.1337')
# is it real? are we rich!? let me see this tx...
d = self.ledger.network.get_transaction(sendtxid)
# what's that smoke on my ethernet cable? oh no!
self.ledger.network.client.connection_lost(Exception())
with self.assertRaises(asyncio.CancelledError):
with self.assertRaises(asyncio.TimeoutError):
await d
self.assertIsNone(self.ledger.network.client.response_time) # response time unknown as it failed
# rich but offline? no way, no water, let's retry
with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
await self.ledger.network.get_transaction(sendtxid)
# * goes to pick some water outside... * time passes by and another donation comes in
sendtxid = await self.blockchain.send_to_address(address1, 42)
await self.blockchain.generate(1)
# (this is just so the test doesnt hang forever if it doesnt reconnect)
if not self.ledger.network.is_connected:
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
# omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid)
async def test_timeout_then_reconnect(self):
# tests that it connects back after some failed attempts
await self.conductor.spv_node.stop()
self.assertFalse(self.ledger.network.is_connected)
await asyncio.sleep(0.2) # let it retry and fail once
await self.conductor.spv_node.start(self.conductor.blockchain_node)
await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected)
@ -79,9 +88,9 @@ class ServerPickingTestCase(AsyncioTestCase):
await self._make_bad_server(),
('localhost', 1),
('example.that.doesnt.resolve', 9000),
await self._make_fake_server(latency=1.2, port=1340),
await self._make_fake_server(latency=0.5, port=1337),
await self._make_fake_server(latency=0.7, port=1339),
await self._make_fake_server(latency=1.0, port=1340),
await self._make_fake_server(latency=0.1, port=1337),
await self._make_fake_server(latency=0.4, port=1339),
],
'connect_timeout': 3
})
@ -89,9 +98,10 @@ class ServerPickingTestCase(AsyncioTestCase):
network = BaseNetwork(ledger)
self.addCleanup(network.stop)
asyncio.ensure_future(network.start())
await asyncio.wait_for(network.on_connected.first, timeout=3)
await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected)
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
# ensure we are connected to all of them
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))
self.assertEqual(len(network.session_pool.sessions), 3)
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
# ensure we are connected to all of them after a while
await asyncio.sleep(1)
self.assertEqual(len(network.session_pool.available_sessions), 3)

View file

@ -0,0 +1,20 @@
from torba.stream import StreamController
from torba.testcase import AsyncioTestCase
class StreamControllerTestCase(AsyncioTestCase):
def test_non_unique_events(self):
events = []
controller = StreamController()
controller.stream.listen(on_data=events.append)
controller.add("yo")
controller.add("yo")
self.assertEqual(events, ["yo", "yo"])
def test_unique_events(self):
events = []
controller = StreamController(merge_repeated_events=True)
controller.stream.listen(on_data=events.append)
controller.add("yo")
controller.add("yo")
self.assertEqual(events, ["yo"])

View file

@ -1,9 +1,8 @@
import logging
import asyncio
from asyncio import CancelledError
from time import time
from typing import List
import socket
from operator import itemgetter
from typing import Dict, Optional
from time import time, perf_counter
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -15,7 +14,7 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, **kwargs):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
@ -24,61 +23,88 @@ class ClientSession(BaseClientSession):
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.ping_task = None
self.response_time: Optional[float] = None
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@property
def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
async def send_request(self, method, args=()):
try:
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout
)
self.response_time = perf_counter() - start
return result
except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e
except asyncio.TimeoutError:
self.abort()
except TimeoutError:
self.response_time = None
raise
async def ping_forever(self):
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
while not self.is_closing():
if (time() - self.last_send) > self.max_seconds_idle:
try:
retry_delay = default_delay = 0.1
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.send_request('server.banner')
except:
self.abort()
raise
await asyncio.sleep(self.max_seconds_idle//3)
retry_delay = default_delay
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
def ensure_server_version(self, required='1.2'):
return self.send_request('server.version', [__version__, required])
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.ping_task = asyncio.create_task(self.ping_forever())
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc)
self.response_time = None
self._on_disconnect_controller.add(True)
if self.ping_task:
self.ping_task.cancel()
class BaseNetwork:
def __init__(self, ledger):
self.switch_event = asyncio.Event()
self.config = ledger.config
self.client: ClientSession = None
self.session_pool: SessionPool = None
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self.running = False
self.remote_height: int = 0
self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController()
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController()
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
@ -88,30 +114,22 @@ class BaseNetwork:
async def start(self):
self.running = True
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
while True:
while self.running:
try:
self.client = await self.pick_fastest_session()
if self.is_connected:
await self.ensure_server_version()
self._update_remote_height((await self.subscribe_headers(),))
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
await self.client.on_disconnected.first
except CancelledError:
self.running = False
self.client = await self.session_pool.wait_for_fastest_session()
self._update_remote_height((await self.subscribe_headers(),))
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
self.client.on_disconnected.listen(lambda _: self.switch_event.set())
await self.switch_event.wait()
self.switch_event.clear()
except asyncio.CancelledError:
await self.stop()
raise
except asyncio.TimeoutError:
log.warning("Timed out while trying to find a server!")
except Exception: # pylint: disable=broad-except
log.exception("Exception while trying to find a server!")
if not self.running:
return
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
pass
async def stop(self):
self.running = False
@ -124,35 +142,21 @@ class BaseNetwork:
@property
def is_connected(self):
return self.client is not None and not self.client.is_closing()
return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args):
fastest = self.session_pool.fastest_session
if fastest is not None and self.client != fastest:
self.switch_event.set()
if self.is_connected:
return self.client.send_request(list_or_method, args)
else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def pick_fastest_session(self):
sessions = await self.session_pool.get_online_sessions()
done, pending = await asyncio.wait([
self.probe_session(session)
for session in sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session
async def probe_session(self, session: ClientSession):
await session.send_request('server.banner')
return session
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
@ -182,73 +186,57 @@ class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float):
self.network = network
self.sessions: List[ClientSession] = []
self._dead_servers: List[ClientSession] = []
self.maintain_connections_task = None
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
# triggered when the master server is out, to speed up reconnect
self._lost_master = asyncio.Event()
self.new_connection_event = asyncio.Event()
@property
def online(self):
for session in self.sessions:
if not session.is_closing():
return True
return False
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return [session for session in self.sessions if session.available]
@property
def fastest_session(self):
if not self.available_sessions:
return None
return min(
[(session.response_time, session) for session in self.available_sessions], key=itemgetter(0)
)[1]
def start(self, default_servers):
self.sessions = [
ClientSession(network=self.network, server=server)
for server in default_servers
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
callback = self.new_connection_event.set
self.sessions = {
ClientSession(
network=self.network, server=server, on_connect_callback=callback
): None for server in default_servers
}
self.ensure_connections()
def stop(self):
if self.maintain_connections_task:
self.maintain_connections_task.cancel()
for session, task in self.sessions.items():
task.cancel()
session.abort()
self.sessions.clear()
def ensure_connections(self):
for session, task in list(self.sessions.items()):
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
if not session.is_closing():
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
session.trigger_urgent_reconnect.set()
async def ensure_connections(self):
while True:
await asyncio.gather(*[
self.ensure_connection(session)
for session in self.sessions
], return_exceptions=True)
try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session):
self._dead_servers.append(session)
self.sessions.remove(session)
try:
if session.is_closing():
await session.create_connection(self.timeout)
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
self.sessions.append(session)
self._dead_servers.remove(session)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except socket.gaierror:
log.warning("Could not resolve IP for %s", session.server[0])
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
async def get_online_sessions(self):
while not self.online:
self._lost_master.set()
await asyncio.sleep(0.5)
return self.sessions
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session

View file

@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools
import json
import typing
import asyncio
from functools import partial
from numbers import Number
@ -745,9 +746,10 @@ class JSONRPCConnection(object):
self._protocol = item
return self.receive_message(message)
def cancel_pending_requests(self):
"""Cancel all pending requests."""
exception = CancelledError()
def time_out_pending_requests(self):
"""Times out all pending requests."""
# this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
exception = asyncio.TimeoutError()
for request, event in self._requests.values():
event.result = exception
event.set()

View file

@ -456,7 +456,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc):
# Cancel pending requests and message processing
self.connection.cancel_pending_requests()
self.connection.time_out_pending_requests()
super().connection_lost(exc)
# External API
@ -473,6 +473,8 @@ class RPCSession(SessionBase):
async def send_request(self, method, args=()):
"""Send an RPC request over the network."""
if self.is_closing():
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
message, event = self.connection.send_request(Request(method, args))
await self._send_message(message)
await event.wait()

View file

@ -45,10 +45,12 @@ class BroadcastSubscription:
class StreamController:
def __init__(self):
def __init__(self, merge_repeated_events=False):
self.stream = Stream(self)
self._first_subscription = None
self._last_subscription = None
self._last_event = None
self._merge_repeated = merge_repeated_events
@property
def has_listener(self):
@ -76,8 +78,10 @@ class StreamController:
return f
def add(self, event):
skip = self._merge_repeated and event == self._last_event
self._last_event = event
return self._notify_and_ensure_future(
lambda subscription: subscription._add(event)
lambda subscription: None if skip else subscription._add(event)
)
def add_error(self, exception):