From 95cbdcb1313b92802fd5bbbf5c1efa9f0956181f Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Sun, 7 Oct 2018 14:53:44 -0400 Subject: [PATCH] refactored sql queries, added full pagination support: limit, offset and count --- tests/integration/test_transactions.py | 2 +- tests/unit/test_account.py | 5 +- tests/unit/test_database.py | 106 +++++--- tests/unit/test_transaction.py | 2 +- torba/baseaccount.py | 51 ++-- torba/basedatabase.py | 327 +++++++++++++------------ torba/baseledger.py | 6 +- 7 files changed, 270 insertions(+), 229 deletions(-) diff --git a/tests/integration/test_transactions.py b/tests/integration/test_transactions.py index 3cc2c8073..697397ba1 100644 --- a/tests/integration/test_transactions.py +++ b/tests/integration/test_transactions.py @@ -43,7 +43,7 @@ class BasicTransactionTests(IntegrationTestCase): self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) - utxos = await d2f(self.account.get_unspent_outputs()) + utxos = await d2f(self.account.get_utxos()) tx = await d2f(self.ledger.transaction_class.create( [self.ledger.transaction_class.input_class.spend(utxos[0])], [], diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index f23f29c0d..912c3bbce 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -223,7 +223,10 @@ class TestSingleKeyAccount(unittest.TestCase): self.assertEqual(new_keys[0], account.public_key.address) records = yield account.receiving.get_address_records() self.assertEqual(records, [{ - 'position': 0, 'address': account.public_key.address, 'used_times': 0 + 'position': 0, 'chain': 0, + 'account': account.public_key.address, + 'address': account.public_key.address, + 'used_times': 0 }]) # case #1: no new addresses needed diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bb3e284f7..231480610 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -4,72 +4,100 @@ from twisted.internet import defer from torba.wallet import Wallet from torba.constants import COIN from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class -from torba.basedatabase import constraints_to_sql +from torba.basedatabase import query, constraints_to_sql from .test_transaction import get_output, NULL_HASH -class TestConstraintBuilder(unittest.TestCase): +class TestQueryBuilder(unittest.TestCase): def test_dot(self): - constraints = {'txo.position': 18} self.assertEqual( - constraints_to_sql(constraints, prepend_sql=''), - 'txo.position = :txo_position' + constraints_to_sql({'txo.position': 18}), + ('txo.position = :txo_position', {'txo_position': 18}) ) - self.assertEqual(constraints, {'txo_position': 18}) def test_any(self): - constraints = { - 'ages__any': { - 'txo.age__gt': 18, - 'txo.age__lt': 38 - } - } self.assertEqual( - constraints_to_sql(constraints, prepend_sql=''), - '(txo.age > :ages__any_txo_age__gt OR txo.age < :ages__any_txo_age__lt)' - ) - self.assertEqual( - constraints, { + constraints_to_sql({ + 'ages__any': { + 'txo.age__gt': 18, + 'txo.age__lt': 38 + } + }), + ('(txo.age > :ages__any_txo_age__gt OR txo.age < :ages__any_txo_age__lt)', { 'ages__any_txo_age__gt': 18, 'ages__any_txo_age__lt': 38 - } + }) ) def test_in_list(self): - constraints = {'txo.age__in': [18, 38]} self.assertEqual( - constraints_to_sql(constraints, prepend_sql=''), - 'txo.age IN (:txo_age_1, :txo_age_2)' + constraints_to_sql({'txo.age__in': [18, 38]}), + ('txo.age IN (18, 38)', {}) ) self.assertEqual( - constraints, { - 'txo_age_1': 18, - 'txo_age_2': 38 - } + constraints_to_sql({'txo.age__in': ['abc123', 'def456']}), + ("txo.age IN ('abc123', 'def456')", {}) ) def test_in_query(self): - constraints = {'txo.age__in': 'SELECT age from ages_table'} self.assertEqual( - constraints_to_sql(constraints, prepend_sql=''), - 'txo.age IN (SELECT age from ages_table)' + constraints_to_sql({'txo.age__in': 'SELECT age from ages_table'}), + ('txo.age IN (SELECT age from ages_table)', {}) ) - self.assertEqual(constraints, {}) def test_not_in_query(self): - constraints = {'txo.age__not_in': 'SELECT age from ages_table'} self.assertEqual( - constraints_to_sql(constraints, prepend_sql=''), - 'txo.age NOT IN (SELECT age from ages_table)' + constraints_to_sql({'txo.age__not_in': 'SELECT age from ages_table'}), + ('txo.age NOT IN (SELECT age from ages_table)', {}) ) - self.assertEqual(constraints, {}) def test_in_invalid(self): - constraints = {'ages__in': 9} - with self.assertRaisesRegex(ValueError, 'list or string'): - constraints_to_sql(constraints, prepend_sql='') + with self.assertRaisesRegex(ValueError, 'list, set or string'): + constraints_to_sql({'ages__in': 9}) + + def test_query(self): + self.assertEqual( + query("select * from foo"), + ("select * from foo", {}) + ) + self.assertEqual( + query( + "select * from foo", + a='b', b__in='select * from blah where c=:$c', + d__any={'one': 1, 'two': 2}, limit=10, order_by='b', **{'$c': 3}), + ( + "select * from foo WHERE a = :a AND " + "b IN (select * from blah where c=:$c) AND " + "(one = :d__any_one OR two = :d__any_two) ORDER BY b LIMIT 10", + {'a': 'b', 'd__any_one': 1, 'd__any_two': 2, '$c': 3} + ) + ) + + def test_query_order_by(self): + self.assertEqual( + query("select * from foo", order_by='foo'), + ("select * from foo ORDER BY foo", {}) + ) + self.assertEqual( + query("select * from foo", order_by=['foo', 'bar']), + ("select * from foo ORDER BY foo, bar", {}) + ) + + def test_query_limit_offset(self): + self.assertEqual( + query("select * from foo", limit=10), + ("select * from foo LIMIT 10", {}) + ) + self.assertEqual( + query("select * from foo", offset=10), + ("select * from foo OFFSET 10", {}) + ) + self.assertEqual( + query("select * from foo", limit=20, offset=10), + ("select * from foo OFFSET 10 LIMIT 20", {}) + ) class TestQueries(unittest.TestCase): @@ -153,13 +181,13 @@ class TestQueries(unittest.TestCase): self.assertEqual(txs[1].inputs[0].is_my_account, False) self.assertEqual(txs[1].outputs[0].is_my_account, True) - tx = yield self.ledger.db.get_transaction(tx2.id) + tx = yield self.ledger.db.get_transaction(txid=tx2.id) self.assertEqual(tx.id, tx2.id) self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, False) - tx = yield self.ledger.db.get_transaction(tx2.id, account1) + tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account1) self.assertEqual(tx.inputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, False) - tx = yield self.ledger.db.get_transaction(tx2.id, account2) + tx = yield self.ledger.db.get_transaction(txid=tx2.id, account=account2) self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.outputs[0].is_my_account, True) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index dca00f7e7..26c5f6d8d 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -201,7 +201,7 @@ class TestTransactionSigning(unittest.TestCase): ) yield account.ensure_address_gap() - address1, address2 = yield account.receiving.get_addresses(2) + address1, address2 = yield account.receiving.get_addresses(limit=2) pubkey_hash1 = self.ledger.address_to_hash160(address1) pubkey_hash2 = self.ledger.address_to_hash160(address2) diff --git a/torba/baseaccount.py b/torba/baseaccount.py index 108147963..b2c5fb8d7 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -47,9 +47,11 @@ class AddressManager: def db(self): return self.account.ledger.db - def _query_addresses(self, limit: int = None, max_used_times: int = None, order_by=None): + def _query_addresses(self, **constraints): return self.db.get_addresses( - self.account, self.chain_number, limit, max_used_times, order_by + account=self.account, + chain=self.chain_number, + **constraints ) def get_private_key(self, index: int) -> PrivateKey: @@ -61,17 +63,17 @@ class AddressManager: def ensure_address_gap(self) -> defer.Deferred: raise NotImplementedError - def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred: + def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: raise NotImplementedError @defer.inlineCallbacks - def get_addresses(self, limit: int = None, only_usable: bool = False) -> defer.Deferred: - records = yield self.get_address_records(limit=limit, only_usable=only_usable) + def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred: + records = yield self.get_address_records(only_usable=only_usable, **constraints) return [r['address'] for r in records] @defer.inlineCallbacks def get_or_create_usable_address(self) -> defer.Deferred: - addresses = yield self.get_addresses(limit=1, only_usable=True) + addresses = yield self.get_addresses(only_usable=True, limit=1) if addresses: return addresses[0] addresses = yield self.ensure_address_gap() @@ -128,7 +130,7 @@ class HierarchicalDeterministic(AddressManager): @defer.inlineCallbacks def ensure_address_gap(self) -> defer.Deferred: - addresses = yield self._query_addresses(self.gap, None, "position DESC") + addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC") existing_gap = 0 for address in addresses: @@ -145,11 +147,10 @@ class HierarchicalDeterministic(AddressManager): new_keys = yield self.generate_keys(start, end-1) return new_keys - def get_address_records(self, limit: int = None, only_usable: bool = False): - return self._query_addresses( - limit, self.maximum_uses_per_address if only_usable else None, - "used_times ASC, position ASC" - ) + def get_address_records(self, only_usable: bool = False, **constraints): + if only_usable: + constraints['used_times__lte'] = self.maximum_uses_per_address + return self._query_addresses(order_by="used_times ASC, position ASC", **constraints) class SingleKey(AddressManager): @@ -184,8 +185,8 @@ class SingleKey(AddressManager): return [self.public_key.address] return [] - def get_address_records(self, limit: int = None, only_usable: bool = False) -> defer.Deferred: - return self._query_addresses() + def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: + return self._query_addresses(**constraints) class BaseAccount: @@ -329,23 +330,23 @@ class BaseAccount: return addresses @defer.inlineCallbacks - def get_addresses(self, limit: int = None, max_used_times: int = None) -> defer.Deferred: - records = yield self.get_address_records(limit, max_used_times) - return [r['address'] for r in records] + def get_addresses(self, **constraints) -> defer.Deferred: + rows = yield self.ledger.db.select_addresses('address', **constraints) + return [r[0] for r in rows] - def get_address_records(self, limit: int = None, max_used_times: int = None) -> defer.Deferred: - return self.ledger.db.get_addresses(self, None, limit, max_used_times) + def get_address_records(self, **constraints) -> defer.Deferred: + return self.ledger.db.get_addresses(account=self, **constraints) def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." address_manager = {0: self.receiving, 1: self.change}[chain] return address_manager.get_private_key(index) - def get_balance(self, confirmations: int = 6, **constraints): + def get_balance(self, confirmations: int = 0, **constraints): if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) constraints.update({'height__lte': height, 'height__gt': 0}) - return self.ledger.db.get_balance_for_account(self, **constraints) + return self.ledger.db.get_balance(account=self, **constraints) @defer.inlineCallbacks def get_max_gap(self): @@ -356,11 +357,11 @@ class BaseAccount: 'max_receiving_gap': receiving_gap, } - def get_unspent_outputs(self, **constraints): + def get_utxos(self, **constraints): return self.ledger.db.get_utxos(account=self, **constraints) - def get_transactions(self) -> List['basetransaction.BaseTransaction']: - return self.ledger.db.get_transactions(account=self) + def get_transactions(self, **constraints) -> List['basetransaction.BaseTransaction']: + return self.ledger.db.get_transactions(account=self, **constraints) @defer.inlineCallbacks def fund(self, to_account, amount=None, everything=False, @@ -368,7 +369,7 @@ class BaseAccount: assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' tx_class = self.ledger.transaction_class if everything: - utxos = yield self.get_unspent_outputs(**constraints) + utxos = yield self.get_utxos(**constraints) yield self.ledger.reserve_outputs(utxos) tx = yield tx_class.create( inputs=[tx_class.input_class.spend(txo) for txo in utxos], diff --git a/torba/basedatabase.py b/torba/basedatabase.py index b70be1d73..a6846b6ae 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -12,60 +12,97 @@ from torba.baseaccount import BaseAccount log = logging.getLogger(__name__) -def clean_arg_name(arg): - return arg.replace('.', '_') - - -def prepare_constraints(constraints): - if 'account' in constraints: - if isinstance(constraints['account'], BaseAccount): - constraints['account'] = constraints['account'].public_key.address - return constraints.pop('my_account', constraints.get('account')) - - -def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''): - if not constraints: - return '' - extras = [] - for key in list(constraints): - col, op, constraint = key, '=', constraints.pop(key) - if key.endswith('__not'): - col, op = key[:-len('__not')], '!=' +def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): + sql, values = [], {} + for key, constraint in constraints.items(): + col, op, key = key, '=', key.replace('.', '_') + if key.startswith('$'): + values[key] = constraint + continue + elif key.endswith('__not'): + col, op = col[:-len('__not')], '!=' elif key.endswith('__lt'): - col, op = key[:-len('__lt')], '<' + col, op = col[:-len('__lt')], '<' elif key.endswith('__lte'): - col, op = key[:-len('__lte')], '<=' + col, op = col[:-len('__lte')], '<=' elif key.endswith('__gt'): - col, op = key[:-len('__gt')], '>' + col, op = col[:-len('__gt')], '>' elif key.endswith('__like'): - col, op = key[:-len('__like')], 'LIKE' + col, op = col[:-len('__like')], 'LIKE' elif key.endswith('__in') or key.endswith('__not_in'): if key.endswith('__in'): - col, op = key[:-len('__in')], 'IN' + col, op = col[:-len('__in')], 'IN' else: - col, op = key[:-len('__not_in')], 'NOT IN' + col, op = col[:-len('__not_in')], 'NOT IN' if isinstance(constraint, (list, set)): - placeholders = [] - for item_no, item in enumerate(constraint, 1): - constraints['{}_{}'.format(clean_arg_name(col), item_no)] = item - placeholders.append(':{}_{}'.format(clean_arg_name(col), item_no)) - items = ', '.join(placeholders) + items = ', '.join( + "'{}'".format(item) if isinstance(item, str) else str(item) + for item in constraint + ) elif isinstance(constraint, str): items = constraint else: - raise ValueError("{} requires a list or string as constraint value.".format(key)) - extras.append('{} {} ({})'.format(col, op, items)) + raise ValueError("{} requires a list, set or string as constraint value.".format(col)) + sql.append('{} {} ({})'.format(col, op, items)) continue elif key.endswith('__any'): - extras.append('({})'.format( - constraints_to_sql(constraint, ' OR ', '', key+'_') - )) - for subkey, val in constraint.items(): - constraints['{}_{}'.format(clean_arg_name(key), clean_arg_name(subkey))] = val + where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_') + sql.append('({})'.format(where)) + values.update(subvalues) continue - constraints[clean_arg_name(key)] = constraint - extras.append('{} {} :{}'.format(col, op, prepend_key+clean_arg_name(key))) - return prepend_sql + joiner.join(extras) if extras else '' + sql.append('{} {} :{}'.format(col, op, prepend_key+key)) + values[prepend_key+key] = constraint + return joiner.join(sql) if sql else '', values + + +def query(select, **constraints): + sql = [select] + limit = constraints.pop('limit', None) + offset = constraints.pop('offset', None) + order_by = constraints.pop('order_by', None) + + constraints.pop('my_account', None) + account = constraints.pop('account', None) + if account is not None: + if not isinstance(account, list): + account = [account] + constraints['account__in'] = [ + (a.public_key.address if isinstance(a, BaseAccount) else a) for a in account + ] + + where, values = constraints_to_sql(constraints) + if where: + sql.append('WHERE') + sql.append(where) + + if order_by is not None: + sql.append('ORDER BY') + if isinstance(order_by, str): + sql.append(order_by) + elif isinstance(order_by, list): + sql.append(', '.join(order_by)) + else: + raise ValueError("order_by must be string or list") + + if offset is not None: + sql.append('OFFSET {}'.format(offset)) + + if limit is not None: + sql.append('LIMIT {}'.format(limit)) + + return ' '.join(sql), values + + +def rows_to_dict(rows, fields): + if rows: + return [dict(zip(fields, r)) for r in rows] + else: + return [] + + +def row_dict_or_default(rows, fields, default=None): + dicts = rows_to_dict(rows, fields) + return dicts[0] if dicts else default class SQLiteMixin: @@ -121,22 +158,6 @@ class SQLiteMixin: else: return default - @defer.inlineCallbacks - def query_dict_value_list(self, query, fields, params=None): - result = yield self.run_query(query.format(', '.join(fields)), params) - if result: - return [dict(zip(fields, r)) for r in result] - else: - return [] - - @defer.inlineCallbacks - def query_dict_value(self, query, fields, params=None, default=None): - result = yield self.query_dict_value_list(query, fields, params) - if result: - return result[0] - else: - return default - @staticmethod def execute(t, sql, values): log.debug(sql) @@ -232,9 +253,9 @@ class BaseDatabase(SQLiteMixin): 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) - existing_txos = [r[0] for r in self.execute( - t, "SELECT position FROM txo WHERE txid = ?", (tx.id,) - ).fetchall()] + existing_txos = [r[0] for r in self.execute(t, *query( + "SELECT position FROM txo", txid=tx.id + )).fetchall()] for txo in tx.outputs: if txo.position in existing_txos: @@ -246,16 +267,14 @@ class BaseDatabase(SQLiteMixin): print('Database.save_transaction_io: pay script hash is not implemented!') # lookup the address associated with each TXI (via its TXO) - txoids = [txi.txo_ref.id for txi in tx.inputs] - txoid_place_holders = ','.join(['?']*len(txoids)) - txoid_to_address = {r[0]: r[1] for r in self.execute( - t, "SELECT txoid, address FROM txo WHERE txoid in ({})".format(txoid_place_holders), txoids - ).fetchall()} + txoid_to_address = {r[0]: r[1] for r in self.execute(t, *query( + "SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs] + )).fetchall()} # list of TXIs that have already been added - existing_txis = [r[0] for r in self.execute( - t, "SELECT txoid FROM txi WHERE txid = ?", (tx.id,) - ).fetchall()] + existing_txis = [r[0] for r in self.execute(t, *query( + "SELECT txoid FROM txi", txid=tx.id + )).fetchall()] for txi in tx.inputs: txoid = txi.txo_ref.id @@ -289,39 +308,27 @@ class BaseDatabase(SQLiteMixin): # 2. update address histories removing deleted TXs return defer.succeed(True) - @defer.inlineCallbacks - def get_transaction(self, txid, account=None): - txs = yield self.get_transactions(account=account, txid=txid) - if len(txs) == 1: - return txs[0] - - @defer.inlineCallbacks - def get_transactions(self, offset=0, limit=1000000, **constraints): - my_account = prepare_constraints(constraints) - account = constraints.pop('account', None) - + def select_transactions(self, cols, account=None, **constraints): if 'txid' not in constraints and account is not None: + constraints['$account'] = account.public_key.address constraints['txid__in'] = """ SELECT txo.txid FROM txo - JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account + JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account UNION SELECT txi.txid FROM txi JOIN txo USING (txoid) - JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account + JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account """ + return self.run_query(*query("SELECT {} FROM tx".format(cols), **constraints)) - where = constraints_to_sql(constraints, prepend_sql='WHERE ') + @defer.inlineCallbacks + def get_transactions(self, my_account=None, **constraints): + my_account = my_account or constraints.get('account', None) - tx_rows = yield self.run_query( - """ - SELECT txid, raw, height, position, is_verified FROM tx {} - ORDER BY height DESC, position DESC LIMIT :offset, :limit - """.format(where), { - **constraints, - 'account': account, - 'offset': max(offset, 0), - 'limit': max(limit, 1) - } + tx_rows = yield self.select_transactions( + 'txid, raw, height, position, is_verified', + order_by=["height DESC", "position DESC"], + **constraints ) txids, txs = [], [] @@ -343,9 +350,7 @@ class BaseDatabase(SQLiteMixin): txo.id: txo for txo in (yield self.get_txos( my_account=my_account, - txoid__in="SELECT txoid FROM txi WHERE txi.txid IN ({})".format( - ','.join("'{}'".format(txid) for txid in txids) - ) + txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0] )) } @@ -364,13 +369,28 @@ class BaseDatabase(SQLiteMixin): return txs @defer.inlineCallbacks - def get_txos(self, **constraints): - my_account = prepare_constraints(constraints) - rows = yield self.run_query( - """ - SELECT amount, script, txid, txo.position, chain, account - FROM txo JOIN pubkey_address USING (address) - """+constraints_to_sql(constraints, prepend_sql='WHERE '), constraints + def get_transaction_count(self, **constraints): + count = yield self.select_transactions('count(*)', **constraints) + return count[0][0] + + @defer.inlineCallbacks + def get_transaction(self, **constraints): + txs = yield self.get_transactions(limit=1, **constraints) + if txs: + return txs[0] + + def select_txos(self, cols, **constraints): + return self.run_query(*query( + "SELECT {} FROM txo JOIN pubkey_address USING (address)".format(cols), **constraints + )) + + @defer.inlineCallbacks + def get_txos(self, my_account=None, **constraints): + my_account = my_account or constraints.get('account', None) + if isinstance(my_account, BaseAccount): + my_account = my_account.public_key.address + rows = yield self.select_txos( + "amount, script, txid, txo.position, chain, account", **constraints ) output_class = self.ledger.transaction_class.output_class return [ @@ -384,34 +404,60 @@ class BaseDatabase(SQLiteMixin): ) for row in rows ] - def get_utxos(self, **constraints): - constraints['txoid__not_in'] = 'SELECT txoid FROM txi' + @defer.inlineCallbacks + def get_txo_count(self, **constraints): + count = yield self.select_txos('count(*)', **constraints) + return count[0][0] + + @staticmethod + def constrain_utxo(constraints): constraints['is_reserved'] = False + constraints['txoid__not_in'] = "SELECT txoid FROM txi" + + def get_utxos(self, **constraints): + self.constrain_utxo(constraints) return self.get_txos(**constraints) - def get_balance_for_account(self, account, include_reserved=False, **constraints): - if not include_reserved: - constraints['is_reserved'] = False - values = {'account': account.public_key.address} - values.update(constraints) - return self.query_one_value( - """ - SELECT SUM(amount) - FROM txo - JOIN tx ON tx.txid=txo.txid - JOIN pubkey_address ON pubkey_address.address=txo.address - WHERE - pubkey_address.account=:account AND - txoid NOT IN (SELECT txoid FROM txi) - """+constraints_to_sql(constraints), values, 0 + def get_utxos_count(self, **constraints): + self.constrain_utxo(constraints) + return self.get_txo_count(**constraints) + + @defer.inlineCallbacks + def get_balance(self, **constraints): + self.constrain_utxo(constraints) + balance = yield self.select_txos('SUM(amount)', **constraints) + return balance[0][0] or 0 + + def select_addresses(self, cols, **constraints): + return self.run_query(*query( + "SELECT {} FROM pubkey_address".format(cols), **constraints + )) + + @defer.inlineCallbacks + def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints): + addresses = yield self.select_addresses(', '.join(cols), **constraints) + return rows_to_dict(addresses, cols) + + @defer.inlineCallbacks + def get_address_count(self, **constraints): + count = yield self.select_addresses('count(*)', **constraints) + return count[0][0] + + @defer.inlineCallbacks + def get_address(self, address): + addresses = yield self.get_addresses( + cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), + address=address, limit=1 ) + if addresses: + return addresses[0] def add_keys(self, account, chain, keys): sql = ( - "insert into pubkey_address " - "(address, account, chain, position, pubkey) " - "values " - ) + ', '.join(['(?, ?, ?, ?, ?)'] * len(keys)) + "insert into pubkey_address " + "(address, account, chain, position, pubkey) " + "values " + ) + ', '.join(['(?, ?, ?, ?, ?)'] * len(keys)) values = [] for position, pubkey in keys: values.append(pubkey.address) @@ -430,40 +476,3 @@ class BaseDatabase(SQLiteMixin): def set_address_history(self, address, history): return self.db.runInteraction(lambda t: self._set_address_history(t, address, history)) - - def get_addresses(self, account, chain, limit=None, max_used_times=None, order_by=None): - columns = ['account', 'chain', 'position', 'address', 'used_times'] - sql = ["SELECT {} FROM pubkey_address"] - - where = [] - params = {} - if account is not None: - params["account"] = account.public_key.address - where.append("account = :account") - columns.remove("account") - if chain is not None: - params["chain"] = chain - where.append("chain = :chain") - columns.remove("chain") - if max_used_times is not None: - params["used_times"] = max_used_times - where.append("used_times <= :used_times") - - if where: - sql.append("WHERE") - sql.append(" AND ".join(where)) - - if order_by: - sql.append("ORDER BY {}".format(order_by)) - - if limit is not None: - sql.append("LIMIT {}".format(limit)) - - return self.query_dict_value_list(" ".join(sql), columns, params) - - def get_address(self, address): - return self.query_dict_value( - "SELECT {} FROM pubkey_address WHERE address = :address", - ('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), - {'address': address} - ) diff --git a/torba/baseledger.py b/torba/baseledger.py index 50c00a607..a19347518 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -148,7 +148,7 @@ class BaseLedger(metaclass=LedgerRegistry): def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]): estimators = [] for account in funding_accounts: - utxos = yield account.get_unspent_outputs() + utxos = yield account.get_utxos() for utxo in utxos: estimators.append(utxo.get_estimator(self)) return estimators @@ -330,7 +330,7 @@ class BaseLedger(metaclass=LedgerRegistry): # need to update anyways. Continue to get history and create more addresses until # all missing addresses are created and history for them is fully restored. yield account.ensure_address_gap() - addresses = yield account.get_addresses(max_used_times=0) + addresses = yield account.get_addresses(used_times=0) while addresses: yield defer.DeferredList([ self.update_history(a) for a in addresses @@ -364,7 +364,7 @@ class BaseLedger(metaclass=LedgerRegistry): try: # see if we have a local copy of transaction, otherwise fetch it from server - tx = yield self.db.get_transaction(hex_id) + tx = yield self.db.get_transaction(txid=hex_id) save_tx = None if tx is None: _raw = yield self.network.get_transaction(hex_id)