improved account balance support

This commit is contained in:
Lex Berezhny 2018-07-16 23:58:29 -04:00
parent 686eb7b1f0
commit 10fd654edc
5 changed files with 55 additions and 30 deletions

View file

@ -1,2 +1,2 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__version__ = '0.0.1'
__version__ = '0.0.2'

View file

@ -260,13 +260,10 @@ class BaseAccount(object):
return self.private_key.child(chain).child(index)
def get_balance(self, confirmations=6, **constraints):
if confirmations == 0:
return self.ledger.db.get_balance_for_account(self, **constraints)
else:
if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1)
return self.ledger.db.get_balance_for_account(
self, height__lte=height, height__not=-1, **constraints
)
constraints.update({'height__lte': height, 'height__not': -1})
return self.ledger.db.get_balance_for_account(self, **constraints)
def get_unspent_outputs(self, **constraints):
return self.ledger.db.get_utxos_for_account(self, **constraints)

View file

@ -57,7 +57,7 @@ class SQLiteMixin(object):
@defer.inlineCallbacks
def query_one_value(self, query, params=None, default=None):
result = yield self.db.runQuery(query, params)
result = yield self.run_query(query, params)
if result:
defer.returnValue(result[0][0] or default)
else:
@ -65,7 +65,7 @@ class SQLiteMixin(object):
@defer.inlineCallbacks
def query_dict_value_list(self, query, fields, params=None):
result = yield self.db.runQuery(query.format(', '.join(fields)), params)
result = yield self.run_query(query.format(', '.join(fields)), params)
if result:
defer.returnValue([dict(zip(fields, r)) for r in result])
else:
@ -79,6 +79,22 @@ class SQLiteMixin(object):
else:
defer.returnValue(default)
@staticmethod
def execute(t, sql, values):
log.debug(sql)
log.debug(values)
return t.execute(sql, values)
def run_operation(self, sql, values):
log.debug(sql)
log.debug(values)
return self.db.runOperation(sql, values)
def run_query(self, sql, values):
log.debug(sql)
log.debug(values)
return self.db.runQuery(sql, values)
class BaseDatabase(SQLiteMixin):
@ -144,39 +160,39 @@ class BaseDatabase(SQLiteMixin):
def _steps(t):
if save_tx == 'insert':
t.execute(*self._insert_sql('tx', {
self.execute(t, *self._insert_sql('tx', {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': height,
'is_verified': is_verified
}))
elif save_tx == 'update':
t.execute(*self._update_sql("tx", {
self.execute(t, *self._update_sql("tx", {
'height': height, 'is_verified': is_verified
}, 'txid = ?', (tx.id,)
))
existing_txos = list(map(itemgetter(0), t.execute(
"SELECT position FROM txo WHERE txid = ?", (tx.id,)
existing_txos = list(map(itemgetter(0), self.execute(
t, "SELECT position FROM txo WHERE txid = ?", (tx.id,)
).fetchall()))
for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash:
t.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
print('Database.save_transaction_io: pay script hash is not implemented!')
spent_txoids = [txi[0] for txi in t.execute(
"SELECT txoid FROM txi WHERE txid = ? AND address = ?", (tx.id, address)
spent_txoids = [txi[0] for txi in self.execute(
t, "SELECT txoid FROM txi WHERE txid = ? AND address = ?", (tx.id, address)
).fetchall()]
for txi in tx.inputs:
txoid = txi.txo_ref.id
if txoid not in spent_txoids:
t.execute(*self._insert_sql("txi", {
self.execute(t, *self._insert_sql("txi", {
'txid': tx.id,
'txoid': txoid,
'address': address,
@ -187,7 +203,7 @@ class BaseDatabase(SQLiteMixin):
return self.db.runInteraction(_steps)
def reserve_spent_outputs(self, txoids, is_reserved=True):
return self.db.runOperation(
return self.run_operation(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids))
), [is_reserved]+txoids
@ -198,7 +214,7 @@ class BaseDatabase(SQLiteMixin):
@defer.inlineCallbacks
def get_transaction(self, txid):
result = yield self.db.runQuery(
result = yield self.run_query(
"SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,)
)
if result:
@ -206,7 +222,7 @@ class BaseDatabase(SQLiteMixin):
else:
defer.returnValue((None, None, False))
def get_balance_for_account(self, account, **constraints):
def get_balance_for_account(self, account, include_reserved=False, **constraints):
extra_sql = ""
if constraints:
extras = []
@ -218,6 +234,8 @@ class BaseDatabase(SQLiteMixin):
col, op = key[:-len('__lte')], '<='
extras.append('{} {} :{}'.format(col, op, key))
extra_sql = ' AND ' + ' AND '.join(extras)
if not include_reserved:
extra_sql += ' AND is_reserved=0'
values = {'account': account.public_key.address}
values.update(constraints)
return self.query_one_value(
@ -241,7 +259,7 @@ class BaseDatabase(SQLiteMixin):
)
values = {'account': account.public_key.address}
values.update(constraints)
utxos = yield self.db.runQuery(
utxos = yield self.run_query(
"""
SELECT amount, script, txid, txo.position
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
@ -271,12 +289,12 @@ class BaseDatabase(SQLiteMixin):
values.append(chain)
values.append(position)
values.append(sqlite3.Binary(pubkey.pubkey_bytes))
return self.db.runOperation(sql, values)
return self.run_operation(sql, values)
@staticmethod
def _set_address_history(t, address, history):
t.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
@classmethod
def _set_address_history(cls, t, address, history):
cls.execute(
t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)

View file

@ -7,7 +7,7 @@ from typing import Dict, Type, Iterable, Generator
from operator import itemgetter
from collections import namedtuple
from twisted.internet import defer
from twisted.internet import defer, reactor
from torba import baseaccount
from torba import basedatabase
@ -69,7 +69,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self.network = self.config.get('network') or self.network_class(self)
self.network.on_header.listen(self.process_header)
self.network.on_status.listen(self.process_status)
self.accounts = set()
self.accounts = []
self.headers = self.config.get('headers') or self.headers_class(self)
self.fee_per_byte = self.config.get('fee_per_byte', self.default_fee_per_byte)
@ -123,7 +123,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks
def add_account(self, account): # type: (baseaccount.BaseAccount) -> None
self.accounts.add(account)
self.accounts.append(account)
if self.network.is_connected:
yield self.update_account(account)
@ -297,7 +297,11 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
log.debug("{}: sync'ed tx {} for address: {}, height: {}, verified: {}".format(
self.get_id(), hex_id, address, remote_height, is_verified
))
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
reactor.callLater(
0.01, self._on_transaction_controller.add,
TransactionEvent(address, tx, remote_height, is_verified)
)
except Exception as e:
log.exception('Failed to synchronize transaction:')

View file

@ -65,6 +65,12 @@ class WalletManager(object):
for wallet in self.wallets:
return wallet.default_account
@property
def accounts(self):
for wallet in self.wallets:
for account in wallet.accounts:
yield account
@defer.inlineCallbacks
def start(self):
self.running = True