refactored queries

This commit is contained in:
Lex Berezhny 2018-10-03 07:08:02 -04:00
parent 04aa559037
commit a6f97dfbde
7 changed files with 240 additions and 91 deletions

View file

@ -5,6 +5,7 @@ from .key_fixtures import expected_ids, expected_privkeys, expected_hardened_pri
from torba.bip32 import PubKey, PrivateKey, from_extended_key_string from torba.bip32 import PubKey, PrivateKey, from_extended_key_string
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
class BIP32Tests(unittest.TestCase): class BIP32Tests(unittest.TestCase):
def test_pubkey_validation(self): def test_pubkey_validation(self):
@ -81,7 +82,6 @@ class BIP32Tests(unittest.TestCase):
self.assertIsInstance(new_privkey, PrivateKey) self.assertIsInstance(new_privkey, PrivateKey)
self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED]) self.assertEqual(hexlify(new_privkey.private_key_bytes), expected_hardened_privkeys[i - 1 - PrivateKey.HARDENED])
def test_from_extended_keys(self): def test_from_extended_keys(self):
ledger = ledger_class({ ledger = ledger_class({
'db': ledger_class.database_class(':memory:'), 'db': ledger_class.database_class(':memory:'),

View file

@ -1,8 +1,15 @@
from unittest import TestCase from twisted.trial import unittest
from twisted.internet import defer
from torba.wallet import Wallet
from torba.constants import COIN
from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class
from torba.basedatabase import constraints_to_sql from torba.basedatabase import constraints_to_sql
from .test_transaction import get_output, NULL_HASH
class TestConstraintBuilder(TestCase):
class TestConstraintBuilder(unittest.TestCase):
def test_any(self): def test_any(self):
constraints = { constraints = {
@ -21,3 +28,125 @@ class TestConstraintBuilder(TestCase):
'ages__any_age__lt': 38 'ages__any_age__lt': 38
} }
) )
def test_in_list(self):
constraints = {'ages__in': [18, 38]}
self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''),
'ages IN (:ages_1, :ages_2)'
)
self.assertEqual(
constraints, {
'ages_1': 18,
'ages_2': 38
}
)
def test_in_query(self):
constraints = {'ages__in': 'SELECT age from ages_table'}
self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''),
'ages IN (SELECT age from ages_table)'
)
self.assertEqual(constraints, {})
def test_not_in_query(self):
constraints = {'ages__not_in': 'SELECT age from ages_table'}
self.assertEqual(
constraints_to_sql(constraints, prepend_sql=''),
'ages NOT IN (SELECT age from ages_table)'
)
self.assertEqual(constraints, {})
class TestQueries(unittest.TestCase):
def setUp(self):
self.ledger = ledger_class({
'db': ledger_class.database_class(':memory:'),
'headers': ledger_class.headers_class(':memory:'),
})
return self.ledger.db.open()
@defer.inlineCallbacks
def create_account(self):
account = self.ledger.account_class.generate(self.ledger, Wallet())
yield account.ensure_address_gap()
return account
@defer.inlineCallbacks
def create_tx_from_nothing(self, my_account, height):
to_address = yield my_account.receiving.get_or_create_usable_address()
to_hash = ledger_class.address_to_hash160(to_address)
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
.add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
return tx
@defer.inlineCallbacks
def create_tx_from_txo(self, txo, to_account, height):
from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash)
to_address = yield to_account.receiving.get_or_create_usable_address()
to_hash = ledger_class.address_to_hash160(to_address)
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
yield self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
return tx
@defer.inlineCallbacks
def create_tx_to_nowhere(self, txo, height):
from_hash = txo.script.values['pubkey_hash']
from_address = self.ledger.hash160_to_address(from_hash)
to_hash = NULL_HASH
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
yield self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
return tx
def txo(self, amount, address):
return get_output(int(amount*COIN), address)
def txi(self, txo):
return ledger_class.transaction_class.input_class.spend(txo)
@defer.inlineCallbacks
def test_get_transactions(self):
account1 = yield self.create_account()
account2 = yield self.create_account()
tx1 = yield self.create_tx_from_nothing(account1, 1)
tx2 = yield self.create_tx_from_txo(tx1.outputs[0], account2, 2)
tx3 = yield self.create_tx_to_nowhere(tx2.outputs[0], 3)
txs = yield self.ledger.db.get_transactions()
self.assertEqual([tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual([3, 2, 1], [tx.height for tx in txs])
txs = yield self.ledger.db.get_transactions(account1)
self.assertEqual([tx2.id, tx1.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True)
txs = yield self.ledger.db.get_transactions(account2)
self.assertEqual([tx3.id, tx2.id], [tx.id for tx in txs])
self.assertEqual(txs[0].inputs[0].is_my_account, True)
self.assertEqual(txs[0].outputs[0].is_my_account, False)
self.assertEqual(txs[1].inputs[0].is_my_account, False)
self.assertEqual(txs[1].outputs[0].is_my_account, True)
tx = yield self.ledger.db.get_transaction(tx2.id)
self.assertEqual(tx.id, tx2.id)
self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, False)
tx = yield self.ledger.db.get_transaction(tx2.id, account1)
self.assertEqual(tx.inputs[0].is_my_account, True)
self.assertEqual(tx.outputs[0].is_my_account, False)
tx = yield self.ledger.db.get_transaction(tx2.id, account2)
self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True)

View file

@ -254,20 +254,20 @@ class TransactionIOBalancing(unittest.TestCase):
def create_utxos(self, amounts): def create_utxos(self, amounts):
utxos = [self.txo(amount) for amount in amounts] utxos = [self.txo(amount) for amount in amounts]
self.funding_tx = ledger_class.transaction_class() \ self.funding_tx = ledger_class.transaction_class(is_verified=True) \
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \ .add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
.add_outputs(utxos) .add_outputs(utxos)
save_tx = 'insert' save_tx = 'insert'
for utxo in utxos: for utxo in utxos:
yield self.ledger.db.save_transaction_io( yield self.ledger.db.save_transaction_io(
save_tx, self.funding_tx, True, save_tx, self.funding_tx,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], '' utxo.script.values['pubkey_hash'], ''
) )
save_tx = 'update' save_tx = 'update'
defer.returnValue(utxos) return utxos
@staticmethod @staticmethod
def inputs(tx): def inputs(tx):

View file

@ -357,10 +357,10 @@ class BaseAccount:
} }
def get_unspent_outputs(self, **constraints): def get_unspent_outputs(self, **constraints):
return self.ledger.db.get_utxos_for_account(self, **constraints) return self.ledger.db.get_utxos(account=self, **constraints)
def get_transactions(self) -> List['basetransaction.BaseTransaction']: def get_transactions(self) -> List['basetransaction.BaseTransaction']:
return self.ledger.db.get_transactions(self) return self.ledger.db.get_transactions(account=self)
@defer.inlineCallbacks @defer.inlineCallbacks
def fund(self, to_account, amount=None, everything=False, def fund(self, to_account, amount=None, everything=False,

View file

@ -27,6 +27,20 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend
col, op = key[:-len('__gt')], '>' col, op = key[:-len('__gt')], '>'
elif key.endswith('__like'): elif key.endswith('__like'):
col, op = key[:-len('__like')], 'LIKE' col, op = key[:-len('__like')], 'LIKE'
elif key.endswith('__in') or key.endswith('__not_in'):
if key.endswith('__in'):
col, op = key[:-len('__in')], 'IN'
else:
col, op = key[:-len('__not_in')], 'NOT IN'
items = constraints.pop(key)
if isinstance(items, list):
placeholders = []
for item_no, item in enumerate(items, 1):
constraints['{}_{}'.format(col, item_no)] = item
placeholders.append(':{}_{}'.format(col, item_no))
items = ', '.join(placeholders)
extras.append('{} {} ({})'.format(col, op, items))
continue
elif key.endswith('__any'): elif key.endswith('__any'):
subconstraints = constraints.pop(key) subconstraints = constraints.pop(key)
extras.append('({})'.format( extras.append('({})'.format(
@ -46,6 +60,7 @@ class SQLiteMixin:
def __init__(self, path): def __init__(self, path):
self._db_path = path self._db_path = path
self.db: adbapi.ConnectionPool = None self.db: adbapi.ConnectionPool = None
self.ledger = None
def open(self): def open(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
@ -186,7 +201,7 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source) 'script': sqlite3.Binary(txo.script.source)
} }
def save_transaction_io(self, save_tx, tx: BaseTransaction, is_verified, address, txhash, history): def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
def _steps(t): def _steps(t):
if save_tx == 'insert': if save_tx == 'insert':
@ -195,11 +210,11 @@ class BaseDatabase(SQLiteMixin):
'raw': sqlite3.Binary(tx.raw), 'raw': sqlite3.Binary(tx.raw),
'height': tx.height, 'height': tx.height,
'position': tx.position, 'position': tx.position,
'is_verified': is_verified 'is_verified': tx.is_verified
})) }))
elif save_tx == 'update': elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", { self.execute(t, *self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': is_verified 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in self.execute( existing_txos = [r[0] for r in self.execute(
@ -260,32 +275,40 @@ class BaseDatabase(SQLiteMixin):
return defer.succeed(True) return defer.succeed(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transaction(self, txid): def get_transaction(self, txid, account=None):
result = yield self.run_query( txs = yield self.get_transactions(account=account, txid=txid)
"SELECT raw, height, position, is_verified FROM tx WHERE txid = ?", (txid,) if len(txs) == 1:
) return txs[0]
if result:
return result[0]
else:
return None, None, None, False
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transactions(self, account, offset=0, limit=100): def get_transactions(self, account=None, txid=None, offset=0, limit=1000):
account_id = account.public_key.address
tx_where = ""
account_id = account.public_key.address if account is not None else None
if txid is not None:
tx_where = """
WHERE txid = :txid
"""
elif account is not None:
tx_where = """
WHERE txid IN (
SELECT txo.txid FROM txo
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account
UNION
SELECT txi.txid FROM txi
JOIN txo USING (txoid)
JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account
)
"""
tx_rows = yield self.run_query( tx_rows = yield self.run_query(
""" """
SELECT txid, raw, height, position FROM tx WHERE txid IN ( SELECT txid, raw, height, position, is_verified FROM tx {}
SELECT txo.txid FROM txo ORDER BY height DESC, position DESC LIMIT :offset, :limit
JOIN pubkey_address USING (address) """.format(tx_where), {
WHERE pubkey_address.account = :account
UNION
SELECT txo.txid FROM txi
JOIN txo USING (txoid)
JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account
) ORDER BY height DESC, position DESC LIMIT :offset, :limit
""", {
'account': account_id, 'account': account_id,
'txid': txid,
'offset': min(offset, 0), 'offset': min(offset, 0),
'limit': max(limit, 100) 'limit': max(limit, 100)
} }
@ -293,8 +316,8 @@ class BaseDatabase(SQLiteMixin):
txids, txs = [], [] txids, txs = [], []
for row in tx_rows: for row in tx_rows:
txids.append(row[0]) txids.append(row[0])
txs.append(account.ledger.transaction_class( txs.append(self.ledger.transaction_class(
raw=row[1], height=row[2], position=row[3] raw=row[1], height=row[2], position=row[3], is_verified=row[4]
)) ))
txo_rows = yield self.run_query( txo_rows = yield self.run_query(
@ -311,31 +334,18 @@ class BaseDatabase(SQLiteMixin):
'is_my_account': row[2] == account_id 'is_my_account': row[2] == account_id
} }
referenced_txo_rows = yield self.run_query( referenced_txos = yield self.get_txos(
""" account=account,
SELECT txoid, txo.amount, txo.script, txo.txid, txo.position, chain, account txoid__in="SELECT txoid FROM txi WHERE txi.txid IN ({})".format(
FROM txi ','.join("'{}'".format(txid) for txid in txids)
JOIN txo USING (txoid)
JOIN pubkey_address USING (address)
WHERE txi.txid IN ({})
""".format(', '.join(['?']*len(txids))), txids
)
referenced_txos = {}
output_class = account.ledger.transaction_class.output_class
for row in referenced_txo_rows:
referenced_txos[row[0]] = output_class(
amount=row[1],
script=output_class.script_class(row[2]),
tx_ref=TXRefImmutable.from_id(row[3]),
position=row[4],
is_change=row[5] == 1,
is_my_account=row[6] == account_id
) )
)
referenced_txos_map = {txo.id: txo for txo in referenced_txos}
for tx in txs: for tx in txs:
for txi in tx.inputs: for txi in tx.inputs:
if txi.txo_ref.id in referenced_txos: if txi.txo_ref.id in referenced_txos_map:
txi.txo_ref = TXORefResolvable(referenced_txos[txi.txo_ref.id]) txi.txo_ref = TXORefResolvable(referenced_txos_map[txi.txo_ref.id])
for txo in tx.outputs: for txo in tx.outputs:
txo_meta = txos.get(txo.id) txo_meta = txos.get(txo.id)
if txo_meta is not None: if txo_meta is not None:
@ -347,6 +357,35 @@ class BaseDatabase(SQLiteMixin):
return txs return txs
@defer.inlineCallbacks
def get_txos(self, account=None, **constraints):
account_id = None
if account is not None:
account_id = account.public_key.address
constraints['account'] = account_id
rows = yield self.run_query(
"""
SELECT amount, script, txid, txo.position, chain, account
FROM txo JOIN pubkey_address USING (address)
"""+constraints_to_sql(constraints, prepend_sql='WHERE '), constraints
)
output_class = self.ledger.transaction_class.output_class
return [
output_class(
amount=row[0],
script=output_class.script_class(row[1]),
tx_ref=TXRefImmutable.from_id(row[2]),
position=row[3],
is_change=row[4] == 1,
is_my_account=row[5] == account_id
) for row in rows
]
def get_utxos(self, **constraints):
constraints['txoid__not_in'] = 'SELECT txoid FROM txi'
constraints['is_reserved'] = 0
return self.get_txos(**constraints)
def get_balance_for_account(self, account, include_reserved=False, **constraints): def get_balance_for_account(self, account, include_reserved=False, **constraints):
if not include_reserved: if not include_reserved:
constraints['is_reserved'] = 0 constraints['is_reserved'] = 0
@ -364,26 +403,6 @@ class BaseDatabase(SQLiteMixin):
"""+constraints_to_sql(constraints), values, 0 """+constraints_to_sql(constraints), values, 0
) )
@defer.inlineCallbacks
def get_utxos_for_account(self, account, **constraints):
constraints['account'] = account.public_key.address
utxos = yield self.run_query(
"""
SELECT amount, script, txid, txo.position
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi)
"""+constraints_to_sql(constraints), constraints
)
output_class = account.ledger.transaction_class.output_class
return [
output_class(
values[0],
output_class.script_class(values[1]),
TXRefImmutable.from_id(values[2]),
position=values[3]
) for values in utxos
]
def add_keys(self, account, chain, keys): def add_keys(self, account, chain, keys):
sql = ( sql = (
"insert into pubkey_address " "insert into pubkey_address "

View file

@ -40,7 +40,7 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id] return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'is_verified'))): class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
pass pass
@ -73,6 +73,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.db: BaseDatabase = self.config.get('db') or self.database_class( self.db: BaseDatabase = self.config.get('db') or self.database_class(
os.path.join(self.path, "blockchain.db") os.path.join(self.path, "blockchain.db")
) )
self.db.ledger = self
self.headers: BaseHeaders = self.config.get('headers') or self.headers_class( self.headers: BaseHeaders = self.config.get('headers') or self.headers_class(
os.path.join(self.path, "headers") os.path.join(self.path, "headers")
) )
@ -87,7 +88,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.on_transaction.listen( self.on_transaction.listen(
lambda e: log.info( lambda e: log.info(
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
self.get_id(), e.address, e.tx.height, e.is_verified, e.tx.id self.get_id(), e.address, e.tx.height, e.tx.is_verified, e.tx.id
) )
) )
@ -214,7 +215,7 @@ class BaseLedger(metaclass=LedgerRegistry):
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height] header = self.headers[height]
tx.position = merkle['pos'] tx.position = merkle['pos']
return merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@ -366,32 +367,31 @@ class BaseLedger(metaclass=LedgerRegistry):
try: try:
# see if we have a local copy of transaction, otherwise fetch it from server # see if we have a local copy of transaction, otherwise fetch it from server
raw, _, position, is_verified = yield self.db.get_transaction(hex_id) tx = yield self.db.get_transaction(hex_id)
save_tx = None save_tx = None
if raw is None: if tx is None:
_raw = yield self.network.get_transaction(hex_id) _raw = yield self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw), height=remote_height) tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert' save_tx = 'insert'
else:
tx = self.transaction_class(raw, height=remote_height)
if remote_height > 0 and (not is_verified or position is None): tx.height = remote_height
is_verified = yield self.validate_transaction_and_set_position(tx, remote_height)
is_verified = 1 if is_verified else 0 if remote_height > 0 and (not tx.is_verified or tx.position == -1):
yield self.validate_transaction_and_set_position(tx, remote_height)
if save_tx is None: if save_tx is None:
save_tx = 'update' save_tx = 'update'
yield self.db.save_transaction_io( yield self.db.save_transaction_io(
save_tx, tx, is_verified, address, self.address_to_hash160(address), save_tx, tx, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history) ''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history)
) )
log.debug( log.debug(
"%s: sync'ed tx %s for address: %s, height: %s, verified: %s", "%s: sync'ed tx %s for address: %s, height: %s, verified: %s",
self.get_id(), hex_id, address, remote_height, is_verified self.get_id(), hex_id, address, tx.height, tx.is_verified
) )
self._on_transaction_controller.add(TransactionEvent(address, tx, is_verified)) self._on_transaction_controller.add(TransactionEvent(address, tx))
except Exception: except Exception:
log.exception('Failed to synchronize transaction:') log.exception('Failed to synchronize transaction:')

View file

@ -141,7 +141,7 @@ class BaseInput(InputOutput):
def is_my_account(self) -> Optional[bool]: def is_my_account(self) -> Optional[bool]:
""" True if the output this input spends is yours. """ """ True if the output this input spends is yours. """
if self.txo_ref.txo is None: if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to determine ownership.') return False
return self.txo_ref.txo.is_my_account return self.txo_ref.txo.is_my_account
@classmethod @classmethod
@ -237,7 +237,7 @@ class BaseTransaction:
input_class = BaseInput input_class = BaseInput
output_class = BaseOutput output_class = BaseOutput
def __init__(self, raw=None, version: int = 1, locktime: int = 0, def __init__(self, raw=None, version: int = 1, locktime: int = 0, is_verified: bool = False,
height: int = -1, position: int = -1) -> None: height: int = -1, position: int = -1) -> None:
self._raw = raw self._raw = raw
self.ref = TXRefMutable(self) self.ref = TXRefMutable(self)
@ -245,6 +245,7 @@ class BaseTransaction:
self.locktime = locktime self.locktime = locktime
self._inputs: List[BaseInput] = [] self._inputs: List[BaseInput] = []
self._outputs: List[BaseOutput] = [] self._outputs: List[BaseOutput] = []
self.is_verified = is_verified
self.height = height self.height = height
self.position = position self.position = position
if raw is not None: if raw is not None: