one spv connection per loaded wallet

This commit is contained in:
Jack Robison 2020-11-17 10:57:22 -05:00
parent f6b396ae64
commit 13b473403e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 166 additions and 69 deletions

View file

@ -1303,8 +1303,8 @@ class Daemon(metaclass=JSONRPCServerType):
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
}
)
if self.ledger.network.is_connected:
await self.ledger.subscribe_account(account)
await self.ledger.network.connect_wallets(wallet.id)
await self.ledger.subscribe_account(account)
wallet.save()
if not skip_on_startup:
with self.conf.update_config() as c:
@ -1331,9 +1331,8 @@ class Daemon(metaclass=JSONRPCServerType):
if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path)
if self.ledger.network.is_connected:
for account in wallet.accounts:
await self.ledger.subscribe_account(account)
if not self.ledger.network.is_connected(wallet.id):
await self.ledger.network.connect_wallets(wallet.id)
return wallet
@requires("wallet")
@ -1619,7 +1618,7 @@ class Daemon(metaclass=JSONRPCServerType):
}
)
wallet.save()
if self.ledger.network.is_connected:
if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account)
return account
@ -1647,7 +1646,7 @@ class Daemon(metaclass=JSONRPCServerType):
}
)
wallet.save()
if self.ledger.network.is_connected:
if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account)
return account
@ -1863,7 +1862,7 @@ class Daemon(metaclass=JSONRPCServerType):
wallet_changed = False
if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data)
if added_accounts and self.ledger.network.is_connected:
if added_accounts and self.ledger.network.is_connected(wallet.id):
if blocking:
await asyncio.wait([
a.ledger.subscribe_account(a) for a in added_accounts
@ -2957,7 +2956,7 @@ class Daemon(metaclass=JSONRPCServerType):
'public_key': data['holding_public_key'],
'address_generator': {'name': 'single-address'}
})
if self.ledger.network.is_connected:
if self.ledger.network.is_connected(wallet.id):
await self.ledger.subscribe_account(account)
await self.ledger._update_tasks.done.wait()
# Case 3: the holding address has changed and we can't create or find an account for it

View file

@ -122,7 +122,7 @@ class Ledger(metaclass=LedgerRegistry):
self.headers.checkpoints = self.checkpoints
self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update)
# self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network)
self.accounts = []
@ -324,12 +324,15 @@ class Ledger(metaclass=LedgerRegistry):
self.db.open(),
self.headers.open()
])
fully_synced = self.on_ready.first
asyncio.create_task(self.network.start())
await self.network.on_connected.first
async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync())
await fully_synced
await self.db.release_all_outputs()
await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
@ -448,11 +451,13 @@ class Ledger(metaclass=LedgerRegistry):
)
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
wallet_ids = {account.wallet.id for account in self.accounts}
await self.network.connect_wallets(*wallet_ids)
# if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values():
@ -460,8 +465,10 @@ class Ledger(metaclass=LedgerRegistry):
await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account):
session = self.network.get_wallet_session(account.wallet)
for address in await account.get_addresses():
await self.network.unsubscribe_address(address)
await self.network.unsubscribe_address(address, session=session)
await self.network.close_wallet_session(account.wallet)
async def announce_addresses(self, address_manager: AddressManager, addresses: List[str]):
await self.subscribe_addresses(address_manager, addresses)
@ -470,28 +477,29 @@ class Ledger(metaclass=LedgerRegistry):
)
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses:
if self.network.is_connected(address_manager.account.wallet.id) and addresses:
session = self.network.get_wallet_session(address_manager.account.wallet)
addresses_remaining = list(addresses)
while addresses_remaining:
batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch)
results = await self.network.subscribe_address(session, *batch)
for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
self._update_tasks.add(self.update_history(session, address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:]
if self.network.client and self.network.client.server_address_and_port:
if session and session.server_address_and_port:
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port)
if self.network.client and self.network.client.server_address_and_port:
len(addresses), *session.server_address_and_port)
if session and session.server_address_and_port:
log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
*session.server_address_and_port
)
def process_status_update(self, update):
def process_status_update(self, client, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))
self._update_tasks.add(self.update_history(client, address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None,
async def update_history(self, client, address, remote_status, address_manager: AddressManager = None,
reattempt_update: bool = True):
async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address)
@ -500,7 +508,9 @@ class Ledger(metaclass=LedgerRegistry):
if local_status == remote_status:
return True
remote_history = await self.network.retriable_call(self.network.get_history, address)
# remote_history = await self.network.retriable_call(self.network.get_history, address, client)
remote_history = await self.network.get_history(address, session=client)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
@ -563,7 +573,7 @@ class Ledger(metaclass=LedgerRegistry):
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
len(remote_history), address
)
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address)
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address, client)
for tx in requested_txes:
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
synced_txs.append(tx)
@ -603,7 +613,7 @@ class Ledger(metaclass=LedgerRegistry):
if self._tx_cache.get(txid) is not cache_item:
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
if reattempt_update:
return await self.update_history(address, remote_status, address_manager, False)
return await self.update_history(client, address, remote_status, address_manager, False)
return False
local_status, local_history = \
@ -741,7 +751,7 @@ class Ledger(metaclass=LedgerRegistry):
await _single_batch(batch)
return transactions
async def _request_transaction_batch(self, to_request, remote_history_size, address):
async def _request_transaction_batch(self, to_request, remote_history_size, address, session):
header_cache = {}
batches = [[]]
remote_heights = {}
@ -766,7 +776,8 @@ class Ledger(metaclass=LedgerRegistry):
async def _single_batch(batch):
this_batch_synced = []
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
# batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch, session)
batch_result = await self.network.get_transaction_batch(batch, session)
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
merkle_height = merkle['block_height']

