diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fab46fcf4..99710726b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,6 +36,9 @@ jobs: - datanetwork - blockchain - other + db: + - postgres + - sqlite steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 @@ -44,7 +47,9 @@ jobs: - if: matrix.test == 'other' run: sudo apt install -y --no-install-recommends ffmpeg - run: pip install tox-travis - - run: tox -e ${{ matrix.test }} + - env: + TEST_DB: ${{ matrix.db }} + run: tox -e ${{ matrix.test }} build: needs: ["lint", "tests-unit", "tests-integration"] diff --git a/lbry/blockchain/db.py b/lbry/blockchain/db.py deleted file mode 100644 index a09543c1f..000000000 --- a/lbry/blockchain/db.py +++ /dev/null @@ -1,172 +0,0 @@ -import os -import asyncio -from concurrent import futures -from collections import namedtuple, deque - -import sqlite3 -import apsw - - -DDL = """ -pragma journal_mode=WAL; - -create table if not exists block ( - block_hash bytes not null primary key, - previous_hash bytes not null, - file_number integer not null, - height int -); -create table if not exists tx ( - block_hash integer not null, - position integer not null, - tx_hash bytes not null -); -create table if not exists txi ( - block_hash bytes not null, - tx_hash bytes not null, - txo_hash bytes not null -); -create table if not exists claim ( - txo_hash bytes not null, - claim_hash bytes not null, - claim_name text not null, - amount integer not null, - height integer -); -create table if not exists claim_history ( - block_hash bytes not null, - tx_hash bytes not null, - tx_position integer not null, - txo_hash bytes not null, - claim_hash bytes not null, - claim_name text not null, - action integer not null, - amount integer not null, - height integer, - is_spent bool -); -create table if not exists support ( - block_hash bytes not null, - tx_hash bytes not null, - txo_hash bytes not null, - claim_hash bytes not null, - amount integer not null -); -""" - - -class BlockchainDB: - - __slots__ = 'db', 'directory' - - def __init__(self, path: str): - self.db = None - self.directory = path - - @property - def db_file_path(self): - return os.path.join(self.directory, 'blockchain.db') - - def open(self): - self.db = sqlite3.connect(self.db_file_path, isolation_level=None, uri=True, timeout=60.0 * 5) - self.db.executescript(""" - pragma journal_mode=wal; - """) -# self.db = apsw.Connection( -# self.db_file_path, -# flags=( -# apsw.SQLITE_OPEN_READWRITE | -# apsw.SQLITE_OPEN_CREATE | -# apsw.SQLITE_OPEN_URI -# ) -# ) - self.execute_ddl(DDL) - self.execute(f"ATTACH ? AS block_index", ('file:'+os.path.join(self.directory, 'block_index.sqlite')+'?mode=ro',)) - #def exec_factory(cursor, statement, bindings): - # tpl = namedtuple('row', (d[0] for d in cursor.getdescription())) - # cursor.setrowtrace(lambda cursor, row: tpl(*row)) - # return True - #self.db.setexectrace(exec_factory) - def row_factory(cursor, row): - tpl = namedtuple('row', (d[0] for d in cursor.description)) - return tpl(*row) - self.db.row_factory = row_factory - return self - - def close(self): - if self.db is not None: - self.db.close() - - def execute(self, *args): - return self.db.cursor().execute(*args) - - def execute_many(self, *args): - return self.db.cursor().executemany(*args) - - def execute_many_tx(self, *args): - cursor = self.db.cursor() - cursor.execute('begin;') - result = cursor.executemany(*args) - cursor.execute('commit;') - return result - - def execute_ddl(self, *args): - self.db.executescript(*args) - #deque(self.execute(*args), maxlen=0) - - def begin(self): - self.execute('begin;') - - def commit(self): - self.execute('commit;') - - def get_block_file_path_from_number(self, block_file_number): - return os.path.join(self.directory, 'blocks', f'blk{block_file_number:05}.dat') - - def get_block_files_not_synced(self): - return list(self.execute( - """ - SELECT file as file_number, COUNT(hash) as blocks, SUM(txcount) as txs - FROM block_index.block_info - WHERE hash NOT IN (SELECT block_hash FROM block) - GROUP BY file ORDER BY file ASC; - """ - )) - - def get_blocks_not_synced(self, block_file): - return self.execute( - """ - SELECT datapos as data_offset, height, hash as block_hash, txCount as txs - FROM block_index.block_info - WHERE file = ? AND hash NOT IN (SELECT block_hash FROM block) - ORDER BY datapos ASC; - """, (block_file,) - ) - - -class AsyncBlockchainDB: - - def __init__(self, db: BlockchainDB): - self.sync_db = db - self.executor = futures.ThreadPoolExecutor(max_workers=1) - - @classmethod - def from_path(cls, path: str) -> 'AsyncBlockchainDB': - return cls(BlockchainDB(path)) - - def get_block_file_path_from_number(self, block_file_number): - return self.sync_db.get_block_file_path_from_number(block_file_number) - - async def run_in_executor(self, func, *args): - return await asyncio.get_running_loop().run_in_executor( - self.executor, func, *args - ) - - async def open(self): - return await self.run_in_executor(self.sync_db.open) - - async def close(self): - return await self.run_in_executor(self.sync_db.close) - - async def get_block_files_not_synced(self): - return await self.run_in_executor(self.sync_db.get_block_files_not_synced) diff --git a/lbry/db/__init__.py b/lbry/db/__init__.py new file mode 100644 index 000000000..303d19265 --- /dev/null +++ b/lbry/db/__init__.py @@ -0,0 +1,6 @@ +from .database import Database, in_account +from .tables import ( + Table, Version, metadata, + AccountAddress, PubkeyAddress, + Block, TX, TXO, TXI +) diff --git a/lbry/db/database.py b/lbry/db/database.py new file mode 100644 index 000000000..ba64742a2 --- /dev/null +++ b/lbry/db/database.py @@ -0,0 +1,1004 @@ +import logging +import asyncio +import sqlite3 +from binascii import hexlify +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Tuple, List, Union, Any, Iterable, Dict, Optional +from datetime import date + +import sqlalchemy +from sqlalchemy import select, text, and_, union, func +from sqlalchemy.sql.expression import Select +try: + from sqlalchemy.dialects.postgresql import insert as pg_insert +except ImportError: + pg_insert = None + + +from lbry.wallet import PubKey +from lbry.wallet.transaction import Transaction, Output, OutputScript, TXRefImmutable +from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES + +from .tables import metadata, Version, TX, TXI, TXO, PubkeyAddress, AccountAddress + + +log = logging.getLogger(__name__) +sqlite3.enable_callback_tracebacks(True) + + +def insert_or_ignore(conn, table): + if conn.dialect.name == 'sqlite': + return table.insert(prefixes=("OR IGNORE",)) + elif conn.dialect.name == 'postgresql': + return pg_insert(table).on_conflict_do_nothing() + else: + raise RuntimeError(f'Unknown database dialect: {conn.dialect.name}.') + + +def insert_or_replace(conn, table, replace): + if conn.dialect.name == 'sqlite': + return table.insert(prefixes=("OR REPLACE",)) + elif conn.dialect.name == 'postgresql': + insert = pg_insert(table) + return insert.on_conflict_do_update( + table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace} + ) + else: + raise RuntimeError(f'Unknown database dialect: {conn.dialect.name}.') + + +def dict_to_clause(t, d): + clauses = [] + for key, value in d.items(): + clauses.append(getattr(t.c, key) == value) + return and_(*clauses) + + +def constrain_single_or_list(constraints, column, value, convert=lambda x: x): + if value is not None: + if isinstance(value, list): + value = [convert(v) for v in value] + if len(value) == 1: + constraints[column] = value[0] + elif len(value) > 1: + constraints[f"{column}__in"] = value + else: + constraints[column] = convert(value) + return constraints + + +def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): + sql, values = [], {} + for key, constraint in constraints.items(): + tag = '0' + if '#' in key: + key, tag = key[:key.index('#')], key[key.index('#')+1:] + col, op, key = key, '=', key.replace('.', '_') + if not key: + sql.append(constraint) + continue + if key.startswith('$$'): + col, key = col[2:], key[1:] + elif key.startswith('$'): + values[key] = constraint + continue + if key.endswith('__not'): + col, op = col[:-len('__not')], '!=' + elif key.endswith('__is_null'): + col = col[:-len('__is_null')] + sql.append(f'{col} IS NULL') + continue + if key.endswith('__is_not_null'): + col = col[:-len('__is_not_null')] + sql.append(f'{col} IS NOT NULL') + continue + if key.endswith('__lt'): + col, op = col[:-len('__lt')], '<' + elif key.endswith('__lte'): + col, op = col[:-len('__lte')], '<=' + elif key.endswith('__gt'): + col, op = col[:-len('__gt')], '>' + elif key.endswith('__gte'): + col, op = col[:-len('__gte')], '>=' + elif key.endswith('__like'): + col, op = col[:-len('__like')], 'LIKE' + elif key.endswith('__not_like'): + col, op = col[:-len('__not_like')], 'NOT LIKE' + elif key.endswith('__in') or key.endswith('__not_in'): + if key.endswith('__in'): + col, op, one_val_op = col[:-len('__in')], 'IN', '=' + else: + col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!=' + if constraint: + if isinstance(constraint, (list, set, tuple)): + if len(constraint) == 1: + values[f'{key}{tag}'] = next(iter(constraint)) + sql.append(f'{col} {one_val_op} :{key}{tag}') + else: + keys = [] + for i, val in enumerate(constraint): + keys.append(f':{key}{tag}_{i}') + values[f'{key}{tag}_{i}'] = val + sql.append(f'{col} {op} ({", ".join(keys)})') + elif isinstance(constraint, str): + sql.append(f'{col} {op} ({constraint})') + else: + raise ValueError(f"{col} requires a list, set or string as constraint value.") + continue + elif key.endswith('__any') or key.endswith('__or'): + where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_') + sql.append(f'({where})') + values.update(subvalues) + continue + if key.endswith('__and'): + where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_') + sql.append(f'({where})') + values.update(subvalues) + continue + sql.append(f'{col} {op} :{prepend_key}{key}{tag}') + values[prepend_key+key+tag] = constraint + return joiner.join(sql) if sql else '', values + + +def query(select, **constraints) -> Tuple[str, Dict[str, Any]]: + sql = [select] + limit = constraints.pop('limit', None) + offset = constraints.pop('offset', None) + order_by = constraints.pop('order_by', None) + group_by = constraints.pop('group_by', None) + + accounts = constraints.pop('accounts', []) + if accounts: + constraints['account__in'] = [a.public_key.address for a in accounts] + + where, values = constraints_to_sql(constraints) + if where: + sql.append('WHERE') + sql.append(where) + + if group_by is not None: + sql.append(f'GROUP BY {group_by}') + + if order_by: + sql.append('ORDER BY') + if isinstance(order_by, str): + sql.append(order_by) + elif isinstance(order_by, list): + sql.append(', '.join(order_by)) + else: + raise ValueError("order_by must be string or list") + + if limit is not None: + sql.append(f'LIMIT {limit}') + + if offset is not None: + sql.append(f'OFFSET {offset}') + + return ' '.join(sql), values + + +def in_account(accounts: Union[List[PubKey], PubKey]): + if isinstance(accounts, list): + if len(accounts) > 1: + return AccountAddress.c.account.in_({a.public_key.address for a in accounts}) + accounts = accounts[0] + return AccountAddress.c.account == accounts.public_key.address + + +def query2(table, select: Select, **constraints) -> Select: + limit = constraints.pop('limit', None) + if limit is not None: + select = select.limit(limit) + + offset = constraints.pop('offset', None) + if offset is not None: + select = select.offset(offset) + + order_by = constraints.pop('order_by', None) + if order_by: + if isinstance(order_by, str): + select = select.order_by(text(order_by)) + elif isinstance(order_by, list): + select = select.order_by(text(', '.join(order_by))) + else: + raise ValueError("order_by must be string or list") + + group_by = constraints.pop('group_by', None) + if group_by is not None: + select = select.group_by(text(group_by)) + + accounts = constraints.pop('accounts', []) + if accounts: + select.append_whereclause(in_account(accounts)) + + if constraints: + select.append_whereclause( + constraints_to_clause2(table, constraints) + ) + + return select + + +def constraints_to_clause2(tables, constraints): + clause = [] + for key, constraint in constraints.items(): + if key.endswith('__not'): + col, op = key[:-len('__not')], '__ne__' + elif key.endswith('__is_null'): + col = key[:-len('__is_null')] + op = '__eq__' + constraint = None + elif key.endswith('__is_not_null'): + col = key[:-len('__is_not_null')] + op = '__ne__' + constraint = None + elif key.endswith('__lt'): + col, op = key[:-len('__lt')], '__lt__' + elif key.endswith('__lte'): + col, op = key[:-len('__lte')], '__le__' + elif key.endswith('__gt'): + col, op = key[:-len('__gt')], '__gt__' + elif key.endswith('__gte'): + col, op = key[:-len('__gte')], '__ge__' + elif key.endswith('__like'): + col, op = key[:-len('__like')], 'like' + elif key.endswith('__not_like'): + col, op = key[:-len('__not_like')], 'notlike' + elif key.endswith('__in') or key.endswith('__not_in'): + if key.endswith('__in'): + col, op, one_val_op = key[:-len('__in')], 'in_', '__eq__' + else: + col, op, one_val_op = key[:-len('__not_in')], 'notin_', '__ne__' + if isinstance(constraint, sqlalchemy.sql.expression.Select): + pass + elif constraint: + if isinstance(constraint, (list, set, tuple)): + if len(constraint) == 1: + op = one_val_op + constraint = next(iter(constraint)) + elif isinstance(constraint, str): + constraint = text(constraint) + else: + raise ValueError(f"{col} requires a list, set or string as constraint value.") + else: + continue + else: + col, op = key, '__eq__' + attr = None + for table in tables: + attr = getattr(table.c, col, None) + if attr is not None: + clause.append(getattr(attr, op)(constraint)) + break + if attr is None: + raise ValueError(f"Attribute '{col}' not found on tables: {', '.join([t.name for t in tables])}.") + return and_(*clause) + + +class Database: + + SCHEMA_VERSION = "1.3" + MAX_QUERY_VARIABLES = 900 + + def __init__(self, url): + self.url = url + self.ledger = None + self.executor = ThreadPoolExecutor(max_workers=1) + self.engine = None + self.db: Optional[sqlalchemy.engine.Connection] = None + + async def execute_fetchall(self, sql, params=None) -> List[dict]: + def foo(): + if params: + result = self.db.execute(sql, params) + else: + result = self.db.execute(sql) + if result.returns_rows: + return [dict(r) for r in result.fetchall()] + else: + try: + self.db.commit() + except: + pass + return [] + return await asyncio.get_event_loop().run_in_executor(self.executor, foo) + + def sync_executemany(self, sql, parameters): + self.db.execute(sql, parameters) + + async def executemany(self, sql: str, parameters: Iterable = None): + return await asyncio.get_event_loop().run_in_executor( + self.executor, self.sync_executemany, sql, parameters + ) + + def sync_open(self): + log.info("connecting to database: %s", self.url) + self.engine = sqlalchemy.create_engine(self.url) + self.db = self.engine.connect() + if self.SCHEMA_VERSION: + if self.engine.has_table('version'): + version = self.db.execute(Version.select().limit(1)).fetchone() + if version and version.version == self.SCHEMA_VERSION: + return + metadata.drop_all(self.engine) + metadata.create_all(self.engine) + self.db.execute(Version.insert().values(version=self.SCHEMA_VERSION)) + else: + metadata.create_all(self.engine) + return self + + async def open(self): + return await asyncio.get_event_loop().run_in_executor( + self.executor, self.sync_open + ) + + def sync_close(self): + if self.engine is not None: + self.engine.dispose() + self.engine = None + self.db = None + + async def close(self): + await asyncio.get_event_loop().run_in_executor( + self.executor, self.sync_close + ) + + def sync_create(self, name): + engine = sqlalchemy.create_engine(self.url) + db = engine.connect() + db.execute('commit') + db.execute(f'create database {name}') + + async def create(self, name): + await asyncio.get_event_loop().run_in_executor( + self.executor, self.sync_create, name + ) + + def sync_drop(self, name): + engine = sqlalchemy.create_engine(self.url) + db = engine.connect() + db.execute('commit') + db.execute(f'drop database if exists {name}') + + async def drop(self, name): + await asyncio.get_event_loop().run_in_executor( + self.executor, self.sync_drop, name + ) + + def txo_to_row(self, tx, txo): + row = { + 'tx_hash': tx.hash, + 'txo_hash': txo.hash, + 'address': txo.get_address(self.ledger), + 'position': txo.position, + 'amount': txo.amount, + 'script': txo.script.source + } + if txo.is_claim: + if txo.can_decode_claim: + claim = txo.claim + row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream']) + if claim.is_repost: + row['reposted_claim_hash'] = claim.repost.reference.claim_hash + if claim.is_signed: + row['channel_hash'] = claim.signing_channel_hash + else: + row['txo_type'] = TXO_TYPES['stream'] + elif txo.is_support: + row['txo_type'] = TXO_TYPES['support'] + elif txo.purchase is not None: + row['txo_type'] = TXO_TYPES['purchase'] + row['claim_id'] = txo.purchased_claim_id + row['claim_hash'] = txo.purchased_claim_hash + if txo.script.is_claim_involved: + row['claim_id'] = txo.claim_id + row['claim_hash'] = txo.claim_hash + row['claim_name'] = txo.claim_name + return row + + def tx_to_row(self, tx): + row = { + 'tx_hash': tx.hash, + 'raw': tx.raw, + 'height': tx.height, + 'position': tx.position, + 'is_verified': tx.is_verified, + 'day': tx.get_ordinal_day(self.ledger), + } + txos = tx.outputs + if len(txos) >= 2 and txos[1].can_decode_purchase_data: + txos[0].purchase = txos[1] + row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash + return row + + async def insert_transaction(self, tx): + await self.execute_fetchall(TX.insert().values(self.tx_to_row(tx))) + + def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash): + conn.execute( + insert_or_replace(conn, TX, ('block_hash', 'height', 'position', 'is_verified', 'day')).values( + self.tx_to_row(tx) + ) + ) + + is_my_input = False + + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + txo = txi.txo_ref.txo + if txo.has_address and txo.get_address(self.ledger) == address: + is_my_input = True + conn.execute( + insert_or_ignore(conn, TXI).values({ + 'tx_hash': tx.hash, + 'txo_hash': txo.hash, + 'address': address, + 'position': txi.position + }) + ) + + for txo in tx.outputs: + if txo.script.is_pay_pubkey_hash and (txo.pubkey_hash == txhash or is_my_input): + conn.execute(insert_or_ignore(conn, TXO).values(self.txo_to_row(tx, txo))) + elif txo.script.is_pay_script_hash: + # TODO: implement script hash payments + log.warning('Database.save_transaction_io: pay script hash is not implemented!') + + def save_transaction_io(self, tx: Transaction, address, txhash, history): + return self.save_transaction_io_batch([tx], address, txhash, history) + + def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history): + history_count = history.count(':') // 2 + + def __many(): + for tx in txs: + self._transaction_io(self.db, tx, address, txhash) + self.db.execute( + PubkeyAddress.update() + .values(history=history, used_times=history_count) + .where(PubkeyAddress.c.address == address) + ) + + return asyncio.get_event_loop().run_in_executor(self.executor, __many) + + async def reserve_outputs(self, txos, is_reserved=True): + txo_hashes = [txo.hash for txo in txos] + if txo_hashes: + await self.execute_fetchall( + TXO.update().values(is_reserved=is_reserved).where(TXO.c.txo_hash.in_(txo_hashes)) + ) + + async def release_outputs(self, txos): + await self.reserve_outputs(txos, is_reserved=False) + + async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use + # TODO: + # 1. delete transactions above_height + # 2. update address histories removing deleted TXs + return True + + async def select_transactions(self, cols, accounts=None, **constraints): + s: Select = select(cols).select_from(TX) + if not {'tx_hash', 'tx_hash__in'}.intersection(constraints): + assert accounts, "'accounts' argument required when no 'tx_hash' constraint is present" + where = in_account(accounts) + tx_hashes = union( + select([TXO.c.tx_hash], where, TXO.join(AccountAddress, TXO.c.address == AccountAddress.c.address)), + select([TXI.c.tx_hash], where, TXI.join(AccountAddress, TXI.c.address == AccountAddress.c.address)) + ) + s.append_whereclause(TX.c.tx_hash.in_(tx_hashes)) + return await self.execute_fetchall(query2([TX], s, **constraints)) + + TXO_NOT_MINE = Output(None, None, is_my_output=False) + + async def get_transactions(self, wallet=None, **constraints): + include_is_spent = constraints.pop('include_is_spent', False) + include_is_my_input = constraints.pop('include_is_my_input', False) + include_is_my_output = constraints.pop('include_is_my_output', False) + + tx_rows = await self.select_transactions( + [TX.c.tx_hash, TX.c.raw, TX.c.height, TX.c.position, TX.c.is_verified], + order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), + **constraints + ) + + if not tx_rows: + return [] + + txids, txs, txi_txoids = [], [], [] + for row in tx_rows: + txids.append(row['tx_hash']) + txs.append(Transaction( + raw=row['raw'], height=row['height'], position=row['position'], + is_verified=bool(row['is_verified']) + )) + for txi in txs[-1].inputs: + txi_txoids.append(txi.txo_ref.hash) + + step = self.MAX_QUERY_VARIABLES + annotated_txos = {} + for offset in range(0, len(txids), step): + annotated_txos.update({ + txo.id: txo for txo in + (await self.get_txos( + wallet=wallet, + tx_hash__in=txids[offset:offset+step], order_by='txo.tx_hash', + include_is_spent=include_is_spent, + include_is_my_input=include_is_my_input, + include_is_my_output=include_is_my_output, + )) + }) + + referenced_txos = {} + for offset in range(0, len(txi_txoids), step): + referenced_txos.update({ + txo.id: txo for txo in + (await self.get_txos( + wallet=wallet, + txo_hash__in=txi_txoids[offset:offset+step], order_by='txo.txo_hash', + include_is_my_output=include_is_my_output, + )) + }) + + for tx in txs: + for txi in tx.inputs: + txo = referenced_txos.get(txi.txo_ref.id) + if txo: + txi.txo_ref = txo.ref + for txo in tx.outputs: + _txo = annotated_txos.get(txo.id) + if _txo: + txo.update_annotations(_txo) + else: + txo.update_annotations(self.TXO_NOT_MINE) + + for tx in txs: + txos = tx.outputs + if len(txos) >= 2 and txos[1].can_decode_purchase_data: + txos[0].purchase = txos[1] + + return txs + + async def get_transaction_count(self, **constraints): + constraints.pop('wallet', None) + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + count = await self.select_transactions([func.count().label('total')], **constraints) + return count[0]['total'] or 0 + + async def get_transaction(self, **constraints): + txs = await self.get_transactions(limit=1, **constraints) + if txs: + return txs[0] + + async def select_txos( + self, cols, accounts=None, is_my_input=None, is_my_output=True, + is_my_input_or_output=None, exclude_internal_transfers=False, + include_is_spent=False, include_is_my_input=False, + is_spent=None, spent=None, **constraints): + s: Select = select(cols) + if accounts: + #account_in_sql, values = constraints_to_sql({ + # '$$account__in': [a.public_key.address for a in accounts] + #}) + my_addresses = select([AccountAddress.c.address]).where(in_account(accounts)) + #f"SELECT address FROM account_address WHERE {account_in_sql}" + if is_my_input_or_output: + include_is_my_input = True + s.append_whereclause( + TXO.c.address.in_(my_addresses) | ( + (TXI.c.address != None) & + (TXI.c.address.in_(my_addresses)) + ) + ) + #constraints['received_or_sent__or'] = { + # 'txo.address__in': my_addresses, + # 'sent__and': { + # 'txi.address__is_not_null': True, + # 'txi.address__in': my_addresses + # } + #} + else: + if is_my_output: + s.append_whereclause(TXO.c.address.in_(my_addresses)) + #constraints['txo.address__in'] = my_addresses + elif is_my_output is False: + s.append_whereclause(TXO.c.address.notin_(my_addresses)) + #constraints['txo.address__not_in'] = my_addresses + if is_my_input: + include_is_my_input = True + s.append_whereclause( + (TXI.c.address != None) & + (TXI.c.address.in_(my_addresses)) + ) + #constraints['txi.address__is_not_null'] = True + #constraints['txi.address__in'] = my_addresses + elif is_my_input is False: + include_is_my_input = True + s.append_whereclause( + (TXI.c.address == None) | + (TXI.c.address.notin_(my_addresses)) + ) + #constraints['is_my_input_false__or'] = { + # 'txi.address__is_null': True, + # 'txi.address__not_in': my_addresses + #} + if exclude_internal_transfers: + include_is_my_input = True + s.append_whereclause( + (TXO.c.txo_type != TXO_TYPES['other']) | + (TXI.c.address == None) | + (TXI.c.address.notin_(my_addresses)) + ) + #constraints['exclude_internal_payments__or'] = { + # 'txo.txo_type__not': TXO_TYPES['other'], + # 'txi.address__is_null': True, + # 'txi.address__not_in': my_addresses + #} + joins = TXO.join(TX) + if spent is None: + spent = TXI.alias('spent') + #sql = [f"SELECT {cols} FROM txo JOIN tx ON (tx.tx_hash=txo.tx_hash)"] + if is_spent: + s.append_whereclause(spent.c.txo_hash != None) + #constraints['spent.txo_hash__is_not_null'] = True + elif is_spent is False: + s.append_whereclause((spent.c.txo_hash == None) & (TXO.c.is_reserved == False)) + #constraints['is_reserved'] = False + #constraints['spent.txo_hash__is_null'] = True + if include_is_spent or is_spent is not None: + joins = joins.join(spent, spent.c.txo_hash == TXO.c.txo_hash, isouter=True) + #sql.append("LEFT JOIN txi AS spent ON (spent.txo_hash=txo.txo_hash)") + if include_is_my_input: + joins = joins.join(TXI, (TXI.c.position == 0) & (TXI.c.tx_hash == TXO.c.tx_hash), isouter=True) + #sql.append("LEFT JOIN txi ON (txi.position=0 AND txi.tx_hash=txo.tx_hash)") + s.append_from(joins) + return await self.execute_fetchall(query2([TXO, TX], s, **constraints)) + + async def get_txos(self, wallet=None, no_tx=False, **constraints): + include_is_spent = constraints.get('include_is_spent', False) + include_is_my_input = constraints.get('include_is_my_input', False) + include_is_my_output = constraints.pop('include_is_my_output', False) + include_received_tips = constraints.pop('include_received_tips', False) + + select_columns = [ + TX.c.tx_hash, TX.c.raw, TX.c.height, TX.c.position.label('tx_position'), TX.c.is_verified, + TXO.c.txo_type, TXO.c.position.label('txo_position'), TXO.c.amount, TXO.c.script + ] + + my_accounts = None + if wallet is not None: + my_accounts = select([AccountAddress.c.address], in_account(wallet.accounts)) + + if include_is_my_output and my_accounts is not None: + if constraints.get('is_my_output', None) in (True, False): + select_columns.append(text(f"{1 if constraints['is_my_output'] else 0} AS is_my_output")) + else: + select_columns.append(TXO.c.address.in_(my_accounts).label('is_my_output')) + + if include_is_my_input and my_accounts is not None: + if constraints.get('is_my_input', None) in (True, False): + select_columns.append(text(f"{1 if constraints['is_my_input'] else 0} AS is_my_input")) + else: + select_columns.append(( + (TXI.c.address != None) & + (TXI.c.address.in_(my_accounts)) + ).label('is_my_input')) + + spent = TXI.alias('spent') + if include_is_spent: + select_columns.append((spent.c.txo_hash != None).label('is_spent')) + + if include_received_tips: + support = TXO.alias('support') + select_columns.append(select( + [func.coalesce(func.sum(support.c.amount), 0)], + (support.c.claim_hash == TXO.c.claim_hash) & + (support.c.txo_type == TXO_TYPES['support']) & + (support.c.address.in_(my_accounts)) & + (support.c.txo_hash.notin_(select([TXI.c.txo_hash]))), + support + ).label('received_tips')) + #select_columns.append(f"""( + #SELECT COALESCE(SUM(support.amount), 0) FROM txo AS support WHERE + # support.claim_hash = txo.claim_hash AND + # support.txo_type = {TXO_TYPES['support']} AND + # support.address IN (SELECT address FROM account_address WHERE {my_accounts_sql}) AND + # support.txo_hash NOT IN (SELECT txo_hash FROM txi) + #) AS received_tips""") + + if 'order_by' not in constraints or constraints['order_by'] == 'height': + constraints['order_by'] = [ + "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position" + ] + elif constraints.get('order_by', None) == 'none': + del constraints['order_by'] + + rows = await self.select_txos(select_columns, spent=spent, **constraints) + + txos = [] + txs = {} + for row in rows: + if no_tx: + txo = Output( + amount=row['amount'], + script=OutputScript(row['script']), + tx_ref=TXRefImmutable.from_hash(row['tx_hash'], row['height']), + position=row['txo_position'] + ) + else: + if row['tx_hash'] not in txs: + txs[row['tx_hash']] = Transaction( + row['raw'], height=row['height'], position=row['tx_position'], + is_verified=bool(row['is_verified']) + ) + txo = txs[row['tx_hash']].outputs[row['txo_position']] + if include_is_spent: + txo.is_spent = bool(row['is_spent']) + if include_is_my_input: + txo.is_my_input = bool(row['is_my_input']) + if include_is_my_output: + txo.is_my_output = bool(row['is_my_output']) + if include_is_my_input and include_is_my_output: + if txo.is_my_input and txo.is_my_output and row['txo_type'] == TXO_TYPES['other']: + txo.is_internal_transfer = True + else: + txo.is_internal_transfer = False + if include_received_tips: + txo.received_tips = row['received_tips'] + txos.append(txo) + + channel_hashes = set() + for txo in txos: + if txo.is_claim and txo.can_decode_claim: + if txo.claim.is_signed: + channel_hashes.add(txo.claim.signing_channel_hash) + if txo.claim.is_channel and wallet: + for account in wallet.accounts: + private_key = account.get_channel_private_key( + txo.claim.channel.public_key_bytes + ) + if private_key: + txo.private_key = private_key + break + + if channel_hashes: + channels = { + txo.claim_hash: txo for txo in + (await self.get_channels( + wallet=wallet, + claim_hash__in=channel_hashes, + )) + } + for txo in txos: + if txo.is_claim and txo.can_decode_claim: + txo.channel = channels.get(txo.claim.signing_channel_hash, None) + + return txos + + @staticmethod + def _clean_txo_constraints_for_aggregation(constraints): + constraints.pop('include_is_spent', None) + constraints.pop('include_is_my_input', None) + constraints.pop('include_is_my_output', None) + constraints.pop('include_received_tips', None) + constraints.pop('wallet', None) + constraints.pop('resolve', None) + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + + async def get_txo_count(self, **constraints): + self._clean_txo_constraints_for_aggregation(constraints) + count = await self.select_txos([func.count().label('total')], **constraints) + return count[0]['total'] or 0 + + async def get_txo_sum(self, **constraints): + self._clean_txo_constraints_for_aggregation(constraints) + result = await self.select_txos([func.sum(TXO.c.amount).label('total')], **constraints) + return result[0]['total'] or 0 + + async def get_txo_plot(self, start_day=None, days_back=0, end_day=None, days_after=None, **constraints): + self._clean_txo_constraints_for_aggregation(constraints) + if start_day is None: + current_ordinal = self.ledger.headers.estimated_date(self.ledger.headers.height).toordinal() + constraints['day__gte'] = current_ordinal - days_back + else: + constraints['day__gte'] = date.fromisoformat(start_day).toordinal() + if end_day is not None: + constraints['day__lte'] = date.fromisoformat(end_day).toordinal() + elif days_after is not None: + constraints['day__lte'] = constraints['day__gte'] + days_after + plot = await self.select_txos( + [TX.c.day, func.sum(TXO.c.amount).label('total')], + group_by='day', order_by='day', **constraints + ) + for row in plot: + row['day'] = date.fromordinal(row['day']) + return plot + + def get_utxos(self, **constraints): + return self.get_txos(is_spent=False, **constraints) + + def get_utxo_count(self, **constraints): + return self.get_txo_count(is_spent=False, **constraints) + + async def get_balance(self, wallet=None, accounts=None, **constraints): + assert wallet or accounts, \ + "'wallet' or 'accounts' constraints required to calculate balance" + constraints['accounts'] = accounts or wallet.accounts + balance = await self.select_txos( + [func.sum(TXO.c.amount).label('total')], is_spent=False, **constraints + ) + return balance[0]['total'] or 0 + + async def select_addresses(self, cols, **constraints): + return await self.execute_fetchall(query2( + [AccountAddress, PubkeyAddress], + select(cols).select_from(PubkeyAddress.join(AccountAddress)), + **constraints + )) + + async def get_addresses(self, cols=None, **constraints): + if cols is None: + cols = ( + PubkeyAddress.c.address, + PubkeyAddress.c.history, + PubkeyAddress.c.used_times, + AccountAddress.c.account, + AccountAddress.c.chain, + AccountAddress.c.pubkey, + AccountAddress.c.chain_code, + AccountAddress.c.n, + AccountAddress.c.depth + ) + addresses = await self.select_addresses(cols, **constraints) + if AccountAddress.c.pubkey in cols: + for address in addresses: + address['pubkey'] = PubKey( + self.ledger, bytes(address.pop('pubkey')), bytes(address.pop('chain_code')), + address.pop('n'), address.pop('depth') + ) + return addresses + + async def get_address_count(self, cols=None, **constraints): + count = await self.select_addresses([func.count().label('total')], **constraints) + return count[0]['total'] or 0 + + async def get_address(self, **constraints): + addresses = await self.get_addresses(limit=1, **constraints) + if addresses: + return addresses[0] + + async def add_keys(self, account, chain, pubkeys): + await self.execute_fetchall( + insert_or_ignore(self.db, PubkeyAddress).values([{ + 'address': k.address + } for k in pubkeys]) + ) + await self.execute_fetchall( + insert_or_ignore(self.db, AccountAddress).values([{ + 'account': account.id, + 'address': k.address, + 'chain': chain, + 'pubkey': k.pubkey_bytes, + 'chain_code': k.chain_code, + 'n': k.n, + 'depth': k.depth + } for k in pubkeys]) + ) + + async def _set_address_history(self, address, history): + await self.execute_fetchall( + PubkeyAddress.update() + .values(history=history, used_times=history.count(':')//2) + .where(PubkeyAddress.c.address == address) + ) + + async def set_address_history(self, address, history): + await self._set_address_history(address, history) + + @staticmethod + def constrain_purchases(constraints): + accounts = constraints.pop('accounts', None) + assert accounts, "'accounts' argument required to find purchases" + if not {'purchased_claim_hash', 'purchased_claim_hash__in'}.intersection(constraints): + constraints['purchased_claim_hash__is_not_null'] = True + constraints['tx_hash__in'] = select( + [TXI.c.tx_hash], in_account(accounts), + TXI.join(AccountAddress, TXI.c.address == AccountAddress.c.address) + ) + + async def get_purchases(self, **constraints): + self.constrain_purchases(constraints) + return [tx.outputs[0] for tx in await self.get_transactions(**constraints)] + + def get_purchase_count(self, **constraints): + self.constrain_purchases(constraints) + return self.get_transaction_count(**constraints) + + @staticmethod + def constrain_claims(constraints): + if {'txo_type', 'txo_type__in'}.intersection(constraints): + return + claim_types = constraints.pop('claim_type', None) + if claim_types: + constrain_single_or_list( + constraints, 'txo_type', claim_types, lambda x: TXO_TYPES[x] + ) + else: + constraints['txo_type__in'] = CLAIM_TYPES + + async def get_claims(self, **constraints) -> List[Output]: + self.constrain_claims(constraints) + return await self.get_utxos(**constraints) + + def get_claim_count(self, **constraints): + self.constrain_claims(constraints) + return self.get_utxo_count(**constraints) + + @staticmethod + def constrain_streams(constraints): + constraints['txo_type'] = TXO_TYPES['stream'] + + def get_streams(self, **constraints): + self.constrain_streams(constraints) + return self.get_claims(**constraints) + + def get_stream_count(self, **constraints): + self.constrain_streams(constraints) + return self.get_claim_count(**constraints) + + @staticmethod + def constrain_channels(constraints): + constraints['txo_type'] = TXO_TYPES['channel'] + + def get_channels(self, **constraints): + self.constrain_channels(constraints) + return self.get_claims(**constraints) + + def get_channel_count(self, **constraints): + self.constrain_channels(constraints) + return self.get_claim_count(**constraints) + + @staticmethod + def constrain_supports(constraints): + constraints['txo_type'] = TXO_TYPES['support'] + + def get_supports(self, **constraints): + self.constrain_supports(constraints) + return self.get_utxos(**constraints) + + def get_support_count(self, **constraints): + self.constrain_supports(constraints) + return self.get_utxo_count(**constraints) + + @staticmethod + def constrain_collections(constraints): + constraints['txo_type'] = TXO_TYPES['collection'] + + def get_collections(self, **constraints): + self.constrain_collections(constraints) + return self.get_utxos(**constraints) + + def get_collection_count(self, **constraints): + self.constrain_collections(constraints) + return self.get_utxo_count(**constraints) + + async def release_all_outputs(self, account): + await self.execute_fetchall( + "UPDATE txo SET is_reserved = 0 WHERE" + " is_reserved = 1 AND txo.address IN (" + " SELECT address from account_address WHERE account = ?" + " )", (account.public_key.address, ) + ) + + def get_supports_summary(self, **constraints): + return self.get_txos( + txo_type=TXO_TYPES['support'], + is_spent=False, is_my_output=True, + include_is_my_input=True, + no_tx=True, + **constraints + ) diff --git a/lbry/db/tables.py b/lbry/db/tables.py new file mode 100644 index 000000000..b17a2533f --- /dev/null +++ b/lbry/db/tables.py @@ -0,0 +1,82 @@ +from sqlalchemy import ( + MetaData, Table, Column, ForeignKey, + Binary, Text, SmallInteger, Integer, Boolean +) + + +metadata = MetaData() + + +Version = Table( + 'version', metadata, + Column('version', Text, primary_key=True), +) + + +PubkeyAddress = Table( + 'pubkey_address', metadata, + Column('address', Text, primary_key=True), + Column('history', Text, nullable=True), + Column('used_times', Integer, server_default='0'), +) + + +AccountAddress = Table( + 'account_address', metadata, + Column('account', Text, primary_key=True), + Column('address', Text, ForeignKey(PubkeyAddress.columns.address), primary_key=True), + Column('chain', Integer), + Column('pubkey', Binary), + Column('chain_code', Binary), + Column('n', Integer), + Column('depth', Integer), +) + + +Block = Table( + 'block', metadata, + Column('block_hash', Binary, primary_key=True), + Column('previous_hash', Binary), + Column('file_number', SmallInteger), + Column('height', Integer), +) + + +TX = Table( + 'tx', metadata, + Column('block_hash', Binary, nullable=True), + Column('tx_hash', Binary, primary_key=True), + Column('raw', Binary), + Column('height', Integer), + Column('position', SmallInteger), + Column('is_verified', Boolean, server_default='FALSE'), + Column('purchased_claim_hash', Binary, nullable=True), + Column('day', Integer, nullable=True), +) + + +TXO = Table( + 'txo', metadata, + Column('tx_hash', Binary, ForeignKey(TX.columns.tx_hash)), + Column('txo_hash', Binary, primary_key=True), + Column('address', Text), + Column('position', Integer), + Column('amount', Integer), + Column('script', Binary), + Column('is_reserved', Boolean, server_default='0'), + Column('txo_type', Integer, server_default='0'), + Column('claim_id', Text, nullable=True), + Column('claim_hash', Binary, nullable=True), + Column('claim_name', Text, nullable=True), + Column('channel_hash', Binary, nullable=True), + Column('reposted_claim_hash', Binary, nullable=True), +) + + +TXI = Table( + 'txi', metadata, + Column('tx_hash', Binary, ForeignKey(TX.columns.tx_hash)), + Column('txo_hash', Binary, ForeignKey(TXO.columns.txo_hash), primary_key=True), + Column('address', Text), + Column('position', Integer), +) diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index cd65af76c..032d51d97 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -18,6 +18,7 @@ from functools import wraps, partial import ecdsa import base58 +from sqlalchemy import text from aiohttp import web from prometheus_client import generate_latest as prom_generate_latest from google.protobuf.message import DecodeError @@ -1530,7 +1531,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.get_wallet_or_default(wallet_id) account = wallet.get_account_or_default(account_id) balance = await account.get_detailed_balance( - confirmations=confirmations, reserved_subtotals=True, read_only=True + confirmations=confirmations, reserved_subtotals=True, ) return dict_values_to_lbc(balance) @@ -1855,7 +1856,7 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) account = wallet.get_account_or_default(account_id) - match = await self.ledger.db.get_address(read_only=True, address=address, accounts=[account]) + match = await self.ledger.db.get_address(address=address, accounts=[account]) if match is not None: return True return False @@ -1879,9 +1880,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {Paginated[Address]} """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - constraints = { - 'cols': ('address', 'account', 'used_times', 'pubkey', 'chain_code', 'n', 'depth') - } + constraints = {} if address: constraints['address'] = address if account_id: @@ -1891,7 +1890,7 @@ class Daemon(metaclass=JSONRPCServerType): return paginate_rows( self.ledger.get_addresses, self.ledger.get_address_count, - page, page_size, read_only=True, **constraints + page, page_size, **constraints ) @requires(WALLET_COMPONENT) @@ -1968,7 +1967,7 @@ class Daemon(metaclass=JSONRPCServerType): txo.purchased_claim_id: txo for txo in await self.ledger.db.get_purchases( accounts=wallet.accounts, - purchased_claim_id__in=[s.claim_id for s in paginated['items']] + purchased_claim_hash__in=[unhexlify(s.claim_id)[::-1] for s in paginated['items']] ) } for stream in paginated['items']: @@ -2630,7 +2629,7 @@ class Daemon(metaclass=JSONRPCServerType): accounts = wallet.accounts existing_channels = await self.ledger.get_claims( - wallet=wallet, accounts=accounts, claim_id=claim_id + wallet=wallet, accounts=accounts, claim_hash=unhexlify(claim_id)[::-1] ) if len(existing_channels) != 1: account_ids = ', '.join(f"'{account.id}'" for account in accounts) @@ -2721,7 +2720,7 @@ class Daemon(metaclass=JSONRPCServerType): if txid is not None and nout is not None: claims = await self.ledger.get_claims( - wallet=wallet, accounts=accounts, **{'txo.txid': txid, 'txo.position': nout} + wallet=wallet, accounts=accounts, tx_hash=unhexlify(txid)[::-1], position=nout ) elif claim_id is not None: claims = await self.ledger.get_claims( @@ -3477,7 +3476,7 @@ class Daemon(metaclass=JSONRPCServerType): if txid is not None and nout is not None: claims = await self.ledger.get_claims( - wallet=wallet, accounts=accounts, **{'txo.txid': txid, 'txo.position': nout} + wallet=wallet, accounts=accounts, tx_hash=unhexlify(txid)[::-1], position=nout ) elif claim_id is not None: claims = await self.ledger.get_claims( @@ -4053,7 +4052,7 @@ class Daemon(metaclass=JSONRPCServerType): if txid is not None and nout is not None: supports = await self.ledger.get_supports( - wallet=wallet, accounts=accounts, **{'txo.txid': txid, 'txo.position': nout} + wallet=wallet, accounts=accounts, tx_hash=unhexlify(txid)[::-1], position=nout ) elif claim_id is not None: supports = await self.ledger.get_supports( @@ -4165,7 +4164,7 @@ class Daemon(metaclass=JSONRPCServerType): self.ledger.get_transaction_history, wallet=wallet, accounts=wallet.accounts) transaction_count = partial( self.ledger.get_transaction_history_count, wallet=wallet, accounts=wallet.accounts) - return paginate_rows(transactions, transaction_count, page, page_size, read_only=True) + return paginate_rows(transactions, transaction_count, page, page_size) @requires(WALLET_COMPONENT) def jsonrpc_transaction_show(self, txid): @@ -4180,7 +4179,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {Transaction} """ - return self.wallet_manager.get_transaction(txid) + return self.wallet_manager.get_transaction(unhexlify(txid)[::-1]) TXO_DOC = """ List and sum transaction outputs. @@ -4210,12 +4209,13 @@ class Daemon(metaclass=JSONRPCServerType): constraints['is_my_output'] = True elif is_not_my_output is True: constraints['is_my_output'] = False + to_hash = lambda x: unhexlify(x)[::-1] database.constrain_single_or_list(constraints, 'txo_type', type, lambda x: TXO_TYPES[x]) - database.constrain_single_or_list(constraints, 'channel_id', channel_id) - database.constrain_single_or_list(constraints, 'claim_id', claim_id) + database.constrain_single_or_list(constraints, 'channel_hash', channel_id, to_hash) + database.constrain_single_or_list(constraints, 'claim_hash', claim_id, to_hash) database.constrain_single_or_list(constraints, 'claim_name', name) - database.constrain_single_or_list(constraints, 'txid', txid) - database.constrain_single_or_list(constraints, 'reposted_claim_id', reposted_claim_id) + database.constrain_single_or_list(constraints, 'tx_hash', txid, to_hash) + database.constrain_single_or_list(constraints, 'reposted_claim_hash', reposted_claim_id, to_hash) return constraints @requires(WALLET_COMPONENT) @@ -4274,8 +4274,8 @@ class Daemon(metaclass=JSONRPCServerType): claims = account.get_txos claim_count = account.get_txo_count else: - claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts, read_only=True) - claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts, read_only=True) + claims = partial(self.ledger.get_txos, wallet=wallet, accounts=wallet.accounts) + claim_count = partial(self.ledger.get_txo_count, wallet=wallet, accounts=wallet.accounts) constraints = { 'resolve': resolve, 'include_is_spent': True, @@ -4332,7 +4332,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.get_wallet_or_default(wallet_id) accounts = [wallet.get_account_or_error(account_id)] if account_id else wallet.accounts txos = await self.ledger.get_txos( - wallet=wallet, accounts=accounts, read_only=True, + wallet=wallet, accounts=accounts, **self._constrain_txo_from_kwargs({}, is_not_spent=True, is_my_output=True, **kwargs) ) txs = [] @@ -4391,7 +4391,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.get_wallet_or_default(wallet_id) return self.ledger.get_txo_sum( wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, - read_only=True, **self._constrain_txo_from_kwargs({}, **kwargs) + **self._constrain_txo_from_kwargs({}, **kwargs) ) @requires(WALLET_COMPONENT) @@ -4447,7 +4447,7 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.get_wallet_or_default(wallet_id) plot = await self.ledger.get_txo_plot( wallet=wallet, accounts=[wallet.get_account_or_error(account_id)] if account_id else wallet.accounts, - read_only=True, days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day, + days_back=days_back, start_day=start_day, days_after=days_after, end_day=end_day, **self._constrain_txo_from_kwargs({}, **kwargs) ) for row in plot: diff --git a/lbry/extras/daemon/json_response_encoder.py b/lbry/extras/daemon/json_response_encoder.py index d2080db8c..adf70c425 100644 --- a/lbry/extras/daemon/json_response_encoder.py +++ b/lbry/extras/daemon/json_response_encoder.py @@ -1,7 +1,7 @@ import logging from decimal import Decimal from binascii import hexlify, unhexlify -from datetime import datetime +from datetime import datetime, date from json import JSONEncoder from google.protobuf.message import DecodeError @@ -134,6 +134,8 @@ class JSONResponseEncoder(JSONEncoder): return self.encode_claim(obj) if isinstance(obj, PubKey): return obj.extended_key_string() + if isinstance(obj, date): + return obj.isoformat() if isinstance(obj, datetime): return obj.strftime("%Y%m%dT%H:%M:%S") if isinstance(obj, Decimal): diff --git a/lbry/extras/daemon/storage.py b/lbry/extras/daemon/storage.py index 11a61e45e..59ea115a8 100644 --- a/lbry/extras/daemon/storage.py +++ b/lbry/extras/daemon/storage.py @@ -6,7 +6,7 @@ import asyncio import binascii import time from typing import Optional -from lbry.wallet import SQLiteMixin +from lbry.wallet.database import SQLiteMixin from lbry.conf import Config from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies from lbry.wallet.transaction import Transaction diff --git a/lbry/schema/result.py b/lbry/schema/result.py index d0889898e..82ba9504f 100644 --- a/lbry/schema/result.py +++ b/lbry/schema/result.py @@ -148,7 +148,7 @@ class Outputs: for txo_message in chain(outputs.txos, outputs.extra_txos): if txo_message.WhichOneof('meta') == 'error': continue - txs.add((hexlify(txo_message.tx_hash[::-1]).decode(), txo_message.height)) + txs.add((txo_message.tx_hash, txo_message.height)) return cls( outputs.txos, outputs.extra_txos, txs, outputs.offset, outputs.total, diff --git a/lbry/testcase.py b/lbry/testcase.py index dcdaa83e5..d1d359a0e 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -253,6 +253,11 @@ class IntegrationTestCase(AsyncioTestCase): lambda e: e.tx.id == txid ) + def on_transaction_hash(self, tx_hash, ledger=None): + return (ledger or self.ledger).on_transaction.where( + lambda e: e.tx.hash == tx_hash + ) + def on_address_update(self, address): return self.ledger.on_transaction.where( lambda e: e.address == address @@ -316,7 +321,7 @@ class CommandTestCase(IntegrationTestCase): self.server_config = None self.server_storage = None self.extra_wallet_nodes = [] - self.extra_wallet_node_port = 5280 + self.extra_wallet_node_port = 5281 self.server_blob_manager = None self.server = None self.reflector = None diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index 7ed88527b..31c2ba60c 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -6,6 +6,7 @@ __node_url__ = ( ) __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' +from .bip32 import PubKey from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK from .manager import WalletManager from .network import Network @@ -13,5 +14,4 @@ from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic from .transaction import Transaction, Output, Input from .script import OutputScript, InputScript -from .database import SQLiteMixin, Database from .header import Headers diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 3a3d4c3f3..2c9c32b8a 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -10,6 +10,8 @@ from hashlib import sha256 from string import hexdigits from typing import Type, Dict, Tuple, Optional, Any, List +from sqlalchemy import text + import ecdsa from lbry.error import InvalidPasswordError from lbry.crypto.crypt import aes_encrypt, aes_decrypt @@ -71,7 +73,6 @@ class AddressManager: def _query_addresses(self, **constraints): return self.account.ledger.db.get_addresses( - read_only=constraints.pop("read_only", False), accounts=[self.account], chain=self.chain_number, **constraints @@ -435,8 +436,8 @@ class Account: addresses.extend(new_addresses) return addresses - async def get_addresses(self, read_only=False, **constraints) -> List[str]: - rows = await self.ledger.db.select_addresses('address', read_only=read_only, accounts=[self], **constraints) + async def get_addresses(self, **constraints) -> List[str]: + rows = await self.ledger.db.select_addresses([text('account_address.address')], accounts=[self], **constraints) return [r['address'] for r in rows] def get_address_records(self, **constraints): @@ -452,13 +453,13 @@ class Account: def get_public_key(self, chain: int, index: int) -> PubKey: return self.address_managers[chain].get_public_key(index) - def get_balance(self, confirmations=0, include_claims=False, read_only=False, **constraints): + def get_balance(self, confirmations=0, include_claims=False, **constraints): if not include_claims: constraints.update({'txo_type__in': (TXO_TYPES['other'], TXO_TYPES['purchase'])}) if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) constraints.update({'height__lte': height, 'height__gt': 0}) - return self.ledger.db.get_balance(accounts=[self], read_only=read_only, **constraints) + return self.ledger.db.get_balance(accounts=[self], **constraints) async def get_max_gap(self): change_gap = await self.change.get_max_gap() @@ -564,9 +565,9 @@ class Account: if gap_changed: self.wallet.save() - async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False, read_only=False): + async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False): tips_balance, supports_balance, claims_balance = 0, 0, 0 - get_total_balance = partial(self.get_balance, read_only=read_only, confirmations=confirmations, + get_total_balance = partial(self.get_balance, confirmations=confirmations, include_claims=True) total = await get_total_balance() if reserved_subtotals: @@ -594,14 +595,14 @@ class Account: } if reserved_subtotals else None } - def get_transaction_history(self, read_only=False, **constraints): + def get_transaction_history(self, **constraints): return self.ledger.get_transaction_history( - read_only=read_only, wallet=self.wallet, accounts=[self], **constraints + wallet=self.wallet, accounts=[self], **constraints ) - def get_transaction_history_count(self, read_only=False, **constraints): + def get_transaction_history_count(self, **constraints): return self.ledger.get_transaction_history_count( - read_only=read_only, wallet=self.wallet, accounts=[self], **constraints + wallet=self.wallet, accounts=[self], **constraints ) def get_claims(self, **constraints): diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index d3590d015..9f1371bd2 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -9,12 +9,6 @@ from contextvars import ContextVar from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional -from datetime import date - -from .bip32 import PubKey -from .transaction import Transaction, Output, OutputScript, TXRefImmutable -from .constants import TXO_TYPES, CLAIM_TYPES -from .util import date_to_julian_day log = logging.getLogger(__name__) @@ -389,697 +383,3 @@ def dict_row_factory(cursor, row): for idx, col in enumerate(cursor.description): d[col[0]] = row[idx] return d - - -class Database(SQLiteMixin): - - SCHEMA_VERSION = "1.3" - - PRAGMAS = """ - pragma journal_mode=WAL; - """ - - CREATE_ACCOUNT_TABLE = """ - create table if not exists account_address ( - account text not null, - address text not null, - chain integer not null, - pubkey blob not null, - chain_code blob not null, - n integer not null, - depth integer not null, - primary key (account, address) - ); - create index if not exists address_account_idx on account_address (address, account); - """ - - CREATE_PUBKEY_ADDRESS_TABLE = """ - create table if not exists pubkey_address ( - address text primary key, - history text, - used_times integer not null default 0 - ); - """ - - CREATE_TX_TABLE = """ - create table if not exists tx ( - txid text primary key, - raw blob not null, - height integer not null, - position integer not null, - is_verified boolean not null default 0, - purchased_claim_id text, - day integer - ); - create index if not exists tx_purchased_claim_id_idx on tx (purchased_claim_id); - """ - - CREATE_TXO_TABLE = """ - create table if not exists txo ( - txid text references tx, - txoid text primary key, - address text references pubkey_address, - position integer not null, - amount integer not null, - script blob not null, - is_reserved boolean not null default 0, - - txo_type integer not null default 0, - claim_id text, - claim_name text, - - channel_id text, - reposted_claim_id text - ); - create index if not exists txo_txid_idx on txo (txid); - create index if not exists txo_address_idx on txo (address); - create index if not exists txo_claim_id_idx on txo (claim_id, txo_type); - create index if not exists txo_claim_name_idx on txo (claim_name); - create index if not exists txo_txo_type_idx on txo (txo_type); - create index if not exists txo_channel_id_idx on txo (channel_id); - create index if not exists txo_reposted_claim_idx on txo (reposted_claim_id); - """ - - CREATE_TXI_TABLE = """ - create table if not exists txi ( - txid text references tx, - txoid text references txo primary key, - address text references pubkey_address, - position integer not null - ); - create index if not exists txi_address_idx on txi (address); - create index if not exists first_input_idx on txi (txid, address) where position=0; - """ - - CREATE_TABLES_QUERY = ( - PRAGMAS + - CREATE_ACCOUNT_TABLE + - CREATE_PUBKEY_ADDRESS_TABLE + - CREATE_TX_TABLE + - CREATE_TXO_TABLE + - CREATE_TXI_TABLE - ) - - async def open(self): - await super().open() - self.db.writer_connection.row_factory = dict_row_factory - - def txo_to_row(self, tx, txo): - row = { - 'txid': tx.id, - 'txoid': txo.id, - 'address': txo.get_address(self.ledger), - 'position': txo.position, - 'amount': txo.amount, - 'script': sqlite3.Binary(txo.script.source) - } - if txo.is_claim: - if txo.can_decode_claim: - claim = txo.claim - row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream']) - if claim.is_repost: - row['reposted_claim_id'] = claim.repost.reference.claim_id - if claim.is_signed: - row['channel_id'] = claim.signing_channel_id - else: - row['txo_type'] = TXO_TYPES['stream'] - elif txo.is_support: - row['txo_type'] = TXO_TYPES['support'] - elif txo.purchase is not None: - row['txo_type'] = TXO_TYPES['purchase'] - row['claim_id'] = txo.purchased_claim_id - if txo.script.is_claim_involved: - row['claim_id'] = txo.claim_id - row['claim_name'] = txo.claim_name - return row - - def tx_to_row(self, tx): - row = { - 'txid': tx.id, - 'raw': sqlite3.Binary(tx.raw), - 'height': tx.height, - 'position': tx.position, - 'is_verified': tx.is_verified, - 'day': tx.get_julian_day(self.ledger), - } - txos = tx.outputs - if len(txos) >= 2 and txos[1].can_decode_purchase_data: - txos[0].purchase = txos[1] - row['purchased_claim_id'] = txos[1].purchase_data.claim_id - return row - - async def insert_transaction(self, tx): - await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx))) - - async def update_transaction(self, tx): - await self.db.execute_fetchall(*self._update_sql("tx", { - 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified - }, 'txid = ?', (tx.id,))) - - def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash): - conn.execute(*self._insert_sql('tx', self.tx_to_row(tx), replace=True)).fetchall() - - is_my_input = False - - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - txo = txi.txo_ref.txo - if txo.has_address and txo.get_address(self.ledger) == address: - is_my_input = True - conn.execute(*self._insert_sql("txi", { - 'txid': tx.id, - 'txoid': txo.id, - 'address': address, - 'position': txi.position - }, ignore_duplicate=True)).fetchall() - - for txo in tx.outputs: - if txo.script.is_pay_pubkey_hash and (txo.pubkey_hash == txhash or is_my_input): - conn.execute(*self._insert_sql( - "txo", self.txo_to_row(tx, txo), ignore_duplicate=True - )).fetchall() - elif txo.script.is_pay_script_hash: - # TODO: implement script hash payments - log.warning('Database.save_transaction_io: pay script hash is not implemented!') - - def save_transaction_io(self, tx: Transaction, address, txhash, history): - return self.save_transaction_io_batch([tx], address, txhash, history) - - def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history): - history_count = history.count(':') // 2 - - def __many(conn): - for tx in txs: - self._transaction_io(conn, tx, address, txhash) - conn.execute( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (history, history_count, address) - ).fetchall() - - return self.db.run(__many) - - async def reserve_outputs(self, txos, is_reserved=True): - txoids = ((is_reserved, txo.id) for txo in txos) - await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids) - - async def release_outputs(self, txos): - await self.reserve_outputs(txos, is_reserved=False) - - async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use - # TODO: - # 1. delete transactions above_height - # 2. update address histories removing deleted TXs - return True - - async def select_transactions(self, cols, accounts=None, read_only=False, **constraints): - if not {'txid', 'txid__in'}.intersection(constraints): - assert accounts, "'accounts' argument required when no 'txid' constraint is present" - where, values = constraints_to_sql({ - '$$account_address.account__in': [a.public_key.address for a in accounts] - }) - constraints['txid__in'] = f""" - SELECT txo.txid FROM txo JOIN account_address USING (address) WHERE {where} - UNION - SELECT txi.txid FROM txi JOIN account_address USING (address) WHERE {where} - """ - constraints.update(values) - return await self.db.execute_fetchall( - *query(f"SELECT {cols} FROM tx", **constraints), read_only=read_only - ) - - TXO_NOT_MINE = Output(None, None, is_my_output=False) - - async def get_transactions(self, wallet=None, **constraints): - include_is_spent = constraints.pop('include_is_spent', False) - include_is_my_input = constraints.pop('include_is_my_input', False) - include_is_my_output = constraints.pop('include_is_my_output', False) - - tx_rows = await self.select_transactions( - 'txid, raw, height, position, is_verified', - order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), - **constraints - ) - - if not tx_rows: - return [] - - txids, txs, txi_txoids = [], [], [] - for row in tx_rows: - txids.append(row['txid']) - txs.append(Transaction( - raw=row['raw'], height=row['height'], position=row['position'], - is_verified=bool(row['is_verified']) - )) - for txi in txs[-1].inputs: - txi_txoids.append(txi.txo_ref.id) - - step = self.MAX_QUERY_VARIABLES - annotated_txos = {} - for offset in range(0, len(txids), step): - annotated_txos.update({ - txo.id: txo for txo in - (await self.get_txos( - wallet=wallet, - txid__in=txids[offset:offset+step], order_by='txo.txid', - include_is_spent=include_is_spent, - include_is_my_input=include_is_my_input, - include_is_my_output=include_is_my_output, - )) - }) - - referenced_txos = {} - for offset in range(0, len(txi_txoids), step): - referenced_txos.update({ - txo.id: txo for txo in - (await self.get_txos( - wallet=wallet, - txoid__in=txi_txoids[offset:offset+step], order_by='txo.txoid', - include_is_my_output=include_is_my_output, - )) - }) - - for tx in txs: - for txi in tx.inputs: - txo = referenced_txos.get(txi.txo_ref.id) - if txo: - txi.txo_ref = txo.ref - for txo in tx.outputs: - _txo = annotated_txos.get(txo.id) - if _txo: - txo.update_annotations(_txo) - else: - txo.update_annotations(self.TXO_NOT_MINE) - - for tx in txs: - txos = tx.outputs - if len(txos) >= 2 and txos[1].can_decode_purchase_data: - txos[0].purchase = txos[1] - - return txs - - async def get_transaction_count(self, **constraints): - constraints.pop('wallet', None) - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - count = await self.select_transactions('COUNT(*) as total', **constraints) - return count[0]['total'] or 0 - - async def get_transaction(self, **constraints): - txs = await self.get_transactions(limit=1, **constraints) - if txs: - return txs[0] - - async def select_txos( - self, cols, accounts=None, is_my_input=None, is_my_output=True, - is_my_input_or_output=None, exclude_internal_transfers=False, - include_is_spent=False, include_is_my_input=False, - is_spent=None, read_only=False, **constraints): - for rename_col in ('txid', 'txoid'): - for rename_constraint in (rename_col, rename_col+'__in', rename_col+'__not_in'): - if rename_constraint in constraints: - constraints['txo.'+rename_constraint] = constraints.pop(rename_constraint) - if accounts: - account_in_sql, values = constraints_to_sql({ - '$$account__in': [a.public_key.address for a in accounts] - }) - my_addresses = f"SELECT address FROM account_address WHERE {account_in_sql}" - constraints.update(values) - if is_my_input_or_output: - include_is_my_input = True - constraints['received_or_sent__or'] = { - 'txo.address__in': my_addresses, - 'sent__and': { - 'txi.address__is_not_null': True, - 'txi.address__in': my_addresses - } - } - else: - if is_my_output: - constraints['txo.address__in'] = my_addresses - elif is_my_output is False: - constraints['txo.address__not_in'] = my_addresses - if is_my_input: - include_is_my_input = True - constraints['txi.address__is_not_null'] = True - constraints['txi.address__in'] = my_addresses - elif is_my_input is False: - include_is_my_input = True - constraints['is_my_input_false__or'] = { - 'txi.address__is_null': True, - 'txi.address__not_in': my_addresses - } - if exclude_internal_transfers: - include_is_my_input = True - constraints['exclude_internal_payments__or'] = { - 'txo.txo_type__not': TXO_TYPES['other'], - 'txi.address__is_null': True, - 'txi.address__not_in': my_addresses - } - sql = [f"SELECT {cols} FROM txo JOIN tx ON (tx.txid=txo.txid)"] - if is_spent: - constraints['spent.txoid__is_not_null'] = True - elif is_spent is False: - constraints['is_reserved'] = False - constraints['spent.txoid__is_null'] = True - if include_is_spent or is_spent is not None: - sql.append("LEFT JOIN txi AS spent ON (spent.txoid=txo.txoid)") - if include_is_my_input: - sql.append("LEFT JOIN txi ON (txi.position=0 AND txi.txid=txo.txid)") - return await self.db.execute_fetchall(*query(' '.join(sql), **constraints), read_only=read_only) - - async def get_txos(self, wallet=None, no_tx=False, read_only=False, **constraints): - include_is_spent = constraints.get('include_is_spent', False) - include_is_my_input = constraints.get('include_is_my_input', False) - include_is_my_output = constraints.pop('include_is_my_output', False) - include_received_tips = constraints.pop('include_received_tips', False) - - select_columns = [ - "tx.txid, raw, tx.height, tx.position as tx_position, tx.is_verified, " - "txo_type, txo.position as txo_position, amount, script" - ] - - my_accounts = {a.public_key.address for a in wallet.accounts} if wallet else set() - my_accounts_sql = "" - if include_is_my_output or include_is_my_input: - my_accounts_sql, values = constraints_to_sql({'$$account__in#_wallet': my_accounts}) - constraints.update(values) - - if include_is_my_output and my_accounts: - if constraints.get('is_my_output', None) in (True, False): - select_columns.append(f"{1 if constraints['is_my_output'] else 0} AS is_my_output") - else: - select_columns.append(f"""( - txo.address IN (SELECT address FROM account_address WHERE {my_accounts_sql}) - ) AS is_my_output""") - - if include_is_my_input and my_accounts: - if constraints.get('is_my_input', None) in (True, False): - select_columns.append(f"{1 if constraints['is_my_input'] else 0} AS is_my_input") - else: - select_columns.append(f"""( - txi.address IS NOT NULL AND - txi.address IN (SELECT address FROM account_address WHERE {my_accounts_sql}) - ) AS is_my_input""") - - if include_is_spent: - select_columns.append("spent.txoid IS NOT NULL AS is_spent") - - if include_received_tips: - select_columns.append(f"""( - SELECT COALESCE(SUM(support.amount), 0) FROM txo AS support WHERE - support.claim_id = txo.claim_id AND - support.txo_type = {TXO_TYPES['support']} AND - support.address IN (SELECT address FROM account_address WHERE {my_accounts_sql}) AND - support.txoid NOT IN (SELECT txoid FROM txi) - ) AS received_tips""") - - if 'order_by' not in constraints or constraints['order_by'] == 'height': - constraints['order_by'] = [ - "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position" - ] - elif constraints.get('order_by', None) == 'none': - del constraints['order_by'] - - rows = await self.select_txos(', '.join(select_columns), read_only=read_only, **constraints) - - txos = [] - txs = {} - for row in rows: - if no_tx: - txo = Output( - amount=row['amount'], - script=OutputScript(row['script']), - tx_ref=TXRefImmutable.from_id(row['txid'], row['height']), - position=row['txo_position'] - ) - else: - if row['txid'] not in txs: - txs[row['txid']] = Transaction( - row['raw'], height=row['height'], position=row['tx_position'], - is_verified=bool(row['is_verified']) - ) - txo = txs[row['txid']].outputs[row['txo_position']] - if include_is_spent: - txo.is_spent = bool(row['is_spent']) - if include_is_my_input: - txo.is_my_input = bool(row['is_my_input']) - if include_is_my_output: - txo.is_my_output = bool(row['is_my_output']) - if include_is_my_input and include_is_my_output: - if txo.is_my_input and txo.is_my_output and row['txo_type'] == TXO_TYPES['other']: - txo.is_internal_transfer = True - else: - txo.is_internal_transfer = False - if include_received_tips: - txo.received_tips = row['received_tips'] - txos.append(txo) - - channel_ids = set() - for txo in txos: - if txo.is_claim and txo.can_decode_claim: - if txo.claim.is_signed: - channel_ids.add(txo.claim.signing_channel_id) - if txo.claim.is_channel and wallet: - for account in wallet.accounts: - private_key = account.get_channel_private_key( - txo.claim.channel.public_key_bytes - ) - if private_key: - txo.private_key = private_key - break - - if channel_ids: - channels = { - txo.claim_id: txo for txo in - (await self.get_channels( - wallet=wallet, - claim_id__in=channel_ids, - read_only=read_only - )) - } - for txo in txos: - if txo.is_claim and txo.can_decode_claim: - txo.channel = channels.get(txo.claim.signing_channel_id, None) - - return txos - - @staticmethod - def _clean_txo_constraints_for_aggregation(constraints): - constraints.pop('include_is_spent', None) - constraints.pop('include_is_my_input', None) - constraints.pop('include_is_my_output', None) - constraints.pop('include_received_tips', None) - constraints.pop('wallet', None) - constraints.pop('resolve', None) - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - - async def get_txo_count(self, **constraints): - self._clean_txo_constraints_for_aggregation(constraints) - count = await self.select_txos('COUNT(*) AS total', **constraints) - return count[0]['total'] or 0 - - async def get_txo_sum(self, **constraints): - self._clean_txo_constraints_for_aggregation(constraints) - result = await self.select_txos('SUM(amount) AS total', **constraints) - return result[0]['total'] or 0 - - async def get_txo_plot(self, start_day=None, days_back=0, end_day=None, days_after=None, **constraints): - self._clean_txo_constraints_for_aggregation(constraints) - if start_day is None: - constraints['day__gte'] = self.ledger.headers.estimated_julian_day( - self.ledger.headers.height - ) - days_back - else: - constraints['day__gte'] = date_to_julian_day( - date.fromisoformat(start_day) - ) - if end_day is not None: - constraints['day__lte'] = date_to_julian_day( - date.fromisoformat(end_day) - ) - elif days_after is not None: - constraints['day__lte'] = constraints['day__gte'] + days_after - return await self.select_txos( - "DATE(day) AS day, SUM(amount) AS total", - group_by='day', order_by='day', **constraints - ) - - def get_utxos(self, read_only=False, **constraints): - return self.get_txos(is_spent=False, read_only=read_only, **constraints) - - def get_utxo_count(self, **constraints): - return self.get_txo_count(is_spent=False, **constraints) - - async def get_balance(self, wallet=None, accounts=None, read_only=False, **constraints): - assert wallet or accounts, \ - "'wallet' or 'accounts' constraints required to calculate balance" - constraints['accounts'] = accounts or wallet.accounts - balance = await self.select_txos( - 'SUM(amount) as total', is_spent=False, read_only=read_only, **constraints - ) - return balance[0]['total'] or 0 - - async def select_addresses(self, cols, read_only=False, **constraints): - return await self.db.execute_fetchall(*query( - f"SELECT {cols} FROM pubkey_address JOIN account_address USING (address)", - **constraints - ), read_only=read_only) - - async def get_addresses(self, cols=None, read_only=False, **constraints): - cols = cols or ( - 'address', 'account', 'chain', 'history', 'used_times', - 'pubkey', 'chain_code', 'n', 'depth' - ) - addresses = await self.select_addresses(', '.join(cols), read_only=read_only, **constraints) - if 'pubkey' in cols: - for address in addresses: - address['pubkey'] = PubKey( - self.ledger, address.pop('pubkey'), address.pop('chain_code'), - address.pop('n'), address.pop('depth') - ) - return addresses - - async def get_address_count(self, cols=None, read_only=False, **constraints): - count = await self.select_addresses('COUNT(*) as total', read_only=read_only, **constraints) - return count[0]['total'] or 0 - - async def get_address(self, read_only=False, **constraints): - addresses = await self.get_addresses(read_only=read_only, limit=1, **constraints) - if addresses: - return addresses[0] - - async def add_keys(self, account, chain, pubkeys): - await self.db.executemany( - "insert or ignore into account_address " - "(account, address, chain, pubkey, chain_code, n, depth) values " - "(?, ?, ?, ?, ?, ?, ?)", (( - account.id, k.address, chain, - sqlite3.Binary(k.pubkey_bytes), - sqlite3.Binary(k.chain_code), - k.n, k.depth - ) for k in pubkeys) - ) - await self.db.executemany( - "insert or ignore into pubkey_address (address) values (?)", - ((pubkey.address,) for pubkey in pubkeys) - ) - - async def _set_address_history(self, address, history): - await self.db.execute_fetchall( - "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (history, history.count(':')//2, address) - ) - - async def set_address_history(self, address, history): - await self._set_address_history(address, history) - - @staticmethod - def constrain_purchases(constraints): - accounts = constraints.pop('accounts', None) - assert accounts, "'accounts' argument required to find purchases" - if not {'purchased_claim_id', 'purchased_claim_id__in'}.intersection(constraints): - constraints['purchased_claim_id__is_not_null'] = True - constraints.update({ - f'$account{i}': a.public_key.address for i, a in enumerate(accounts) - }) - account_values = ', '.join([f':$account{i}' for i in range(len(accounts))]) - constraints['txid__in'] = f""" - SELECT txid FROM txi JOIN account_address USING (address) - WHERE account_address.account IN ({account_values}) - """ - - async def get_purchases(self, **constraints): - self.constrain_purchases(constraints) - return [tx.outputs[0] for tx in await self.get_transactions(**constraints)] - - def get_purchase_count(self, **constraints): - self.constrain_purchases(constraints) - return self.get_transaction_count(**constraints) - - @staticmethod - def constrain_claims(constraints): - if {'txo_type', 'txo_type__in'}.intersection(constraints): - return - claim_types = constraints.pop('claim_type', None) - if claim_types: - constrain_single_or_list( - constraints, 'txo_type', claim_types, lambda x: TXO_TYPES[x] - ) - else: - constraints['txo_type__in'] = CLAIM_TYPES - - async def get_claims(self, read_only=False, **constraints) -> List[Output]: - self.constrain_claims(constraints) - return await self.get_utxos(read_only=read_only, **constraints) - - def get_claim_count(self, **constraints): - self.constrain_claims(constraints) - return self.get_utxo_count(**constraints) - - @staticmethod - def constrain_streams(constraints): - constraints['txo_type'] = TXO_TYPES['stream'] - - def get_streams(self, read_only=False, **constraints): - self.constrain_streams(constraints) - return self.get_claims(read_only=read_only, **constraints) - - def get_stream_count(self, **constraints): - self.constrain_streams(constraints) - return self.get_claim_count(**constraints) - - @staticmethod - def constrain_channels(constraints): - constraints['txo_type'] = TXO_TYPES['channel'] - - def get_channels(self, **constraints): - self.constrain_channels(constraints) - return self.get_claims(**constraints) - - def get_channel_count(self, **constraints): - self.constrain_channels(constraints) - return self.get_claim_count(**constraints) - - @staticmethod - def constrain_supports(constraints): - constraints['txo_type'] = TXO_TYPES['support'] - - def get_supports(self, **constraints): - self.constrain_supports(constraints) - return self.get_utxos(**constraints) - - def get_support_count(self, **constraints): - self.constrain_supports(constraints) - return self.get_utxo_count(**constraints) - - @staticmethod - def constrain_collections(constraints): - constraints['txo_type'] = TXO_TYPES['collection'] - - def get_collections(self, **constraints): - self.constrain_collections(constraints) - return self.get_utxos(**constraints) - - def get_collection_count(self, **constraints): - self.constrain_collections(constraints) - return self.get_utxo_count(**constraints) - - async def release_all_outputs(self, account): - await self.db.execute_fetchall( - "UPDATE txo SET is_reserved = 0 WHERE" - " is_reserved = 1 AND txo.address IN (" - " SELECT address from account_address WHERE account = ?" - " )", (account.public_key.address, ) - ) - - def get_supports_summary(self, read_only=False, **constraints): - return self.get_txos( - txo_type=TXO_TYPES['support'], - is_spent=False, is_my_output=True, - include_is_my_input=True, - no_tx=True, read_only=read_only, - **constraints - ) diff --git a/lbry/wallet/header.py b/lbry/wallet/header.py index fd5a85c61..0e51257ca 100644 --- a/lbry/wallet/header.py +++ b/lbry/wallet/header.py @@ -12,7 +12,7 @@ from typing import Optional, Iterator, Tuple, Callable from binascii import hexlify, unhexlify from lbry.crypto.hash import sha512, double_sha256, ripemd160 -from lbry.wallet.util import ArithUint256, date_to_julian_day +from lbry.wallet.util import ArithUint256 from .checkpoints import HASHES @@ -140,8 +140,8 @@ class Headers: return return int(self.first_block_timestamp + (height * self.timestamp_average_offset)) - def estimated_julian_day(self, height): - return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height))) + def estimated_date(self, height): + return date.fromtimestamp(self.estimated_timestamp(height)) async def get_raw_header(self, height) -> bytes: if self.chunk_getter: diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 79e7ef6b2..0de292e7b 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -16,8 +16,9 @@ from lbry.schema.url import URL from lbry.crypto.hash import hash160, double_sha256, sha256 from lbry.crypto.base58 import Base58 +from lbry.db import Database, AccountAddress + from .tasks import TaskGroup -from .database import Database from .stream import StreamController from .dewies import dewies_to_lbc from .account import Account, AddressManager, SingleKey @@ -508,7 +509,7 @@ class Ledger(metaclass=LedgerRegistry): else: check_local = (txid, remote_height) not in we_need cache_tasks.append(loop.create_task( - self.cache_transaction(txid, remote_height, check_local=check_local) + self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local) )) synced_txs = [] @@ -519,18 +520,18 @@ class Ledger(metaclass=LedgerRegistry): for txi in tx.inputs: if txi.txo_ref.txo is not None: continue - cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) + cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash) if cache_item is not None: if cache_item.tx is None: await cache_item.has_tx.wait() assert cache_item.tx is not None txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref else: - check_db_for_txos.append(txi.txo_ref.id) + check_db_for_txos.append(txi.txo_ref.hash) referenced_txos = {} if not check_db_for_txos else { txo.id: txo for txo in await self.db.get_txos( - txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True + txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True ) } @@ -574,10 +575,10 @@ class Ledger(metaclass=LedgerRegistry): else: return True - async def cache_transaction(self, txid, remote_height, check_local=True): - cache_item = self._tx_cache.get(txid) + async def cache_transaction(self, tx_hash, remote_height, check_local=True): + cache_item = self._tx_cache.get(tx_hash) if cache_item is None: - cache_item = self._tx_cache[txid] = TransactionCacheItem() + cache_item = self._tx_cache[tx_hash] = TransactionCacheItem() elif cache_item.tx is not None and \ cache_item.tx.height >= remote_height and \ (cache_item.tx.is_verified or remote_height < 1): @@ -585,11 +586,11 @@ class Ledger(metaclass=LedgerRegistry): try: cache_item.pending_verifications += 1 - return await self._update_cache_item(cache_item, txid, remote_height, check_local) + return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local) finally: cache_item.pending_verifications -= 1 - async def _update_cache_item(self, cache_item, txid, remote_height, check_local=True): + async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True): async with cache_item.lock: @@ -597,13 +598,13 @@ class Ledger(metaclass=LedgerRegistry): if tx is None and check_local: # check local db - tx = cache_item.tx = await self.db.get_transaction(txid=txid) + tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash) merkle = None if tx is None: # fetch from network _raw, merkle = await self.network.retriable_call( - self.network.get_transaction_and_merkle, txid, remote_height + self.network.get_transaction_and_merkle, tx_hash, remote_height ) tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) cache_item.tx = tx # make sure it's saved before caching it @@ -612,16 +613,16 @@ class Ledger(metaclass=LedgerRegistry): async def maybe_verify_transaction(self, tx, remote_height, merkle=None): tx.height = remote_height - cached = self._tx_cache.get(tx.id) + cached = self._tx_cache.get(tx.hash) if not cached: # cache txs looked up by transaction_show too cached = TransactionCacheItem() cached.tx = tx - self._tx_cache[tx.id] = cached + self._tx_cache[tx.hash] = cached if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case if not merkle: - merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) + merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = await self.headers.get(remote_height) tx.position = merkle['pos'] @@ -703,7 +704,7 @@ class Ledger(metaclass=LedgerRegistry): txo.purchased_claim_id: txo for txo in await self.db.get_purchases( accounts=accounts, - purchased_claim_id__in=[c.claim_id for c in priced_claims] + purchased_claim_hash__in=[c.claim_hash for c in priced_claims] ) } for txo in txos: @@ -808,7 +809,7 @@ class Ledger(metaclass=LedgerRegistry): async def _reset_balance_cache(self, e: TransactionEvent): account_ids = [ - r['account'] for r in await self.db.get_addresses(('account',), address=e.address) + r['account'] for r in await self.db.get_addresses([AccountAddress.c.account], address=e.address) ] for account_id in account_ids: if account_id in self._balance_cache: @@ -917,10 +918,10 @@ class Ledger(metaclass=LedgerRegistry): def get_support_count(self, **constraints): return self.db.get_support_count(**constraints) - async def get_transaction_history(self, read_only=False, **constraints): + async def get_transaction_history(self, **constraints): txs: List[Transaction] = await self.db.get_transactions( include_is_my_output=True, include_is_spent=True, - read_only=read_only, **constraints + **constraints ) headers = self.headers history = [] @@ -1030,8 +1031,8 @@ class Ledger(metaclass=LedgerRegistry): history.append(item) return history - def get_transaction_history_count(self, read_only=False, **constraints): - return self.db.get_transaction_count(read_only=read_only, **constraints) + def get_transaction_history_count(self, **constraints): + return self.db.get_transaction_count(**constraints) async def get_detailed_balance(self, accounts, confirmations=0): result = { diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 20658a2e5..40f14b762 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -14,11 +14,11 @@ from .dewies import dewies_to_lbc from .account import Account from .ledger import Ledger, LedgerRegistry from .transaction import Transaction, Output -from .database import Database from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK from .rpc.jsonrpc import CodeMessageError if typing.TYPE_CHECKING: + from lbry.db import Database from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager @@ -109,7 +109,7 @@ class WalletManager: return self.default_account.ledger @property - def db(self) -> Database: + def db(self) -> 'Database': return self.ledger.db def check_locked(self): @@ -256,12 +256,12 @@ class WalletManager: def get_unused_address(self): return self.default_account.receiving.get_or_create_usable_address() - async def get_transaction(self, txid: str): - tx = await self.db.get_transaction(txid=txid) + async def get_transaction(self, tx_hash: bytes): + tx = await self.db.get_transaction(tx_hash=tx_hash) if tx: return tx try: - raw, merkle = await self.ledger.network.get_transaction_and_merkle(txid) + raw, merkle = await self.ledger.network.get_transaction_and_merkle(tx_hash) except CodeMessageError as e: if 'No such mempool or blockchain transaction.' in e.message: return {'success': False, 'code': 404, 'message': 'transaction not found'} diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index b117a0164..e99e56567 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -4,6 +4,7 @@ import json from time import perf_counter from operator import itemgetter from typing import Dict, Optional, Tuple +from binascii import hexlify from lbry import __version__ from lbry.error import IncompatibleWalletServerError @@ -254,20 +255,20 @@ class Network: def get_transaction(self, tx_hash, known_height=None): # use any server if its old, otherwise restrict to who gave us the history restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get', [tx_hash], restricted) + return self.rpc('blockchain.transaction.get', [hexlify(tx_hash[::-1]).decode()], restricted) def get_transaction_and_merkle(self, tx_hash, known_height=None): # use any server if its old, otherwise restrict to who gave us the history restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.info', [tx_hash], restricted) + return self.rpc('blockchain.transaction.info', [hexlify(tx_hash[::-1]).decode()], restricted) def get_transaction_height(self, tx_hash, known_height=None): restricted = not known_height or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_height', [tx_hash], restricted) + return self.rpc('blockchain.transaction.get_height', [hexlify(tx_hash[::-1]).decode()], restricted) def get_merkle(self, tx_hash, height): restricted = 0 > height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_merkle', [tx_hash, height], restricted) + return self.rpc('blockchain.transaction.get_merkle', [hexlify(tx_hash[::-1]).decode(), height], restricted) def get_headers(self, height, count=10000, b64=False): restricted = height >= self.remote_height - 100 diff --git a/lbry/wallet/orchstr8/node.py b/lbry/wallet/orchstr8/node.py index 83436eb2d..be8ff6116 100644 --- a/lbry/wallet/orchstr8/node.py +++ b/lbry/wallet/orchstr8/node.py @@ -13,6 +13,7 @@ from typing import Type, Optional import urllib.request import lbry +from lbry.db import Database from lbry.wallet.server.server import Server from lbry.wallet.server.env import Env from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent @@ -125,12 +126,24 @@ class WalletNode: wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json') with open(wallet_file_name, 'w') as wallet_file: wallet_file.write('{"version": 1, "accounts": []}\n') + db_driver = os.environ.get('TEST_DB', 'sqlite') + if db_driver == 'sqlite': + db = 'sqlite:///'+os.path.join(self.data_path, self.ledger_class.get_id(), 'blockchain.db') + elif db_driver == 'postgres': + db_name = f'lbry_test_{self.port}' + meta_db = Database(f'postgres:///postgres') + await meta_db.drop(db_name) + await meta_db.create(db_name) + db = f'postgres:///{db_name}' + else: + raise RuntimeError(f"Unsupported database driver: {db_driver}") self.manager = self.manager_class.from_config({ 'ledgers': { self.ledger_class.get_id(): { 'api_port': self.port, 'default_servers': [(spv_node.hostname, spv_node.port)], - 'data_path': self.data_path + 'data_path': self.data_path, + 'db': Database(db) } }, 'wallets': [wallet_file_name] diff --git a/lbry/wallet/transaction.py b/lbry/wallet/transaction.py index 8536ddad2..fa3802049 100644 --- a/lbry/wallet/transaction.py +++ b/lbry/wallet/transaction.py @@ -268,6 +268,10 @@ class Output(InputOutput): def id(self): return self.ref.id + @property + def hash(self): + return self.ref.hash + @property def pubkey_hash(self): return self.script.values['pubkey_hash'] @@ -477,6 +481,13 @@ class Output(InputOutput): if self.purchased_claim is not None: return self.purchased_claim.claim_id + @property + def purchased_claim_hash(self): + if self.purchase is not None: + return self.purchase.purchase_data.claim_hash + if self.purchased_claim is not None: + return self.purchased_claim.claim_hash + @property def has_price(self): if self.can_decode_claim: @@ -536,9 +547,9 @@ class Transaction: def hash(self): return self.ref.hash - def get_julian_day(self, ledger): + def get_ordinal_day(self, ledger): if self._day is None and self.height > 0: - self._day = ledger.headers.estimated_julian_day(self.height) + self._day = ledger.headers.estimated_date(self.height).toordinal() return self._day @property diff --git a/lbry/wallet/util.py b/lbry/wallet/util.py index cb57bc694..a9504bff1 100644 --- a/lbry/wallet/util.py +++ b/lbry/wallet/util.py @@ -3,10 +3,6 @@ from typing import TypeVar, Sequence, Optional from .constants import COIN -def date_to_julian_day(d): - return d.toordinal() + 1721424.5 - - def coins_to_satoshis(coins): if not isinstance(coins, str): raise ValueError("{coins} must be a string") diff --git a/tests/integration/blockchain/test_blockchain.py b/tests/integration/blockchain/test_blockchain.py new file mode 100644 index 000000000..a5ff17a5d --- /dev/null +++ b/tests/integration/blockchain/test_blockchain.py @@ -0,0 +1,87 @@ +import os +import time +import asyncio +import logging +from binascii import unhexlify, hexlify +from random import choice + +from lbry.testcase import AsyncioTestCase +from lbry.crypto.base58 import Base58 +from lbry.blockchain import Lbrycrd, BlockchainSync +from lbry.db import Database +from lbry.blockchain.block import Block +from lbry.schema.claim import Stream +from lbry.wallet.transaction import Transaction, Output +from lbry.wallet.constants import CENT +from lbry.wallet.bcd_data_stream import BCDataStream + +#logging.getLogger('lbry.blockchain').setLevel(logging.DEBUG) +log = logging.getLogger(__name__) + + +class TestBlockchain(AsyncioTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + #self.chain = Lbrycrd.temp_regtest() + self.chain = Lbrycrd('/tmp/tmp0429f0ku/', True)#.temp_regtest() + await self.chain.ensure() + await self.chain.start('-maxblockfilesize=8', '-rpcworkqueue=128') + self.addCleanup(self.chain.stop, False) + + async def test_block_event(self): + msgs = [] + + self.chain.subscribe() + self.chain.on_block.listen(lambda e: msgs.append(e['msg'])) + res = await self.chain.generate(5) + await self.chain.on_block.where(lambda e: e['msg'] == 4) + self.assertEqual([0, 1, 2, 3, 4], msgs) + self.assertEqual(5, len(res)) + + self.chain.unsubscribe() + res = await self.chain.generate(2) + self.assertEqual(2, len(res)) + await asyncio.sleep(0.1) # give some time to "miss" the new block events + + self.chain.subscribe() + res = await self.chain.generate(3) + await self.chain.on_block.where(lambda e: e['msg'] == 9) + self.assertEqual(3, len(res)) + self.assertEqual([0, 1, 2, 3, 4, 7, 8, 9], msgs) + + async def test_sync(self): + if False: + names = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + await self.chain.generate(101) + address = Base58.decode(await self.chain.get_new_address()) + for _ in range(190): + tx = Transaction().add_outputs([ + Output.pay_claim_name_pubkey_hash( + CENT, f'{choice(names)}{i}', + Stream().update( + title='a claim title', + description='Lorem ipsum '*400, + tags=['crypto', 'health', 'space'], + ).claim, + address) + for i in range(1, 20) + ]) + funded = await self.chain.fund_raw_transaction(hexlify(tx.raw).decode()) + signed = await self.chain.sign_raw_transaction_with_wallet(funded['hex']) + await self.chain.send_raw_transaction(signed['hex']) + await self.chain.generate(1) + + self.assertEqual( + [(0, 191, 280), (1, 89, 178), (2, 12, 24)], + [(file['file_number'], file['blocks'], file['txs']) + for file in await self.chain.get_block_files()] + ) + self.assertEqual(191, len(await self.chain.get_file_details(0))) + + db = Database(os.path.join(self.chain.actual_data_dir, 'lbry.db')) + self.addCleanup(db.close) + await db.open() + + sync = BlockchainSync(self.chain, use_process_pool=False) + await sync.load_blocks() diff --git a/tests/integration/blockchain/test_network.py b/tests/integration/blockchain/test_network.py index eacd0d0e6..0c0b47c79 100644 --- a/tests/integration/blockchain/test_network.py +++ b/tests/integration/blockchain/test_network.py @@ -1,8 +1,9 @@ import asyncio -import lbry from unittest.mock import Mock +from binascii import unhexlify +import lbry from lbry.wallet.network import Network from lbry.wallet.orchstr8.node import SPVNode from lbry.wallet.rpc import RPCSession @@ -100,15 +101,15 @@ class ReconnectTests(IntegrationTestCase): # disconnect and send a new tx, should reconnect and get it self.ledger.network.client.connection_lost(Exception()) self.assertFalse(self.ledger.network.is_connected) - sendtxid = await self.blockchain.send_to_address(address1, 1.1337) - await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool + tx_hash = unhexlify((await self.blockchain.send_to_address(address1, 1.1337)))[::-1] + await asyncio.wait_for(self.on_transaction_hash(tx_hash), 2.0) # mempool await self.blockchain.generate(1) - await self.on_transaction_id(sendtxid) # confirmed + await self.on_transaction_hash(tx_hash) # confirmed self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine await self.assertBalance(self.account, '1.1337') # is it real? are we rich!? let me see this tx... - d = self.ledger.network.get_transaction(sendtxid) + d = self.ledger.network.get_transaction(tx_hash) # what's that smoke on my ethernet cable? oh no! master_client = self.ledger.network.client self.ledger.network.client.connection_lost(Exception()) @@ -117,15 +118,15 @@ class ReconnectTests(IntegrationTestCase): self.assertIsNone(master_client.response_time) # response time unknown as it failed # rich but offline? no way, no water, let's retry with self.assertRaisesRegex(ConnectionError, 'connection is not available'): - await self.ledger.network.get_transaction(sendtxid) + await self.ledger.network.get_transaction(tx_hash) # * goes to pick some water outside... * time passes by and another donation comes in - sendtxid = await self.blockchain.send_to_address(address1, 42) + tx_hash = unhexlify((await self.blockchain.send_to_address(address1, 42)))[::-1] await self.blockchain.generate(1) # (this is just so the test doesn't hang forever if it doesn't reconnect) if not self.ledger.network.is_connected: await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0) # omg, the burned cable still works! torba is fire proof! - await self.ledger.network.get_transaction(sendtxid) + await self.ledger.network.get_transaction(tx_hash) async def test_timeout_then_reconnect(self): # tests that it connects back after some failed attempts diff --git a/tests/integration/blockchain/test_transactions.py b/tests/integration/blockchain/test_transactions.py index 6a0ace201..f399d327b 100644 --- a/tests/integration/blockchain/test_transactions.py +++ b/tests/integration/blockchain/test_transactions.py @@ -46,13 +46,13 @@ class BasicTransactionTests(IntegrationTestCase): [self.account], self.account )) await asyncio.wait([self.broadcast(tx) for tx in txs]) - await asyncio.wait([self.ledger.wait(tx) for tx in txs]) + await asyncio.wait([self.ledger.wait(tx, timeout=2) for tx in txs]) # verify that a previous bug which failed to save TXIs doesn't come back # this check must happen before generating a new block self.assertTrue(all([ tx.inputs[0].txo_ref.txo is not None - for tx in await self.ledger.db.get_transactions(txid__in=[tx.id for tx in txs]) + for tx in await self.ledger.db.get_transactions(tx_hash__in=[tx.hash for tx in txs]) ])) await self.blockchain.generate(1) diff --git a/tests/integration/blockchain/test_wallet_commands.py b/tests/integration/blockchain/test_wallet_commands.py index 4f8f0d6ed..30415bfd7 100644 --- a/tests/integration/blockchain/test_wallet_commands.py +++ b/tests/integration/blockchain/test_wallet_commands.py @@ -1,6 +1,8 @@ import asyncio import json +from sqlalchemy import event + from lbry.wallet import ENCRYPT_ON_DISK from lbry.error import InvalidPasswordError from lbry.testcase import CommandTestCase @@ -64,7 +66,14 @@ class WalletCommands(CommandTestCase): wallet_balance = self.daemon.jsonrpc_wallet_balance ledger = self.ledger - query_count = self.ledger.db.db.query_count + + query_count = 0 + + def catch_queries(*args, **kwargs): + nonlocal query_count + query_count += 1 + + event.listen(self.ledger.db.engine, "before_cursor_execute", catch_queries) expected = { 'total': '20.0', @@ -74,15 +83,14 @@ class WalletCommands(CommandTestCase): } self.assertIsNone(ledger._balance_cache.get(self.account.id)) - query_count += 6 self.assertEqual(await wallet_balance(), expected) - self.assertEqual(self.ledger.db.db.query_count, query_count) + self.assertEqual(query_count, 6) self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') # calling again uses cache self.assertEqual(await wallet_balance(), expected) - self.assertEqual(self.ledger.db.db.query_count, query_count) + self.assertEqual(query_count, 6) self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '10.0') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') @@ -96,12 +104,11 @@ class WalletCommands(CommandTestCase): 'reserved_subtotals': {'claims': '1.0', 'supports': '0.0', 'tips': '0.0'} } # on_transaction event reset balance cache - query_count = self.ledger.db.db.query_count + query_count = 0 self.assertEqual(await wallet_balance(), expected) - query_count += 3 # only one of the accounts changed self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(self.account.id))['total'], '9.979893') self.assertEqual(dict_values_to_lbc(ledger._balance_cache.get(account2.id))['total'], '10.0') - self.assertEqual(self.ledger.db.db.query_count, query_count) + self.assertEqual(query_count, 3) # only one of the accounts changed async def test_granular_balances(self): account2 = await self.daemon.jsonrpc_account_create("Tip-er") diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index e33064503..aeda3ec90 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -10,9 +10,10 @@ from lbry.testcase import get_fake_exchange_rate_manager from lbry.utils import generate_id from lbry.error import InsufficientFundsError from lbry.error import KeyFeeAboveMaxAllowedError, ResolveError, DownloadSDTimeoutError, DownloadDataTimeoutError -from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output, Database +from lbry.wallet import WalletManager, Wallet, Ledger, Transaction, Input, Output from lbry.wallet.constants import CENT, NULL_HASH32 from lbry.wallet.network import ClientSession +from lbry.db import Database from lbry.conf import Config from lbry.extras.daemon.analytics import AnalyticsManager from lbry.stream.stream_manager import StreamManager @@ -95,7 +96,7 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): wallet = Wallet() ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': FakeHeaders(514082) }) await ledger.db.open() diff --git a/tests/unit/wallet/test_account.py b/tests/unit/wallet/test_account.py index a54962ed6..a494cfe33 100644 --- a/tests/unit/wallet/test_account.py +++ b/tests/unit/wallet/test_account.py @@ -1,13 +1,14 @@ from binascii import hexlify from lbry.testcase import AsyncioTestCase -from lbry.wallet import Wallet, Ledger, Database, Headers, Account, SingleKey, HierarchicalDeterministic +from lbry.wallet import Wallet, Ledger, Headers, Account, SingleKey, HierarchicalDeterministic +from lbry.db import Database class TestAccount(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -233,7 +234,7 @@ class TestSingleKeyAccount(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() diff --git a/tests/unit/wallet/test_bip32.py b/tests/unit/wallet/test_bip32.py index 64e72c907..97d33addc 100644 --- a/tests/unit/wallet/test_bip32.py +++ b/tests/unit/wallet/test_bip32.py @@ -2,7 +2,8 @@ from binascii import unhexlify, hexlify from lbry.testcase import AsyncioTestCase from lbry.wallet.bip32 import PubKey, PrivateKey, from_extended_key_string -from lbry.wallet import Ledger, Database, Headers +from lbry.wallet import Ledger, Headers +from lbry.db import Database from tests.unit.wallet.key_fixtures import expected_ids, expected_privkeys, expected_hardened_privkeys @@ -47,7 +48,7 @@ class BIP32Tests(AsyncioTestCase): PrivateKey(None, b'abcd', b'abcd'*8, 0, 255) private_key = PrivateKey( Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:'), }), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), @@ -68,7 +69,7 @@ class BIP32Tests(AsyncioTestCase): async def test_private_key_derivation(self): private_key = PrivateKey( Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:'), }), unhexlify('2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c'), @@ -85,7 +86,7 @@ class BIP32Tests(AsyncioTestCase): async def test_from_extended_keys(self): ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:'), }) self.assertIsInstance( diff --git a/tests/unit/wallet/test_coinselection.py b/tests/unit/wallet/test_coinselection.py index e47d7d6e3..2faceda89 100644 --- a/tests/unit/wallet/test_coinselection.py +++ b/tests/unit/wallet/test_coinselection.py @@ -2,7 +2,8 @@ from types import GeneratorType from lbry.testcase import AsyncioTestCase -from lbry.wallet import Ledger, Database, Headers +from lbry.wallet import Ledger, Headers +from lbry.db import Database from lbry.wallet.coinselection import CoinSelector, MAXIMUM_TRIES from lbry.constants import CENT @@ -21,7 +22,7 @@ class BaseSelectionTestCase(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:'), }) await self.ledger.db.open() diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 7ceed23d7..a46522153 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -6,9 +6,12 @@ import tempfile import asyncio from concurrent.futures.thread import ThreadPoolExecutor +from sqlalchemy import Column, Text + from lbry.wallet import ( - Wallet, Account, Ledger, Database, Headers, Transaction, Input + Wallet, Account, Ledger, Headers, Transaction, Input ) +from lbry.db import Table, Version, Database, metadata from lbry.wallet.constants import COIN from lbry.wallet.database import query, interpolate, constraints_to_sql, AIOSQLite from lbry.crypto.hash import sha256 @@ -208,7 +211,7 @@ class TestQueries(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) self.wallet = Wallet() @@ -265,13 +268,13 @@ class TestQueries(AsyncioTestCase): async def test_large_tx_doesnt_hit_variable_limits(self): # SQLite is usually compiled with 999 variables limit: https://www.sqlite.org/limits.html # This can be removed when there is a better way. See: https://github.com/lbryio/lbry-sdk/issues/2281 - fetchall = self.ledger.db.db.execute_fetchall + fetchall = self.ledger.db.execute_fetchall - def check_parameters_length(sql, parameters, read_only=False): + def check_parameters_length(sql, parameters=None): self.assertLess(len(parameters or []), 999) - return fetchall(sql, parameters, read_only) + return fetchall(sql, parameters) - self.ledger.db.db.execute_fetchall = check_parameters_length + self.ledger.db.execute_fetchall = check_parameters_length account = await self.create_account() tx = await self.create_tx_from_nothing(account, 0) for height in range(1, 1200): @@ -368,14 +371,14 @@ class TestQueries(AsyncioTestCase): self.assertEqual(txs[1].outputs[0].is_my_output, True) self.assertEqual(2, await self.ledger.db.get_transaction_count(accounts=[account2])) - tx = await self.ledger.db.get_transaction(txid=tx2.id) + tx = await self.ledger.db.get_transaction(tx_hash=tx2.hash) self.assertEqual(tx.id, tx2.id) self.assertIsNone(tx.inputs[0].is_my_input) self.assertIsNone(tx.outputs[0].is_my_output) - tx = await self.ledger.db.get_transaction(wallet=wallet1, txid=tx2.id, include_is_my_output=True) + tx = await self.ledger.db.get_transaction(wallet=wallet1, tx_hash=tx2.hash, include_is_my_output=True) self.assertTrue(tx.inputs[0].is_my_input) self.assertFalse(tx.outputs[0].is_my_output) - tx = await self.ledger.db.get_transaction(wallet=wallet2, txid=tx2.id, include_is_my_output=True) + tx = await self.ledger.db.get_transaction(wallet=wallet2, tx_hash=tx2.hash, include_is_my_output=True) self.assertFalse(tx.inputs[0].is_my_input) self.assertTrue(tx.outputs[0].is_my_output) @@ -425,7 +428,7 @@ class TestUpgrade(AsyncioTestCase): async def test_reset_on_version_change(self): self.ledger = Ledger({ - 'db': Database(self.path), + 'db': Database('sqlite:///'+self.path), 'headers': Headers(':memory:') }) @@ -433,7 +436,8 @@ class TestUpgrade(AsyncioTestCase): self.ledger.db.SCHEMA_VERSION = None self.assertListEqual(self.get_tables(), []) await self.ledger.db.open() - self.assertEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo']) + metadata.drop_all(self.ledger.db.engine, [Version]) # simulate pre-version table db + self.assertEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo']) self.assertListEqual(self.get_addresses(), []) self.add_address('address1') await self.ledger.db.close() @@ -442,28 +446,27 @@ class TestUpgrade(AsyncioTestCase): self.ledger.db.SCHEMA_VERSION = '1.0' await self.ledger.db.open() self.assertEqual(self.get_version(), '1.0') - self.assertListEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) + self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_addresses(), []) # address1 deleted during version upgrade self.add_address('address2') await self.ledger.db.close() # nothing changes self.assertEqual(self.get_version(), '1.0') - self.assertListEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) + self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) await self.ledger.db.open() self.assertEqual(self.get_version(), '1.0') - self.assertListEqual(self.get_tables(), ['account_address', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) + self.assertListEqual(self.get_tables(), ['account_address', 'block', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_addresses(), ['address2']) await self.ledger.db.close() # upgrade version, database reset + foo = Table('foo', metadata, Column('bar', Text, primary_key=True)) + self.addCleanup(metadata.remove, foo) self.ledger.db.SCHEMA_VERSION = '1.1' - self.ledger.db.CREATE_TABLES_QUERY += """ - create table if not exists foo (bar text); - """ await self.ledger.db.open() self.assertEqual(self.get_version(), '1.1') - self.assertListEqual(self.get_tables(), ['account_address', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) + self.assertListEqual(self.get_tables(), ['account_address', 'block', 'foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertListEqual(self.get_addresses(), []) # all tables got reset await self.ledger.db.close() diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index 0244de987..e6e4aaf14 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -1,8 +1,9 @@ import os -from binascii import hexlify +from binascii import hexlify, unhexlify from lbry.testcase import AsyncioTestCase -from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Database, Headers +from lbry.wallet import Wallet, Account, Transaction, Output, Input, Ledger, Headers +from lbry.db import Database from tests.unit.wallet.test_transaction import get_transaction, get_output from tests.unit.wallet.test_headers import HEADERS, block_bytes @@ -45,7 +46,7 @@ class LedgerTestCase(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) self.account = Account.generate(self.ledger, Wallet(), "lbryum") @@ -84,6 +85,10 @@ class TestSynchronization(LedgerTestCase): txid2 = 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9' txid3 = 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0' txid4 = '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828' + txhash1 = unhexlify(txid1)[::-1] + txhash2 = unhexlify(txid2)[::-1] + txhash3 = unhexlify(txid3)[::-1] + txhash4 = unhexlify(txid4)[::-1] account = Account.generate(self.ledger, Wallet(), "torba") address = await account.receiving.get_or_create_usable_address() @@ -99,13 +104,13 @@ class TestSynchronization(LedgerTestCase): {'tx_hash': txid2, 'height': 1}, {'tx_hash': txid3, 'height': 2}, ], { - txid1: hexlify(get_transaction(get_output(1)).raw), - txid2: hexlify(get_transaction(get_output(2)).raw), - txid3: hexlify(get_transaction(get_output(3)).raw), + txhash1: hexlify(get_transaction(get_output(1)).raw), + txhash2: hexlify(get_transaction(get_output(2)).raw), + txhash3: hexlify(get_transaction(get_output(3)).raw), }) await self.ledger.update_history(address, '') self.assertListEqual(self.ledger.network.get_history_called, [address]) - self.assertListEqual(self.ledger.network.get_transaction_called, [txid1, txid2, txid3]) + self.assertListEqual(self.ledger.network.get_transaction_called, [txhash1, txhash2, txhash3]) address_details = await self.ledger.db.get_address(address=address) @@ -125,12 +130,12 @@ class TestSynchronization(LedgerTestCase): self.assertListEqual(self.ledger.network.get_transaction_called, []) self.ledger.network.history.append({'tx_hash': txid4, 'height': 3}) - self.ledger.network.transaction[txid4] = hexlify(get_transaction(get_output(4)).raw) + self.ledger.network.transaction[txhash4] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] await self.ledger.update_history(address, '') self.assertListEqual(self.ledger.network.get_history_called, [address]) - self.assertListEqual(self.ledger.network.get_transaction_called, [txid4]) + self.assertListEqual(self.ledger.network.get_transaction_called, [txhash4]) address_details = await self.ledger.db.get_address(address=address) self.assertEqual( address_details['history'], diff --git a/tests/unit/wallet/test_schema_signing.py b/tests/unit/wallet/test_schema_signing.py index dbe31943e..4d4dccf7f 100644 --- a/tests/unit/wallet/test_schema_signing.py +++ b/tests/unit/wallet/test_schema_signing.py @@ -3,7 +3,8 @@ from binascii import unhexlify from lbry.testcase import AsyncioTestCase from lbry.wallet.constants import CENT, NULL_HASH32 -from lbry.wallet import Ledger, Database, Headers, Transaction, Input, Output +from lbry.wallet import Ledger, Headers, Transaction, Input, Output +from lbry.db import Database from lbry.schema.claim import Claim diff --git a/tests/unit/wallet/test_transaction.py b/tests/unit/wallet/test_transaction.py index 7c0942cd0..d32d92d4a 100644 --- a/tests/unit/wallet/test_transaction.py +++ b/tests/unit/wallet/test_transaction.py @@ -4,7 +4,8 @@ from itertools import cycle from lbry.testcase import AsyncioTestCase from lbry.wallet.constants import CENT, COIN, NULL_HASH32 -from lbry.wallet import Wallet, Account, Ledger, Database, Headers, Transaction, Output, Input +from lbry.wallet import Wallet, Account, Ledger, Headers, Transaction, Output, Input +from lbry.db import Database NULL_HASH = b'\x00'*32 @@ -38,7 +39,7 @@ class TestSizeAndFeeEstimation(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -264,7 +265,7 @@ class TestTransactionSigning(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() @@ -303,7 +304,7 @@ class TransactionIOBalancing(AsyncioTestCase): async def asyncSetUp(self): self.ledger = Ledger({ - 'db': Database(':memory:'), + 'db': Database('sqlite:///:memory:'), 'headers': Headers(':memory:') }) await self.ledger.db.open() diff --git a/tox.ini b/tox.ini index 3b446a241..600fdfad5 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,7 @@ extras = test changedir = {toxinidir}/tests setenv = HOME=/tmp +passenv = TEST_DB commands = pip install https://github.com/rogerbinns/apsw/releases/download/3.30.1-r1/apsw-3.30.1-r1.zip \ --global-option=fetch \