Merge pull request #2716 from lbryio/batched-address-subscriptions

speed up wallet sync and startup by batching address history subscriptions
This commit is contained in:
Jack Robison 2020-01-10 21:32:12 -05:00 committed by GitHub
commit ea7056835f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 57 additions and 35 deletions

View file

@ -24,7 +24,7 @@ class BlobAnnouncer:
else: else:
log.debug("failed to announce %s, could only find %d peers, retrying soon.", blob_hash[:8], peers) log.debug("failed to announce %s, could only find %d peers, retrying soon.", blob_hash[:8], peers)
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError): if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise err raise err
log.warning("error announcing %s: %s", blob_hash[:8], str(err)) log.warning("error announcing %s: %s", blob_hash[:8], str(err))

View file

@ -376,7 +376,7 @@ class UPnPComponent(Component):
self.upnp = await UPnP.discover(loop=self.component_manager.loop) self.upnp = await UPnP.discover(loop=self.component_manager.loop)
log.info("found upnp gateway: %s", self.upnp.gateway.manufacturer_string) log.info("found upnp gateway: %s", self.upnp.gateway.manufacturer_string)
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError): if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise raise
log.warning("upnp discovery failed: %s", err) log.warning("upnp discovery failed: %s", err)
self.upnp = None self.upnp = None

View file

@ -371,7 +371,7 @@ class StreamManager:
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise ResolveTimeoutError(uri) raise ResolveTimeoutError(uri)
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError): if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise raise
log.exception("Unexpected error resolving stream:") log.exception("Unexpected error resolving stream:")
raise ResolveError(f"Unexpected error resolving stream: {str(err)}") raise ResolveError(f"Unexpected error resolving stream: {str(err)}")

View file

@ -416,7 +416,7 @@ class Database(SQLiteMixin):
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash, history): def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash):
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)) conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True))
for txo in tx.outputs: for txo in tx.outputs:
@ -438,18 +438,20 @@ class Database(SQLiteMixin):
'address': address, 'address': address,
}, ignore_duplicate=True)).fetchall() }, ignore_duplicate=True)).fetchall()
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':') // 2, address)
)
def save_transaction_io(self, tx: Transaction, address, txhash, history): def save_transaction_io(self, tx: Transaction, address, txhash, history):
return self.db.run(self._transaction_io, tx, address, txhash, history) return self.save_transaction_io_batch([tx], address, txhash, history)
def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history): def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history):
history_count = history.count(':') // 2
def __many(conn): def __many(conn):
for tx in txs: for tx in txs:
self._transaction_io(conn, tx, address, txhash, history) self._transaction_io(conn, tx, address, txhash)
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history_count, address)
).fetchall()
return self.db.run(__many) return self.db.run(__many)
async def reserve_outputs(self, txos, is_reserved=True): async def reserve_outputs(self, txos, is_reserved=True):

View file

@ -7,9 +7,9 @@ from io import StringIO
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from collections import namedtuple from collections import namedtuple, defaultdict
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Dict, Tuple, Type, Iterable, List, Optional from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict
import pylru import pylru
from lbry.schema.result import Outputs from lbry.schema.result import Outputs
@ -154,7 +154,7 @@ class Ledger(metaclass=LedgerRegistry):
self._update_tasks = TaskGroup() self._update_tasks = TaskGroup()
self._utxo_reservation_lock = asyncio.Lock() self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock()
self._address_update_locks: Dict[str, asyncio.Lock] = {} self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.coin_selection_strategy = None self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set() self._known_addresses_out_of_sync = set()
@ -425,6 +425,7 @@ class Ledger(metaclass=LedgerRegistry):
async def subscribe_accounts(self): async def subscribe_accounts(self):
if self.network.is_connected and self.accounts: if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([ await asyncio.wait([
self.subscribe_account(a) for a in self.accounts self.subscribe_account(a) for a in self.accounts
]) ])
@ -444,24 +445,28 @@ class Ledger(metaclass=LedgerRegistry):
AddressesGeneratedEvent(address_manager, addresses) AddressesGeneratedEvent(address_manager, addresses)
) )
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str]): async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses: if self.network.is_connected and addresses:
await asyncio.wait([ addresses_remaining = list(addresses)
self.subscribe_address(address_manager, address) for address in addresses while addresses_remaining:
]) batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch)
async def subscribe_address(self, address_manager: AddressManager, address: str): for address, remote_status in zip(batch, results):
remote_status = await self.network.subscribe_address(address) self._update_tasks.add(self.update_history(address, remote_status, address_manager))
self._update_tasks.add(self.update_history(address, remote_status, address_manager)) addresses_remaining = addresses_remaining[batch_size:]
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port)
log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
)
def process_status_update(self, update): def process_status_update(self, update):
address, remote_status = update address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status)) self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status, async def update_history(self, address, remote_status, address_manager: AddressManager = None):
address_manager: AddressManager = None): async with self._address_update_locks[address]:
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
self._known_addresses_out_of_sync.discard(address) self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address) local_status, local_history = await self.get_local_status_and_history(address)
@ -685,7 +690,9 @@ class Ledger(metaclass=LedgerRegistry):
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ", "%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change, account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count) account.change.gap, channel_count, len(account.channel_keys), claim_count)
except: # pylint: disable=bare-except except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception( log.exception(
'Failed to display wallet state, please file issue ' 'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:') 'for this bug along with the traceback you see below:')
@ -708,7 +715,9 @@ class Ledger(metaclass=LedgerRegistry):
claim_ids = [p.purchased_claim_id for p in purchases] claim_ids = [p.purchased_claim_id for p in purchases]
try: try:
resolved, _, _ = await self.claim_search([], claim_ids=claim_ids) resolved, _, _ = await self.claim_search([], claim_ids=claim_ids)
except: # pylint: disable=bare-except except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Resolve failed while looking up purchased claim ids:") log.exception("Resolve failed while looking up purchased claim ids:")
resolved = [] resolved = []
lookup = {claim.claim_id: claim for claim in resolved} lookup = {claim.claim_id: claim for claim in resolved}
@ -741,7 +750,9 @@ class Ledger(metaclass=LedgerRegistry):
claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset] claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset]
try: try:
resolve_results, _, _ = await self.claim_search([], claim_ids=claim_ids) resolve_results, _, _ = await self.claim_search([], claim_ids=claim_ids)
except: # pylint: disable=bare-except except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Resolve failed while looking up collection claim ids:") log.exception("Resolve failed while looking up collection claim ids:")
return [] return []
claims = [] claims = []

View file

@ -256,10 +256,15 @@ class Network:
def subscribe_headers(self): def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True) return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address): async def subscribe_address(self, address, *addresses):
addresses = list((address, ) + addresses)
try: try:
return await self.rpc('blockchain.address.subscribe', [address], True) return await self.rpc('blockchain.address.subscribe', addresses, True)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning(
"timed out subscribing to addresses from %s:%i",
*self.client.server_address_and_port
)
# abort and cancel, we can't lose a subscription, it will happen again on reconnect # abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client: if self.client:
self.client.abort() self.client.abort()

View file

@ -1141,12 +1141,16 @@ class LBRYElectrumX(SessionBase):
hashX = self.address_to_hashX(address) hashX = self.address_to_hashX(address)
return await self.hashX_listunspent(hashX) return await self.hashX_listunspent(hashX)
async def address_subscribe(self, address): async def address_subscribe(self, *addresses):
"""Subscribe to an address. """Subscribe to an address.
address: the address to subscribe to""" address: the address to subscribe to"""
hashX = self.address_to_hashX(address) if len(addresses) > 1000:
return await self.hashX_subscribe(hashX, address) raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}')
hashXes = [
(self.address_to_hashX(address), address) for address in addresses
]
return await asyncio.gather(*(self.hashX_subscribe(*args) for args in hashXes))
async def address_unsubscribe(self, address): async def address_unsubscribe(self, address):
"""Unsubscribe an address. """Unsubscribe an address.

View file

@ -320,7 +320,7 @@ class TestStreamManager(BlobExchangeTestBase):
try: try:
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout) await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError): if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise raise
error = err error = err
self.assertEqual(expected_error, type(error)) self.assertEqual(expected_error, type(error))