import logging import asyncio from time import perf_counter from operator import itemgetter from typing import Dict, Optional, Tuple from lbry import __version__ from lbry.error import IncompatibleWalletServerError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.stream import StreamController log = logging.getLogger(__name__) class ClientSession(BaseClientSession): def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): self.network = network self.server = server super().__init__(*args, **kwargs) self._on_disconnect_controller = StreamController() self.on_disconnected = self._on_disconnect_controller.stream self.framer.max_size = self.max_errors = 1 << 32 self.timeout = timeout self.max_seconds_idle = timeout * 2 self.response_time: Optional[float] = None self.connection_latency: Optional[float] = None self._response_samples = 0 self.pending_amount = 0 self._on_connect_cb = on_connect_callback or (lambda: None) self.trigger_urgent_reconnect = asyncio.Event() @property def available(self): return not self.is_closing() and self.response_time is not None @property def server_address_and_port(self) -> Optional[Tuple[str, int]]: if not self.transport: return None return self.transport.get_extra_info('peername') async def send_timed_server_version_request(self, args=(), timeout=None): timeout = timeout or self.timeout log.debug("send version request to %s:%i", *self.server) start = perf_counter() result = await asyncio.wait_for( super().send_request('server.version', args), timeout=timeout ) current_response_time = perf_counter() - start response_sum = (self.response_time or 0) * self._response_samples + current_response_time self.response_time = response_sum / (self._response_samples + 1) self._response_samples += 1 return result async def send_request(self, method, args=()): self.pending_amount += 1 log.debug("send %s to %s:%i", method, *self.server) try: if method == 'server.version': return await self.send_timed_server_version_request(args, self.timeout) request = asyncio.ensure_future(super().send_request(method, args)) while not request.done(): done, pending = await asyncio.wait([request], timeout=self.timeout) if pending: log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received) if (perf_counter() - self.last_packet_received) < self.timeout: continue log.info("timeout sending %s to %s:%i", method, *self.server) raise asyncio.TimeoutError if done: return request.result() except (RPCError, ProtocolError) as e: log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", *self.server, *e.args) raise e except ConnectionError: log.warning("connection to %s:%i lost", *self.server) self.synchronous_close() raise except asyncio.CancelledError: log.info("cancelled sending %s to %s:%i", method, *self.server) self.synchronous_close() raise finally: self.pending_amount -= 1 async def ensure_session(self): # Handles reconnecting and maintaining a session alive # TODO: change to 'ping' on newer protocol (above 1.2) retry_delay = default_delay = 1.0 while True: try: if self.is_closing(): await self.create_connection(self.timeout) await self.ensure_server_version() self._on_connect_cb() if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: await self.ensure_server_version() retry_delay = default_delay except RPCError as e: await self.close() log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message) retry_delay = 60 * 60 except IncompatibleWalletServerError: await self.close() retry_delay = 60 * 60 log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server) except (asyncio.TimeoutError, OSError): await self.close() retry_delay = min(60, retry_delay * 2) log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) try: await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) except asyncio.TimeoutError: pass finally: self.trigger_urgent_reconnect.clear() async def ensure_server_version(self, required=None, timeout=3): required = required or self.network.PROTOCOL_VERSION response = await asyncio.wait_for( self.send_request('server.version', [__version__, required]), timeout=timeout ) if tuple(int(piece) for piece in response[0].split(".")) < self.network.MINIMUM_REQUIRED: raise IncompatibleWalletServerError(*self.server) return response async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) start = perf_counter() await asyncio.wait_for(connector.create_connection(), timeout=timeout) self.connection_latency = perf_counter() - start async def handle_request(self, request): controller = self.network.subscription_controllers[request.method] controller.add(request.args) def connection_lost(self, exc): log.debug("Connection lost: %s:%d", *self.server) super().connection_lost(exc) self.response_time = None self.connection_latency = None self._response_samples = 0 self.pending_amount = 0 self._on_disconnect_controller.add(True) class Network: PROTOCOL_VERSION = __version__ MINIMUM_REQUIRED = (0, 53, 2) def __init__(self, ledger): self.ledger = ledger self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.client: Optional[ClientSession] = None self._switch_task: Optional[asyncio.Task] = None self.running = False self.remote_height: int = 0 self._concurrency = asyncio.Semaphore(16) self._on_connected_controller = StreamController() self.on_connected = self._on_connected_controller.stream self._on_header_controller = StreamController(merge_repeated_events=True) self.on_header = self._on_header_controller.stream self._on_status_controller = StreamController(merge_repeated_events=True) self.on_status = self._on_status_controller.stream self.subscription_controllers = { 'blockchain.headers.subscribe': self._on_header_controller, 'blockchain.address.subscribe': self._on_status_controller, } @property def config(self): return self.ledger.config async def switch_forever(self): while self.running: if self.is_connected: await self.client.on_disconnected.first self.client = None continue self.client = await self.session_pool.wait_for_fastest_session() log.info("Switching to SPV wallet server: %s:%d", *self.client.server) try: self._update_remote_height((await self.subscribe_headers(),)) self._on_connected_controller.add(True) log.info("Subscribed to headers: %s:%d", *self.client.server) except (asyncio.TimeoutError, ConnectionError): log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server) self.client.synchronous_close() self.client = None async def start(self): self.running = True self._switch_task = asyncio.ensure_future(self.switch_forever()) # this may become unnecessary when there are no more bugs found, # but for now it helps understanding log reports self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) async def stop(self): if self.running: self.running = False self._switch_task.cancel() self.session_pool.stop() @property def is_connected(self): return self.client and not self.client.is_closing() def rpc(self, list_or_method, args, restricted=True): session = self.client if restricted else self.session_pool.fastest_session if session and not session.is_closing(): return session.send_request(list_or_method, args) else: self.session_pool.trigger_nodelay_connect() raise ConnectionError("Attempting to send rpc request when connection is not available.") async def retriable_call(self, function, *args, **kwargs): async with self._concurrency: while self.running: if not self.is_connected: log.warning("Wallet server unavailable, waiting for it to come back and retry.") await self.on_connected.first await self.session_pool.wait_for_fastest_session() try: return await function(*args, **kwargs) except asyncio.TimeoutError: log.warning("Wallet server call timed out, retrying.") except ConnectionError: pass raise asyncio.CancelledError() # if we got here, we are shutting down def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] def get_transaction(self, tx_hash, known_height=None): # use any server if its old, otherwise restrict to who gave us the history restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 return self.rpc('blockchain.transaction.get', [tx_hash], restricted) def get_transaction_height(self, tx_hash, known_height=None): restricted = not known_height or 0 > known_height > self.remote_height - 10 return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) def get_merkle(self, tx_hash, height): restricted = 0 > height > self.remote_height - 10 return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) def get_headers(self, height, count=10000, b64=False): restricted = height >= self.remote_height - 100 return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) # --- Subscribes, history and broadcasts are always aimed towards the master client directly def get_history(self, address): return self.rpc('blockchain.address.get_history', [address], True) def broadcast(self, raw_transaction): return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) def subscribe_headers(self): return self.rpc('blockchain.headers.subscribe', [True], True) async def subscribe_address(self, address, *addresses): addresses = list((address, ) + addresses) try: return await self.rpc('blockchain.address.subscribe', addresses, True) except asyncio.TimeoutError: log.warning( "timed out subscribing to addresses from %s:%i", *self.client.server_address_and_port ) # abort and cancel, we can't lose a subscription, it will happen again on reconnect if self.client: self.client.abort() raise asyncio.CancelledError() def unsubscribe_address(self, address): return self.rpc('blockchain.address.unsubscribe', [address], True) def get_server_features(self): return self.rpc('server.features', (), restricted=True) def get_claims_by_ids(self, claim_ids): return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) def resolve(self, urls): return self.rpc('blockchain.claimtrie.resolve', urls) def claim_search(self, **kwargs): return self.rpc('blockchain.claimtrie.search', kwargs) class SessionPool: def __init__(self, network: Network, timeout: float): self.network = network self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() self.timeout = timeout self.new_connection_event = asyncio.Event() @property def online(self): return any(not session.is_closing() for session in self.sessions) @property def available_sessions(self): return (session for session in self.sessions if session.available) @property def fastest_session(self): if not self.online: return None return min( [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) for session in self.available_sessions] or [(0, None)], key=itemgetter(0) )[1] def _get_session_connect_callback(self, session: ClientSession): loop = asyncio.get_event_loop() def callback(): duplicate_connections = [ s for s in self.sessions if s is not session and s.server_address_and_port == session.server_address_and_port ] already_connected = None if not duplicate_connections else duplicate_connections[0] if already_connected: self.sessions.pop(session).cancel() session.synchronous_close() log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour", session.server[0], already_connected.server[0]) loop.call_later(3600, self._connect_session, session.server) return self.new_connection_event.set() log.info("connected to %s:%i", *session.server) return callback def _connect_session(self, server: Tuple[str, int]): session = None for s in self.sessions: if s.server == server: session = s break if not session: session = ClientSession( network=self.network, server=server ) session._on_connect_cb = self._get_session_connect_callback(session) task = self.sessions.get(session, None) if not task or task.done(): task = asyncio.create_task(session.ensure_session()) task.add_done_callback(lambda _: self.ensure_connections()) self.sessions[session] = task def start(self, default_servers): for server in default_servers: self._connect_session(server) def stop(self): for session, task in self.sessions.items(): task.cancel() session.synchronous_close() self.sessions.clear() def ensure_connections(self): for session in self.sessions: self._connect_session(session.server) def trigger_nodelay_connect(self): # used when other parts of the system sees we might have internet back # bypasses the retry interval for session in self.sessions: session.trigger_urgent_reconnect.set() async def wait_for_fastest_session(self): while not self.fastest_session: self.trigger_nodelay_connect() self.new_connection_event.clear() await self.new_connection_event.wait() return self.fastest_session