forked from LBRYCommunity/lbry-sdk
connect to multiple servers
This commit is contained in:
parent
be15601a67
commit
1b5281c23e
2 changed files with 90 additions and 28 deletions
|
@ -2,7 +2,6 @@ import logging
|
|||
import asyncio
|
||||
from unittest.mock import Mock
|
||||
|
||||
from torba.client.baseledger import BaseLedger
|
||||
from torba.client.basenetwork import BaseNetwork
|
||||
from torba.rpc import RPCSession
|
||||
from torba.testcase import IntegrationTestCase, AsyncioTestCase
|
||||
|
@ -70,3 +69,6 @@ class ServerPickingTestCase(AsyncioTestCase):
|
|||
await asyncio.wait_for(network.on_connected.first, timeout=1)
|
||||
self.assertTrue(network.is_connected)
|
||||
self.assertEqual(network.client.server, ('127.0.0.1', 1337))
|
||||
# ensure we are connected to all of them
|
||||
self.assertEqual(len(network.session_pool.sessions), 4)
|
||||
self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions]))
|
||||
|
|
|
@ -58,6 +58,7 @@ class BaseNetwork:
|
|||
def __init__(self, ledger):
|
||||
self.config = ledger.config
|
||||
self.client: ClientSession = None
|
||||
self.session_pool: SessionPool = None
|
||||
self.running = False
|
||||
|
||||
self._on_connected_controller = StreamController()
|
||||
|
@ -74,40 +75,18 @@ class BaseNetwork:
|
|||
'blockchain.address.subscribe': self._on_status_controller,
|
||||
}
|
||||
|
||||
async def pick_fastest_server(self, timeout):
|
||||
async def __probe(server):
|
||||
client = ClientSession(network=self, server=server)
|
||||
try:
|
||||
await client.create_connection(timeout)
|
||||
await client.send_request('server.banner')
|
||||
return client
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
if not client.is_closing():
|
||||
client.abort()
|
||||
raise
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.exception("Connecting to %s:%d raised an exception:", *server)
|
||||
futures = []
|
||||
for server in self.config['default_servers']:
|
||||
futures.append(__probe(server))
|
||||
done, pending = await asyncio.wait(futures, return_when='FIRST_COMPLETED')
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for client in done:
|
||||
return await client
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
delay = 0.0
|
||||
connect_timeout = self.config.get('connect_timeout', 6)
|
||||
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
while True:
|
||||
try:
|
||||
self.client = await self.pick_fastest_server(connect_timeout)
|
||||
self.client = await self.session_pool.pick_fastest_server()
|
||||
if self.is_connected:
|
||||
await self.ensure_server_version()
|
||||
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
|
||||
self._on_connected_controller.add(True)
|
||||
delay = 0.0
|
||||
await self.client.on_disconnected.first
|
||||
except CancelledError:
|
||||
self.running = False
|
||||
|
@ -120,11 +99,11 @@ class BaseNetwork:
|
|||
elif self.client:
|
||||
await self.client.close()
|
||||
self.client.connection.cancel_pending_requests()
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay + 1.0, 10.0)
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self.session_pool:
|
||||
self.session_pool.stop()
|
||||
if self.is_connected:
|
||||
disconnected = self.client.on_disconnected.first
|
||||
await self.client.close()
|
||||
|
@ -166,3 +145,84 @@ class BaseNetwork:
|
|||
|
||||
def subscribe_address(self, address):
|
||||
return self.rpc('blockchain.address.subscribe', [address])
|
||||
|
||||
|
||||
class SessionPool:
|
||||
|
||||
def __init__(self, network: BaseNetwork, timeout: float):
|
||||
self.network = network
|
||||
self.sessions = []
|
||||
self._dead_servers = []
|
||||
self.maintain_connections_task = None
|
||||
self.timeout = timeout
|
||||
# triggered when the master server is out, to speed up reconnect
|
||||
self._lost_master = asyncio.Event()
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
for session in self.sessions:
|
||||
if not session.is_closing():
|
||||
return True
|
||||
return False
|
||||
|
||||
def start(self, default_servers):
|
||||
self.sessions = [
|
||||
ClientSession(network=self.network, server=server)
|
||||
for server in default_servers
|
||||
]
|
||||
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
|
||||
|
||||
def stop(self):
|
||||
if self.maintain_connections_task:
|
||||
self.maintain_connections_task.cancel()
|
||||
for session in self.sessions:
|
||||
if not session.is_closing():
|
||||
session.abort()
|
||||
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
|
||||
|
||||
async def ensure_connections(self):
|
||||
while True:
|
||||
await asyncio.gather(*[
|
||||
self.ensure_connection(session)
|
||||
for session in self.sessions
|
||||
], return_exceptions=True)
|
||||
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
|
||||
self._lost_master.clear()
|
||||
if not self.sessions:
|
||||
self.sessions.extend(self._dead_servers)
|
||||
self._dead_servers = []
|
||||
|
||||
async def ensure_connection(self, session):
|
||||
if not session.is_closing():
|
||||
return
|
||||
try:
|
||||
return await session.create_connection(self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Timeout connecting to %s:%d", *session.server)
|
||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||
raise
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
if 'Connect call failed' in str(err):
|
||||
log.warning("Could not connect to %s:%d", *session.server)
|
||||
else:
|
||||
log.exception("Connecting to %s:%d raised an exception:", *session.server)
|
||||
self._dead_servers.append(session)
|
||||
self.sessions.remove(session)
|
||||
|
||||
async def pick_fastest_server(self):
|
||||
self._lost_master.set()
|
||||
while not self.online:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _probe(session):
|
||||
await session.send_request('server.banner')
|
||||
return session
|
||||
|
||||
done, pending = await asyncio.wait([
|
||||
_probe(session)
|
||||
for session in self.sessions if not session.is_closing()
|
||||
], return_when='FIRST_COMPLETED')
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for session in done:
|
||||
return await session
|
||||
|
|
Loading…
Add table
Reference in a new issue