refactor pick_fastest_session

This commit is contained in:
Victor Shyba 2019-06-10 14:53:47 -03:00 committed by Lex Berezhny
parent 9e24f4ca54
commit b75139662b

View file

@ -2,7 +2,7 @@ import logging
import asyncio import asyncio
from asyncio import CancelledError from asyncio import CancelledError
from time import time from time import time
from typing import Iterable from typing import List
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
@ -83,7 +83,7 @@ class BaseNetwork:
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
while True: while True:
try: try:
self.client = await self.session_pool.pick_fastest_server() self.client = await self.pick_fastest_session()
if self.is_connected: if self.is_connected:
await self.ensure_server_version() await self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
@ -120,6 +120,21 @@ class BaseNetwork:
else: else:
raise ConnectionError("Attempting to send rpc request when connection is not available.") raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def pick_fastest_session(self):
sessions = await self.session_pool.get_online_sessions()
done, pending = await asyncio.wait([
self.probe_session(session)
for session in sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session
async def probe_session(self, session: ClientSession):
await session.send_request('server.banner')
return session
def ensure_server_version(self, required='1.2'): def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required]) return self.rpc('server.version', [__version__, required])
@ -152,8 +167,8 @@ class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float): def __init__(self, network: BaseNetwork, timeout: float):
self.network = network self.network = network
self.sessions: Iterable[ClientSession] = [] self.sessions: List[ClientSession] = []
self._dead_servers: Iterable[ClientSession] = [] self._dead_servers: List[ClientSession] = []
self.maintain_connections_task = None self.maintain_connections_task = None
self.timeout = timeout self.timeout = timeout
# triggered when the master server is out, to speed up reconnect # triggered when the master server is out, to speed up reconnect
@ -210,20 +225,8 @@ class SessionPool:
self._dead_servers.append(session) self._dead_servers.append(session)
self.sessions.remove(session) self.sessions.remove(session)
async def pick_fastest_server(self): async def get_online_sessions(self):
self._lost_master.set() self._lost_master.set()
while not self.online: while not self.online:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return self.sessions
async def _probe(session):
await session.send_request('server.banner')
return session
done, pending = await asyncio.wait([
_probe(session)
for session in self.sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session