From c29b4c476d814358fa284527f8d41af4ed7fa43d Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Tue, 25 Sep 2018 18:02:50 -0400 Subject: [PATCH] + tx.position, + tx.net_account_balance, + txo.is_my_account --- tests/unit/test_transaction.py | 46 ++++++++++++++++++++- torba/basedatabase.py | 74 ++++++++++++++++++++++------------ torba/baseledger.py | 21 +++++----- torba/basetransaction.py | 61 +++++++++++++++++++++------- 4 files changed, 151 insertions(+), 51 deletions(-) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 2ffb5b022..d007bc12c 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -54,6 +54,50 @@ class TestSizeAndFeeEstimation(unittest.TestCase): self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size) +class TestAccountBalanceImpactFromTransaction(unittest.TestCase): + + def test_is_my_account_not_set(self): + tx = get_transaction() + with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"): + _ = tx.net_account_balance + tx.inputs[0].is_my_account = True + with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"): + _ = tx.net_account_balance + tx.outputs[0].is_my_account = True + # all inputs/outputs are set now so it should work + _ = tx.net_account_balance + + def test_paying_from_my_account_to_other_account(self): + tx = ledger_class.transaction_class() \ + .add_inputs([get_input(300*CENT)]) \ + .add_outputs([get_output(190*CENT, NULL_HASH), + get_output(100*CENT, NULL_HASH)]) + tx.inputs[0].is_my_account = True + tx.outputs[0].is_my_account = False + tx.outputs[1].is_my_account = True + self.assertEqual(tx.net_account_balance, -200*CENT) + + def test_paying_from_other_account_to_my_account(self): + tx = ledger_class.transaction_class() \ + .add_inputs([get_input(300*CENT)]) \ + .add_outputs([get_output(190*CENT, NULL_HASH), + get_output(100*CENT, NULL_HASH)]) + tx.inputs[0].is_my_account = False + tx.outputs[0].is_my_account = True + tx.outputs[1].is_my_account = False + self.assertEqual(tx.net_account_balance, 190*CENT) + + def test_paying_from_my_account_to_my_account(self): + tx = ledger_class.transaction_class() \ + .add_inputs([get_input(300*CENT)]) \ + .add_outputs([get_output(190*CENT, NULL_HASH), + get_output(100*CENT, NULL_HASH)]) + tx.inputs[0].is_my_account = True + tx.outputs[0].is_my_account = True + tx.outputs[1].is_my_account = True + self.assertEqual(tx.net_account_balance, -10*CENT) # lost to fee + + class TestTransactionSerialization(unittest.TestCase): def test_genesis_transaction(self): @@ -217,7 +261,7 @@ class TransactionIOBalancing(unittest.TestCase): save_tx = 'insert' for utxo in utxos: yield self.ledger.db.save_transaction_io( - save_tx, self.funding_tx, 1, True, + save_tx, self.funding_tx, True, self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), utxo.script.values['pubkey_hash'], '' ) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index 1730624e8..35cb4b087 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -143,6 +143,7 @@ class BaseDatabase(SQLiteMixin): txid text primary key, raw blob not null, height integer not null, + position integer not null, is_verified boolean not null default 0 ); """ @@ -185,19 +186,20 @@ class BaseDatabase(SQLiteMixin): 'script': sqlite3.Binary(txo.script.source) } - def save_transaction_io(self, save_tx, tx, height, is_verified, address, txhash, history): + def save_transaction_io(self, save_tx, tx, is_verified, address, txhash, history): def _steps(t): if save_tx == 'insert': self.execute(t, *self._insert_sql('tx', { 'txid': tx.id, 'raw': sqlite3.Binary(tx.raw), - 'height': height, + 'height': tx.height, + 'position': tx.position, 'is_verified': is_verified })) elif save_tx == 'update': self.execute(t, *self._update_sql("tx", { - 'height': height, 'is_verified': is_verified + 'height': tx.height, 'position': tx.position, 'is_verified': is_verified }, 'txid = ?', (tx.id,))) existing_txos = [r[0] for r in self.execute( @@ -260,19 +262,19 @@ class BaseDatabase(SQLiteMixin): @defer.inlineCallbacks def get_transaction(self, txid): result = yield self.run_query( - "SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,) + "SELECT raw, height, position, is_verified FROM tx WHERE txid = ?", (txid,) ) if result: return result[0] else: - return None, None, False + return None, None, None, False @defer.inlineCallbacks def get_transactions(self, account, offset=0, limit=100): - offset, limit = min(offset, 0), max(limit, 100) + account_id = account.public_key.address tx_rows = yield self.run_query( """ - SELECT txid, raw, height FROM tx WHERE txid IN ( + SELECT txid, raw, height, position FROM tx WHERE txid IN ( SELECT txo.txid FROM txo JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account @@ -281,47 +283,67 @@ class BaseDatabase(SQLiteMixin): JOIN txo USING (txoid) JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account - ) ORDER BY height DESC LIMIT :offset, :limit - """, {'account': account.public_key.address, 'offset': offset, 'limit': limit} + ) ORDER BY height DESC, position DESC LIMIT :offset, :limit + """, { + 'account': account_id, + 'offset': min(offset, 0), + 'limit': max(limit, 100) + } ) txids, txs = [], [] for row in tx_rows: txids.append(row[0]) - txs.append(account.ledger.transaction_class(raw=row[1], height=row[2])) + txs.append(account.ledger.transaction_class( + raw=row[1], height=row[2], position=row[3] + )) txo_rows = yield self.run_query( """ - SELECT txoid, pubkey_address.chain + SELECT txoid, chain, account FROM txo JOIN pubkey_address USING (address) WHERE txid IN ({}) """.format(', '.join(['?']*len(txids))), txids ) - txos = dict(txo_rows) + txos = {} + for row in txo_rows: + txos[row[0]] = { + 'is_change': row[1] == 1, + 'is_my_account': row[2] == account_id + } - txi_rows = yield self.run_query( + referenced_txo_rows = yield self.run_query( """ - SELECT txoid, txo.amount, txo.script, txo.txid, txo.position - FROM txi JOIN txo USING (txoid) + SELECT txoid, txo.amount, txo.script, txo.txid, txo.position, chain, account + FROM txi + JOIN txo USING (txoid) + JOIN pubkey_address USING (address) WHERE txi.txid IN ({}) """.format(', '.join(['?']*len(txids))), txids ) - txis = {} + referenced_txos = {} output_class = account.ledger.transaction_class.output_class - for row in txi_rows: - txis[row[0]] = output_class( - row[1], - output_class.script_class(row[2]), - TXRefImmutable.from_id(row[3]), - position=row[4] + for row in referenced_txo_rows: + referenced_txos[row[0]] = output_class( + amount=row[1], + script=output_class.script_class(row[2]), + tx_ref=TXRefImmutable.from_id(row[3]), + position=row[4], + is_change=row[5] == 1, + is_my_account=row[6] == account_id ) for tx in txs: for txi in tx.inputs: - if txi.txo_ref.id in txis: - txi.txo_ref = TXORefResolvable(txis[txi.txo_ref.id]) + if txi.txo_ref.id in referenced_txos: + txi.txo_ref = TXORefResolvable(referenced_txos[txi.txo_ref.id]) for txo in tx.outputs: - if txo.id in txos: - txo.is_change = txos[txo.id] == 1 + txo_meta = txos.get(txo.id) + if txo_meta is not None: + txo.is_change = txo_meta['is_change'] + txo.is_my_account = txo_meta['is_my_account'] + else: + txo.is_change = False + txo.is_my_account = False return txs diff --git a/torba/baseledger.py b/torba/baseledger.py index 9cb1c7337..1e836a1d8 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -40,7 +40,7 @@ class LedgerRegistry(type): return mcs.ledgers[ledger_id] -class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))): +class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'is_verified'))): pass @@ -87,7 +87,7 @@ class BaseLedger(metaclass=LedgerRegistry): self.on_transaction.listen( lambda e: log.info( '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', - self.get_id(), e.address, e.height, e.is_verified, e.tx.id + self.get_id(), e.address, e.tx.height, e.is_verified, e.tx.id ) ) @@ -207,12 +207,13 @@ class BaseLedger(metaclass=LedgerRegistry): return hexlify(working_branch[::-1]) @defer.inlineCallbacks - def is_valid_transaction(self, tx, height): + def validate_transaction_and_set_position(self, tx, height): if not height <= len(self.headers): return False merkle = yield self.network.get_merkle(tx.id, height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[height] + tx.position = merkle['pos'] return merkle_root == header['merkle_root'] @defer.inlineCallbacks @@ -365,23 +366,23 @@ class BaseLedger(metaclass=LedgerRegistry): try: # see if we have a local copy of transaction, otherwise fetch it from server - raw, _, is_verified = yield self.db.get_transaction(hex_id) + raw, _, position, is_verified = yield self.db.get_transaction(hex_id) save_tx = None if raw is None: _raw = yield self.network.get_transaction(hex_id) - tx = self.transaction_class(unhexlify(_raw)) + tx = self.transaction_class(unhexlify(_raw), height=remote_height) save_tx = 'insert' else: - tx = self.transaction_class(raw) + tx = self.transaction_class(raw, height=remote_height) - if remote_height > 0 and not is_verified: - is_verified = yield self.is_valid_transaction(tx, remote_height) + if remote_height > 0 and (not is_verified or position is None): + is_verified = yield self.validate_transaction_and_set_position(tx, remote_height) is_verified = 1 if is_verified else 0 if save_tx is None: save_tx = 'update' yield self.db.save_transaction_io( - save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address), + save_tx, tx, is_verified, address, self.address_to_hash160(address), ''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history) ) @@ -390,7 +391,7 @@ class BaseLedger(metaclass=LedgerRegistry): self.get_id(), hex_id, address, remote_height, is_verified ) - self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified)) + self._on_transaction_controller.add(TransactionEvent(address, tx, is_verified)) except Exception: log.exception('Failed to synchronize transaction:') diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 304bb517c..20935ba3c 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -137,6 +137,13 @@ class BaseInput(InputOutput): raise ValueError('Cannot resolve output to get amount.') return self.txo_ref.txo.amount + @property + def is_my_account(self) -> int: + """ True if the output this input spends is yours. """ + if self.txo_ref.txo is None: + raise ValueError('Cannot resolve output to determine ownership.') + return self.txo_ref.txo.is_my_account + @classmethod def deserialize_from(cls, stream): tx_ref = TXRefImmutable.from_hash(stream.read(32)) @@ -181,14 +188,17 @@ class BaseOutput(InputOutput): script_class = BaseOutputScript estimator_class = BaseOutputEffectiveAmountEstimator - __slots__ = 'amount', 'script', 'is_change' + __slots__ = 'amount', 'script', 'is_change', 'is_my_account' def __init__(self, amount: int, script: BaseOutputScript, - tx_ref: TXRef = None, position: int = None) -> None: + tx_ref: TXRef = None, position: int = None, + is_change: Optional[bool] = None, is_my_account: Optional[bool] = None + ) -> None: super().__init__(tx_ref, position) self.amount = amount self.script = script - self.is_change = None + self.is_change = is_change + self.is_my_account = is_my_account @property def ref(self): @@ -227,14 +237,16 @@ class BaseTransaction: input_class = BaseInput output_class = BaseOutput - def __init__(self, raw=None, version=1, locktime=0, height=None) -> None: + def __init__(self, raw=None, version: int=1, locktime: int=0, + height: int=-1, position: int=-1) -> None: self._raw = raw self.ref = TXRefMutable(self) - self.version = version # type: int - self.locktime = locktime # type: int - self._inputs = [] # type: List[BaseInput] - self._outputs = [] # type: List[BaseOutput] + self.version = version + self.locktime = locktime + self._inputs: List[BaseInput] = [] + self._outputs: List[BaseOutput] = [] self.height = height + self.position = position if raw is not None: self._deserialize() @@ -257,11 +269,11 @@ class BaseTransaction: self.ref.reset() @property - def inputs(self): # type: () -> ReadOnlyList[BaseInput] + def inputs(self) -> ReadOnlyList[BaseInput]: return ReadOnlyList(self._inputs) @property - def outputs(self): # type: () -> ReadOnlyList[BaseOutput] + def outputs(self) -> ReadOnlyList[BaseOutput]: return ReadOnlyList(self._outputs) def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction': @@ -301,18 +313,39 @@ class BaseTransaction: return sum(o.amount for o in self.outputs) @property - def fee(self): + def net_account_balance(self) -> int: + balance = 0 + for txi in self.inputs: + if txi.is_my_account is None: + raise ValueError( + "Cannot access net_account_balance if inputs/outputs do not " + "have is_my_account set properly." + ) + elif txi.is_my_account: + balance -= txi.amount + for txo in self.outputs: + if txo.is_my_account is None: + raise ValueError( + "Cannot access net_account_balance if inputs/outputs do not " + "have is_my_account set properly." + ) + elif txo.is_my_account: + balance += txo.amount + return balance + + @property + def fee(self) -> int: return self.input_sum - self.output_sum - def get_base_fee(self, ledger): + def get_base_fee(self, ledger) -> int: """ Fee for base tx excluding inputs and outputs. """ return self.base_size * ledger.fee_per_byte - def get_effective_input_sum(self, ledger): + def get_effective_input_sum(self, ledger) -> int: """ Sum of input values *minus* the cost involved to spend them. """ return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs) - def get_total_output_sum(self, ledger): + def get_total_output_sum(self, ledger) -> int: """ Sum of output values *plus* the cost involved to spend them. """ return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)