diff --git a/tests/unit/test_bip32.py b/tests/unit/test_bip32.py index 59e7756f2..12d69089e 100644 --- a/tests/unit/test_bip32.py +++ b/tests/unit/test_bip32.py @@ -5,6 +5,7 @@ from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_pri from torba.bip32 import PubKey, PrivateKey, from_extended_key_string from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class + class BIP32Tests(unittest.TestCase): def test_pubkey_validation(self): @@ -81,7 +82,6 @@ class BIP32Tests(unittest.TestCase): self.assertIsInstance(new_privkey, PrivateKey) self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED]) - def test_from_extended_keys(self): ledger = ledger_class({ 'db': ledger_class.database_class(':memory:'), diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 7831f7916..f22be4769 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1,8 +1,15 @@ -from unittest import TestCase +from twisted.trial import unittest +from twisted.internet import defer + +from torba.wallet import Wallet +from torba.constants import COIN +from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.basedatabase import constraints_to_sql +from .test_transaction import get_output, NULL_HASH -class TestConstraintBuilder(TestCase): + +class TestConstraintBuilder(unittest.TestCase): def test_any(self): constraints = { @@ -21,3 +28,125 @@ class TestConstraintBuilder(TestCase): 'ages__any_age__lt': 38 } ) + + def test_in_list(self): + constraints = {'ages__in': [18, 38]} + self.assertEqual( + constraints_to_sql(constraints, prepend_sql=''), + 'ages IN (:ages_1, :ages_2)' + ) + self.assertEqual( + constraints, { + 'ages_1': 18, + 'ages_2': 38 + } + ) + + def test_in_query(self): + constraints = {'ages__in': 'SELECT age from ages_table'} + self.assertEqual( + constraints_to_sql(constraints, prepend_sql=''), + 'ages IN (SELECT age from ages_table)' + ) + self.assertEqual(constraints, {}) + + def test_not_in_query(self): + constraints = {'ages__not_in': 'SELECT age from ages_table'} + self.assertEqual( + constraints_to_sql(constraints, prepend_sql=''), + 'ages NOT IN (SELECT age from ages_table)' + ) + self.assertEqual(constraints, {}) + + +class TestQueries(unittest.TestCase): + + def setUp(self): + self.ledger = ledger_class({ + 'db': ledger_class.database_class(':memory:'), + 'headers': ledger_class.headers_class(':memory:'), + }) + return self.ledger.db.open() + + @defer.inlineCallbacks + def create_account(self): + account = self.ledger.account_class.generate(self.ledger, Wallet()) + yield account.ensure_address_gap() + return account + + @defer.inlineCallbacks + def create_tx_from_nothing(self, my_account, height): + to_address = yield my_account.receiving.get_or_create_usable_address() + to_hash = ledger_class.address_to_hash160(to_address) + tx = ledger_class.transaction_class(height=height, is_verified=True) \ + .add_inputs([self.txi(self.txo(1, NULL_HASH))]) \ + .add_outputs([self.txo(1, to_hash)]) + yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '') + return tx + + @defer.inlineCallbacks + def create_tx_from_txo(self, txo, to_account, height): + from_hash = txo.script.values['pubkey_hash'] + from_address = self.ledger.hash160_to_address(from_hash) + to_address = yield to_account.receiving.get_or_create_usable_address() + to_hash = ledger_class.address_to_hash160(to_address) + tx = ledger_class.transaction_class(height=height, is_verified=True) \ + .add_inputs([self.txi(txo)]) \ + .add_outputs([self.txo(1, to_hash)]) + yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') + yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '') + return tx + + @defer.inlineCallbacks + def create_tx_to_nowhere(self, txo, height): + from_hash = txo.script.values['pubkey_hash'] + from_address = self.ledger.hash160_to_address(from_hash) + to_hash = NULL_HASH + tx = ledger_class.transaction_class(height=height, is_verified=True) \ + .add_inputs([self.txi(txo)]) \ + .add_outputs([self.txo(1, to_hash)]) + yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') + return tx + + def txo(self, amount, address): + return get_output(int(amount*COIN), address) + + def txi(self, txo): + return ledger_class.transaction_class.input_class.spend(txo) + + @defer.inlineCallbacks + def test_get_transactions(self): + account1 = yield self.create_account() + account2 = yield self.create_account() + tx1 = yield self.create_tx_from_nothing(account1, 1) + tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2) + tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3) + + txs = yield self.ledger.db.get_transactions() + self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs]) + self.assertEqual([3, 2, 1], [tx.height for tx in txs]) + + txs = yield self.ledger.db.get_transactions(account1) + self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs]) + self.assertEqual(txs[0].inputs[0].is_my_account, True) + self.assertEqual(txs[0].outputs[0].is_my_account, False) + self.assertEqual(txs[1].inputs[0].is_my_account, False) + self.assertEqual(txs[1].outputs[0].is_my_account, True) + + txs = yield self.ledger.db.get_transactions(account2) + self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs]) + self.assertEqual(txs[0].inputs[0].is_my_account, True) + self.assertEqual(txs[0].outputs[0].is_my_account, False) + self.assertEqual(txs[1].inputs[0].is_my_account, False) + self.assertEqual(txs[1].outputs[0].is_my_account, True) + + tx = yield self.ledger.db.get_transaction(tx2.id) + self.assertEqual(tx.id, tx2.id) + self.assertEqual(tx.inputs[0].is_my_account, False) + self.assertEqual(tx.outputs[0].is_my_account, False) + tx = yield self.ledger.db.get_transaction(tx2.id, account1) + self.assertEqual(tx.inputs[0].is_my_account, True) + self.assertEqual(tx.outputs[0].is_my_account, False) + tx = yield self.ledger.db.get_transaction(tx2.id, account2) + self.assertEqual(tx.inputs[0].is_my_account, False) + self.assertEqual(tx.outputs[0].is_my_account, True) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 640f82346..dca00f7e7 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -254,20 +254,20 @@ class TransactionIOBalancing(unittest.TestCase): def create_utxos(self, amounts): utxos = [self.txo(amount) for amount in amounts] - self.funding_tx = ledger_class.transaction_class() \ + self.funding_tx = ledger_class.transaction_class(is_verified=True) \ .add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \ .add_outputs(utxos) save_tx = 'insert' for utxo in utxos: yield self.ledger.db.save_transaction_io( - save_tx, self.funding_tx, True, + save_tx, self.funding_tx, self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), utxo.script.values['pubkey_hash'], '' ) save_tx = 'update' - defer.returnValue(utxos) + return utxos @staticmethod def inputs(tx): diff --git a/torba/baseaccount.py b/torba/baseaccount.py index a53712465..108147963 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -357,10 +357,10 @@ class BaseAccount: } def get_unspent_outputs(self, **constraints): - return self.ledger.db.get_utxos_for_account(self, **constraints) + return self.ledger.db.get_utxos(account=self, **constraints) def get_transactions(self) -> List['basetransaction.BaseTransaction']: - return self.ledger.db.get_transactions(self) + return self.ledger.db.get_transactions(account=self) @defer.inlineCallbacks def fund(self, to_account, amount=None, everything=False, diff --git a/torba/basedatabase.py b/torba/basedatabase.py index 898c7e62f..8dd65d8f6 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -27,6 +27,20 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend col, op = key[:-len('__gt')], '>' elif key.endswith('__like'): col, op = key[:-len('__like')], 'LIKE' + elif key.endswith('__in') or key.endswith('__not_in'): + if key.endswith('__in'): + col, op = key[:-len('__in')], 'IN' + else: + col, op = key[:-len('__not_in')], 'NOT IN' + items = constraints.pop(key) + if isinstance(items, list): + placeholders = [] + for item_no, item in enumerate(items, 1): + constraints['{}_{}'.format(col, item_no)] = item + placeholders.append(':{}_{}'.format(col, item_no)) + items = ', '.join(placeholders) + extras.append('{} {} ({})'.format(col, op, items)) + continue elif key.endswith('__any'): subconstraints = constraints.pop(key) extras.append('({})'.format( @@ -46,6 +60,7 @@ class SQLiteMixin: def __init__(self, path): self._db_path = path self.db: adbapi.ConnectionPool = None + self.ledger = None def open(self): log.info("connecting to database: %s", self._db_path) @@ -186,7 +201,7 @@ class BaseDatabase(SQLiteMixin): 'script': sqlite3.Binary(txo.script.source) } - def save_transaction_io(self, save_tx, tx: BaseTransaction, is_verified, address, txhash, history): + def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): def _steps(t): if save_tx == 'insert': @@ -195,11 +210,11 @@ class BaseDatabase(SQLiteMixin): 'raw': sqlite3.Binary(tx.raw), 'height': tx.height, 'position': tx.position, - 'is_verified': is_verified + 'is_verified': tx.is_verified })) elif save_tx == 'update': self.execute(t, *self._update_sql("tx", { - 'height': tx.height, 'position': tx.position, 'is_verified': is_verified + 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) existing_txos = [r[0] for r in self.execute( @@ -260,32 +275,40 @@ class BaseDatabase(SQLiteMixin): return defer.succeed(True) @defer.inlineCallbacks - def get_transaction(self, txid): - result = yield self.run_query( - "SELECT raw, height, position, is_verified FROM tx WHERE txid = ?", (txid,) - ) - if result: - return result[0] - else: - return None, None, None, False + def get_transaction(self, txid, account=None): + txs = yield self.get_transactions(account=account, txid=txid) + if len(txs) == 1: + return txs[0] @defer.inlineCallbacks - def get_transactions(self, account, offset=0, limit=100): - account_id = account.public_key.address + def get_transactions(self, account=None, txid=None, offset=0, limit=1000): + + tx_where = "" + account_id = account.public_key.address if account is not None else None + + if txid is not None: + tx_where = """ + WHERE txid = :txid + """ + elif account is not None: + tx_where = """ + WHERE txid IN ( + SELECT txo.txid FROM txo + JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account + UNION + SELECT txi.txid FROM txi + JOIN txo USING (txoid) + JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account + ) + """ + tx_rows = yield self.run_query( """ - SELECT txid, raw, height, position FROM tx WHERE txid IN ( - SELECT txo.txid FROM txo - JOIN pubkey_address USING (address) - WHERE pubkey_address.account = :account - UNION - SELECT txo.txid FROM txi - JOIN txo USING (txoid) - JOIN pubkey_address USING (address) - WHERE pubkey_address.account = :account - ) ORDER BY height DESC, position DESC LIMIT :offset, :limit - """, { + SELECT txid, raw, height, position, is_verified FROM tx {} + ORDER BY height DESC, position DESC LIMIT :offset, :limit + """.format(tx_where), { 'account': account_id, + 'txid': txid, 'offset': min(offset, 0), 'limit': max(limit, 100) } @@ -293,8 +316,8 @@ class BaseDatabase(SQLiteMixin): txids, txs = [], [] for row in tx_rows: txids.append(row[0]) - txs.append(account.ledger.transaction_class( - raw=row[1], height=row[2], position=row[3] + txs.append(self.ledger.transaction_class( + raw=row[1], height=row[2], position=row[3], is_verified=row[4] )) txo_rows = yield self.run_query( @@ -311,31 +334,18 @@ class BaseDatabase(SQLiteMixin): 'is_my_account': row[2] == account_id } - referenced_txo_rows = yield self.run_query( - """ - SELECT txoid, txo.amount, txo.script, txo.txid, txo.position, chain, account - FROM txi - JOIN txo USING (txoid) - JOIN pubkey_address USING (address) - WHERE txi.txid IN ({}) - """.format(', '.join(['?']*len(txids))), txids - ) - referenced_txos = {} - output_class = account.ledger.transaction_class.output_class - for row in referenced_txo_rows: - referenced_txos[row[0]] = output_class( - amount=row[1], - script=output_class.script_class(row[2]), - tx_ref=TXRefImmutable.from_id(row[3]), - position=row[4], - is_change=row[5] == 1, - is_my_account=row[6] == account_id + referenced_txos = yield self.get_txos( + account=account, + txoid__in="SELECT txoid FROM txi WHERE txi.txid IN ({})".format( + ','.join("'{}'".format(txid) for txid in txids) ) + ) + referenced_txos_map = {txo.id: txo for txo in referenced_txos} for tx in txs: for txi in tx.inputs: - if txi.txo_ref.id in referenced_txos: - txi.txo_ref = TXORefResolvable(referenced_txos[txi.txo_ref.id]) + if txi.txo_ref.id in referenced_txos_map: + txi.txo_ref = TXORefResolvable(referenced_txos_map[txi.txo_ref.id]) for txo in tx.outputs: txo_meta = txos.get(txo.id) if txo_meta is not None: @@ -347,6 +357,35 @@ class BaseDatabase(SQLiteMixin): return txs + @defer.inlineCallbacks + def get_txos(self, account=None, **constraints): + account_id = None + if account is not None: + account_id = account.public_key.address + constraints['account'] = account_id + rows = yield self.run_query( + """ + SELECT amount, script, txid, txo.position, chain, account + FROM txo JOIN pubkey_address USING (address) + """+constraints_to_sql(constraints, prepend_sql='WHERE '), constraints + ) + output_class = self.ledger.transaction_class.output_class + return [ + output_class( + amount=row[0], + script=output_class.script_class(row[1]), + tx_ref=TXRefImmutable.from_id(row[2]), + position=row[3], + is_change=row[4] == 1, + is_my_account=row[5] == account_id + ) for row in rows + ] + + def get_utxos(self, **constraints): + constraints['txoid__not_in'] = 'SELECT txoid FROM txi' + constraints['is_reserved'] = 0 + return self.get_txos(**constraints) + def get_balance_for_account(self, account, include_reserved=False, **constraints): if not include_reserved: constraints['is_reserved'] = 0 @@ -364,26 +403,6 @@ class BaseDatabase(SQLiteMixin): """+constraints_to_sql(constraints), values, 0 ) - @defer.inlineCallbacks - def get_utxos_for_account(self, account, **constraints): - constraints['account'] = account.public_key.address - utxos = yield self.run_query( - """ - SELECT amount, script, txid, txo.position - FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address - WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi) - """+constraints_to_sql(constraints), constraints - ) - output_class = account.ledger.transaction_class.output_class - return [ - output_class( - values[0], - output_class.script_class(values[1]), - TXRefImmutable.from_id(values[2]), - position=values[3] - ) for values in utxos - ] - def add_keys(self, account, chain, keys): sql = ( "insert into pubkey_address " diff --git a/torba/baseledger.py b/torba/baseledger.py index 8a7ca4de8..a9b186f90 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -40,7 +40,7 @@ class LedgerRegistry(type): return mcs.ledgers[ledger_id] -class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'is_verified'))): +class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))): pass @@ -73,6 +73,7 @@ class BaseLedger(metaclass=LedgerRegistry): self.db: BaseDatabase = self.config.get('db') or self.database_class( os.path.join(self.path, "blockchain.db") ) + self.db.ledger = self self.headers: BaseHeaders = self.config.get('headers') or self.headers_class( os.path.join(self.path, "headers") ) @@ -87,7 +88,7 @@ class BaseLedger(metaclass=LedgerRegistry): self.on_transaction.listen( lambda e: log.info( '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', - self.get_id(), e.address, e.tx.height, e.is_verified, e.tx.id + self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id ) ) @@ -214,7 +215,7 @@ class BaseLedger(metaclass=LedgerRegistry): merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[height] tx.position = merkle['pos'] - return merkle_root == header['merkle_root'] + tx.is_verified = merkle_root == header['merkle_root'] @defer.inlineCallbacks def start(self): @@ -366,32 +367,31 @@ class BaseLedger(metaclass=LedgerRegistry): try: # see if we have a local copy of transaction, otherwise fetch it from server - raw, _, position, is_verified = yield self.db.get_transaction(hex_id) + tx = yield self.db.get_transaction(hex_id) save_tx = None - if raw is None: + if tx is None: _raw = yield self.network.get_transaction(hex_id) - tx = self.transaction_class(unhexlify(_raw), height=remote_height) + tx = self.transaction_class(unhexlify(_raw)) save_tx = 'insert' - else: - tx = self.transaction_class(raw, height=remote_height) - if remote_height > 0 and (not is_verified or position is None): - is_verified = yield self.validate_transaction_and_set_position(tx, remote_height) - is_verified = 1 if is_verified else 0 + tx.height = remote_height + + if remote_height > 0 and (not tx.is_verified or tx.position == -1): + yield self.validate_transaction_and_set_position(tx, remote_height) if save_tx is None: save_tx = 'update' yield self.db.save_transaction_io( - save_tx, tx, is_verified, address, self.address_to_hash160(address), + save_tx, tx, address, self.address_to_hash160(address), ''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history) ) log.debug( "%s: sync'ed tx %s for address: %s, height: %s, verified: %s", - self.get_id(), hex_id, address, remote_height, is_verified + self.get_id(), hex_id, address, tx.height, tx.is_verified ) - self._on_transaction_controller.add(TransactionEvent(address, tx, is_verified)) + self._on_transaction_controller.add(TransactionEvent(address, tx)) except Exception: log.exception('Failed to synchronize transaction:') diff --git a/torba/basetransaction.py b/torba/basetransaction.py index d19a855b0..42986e94d 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -141,7 +141,7 @@ class BaseInput(InputOutput): def is_my_account(self) -> Optional[bool]: """ True if the output this input spends is yours. """ if self.txo_ref.txo is None: - raise ValueError('Cannot resolve output to determine ownership.') + return False return self.txo_ref.txo.is_my_account @classmethod @@ -237,7 +237,7 @@ class BaseTransaction: input_class = BaseInput output_class = BaseOutput - def __init__(self, raw=None, version: int = 1, locktime: int = 0, + def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False, height: int = -1, position: int = -1) -> None: self._raw = raw self.ref = TXRefMutable(self) @@ -245,6 +245,7 @@ class BaseTransaction: self.locktime = locktime self._inputs: List[BaseInput] = [] self._outputs: List[BaseOutput] = [] + self.is_verified = is_verified self.height = height self.position = position if raw is not None: