diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 270dde189..070c77bd6 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -2155,11 +2155,11 @@ class Daemon(metaclass=JSONRPCServerType): accounts = wallet.get_accounts_or_all(funding_account_ids) txo = None if claim_id: - txo = await self.ledger.get_claim_by_claim_id(accounts, claim_id) + txo = await self.ledger.get_claim_by_claim_id(accounts, claim_id, include_purchase_receipt=True) if not isinstance(txo, Output) or not txo.is_claim: raise Exception(f"Could not find claim with claim_id '{claim_id}'. ") elif url: - txo = (await self.ledger.resolve(accounts, [url]))[url] + txo = (await self.ledger.resolve(accounts, [url], include_purchase_receipt=True))[url] if not isinstance(txo, Output) or not txo.is_claim: raise Exception(f"Could not find claim with url '{url}'. ") else: diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index c9fefd76d..b6870d636 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -1,5 +1,6 @@ import os import zlib +import copy import base64 import asyncio import logging @@ -660,12 +661,21 @@ class Ledger(metaclass=LedgerRegistry): txs: List[Transaction] = await asyncio.gather(*( self.cache_transaction(*tx) for tx in outputs.txs )) - if include_purchase_receipt and accounts: + + txos, blocked = outputs.inflate(txs) + + includes = ( + include_purchase_receipt, include_is_my_output, + include_sent_supports, include_sent_tips + ) + if accounts and any(includes): + copies = [] + receipts = {} + if include_purchase_receipt: priced_claims = [] - for tx in txs: - for txo in tx.outputs: - if txo.has_price: - priced_claims.append(txo) + for txo in txos: + if isinstance(txo, Output) and txo.has_price: + priced_claims.append(txo) if priced_claims: receipts = { txo.purchased_claim_id: txo for txo in @@ -674,46 +684,48 @@ class Ledger(metaclass=LedgerRegistry): purchased_claim_id__in=[c.claim_id for c in priced_claims] ) } - for txo in priced_claims: - txo.purchase_receipt = receipts.get(txo.claim_id) - txos, blocked = outputs.inflate(txs) - if any((include_is_my_output, include_sent_supports, include_sent_tips)): for txo in txos: if isinstance(txo, Output) and txo.can_decode_claim: + # transactions and outputs are cached and shared between wallets + # we don't want to leak informaion between wallet so we add the + # wallet specific metadata on throw away copies of the txos + txo_copy = copy.copy(txo) + copies.append(txo_copy) + if include_purchase_receipt: + txo_copy.purchase_receipt = receipts.get(txo.claim_id) if include_is_my_output: mine = await self.db.get_txo_count( claim_id=txo.claim_id, txo_type__in=CLAIM_TYPES, is_my_output=True, unspent=True, accounts=accounts ) if mine: - txo.is_my_output = True + txo_copy.is_my_output = True else: - txo.is_my_output = False + txo_copy.is_my_output = False if include_sent_supports: supports = await self.db.get_txo_sum( claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], is_my_input=True, is_my_output=True, unspent=True, accounts=accounts ) - txo.sent_supports = supports + txo_copy.sent_supports = supports if include_sent_tips: tips = await self.db.get_txo_sum( claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], is_my_input=True, is_my_output=False, accounts=accounts ) - txo.sent_tips = tips + txo_copy.sent_tips = tips if include_received_tips: tips = await self.db.get_txo_sum( claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], is_my_input=False, is_my_output=True, accounts=accounts ) - txo.received_tips = tips - if not include_purchase_receipt: - # txo's are cached across wallets, this prevents - # leaking receipts between wallets - txo.purchase_receipt = None + txo_copy.received_tips = tips + else: + copies.append(txo) + txos = copies return txos, blocked, outputs.offset, outputs.total async def resolve(self, accounts, urls, **kwargs): @@ -740,8 +752,8 @@ class Ledger(metaclass=LedgerRegistry): include_is_my_output=include_is_my_output ) - async def get_claim_by_claim_id(self, accounts, claim_id) -> Output: - for claim in (await self.claim_search(accounts, claim_id=claim_id))[0]: + async def get_claim_by_claim_id(self, accounts, claim_id, **kwargs) -> Output: + for claim in (await self.claim_search(accounts, claim_id=claim_id, **kwargs))[0]: return claim async def _report_state(self): diff --git a/tests/integration/blockchain/test_purchase_command.py b/tests/integration/blockchain/test_purchase_command.py index ae68dd7f9..26aae8626 100644 --- a/tests/integration/blockchain/test_purchase_command.py +++ b/tests/integration/blockchain/test_purchase_command.py @@ -147,7 +147,7 @@ class PurchaseCommandTests(CommandTestCase): self.assertEqual(result[1]['claim_id'], result[1]['purchase_receipt']['claim_id']) url = result[0]['canonical_url'] - resolve = await self.resolve(url) + resolve = await self.resolve(url, include_purchase_receipt=True) self.assertEqual(result[0]['claim_id'], resolve['purchase_receipt']['claim_id']) self.assertItemCount(await self.daemon.jsonrpc_file_list(), 0) diff --git a/tests/integration/blockchain/test_resolve_command.py b/tests/integration/blockchain/test_resolve_command.py index 7b9b56196..2681e3826 100644 --- a/tests/integration/blockchain/test_resolve_command.py +++ b/tests/integration/blockchain/test_resolve_command.py @@ -321,7 +321,14 @@ class ResolveCommand(BaseResolveTestCase): self.assertEqual('0.0', resolve['sent_tips']) self.assertEqual('0.9', resolve['received_tips']) self.assertEqual('1.4', resolve['meta']['support_amount']) - self.assertNotIn('purchase_receipt', resolve) # prevent leaking cached receipts + + # make sure nothing is leaked between wallets through cached tx/txos + resolve = await self.resolve('priced') + self.assertNotIn('is_my_output', resolve) + self.assertNotIn('purchase_receipt', resolve) + self.assertNotIn('sent_supports', resolve) + self.assertNotIn('sent_tips', resolve) + self.assertNotIn('received_tips', resolve) class ResolveAfterReorg(BaseResolveTestCase):