From 10fd654edce307521858353907cafdf4d0debcb3 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Mon, 16 Jul 2018 23:58:29 -0400 Subject: [PATCH] improved account balance support --- torba/__init__.py | 2 +- torba/baseaccount.py | 9 +++---- torba/basedatabase.py | 56 ++++++++++++++++++++++++++++--------------- torba/baseledger.py | 12 ++++++---- torba/manager.py | 6 +++++ 5 files changed, 55 insertions(+), 30 deletions(-) diff --git a/torba/__init__.py b/torba/__init__.py index bc9d27d55..ebb33d435 100644 --- a/torba/__init__.py +++ b/torba/__init__.py @@ -1,2 +1,2 @@ __path__ = __import__('pkgutil').extend_path(__path__, __name__) -__version__ = '0.0.1' +__version__ = '0.0.2' diff --git a/torba/baseaccount.py b/torba/baseaccount.py index a10ae7257..92d64846c 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -260,13 +260,10 @@ class BaseAccount(object): return self.private_key.child(chain).child(index) def get_balance(self, confirmations=6, **constraints): - if confirmations == 0: - return self.ledger.db.get_balance_for_account(self, **constraints) - else: + if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) - return self.ledger.db.get_balance_for_account( - self, height__lte=height, height__not=-1, **constraints - ) + constraints.update({'height__lte': height, 'height__not': -1}) + return self.ledger.db.get_balance_for_account(self, **constraints) def get_unspent_outputs(self, **constraints): return self.ledger.db.get_utxos_for_account(self, **constraints) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index a850aa2b7..3d4df697f 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -57,7 +57,7 @@ class SQLiteMixin(object): @defer.inlineCallbacks def query_one_value(self, query, params=None, default=None): - result = yield self.db.runQuery(query, params) + result = yield self.run_query(query, params) if result: defer.returnValue(result[0][0] or default) else: @@ -65,7 +65,7 @@ class SQLiteMixin(object): @defer.inlineCallbacks def query_dict_value_list(self, query, fields, params=None): - result = yield self.db.runQuery(query.format(', '.join(fields)), params) + result = yield self.run_query(query.format(', '.join(fields)), params) if result: defer.returnValue([dict(zip(fields, r)) for r in result]) else: @@ -79,6 +79,22 @@ class SQLiteMixin(object): else: defer.returnValue(default) + @staticmethod + def execute(t, sql, values): + log.debug(sql) + log.debug(values) + return t.execute(sql, values) + + def run_operation(self, sql, values): + log.debug(sql) + log.debug(values) + return self.db.runOperation(sql, values) + + def run_query(self, sql, values): + log.debug(sql) + log.debug(values) + return self.db.runQuery(sql, values) + class BaseDatabase(SQLiteMixin): @@ -144,39 +160,39 @@ class BaseDatabase(SQLiteMixin): def _steps(t): if save_tx == 'insert': - t.execute(*self._insert_sql('tx', { + self.execute(t, *self._insert_sql('tx', { 'txid': tx.id, 'raw': sqlite3.Binary(tx.raw), 'height': height, 'is_verified': is_verified })) elif save_tx == 'update': - t.execute(*self._update_sql("tx", { + self.execute(t, *self._update_sql("tx", { 'height': height, 'is_verified': is_verified }, 'txid = ?', (tx.id,) )) - existing_txos = list(map(itemgetter(0), t.execute( - "SELECT position FROM txo WHERE txid = ?", (tx.id,) + existing_txos = list(map(itemgetter(0), self.execute( + t, "SELECT position FROM txo WHERE txid = ?", (tx.id,) ).fetchall())) for txo in tx.outputs: if txo.position in existing_txos: continue if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash: - t.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo))) + self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo))) elif txo.script.is_pay_script_hash: # TODO: implement script hash payments print('Database.save_transaction_io: pay script hash is not implemented!') - spent_txoids = [txi[0] for txi in t.execute( - "SELECT txoid FROM txi WHERE txid = ? AND address = ?", (tx.id, address) + spent_txoids = [txi[0] for txi in self.execute( + t, "SELECT txoid FROM txi WHERE txid = ? AND address = ?", (tx.id, address) ).fetchall()] for txi in tx.inputs: txoid = txi.txo_ref.id if txoid not in spent_txoids: - t.execute(*self._insert_sql("txi", { + self.execute(t, *self._insert_sql("txi", { 'txid': tx.id, 'txoid': txoid, 'address': address, @@ -187,7 +203,7 @@ class BaseDatabase(SQLiteMixin): return self.db.runInteraction(_steps) def reserve_spent_outputs(self, txoids, is_reserved=True): - return self.db.runOperation( + return self.run_operation( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( ', '.join(['?']*len(txoids)) ), [is_reserved]+txoids @@ -198,7 +214,7 @@ class BaseDatabase(SQLiteMixin): @defer.inlineCallbacks def get_transaction(self, txid): - result = yield self.db.runQuery( + result = yield self.run_query( "SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,) ) if result: @@ -206,7 +222,7 @@ class BaseDatabase(SQLiteMixin): else: defer.returnValue((None, None, False)) - def get_balance_for_account(self, account, **constraints): + def get_balance_for_account(self, account, include_reserved=False, **constraints): extra_sql = "" if constraints: extras = [] @@ -218,6 +234,8 @@ class BaseDatabase(SQLiteMixin): col, op = key[:-len('__lte')], '<=' extras.append('{} {} :{}'.format(col, op, key)) extra_sql = ' AND ' + ' AND '.join(extras) + if not include_reserved: + extra_sql += ' AND is_reserved=0' values = {'account': account.public_key.address} values.update(constraints) return self.query_one_value( @@ -241,7 +259,7 @@ class BaseDatabase(SQLiteMixin): ) values = {'account': account.public_key.address} values.update(constraints) - utxos = yield self.db.runQuery( + utxos = yield self.run_query( """ SELECT amount, script, txid, txo.position FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address @@ -271,12 +289,12 @@ class BaseDatabase(SQLiteMixin): values.append(chain) values.append(position) values.append(sqlite3.Binary(pubkey.pubkey_bytes)) - return self.db.runOperation(sql, values) + return self.run_operation(sql, values) - @staticmethod - def _set_address_history(t, address, history): - t.execute( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + @classmethod + def _set_address_history(cls, t, address, history): + cls.execute( + t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) diff --git a/torba/baseledger.py b/torba/baseledger.py index 0f4a522c6..cfe80169e 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -7,7 +7,7 @@ from typing import Dict, Type, Iterable, Generator from operator import itemgetter from collections import namedtuple -from twisted.internet import defer +from twisted.internet import defer, reactor from torba import baseaccount from torba import basedatabase @@ -69,7 +69,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): self.network = self.config.get('network') or self.network_class(self) self.network.on_header.listen(self.process_header) self.network.on_status.listen(self.process_status) - self.accounts = set() + self.accounts = [] self.headers = self.config.get('headers') or self.headers_class(self) self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte) @@ -123,7 +123,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): @defer.inlineCallbacks def add_account(self, account): # type: (baseaccount.BaseAccount) -> None - self.accounts.add(account) + self.accounts.append(account) if self.network.is_connected: yield self.update_account(account) @@ -297,7 +297,11 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format( self.get_id(), hex_id, address, remote_height, is_verified )) - self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified)) + + reactor.callLater( + 0.01, self._on_transaction_controller.add, + TransactionEvent(address, tx, remote_height, is_verified) + ) except Exception as e: log.exception('Failed to synchronize transaction:') diff --git a/torba/manager.py b/torba/manager.py index ac803e419..2cb998758 100644 --- a/torba/manager.py +++ b/torba/manager.py @@ -65,6 +65,12 @@ class WalletManager(object): for wallet in self.wallets: return wallet.default_account + @property + def accounts(self): + for wallet in self.wallets: + for account in wallet.accounts: + yield account + @defer.inlineCallbacks def start(self): self.running = True