fix torba tests

This commit is contained in:
Lex Berezhny 2019-08-12 01:16:15 -04:00
parent 98d4d00f96
commit 0d6248fbeb
5 changed files with 51 additions and 47 deletions

View file

@ -110,15 +110,8 @@ class MainNetLedger(BaseLedger):
'Failed to display wallet state, please file issue ' 'Failed to display wallet state, please file issue '
'for this bug along with the traceback you see below:') 'for this bug along with the traceback you see below:')
def constraint_account_or_all(self, constraints): @staticmethod
account = constraints.pop('account', None) def constraint_spending_utxos(constraints):
if account:
constraints['accounts'] = [account]
else:
constraints['accounts'] = self.accounts
def constraint_spending_utxos(self, constraints):
self.constraint_account_or_all(constraints)
constraints.update({'is_claim': 0, 'is_update': 0, 'is_support': 0}) constraints.update({'is_claim': 0, 'is_update': 0, 'is_support': 0})
def get_utxos(self, **constraints): def get_utxos(self, **constraints):

View file

@ -259,7 +259,7 @@ class TestQueries(AsyncioTestCase):
tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height) tx = await self.create_tx_from_txo(tx.outputs[0], account, height=height)
variable_limit = self.ledger.db.MAX_QUERY_VARIABLES variable_limit = self.ledger.db.MAX_QUERY_VARIABLES
for limit in range(variable_limit-2, variable_limit+2): for limit in range(variable_limit-2, variable_limit+2):
txs = await self.ledger.db.get_transactions(limit=limit, order_by='height asc') txs = await self.ledger.get_transactions(limit=limit, order_by='height asc')
self.assertEqual(len(txs), limit) self.assertEqual(len(txs), limit)
inputs, outputs, last_tx = set(), set(), txs[0] inputs, outputs, last_tx = set(), set(), txs[0]
for tx in txs[1:]: for tx in txs[1:]:
@ -276,73 +276,73 @@ class TestQueries(AsyncioTestCase):
account2 = await self.create_account() account2 = await self.create_account()
self.assertEqual(52, await self.ledger.db.get_address_count()) self.assertEqual(52, await self.ledger.db.get_address_count())
self.assertEqual(0, await self.ledger.db.get_transaction_count()) self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account1, account2]))
self.assertEqual(0, await self.ledger.db.get_utxo_count()) self.assertEqual(0, await self.ledger.db.get_utxo_count())
self.assertEqual([], await self.ledger.db.get_utxos()) self.assertEqual([], await self.ledger.db.get_utxos())
self.assertEqual(0, await self.ledger.db.get_txo_count()) self.assertEqual(0, await self.ledger.db.get_txo_count())
self.assertEqual(0, await self.ledger.db.get_balance()) self.assertEqual(0, await self.ledger.db.get_balance())
self.assertEqual(0, await self.ledger.db.get_balance(account=account1)) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(account=account2)) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
tx1 = await self.create_tx_from_nothing(account1, 1) tx1 = await self.create_tx_from_nothing(account1, 1)
self.assertEqual(1, await self.ledger.db.get_transaction_count(account=account1)) self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_transaction_count(account=account2)) self.assertEqual(0, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(account=account1)) self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1)) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_txo_count(account=account2)) self.assertEqual(0, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance()) self.assertEqual(10**8, await self.ledger.db.get_balance())
self.assertEqual(10**8, await self.ledger.db.get_balance(account=account1)) self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(account=account2)) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2) tx2 = await self.create_tx_from_txo(tx1.outputs[0], account2, 2)
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account1)) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_transaction_count(account=account2)) self.assertEqual(1, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account1)) self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1)) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_utxo_count(account=account2)) self.assertEqual(1, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account2)) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(10**8, await self.ledger.db.get_balance()) self.assertEqual(10**8, await self.ledger.db.get_balance())
self.assertEqual(0, await self.ledger.db.get_balance(account=account1)) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(10**8, await self.ledger.db.get_balance(account=account2)) self.assertEqual(10**8, await self.ledger.db.get_balance(accounts=[account2]))
tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3) tx3 = await self.create_tx_to_nowhere(tx2.outputs[0], 3)
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account1)) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account1]))
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account2)) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account1)) self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account1]))
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account1)) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_utxo_count(account=account2)) self.assertEqual(0, await self.ledger.db.get_utxo_count(accounts=[account2]))
self.assertEqual(1, await self.ledger.db.get_txo_count(account=account2)) self.assertEqual(1, await self.ledger.db.get_txo_count(accounts=[account2]))
self.assertEqual(0, await self.ledger.db.get_balance()) self.assertEqual(0, await self.ledger.db.get_balance())
self.assertEqual(0, await self.ledger.db.get_balance(account=account1)) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account1]))
self.assertEqual(0, await self.ledger.db.get_balance(account=account2)) self.assertEqual(0, await self.ledger.db.get_balance(accounts=[account2]))
txs = await self.ledger.db.get_transactions() txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual([3, 2, 1], [tx.height for tx in txs]) self.assertEqual([3, 2, 1], [tx.height for tx in txs])
txs = await self.ledger.db.get_transactions(account=account1) txs = await self.ledger.db.get_transactions(accounts=[account1])
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True) self.assertEqual(txs[1].outputs[0].is_my_account, True)
txs = await self.ledger.db.get_transactions(account=account2) txs = await self.ledger.db.get_transactions(accounts=[account2])
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs]) self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True) self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False) self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True) self.assertEqual(txs[1].outputs[0].is_my_account, True)
self.assertEqual(2, await self.ledger.db.get_transaction_count(account=account2)) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2]))
tx = await self.ledger.db.get_transaction(txid=tx2.id) tx = await self.ledger.db.get_transaction(txid=tx2.id)
self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.id, tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False)
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account1) tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account1])
self.assertEqual(tx.inputs[0].is_my_account, True) self.assertEqual(tx.inputs[0].is_my_account, True)
self.assertEqual(tx.outputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False)
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2) tx = await self.ledger.db.get_transaction(txid=tx2.id, accounts=[account2])
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, True)
@ -351,6 +351,6 @@ class TestQueries(AsyncioTestCase):
txos = await self.ledger.db.get_txos() txos = await self.ledger.db.get_txos()
self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos]) self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos])
self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos]) self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos])
txs = await self.ledger.db.get_transactions() txs = await self.ledger.db.get_transactions(accounts=[account1, account2])
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs]) self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])

View file

@ -386,14 +386,14 @@ class BaseAccount:
return addresses return addresses
async def get_addresses(self, **constraints) -> List[str]: async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', account=self, **constraints) rows = await self.ledger.db.select_addresses('address', accounts=[self], **constraints)
return [r[0] for r in rows] return [r[0] for r in rows]
def get_address_records(self, **constraints): def get_address_records(self, **constraints):
return self.ledger.db.get_addresses(account=self, **constraints) return self.ledger.db.get_addresses(accounts=[self], **constraints)
def get_address_count(self, **constraints): def get_address_count(self, **constraints):
return self.ledger.db.get_address_count(account=self, **constraints) return self.ledger.db.get_address_count(accounts=[self], **constraints)
def get_private_key(self, chain: int, index: int) -> PrivateKey: def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account." assert not self.encrypted, "Cannot get private key on encrypted wallet account."

View file

@ -392,7 +392,7 @@ class BaseDatabase(SQLiteMixin):
return True return True
async def select_transactions(self, cols, accounts=None, **constraints): async def select_transactions(self, cols, accounts=None, **constraints):
if 'txid' not in constraints: if not set(constraints) & {'txid', 'txid__in'}:
assert accounts is not None, "'accounts' argument required when no 'txid' constraint" assert accounts is not None, "'accounts' argument required when no 'txid' constraint"
constraints.update({ constraints.update({
f'$account{i}': a.public_key.address for i, a in enumerate(accounts) f'$account{i}': a.public_key.address for i, a in enumerate(accounts)

View file

@ -227,16 +227,27 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_tx(self, tx): def release_tx(self, tx):
return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) return self.release_outputs([txi.txo_ref.txo for txi in tx.inputs])
def constraint_account_or_all(self, constraints):
account = constraints.pop('account', None)
if account:
constraints['accounts'] = [account]
else:
constraints['accounts'] = self.accounts
def get_utxos(self, **constraints): def get_utxos(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_utxos(**constraints) return self.db.get_utxos(**constraints)
def get_utxo_count(self, **constraints): def get_utxo_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_utxo_count(**constraints) return self.db.get_utxo_count(**constraints)
def get_transactions(self, **constraints): def get_transactions(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transactions(**constraints) return self.db.get_transactions(**constraints)
def get_transaction_count(self, **constraints): def get_transaction_count(self, **constraints):
self.constraint_account_or_all(constraints)
return self.db.get_transaction_count(**constraints) return self.db.get_transaction_count(**constraints)
async def get_local_status_and_history(self, address): async def get_local_status_and_history(self, address):