Merge pull request #2716 from lbryio/batched-address-subscriptions

speed up wallet sync and startup by batching address history subscriptions
This commit is contained in:
Jack Robison 2020-01-10 21:32:12 -05:00 committed by GitHub
commit ea7056835f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 57 additions and 35 deletions

View file

@ -24,7 +24,7 @@ class BlobAnnouncer:
else:
log.debug("failed to announce %s, could only find %d peers, retrying soon.", blob_hash[:8], peers)
except Exception as err:
if isinstance(err, asyncio.CancelledError):
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise err
log.warning("error announcing %s: %s", blob_hash[:8], str(err))

View file

@ -376,7 +376,7 @@ class UPnPComponent(Component):
self.upnp = await UPnP.discover(loop=self.component_manager.loop)
log.info("found upnp gateway: %s", self.upnp.gateway.manufacturer_string)
except Exception as err:
if isinstance(err, asyncio.CancelledError):
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.warning("upnp discovery failed: %s", err)
self.upnp = None

View file

@ -371,7 +371,7 @@ class StreamManager:
except asyncio.TimeoutError:
raise ResolveTimeoutError(uri)
except Exception as err:
if isinstance(err, asyncio.CancelledError):
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Unexpected error resolving stream:")
raise ResolveError(f"Unexpected error resolving stream: {str(err)}")

View file

@ -416,7 +416,7 @@ class Database(SQLiteMixin):
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash, history):
def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash):
conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True))
for txo in tx.outputs:
@ -438,18 +438,20 @@ class Database(SQLiteMixin):
'address': address,
}, ignore_duplicate=True)).fetchall()
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':') // 2, address)
)
def save_transaction_io(self, tx: Transaction, address, txhash, history):
return self.db.run(self._transaction_io, tx, address, txhash, history)
return self.save_transaction_io_batch([tx], address, txhash, history)
def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history):
history_count = history.count(':') // 2
def __many(conn):
for tx in txs:
self._transaction_io(conn, tx, address, txhash, history)
self._transaction_io(conn, tx, address, txhash)
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history_count, address)
).fetchall()
return self.db.run(__many)
async def reserve_outputs(self, txos, is_reserved=True):

View file

@ -7,9 +7,9 @@ from io import StringIO
from datetime import datetime
from functools import partial
from operator import itemgetter
from collections import namedtuple
from collections import namedtuple, defaultdict
from binascii import hexlify, unhexlify
from typing import Dict, Tuple, Type, Iterable, List, Optional
from typing import Dict, Tuple, Type, Iterable, List, Optional, DefaultDict
import pylru
from lbry.schema.result import Outputs
@ -154,7 +154,7 @@ class Ledger(metaclass=LedgerRegistry):
self._update_tasks = TaskGroup()
self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock()
self._address_update_locks: Dict[str, asyncio.Lock] = {}
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set()
@ -425,6 +425,7 @@ class Ledger(metaclass=LedgerRegistry):
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
@ -444,24 +445,28 @@ class Ledger(metaclass=LedgerRegistry):
AddressesGeneratedEvent(address_manager, addresses)
)
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str]):
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses:
await asyncio.wait([
self.subscribe_address(address_manager, address) for address in addresses
])
async def subscribe_address(self, address_manager: AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
addresses_remaining = list(addresses)
while addresses_remaining:
batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch)
for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:]
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port)
log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
)
def process_status_update(self, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status,
address_manager: AddressManager = None):
async with self._address_update_locks.setdefault(address, asyncio.Lock()):
async def update_history(self, address, remote_status, address_manager: AddressManager = None):
async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address)
@ -685,7 +690,9 @@ class Ledger(metaclass=LedgerRegistry):
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims. ",
account.id, balance, total_receiving, account.receiving.gap, total_change,
account.change.gap, channel_count, len(account.channel_keys), claim_count)
except: # pylint: disable=bare-except
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception(
'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:')
@ -708,7 +715,9 @@ class Ledger(metaclass=LedgerRegistry):
claim_ids = [p.purchased_claim_id for p in purchases]
try:
resolved, _, _ = await self.claim_search([], claim_ids=claim_ids)
except: # pylint: disable=bare-except
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Resolve failed while looking up purchased claim ids:")
resolved = []
lookup = {claim.claim_id: claim for claim in resolved}
@ -741,7 +750,9 @@ class Ledger(metaclass=LedgerRegistry):
claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset]
try:
resolve_results, _, _ = await self.claim_search([], claim_ids=claim_ids)
except: # pylint: disable=bare-except
except Exception as err:
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
log.exception("Resolve failed while looking up collection claim ids:")
return []
claims = []

View file

@ -256,10 +256,15 @@ class Network:
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address):
async def subscribe_address(self, address, *addresses):
addresses = list((address, ) + addresses)
try:
return await self.rpc('blockchain.address.subscribe', [address], True)
return await self.rpc('blockchain.address.subscribe', addresses, True)
except asyncio.TimeoutError:
log.warning(
"timed out subscribing to addresses from %s:%i",
*self.client.server_address_and_port
)
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client:
self.client.abort()

View file

@ -1141,12 +1141,16 @@ class LBRYElectrumX(SessionBase):
hashX = self.address_to_hashX(address)
return await self.hashX_listunspent(hashX)
async def address_subscribe(self, address):
async def address_subscribe(self, *addresses):
"""Subscribe to an address.
address: the address to subscribe to"""
hashX = self.address_to_hashX(address)
return await self.hashX_subscribe(hashX, address)
if len(addresses) > 1000:
raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}')
hashXes = [
(self.address_to_hashX(address), address) for address in addresses
]
return await asyncio.gather(*(self.hashX_subscribe(*args) for args in hashXes))
async def address_unsubscribe(self, address):
"""Unsubscribe an address.

View file

@ -320,7 +320,7 @@ class TestStreamManager(BlobExchangeTestBase):
try:
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
except Exception as err:
if isinstance(err, asyncio.CancelledError):
if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8
raise
error = err
self.assertEqual(expected_error, type(error))