lbry-sdk/torba/client/basenetwork.py
2019-06-10 23:25:38 -04:00

229 lines
8.4 KiB
Python

import logging
import asyncio
from asyncio import CancelledError
from time import time
from typing import Iterable
from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError
from torba import __version__
from torba.stream import StreamController
log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32
self.max_seconds_idle = 60
self.ping_task = None
async def send_request(self, method, args=()):
try:
return await super().send_request(method, args)
except RPCError as e:
log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args)
raise e
async def ping_forever(self):
# TODO: change to 'ping' on newer protocol (above 1.2)
while not self.is_closing():
if (time() - self.last_send) > self.max_seconds_idle:
await self.send_request('server.banner')
await asyncio.sleep(self.max_seconds_idle//3)
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.ping_task = asyncio.create_task(self.ping_forever())
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
super().connection_lost(exc)
self._on_disconnect_controller.add(True)
if self.ping_task:
self.ping_task.cancel()
class BaseNetwork:
def __init__(self, ledger):
self.config = ledger.config
self.client: ClientSession = None
self.session_pool: SessionPool = None
self.running = False
self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController()
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
async def start(self):
self.running = True
connect_timeout = self.config.get('connect_timeout', 6)
self.session_pool = SessionPool(network=self, timeout=connect_timeout)
self.session_pool.start(self.config['default_servers'])
while True:
try:
self.client = await self.session_pool.pick_fastest_server()
if self.is_connected:
await self.ensure_server_version()
log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server)
self._on_connected_controller.add(True)
await self.client.on_disconnected.first
except CancelledError:
self.running = False
except asyncio.TimeoutError:
log.warning("Timed out while trying to find a server!")
except Exception: # pylint: disable=broad-except
log.exception("Exception while trying to find a server!")
if not self.running:
return
elif self.client:
await self.client.close()
self.client.connection.cancel_pending_requests()
async def stop(self):
self.running = False
if self.session_pool:
self.session_pool.stop()
if self.is_connected:
disconnected = self.client.on_disconnected.first
await self.client.close()
await disconnected
@property
def is_connected(self):
return self.client is not None and not self.client.is_closing()
def rpc(self, list_or_method, args):
if self.is_connected:
return self.client.send_request(list_or_method, args)
else:
raise ConnectionError("Attempting to send rpc request when connection is not available.")
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction])
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address])
def get_transaction(self, tx_hash):
return self.rpc('blockchain.transaction.get', [tx_hash])
def get_transaction_height(self, tx_hash):
return self.rpc('blockchain.transaction.get_height', [tx_hash])
def get_merkle(self, tx_hash, height):
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height])
def get_headers(self, height, count=10000):
return self.rpc('blockchain.block.headers', [height, count])
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True])
def subscribe_address(self, address):
return self.rpc('blockchain.address.subscribe', [address])
class SessionPool:
def __init__(self, network: BaseNetwork, timeout: float):
self.network = network
self.sessions: Iterable[ClientSession] = []
self._dead_servers: Iterable[ClientSession] = []
self.maintain_connections_task = None
self.timeout = timeout
# triggered when the master server is out, to speed up reconnect
self._lost_master = asyncio.Event()
@property
def online(self):
for session in self.sessions:
if not session.is_closing():
return True
return False
def start(self, default_servers):
self.sessions = [
ClientSession(network=self.network, server=server)
for server in default_servers
]
self.maintain_connections_task = asyncio.create_task(self.ensure_connections())
def stop(self):
if self.maintain_connections_task:
self.maintain_connections_task.cancel()
for session in self.sessions:
if not session.is_closing():
session.abort()
self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None
async def ensure_connections(self):
while True:
await asyncio.gather(*[
self.ensure_connection(session)
for session in self.sessions
], return_exceptions=True)
await asyncio.wait([asyncio.sleep(3), self._lost_master.wait()], return_when='FIRST_COMPLETED')
self._lost_master.clear()
if not self.sessions:
self.sessions.extend(self._dead_servers)
self._dead_servers = []
async def ensure_connection(self, session):
if not session.is_closing():
return
try:
return await session.create_connection(self.timeout)
except asyncio.TimeoutError:
log.warning("Timeout connecting to %s:%d", *session.server)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as err: # pylint: disable=broad-except
if 'Connect call failed' in str(err):
log.warning("Could not connect to %s:%d", *session.server)
else:
log.exception("Connecting to %s:%d raised an exception:", *session.server)
self._dead_servers.append(session)
self.sessions.remove(session)
async def pick_fastest_server(self):
self._lost_master.set()
while not self.online:
await asyncio.sleep(0.1)
async def _probe(session):
await session.send_request('server.banner')
return session
done, pending = await asyncio.wait([
_probe(session)
for session in self.sessions if not session.is_closing()
], return_when='FIRST_COMPLETED')
for task in pending:
task.cancel()
for session in done:
return await session