one spv connection per loaded wallet

This commit is contained in:
Jack Robison 2020-11-17 10:57:22 -05:00
parent f6b396ae64
commit 13b473403e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 166 additions and 69 deletions

View file

@ -1303,8 +1303,8 @@ class Daemon(metaclass=JSONRPCServerType):
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
} }
) )
if self.ledger.network.is_connected: await self.ledger.network.connect_wallets(wallet.id)
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
wallet.save() wallet.save()
if not skip_on_startup: if not skip_on_startup:
with self.conf.update_config() as c: with self.conf.update_config() as c:
@ -1331,9 +1331,8 @@ class Daemon(metaclass=JSONRPCServerType):
if not os.path.exists(wallet_path): if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.") raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path) wallet = self.wallet_manager.import_wallet(wallet_path)
if self.ledger.network.is_connected: if not self.ledger.network.is_connected(wallet.id):
for account in wallet.accounts: await self.ledger.network.connect_wallets(wallet.id)
await self.ledger.subscribe_account(account)
return wallet return wallet
@requires("wallet") @requires("wallet")
@ -1619,7 +1618,7 @@ class Daemon(metaclass=JSONRPCServerType):
} }
) )
wallet.save() wallet.save()
if self.ledger.network.is_connected: if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
return account return account
@ -1647,7 +1646,7 @@ class Daemon(metaclass=JSONRPCServerType):
} }
) )
wallet.save() wallet.save()
if self.ledger.network.is_connected: if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
return account return account
@ -1863,7 +1862,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet_changed = False wallet_changed = False
if data is not None: if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data) added_accounts = wallet.merge(self.wallet_manager, password, data)
if added_accounts and self.ledger.network.is_connected: if added_accounts and self.ledger.network.is_connected(wallet.id):
if blocking: if blocking:
await asyncio.wait([ await asyncio.wait([
a.ledger.subscribe_account(a) for a in added_accounts a.ledger.subscribe_account(a) for a in added_accounts
@ -2957,7 +2956,7 @@ class Daemon(metaclass=JSONRPCServerType):
'public_key': data['holding_public_key'], 'public_key': data['holding_public_key'],
'address_generator': {'name': 'single-address'} 'address_generator': {'name': 'single-address'}
}) })
if self.ledger.network.is_connected: if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account) await self.ledger.subscribe_account(account)
await self.ledger._update_tasks.done.wait() await self.ledger._update_tasks.done.wait()
# Case 3: the holding address has changed and we can't create or find an account for it # Case 3: the holding address has changed and we can't create or find an account for it

View file

@ -122,7 +122,7 @@ class Ledger(metaclass=LedgerRegistry):
self.headers.checkpoints = self.checkpoints self.headers.checkpoints = self.checkpoints
self.network: Network = self.config.get('network') or Network(self) self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header) self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update) # self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network) self.network.on_connected.listen(self.join_network)
self.accounts = [] self.accounts = []
@ -324,12 +324,15 @@ class Ledger(metaclass=LedgerRegistry):
self.db.open(), self.db.open(),
self.headers.open() self.headers.open()
]) ])
fully_synced = self.on_ready.first fully_synced = self.on_ready.first
asyncio.create_task(self.network.start()) asyncio.create_task(self.network.start())
await self.network.on_connected.first await self.network.on_connected.first
async with self._header_processing_lock: async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync()) await self._update_tasks.add(self.initial_headers_sync())
await fully_synced await fully_synced
await self.db.release_all_outputs() await self.db.release_all_outputs()
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
@ -448,11 +451,13 @@ class Ledger(metaclass=LedgerRegistry):
) )
async def subscribe_accounts(self): async def subscribe_accounts(self):
if self.network.is_connected and self.accounts: wallet_ids = {account.wallet.id for account in self.accounts}
log.info("Subscribe to %i accounts", len(self.accounts)) await self.network.connect_wallets(*wallet_ids)
await asyncio.wait([ # if self.network.is_connected and self.accounts:
self.subscribe_account(a) for a in self.accounts log.info("Subscribe to %i accounts", len(self.accounts))
]) await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: Account): async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values(): for address_manager in account.address_managers.values():
@ -460,8 +465,10 @@ class Ledger(metaclass=LedgerRegistry):
await account.ensure_address_gap() await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account): async def unsubscribe_account(self, account: Account):
session = self.network.get_wallet_session(account.wallet)
for address in await account.get_addresses(): for address in await account.get_addresses():
await self.network.unsubscribe_address(address) await self.network.unsubscribe_address(address, session=session)
await self.network.close_wallet_session(account.wallet)
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]): async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
await self.subscribe_addresses(address_manager, addresses) await self.subscribe_addresses(address_manager, addresses)
@ -470,28 +477,29 @@ class Ledger(metaclass=LedgerRegistry):
) )
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses: if self.network.is_connected(address_manager.account.wallet.id) and addresses:
session = self.network.get_wallet_session(address_manager.account.wallet)
addresses_remaining = list(addresses) addresses_remaining = list(addresses)
while addresses_remaining: while addresses_remaining:
batch = addresses_remaining[:batch_size] batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch) results = await self.network.subscribe_address(session, *batch)
for address, remote_status in zip(batch, results): for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager)) self._update_tasks.add(self.update_history(session, address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:] addresses_remaining = addresses_remaining[batch_size:]
if self.network.client and self.network.client.server_address_and_port: if session and session.server_address_and_port:
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port) len(addresses), *session.server_address_and_port)
if self.network.client and self.network.client.server_address_and_port: if session and session.server_address_and_port:
log.info( log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses), "finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port *session.server_address_and_port
) )
def process_status_update(self, update): def process_status_update(self, client, update):
address, remote_status = update address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status)) self._update_tasks.add(self.update_history(client, address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None, async def update_history(self, client, address, remote_status, address_manager: AddressManager = None,
reattempt_update: bool = True): reattempt_update: bool = True):
async with self._address_update_locks[address]: async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address) self._known_addresses_out_of_sync.discard(address)
@ -500,7 +508,9 @@ class Ledger(metaclass=LedgerRegistry):
if local_status == remote_status: if local_status == remote_status:
return True return True
remote_history = await self.network.retriable_call(self.network.get_history, address) # remote_history = await self.network.retriable_call(self.network.get_history, address, client)
remote_history = await self.network.get_history(address, session=client)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history) we_need = set(remote_history) - set(local_history)
if not we_need: if not we_need:
@ -563,7 +573,7 @@ class Ledger(metaclass=LedgerRegistry):
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs), "request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
len(remote_history), address len(remote_history), address
) )
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address) requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address, client)
for tx in requested_txes: for tx in requested_txes:
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:" pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
synced_txs.append(tx) synced_txs.append(tx)
@ -603,7 +613,7 @@ class Ledger(metaclass=LedgerRegistry):
if self._tx_cache.get(txid) is not cache_item: if self._tx_cache.get(txid) is not cache_item:
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update) log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
if reattempt_update: if reattempt_update:
return await self.update_history(address, remote_status, address_manager, False) return await self.update_history(client, address, remote_status, address_manager, False)
return False return False
local_status, local_history = \ local_status, local_history = \
@ -741,7 +751,7 @@ class Ledger(metaclass=LedgerRegistry):
await _single_batch(batch) await _single_batch(batch)
return transactions return transactions
async def _request_transaction_batch(self, to_request, remote_history_size, address): async def _request_transaction_batch(self, to_request, remote_history_size, address, session):
header_cache = {} header_cache = {}
batches = [[]] batches = [[]]
remote_heights = {} remote_heights = {}
@ -766,7 +776,8 @@ class Ledger(metaclass=LedgerRegistry):
async def _single_batch(batch): async def _single_batch(batch):
this_batch_synced = [] this_batch_synced = []
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) # batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, session)
batch_result = await self.network.get_transaction_batch(batch, session)
for txid, (raw, merkle) in batch_result.items(): for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid] remote_height = remote_heights[txid]
merkle_height = merkle['block_height'] merkle_height = merkle['block_height']

View file

