connect to multiple servers

This commit is contained in:
Victor Shyba 2019-06-09 20:29:12 -03:00 committed by Lex Berezhny
parent be15601a67
commit 1b5281c23e
2 changed files with 90 additions and 28 deletions

View file

@ -2,7 +2,6 @@ import logging
import asyncio
from unittest.mock import Mock
from torba.client.baseledger import BaseLedger
from torba.client.basenetwork import BaseNetwork
from torba.rpc import RPCSession
from torba.testcase import IntegrationTestCase, AsyncioTestCase
@ -70,3 +69,6 @@ class ServerPickingTestCase(AsyncioTestCase):
await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected)
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
# ensure we are connected to all of them
self.assertEqual(len(network.session_pool.sessions), 4)
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))

View file

@ -58,6 +58,7 @@ class BaseNetwork:
def __init__(self, ledger):
self.config = ledger.config
self.client: ClientSession = None
self.session_pool: SessionPool = None
self.running = False
self._on_connected_controller = StreamController()
@ -74,40 +75,18 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller,
}
async def pick_fastest_server(self, timeout):
async def __probe(server):
client = ClientSession(network=self, server=server)
try:
await client.create_connection(timeout)
await client.send_request('server.banner')
return client
except (asyncio.TimeoutError, asyncio.CancelledError):
if not client.is_closing():
client.abort()
raise
except Exception: # pylint: disable=broad-except
log.exception("Connecting to %s:%d raised an exception:", *server)
futures = []
for server in self.config['default_servers']:
futures.append(__probe(server))
done, pending = await asyncio.wait(futures, return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for client in done:
return await client
async def start(self):
self.running = True
delay = 0.0
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers'])
while True:
try:
self.client = await self.pick_fastest_server(connect_timeout)
self.client = await self.session_pool.pick_fastest_server()
if self.is_connected:
await self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
delay = 0.0
await self.client.on_disconnected.first
except CancelledError:
self.running = False
@ -120,11 +99,11 @@ class BaseNetwork:
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
await asyncio.sleep(delay)
delay = min(delay + 1.0, 10.0)
async def stop(self):
self.running = False
if self.session_pool:
self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
@ -166,3 +145,84 @@ class BaseNetwork:
def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', [address])
class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float):
self.network = network
self.sessions = []
self._dead_servers = []
self.maintain_connections_task = None
self.timeout = timeout
# triggered when the master server is out, to speed up reconnect
self._lost_master = asyncio.Event()
@property
def online(self):
for session in self.sessions:
if not session.is_closing():
return True
return False
def start(self, default_servers):
self.sessions = [
ClientSession(network=self.network, server=server)
for server in default_servers
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
def stop(self):
if self.maintain_connections_task:
self.maintain_connections_task.cancel()
for session in self.sessions:
if not session.is_closing():
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
async def ensure_connections(self):
while True:
await asyncio.gather(*[
self.ensure_connection(session)
for session in self.sessions
], return_exceptions=True)
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session):
if not session.is_closing():
return
try:
return await session.create_connection(self.timeout)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
self._dead_servers.append(session)
self.sessions.remove(session)
async def pick_fastest_server(self):
self._lost_master.set()
while not self.online:
await asyncio.sleep(0.1)
async def _probe(session):
await session.send_request('server.banner')
return session
done, pending = await asyncio.wait([
_probe(session)
for session in self.sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session