From 13b473403e1e0b88eddd17d8fc1c7c27d6926782 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 17 Nov 2020 10:57:22 -0500 Subject: [PATCH] one spv connection per loaded wallet --- lbry/extras/daemon/daemon.py | 17 ++-- lbry/wallet/ledger.py | 55 +++++++----- lbry/wallet/network.py | 158 ++++++++++++++++++++++++++-------- lbry/wallet/server/session.py | 5 +- 4 files changed, 166 insertions(+), 69 deletions(-) diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 8f0ae21ff..1ad87af4a 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -1303,8 +1303,8 @@ class Daemon(metaclass=JSONRPCServerType): 'name': SingleKey.name if single_key else HierarchicalDeterministic.name } ) - if self.ledger.network.is_connected: - await self.ledger.subscribe_account(account) + await self.ledger.network.connect_wallets(wallet.id) + await self.ledger.subscribe_account(account) wallet.save() if not skip_on_startup: with self.conf.update_config() as c: @@ -1331,9 +1331,8 @@ class Daemon(metaclass=JSONRPCServerType): if not os.path.exists(wallet_path): raise Exception(f"Wallet at path '{wallet_path}' was not found.") wallet = self.wallet_manager.import_wallet(wallet_path) - if self.ledger.network.is_connected: - for account in wallet.accounts: - await self.ledger.subscribe_account(account) + if not self.ledger.network.is_connected(wallet.id): + await self.ledger.network.connect_wallets(wallet.id) return wallet @requires("wallet") @@ -1619,7 +1618,7 @@ class Daemon(metaclass=JSONRPCServerType): } ) wallet.save() - if self.ledger.network.is_connected: + if self.ledger.network.is_connected(wallet.id): await self.ledger.subscribe_account(account) return account @@ -1647,7 +1646,7 @@ class Daemon(metaclass=JSONRPCServerType): } ) wallet.save() - if self.ledger.network.is_connected: + if self.ledger.network.is_connected(wallet.id): await self.ledger.subscribe_account(account) return account @@ -1863,7 +1862,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet_changed = False if data is not None: added_accounts = wallet.merge(self.wallet_manager, password, data) - if added_accounts and self.ledger.network.is_connected: + if added_accounts and self.ledger.network.is_connected(wallet.id): if blocking: await asyncio.wait([ a.ledger.subscribe_account(a) for a in added_accounts @@ -2957,7 +2956,7 @@ class Daemon(metaclass=JSONRPCServerType): 'public_key': data['holding_public_key'], 'address_generator': {'name': 'single-address'} }) - if self.ledger.network.is_connected: + if self.ledger.network.is_connected(wallet.id): await self.ledger.subscribe_account(account) await self.ledger._update_tasks.done.wait() # Case 3: the holding address has changed and we can't create or find an account for it diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 1676b01f1..d2d1ff645 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -122,7 +122,7 @@ class Ledger(metaclass=LedgerRegistry): self.headers.checkpoints = self.checkpoints self.network: Network = self.config.get('network') or Network(self) self.network.on_header.listen(self.receive_header) - self.network.on_status.listen(self.process_status_update) + # self.network.on_status.listen(self.process_status_update) self.network.on_connected.listen(self.join_network) self.accounts = [] @@ -324,12 +324,15 @@ class Ledger(metaclass=LedgerRegistry): self.db.open(), self.headers.open() ]) + fully_synced = self.on_ready.first + asyncio.create_task(self.network.start()) await self.network.on_connected.first async with self._header_processing_lock: await self._update_tasks.add(self.initial_headers_sync()) await fully_synced + await self.db.release_all_outputs() await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) @@ -448,11 +451,13 @@ class Ledger(metaclass=LedgerRegistry): ) async def subscribe_accounts(self): - if self.network.is_connected and self.accounts: - log.info("Subscribe to %i accounts", len(self.accounts)) - await asyncio.wait([ - self.subscribe_account(a) for a in self.accounts - ]) + wallet_ids = {account.wallet.id for account in self.accounts} + await self.network.connect_wallets(*wallet_ids) + # if self.network.is_connected and self.accounts: + log.info("Subscribe to %i accounts", len(self.accounts)) + await asyncio.wait([ + self.subscribe_account(a) for a in self.accounts + ]) async def subscribe_account(self, account: Account): for address_manager in account.address_managers.values(): @@ -460,8 +465,10 @@ class Ledger(metaclass=LedgerRegistry): await account.ensure_address_gap() async def unsubscribe_account(self, account: Account): + session = self.network.get_wallet_session(account.wallet) for address in await account.get_addresses(): - await self.network.unsubscribe_address(address) + await self.network.unsubscribe_address(address, session=session) + await self.network.close_wallet_session(account.wallet) async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]): await self.subscribe_addresses(address_manager, addresses) @@ -470,28 +477,29 @@ class Ledger(metaclass=LedgerRegistry): ) async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): - if self.network.is_connected and addresses: + if self.network.is_connected(address_manager.account.wallet.id) and addresses: + session = self.network.get_wallet_session(address_manager.account.wallet) addresses_remaining = list(addresses) while addresses_remaining: batch = addresses_remaining[:batch_size] - results = await self.network.subscribe_address(*batch) + results = await self.network.subscribe_address(session, *batch) for address, remote_status in zip(batch, results): - self._update_tasks.add(self.update_history(address, remote_status, address_manager)) + self._update_tasks.add(self.update_history(session, address, remote_status, address_manager)) addresses_remaining = addresses_remaining[batch_size:] - if self.network.client and self.network.client.server_address_and_port: + if session and session.server_address_and_port: log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), - len(addresses), *self.network.client.server_address_and_port) - if self.network.client and self.network.client.server_address_and_port: + len(addresses), *session.server_address_and_port) + if session and session.server_address_and_port: log.info( "finished subscribing to %i addresses on %s:%i", len(addresses), - *self.network.client.server_address_and_port + *session.server_address_and_port ) - def process_status_update(self, update): + def process_status_update(self, client, update): address, remote_status = update - self._update_tasks.add(self.update_history(address, remote_status)) + self._update_tasks.add(self.update_history(client, address, remote_status)) - async def update_history(self, address, remote_status, address_manager: AddressManager = None, + async def update_history(self, client, address, remote_status, address_manager: AddressManager = None, reattempt_update: bool = True): async with self._address_update_locks[address]: self._known_addresses_out_of_sync.discard(address) @@ -500,7 +508,9 @@ class Ledger(metaclass=LedgerRegistry): if local_status == remote_status: return True - remote_history = await self.network.retriable_call(self.network.get_history, address) + # remote_history = await self.network.retriable_call(self.network.get_history, address, client) + remote_history = await self.network.get_history(address, session=client) + remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) we_need = set(remote_history) - set(local_history) if not we_need: @@ -563,7 +573,7 @@ class Ledger(metaclass=LedgerRegistry): "request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs), len(remote_history), address ) - requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address) + requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address, client) for tx in requested_txes: pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:" synced_txs.append(tx) @@ -603,7 +613,7 @@ class Ledger(metaclass=LedgerRegistry): if self._tx_cache.get(txid) is not cache_item: log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update) if reattempt_update: - return await self.update_history(address, remote_status, address_manager, False) + return await self.update_history(client, address, remote_status, address_manager, False) return False local_status, local_history = \ @@ -741,7 +751,7 @@ class Ledger(metaclass=LedgerRegistry): await _single_batch(batch) return transactions - async def _request_transaction_batch(self, to_request, remote_history_size, address): + async def _request_transaction_batch(self, to_request, remote_history_size, address, session): header_cache = {} batches = [[]] remote_heights = {} @@ -766,7 +776,8 @@ class Ledger(metaclass=LedgerRegistry): async def _single_batch(batch): this_batch_synced = [] - batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) + # batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, session) + batch_result = await self.network.get_transaction_batch(batch, session) for txid, (raw, merkle) in batch_result.items(): remote_height = remote_heights[txid] merkle_height = merkle['block_height'] diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index d4b460600..28d7503fc 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -6,7 +6,7 @@ from operator import itemgetter from contextlib import asynccontextmanager from functools import partial from typing import Dict, Optional, Tuple - +import typing import aiohttp from lbry import __version__ @@ -14,6 +14,7 @@ from lbry.error import IncompatibleWalletServerError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.stream import StreamController + log = logging.getLogger(__name__) @@ -33,11 +34,26 @@ class ClientSession(BaseClientSession): self.pending_amount = 0 self._on_connect_cb = on_connect_callback or (lambda: None) self.trigger_urgent_reconnect = asyncio.Event() + self._connected = asyncio.Event() + self._disconnected = asyncio.Event() + self._disconnected.set() + + self._on_status_controller = StreamController(merge_repeated_events=True) + self.on_status = self._on_status_controller.stream + self.on_status.listen(partial(self.network.ledger.process_status_update, self)) + self.subscription_controllers = { + 'blockchain.headers.subscribe': self.network._on_header_controller, + 'blockchain.address.subscribe': self._on_status_controller, + } @property def available(self): return not self.is_closing() and self.response_time is not None + @property + def is_connected(self) -> bool: + return self._connected.is_set() + @property def server_address_and_port(self) -> Optional[Tuple[str, int]]: if not self.transport: @@ -144,7 +160,7 @@ class ClientSession(BaseClientSession): self.connection_latency = perf_counter() - start async def handle_request(self, request): - controller = self.network.subscription_controllers[request.method] + controller = self.subscription_controllers[request.method] controller.add(request.args) def connection_lost(self, exc): @@ -154,6 +170,13 @@ class ClientSession(BaseClientSession): self.connection_latency = None self._response_samples = 0 self._on_disconnect_controller.add(True) + self._connected.clear() + self._disconnected.set() + + def connection_made(self, transport): + super(ClientSession, self).connection_made(transport) + self._disconnected.clear() + self._connected.set() class Network: @@ -165,6 +188,7 @@ class Network: self.ledger = ledger self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.client: Optional[ClientSession] = None + self.clients: Dict[str, ClientSession] = {} self.server_features = None self._switch_task: Optional[asyncio.Task] = None self.running = False @@ -177,27 +201,31 @@ class Network: self._on_header_controller = StreamController(merge_repeated_events=True) self.on_header = self._on_header_controller.stream - self._on_status_controller = StreamController(merge_repeated_events=True) - self.on_status = self._on_status_controller.stream self.subscription_controllers = { 'blockchain.headers.subscribe': self._on_header_controller, - 'blockchain.address.subscribe': self._on_status_controller, } self.aiohttp_session: Optional[aiohttp.ClientSession] = None + def get_wallet_session(self, wallet): + return self.clients[wallet.id] + + async def close_wallet_session(self, wallet): + await self.clients.pop(wallet.id).close() + @property def config(self): return self.ledger.config async def switch_forever(self): while self.running: - if self.is_connected: + if self.is_connected(): await self.client.on_disconnected.first self.server_features = None self.client = None continue + self.client = await self.session_pool.wait_for_fastest_session() log.info("Switching to SPV wallet server: %s:%d", *self.client.server) try: @@ -220,6 +248,11 @@ class Network: self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) + await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts}) + await self.ledger.subscribe_accounts() + + async def connect_wallets(self, *wallet_ids): + await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids]) async def stop(self): if self.running: @@ -228,9 +261,19 @@ class Network: self._switch_task.cancel() self.session_pool.stop() - @property - def is_connected(self): - return self.client and not self.client.is_closing() + def is_connected(self, wallet_id: str = None): + if wallet_id is None: + if not self.client: + return False + return self.client.is_connected + if wallet_id not in self.clients: + return False + return self.clients[wallet_id].is_connected + + def connect_wallet_client(self, wallet_id: str, server: Tuple[str, int]): + client = ClientSession(network=self, server=server) + self.clients[wallet_id] = client + return client def rpc(self, list_or_method, args, restricted=True, session=None): session = session or (self.client if restricted else self.session_pool.fastest_session) @@ -256,7 +299,8 @@ class Network: raise asyncio.CancelledError() # if we got here, we are shutting down @asynccontextmanager - async def single_call_context(self, function, *args, **kwargs): + async def fastest_connection_context(self): + if not self.is_connected: log.warning("Wallet server unavailable, waiting for it to come back and retry.") await self.on_connected.first @@ -264,6 +308,38 @@ class Network: server = self.session_pool.fastest_session.server session = ClientSession(network=self, server=server) + async def call_with_reconnect(function, *args, **kwargs): + nonlocal session + + while self.running: + if not session.is_connected: + try: + await session.create_connection() + except asyncio.TimeoutError: + if not session.is_connected: + log.warning("Wallet server unavailable, waiting for it to come back and retry.") + await self.on_connected.first + await self.session_pool.wait_for_fastest_session() + server = self.session_pool.fastest_session.server + session = ClientSession(network=self, server=server) + try: + return await partial(function, *args, session_override=session, **kwargs)(*args, **kwargs) + except asyncio.TimeoutError: + log.warning("'%s' failed, retrying", function.__name__) + try: + yield call_with_reconnect + finally: + await session.close() + + @asynccontextmanager + async def single_call_context(self, function, *args, **kwargs): + if not self.is_connected(): + log.warning("Wallet server unavailable, waiting for it to come back and retry.") + await self.on_connected.first + await self.session_pool.wait_for_fastest_session() + server = self.session_pool.fastest_session.server + session = ClientSession(network=self, server=server) + async def call_with_reconnect(*a, **kw): while self.running: if not session.available: @@ -280,71 +356,71 @@ class Network: def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def get_transaction(self, tx_hash, known_height=None): + def get_transaction(self, tx_hash, known_height=None, session=None): # use any server if its old, otherwise restrict to who gave us the history restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get', [tx_hash], restricted) + return self.rpc('blockchain.transaction.get', [tx_hash], restricted, session=session) def get_transaction_batch(self, txids, restricted=True, session=None): # use any server if its old, otherwise restrict to who gave us the history - return self.rpc('blockchain.transaction.get_batch', txids, restricted, session) + return self.rpc('blockchain.transaction.get_batch', txids, restricted, session=session) - def get_transaction_and_merkle(self, tx_hash, known_height=None): + def get_transaction_and_merkle(self, tx_hash, known_height=None, session=None): # use any server if its old, otherwise restrict to who gave us the history restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.info', [tx_hash], restricted) + return self.rpc('blockchain.transaction.info', [tx_hash], restricted, session=session) - def get_transaction_height(self, tx_hash, known_height=None): + def get_transaction_height(self, tx_hash, known_height=None, session=None): restricted = not known_height or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) + return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted, session=session) - def get_merkle(self, tx_hash, height): + def get_merkle(self, tx_hash, height, session=None): restricted = 0 > height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) + return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted, session=session) def get_headers(self, height, count=10000, b64=False): restricted = height >= self.remote_height - 100 return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) # --- Subscribes, history and broadcasts are always aimed towards the master client directly - def get_history(self, address): - return self.rpc('blockchain.address.get_history', [address], True) + def get_history(self, address, session=None): + return self.rpc('blockchain.address.get_history', [address], True, session=session) - def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) + def broadcast(self, raw_transaction, session=None): + return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True, session=session) def subscribe_headers(self): return self.rpc('blockchain.headers.subscribe', [True], True) - async def subscribe_address(self, address, *addresses): + async def subscribe_address(self, session, address, *addresses): addresses = list((address, ) + addresses) - server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None + server_addr_and_port = session.server_address_and_port # on disconnect client will be None try: - return await self.rpc('blockchain.address.subscribe', addresses, True) + return await self.rpc('blockchain.address.subscribe', addresses, True, session=session) except asyncio.TimeoutError: log.warning( "timed out subscribing to addresses from %s:%i", *server_addr_and_port ) # abort and cancel, we can't lose a subscription, it will happen again on reconnect - if self.client: - self.client.abort() + if session: + session.abort() raise asyncio.CancelledError() - def unsubscribe_address(self, address): - return self.rpc('blockchain.address.unsubscribe', [address], True) + def unsubscribe_address(self, address, session=None): + return self.rpc('blockchain.address.unsubscribe', [address], True, session=session) - def get_server_features(self): - return self.rpc('server.features', (), restricted=True) + def get_server_features(self, session=None): + return self.rpc('server.features', (), restricted=True, session=session) def get_claims_by_ids(self, claim_ids): return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) def resolve(self, urls, session_override=None): - return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override) + return self.rpc('blockchain.claimtrie.resolve', urls, False, session=session_override) def claim_search(self, session_override=None, **kwargs): - return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override) + return self.rpc('blockchain.claimtrie.search', kwargs, False, session=session_override) async def new_resolve(self, server, urls): message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}} @@ -371,6 +447,7 @@ class SessionPool: def __init__(self, network: Network, timeout: float): self.network = network self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() + self.wallet_session_tasks: Dict[ClientSession, Optional[asyncio.Task]] = dict() self.timeout = timeout self.new_connection_event = asyncio.Event() @@ -430,6 +507,19 @@ class SessionPool: task.add_done_callback(lambda _: self.ensure_connections()) self.sessions[session] = task + async def connect_wallet(self, wallet_id: str): + fastest = await self.wait_for_fastest_session() + if not self.network.is_connected(wallet_id): + session = self.network.connect_wallet_client(wallet_id, fastest.server) + # session._on_connect_cb = self._get_session_connect_callback(session) + else: + session = self.network.clients[wallet_id] + task = self.wallet_session_tasks.get(session, None) + if not task or task.done(): + task = asyncio.create_task(session.ensure_session()) + # task.add_done_callback(lambda _: self.ensure_connections()) + self.wallet_session_tasks[session] = task + def start(self, default_servers): for server in default_servers: self._connect_session(server) diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py index c5f79b116..79fbda84f 100644 --- a/lbry/wallet/server/session.py +++ b/lbry/wallet/server/session.py @@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase): return await self.address_status(hashX) async def hashX_unsubscribe(self, hashX, alias): - try: - del self.hashX_subs[hashX] - except ValueError: - pass + self.hashX_subs.pop(hashX, None) def address_to_hashX(self, address): try: