import logging import asyncio from asyncio import wrap_future from concurrent.futures.thread import ThreadPoolExecutor from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable import sqlite3 from torba.client.hash import TXRefImmutable from torba.client.basetransaction import BaseTransaction from torba.client.baseaccount import BaseAccount log = logging.getLogger(__name__) class AIOSQLite: def __init__(self): # has to be single threaded as there is no mapping of thread:connection self.executor = ThreadPoolExecutor(max_workers=1) self.connection: sqlite3.Connection = None @classmethod async def connect(cls, path: Union[bytes, str], *args, **kwargs): db = cls() db.connection = await wrap_future(db.executor.submit(sqlite3.connect, path, *args, **kwargs)) return db async def close(self): def __close(conn): self.executor.submit(conn.close) self.executor.shutdown(wait=True) conn = self.connection self.connection = None return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn) def executescript(self, script: str) -> Awaitable: return wrap_future(self.executor.submit(self.connection.executescript, script)) def execute_fetchall(self, sql: str, parameters: Iterable = None) -> Awaitable[Iterable[sqlite3.Row]]: parameters = parameters if parameters is not None else [] def __fetchall(conn: sqlite3.Connection, *args, **kwargs): return conn.execute(*args, **kwargs).fetchall() return wrap_future(self.executor.submit(__fetchall, self.connection, sql, parameters)) def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: parameters = parameters if parameters is not None else [] return self.run(lambda conn, sql, parameters: conn.execute(sql, parameters), sql, parameters) def run(self, fun, *args, **kwargs) -> Awaitable: return wrap_future(self.executor.submit(self.__run_transaction, fun, *args, **kwargs)) def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): self.connection.execute('begin') try: fun(self.connection, *args, **kwargs) # type: ignore except (Exception, OSError): self.connection.rollback() raise else: self.connection.commit() def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): sql, values = [], {} for key, constraint in constraints.items(): col, op, key = key, '=', key.replace('.', '_') if key.startswith('$'): values[key] = constraint continue elif key.endswith('__not'): col, op = col[:-len('__not')], '!=' elif key.endswith('__lt'): col, op = col[:-len('__lt')], '<' elif key.endswith('__lte'): col, op = col[:-len('__lte')], '<=' elif key.endswith('__gt'): col, op = col[:-len('__gt')], '>' elif key.endswith('__like'): col, op = col[:-len('__like')], 'LIKE' elif key.endswith('__in') or key.endswith('__not_in'): if key.endswith('__in'): col, op = col[:-len('__in')], 'IN' else: col, op = col[:-len('__not_in')], 'NOT IN' if isinstance(constraint, (list, set)): items = ', '.join( "'{}'".format(item) if isinstance(item, str) else str(item) for item in constraint ) elif isinstance(constraint, str): items = constraint else: raise ValueError("{} requires a list, set or string as constraint value.".format(col)) sql.append('{} {} ({})'.format(col, op, items)) continue elif key.endswith('__any'): where, subvalues = constraints_to_sql(constraint, ' OR ', key+'_') sql.append('({})'.format(where)) values.update(subvalues) continue sql.append('{} {} :{}'.format(col, op, prepend_key+key)) values[prepend_key+key] = constraint return joiner.join(sql) if sql else '', values def query(select, **constraints): sql = [select] limit = constraints.pop('limit', None) offset = constraints.pop('offset', None) order_by = constraints.pop('order_by', None) constraints.pop('my_account', None) account = constraints.pop('account', None) if account is not None: if not isinstance(account, list): account = [account] constraints['account__in'] = [ (a.public_key.address if isinstance(a, BaseAccount) else a) for a in account ] where, values = constraints_to_sql(constraints) if where: sql.append('WHERE') sql.append(where) if order_by is not None: sql.append('ORDER BY') if isinstance(order_by, str): sql.append(order_by) elif isinstance(order_by, list): sql.append(', '.join(order_by)) else: raise ValueError("order_by must be string or list") if limit is not None: sql.append('LIMIT {}'.format(limit)) if offset is not None: sql.append('OFFSET {}'.format(offset)) return ' '.join(sql), values def rows_to_dict(rows, fields): if rows: return [dict(zip(fields, r)) for r in rows] else: return [] class SQLiteMixin: CREATE_TABLES_QUERY: str def __init__(self, path): self._db_path = path self.db: AIOSQLite = None self.ledger = None async def open(self): log.info("connecting to database: %s", self._db_path) self.db = await AIOSQLite.connect(self._db_path) await self.db.executescript(self.CREATE_TABLES_QUERY) async def close(self): await self.db.close() @staticmethod def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]: columns, values = [], [] for column, value in data.items(): columns.append(column) values.append(value) or_ignore = "" if ignore_duplicate: or_ignore = " OR IGNORE" sql = "INSERT{} INTO {} ({}) VALUES ({})".format( or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values)) ) return sql, values @staticmethod def _update_sql(table: str, data: dict, where: str, constraints: Union[list, tuple]) -> Tuple[str, list]: columns, values = [], [] for column, value in data.items(): columns.append("{} = ?".format(column)) values.append(value) values.extend(constraints) sql = "UPDATE {} SET {} WHERE {}".format( table, ', '.join(columns), where ) return sql, values class BaseDatabase(SQLiteMixin): PRAGMAS = """ pragma journal_mode=WAL; """ CREATE_PUBKEY_ADDRESS_TABLE = """ create table if not exists pubkey_address ( address text primary key, account text not null, chain integer not null, position integer not null, pubkey blob not null, history text, used_times integer not null default 0 ); """ CREATE_PUBKEY_ADDRESS_INDEX = """ create index if not exists pubkey_address_account_idx on pubkey_address (account); """ CREATE_TX_TABLE = """ create table if not exists tx ( txid text primary key, raw blob not null, height integer not null, position integer not null, is_verified boolean not null default 0 ); """ CREATE_TXO_TABLE = """ create table if not exists txo ( txid text references tx, txoid text primary key, address text references pubkey_address, position integer not null, amount integer not null, script blob not null, is_reserved boolean not null default 0 ); """ CREATE_TXO_INDEX = """ create index if not exists txo_address_idx on txo (address); """ CREATE_TXI_TABLE = """ create table if not exists txi ( txid text references tx, txoid text references txo, address text references pubkey_address ); """ CREATE_TXI_INDEX = """ create index if not exists txi_address_idx on txi (address); create index if not exists txi_txoid_idx on txi (txoid); """ CREATE_TABLES_QUERY = ( PRAGMAS + CREATE_TX_TABLE + CREATE_PUBKEY_ADDRESS_TABLE + CREATE_PUBKEY_ADDRESS_INDEX + CREATE_TXO_TABLE + CREATE_TXO_INDEX + CREATE_TXI_TABLE + CREATE_TXI_INDEX ) @staticmethod def txo_to_row(tx, address, txo): return { 'txid': tx.id, 'txoid': txo.id, 'address': address, 'position': txo.position, 'amount': txo.amount, 'script': sqlite3.Binary(txo.script.source) } async def insert_transaction(self, tx): await self.db.execute(*self._insert_sql('tx', { 'txid': tx.id, 'raw': sqlite3.Binary(tx.raw), 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified })) async def update_transaction(self, tx): await self.db.execute(*self._update_sql("tx", { 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) def save_transaction_io(self, tx: BaseTransaction, address, txhash, history): def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): for txo in tx.outputs: if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: conn.execute(*self._insert_sql( "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True )) elif txo.script.is_pay_script_hash: # TODO: implement script hash payments log.warning('Database.save_transaction_io: pay script hash is not implemented!') for txi in tx.inputs: if txi.txo_ref.txo is not None: txo = txi.txo_ref.txo if txo.get_address(self.ledger) == address: conn.execute(*self._insert_sql("txi", { 'txid': tx.id, 'txoid': txo.id, 'address': address, }, ignore_duplicate=True)) conn.execute( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) return self.db.run(_transaction, tx, address, txhash, history) async def reserve_outputs(self, txos, is_reserved=True): txoids = [txo.id for txo in txos] await self.db.execute( "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( ', '.join(['?']*len(txoids)) ), [is_reserved]+txoids ) async def release_outputs(self, txos): await self.reserve_outputs(txos, is_reserved=False) async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use # TODO: # 1. delete transactions above_height # 2. update address histories removing deleted TXs return True async def select_transactions(self, cols, account=None, **constraints): if 'txid' not in constraints and account is not None: constraints['$account'] = account.public_key.address constraints['txid__in'] = """ SELECT txo.txid FROM txo JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account UNION SELECT txi.txid FROM txi JOIN pubkey_address USING (address) WHERE pubkey_address.account = :$account """ return await self.db.execute_fetchall( *query("SELECT {} FROM tx".format(cols), **constraints) ) async def get_transactions(self, my_account=None, **constraints): my_account = my_account or constraints.get('account', None) tx_rows = await self.select_transactions( 'txid, raw, height, position, is_verified', order_by=["height DESC", "position DESC"], **constraints ) if not tx_rows: return [] txids, txs = [], [] for row in tx_rows: txids.append(row[0]) txs.append(self.ledger.transaction_class( raw=row[1], height=row[2], position=row[3], is_verified=bool(row[4]) )) annotated_txos = { txo.id: txo for txo in (await self.get_txos( my_account=my_account, txid__in=txids )) } referenced_txos = { txo.id: txo for txo in (await self.get_txos( my_account=my_account, txoid__in=query("SELECT txoid FROM txi", **{'txid__in': txids})[0] )) } for tx in txs: for txi in tx.inputs: txo = referenced_txos.get(txi.txo_ref.id) if txo: txi.txo_ref = txo.ref for txo in tx.outputs: _txo = annotated_txos.get(txo.id) if _txo: txo.update_annotations(_txo) else: txo.update_annotations(None) return txs async def get_transaction_count(self, **constraints): constraints.pop('offset', None) constraints.pop('limit', None) constraints.pop('order_by', None) count = await self.select_transactions('count(*)', **constraints) return count[0][0] async def get_transaction(self, **constraints): txs = await self.get_transactions(limit=1, **constraints) if txs: return txs[0] async def select_txos(self, cols, **constraints): return await self.db.execute_fetchall(*query( "SELECT {} FROM txo" " JOIN pubkey_address USING (address)" " JOIN tx USING (txid)".format(cols), **constraints )) async def get_txos(self, my_account=None, **constraints): my_account = my_account or constraints.get('account', None) if isinstance(my_account, BaseAccount): my_account = my_account.public_key.address rows = await self.select_txos( "amount, script, txid, tx.height, txo.position, chain, account", **constraints ) output_class = self.ledger.transaction_class.output_class return [ output_class( amount=row[0], script=output_class.script_class(row[1]), tx_ref=TXRefImmutable.from_id(row[2], row[3]), position=row[4], is_change=row[5] == 1, is_my_account=row[6] == my_account ) for row in rows ] async def get_txo_count(self, **constraints): constraints.pop('offset', None) constraints.pop('limit', None) constraints.pop('order_by', None) count = await self.select_txos('count(*)', **constraints) return count[0][0] @staticmethod def constrain_utxo(constraints): constraints['is_reserved'] = False constraints['txoid__not_in'] = "SELECT txoid FROM txi" def get_utxos(self, **constraints): self.constrain_utxo(constraints) return self.get_txos(**constraints) def get_utxo_count(self, **constraints): self.constrain_utxo(constraints) return self.get_txo_count(**constraints) async def get_balance(self, **constraints): self.constrain_utxo(constraints) balance = await self.select_txos('SUM(amount)', **constraints) return balance[0][0] or 0 async def select_addresses(self, cols, **constraints): return await self.db.execute_fetchall(*query( "SELECT {} FROM pubkey_address".format(cols), **constraints )) async def get_addresses(self, cols=('address', 'account', 'chain', 'position', 'used_times'), **constraints): addresses = await self.select_addresses(', '.join(cols), **constraints) return rows_to_dict(addresses, cols) async def get_address_count(self, **constraints): count = await self.select_addresses('count(*)', **constraints) return count[0][0] async def get_address(self, **constraints): addresses = await self.get_addresses( cols=('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), limit=1, **constraints ) if addresses: return addresses[0] async def add_keys(self, account, chain, keys): sql = ( "insert into pubkey_address " "(address, account, chain, position, pubkey) " "values " ) + ', '.join(['(?, ?, ?, ?, ?)'] * len(keys)) values = [] for position, pubkey in keys: values.extend(( pubkey.address, account.public_key.address, chain, position, sqlite3.Binary(pubkey.pubkey_bytes) )) await self.db.execute(sql, values) async def _set_address_history(self, address, history): await self.db.execute( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) async def set_address_history(self, address, history): await self._set_address_history(address, history)