Merge pull request #2371 from lbryio/basenetwork_refactor

refactor basenetwork so each session takes care of itself
This commit is contained in:
Alex Grin 2019-08-16 10:19:53 -04:00 committed by GitHub
commit ff73418fc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 139 deletions

View file

@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase):
await self.conductor.start_spv() await self.conductor.start_spv()
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2) session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
await session.create_connection() await session.create_connection()
session.ping_task.cancel()
await session.send_request('server.banner', ()) await session.send_request('server.banner', ())
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
self.assertFalse(session.is_closing()) self.assertFalse(session.is_closing())

View file

@ -22,31 +22,40 @@ class ReconnectTests(IntegrationTestCase):
async def test_connection_drop_still_receives_events_after_reconnected(self): async def test_connection_drop_still_receives_events_after_reconnected(self):
address1 = await self.account.receiving.get_or_create_usable_address() address1 = await self.account.receiving.get_or_create_usable_address()
# disconnect and send a new tx, should reconnect and get it
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
self.assertFalse(self.ledger.network.is_connected)
sendtxid = await self.blockchain.send_to_address(address1, 1.1337) sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
await self.on_transaction_id(sendtxid) # mempool await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction_id(sendtxid) # confirmed await self.on_transaction_id(sendtxid) # confirmed
self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine
await self.assertBalance(self.account, '1.1337') await self.assertBalance(self.account, '1.1337')
# is it real? are we rich!? let me see this tx... # is it real? are we rich!? let me see this tx...
d = self.ledger.network.get_transaction(sendtxid) d = self.ledger.network.get_transaction(sendtxid)
# what's that smoke on my ethernet cable? oh no! # what's that smoke on my ethernet cable? oh no!
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.TimeoutError):
await d await d
self.assertIsNone(self.ledger.network.client.response_time) # response time unknown as it failed
# rich but offline? no way, no water, let's retry # rich but offline? no way, no water, let's retry
with self.assertRaisesRegex(ConnectionError, 'connection is not available'): with self.assertRaisesRegex(ConnectionError, 'connection is not available'):
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(sendtxid)
# * goes to pick some water outside... * time passes by and another donation comes in # * goes to pick some water outside... * time passes by and another donation comes in
sendtxid = await self.blockchain.send_to_address(address1, 42) sendtxid = await self.blockchain.send_to_address(address1, 42)
await self.blockchain.generate(1) await self.blockchain.generate(1)
# (this is just so the test doesnt hang forever if it doesnt reconnect)
if not self.ledger.network.is_connected:
await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0)
# omg, the burned cable still works! torba is fire proof! # omg, the burned cable still works! torba is fire proof!
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(sendtxid)
async def test_timeout_then_reconnect(self): async def test_timeout_then_reconnect(self):
# tests that it connects back after some failed attempts
await self.conductor.spv_node.stop() await self.conductor.spv_node.stop()
self.assertFalse(self.ledger.network.is_connected) self.assertFalse(self.ledger.network.is_connected)
await asyncio.sleep(0.2) # let it retry and fail once
await self.conductor.spv_node.start(self.conductor.blockchain_node) await self.conductor.spv_node.start(self.conductor.blockchain_node)
await self.ledger.network.on_connected.first await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected)
@ -79,9 +88,9 @@ class ServerPickingTestCase(AsyncioTestCase):
await self._make_bad_server(), await self._make_bad_server(),
('localhost', 1), ('localhost', 1),
('example.that.doesnt.resolve', 9000), ('example.that.doesnt.resolve', 9000),
await self._make_fake_server(latency=1.2, port=1340), await self._make_fake_server(latency=1.0, port=1340),
await self._make_fake_server(latency=0.5, port=1337), await self._make_fake_server(latency=0.1, port=1337),
await self._make_fake_server(latency=0.7, port=1339), await self._make_fake_server(latency=0.4, port=1339),
], ],
'connect_timeout': 3 'connect_timeout': 3
}) })
@ -89,9 +98,10 @@ class ServerPickingTestCase(AsyncioTestCase):
network = BaseNetwork(ledger) network = BaseNetwork(ledger)
self.addCleanup(network.stop) self.addCleanup(network.stop)
asyncio.ensure_future(network.start()) asyncio.ensure_future(network.start())
await asyncio.wait_for(network.on_connected.first, timeout=3) await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected) self.assertTrue(network.is_connected)
self.assertEqual(network.client.server, ('127.0.0.1', 1337)) self.assertEqual(network.client.server, ('127.0.0.1', 1337))
# ensure we are connected to all of them self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions])) # ensure we are connected to all of them after a while
self.assertEqual(len(network.session_pool.sessions), 3) await asyncio.sleep(1)
self.assertEqual(len(network.session_pool.available_sessions), 3)

View file

@ -0,0 +1,20 @@
from torba.stream import StreamController
from torba.testcase import AsyncioTestCase
class StreamControllerTestCase(AsyncioTestCase):
def test_non_unique_events(self):
events = []
controller = StreamController()
controller.stream.listen(on_data=events.append)
controller.add("yo")
controller.add("yo")
self.assertEqual(events, ["yo", "yo"])
def test_unique_events(self):
events = []
controller = StreamController(merge_repeated_events=True)
controller.stream.listen(on_data=events.append)
controller.add("yo")
controller.add("yo")
self.assertEqual(events, ["yo"])

View file

@ -1,9 +1,8 @@
import logging import logging
import asyncio import asyncio
from asyncio import CancelledError from operator import itemgetter
from time import time from typing import Dict, Optional
from typing import List from time import time, perf_counter
import socket
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -15,7 +14,7 @@ log = logging.getLogger(__name__)
class ClientSession(BaseClientSession): class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, **kwargs): def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network self.network = network
self.server = server self.server = server
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -24,61 +23,88 @@ class ClientSession(BaseClientSession):
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
self.ping_task = None self.response_time: Optional[float] = None
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@property
def available(self):
return not self.is_closing() and self._can_send.is_set() and self.response_time is not None
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
try: try:
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) start = perf_counter()
result = await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout
)
self.response_time = perf_counter() - start
return result
except RPCError as e: except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e raise e
except asyncio.TimeoutError: except TimeoutError:
self.abort() self.response_time = None
raise raise
async def ping_forever(self): async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2) # TODO: change to 'ping' on newer protocol (above 1.2)
while not self.is_closing(): retry_delay = default_delay = 0.1
if (time() - self.last_send) > self.max_seconds_idle: while True:
try: try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.send_request('server.banner') await self.send_request('server.banner')
except: retry_delay = default_delay
self.abort() except (asyncio.TimeoutError, OSError):
raise await self.close()
await asyncio.sleep(self.max_seconds_idle//3) retry_delay = min(60, retry_delay * 2)
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
def ensure_server_version(self, required='1.2'):
return self.send_request('server.version', [__version__, required])
async def create_connection(self, timeout=6): async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server) connector = Connector(lambda: self, *self.server)
await asyncio.wait_for(connector.create_connection(), timeout=timeout) await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.ping_task = asyncio.create_task(self.ping_forever())
async def handle_request(self, request): async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method] controller = self.network.subscription_controllers[request.method]
controller.add(request.args) controller.add(request.args)
def connection_lost(self, exc): def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc) super().connection_lost(exc)
self.response_time = None
self._on_disconnect_controller.add(True) self._on_disconnect_controller.add(True)
if self.ping_task:
self.ping_task.cancel()
class BaseNetwork: class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
self.switch_event = asyncio.Event()
self.config = ledger.config self.config = ledger.config
self.client: ClientSession = None self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.session_pool: SessionPool = None self.client: Optional[ClientSession] = None
self.running = False self.running = False
self.remote_height: int = 0 self.remote_height: int = 0
self._on_connected_controller = StreamController() self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController() self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController() self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream self.on_status = self._on_status_controller.stream
self.subscription_controllers = { self.subscription_controllers = {
@ -88,30 +114,22 @@ class BaseNetwork:
async def start(self): async def start(self):
self.running = True self.running = True
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
while True: while self.running:
try: try:
self.client = await self.pick_fastest_session() self.client = await self.session_pool.wait_for_fastest_session()
if self.is_connected: self._update_remote_height((await self.subscribe_headers(),))
await self.ensure_server_version() log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
self._update_remote_height((await self.subscribe_headers(),)) self._on_connected_controller.add(True)
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) self.client.on_disconnected.listen(lambda _: self.switch_event.set())
self._on_connected_controller.add(True) await self.switch_event.wait()
await self.client.on_disconnected.first self.switch_event.clear()
except CancelledError: except asyncio.CancelledError:
self.running = False await self.stop()
raise
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("Timed out while trying to find a server!") pass
except Exception: # pylint: disable=broad-except
log.exception("Exception while trying to find a server!")
if not self.running:
return
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
async def stop(self): async def stop(self):
self.running = False self.running = False
@ -124,35 +142,21 @@ class BaseNetwork:
@property @property
def is_connected(self): def is_connected(self):
return self.client is not None and not self.client.is_closing() return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args): def rpc(self, list_or_method, args):
fastest = self.session_pool.fastest_session
if fastest is not None and self.client != fastest:
self.switch_event.set()
if self.is_connected: if self.is_connected:
return self.client.send_request(list_or_method, args) return self.client.send_request(list_or_method, args)
else: else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def pick_fastest_session(self):
sessions = await self.session_pool.get_online_sessions()
done, pending = await asyncio.wait([
self.probe_session(session)
for session in sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session
async def probe_session(self, session: ClientSession):
await session.send_request('server.banner')
return session
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
@ -182,73 +186,57 @@ class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float): def __init__(self, network: BaseNetwork, timeout: float):
self.network = network self.network = network
self.sessions: List[ClientSession] = [] self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self._dead_servers: List[ClientSession] = []
self.maintain_connections_task = None
self.timeout = timeout self.timeout = timeout
# triggered when the master server is out, to speed up reconnect self.new_connection_event = asyncio.Event()
self._lost_master = asyncio.Event()
@property @property
def online(self): def online(self):
for session in self.sessions: return any(not session.is_closing() for session in self.sessions)
if not session.is_closing():
return True @property
return False def available_sessions(self):
return [session for session in self.sessions if session.available]
@property
def fastest_session(self):
if not self.available_sessions:
return None
return min(
[(session.response_time, session) for session in self.available_sessions], key=itemgetter(0)
)[1]
def start(self, default_servers): def start(self, default_servers):
self.sessions = [ callback = self.new_connection_event.set
ClientSession(network=self.network, server=server) self.sessions = {
for server in default_servers ClientSession(
] network=self.network, server=server, on_connect_callback=callback
self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) ): None for server in default_servers
}
self.ensure_connections()
def stop(self): def stop(self):
if self.maintain_connections_task: for session, task in self.sessions.items():
self.maintain_connections_task.cancel() task.cancel()
session.abort()
self.sessions.clear()
def ensure_connections(self):
for session, task in list(self.sessions.items()):
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions: for session in self.sessions:
if not session.is_closing(): session.trigger_urgent_reconnect.set()
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
async def ensure_connections(self): async def wait_for_fastest_session(self):
while True: while not self.fastest_session:
await asyncio.gather(*[ self.trigger_nodelay_connect()
self.ensure_connection(session) self.new_connection_event.clear()
for session in self.sessions await self.new_connection_event.wait()
], return_exceptions=True) return self.fastest_session
try:
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session):
self._dead_servers.append(session)
self.sessions.remove(session)
try:
if session.is_closing():
await session.create_connection(self.timeout)
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
self.sessions.append(session)
self._dead_servers.remove(session)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except socket.gaierror:
log.warning("Could not resolve IP for %s", session.server[0])
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
async def get_online_sessions(self):
while not self.online:
self._lost_master.set()
await asyncio.sleep(0.5)
return self.sessions

