From 8ed0791b265d037dc4e6d9fdc24959bcc7a80cb1 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 21 Sep 2018 22:18:30 -0400 Subject: [PATCH] improved db.get_transactions() --- tests/integration/test_transactions.py | 11 +++- torba/baseaccount.py | 3 - torba/basedatabase.py | 79 ++++++++++++++++---------- torba/basetransaction.py | 18 ++---- 4 files changed, 65 insertions(+), 46 deletions(-) diff --git a/tests/integration/test_transactions.py b/tests/integration/test_transactions.py index b9173b1c6..3cc2c8073 100644 --- a/tests/integration/test_transactions.py +++ b/tests/integration/test_transactions.py @@ -9,7 +9,7 @@ class BasicTransactionTests(IntegrationTestCase): async def test_sending_and_receiving(self): account1, account2 = self.account, self.wallet.generate_account(self.ledger) - yield self.ledger.update_account(account2) + await d2f(self.ledger.update_account(account2)) self.assertEqual(await self.get_balance(account1), 0) self.assertEqual(await self.get_balance(account2), 0) @@ -53,3 +53,12 @@ class BasicTransactionTests(IntegrationTestCase): await self.on_transaction(tx) # mempool await self.blockchain.generate(1) await self.on_transaction(tx) # confirmed + + txs = await d2f(account1.get_transactions()) + tx = txs[1] + self.assertEqual(round(tx.inputs[0].txo_ref.txo.amount/COIN, 1), 1.1) + self.assertEqual(round(tx.inputs[1].txo_ref.txo.amount/COIN, 1), 1.1) + self.assertEqual(round(tx.outputs[0].amount/COIN, 1), 2.0) + self.assertEqual(tx.outputs[0].get_address(self.ledger), address2) + self.assertEqual(tx.outputs[0].is_change, False) + self.assertEqual(tx.outputs[1].is_change, True) diff --git a/torba/baseaccount.py b/torba/baseaccount.py index f22b5ad24..9a90d8244 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -342,9 +342,6 @@ class BaseAccount: def get_unspent_outputs(self, **constraints): return self.ledger.db.get_utxos_for_account(self, **constraints) - def get_inputs_outputs(self, **constraints): - return self.ledger.db.get_txios_for_account(self, **constraints) - def get_transactions(self): return self.ledger.db.get_transactions(self) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index d7b7c62ea..2912d49b9 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -6,6 +6,7 @@ from twisted.internet import defer from twisted.enterprise import adbapi from torba.hash import TXRefImmutable +from torba.basetransaction import TXORefResolvable log = logging.getLogger(__name__) @@ -267,24 +268,62 @@ class BaseDatabase(SQLiteMixin): return None, None, False @defer.inlineCallbacks - def get_transactions(self, account): - txs = self.run_query( + def get_transactions(self, account, offset=0, limit=100): + offset, limit = min(offset, 0), max(limit, 100) + tx_records = yield self.run_query( """ - SELECT raw FROM tx where txid in ( - SELECT txo.txid - FROM txo + SELECT txid, raw, height FROM tx WHERE txid IN ( + SELECT txo.txid FROM txo JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account UNION - SELECT txo.txid - FROM txi + SELECT txo.txid FROM txi JOIN txo USING (txoid) JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account - ) - """, {'account': account.public_key.address} + ) ORDER BY height DESC LIMIT :offset, :limit + """, {'account': account.public_key.address, 'offset': offset, 'limit': limit} ) - return [account.ledger.transaction_class(values[0]) for values in txs] + txids, txs = [], [] + for r in tx_records: + txids.append(r[0]) + txs.append(account.ledger.transaction_class(raw=r[1], height=r[2])) + + txo_records = yield self.run_query( + """ + SELECT txoid, pubkey_address.chain + FROM txo JOIN pubkey_address USING (address) + WHERE txid IN ({}) + """.format(', '.join(['?']*len(txids))), txids + ) + txos = dict(txo_records) + + txi_records = yield self.run_query( + """ + SELECT txoid, txo.amount, txo.script, txo.txid, txo.position + FROM txi JOIN txo USING (txoid) + WHERE txi.txid IN ({}) + """.format(', '.join(['?']*len(txids))), txids + ) + txis = {} + output_class = account.ledger.transaction_class.output_class + for r in txi_records: + txis[r[0]] = output_class( + r[1], + output_class.script_class(r[2]), + TXRefImmutable.from_id(r[3]), + position=r[4] + ) + + for tx in txs: + for txi in tx.inputs: + if txi.txo_ref.id in txis: + txi.txo_ref = TXORefResolvable(txis[txi.txo_ref.id]) + for txo in tx.outputs: + if txo.id in txos: + txo.is_change = txos[txo.id] == 1 + + return txs def get_balance_for_account(self, account, include_reserved=False, **constraints): if not include_reserved: @@ -323,26 +362,6 @@ class BaseDatabase(SQLiteMixin): ) for values in utxos ] - @defer.inlineCallbacks - def get_txios_for_account(self, account, **constraints): - constraints['account'] = account.public_key.address - utxos = yield self.run_query( - """ - SELECT amount, script, txid, txo.position - FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address - WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi) - """+constraints_to_sql(constraints), constraints - ) - output_class = account.ledger.transaction_class.output_class - return [ - output_class( - values[0], - output_class.script_class(values[1]), - TXRefImmutable.from_id(values[2]), - position=values[3] - ) for values in utxos - ] - def add_keys(self, account, chain, keys): sql = ( "insert into pubkey_address " diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 7302cb3d9..304bb517c 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -181,13 +181,14 @@ class BaseOutput(InputOutput): script_class = BaseOutputScript estimator_class = BaseOutputEffectiveAmountEstimator - __slots__ = 'amount', 'script' + __slots__ = 'amount', 'script', 'is_change' def __init__(self, amount: int, script: BaseOutputScript, tx_ref: TXRef = None, position: int = None) -> None: super().__init__(tx_ref, position) self.amount = amount self.script = script + self.is_change = None @property def ref(self): @@ -226,13 +227,14 @@ class BaseTransaction: input_class = BaseInput output_class = BaseOutput - def __init__(self, raw=None, version=1, locktime=0) -> None: + def __init__(self, raw=None, version=1, locktime=0, height=None) -> None: self._raw = raw self.ref = TXRefMutable(self) self.version = version # type: int self.locktime = locktime # type: int self._inputs = [] # type: List[BaseInput] self._outputs = [] # type: List[BaseOutput] + self.height = height if raw is not None: self._deserialize() @@ -416,6 +418,8 @@ class BaseTransaction: change_address = yield change_account.change.get_or_create_usable_address() change_hash160 = change_account.ledger.address_to_hash160(change_address) change_amount = change - cost_of_change + change_output = cls.output_class.pay_pubkey_hash(change_amount, change_hash160) + change_output.is_change = True tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) yield tx.sign(funding_accounts) @@ -449,13 +453,3 @@ class BaseTransaction: else: raise NotImplementedError("Don't know how to spend this output.") self._reset() - - @defer.inlineCallbacks - def get_my_addresses(self, ledger): - addresses = set() - for txo in self.outputs: - address = ledger.hash160_to_address(txo.script.values['pubkey_hash']) - record = yield ledger.db.get_address(address) - if record is not None: - addresses.add(address) - defer.returnValue(list(addresses))