refactor basenetwork so each session takes care of itself

This commit is contained in:
Victor Shyba 2019-08-04 19:09:40 -03:00
parent af1c8ec35c
commit 5728493abb
2 changed files with 76 additions and 103 deletions

View file

@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase):
await self.conductor.start_spv() await self.conductor.start_spv()
session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2) session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2)
await session.create_connection() await session.create_connection()
session.ping_task.cancel()
await session.send_request('server.banner', ()) await session.send_request('server.banner', ())
self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1)
self.assertFalse(session.is_closing()) self.assertFalse(session.is_closing())

View file

@ -1,9 +1,7 @@
import logging import logging
import asyncio import asyncio
from asyncio import CancelledError from typing import Dict
from time import time from time import time
from typing import List
import socket
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -24,11 +22,20 @@ class ClientSession(BaseClientSession):
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout self.timeout = timeout
self.max_seconds_idle = timeout * 2 self.max_seconds_idle = timeout * 2
self.ping_task = None self.latency = 1 << 32
@property
def available(self):
return not self.is_closing() and self._can_send.is_set()
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
try: try:
return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) start = time()
result = await asyncio.wait_for(
super().send_request(method, args), timeout=self.timeout
)
self.latency = time() - start
return result
except RPCError as e: except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e raise e
@ -36,21 +43,29 @@ class ClientSession(BaseClientSession):
self.abort() self.abort()
raise raise
async def ping_forever(self): async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2) # TODO: change to 'ping' on newer protocol (above 1.2)
while not self.is_closing(): retry_delay = 1.0
if (time() - self.last_send) > self.max_seconds_idle: while True:
try: try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
if (time() - self.last_send) > self.max_seconds_idle:
await self.send_request('server.banner') await self.send_request('server.banner')
except: retry_delay = 1.0
self.abort() except asyncio.TimeoutError:
raise log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
await asyncio.sleep(self.max_seconds_idle//3) retry_delay = max(60, retry_delay * 2)
await asyncio.sleep(retry_delay)
def ensure_server_version(self, required='1.2'):
return self.send_request('server.version', [__version__, required])
async def create_connection(self, timeout=6): async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server) connector = Connector(lambda: self, *self.server)
await asyncio.wait_for(connector.create_connection(), timeout=timeout) await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.ping_task = asyncio.create_task(self.ping_forever())
async def handle_request(self, request): async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method] controller = self.network.subscription_controllers[request.method]
@ -58,9 +73,8 @@ class ClientSession(BaseClientSession):
def connection_lost(self, exc): def connection_lost(self, exc):
super().connection_lost(exc) super().connection_lost(exc)
self.latency = 1 << 32
self._on_disconnect_controller.add(True) self._on_disconnect_controller.add(True)
if self.ping_task:
self.ping_task.cancel()
class BaseNetwork: class BaseNetwork:
@ -92,26 +106,20 @@ class BaseNetwork:
self.session_pool = SessionPool(network=self, timeout=connect_timeout) self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
while True: while self.running:
try: try:
self.client = await self.pick_fastest_session() self.client = await self.session_pool.wait_for_fastest_session()
if self.is_connected:
await self.ensure_server_version()
self._update_remote_height((await self.subscribe_headers(),)) self._update_remote_height((await self.subscribe_headers(),))
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
await self.client.on_disconnected.first await self.client.on_disconnected.first
except CancelledError: except asyncio.CancelledError:
self.running = False await self.stop()
raise
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("Timed out while trying to find a server!") pass
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
log.exception("Exception while trying to find a server!") log.exception("Exception while trying to find a server!")
if not self.running:
return
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
async def stop(self): async def stop(self):
self.running = False self.running = False
@ -124,25 +132,15 @@ class BaseNetwork:
@property @property
def is_connected(self): def is_connected(self):
return self.client is not None and not self.client.is_closing() return self.session_pool.online
def rpc(self, list_or_method, args): async def rpc(self, list_or_method, args):
if self.is_connected: if self.is_connected:
return self.client.send_request(list_or_method, args) await self.session_pool.wait_for_fastest_session()
return await self.session_pool.fastest_session.send_request(list_or_method, args)
else: else:
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def pick_fastest_session(self):
sessions = await self.session_pool.get_online_sessions()
done, pending = await asyncio.wait([
self.probe_session(session)
for session in sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session
async def probe_session(self, session: ClientSession): async def probe_session(self, session: ClientSession):
await session.send_request('server.banner') await session.send_request('server.banner')
return session return session
@ -150,9 +148,6 @@ class BaseNetwork:
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
@ -172,83 +167,62 @@ class BaseNetwork:
return self.rpc('blockchain.block.headers', [height, count]) return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True]) return self.client.send_request('blockchain.headers.subscribe', [True])
def subscribe_address(self, address): def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', [address]) return self.client.send_request('blockchain.address.subscribe', [address])
class SessionPool: class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float): def __init__(self, network: BaseNetwork, timeout: float):
self.network = network self.network = network
self.sessions: List[ClientSession] = [] self.sessions: Dict[ClientSession, asyncio.Task] = dict()
self._dead_servers: List[ClientSession] = []
self.maintain_connections_task = None self.maintain_connections_task = None
self.timeout = timeout self.timeout = timeout
# triggered when the master server is out, to speed up reconnect
self._lost_master = asyncio.Event()
@property @property
def online(self): def online(self):
for session in self.sessions: return any(not session.is_closing() for session in self.sessions)
if not session.is_closing():
return True @property
return False def available_sessions(self):
return [session for session in self.sessions if session.available]
@property
def fastest_session(self):
if not self.available_sessions:
return None
return min([(session.latency, session) for session in self.available_sessions])[1]
def start(self, default_servers): def start(self, default_servers):
self.sessions = [ for server in default_servers:
ClientSession(network=self.network, server=server) session = ClientSession(network=self.network, server=server)
for server in default_servers self.sessions[session] = asyncio.create_task(session.ensure_session())
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
def stop(self): def stop(self):
if self.maintain_connections_task: if self.maintain_connections_task:
self.maintain_connections_task.cancel() self.maintain_connections_task.cancel()
for session in self.sessions: self.maintain_connections_task = None
for session, maintenance_task in self.sessions.items():
maintenance_task.cancel()
if not session.is_closing(): if not session.is_closing():
session.abort() session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None self.sessions.clear()
async def ensure_connections(self): async def ensure_connections(self):
while True: while True:
await asyncio.gather(*[ log.info("Checking conns")
self.ensure_connection(session) for session, task in list(self.sessions.items()):
for session in self.sessions if task.done():
], return_exceptions=True) self.sessions[session] = asyncio.create_task(session.ensure_session())
try: await asyncio.wait(self.sessions.items(), timeout=10)
await asyncio.wait_for(self._lost_master.wait(), timeout=3)
except asyncio.TimeoutError:
pass
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session): async def wait_for_fastest_session(self):
self._dead_servers.append(session) while True:
self.sessions.remove(session) fastest = self.fastest_session
try: if fastest:
if session.is_closing(): return fastest
await session.create_connection(self.timeout)
await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout)
self.sessions.append(session)
self._dead_servers.remove(session)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except socket.gaierror:
log.warning("Could not resolve IP for %s", session.server[0])
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else: else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
async def get_online_sessions(self):
while not self.online:
self._lost_master.set()
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
return self.sessions