View file

@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools import itertools
import json import json
import typing import typing
import asyncio
from functools import partial from functools import partial
from numbers import Number from numbers import Number
@ -745,9 +746,10 @@ class JSONRPCConnection(object):
self._protocol = item self._protocol = item
return self.receive_message(message) return self.receive_message(message)
def cancel_pending_requests(self): def time_out_pending_requests(self):
"""Cancel all pending requests.""" """Times out all pending requests."""
exception = CancelledError() # this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing?
exception = asyncio.TimeoutError()
for request, event in self._requests.values(): for request, event in self._requests.values():
event.result = exception event.result = exception
event.set() event.set()

View file

@ -456,7 +456,7 @@ class RPCSession(SessionBase):
def connection_lost(self, exc): def connection_lost(self, exc):
# Cancel pending requests and message processing # Cancel pending requests and message processing
self.connection.cancel_pending_requests() self.connection.time_out_pending_requests()
super().connection_lost(exc) super().connection_lost(exc)
# External API # External API
@ -473,6 +473,8 @@ class RPCSession(SessionBase):
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
"""Send an RPC request over the network.""" """Send an RPC request over the network."""
if self.is_closing():
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
message, event = self.connection.send_request(Request(method, args)) message, event = self.connection.send_request(Request(method, args))
await self._send_message(message) await self._send_message(message)
await event.wait() await event.wait()

View file

@ -45,10 +45,12 @@ class BroadcastSubscription:
class StreamController: class StreamController:
def __init__(self): def __init__(self, merge_repeated_events=False):
self.stream = Stream(self) self.stream = Stream(self)
self._first_subscription = None self._first_subscription = None
self._last_subscription = None self._last_subscription = None
self._last_event = None
self._merge_repeated = merge_repeated_events
@property @property
def has_listener(self): def has_listener(self):
@ -76,8 +78,10 @@ class StreamController:
return f return f
def add(self, event): def add(self, event):
skip = self._merge_repeated and event == self._last_event
self._last_event = event
return self._notify_and_ensure_future( return self._notify_and_ensure_future(
lambda subscription: subscription._add(event) lambda subscription: None if skip else subscription._add(event)
) )
def add_error(self, exception): def add_error(self, exception):