From 0d6248fbeb82d10354659cc1e826614e0f3602cd Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Mon, 12 Aug 2019 01:16:15 -0400 Subject: [PATCH] fix torba tests --- lbry/lbry/wallet/ledger.py | 11 +-- .../tests/client_tests/unit/test_database.py | 68 +++++++++---------- torba/torba/client/baseaccount.py | 6 +- torba/torba/client/basedatabase.py | 2 +- torba/torba/client/baseledger.py | 11 +++ 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/lbry/lbry/wallet/ledger.py b/lbry/lbry/wallet/ledger.py index 4c3537452..ff2d878ed 100644 --- a/lbry/lbry/wallet/ledger.py +++ b/lbry/lbry/wallet/ledger.py @@ -110,15 +110,8 @@ class MainNetLedger(BaseLedger): 'Failed to display wallet state, please file issue ' 'for this bug along with the traceback you see below:') - def constraint_account_or_all(self, constraints): - account = constraints.pop('account', None) - if account: - constraints['accounts'] = [account] - else: - constraints['accounts'] = self.accounts - - def constraint_spending_utxos(self, constraints): - self.constraint_account_or_all(constraints) + @staticmethod + def constraint_spending_utxos(constraints): constraints.update({'is_claim': 0, 'is_update': 0, 'is_support': 0}) def get_utxos(self, **constraints): diff --git a/torba/tests/client_tests/unit/test_database.py b/torba/tests/client_tests/unit/test_database.py index f5e018aed..e5217d12c 100644 --- a/torba/tests/client_tests/unit/test_database.py +++ b/torba/tests/client_tests/unit/test_database.py @@ -259,7 +259,7 @@ class TestQueries(AsyncioTestCase): tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height) variable_limit = self.ledger.db.MAX_QUERY_VARIABLES for limit in range(variable_limit-2, variable_limit+2): - txs = await self.ledger.db.get_transactions(limit=limit, order_by='height asc') + txs = await self.ledger.get_transactions(limit=limit, order_by='height asc') self.assertEqual(len(txs), limit) inputs, outputs, last_tx = set(), set(), txs[0] for tx in txs[1:]: @@ -276,73 +276,73 @@ class TestQueries(AsyncioTestCase): account2 = await self.create_account() self.assertEqual(52, await self.ledger.db.get_address_count()) - self.assertEqual(0, await self.ledger.db.get_transaction_count()) + self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2])) self.assertEqual(0, await self.ledger.db.get_utxo_count()) self.assertEqual([], await self.ledger.db.get_utxos()) self.assertEqual(0, await self.ledger.db.get_txo_count()) self.assertEqual(0, await self.ledger.db.get_balance()) - self.assertEqual(0, await self.ledger.db.get_balance(account=account1)) - self.assertEqual(0, await self.ledger.db.get_balance(account=account2)) + self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) + self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) tx1 = await self.create_tx_from_nothing(account1, 1) - self.assertEqual(1, await self.ledger.db.get_transaction_count(account=account1)) - self.assertEqual(0, await self.ledger.db.get_transaction_count(account=account2)) - self.assertEqual(1, await self.ledger.db.get_utxo_count(account=account1)) - self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1)) - self.assertEqual(0, await self.ledger.db.get_txo_count(account=account2)) + self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1])) + self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account2])) + self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1])) + self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) + self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2])) self.assertEqual(10**8, await self.ledger.db.get_balance()) - self.assertEqual(10**8, await self.ledger.db.get_balance(account=account1)) - self.assertEqual(0, await self.ledger.db.get_balance(account=account2)) + self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1])) + self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2) - self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account1)) - self.assertEqual(1, await self.ledger.db.get_transaction_count(account=account2)) - self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account1)) - self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1)) - self.assertEqual(1, await self.ledger.db.get_utxo_count(account=account2)) - self.assertEqual(1, await self.ledger.db.get_txo_count(account=account2)) + self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1])) + self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2])) + self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1])) + self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) + self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2])) + self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2])) self.assertEqual(10**8, await self.ledger.db.get_balance()) - self.assertEqual(0, await self.ledger.db.get_balance(account=account1)) - self.assertEqual(10**8, await self.ledger.db.get_balance(account=account2)) + self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) + self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2])) tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3) - self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account1)) - self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account2)) - self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account1)) - self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1)) - self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account2)) - self.assertEqual(1, await self.ledger.db.get_txo_count(account=account2)) + self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1])) + self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2])) + self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1])) + self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1])) + self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2])) + self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2])) self.assertEqual(0, await self.ledger.db.get_balance()) - self.assertEqual(0, await self.ledger.db.get_balance(account=account1)) - self.assertEqual(0, await self.ledger.db.get_balance(account=account2)) + self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1])) + self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2])) - txs = await self.ledger.db.get_transactions() + txs = await self.ledger.db.get_transactions(accounts=[account1, account2]) self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([3, 2, 1], [tx.height for tx in txs]) - txs = await self.ledger.db.get_transactions(account=account1) + txs = await self.ledger.db.get_transactions(accounts=[account1]) self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].outputs[0].is_my_account, True) - txs = await self.ledger.db.get_transactions(account=account2) + txs = await self.ledger.db.get_transactions(accounts=[account2]) self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs]) self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].outputs[0].is_my_account, True) - self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account2)) + self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2])) tx = await self.ledger.db.get_transaction(txid=tx2.id) self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False) - tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1) + tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account1]) self.assertEqual(tx.inputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, False) - tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2) + tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account2]) self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, True) @@ -351,6 +351,6 @@ class TestQueries(AsyncioTestCase): txos = await self.ledger.db.get_txos() self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos]) self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos]) - txs = await self.ledger.db.get_transactions() + txs = await self.ledger.db.get_transactions(accounts=[account1, account2]) self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs]) self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) diff --git a/torba/torba/client/baseaccount.py b/torba/torba/client/baseaccount.py index 9b8328117..b1b43a060 100644 --- a/torba/torba/client/baseaccount.py +++ b/torba/torba/client/baseaccount.py @@ -386,14 +386,14 @@ class BaseAccount: return addresses async def get_addresses(self, **constraints) -> List[str]: - rows = await self.ledger.db.select_addresses('address', account=self, **constraints) + rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints) return [r[0] for r in rows] def get_address_records(self, **constraints): - return self.ledger.db.get_addresses(account=self, **constraints) + return self.ledger.db.get_addresses(accounts=[self], **constraints) def get_address_count(self, **constraints): - return self.ledger.db.get_address_count(account=self, **constraints) + return self.ledger.db.get_address_count(accounts=[self], **constraints) def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." diff --git a/torba/torba/client/basedatabase.py b/torba/torba/client/basedatabase.py index ec48b6237..d0f37ccaf 100644 --- a/torba/torba/client/basedatabase.py +++ b/torba/torba/client/basedatabase.py @@ -392,7 +392,7 @@ class BaseDatabase(SQLiteMixin): return True async def select_transactions(self, cols, accounts=None, **constraints): - if 'txid' not in constraints: + if not set(constraints) & {'txid', 'txid__in'}: assert accounts is not None, "'accounts' argument required when no 'txid' constraint" constraints.update({ f'$account{i}': a.public_key.address for i, a in enumerate(accounts) diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index 8b0c6ad7d..8f16455d9 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -227,16 +227,27 @@ class BaseLedger(metaclass=LedgerRegistry): def release_tx(self, tx): return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) + def constraint_account_or_all(self, constraints): + account = constraints.pop('account', None) + if account: + constraints['accounts'] = [account] + else: + constraints['accounts'] = self.accounts + def get_utxos(self, **constraints): + self.constraint_account_or_all(constraints) return self.db.get_utxos(**constraints) def get_utxo_count(self, **constraints): + self.constraint_account_or_all(constraints) return self.db.get_utxo_count(**constraints) def get_transactions(self, **constraints): + self.constraint_account_or_all(constraints) return self.db.get_transactions(**constraints) def get_transaction_count(self, **constraints): + self.constraint_account_or_all(constraints) return self.db.get_transaction_count(**constraints) async def get_local_status_and_history(self, address):