View file

@ -6,7 +6,7 @@ from operator import itemgetter
from contextlib import asynccontextmanager
from functools import partial
from typing import Dict, Optional, Tuple
import typing
import aiohttp
from lbry import __version__
@ -14,6 +14,7 @@ from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__)
@ -33,11 +34,26 @@ class ClientSession(BaseClientSession):
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
self._connected = asyncio.Event()
self._disconnected = asyncio.Event()
self._disconnected.set()
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.on_status.listen(partial(self.network.ledger.process_status_update, self))
self.subscription_controllers = {
'blockchain.headers.subscribe': self.network._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property
def available(self):
return not self.is_closing() and self.response_time is not None
@property
def is_connected(self) -> bool:
return self._connected.is_set()
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
@ -144,7 +160,7 @@ class ClientSession(BaseClientSession):
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller = self.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
@ -154,6 +170,13 @@ class ClientSession(BaseClientSession):
self.connection_latency = None
self._response_samples = 0
self._on_disconnect_controller.add(True)
self._connected.clear()
self._disconnected.set()
def connection_made(self, transport):
super(ClientSession, self).connection_made(transport)
self._disconnected.clear()
self._connected.set()
class Network:
@ -165,6 +188,7 @@ class Network:
self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self.clients: Dict[str, ClientSession] = {}
self.server_features = None
self._switch_task: Optional[asyncio.Task] = None
self.running = False
@ -177,27 +201,31 @@ class Network:
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
def get_wallet_session(self, wallet):
return self.clients[wallet.id]
async def close_wallet_session(self, wallet):
await self.clients.pop(wallet.id).close()
@property
def config(self):
return self.ledger.config
async def switch_forever(self):
while self.running:
if self.is_connected:
if self.is_connected():
await self.client.on_disconnected.first
self.server_features = None
self.client = None
continue
self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try:
@ -220,6 +248,11 @@ class Network:
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts})
await self.ledger.subscribe_accounts()
async def connect_wallets(self, *wallet_ids):
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
async def stop(self):
if self.running:
@ -228,9 +261,19 @@ class Network:
self._switch_task.cancel()
self.session_pool.stop()
@property
def is_connected(self):
return self.client and not self.client.is_closing()
def is_connected(self, wallet_id: str = None):
if wallet_id is None:
if not self.client:
return False
return self.client.is_connected
if wallet_id not in self.clients:
return False
return self.clients[wallet_id].is_connected
def connect_wallet_client(self, wallet_id: str, server: Tuple[str, int]):
client = ClientSession(network=self, server=server)
self.clients[wallet_id] = client
return client
def rpc(self, list_or_method, args, restricted=True, session=None):
session = session or (self.client if restricted else self.session_pool.fastest_session)
@ -256,7 +299,8 @@ class Network:
raise asyncio.CancelledError() # if we got here, we are shutting down
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
async def fastest_connection_context(self):
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
@ -264,6 +308,38 @@ class Network:
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(function, *args, **kwargs):
nonlocal session
while self.running:
if not session.is_connected:
try:
await session.create_connection()
except asyncio.TimeoutError:
if not session.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
try:
return await partial(function, *args, session_override=session, **kwargs)(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("'%s' failed, retrying", function.__name__)
try:
yield call_with_reconnect
finally:
await session.close()
@asynccontextmanager
async def single_call_context(self, function, *args, **kwargs):
if not self.is_connected():
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
server = self.session_pool.fastest_session.server
session = ClientSession(network=self, server=server)
async def call_with_reconnect(*a, **kw):
while self.running:
if not session.available:
@ -280,71 +356,71 @@ class Network:
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None):
def get_transaction(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
return self.rpc('blockchain.transaction.get', [tx_hash], restricted, session=session)
def get_transaction_batch(self, txids, restricted=True, session=None):
# use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session)
return self.rpc('blockchain.transaction.get_batch', txids, restricted, session=session)
def get_transaction_and_merkle(self, tx_hash, known_height=None):
def get_transaction_and_merkle(self, tx_hash, known_height=None, session=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [tx_hash], restricted)
return self.rpc('blockchain.transaction.info', [tx_hash], restricted, session=session)
def get_transaction_height(self, tx_hash, known_height=None):
def get_transaction_height(self, tx_hash, known_height=None, session=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted)
return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted, session=session)
def get_merkle(self, tx_hash, height):
def get_merkle(self, tx_hash, height, session=None):
restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted)
return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted, session=session)
def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], True)
def get_history(self, address, session=None):
return self.rpc('blockchain.address.get_history', [address], True, session=session)
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def broadcast(self, raw_transaction, session=None):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True, session=session)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address, *addresses):
async def subscribe_address(self, session, address, *addresses):
addresses = list((address, ) + addresses)
server_addr_and_port = self.client.server_address_and_port # on disconnect client will be None
server_addr_and_port = session.server_address_and_port # on disconnect client will be None
try:
return await self.rpc('blockchain.address.subscribe', addresses, True)
return await self.rpc('blockchain.address.subscribe', addresses, True, session=session)
except asyncio.TimeoutError:
log.warning(
"timed out subscribing to addresses from %s:%i",
*server_addr_and_port
)
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client:
self.client.abort()
if session:
session.abort()
raise asyncio.CancelledError()
def unsubscribe_address(self, address):
return self.rpc('blockchain.address.unsubscribe', [address], True)
def unsubscribe_address(self, address, session=None):
return self.rpc('blockchain.address.unsubscribe', [address], True, session=session)
def get_server_features(self):
return self.rpc('server.features', (), restricted=True)
def get_server_features(self, session=None):
return self.rpc('server.features', (), restricted=True, session=session)
def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override)
return self.rpc('blockchain.claimtrie.resolve', urls, False, session=session_override)
def claim_search(self, session_override=None, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs, False, session_override)
return self.rpc('blockchain.claimtrie.search', kwargs, False, session=session_override)
async def new_resolve(self, server, urls):
message = {"method": "resolve", "params": {"urls": urls, "protobuf": True}}
@ -371,6 +447,7 @@ class SessionPool:
def __init__(self, network: Network, timeout: float):
self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.wallet_session_tasks: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@ -430,6 +507,19 @@ class SessionPool:
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
async def connect_wallet(self, wallet_id: str):
fastest = await self.wait_for_fastest_session()
if not self.network.is_connected(wallet_id):
session = self.network.connect_wallet_client(wallet_id, fastest.server)
# session._on_connect_cb = self._get_session_connect_callback(session)
else:
session = self.network.clients[wallet_id]
task = self.wallet_session_tasks.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
# task.add_done_callback(lambda _: self.ensure_connections())
self.wallet_session_tasks[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)

View file

@ -1227,10 +1227,7 @@ class LBRYElectrumX(SessionBase):
return await self.address_status(hashX)
async def hashX_unsubscribe(self, hashX, alias):
try:
del self.hashX_subs[hashX]
except ValueError:
pass
self.hashX_subs.pop(hashX, None)
def address_to_hashX(self, address):
try: