improved account balance support

This commit is contained in:
Lex Berezhny 2018-07-16 23:58:29 -04:00
parent 686eb7b1f0
commit 10fd654edc
5 changed files with 55 additions and 30 deletions

View file

@ -1,2 +1,2 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) __path__ = __import__('pkgutil').extend_path(__path__, __name__)
__version__ = '0.0.1' __version__ = '0.0.2'

View file

@ -260,13 +260,10 @@ class BaseAccount(object):
return self.private_key.child(chain).child(index) return self.private_key.child(chain).child(index)
def get_balance(self, confirmations=6, **constraints): def get_balance(self, confirmations=6, **constraints):
if confirmations == 0: if confirmations > 0:
return self.ledger.db.get_balance_for_account(self, **constraints)
else:
height = self.ledger.headers.height - (confirmations-1) height = self.ledger.headers.height - (confirmations-1)
return self.ledger.db.get_balance_for_account( constraints.update({'height__lte': height, 'height__not': -1})
self, height__lte=height, height__not=-1, **constraints return self.ledger.db.get_balance_for_account(self, **constraints)
)
def get_unspent_outputs(self, **constraints): def get_unspent_outputs(self, **constraints):
return self.ledger.db.get_utxos_for_account(self, **constraints) return self.ledger.db.get_utxos_for_account(self, **constraints)

View file