@ -6,7 +6,7 @@ from operator import itemgetter
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import typing
import aiohttp import aiohttp
from lbry import __version__ from lbry import __version__
@ -14,6 +14,7 @@ from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -33,11 +34,26 @@ class ClientSession(BaseClientSession):
self.pending_amount = 0 self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None) self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event() self.trigger_urgent_reconnect = asyncio.Event()
self._connected = asyncio.Event()
self._disconnected = asyncio.Event()
self._disconnected.set()
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.on_status.listen(partial(self.network.ledger.process_status_update, self))
self.subscription_controllers = {
'blockchain.headers.subscribe': self.network._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property @property
def available(self): def available(self):
return not self.is_closing() and self.response_time is not None return not self.is_closing() and self.response_time is not None
@property
def is_connected(self) -> bool:
return self._connected.is_set()
@property @property
def server_address_and_port(self) -> Optional[Tuple[str, int]]: def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport: if not self.transport:
@ -144,7 +160,7 @@ class ClientSession(BaseClientSession):
self.connection_latency = perf_counter() - start self.connection_latency = perf_counter() - start
async def handle_request(self, request): async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method] controller = self.subscription_controllers[request.method]
controller.add(request.args) controller.add(request.args)
def connection_lost(self, exc): def connection_lost(self, exc):
@ -154,6 +170,13 @@ class ClientSession(BaseClientSession):
self.connection_latency = None self.connection_latency = None
self._response_samples = 0 self._response_samples = 0
self._on_disconnect_controller.add(True) self._on_disconnect_controller.add(True)
self._connected.clear()
self._disconnected.set()
def connection_made(self, transport):
super(ClientSession, self).connection_made(transport)
self._disconnected.clear()
self._connected.set()
class Network: class Network:
@ -165,6 +188,7 @@ class Network:
self.ledger = ledger self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None self.client: Optional[ClientSession] = None
self.clients: Dict[str, ClientSession] = {}
self.server_features = None self.server_features = None
self._switch_task: Optional[asyncio.Task] = None self._switch_task: Optional[asyncio.Task] = None
self.running = False self.running = False
@ -177,27 +201,31 @@ class Network:
self._on_header_controller = StreamController(merge_repeated_events=True) self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = { self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller, 'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
} }
self.aiohttp_session: Optional[aiohttp.ClientSession] = None self.aiohttp_session: Optional[aiohttp.ClientSession] = None
def get_wallet_session(self, wallet):
return self.clients[wallet.id]
async def close_wallet_session(self, wallet):
await self.clients.pop(wallet.id).close()
@property @property
def config(self): def config(self):
return self.ledger.config return self.ledger.config
async def switch_forever(self): async def switch_forever(self):
while self.running: while self.running:
if self.is_connected: if self.is_connected():
await self.client.on_disconnected.first await self.client.on_disconnected.first
self.server_features = None self.server_features = None
self.client = None self.client = None
continue continue
self.client = await self.session_pool.wait_for_fastest_session() self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server) log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try: try:
@ -220,6 +248,11 @@ class Network:
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts})
await self.ledger.subscribe_accounts()
async def connect_wallets(self, *wallet_ids):
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
async def stop(self): async def stop(self):
if self.running: if self.running:
@ -228,9 +261,19 @@ class Network:
self._switch_task.cancel() self._switch_task.cancel()
self.session_pool.stop() self.session_pool.stop()
@property def is_connected(self, wallet_id: str = None):
def is_connected(self): if wallet_id is None:
return self.client and not self.client.is_closing() if not self.client:
return False
return self.client.is_connected
if wallet_id not in self.clients:
return False
return self.clients[wallet_id].is_connected
def connect_wallet_client(self, wallet_id: str, server: Tuple[str, int]):
client = ClientSession(network=self, server=server)
self.clients[wallet_id] = client
return client
def rpc(self, list_or_method, args, restricted=True, session=None): def rpc(self, list_or_method, args, restricted=True, session=None):
session = session or (self.client if restricted else self.session_pool.fastest_session) session = session or (self.client if restricted else self.session_pool.fastest_session)
@ -256,7 +299,8 @@ class Network:
raise asyncio.CancelledError() # if we got here, we are shutting down raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager @asynccontextmanager
async def single_call_context(self, function, *args, **kwargs): async def fastest_connection_context(self):
if not self.is_connected: if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.") log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first await self.on_connected.first
@ -264,6 +308,38 @@ class Network:
server = self.session_pool.fastest_session.server server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server) session = ClientSession(network=self, server=server)
async def call_with_reconnect(function, *args, **kwargs):
nonlocal session
while self.running:
if not session.is_connected:
try:
await session.create_connection()
except asyncio.TimeoutError:
if not session.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
try:
return await partial(function, *args, session_override=session, **kwargs)(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield call_with_reconnect
finally:
await session.close()
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected():
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw): async def call_with_reconnect(*a, **kw):
while self.running: while self.running:
if not session.available: if not session.available:
@ -280,71 +356,71 @@ class Network:
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None): def get_transaction(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [tx_hash], restricted, session=session)
def get_transaction_batch(self, txids, restricted=True, session=None): def get_transaction_batch(self, txids, restricted=True, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session) return self.rpc('blockchain.transaction.get_batch', txids, restricted, session=session)
def get_transaction_and_merkle(self, tx_hash, known_height=None): def get_transaction_and_merkle(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [tx_hash], restricted) return self.rpc('blockchain.transaction.info', [tx_hash], restricted, session=session)
def get_transaction_height(self, tx_hash, known_height=None): def get_transaction_height(self, tx_hash, known_height=None, session=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10 restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted, session=session)
def get_merkle(self, tx_hash, height): def get_merkle(self, tx_hash, height, session=None):
restricted = 0 > height > self.remote_height - 10 restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted, session=session)
def get_headers(self, height, count=10000, b64=False): def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100 restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly # --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address): def get_history(self, address, session=None):
return self.rpc('blockchain.address.get_history', [address], True) return self.rpc('blockchain.address.get_history', [address], True, session=session)
def broadcast(self, raw_transaction): def broadcast(self, raw_transaction, session=None):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True, session=session)
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True) return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address, *addresses): async def subscribe_address(self, session, address, *addresses):
addresses = list((address, ) + addresses) addresses = list((address, ) + addresses)
server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None server_addr_and_port = session.server_address_and_port # on disconnect client will be None
try: try:
return await self.rpc('blockchain.address.subscribe', addresses, True) return await self.rpc('blockchain.address.subscribe', addresses, True, session=session)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning( log.warning(
"timed out subscribing to addresses from %s:%i", "timed out subscribing to addresses from %s:%i",
*server_addr_and_port *server_addr_and_port
) )
# abort and cancel, we can't lose a subscription, it will happen again on reconnect # abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client: if session:
self.client.abort() session.abort()
raise asyncio.CancelledError() raise asyncio.CancelledError()
def unsubscribe_address(self, address): def unsubscribe_address(self, address, session=None):
return self.rpc('blockchain.address.unsubscribe', [address], True) return self.rpc('blockchain.address.unsubscribe', [address], True, session=session)
def get_server_features(self): def get_server_features(self, session=None):
return self.rpc('server.features', (), restricted=True) return self.rpc('server.features', (), restricted=True, session=session)
def get_claims_by_ids(self, claim_ids): def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls, session_override=None): def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override) return self.rpc('blockchain.claimtrie.resolve', urls, False, session=session_override)
def claim_search(self, session_override=None, **kwargs): def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override) return self.rpc('blockchain.claimtrie.search', kwargs, False, session=session_override)
async def new_resolve(self, server, urls): async def new_resolve(self, server, urls):
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}} message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}
@ -371,6 +447,7 @@ class SessionPool:
def __init__(self, network: Network, timeout: float): def __init__(self, network: Network, timeout: float):
self.network = network self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.wallet_session_tasks: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout self.timeout = timeout
self.new_connection_event = asyncio.Event() self.new_connection_event = asyncio.Event()
@ -430,6 +507,19 @@ class SessionPool:
task.add_done_callback(lambda _: self.ensure_connections()) task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task self.sessions[session] = task
async def connect_wallet(self, wallet_id: str):
fastest = await self.wait_for_fastest_session()
if not self.network.is_connected(wallet_id):
session = self.network.connect_wallet_client(wallet_id, fastest.server)
# session._on_connect_cb = self._get_session_connect_callback(session)
else:
session = self.network.clients[wallet_id]
task = self.wallet_session_tasks.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
# task.add_done_callback(lambda _: self.ensure_connections())
self.wallet_session_tasks[session] = task
def start(self, default_servers): def start(self, default_servers):
for server in default_servers: for server in default_servers:
self._connect_session(server) self._connect_session(server)

View file

@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase):
return await self.address_status(hashX) return await self.address_status(hashX)
async def hashX_unsubscribe(self, hashX, alias): async def hashX_unsubscribe(self, hashX, alias):
try: self.hashX_subs.pop(hashX, None)
del self.hashX_subs[hashX]
except ValueError:
pass
def address_to_hashX(self, address): def address_to_hashX(self, address):
try: try: