refactor basenetwork so each session takes care of itself
This commit is contained in:
parent
af1c8ec35c
commit
5728493abb
2 changed files with 76 additions and 103 deletions
|
@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase):
|
||||||
await self.conductor.start_spv()
|
await self.conductor.start_spv()
|
||||||
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
|
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
|
||||||
await session.create_connection()
|
await session.create_connection()
|
||||||
session.ping_task.cancel()
|
|
||||||
await session.send_request('server.banner', ())
|
await session.send_request('server.banner', ())
|
||||||
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
|
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
|
||||||
self.assertFalse(session.is_closing())
|
self.assertFalse(session.is_closing())
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import CancelledError
|
from typing import Dict
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
|
||||||
import socket
|
|
||||||
|
|
||||||
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
|
||||||
|
|
||||||
|
@ -24,11 +22,20 @@ class ClientSession(BaseClientSession):
|
||||||
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.max_seconds_idle = timeout * 2
|
self.max_seconds_idle = timeout * 2
|
||||||
self.ping_task = None
|
self.latency = 1 << 32
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self):
|
||||||
|
return not self.is_closing() and self._can_send.is_set()
|
||||||
|
|
||||||
async def send_request(self, method, args=()):
|
async def send_request(self, method, args=()):
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout)
|
start = time()
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
super().send_request(method, args), timeout=self.timeout
|
||||||
|
)
|
||||||
|
self.latency = time() - start
|
||||||
|
return result
|
||||||
except RPCError as e:
|
except RPCError as e:
|
||||||
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
|
||||||
raise e
|
raise e
|
||||||
|
@ -36,21 +43,29 @@ class ClientSession(BaseClientSession):
|
||||||
self.abort()
|
self.abort()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def ping_forever(self):
|
async def ensure_session(self):
|
||||||
|
# Handles reconnecting and maintaining a session alive
|
||||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||||
while not self.is_closing():
|
retry_delay = 1.0
|
||||||
if (time() - self.last_send) > self.max_seconds_idle:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
if self.is_closing():
|
||||||
|
await self.create_connection(self.timeout)
|
||||||
|
await self.ensure_server_version()
|
||||||
|
if (time() - self.last_send) > self.max_seconds_idle:
|
||||||
await self.send_request('server.banner')
|
await self.send_request('server.banner')
|
||||||
except:
|
retry_delay = 1.0
|
||||||
self.abort()
|
except asyncio.TimeoutError:
|
||||||
raise
|
log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
|
||||||
await asyncio.sleep(self.max_seconds_idle//3)
|
retry_delay = max(60, retry_delay * 2)
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
|
||||||
|
def ensure_server_version(self, required='1.2'):
|
||||||
|
return self.send_request('server.version', [__version__, required])
|
||||||
|
|
||||||
async def create_connection(self, timeout=6):
|
async def create_connection(self, timeout=6):
|
||||||
connector = Connector(lambda: self, *self.server)
|
connector = Connector(lambda: self, *self.server)
|
||||||
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
|
||||||
self.ping_task = asyncio.create_task(self.ping_forever())
|
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def handle_request(self, request):
|
||||||
controller = self.network.subscription_controllers[request.method]
|
controller = self.network.subscription_controllers[request.method]
|
||||||
|
@ -58,9 +73,8 @@ class ClientSession(BaseClientSession):
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
super().connection_lost(exc)
|
super().connection_lost(exc)
|
||||||
|
self.latency = 1 << 32
|
||||||
self._on_disconnect_controller.add(True)
|
self._on_disconnect_controller.add(True)
|
||||||
if self.ping_task:
|
|
||||||
self.ping_task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNetwork:
|
class BaseNetwork:
|
||||||
|
@ -92,26 +106,20 @@ class BaseNetwork:
|
||||||
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
||||||
self.session_pool.start(self.config['default_servers'])
|
self.session_pool.start(self.config['default_servers'])
|
||||||
self.on_header.listen(self._update_remote_height)
|
self.on_header.listen(self._update_remote_height)
|
||||||
while True:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
self.client = await self.pick_fastest_session()
|
self.client = await self.session_pool.wait_for_fastest_session()
|
||||||
if self.is_connected:
|
self._update_remote_height((await self.subscribe_headers(),))
|
||||||
await self.ensure_server_version()
|
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
||||||
self._update_remote_height((await self.subscribe_headers(),))
|
self._on_connected_controller.add(True)
|
||||||
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
await self.client.on_disconnected.first
|
||||||
self._on_connected_controller.add(True)
|
except asyncio.CancelledError:
|
||||||
await self.client.on_disconnected.first
|
await self.stop()
|
||||||
except CancelledError:
|
raise
|
||||||
self.running = False
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.warning("Timed out while trying to find a server!")
|
pass
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
log.exception("Exception while trying to find a server!")
|
log.exception("Exception while trying to find a server!")
|
||||||
if not self.running:
|
|
||||||
return
|
|
||||||
elif self.client:
|
|
||||||
await self.client.close()
|
|
||||||
self.client.connection.cancel_pending_requests()
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
@ -124,25 +132,15 @@ class BaseNetwork:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.client is not None and not self.client.is_closing()
|
return self.session_pool.online
|
||||||
|
|
||||||
def rpc(self, list_or_method, args):
|
async def rpc(self, list_or_method, args):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
return self.client.send_request(list_or_method, args)
|
await self.session_pool.wait_for_fastest_session()
|
||||||
|
return await self.session_pool.fastest_session.send_request(list_or_method, args)
|
||||||
else:
|
else:
|
||||||
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
raise ConnectionError("Attempting to send rpc request when connection is not available.")
|
||||||
|
|
||||||
async def pick_fastest_session(self):
|
|
||||||
sessions = await self.session_pool.get_online_sessions()
|
|
||||||
done, pending = await asyncio.wait([
|
|
||||||
self.probe_session(session)
|
|
||||||
for session in sessions if not session.is_closing()
|
|
||||||
], return_when='FIRST_COMPLETED')
|
|
||||||
for task in pending:
|
|
||||||
task.cancel()
|
|
||||||
for session in done:
|
|
||||||
return await session
|
|
||||||
|
|
||||||
async def probe_session(self, session: ClientSession):
|
async def probe_session(self, session: ClientSession):
|
||||||
await session.send_request('server.banner')
|
await session.send_request('server.banner')
|
||||||
return session
|
return session
|
||||||
|
@ -150,9 +148,6 @@ class BaseNetwork:
|
||||||
def _update_remote_height(self, header_args):
|
def _update_remote_height(self, header_args):
|
||||||
self.remote_height = header_args[0]["height"]
|
self.remote_height = header_args[0]["height"]
|
||||||
|
|
||||||
def ensure_server_version(self, required='1.2'):
|
|
||||||
return self.rpc('server.version', [__version__, required])
|
|
||||||
|
|
||||||
def broadcast(self, raw_transaction):
|
def broadcast(self, raw_transaction):
|
||||||
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
|
||||||
|
|
||||||
|
@ -172,83 +167,62 @@ class BaseNetwork:
|
||||||
return self.rpc('blockchain.block.headers', [height, count])
|
return self.rpc('blockchain.block.headers', [height, count])
|
||||||
|
|
||||||
def subscribe_headers(self):
|
def subscribe_headers(self):
|
||||||
return self.rpc('blockchain.headers.subscribe', [True])
|
return self.client.send_request('blockchain.headers.subscribe', [True])
|
||||||
|
|
||||||
def subscribe_address(self, address):
|
def subscribe_address(self, address):
|
||||||
return self.rpc('blockchain.address.subscribe', [address])
|
return self.client.send_request('blockchain.address.subscribe', [address])
|
||||||
|
|
||||||
|
|
||||||
class SessionPool:
|
class SessionPool:
|
||||||
|
|
||||||
def __init__(self, network: BaseNetwork, timeout: float):
|
def __init__(self, network: BaseNetwork, timeout: float):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.sessions: List[ClientSession] = []
|
self.sessions: Dict[ClientSession, asyncio.Task] = dict()
|
||||||
self._dead_servers: List[ClientSession] = []
|
|
||||||
self.maintain_connections_task = None
|
self.maintain_connections_task = None
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
# triggered when the master server is out, to speed up reconnect
|
|
||||||
self._lost_master = asyncio.Event()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def online(self):
|
def online(self):
|
||||||
for session in self.sessions:
|
return any(not session.is_closing() for session in self.sessions)
|
||||||
if not session.is_closing():
|
|
||||||
return True
|
@property
|
||||||
return False
|
def available_sessions(self):
|
||||||
|
return [session for session in self.sessions if session.available]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fastest_session(self):
|
||||||
|
if not self.available_sessions:
|
||||||
|
return None
|
||||||
|
return min([(session.latency, session) for session in self.available_sessions])[1]
|
||||||
|
|
||||||
def start(self, default_servers):
|
def start(self, default_servers):
|
||||||
self.sessions = [
|
for server in default_servers:
|
||||||
ClientSession(network=self.network, server=server)
|
session = ClientSession(network=self.network, server=server)
|
||||||
for server in default_servers
|
self.sessions[session] = asyncio.create_task(session.ensure_session())
|
||||||
]
|
|
||||||
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.maintain_connections_task:
|
if self.maintain_connections_task:
|
||||||
self.maintain_connections_task.cancel()
|
self.maintain_connections_task.cancel()
|
||||||
for session in self.sessions:
|
self.maintain_connections_task = None
|
||||||
|
for session, maintenance_task in self.sessions.items():
|
||||||
|
maintenance_task.cancel()
|
||||||
if not session.is_closing():
|
if not session.is_closing():
|
||||||
session.abort()
|
session.abort()
|
||||||
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
|
self.sessions.clear()
|
||||||
|
|
||||||
async def ensure_connections(self):
|
async def ensure_connections(self):
|
||||||
while True:
|
while True:
|
||||||
await asyncio.gather(*[
|
log.info("Checking conns")
|
||||||
self.ensure_connection(session)
|
for session, task in list(self.sessions.items()):
|
||||||
for session in self.sessions
|
if task.done():
|
||||||
], return_exceptions=True)
|
self.sessions[session] = asyncio.create_task(session.ensure_session())
|
||||||
try:
|
await asyncio.wait(self.sessions.items(), timeout=10)
|
||||||
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
self._lost_master.clear()
|
|
||||||
if not self.sessions:
|
|
||||||
self.sessions.extend(self._dead_servers)
|
|
||||||
self._dead_servers = []
|
|
||||||
|
|
||||||
async def ensure_connection(self, session):
|
async def wait_for_fastest_session(self):
|
||||||
self._dead_servers.append(session)
|
while True:
|
||||||
self.sessions.remove(session)
|
fastest = self.fastest_session
|
||||||
try:
|
if fastest:
|
||||||
if session.is_closing():
|
return fastest
|
||||||
await session.create_connection(self.timeout)
|
|
||||||
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
|
|
||||||
self.sessions.append(session)
|
|
||||||
self._dead_servers.remove(session)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.warning("Timeout connecting to %s:%d", *session.server)
|
|
||||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
|
||||||
raise
|
|
||||||
except socket.gaierror:
|
|
||||||
log.warning("Could not resolve IP for %s", session.server[0])
|
|
||||||
except Exception as err: # pylint: disable=broad-except
|
|
||||||
if 'Connect call failed' in str(err):
|
|
||||||
log.warning("Could not connect to %s:%d", *session.server)
|
|
||||||
else:
|
else:
|
||||||
log.exception("Connecting to %s:%d raised an exception:", *session.server)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
async def get_online_sessions(self):
|
|
||||||
while not self.online:
|
|
||||||
self._lost_master.set()
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
return self.sessions
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue