basenet refactor
This commit is contained in:
parent
5728493abb
commit
e78e2738ef
2 changed files with 48 additions and 46 deletions
|
@ -22,9 +22,11 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
|
|
||||||
async def test_connection_drop_still_receives_events_after_reconnected(self):
|
async def test_connection_drop_still_receives_events_after_reconnected(self):
|
||||||
address1 = await self.account.receiving.get_or_create_usable_address()
|
address1 = await self.account.receiving.get_or_create_usable_address()
|
||||||
|
# disconnect and send a new tx, should reconnect and get it
|
||||||
self.ledger.network.client.connection_lost(Exception())
|
self.ledger.network.client.connection_lost(Exception())
|
||||||
|
self.assertFalse(self.ledger.network.is_connected)
|
||||||
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
|
sendtxid = await self.blockchain.send_to_address(address1, 1.1337)
|
||||||
await self.on_transaction_id(sendtxid) # mempool
|
await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool
|
||||||
await self.blockchain.generate(1)
|
await self.blockchain.generate(1)
|
||||||
await self.on_transaction_id(sendtxid) # confirmed
|
await self.on_transaction_id(sendtxid) # confirmed
|
||||||
|
|
||||||
|
@ -45,8 +47,10 @@ class ReconnectTests(IntegrationTestCase):
|
||||||
await self.ledger.network.get_transaction(sendtxid)
|
await self.ledger.network.get_transaction(sendtxid)
|
||||||
|
|
||||||
async def test_timeout_then_reconnect(self):
|
async def test_timeout_then_reconnect(self):
|
||||||
|
# tests that it connects back after some failed attempts
|
||||||
await self.conductor.spv_node.stop()
|
await self.conductor.spv_node.stop()
|
||||||
self.assertFalse(self.ledger.network.is_connected)
|
self.assertFalse(self.ledger.network.is_connected)
|
||||||
|
await asyncio.sleep(0.2) # let it retry and fail once
|
||||||
await self.conductor.spv_node.start(self.conductor.blockchain_node)
|
await self.conductor.spv_node.start(self.conductor.blockchain_node)
|
||||||
await self.ledger.network.on_connected.first
|
await self.ledger.network.on_connected.first
|
||||||
self.assertTrue(self.ledger.network.is_connected)
|
self.assertTrue(self.ledger.network.is_connected)
|
||||||
|
@ -79,9 +83,9 @@ class ServerPickingTestCase(AsyncioTestCase):
|
||||||
await self._make_bad_server(),
|
await self._make_bad_server(),
|
||||||
('localhost', 1),
|
('localhost', 1),
|
||||||
('example.that.doesnt.resolve', 9000),
|
('example.that.doesnt.resolve', 9000),
|
||||||
await self._make_fake_server(latency=1.2, port=1340),
|
await self._make_fake_server(latency=1.0, port=1340),
|
||||||
await self._make_fake_server(latency=0.5, port=1337),
|
await self._make_fake_server(latency=0.1, port=1337),
|
||||||
await self._make_fake_server(latency=0.7, port=1339),
|
await self._make_fake_server(latency=0.4, port=1339),
|
||||||
],
|
],
|
||||||
'connect_timeout': 3
|
'connect_timeout': 3
|
||||||
})
|
})
|
||||||
|
@ -89,9 +93,10 @@ class ServerPickingTestCase(AsyncioTestCase):
|
||||||
network = BaseNetwork(ledger)
|
network = BaseNetwork(ledger)
|
||||||
self.addCleanup(network.stop)
|
self.addCleanup(network.stop)
|
||||||
asyncio.ensure_future(network.start())
|
asyncio.ensure_future(network.start())
|
||||||
await asyncio.wait_for(network.on_connected.first, timeout=3)
|
await asyncio.wait_for(network.on_connected.first, timeout=1)
|
||||||
self.assertTrue(network.is_connected)
|
self.assertTrue(network.is_connected)
|
||||||
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
|
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
|
||||||
# ensure we are connected to all of them
|
self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions]))
|
||||||
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))
|
# ensure we are connected to all of them after a while
|
||||||
self.assertEqual(len(network.session_pool.sessions), 3)
|
await asyncio.sleep(1)
|
||||||
|
self.assertEqual(len(network.session_pool.available_sessions), 3)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict
|
from operator import itemgetter
|
||||||
|
from typing import Dict, Optional
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||||
|
@ -26,7 +27,7 @@ class ClientSession(BaseClientSession):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available(self):
|
def available(self):
|
||||||
return not self.is_closing() and self._can_send.is_set()
|
return not self.is_closing() and self._can_send.is_set() and self.latency < 1 << 32
|
||||||
|
|
||||||
async def send_request(self, method, args=()):
|
async def send_request(self, method, args=()):
|
||||||
try:
|
try:
|
||||||
|
@ -39,14 +40,11 @@ class ClientSession(BaseClientSession):
|
||||||
except RPCError as e:
|
except RPCError as e:
|
||||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||||
raise e
|
raise e
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.abort()
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def ensure_session(self):
|
async def ensure_session(self):
|
||||||
# Handles reconnecting and maintaining a session alive
|
# Handles reconnecting and maintaining a session alive
|
||||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||||
retry_delay = 1.0
|
retry_delay = default_delay = 0.1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if self.is_closing():
|
if self.is_closing():
|
||||||
|
@ -54,10 +52,11 @@ class ClientSession(BaseClientSession):
|
||||||
await self.ensure_server_version()
|
await self.ensure_server_version()
|
||||||
if (time() - self.last_send) > self.max_seconds_idle:
|
if (time() - self.last_send) > self.max_seconds_idle:
|
||||||
await self.send_request('server.banner')
|
await self.send_request('server.banner')
|
||||||
retry_delay = 1.0
|
retry_delay = default_delay
|
||||||
except asyncio.TimeoutError:
|
except (asyncio.TimeoutError, OSError):
|
||||||
|
await self.close()
|
||||||
|
retry_delay = min(60, retry_delay * 2)
|
||||||
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||||
retry_delay = max(60, retry_delay * 2)
|
|
||||||
await asyncio.sleep(retry_delay)
|
await asyncio.sleep(retry_delay)
|
||||||
|
|
||||||
def ensure_server_version(self, required='1.2'):
|
def ensure_server_version(self, required='1.2'):
|
||||||
|
@ -80,9 +79,10 @@ class ClientSession(BaseClientSession):
|
||||||
class BaseNetwork:
|
class BaseNetwork:
|
||||||
|
|
||||||
def __init__(self, ledger):
|
def __init__(self, ledger):
|
||||||
|
self.switch_event = asyncio.Event()
|
||||||
self.config = ledger.config
|
self.config = ledger.config
|
||||||
self.client: ClientSession = None
|
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
|
||||||
self.session_pool: SessionPool = None
|
self.client: Optional[ClientSession] = None
|
||||||
self.running = False
|
self.running = False
|
||||||
self.remote_height: int = 0
|
self.remote_height: int = 0
|
||||||
|
|
||||||
|
@ -102,8 +102,6 @@ class BaseNetwork:
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
connect_timeout = self.config.get('connect_timeout', 6)
|
|
||||||
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
|
||||||
self.session_pool.start(self.config['default_servers'])
|
self.session_pool.start(self.config['default_servers'])
|
||||||
self.on_header.listen(self._update_remote_height)
|
self.on_header.listen(self._update_remote_height)
|
||||||
while self.running:
|
while self.running:
|
||||||
|
@ -112,7 +110,9 @@ class BaseNetwork:
|
||||||
self._update_remote_height((await self.subscribe_headers(),))
|
self._update_remote_height((await self.subscribe_headers(),))
|
||||||
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
||||||
self._on_connected_controller.add(True)
|
self._on_connected_controller.add(True)
|
||||||
await self.client.on_disconnected.first
|
self.client.on_disconnected.listen(lambda _: self.switch_event.set())
|
||||||
|
await self.switch_event.wait()
|
||||||
|
self.switch_event.clear()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await self.stop()
|
await self.stop()
|
||||||
raise
|
raise
|
||||||
|
@ -132,12 +132,14 @@ class BaseNetwork:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.session_pool.online
|
return self.client and not self.client.is_closing()
|
||||||
|
|
||||||
async def rpc(self, list_or_method, args):
|
def rpc(self, list_or_method, args):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
await self.session_pool.wait_for_fastest_session()
|
fastest = self.session_pool.fastest_session
|
||||||
return await self.session_pool.fastest_session.send_request(list_or_method, args)
|
if self.client != fastest:
|
||||||
|
self.switch_event.set()
|
||||||
|
return self.client.send_request(list_or_method, args)
|
||||||
else:
|
else:
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
|
@ -178,7 +180,6 @@ class SessionPool:
|
||||||
def __init__(self, network: BaseNetwork, timeout: float):
|
def __init__(self, network: BaseNetwork, timeout: float):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.sessions: Dict[ClientSession, asyncio.Task] = dict()
|
self.sessions: Dict[ClientSession, asyncio.Task] = dict()
|
||||||
self.maintain_connections_task = None
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -193,31 +194,27 @@ class SessionPool:
|
||||||
def fastest_session(self):
|
def fastest_session(self):
|
||||||
if not self.available_sessions:
|
if not self.available_sessions:
|
||||||
return None
|
return None
|
||||||
return min([(session.latency, session) for session in self.available_sessions])[1]
|
return min([(session.latency, session) for session in self.available_sessions], key=itemgetter(0))[1]
|
||||||
|
|
||||||
def start(self, default_servers):
|
def start(self, default_servers):
|
||||||
for server in default_servers:
|
self.sessions = {
|
||||||
session = ClientSession(network=self.network, server=server)
|
ClientSession(network=self.network, server=server): None for server in default_servers
|
||||||
self.sessions[session] = asyncio.create_task(session.ensure_session())
|
}
|
||||||
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
self.ensure_connections()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.maintain_connections_task:
|
for session, task in self.sessions.items():
|
||||||
self.maintain_connections_task.cancel()
|
task.cancel()
|
||||||
self.maintain_connections_task = None
|
session.abort()
|
||||||
for session, maintenance_task in self.sessions.items():
|
|
||||||
maintenance_task.cancel()
|
|
||||||
if not session.is_closing():
|
|
||||||
session.abort()
|
|
||||||
self.sessions.clear()
|
self.sessions.clear()
|
||||||
|
|
||||||
async def ensure_connections(self):
|
def ensure_connections(self):
|
||||||
while True:
|
for session, task in list(self.sessions.items()):
|
||||||
log.info("Checking conns")
|
if not task or task.done():
|
||||||
for session, task in list(self.sessions.items()):
|
task = asyncio.create_task(session.ensure_session())
|
||||||
if task.done():
|
task.add_done_callback(lambda _: self.ensure_connections())
|
||||||
self.sessions[session] = asyncio.create_task(session.ensure_session())
|
self.sessions[session] = task
|
||||||
await asyncio.wait(self.sessions.items(), timeout=10)
|
|
||||||
|
|
||||||
async def wait_for_fastest_session(self):
|
async def wait_for_fastest_session(self):
|
||||||
while True:
|
while True:
|
||||||
|
|
Loading…
Reference in a new issue