@ -57,7 +57,7 @@ class SQLiteMixin(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_one_value(self, query, params=None, default=None): def query_one_value(self, query, params=None, default=None):
result = yield self.db.runQuery(query, params) result = yield self.run_query(query, params)
if result: if result:
defer.returnValue(result[0][0] or default) defer.returnValue(result[0][0] or default)
else: else:
@ -65,7 +65,7 @@ class SQLiteMixin(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_dict_value_list(self, query, fields, params=None): def query_dict_value_list(self, query, fields, params=None):
result = yield self.db.runQuery(query.format(', '.join(fields)), params) result = yield self.run_query(query.format(', '.join(fields)), params)
if result: if result:
defer.returnValue([dict(zip(fields, r)) for r in result]) defer.returnValue([dict(zip(fields, r)) for r in result])
else: else:
@ -79,6 +79,22 @@ class SQLiteMixin(object):
else: else:
defer.returnValue(default) defer.returnValue(default)
@staticmethod
def execute(t, sql, values):
log.debug(sql)
log.debug(values)
return t.execute(sql, values)
def run_operation(self, sql, values):
log.debug(sql)
log.debug(values)
return self.db.runOperation(sql, values)
def run_query(self, sql, values):
log.debug(sql)
log.debug(values)
return self.db.runQuery(sql, values)
class BaseDatabase(SQLiteMixin): class BaseDatabase(SQLiteMixin):
@ -144,39 +160,39 @@ class BaseDatabase(SQLiteMixin):
def _steps(t): def _steps(t):
if save_tx == 'insert': if save_tx == 'insert':
t.execute(*self._insert_sql('tx', { self.execute(t, *self._insert_sql('tx', {
'txid': tx.id, 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw), 'raw': sqlite3.Binary(tx.raw),
'height': height, 'height': height,
'is_verified': is_verified 'is_verified': is_verified
})) }))
elif save_tx == 'update': elif save_tx == 'update':
t.execute(*self._update_sql("tx", { self.execute(t, *self._update_sql("tx", {
'height': height, 'is_verified': is_verified 'height': height, 'is_verified': is_verified
}, 'txid = ?', (tx.id,) }, 'txid = ?', (tx.id,)
)) ))
existing_txos = list(map(itemgetter(0), t.execute( existing_txos = list(map(itemgetter(0), self.execute(
"SELECT position FROM txo WHERE txid = ?", (tx.id,) t, "SELECT position FROM txo WHERE txid = ?", (tx.id,)
).fetchall())) ).fetchall()))
for txo in tx.outputs: for txo in tx.outputs:
if txo.position in existing_txos: if txo.position in existing_txos:
continue continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash: if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash:
t.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo))) self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash: elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments # TODO: implement script hash payments
print('Database.save_transaction_io: pay script hash is not implemented!') print('Database.save_transaction_io: pay script hash is not implemented!')
spent_txoids = [txi[0] for txi in t.execute( spent_txoids = [txi[0] for txi in self.execute(
"SELECT txoid FROM txi WHERE txid = ? AND address = ?", (tx.id, address) t, "SELECT txoid FROM txi WHERE txid = ? AND address = ?", (tx.id, address)
).fetchall()] ).fetchall()]
for txi in tx.inputs: for txi in tx.inputs:
txoid = txi.txo_ref.id txoid = txi.txo_ref.id
if txoid not in spent_txoids: if txoid not in spent_txoids:
t.execute(*self._insert_sql("txi", { self.execute(t, *self._insert_sql("txi", {
'txid': tx.id, 'txid': tx.id,
'txoid': txoid, 'txoid': txoid,
'address': address, 'address': address,
@ -187,7 +203,7 @@ class BaseDatabase(SQLiteMixin):
return self.db.runInteraction(_steps) return self.db.runInteraction(_steps)
def reserve_spent_outputs(self, txoids, is_reserved=True): def reserve_spent_outputs(self, txoids, is_reserved=True):
return self.db.runOperation( return self.run_operation(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids)) ', '.join(['?']*len(txoids))
), [is_reserved]+txoids ), [is_reserved]+txoids
@ -198,7 +214,7 @@ class BaseDatabase(SQLiteMixin):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transaction(self, txid): def get_transaction(self, txid):
result = yield self.db.runQuery( result = yield self.run_query(
"SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,) "SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,)
) )
if result: if result:
@ -206,7 +222,7 @@ class BaseDatabase(SQLiteMixin):
else: else:
defer.returnValue((None, None, False)) defer.returnValue((None, None, False))
def get_balance_for_account(self, account, **constraints): def get_balance_for_account(self, account, include_reserved=False, **constraints):
extra_sql = "" extra_sql = ""
if constraints: if constraints:
extras = [] extras = []
@ -218,6 +234,8 @@ class BaseDatabase(SQLiteMixin):
col, op = key[:-len('__lte')], '<=' col, op = key[:-len('__lte')], '<='
extras.append('{} {} :{}'.format(col, op, key)) extras.append('{} {} :{}'.format(col, op, key))
extra_sql = ' AND ' + ' AND '.join(extras) extra_sql = ' AND ' + ' AND '.join(extras)
if not include_reserved:
extra_sql += ' AND is_reserved=0'
values = {'account': account.public_key.address} values = {'account': account.public_key.address}
values.update(constraints) values.update(constraints)
return self.query_one_value( return self.query_one_value(
@ -241,7 +259,7 @@ class BaseDatabase(SQLiteMixin):
) )
values = {'account': account.public_key.address} values = {'account': account.public_key.address}
values.update(constraints) values.update(constraints)
utxos = yield self.db.runQuery( utxos = yield self.run_query(
""" """
SELECT amount, script, txid, txo.position SELECT amount, script, txid, txo.position
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
@ -271,12 +289,12 @@ class BaseDatabase(SQLiteMixin):
values.append(chain) values.append(chain)
values.append(position) values.append(position)
values.append(sqlite3.Binary(pubkey.pubkey_bytes)) values.append(sqlite3.Binary(pubkey.pubkey_bytes))
return self.db.runOperation(sql, values) return self.run_operation(sql, values)
@staticmethod @classmethod
def _set_address_history(t, address, history): def _set_address_history(cls, t, address, history):
t.execute( cls.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address) (history, history.count(':')//2, address)
) )

View file

@ -7,7 +7,7 @@ from typing import Dict, Type, Iterable, Generator
from operator import itemgetter from operator import itemgetter
from collections import namedtuple from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer, reactor
from torba import baseaccount from torba import baseaccount
from torba import basedatabase from torba import basedatabase
@ -69,7 +69,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.network = self.config.get('network') or self.network_class(self) self.network = self.config.get('network') or self.network_class(self)
self.network.on_header.listen(self.process_header) self.network.on_header.listen(self.process_header)
self.network.on_status.listen(self.process_status) self.network.on_status.listen(self.process_status)
self.accounts = set() self.accounts = []
self.headers = self.config.get('headers') or self.headers_class(self) self.headers = self.config.get('headers') or self.headers_class(self)
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte) self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
@ -123,7 +123,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_account(self, account): # type: (baseaccount.BaseAccount) -> None def add_account(self, account): # type: (baseaccount.BaseAccount) -> None
self.accounts.add(account) self.accounts.append(account)
if self.network.is_connected: if self.network.is_connected:
yield self.update_account(account) yield self.update_account(account)
@ -297,7 +297,11 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format( log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format(
self.get_id(), hex_id, address, remote_height, is_verified self.get_id(), hex_id, address, remote_height, is_verified
)) ))
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
reactor.callLater(
0.01, self._on_transaction_controller.add,
TransactionEvent(address, tx, remote_height, is_verified)
)
except Exception as e: except Exception as e:
log.exception('Failed to synchronize transaction:') log.exception('Failed to synchronize transaction:')

View file

@ -65,6 +65,12 @@ class WalletManager(object):
for wallet in self.wallets: for wallet in self.wallets:
return wallet.default_account return wallet.default_account
@property
def accounts(self):
for wallet in self.wallets:
for account in wallet.accounts:
yield account
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
self.running = True self.running = True