diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 60ad645e7..34bcc0b85 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -125,7 +125,6 @@ class Ledger(metaclass=LedgerRegistry): self.network.on_status.listen(self.process_status_update) self.accounts = [] - self.pending = 0 self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) self._on_transaction_controller = StreamController() @@ -557,7 +556,7 @@ class Ledger(metaclass=LedgerRegistry): len(remote_history), address ) remote_history_txids = set(txid for txid, _ in remote_history) - requested_txes = await self._request_transaction_batch(to_request, remote_history_txids, address) + requested_txes = await self.request_synced_transactions(to_request, remote_history_txids, address) for tx in requested_txes: pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:" synced_txs.append(tx) @@ -682,7 +681,7 @@ class Ledger(metaclass=LedgerRegistry): tx.position = merkle['pos'] tx.is_verified = merkle_root == header['merkle_root'] - async def request_transactions_for_inflate(self, to_request: Tuple[Tuple[str, int], ...], session_override=None): + async def request_transactions(self, to_request: Tuple[Tuple[str, int], ...], session_override=None): header_cache = {} batches = [[]] remote_heights = {} @@ -702,139 +701,69 @@ class Ledger(metaclass=LedgerRegistry): if not batches[-1]: batches.pop() - async def _single_batch(batch): - if session_override: - batch_result = await self.network.get_transaction_batch( - batch, restricted=False, session=session_override - ) - else: - batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) - for txid, (raw, merkle) in batch_result.items(): - remote_height = remote_heights[txid] - merkle_height = merkle['block_height'] - cache_item = self._tx_cache.get(txid) - if cache_item is None: - cache_item = TransactionCacheItem() - self._tx_cache[txid] = cache_item - tx = cache_item.tx or Transaction(unhexlify(raw), height=remote_height) - tx.height = remote_height - cache_item.tx = tx - if 'merkle' in merkle and remote_heights[txid] > 0: - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - try: - header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height)) - except IndexError: - log.warning("failed to verify %s at height %i", tx.id, merkle_height) - else: - header_cache[remote_heights[txid]] = header - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] - transactions.append(tx) - for batch in batches: - await _single_batch(batch) + transactions.extend(await self._single_batch(batch, remote_heights, header_cache)) return transactions - async def _request_transaction_batch(self, to_request, remote_history, address): - header_cache = {} - batches = [[]] - remote_heights = {} - synced_txs = [] - pending_sync = [] - heights_in_batch = 0 - last_height = 0 - for idx in sorted(to_request): - txid = to_request[idx][0] - height = to_request[idx][1] - remote_heights[txid] = height - if txid not in self._tx_cache: - self._tx_cache[txid] = TransactionCacheItem() - elif self._tx_cache[txid].tx is not None and self._tx_cache[txid].tx.is_verified: - log.warning("has: %s", txid) - pending_sync.append(self._tx_cache[txid].tx) - continue - if height != last_height: - heights_in_batch += 1 - last_height = height - if len(batches[-1]) == 100 or heights_in_batch == 20: - batches.append([]) - heights_in_batch = 1 - batches[-1].append(txid) - if not batches[-1]: - batches.pop() + async def request_synced_transactions(self, to_request, remote_history, address): + pending_sync = await self.request_transactions(((txid, height) for txid, height in to_request.values())) + await asyncio.gather(*(self._sync(tx, remote_history) for tx in pending_sync)) + return pending_sync - last_showed_synced_count = 0 - - async def _single_batch(batch): - batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) - for txid, (raw, merkle) in batch_result.items(): - log.warning("arrived batch %s", txid) - remote_height = remote_heights[txid] - merkle_height = merkle['block_height'] - tx = Transaction(unhexlify(raw), height=remote_height) - tx.height = remote_height - if 'merkle' in merkle and remote_heights[txid] > 0: - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - try: - header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height)) - except IndexError: - log.warning("failed to verify %s at height %i", tx.id, merkle_height) - else: - header_cache[remote_heights[txid]] = header - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] - self._tx_cache[txid].tx = tx - pending_sync.append(tx) - - async def __sync(tx): - check_db_for_txos = [] - log.warning("%s", tx.id) - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - if txi.txo_ref.tx_ref.id not in remote_history: - continue - if txi.txo_ref.tx_ref.id in self._tx_cache: - continue + async def _single_batch(self, batch, remote_heights, header_cache): + batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) + transactions = [] + for txid, (raw, merkle) in batch_result.items(): + remote_height = remote_heights[txid] + merkle_height = merkle['block_height'] + tx = Transaction(unhexlify(raw), height=remote_height) + tx.height = remote_height + if 'merkle' in merkle and remote_heights[txid] > 0: + merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + try: + header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height)) + except IndexError: + log.warning("failed to verify %s at height %i", tx.id, merkle_height) else: - check_db_for_txos.append(txi.txo_ref.id) + header_cache[remote_heights[txid]] = header + tx.position = merkle['pos'] + tx.is_verified = merkle_root == header['merkle_root'] + self._tx_cache[txid].tx = tx + transactions.append(tx) + return transactions - referenced_txos = {} if not check_db_for_txos else { - txo.id: txo for txo in await self.db.get_txos( - txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True - ) - } + async def _sync(self, tx, remote_history): + check_db_for_txos = [] + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + if txi.txo_ref.tx_ref.id not in remote_history: + continue + if txi.txo_ref.tx_ref.id in self._tx_cache: + continue + else: + check_db_for_txos.append(txi.txo_ref.id) - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - if txi.txo_ref.tx_ref.id not in remote_history: - continue - referenced_txo = referenced_txos.get(txi.txo_ref.id) - if referenced_txo is not None: - txi.txo_ref = referenced_txo.ref - continue - wanted_txid = txi.txo_ref.tx_ref.id - if wanted_txid in self._tx_cache: - log.warning("waiting on %s", wanted_txid) - self.pending += 1 - log.warning("total pending %s", self.pending) - await self._tx_cache[wanted_txid].has_tx.wait() - log.warning("got %s", wanted_txid) - self.pending -= 1 - log.warning("total pending %s", self.pending) - txi.txo_ref = self._tx_cache[wanted_txid].tx.outputs[txi.txo_ref.position].ref + referenced_txos = {} if not check_db_for_txos else { + txo.id: txo for txo in await self.db.get_txos( + txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True + ) + } - synced_txs.append(tx) - nonlocal last_showed_synced_count - if last_showed_synced_count + 100 < len(synced_txs): - log.info("synced %i/%i transactions for %s", len(synced_txs), len(remote_history), address) - last_showed_synced_count = len(synced_txs) - - for batch in batches: - await _single_batch(batch) - await asyncio.gather(*(__sync(tx) for tx in pending_sync)) - return synced_txs + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + if txi.txo_ref.tx_ref.id not in remote_history: + continue + referenced_txo = referenced_txos.get(txi.txo_ref.id) + if referenced_txo is not None: + txi.txo_ref = referenced_txo.ref + continue + wanted_txid = txi.txo_ref.tx_ref.id + if wanted_txid in self._tx_cache: + await self._tx_cache[wanted_txid].has_tx.wait() + txi.txo_ref = self._tx_cache[wanted_txid].tx.outputs[txi.txo_ref.position].ref + return tx async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: details = await self.db.get_address(address=address) @@ -907,7 +836,7 @@ class Ledger(metaclass=LedgerRegistry): if len(outputs.txs) > 0: txs: List[Transaction] = [] if session_override: - txs.extend((await self.request_transactions_for_inflate(tuple(outputs.txs), session_override))) + txs.extend((await self.request_transactions(tuple(outputs.txs), session_override))) else: txs.extend((await asyncio.gather(*(self.cache_transaction(*tx) for tx in outputs.txs))))