import logging from typing import Tuple, List, Sequence import sqlite3 from twisted.internet import defer from twisted.enterprise import adbapi from torba.hash import TXRefImmutable from torba.basetransaction import BaseTransaction from torba.baseaccount import BaseAccount log = logging.getLogger(__name__) def clean_arg_name(arg): return arg.replace('.', '_') def prepare_constraints(constraints): if 'account' in constraints: if isinstance(constraints['account'], BaseAccount): constraints['account'] = constraints['account'].public_key.address return constraints.pop('my_account', constraints.get('account')) def constraints_to_sql(constraints, joiner=' AND ', prepend_sql=' AND ', prepend_key=''): if not constraints: return '' extras = [] for key in list(constraints): col, op, constraint = key, '=', constraints.pop(key) if key.endswith('__not'): col, op = key[:-len('__not')], '!=' elif key.endswith('__lt'): col, op = key[:-len('__lt')], '<' elif key.endswith('__lte'): col, op = key[:-len('__lte')], '<=' elif key.endswith('__gt'): col, op = key[:-len('__gt')], '>' elif key.endswith('__like'): col, op = key[:-len('__like')], 'LIKE' elif key.endswith('__in') or key.endswith('__not_in'): if key.endswith('__in'): col, op = key[:-len('__in')], 'IN' else: col, op = key[:-len('__not_in')], 'NOT IN' if isinstance(constraint, (list, set)): placeholders = [] for item_no, item in enumerate(constraint, 1): constraints['{}_{}'.format(clean_arg_name(col), item_no)] = item placeholders.append(':{}_{}'.format(clean_arg_name(col), item_no)) items = ', '.join(placeholders) elif isinstance(constraint, str): items = constraint else: raise ValueError("{} requires a list or string as constraint value.".format(key)) extras.append('{} {} ({})'.format(col, op, items)) continue elif key.endswith('__any'): extras.append('({})'.format( constraints_to_sql(constraint, ' OR ', '', key+'_') )) for subkey, val in constraint.items(): constraints['{}_{}'.format(clean_arg_name(key), clean_arg_name(subkey))] = val continue constraints[clean_arg_name(key)] = constraint extras.append('{} {} :{}'.format(col, op, prepend_key+clean_arg_name(key))) return prepend_sql + joiner.join(extras) if extras else '' class SQLiteMixin: CREATE_TABLES_QUERY: Sequence[str] = () def __init__(self, path): self._db_path = path self.db: adbapi.ConnectionPool = None self.ledger = None def open(self): log.info("connecting to database: %s", self._db_path) self.db = adbapi.ConnectionPool( 'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False ) return self.db.runInteraction( lambda t: t.executescript(self.CREATE_TABLES_QUERY) ) def close(self): self.db.close() return defer.succeed(True) @staticmethod def _insert_sql(table: str, data: dict) -> Tuple[str, List]: columns, values = [], [] for column, value in data.items(): columns.append(column) values.append(value) sql = "INSERT INTO {} ({}) VALUES ({})".format( table, ', '.join(columns), ', '.join(['?'] * len(values)) ) return sql, values @staticmethod def _update_sql(table: str, data: dict, where: str, constraints: list) -> Tuple[str, list]: columns, values = [], [] for column, value in data.items(): columns.append("{} = ?".format(column)) values.append(value) values.extend(constraints) sql = "UPDATE {} SET {} WHERE {}".format( table, ', '.join(columns), where ) return sql, values @defer.inlineCallbacks def query_one_value(self, query, params=None, default=None): result = yield self.run_query(query, params) if result: return result[0][0] or default else: return default @defer.inlineCallbacks def query_dict_value_list(self, query, fields, params=None): result = yield self.run_query(query.format(', '.join(fields)), params) if result: return [dict(zip(fields, r)) for r in result] else: return [] @defer.inlineCallbacks def query_dict_value(self, query, fields, params=None, default=None): result = yield self.query_dict_value_list(query, fields, params) if result: return result[0] else: return default @staticmethod def execute(t, sql, values): log.debug(sql) log.debug(values) return t.execute(sql, values) def run_operation(self, sql, values): log.debug(sql) log.debug(values) return self.db.runOperation(sql, values) def run_query(self, sql, values): log.debug(sql) log.debug(values) return self.db.runQuery(sql, values) class BaseDatabase(SQLiteMixin): CREATE_PUBKEY_ADDRESS_TABLE = """ create table if not exists pubkey_address ( address text primary key, account text not null, chain integer not null, position integer not null, pubkey blob not null, history text, used_times integer not null default 0 ); """ CREATE_TX_TABLE = """ create table if not exists tx ( txid text primary key, raw blob not null, height integer not null, position integer not null, is_verified boolean not null default 0 ); """ CREATE_TXO_TABLE = """ create table if not exists txo ( txid text references tx, txoid text primary key, address text references pubkey_address, position integer not null, amount integer not null, script blob not null, is_reserved boolean not null default 0 ); """ CREATE_TXI_TABLE = """ create table if not exists txi ( txid text references tx, txoid text references txo, address text references pubkey_address ); """ CREATE_TABLES_QUERY = ( CREATE_TX_TABLE + CREATE_PUBKEY_ADDRESS_TABLE + CREATE_TXO_TABLE + CREATE_TXI_TABLE ) @staticmethod def txo_to_row(tx, address, txo): return { 'txid': tx.id, 'txoid': txo.id, 'address': address, 'position': txo.position, 'amount': txo.amount, 'script': sqlite3.Binary(txo.script.source) } def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): def _steps(t): if save_tx == 'insert': self.execute(t, *self._insert_sql('tx', { 'txid': tx.id, 'raw': sqlite3.Binary(tx.raw), 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified })) elif save_tx == 'update': self.execute(t, *self._update_sql("tx", { 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) existing_txos = [r[0] for r in self.execute( t, "SELECT position FROM txo WHERE txid = ?", (tx.id,) ).fetchall()] for txo in tx.outputs: if txo.position in existing_txos: continue if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: self.execute(t, *self._insert_sql("txo", self.txo_to_row(tx, address, txo))) elif txo.script.is_pay_script_hash: # TODO: implement script hash payments print('Database.save_transaction_io: pay script hash is not implemented!') # lookup the address associated with each TXI (via its TXO) txoids = [txi.txo_ref.id for txi in tx.inputs] txoid_place_holders = ','.join(['?']*len(txoids)) txoid_to_address = {r[0]: r[1] for r in self.execute( t, "SELECT txoid, address FROM txo WHERE txoid in ({})".format(txoid_place_holders), txoids ).fetchall()} # list of TXIs that have already been added existing_txis = [r[0] for r in self.execute( t, "SELECT txoid FROM txi WHERE txid = ?", (tx.id,) ).fetchall()] for txi in tx.inputs: txoid = txi.txo_ref.id new_txi = txoid not in existing_txis address_matches = txoid_to_address.get(txoid) == address if new_txi and address_matches: self.execute(t, *self._insert_sql("txi", { 'txid': tx.id, 'txoid': txoid, 'address': address, })) self._set_address_history(t, address, history) return self.db.runInteraction(_steps) def reserve_outputs(self, txos, is_reserved=True): txoids = [txo.id for txo in txos] return self.run_operation( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( ', '.join(['?']*len(txoids)) ), [is_reserved]+txoids ) def release_outputs(self, txos): return self.reserve_outputs(txos, is_reserved=False) def rewind_blockchain(self, above_height): # pylint: disable=no-self-use # TODO: # 1. delete transactions above_height # 2. update address histories removing deleted TXs return defer.succeed(True) @defer.inlineCallbacks def get_transaction(self, txid, account=None): txs = yield self.get_transactions(account=account, txid=txid) if len(txs) == 1: return txs[0] @defer.inlineCallbacks def get_transactions(self, offset=0, limit=1000000, **constraints): my_account = prepare_constraints(constraints) account = constraints.pop('account', None) if 'txid' not in constraints and account is not None: constraints['txid__in'] = """ SELECT txo.txid FROM txo JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account UNION SELECT txi.txid FROM txi JOIN txo USING (txoid) JOIN pubkey_address USING (address) WHERE pubkey_address.account = :account """ where = constraints_to_sql(constraints, prepend_sql='WHERE ') tx_rows = yield self.run_query( """ SELECT txid, raw, height, position, is_verified FROM tx {} ORDER BY height DESC, position DESC LIMIT :offset, :limit """.format(where), { **constraints, 'account': account, 'offset': max(offset, 0), 'limit': max(limit, 1) } ) txids, txs = [], [] for row in tx_rows: txids.append(row[0]) txs.append(self.ledger.transaction_class( raw=row[1], height=row[2], position=row[3], is_verified=row[4] )) annotated_txos = { txo.id: txo for txo in (yield self.get_txos( my_account=my_account, txid__in=txids )) } referenced_txos = { txo.id: txo for txo in (yield self.get_txos( my_account=my_account, txoid__in="SELECT txoid FROM txi WHERE txi.txid IN ({})".format( ','.join("'{}'".format(txid) for txid in txids) ) )) } for tx in txs: for txi in tx.inputs: txo = referenced_txos.get(txi.txo_ref.id) if txo: txi.txo_ref = txo.ref for txo in tx.outputs: _txo = annotated_txos.get(txo.id) if _txo: txo.update_annotations(_txo) return txs @defer.inlineCallbacks def get_txos(self, **constraints): my_account = prepare_constraints(constraints) rows = yield self.run_query( """ SELECT amount, script, txid, txo.position, chain, account FROM txo JOIN pubkey_address USING (address) """+constraints_to_sql(constraints, prepend_sql='WHERE '), constraints ) output_class = self.ledger.transaction_class.output_class return [ output_class( amount=row[0], script=output_class.script_class(row[1]), tx_ref=TXRefImmutable.from_id(row[2]), position=row[3], is_change=row[4] == 1, is_my_account=row[5] == my_account ) for row in rows ] def get_utxos(self, **constraints): constraints['txoid__not_in'] = 'SELECT txoid FROM txi' constraints['is_reserved'] = False return self.get_txos(**constraints) def get_balance_for_account(self, account, include_reserved=False, **constraints): if not include_reserved: constraints['is_reserved'] = False values = {'account': account.public_key.address} values.update(constraints) return self.query_one_value( """ SELECT SUM(amount) FROM txo JOIN tx ON tx.txid=txo.txid JOIN pubkey_address ON pubkey_address.address=txo.address WHERE pubkey_address.account=:account AND txoid NOT IN (SELECT txoid FROM txi) """+constraints_to_sql(constraints), values, 0 ) def add_keys(self, account, chain, keys): sql = ( "insert into pubkey_address " "(address, account, chain, position, pubkey) " "values " ) + ', '.join(['(?, ?, ?, ?, ?)'] * len(keys)) values = [] for position, pubkey in keys: values.append(pubkey.address) values.append(account.public_key.address) values.append(chain) values.append(position) values.append(sqlite3.Binary(pubkey.pubkey_bytes)) return self.run_operation(sql, values) @classmethod def _set_address_history(cls, t, address, history): cls.execute( t, "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) def set_address_history(self, address, history): return self.db.runInteraction(lambda t: self._set_address_history(t, address, history)) def get_addresses(self, account, chain, limit=None, max_used_times=None, order_by=None): columns = ['account', 'chain', 'position', 'address', 'used_times'] sql = ["SELECT {} FROM pubkey_address"] where = [] params = {} if account is not None: params["account"] = account.public_key.address where.append("account = :account") columns.remove("account") if chain is not None: params["chain"] = chain where.append("chain = :chain") columns.remove("chain") if max_used_times is not None: params["used_times"] = max_used_times where.append("used_times <= :used_times") if where: sql.append("WHERE") sql.append(" AND ".join(where)) if order_by: sql.append("ORDER BY {}".format(order_by)) if limit is not None: sql.append("LIMIT {}".format(limit)) return self.query_dict_value_list(" ".join(sql), columns, params) def get_address(self, address): return self.query_dict_value( "SELECT {} FROM pubkey_address WHERE address = :address", ('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), {'address': address} )