connect to multiple servers

This commit is contained in:
Victor Shyba 2019-06-09 20:29:12 -03:00 committed by Lex Berezhny
parent be15601a67
commit 1b5281c23e
2 changed files with 90 additions and 28 deletions

View file

@ -2,7 +2,6 @@ import logging
import asyncio import asyncio
from unittest.mock import Mock from unittest.mock import Mock
from torba.client.baseledger import BaseLedger
from torba.client.basenetwork import BaseNetwork from torba.client.basenetwork import BaseNetwork
from torba.rpc import RPCSession from torba.rpc import RPCSession
from torba.testcase import IntegrationTestCase, AsyncioTestCase from torba.testcase import IntegrationTestCase, AsyncioTestCase
@ -70,3 +69,6 @@ class ServerPickingTestCase(AsyncioTestCase):
await asyncio.wait_for(network.on_connected.first, timeout=1) await asyncio.wait_for(network.on_connected.first, timeout=1)
self.assertTrue(network.is_connected) self.assertTrue(network.is_connected)
self.assertEqual(network.client.server, ('127.0.0.1', 1337)) self.assertEqual(network.client.server, ('127.0.0.1', 1337))
# ensure we are connected to all of them
self.assertEqual(len(network.session_pool.sessions), 4)
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))

View file

@ -58,6 +58,7 @@ class BaseNetwork:
def __init__(self, ledger): def __init__(self, ledger):
self.config = ledger.config self.config = ledger.config
self.client: ClientSession = None self.client: ClientSession = None
self.session_pool: SessionPool = None
self.running = False self.running = False
self._on_connected_controller = StreamController() self._on_connected_controller = StreamController()
@ -74,40 +75,18 @@ class BaseNetwork:
'blockchain.address.subscribe': self._on_status_controller, 'blockchain.address.subscribe': self._on_status_controller,
} }
async def pick_fastest_server(self, timeout):
async def __probe(server):
client = ClientSession(network=self, server=server)
try:
await client.create_connection(timeout)
await client.send_request('server.banner')
return client
except (asyncio.TimeoutError, asyncio.CancelledError):
if not client.is_closing():
client.abort()
raise
except Exception: # pylint: disable=broad-except
log.exception("Connecting to %s:%d raised an exception:", *server)
futures = []
for server in self.config['default_servers']:
futures.append(__probe(server))
done, pending = await asyncio.wait(futures, return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for client in done:
return await client
async def start(self): async def start(self):
self.running = True self.running = True
delay = 0.0
connect_timeout = self.config.get('connect_timeout', 6) connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers'])
while True: while True:
try: try:
self.client = await self.pick_fastest_server(connect_timeout) self.client = await self.session_pool.pick_fastest_server()
if self.is_connected: if self.is_connected:
await self.ensure_server_version() await self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
delay = 0.0
await self.client.on_disconnected.first await self.client.on_disconnected.first
except CancelledError: except CancelledError:
self.running = False self.running = False
@ -120,11 +99,11 @@ class BaseNetwork:
elif self.client: elif self.client:
await self.client.close() await self.client.close()
self.client.connection.cancel_pending_requests() self.client.connection.cancel_pending_requests()
await asyncio.sleep(delay)
delay = min(delay + 1.0, 10.0)
async def stop(self): async def stop(self):
self.running = False self.running = False
if self.session_pool:
self.session_pool.stop()
if self.is_connected: if self.is_connected:
disconnected = self.client.on_disconnected.first disconnected = self.client.on_disconnected.first
await self.client.close() await self.client.close()
@ -166,3 +145,84 @@ class BaseNetwork:
def subscribe_address(self, address): def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', [address]) return self.rpc('blockchain.address.subscribe', [address])
class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float):
self.network = network
self.sessions = []
self._dead_servers = []
self.maintain_connections_task = None
self.timeout = timeout
# triggered when the master server is out, to speed up reconnect
self._lost_master = asyncio.Event()
@property
def online(self):
for session in self.sessions:
if not session.is_closing():
return True
return False
def start(self, default_servers):
self.sessions = [
ClientSession(network=self.network, server=server)
for server in default_servers
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
def stop(self):
if self.maintain_connections_task:
self.maintain_connections_task.cancel()
for session in self.sessions:
if not session.is_closing():
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
async def ensure_connections(self):
while True:
await asyncio.gather(*[
self.ensure_connection(session)
for session in self.sessions
], return_exceptions=True)
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session):
if not session.is_closing():
return
try:
return await session.create_connection(self.timeout)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
self._dead_servers.append(session)
self.sessions.remove(session)
async def pick_fastest_server(self):
self._lost_master.set()
while not self.online:
await asyncio.sleep(0.1)
async def _probe(session):
await session.send_request('server.banner')
return session
done, pending = await asyncio.wait([
_probe(session)
for session in self.sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session