basenet refactor

This commit is contained in:
Victor Shyba 2019-08-06 02:17:39 -03:00
parent 5728493abb
commit e78e2738ef
2 changed files with 48 additions and 46 deletions

View file

@ -22,9 +22,11 @@ class ReconnectTests(IntegrationTestCase):
async def test_connection_drop_still_receives_events_after_reconnected(self): async def test_connection_drop_still_receives_events_after_reconnected(self):
address1 = await self.account.receiving.get_or_create_usable_address() address1 = await self.account.receiving.get_or_create_usable_address()
# disconnect and send a new tx, should reconnect and get it
self.ledger.network.client.connection_lost(Exception()) self.ledger.network.client.connection_lost(Exception())
self.assertFalse(self.ledger.network.is_connected)
sendtxid = await self.blockchain.send_to_address(address1, 1.1337) sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
await self.on_transaction_id(sendtxid) # mempool await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction_id(sendtxid) # confirmed await self.on_transaction_id(sendtxid) # confirmed
@ -45,8 +47,10 @@ class ReconnectTests(IntegrationTestCase):
await self.ledger.network.get_transaction(sendtxid) await self.ledger.network.get_transaction(sendtxid)
async def test_timeout_then_reconnect(self): async def test_timeout_then_reconnect(self):
# tests that it connects back after some failed attempts
await self.conductor.spv_node.stop() await self.conductor.spv_node.stop()
self.assertFalse(self.ledger.network.is_connected) self.assertFalse(self.ledger.network.is_connected)
await asyncio.sleep(0.2) # let it retry and fail once
await self.conductor.spv_node.start(self.conductor.blockchain_node) await self.conductor.spv_node.start(self.conductor.blockchain_node)
await self.ledger.network.on_connected.first await self.ledger.network.on_connected.first
self.assertTrue(self.ledger.network.is_connected) self.assertTrue(self.ledger.network.is_connected)
@ -79,9 +83,9 @@ class ServerPickingTestCase(AsyncioTestCase):
await self._make_bad_server(), await self._make_bad_server(),
('localhost', 1), ('localhost', 1),
('example.that.doesnt.resolve', 9000), ('example.that.doesnt.resolve', 9000),
await self._make_fake_server(latency=1.2, port=1340), await self._make_fake_server(latency=1.0, port=1340),
await self._make_fake_server(latency=0.5, port=1337), await self._make_fake_server(latency=0.1, port=1337),
await self._make_fake_server(latency=0.7, port=1339), await self._make_fake_server(latency=0.4, port=1339),
], ],
'connect_timeout': 3 'connect_timeout': 3
}) })
@ -89,9 +93,10 @@ class ServerPickingTestCase(AsyncioTestCase):
network = BaseNetwork(ledger) network = BaseNetwork(ledger)
self.addCleanup(network.stop) self.addCleanup(network.stop)
asyncio.ensure_future(network.start()) asyncio.ensure_future(network.start())
await asyncio.wait_for(network.on_connected.first, timeout=3) await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected) self.assertTrue(network.is_connected)
self.assertEqual(network.client.server, ('127.0.0.1', 1337)) self.assertEqual(network.client.server, ('127.0.0.1', 1337))
# ensure we are connected to all of them self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions])) # ensure we are connected to all of them after a while
self.assertEqual(len(network.session_pool.sessions), 3) await asyncio.sleep(1)
self.assertEqual(len(network.session_pool.available_sessions), 3)

View file

@ -1,6 +1,7 @@
import logging import logging
import asyncio import asyncio
from typing import Dict from operator import itemgetter
from typing import Dict, Optional
from time import time from time import time
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -26,7 +27,7 @@ class ClientSession(BaseClientSession):
@property @property
def available(self): def available(self):
return not self.is_closing() and self._can_send.is_set() return not self.is_closing() and self._can_send.is_set() and self.latency < 1 << 32
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
try: try:
@ -39,14 +40,11 @@ class ClientSession(BaseClientSession):
except RPCError as e: except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e raise e
except asyncio.TimeoutError:
self.abort()
raise
async def ensure_session(self): async def ensure_session(self):
# Handles reconnecting and maintaining a session alive # Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2) # TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = 1.0 retry_delay = default_delay = 0.1
while True: while True:
try: try:
if self.is_closing(): if self.is_closing():
@ -54,10 +52,11 @@ class ClientSession(BaseClientSession):
await self.ensure_server_version() await self.ensure_server_version()
if (time() - self.last_send) > self.max_seconds_idle: if (time() - self.last_send) > self.max_seconds_idle:
await self.send_request('server.banner') await self.send_request('server.banner')
retry_delay = 1.0 retry_delay = default_delay
except asyncio.TimeoutError: except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
retry_delay = max(60, retry_delay * 2)
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
def ensure_server_version(self, required='1.2'): def ensure_server_version(self, required='1.2'):
@ -80,9 +79,10 @@ class ClientSession(BaseClientSession):
class BaseNetwork: class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
self.switch_event = asyncio.Event()
self.config = ledger.config self.config = ledger.config
self.client: ClientSession = None self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.session_pool: SessionPool = None self.client: Optional[ClientSession] = None
self.running = False self.running = False
self.remote_height: int = 0 self.remote_height: int = 0
@ -102,8 +102,6 @@ class BaseNetwork:
async def start(self): async def start(self):
self.running = True self.running = True
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
while self.running: while self.running:
@ -112,7 +110,9 @@ class BaseNetwork:
self._update_remote_height((await self.subscribe_headers(),)) self._update_remote_height((await self.subscribe_headers(),))
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
await self.client.on_disconnected.first self.client.on_disconnected.listen(lambda _: self.switch_event.set())
await self.switch_event.wait()
self.switch_event.clear()
except asyncio.CancelledError: except asyncio.CancelledError:
await self.stop() await self.stop()
raise raise
@ -132,12 +132,14 @@ class BaseNetwork:
@property @property
def is_connected(self): def is_connected(self):
return self.session_pool.online return self.client and not self.client.is_closing()
async def rpc(self, list_or_method, args): def rpc(self, list_or_method, args):
if self.is_connected: if self.is_connected:
await self.session_pool.wait_for_fastest_session() fastest = self.session_pool.fastest_session
return await self.session_pool.fastest_session.send_request(list_or_method, args) if self.client != fastest:
self.switch_event.set()
return self.client.send_request(list_or_method, args)
else: else:
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
@ -178,7 +180,6 @@ class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float): def __init__(self, network: BaseNetwork, timeout: float):
self.network = network self.network = network
self.sessions: Dict[ClientSession, asyncio.Task] = dict() self.sessions: Dict[ClientSession, asyncio.Task] = dict()
self.maintain_connections_task = None
self.timeout = timeout self.timeout = timeout
@property @property
@ -193,31 +194,27 @@ class SessionPool:
def fastest_session(self): def fastest_session(self):
if not self.available_sessions: if not self.available_sessions:
return None return None
return min([(session.latency, session) for session in self.available_sessions])[1] return min([(session.latency, session) for session in self.available_sessions], key=itemgetter(0))[1]
def start(self, default_servers): def start(self, default_servers):
for server in default_servers: self.sessions = {
session = ClientSession(network=self.network, server=server) ClientSession(network=self.network, server=server): None for server in default_servers
self.sessions[session] = asyncio.create_task(session.ensure_session()) }
self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) self.ensure_connections()
def stop(self): def stop(self):
if self.maintain_connections_task: for session, task in self.sessions.items():
self.maintain_connections_task.cancel() task.cancel()
self.maintain_connections_task = None session.abort()
for session, maintenance_task in self.sessions.items():
maintenance_task.cancel()
if not session.is_closing():
session.abort()
self.sessions.clear() self.sessions.clear()
async def ensure_connections(self): def ensure_connections(self):
while True: for session, task in list(self.sessions.items()):
log.info("Checking conns") if not task or task.done():
for session, task in list(self.sessions.items()): task = asyncio.create_task(session.ensure_session())
if task.done(): task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = asyncio.create_task(session.ensure_session()) self.sessions[session] = task
await asyncio.wait(self.sessions.items(), timeout=10)
async def wait_for_fastest_session(self): async def wait_for_fastest_session(self):
while True: while True: