From 6a33d86bfe71a5d3fc9ce10408e5322ce72c0223 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 1 May 2020 09:29:44 -0400 Subject: [PATCH] wip lbry.db --- lbry/db/__init__.py | 5 +- lbry/db/constants.py | 59 +++ lbry/db/database.py | 941 ++++++--------------------------- lbry/db/queries.py | 1201 ++++++++++++++++++++++++++++++++++++++++++ lbry/db/search.py | 471 +---------------- lbry/db/tables.py | 42 +- lbry/db/utils.py | 129 +++++ 7 files changed, 1582 insertions(+), 1266 deletions(-) create mode 100644 lbry/db/constants.py create mode 100644 lbry/db/queries.py create mode 100644 lbry/db/utils.py diff --git a/lbry/db/__init__.py b/lbry/db/__init__.py index 303d19265..d95063beb 100644 --- a/lbry/db/__init__.py +++ b/lbry/db/__init__.py @@ -1,6 +1,7 @@ -from .database import Database, in_account +from .database import Database +from .constants import TXO_TYPES, CLAIM_TYPE_CODES, CLAIM_TYPE_NAMES from .tables import ( Table, Version, metadata, AccountAddress, PubkeyAddress, - Block, TX, TXO, TXI + Block, TX, TXO, TXI, Claim, Tag, Claimtrie ) diff --git a/lbry/db/constants.py b/lbry/db/constants.py new file mode 100644 index 000000000..909bfb458 --- /dev/null +++ b/lbry/db/constants.py @@ -0,0 +1,59 @@ +TXO_TYPES = { + "other": 0, + "stream": 1, + "channel": 2, + "support": 3, + "purchase": 4, + "collection": 5, + "repost": 6, +} + +CLAIM_TYPE_NAMES = [ + 'stream', + 'channel', + 'collection', + 'repost', +] + +CLAIM_TYPE_CODES = [ + TXO_TYPES[name] for name in CLAIM_TYPE_NAMES +] + +STREAM_TYPES = { + 'video': 1, + 'audio': 2, + 'image': 3, + 'document': 4, + 'binary': 5, + 'model': 6 +} + +MATURE_TAGS = ( + 'nsfw', 'porn', 'xxx', 'mature', 'adult', 'sex' +) + +ATTRIBUTE_ARRAY_MAX_LENGTH = 100 + +SEARCH_INTEGER_PARAMS = { + 'height', 'creation_height', 'activation_height', 'expiration_height', + 'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount', + 'tx_position', 'channel_join', 'reposted', + 'amount', 'effective_amount', 'support_amount', + 'trending_group', 'trending_mixed', + 'trending_local', 'trending_global', +} + +SEARCH_PARAMS = { + 'name', 'text', 'claim_id', 'claim_ids', 'txid', 'nout', 'channel', 'channel_ids', 'not_channel_ids', + 'public_key_id', 'claim_type', 'stream_types', 'media_types', 'fee_currency', + 'has_channel_signature', 'signature_valid', + 'any_tags', 'all_tags', 'not_tags', 'reposted_claim_id', + 'any_locations', 'all_locations', 'not_locations', + 'any_languages', 'all_languages', 'not_languages', + 'is_controlling', 'limit', 'offset', 'order_by', + 'no_totals', +} | SEARCH_INTEGER_PARAMS + +SEARCH_ORDER_FIELDS = { + 'name', 'claim_hash' +} | SEARCH_INTEGER_PARAMS diff --git a/lbry/db/database.py b/lbry/db/database.py index 21f0304a3..5d199c3a6 100644 --- a/lbry/db/database.py +++ b/lbry/db/database.py @@ -1,856 +1,219 @@ -# pylint: disable=singleton-comparison - -import logging +import os import asyncio -import sqlite3 -from concurrent.futures.thread import ThreadPoolExecutor -from typing import List, Union, Iterable, Optional -from datetime import date +from typing import List, Optional, Tuple, Iterable +from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor +from functools import partial -import sqlalchemy -from sqlalchemy.future import select -from sqlalchemy import text, and_, union, func, inspect -from sqlalchemy.sql.expression import Select -try: - from sqlalchemy.dialects.postgresql import insert as pg_insert -except ImportError: - pg_insert = None +from sqlalchemy import create_engine, text + +from lbry.crypto.bip32 import PubKey +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.transaction import Transaction, Output +from .constants import TXO_TYPES, CLAIM_TYPE_CODES +from . import queries as q -from lbry.wallet import PubKey -from lbry.wallet.transaction import Transaction, Output, OutputScript, TXRefImmutable -from lbry.wallet.constants import TXO_TYPES, CLAIM_TYPES - -from .tables import ( - metadata, Version, - PubkeyAddress, AccountAddress, - TX, - TXO, txo_join_account, - TXI, txi_join_account, -) - - -log = logging.getLogger(__name__) -sqlite3.enable_callback_tracebacks(True) - - -def insert_or_ignore(conn, table): - if conn.dialect.name == 'sqlite': - return table.insert().prefix_with("OR IGNORE") - elif conn.dialect.name == 'postgresql': - return pg_insert(table).on_conflict_do_nothing() - else: - raise RuntimeError(f'Unknown database dialect: {conn.dialect.name}.') - - -def insert_or_replace(conn, table, replace): - if conn.dialect.name == 'sqlite': - return table.insert().prefix_with("OR REPLACE") - elif conn.dialect.name == 'postgresql': - insert = pg_insert(table) - return insert.on_conflict_do_update( - table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace} - ) - else: - raise RuntimeError(f'Unknown database dialect: {conn.dialect.name}.') - - -def constrain_single_or_list(constraints, column, value, convert=lambda x: x): - if value is not None: - if isinstance(value, list): - value = [convert(v) for v in value] - if len(value) == 1: - constraints[column] = value[0] - elif len(value) > 1: - constraints[f"{column}__in"] = value - else: - constraints[column] = convert(value) - return constraints - - -def in_account(accounts: Union[List[PubKey], PubKey]): - if isinstance(accounts, list): - if len(accounts) > 1: - return AccountAddress.c.account.in_({a.public_key.address for a in accounts}) - accounts = accounts[0] - return AccountAddress.c.account == accounts.public_key.address - - -def query2(table, s: Select, **constraints) -> Select: - limit = constraints.pop('limit', None) - if limit is not None: - s = s.limit(limit) - - offset = constraints.pop('offset', None) - if offset is not None: - s = s.offset(offset) - - order_by = constraints.pop('order_by', None) - if order_by: - if isinstance(order_by, str): - s = s.order_by(text(order_by)) - elif isinstance(order_by, list): - s = s.order_by(text(', '.join(order_by))) - else: - raise ValueError("order_by must be string or list") - - group_by = constraints.pop('group_by', None) - if group_by is not None: - s = s.group_by(text(group_by)) - +def clean_wallet_account_ids(constraints): + wallet = constraints.pop('wallet', None) + account = constraints.pop('account', None) accounts = constraints.pop('accounts', []) + if account and not accounts: + accounts = [account] + if wallet: + constraints['wallet_account_ids'] = [account.id for account in wallet.accounts] + if not accounts: + accounts = wallet.accounts if accounts: - s = s.where(in_account(accounts)) - - if constraints: - s = s.where( - constraints_to_clause2(table, constraints) - ) - - return s + constraints['account_ids'] = [account.id for account in accounts] -def constraints_to_clause2(tables, constraints): - clause = [] - for key, constraint in constraints.items(): - if key.endswith('__not'): - col, op = key[:-len('__not')], '__ne__' - elif key.endswith('__is_null'): - col = key[:-len('__is_null')] - op = '__eq__' - constraint = None - elif key.endswith('__is_not_null'): - col = key[:-len('__is_not_null')] - op = '__ne__' - constraint = None - elif key.endswith('__lt'): - col, op = key[:-len('__lt')], '__lt__' - elif key.endswith('__lte'): - col, op = key[:-len('__lte')], '__le__' - elif key.endswith('__gt'): - col, op = key[:-len('__gt')], '__gt__' - elif key.endswith('__gte'): - col, op = key[:-len('__gte')], '__ge__' - elif key.endswith('__like'): - col, op = key[:-len('__like')], 'like' - elif key.endswith('__not_like'): - col, op = key[:-len('__not_like')], 'notlike' - elif key.endswith('__in') or key.endswith('__not_in'): - if key.endswith('__in'): - col, op, one_val_op = key[:-len('__in')], 'in_', '__eq__' - else: - col, op, one_val_op = key[:-len('__not_in')], 'notin_', '__ne__' - if isinstance(constraint, sqlalchemy.sql.expression.Select): - pass - elif constraint: - if isinstance(constraint, (list, set, tuple)): - if len(constraint) == 1: - op = one_val_op - constraint = next(iter(constraint)) - elif isinstance(constraint, str): - constraint = text(constraint) - else: - raise ValueError(f"{col} requires a list, set or string as constraint value.") - else: - continue - else: - col, op = key, '__eq__' - attr = None - for table in tables: - attr = getattr(table.c, col, None) - if attr is not None: - clause.append(getattr(attr, op)(constraint)) - break - if attr is None: - raise ValueError(f"Attribute '{col}' not found on tables: {', '.join([t.name for t in tables])}.") - return and_(*clause) +def add_channel_keys_to_txo_results(accounts: List, txos: Iterable[Output]): + sub_channels = set() + for txo in txos: + if txo.claim.is_channel: + for account in accounts: + private_key = account.get_channel_private_key( + txo.claim.channel.public_key_bytes + ) + if private_key: + txo.private_key = private_key + break + if txo.channel is not None: + sub_channels.add(txo.channel) + if sub_channels: + add_channel_keys_to_txo_results(accounts, sub_channels) class Database: - SCHEMA_VERSION = "1.3" - MAX_QUERY_VARIABLES = 900 - - def __init__(self, url): + def __init__(self, ledger: Ledger, url: str, multiprocess=False): self.url = url - self.ledger = None - self.executor = ThreadPoolExecutor(max_workers=1) - self.engine = None - self.db: Optional[sqlalchemy.engine.Connection] = None - - def sync_execute_fetchall(self, sql, params=None): - if params: - result = self.db.execute(sql, params) - else: - result = self.db.execute(sql) - if result.returns_rows: - return [dict(r._mapping) for r in result.fetchall()] - return [] - - async def execute_fetchall(self, sql, params=None) -> List[dict]: - return await asyncio.get_event_loop().run_in_executor( - self.executor, self.sync_execute_fetchall, sql, params - ) - - def sync_executemany(self, sql, parameters): - self.db.execute(sql, parameters) - - async def executemany(self, sql: str, parameters: Iterable = None): - return await asyncio.get_event_loop().run_in_executor( - self.executor, self.sync_executemany, sql, parameters - ) - - def sync_open(self): - log.info("connecting to database: %s", self.url) - self.engine = sqlalchemy.create_engine(self.url) - self.db = self.engine.connect() - if self.SCHEMA_VERSION: - if inspect(self.engine).has_table('version'): - version = self.db.execute(Version.select().limit(1)).fetchone() - if version and version.version == self.SCHEMA_VERSION: - return - metadata.drop_all(self.engine) - metadata.create_all(self.engine) - self.db.execute(Version.insert().values(version=self.SCHEMA_VERSION)) - else: - metadata.create_all(self.engine) - return self - - async def open(self): - return await asyncio.get_event_loop().run_in_executor( - self.executor, self.sync_open - ) - - def sync_close(self): - if self.engine is not None: - self.engine.dispose() - self.engine = None - self.db = None - - async def close(self): - await asyncio.get_event_loop().run_in_executor( - self.executor, self.sync_close - ) + self.ledger = ledger + self.multiprocess = multiprocess + self.executor: Optional[Executor] = None def sync_create(self, name): - engine = sqlalchemy.create_engine(self.url) + engine = create_engine(self.url) db = engine.connect() - db.execute(text('commit')) - db.execute(text(f'create database {name}')) + db.execute(text("COMMIT")) + db.execute(text(f"CREATE DATABASE {name}")) async def create(self, name): - await asyncio.get_event_loop().run_in_executor( - self.executor, self.sync_create, name - ) + return await asyncio.get_event_loop().run_in_executor(None, self.sync_create, name) def sync_drop(self, name): - engine = sqlalchemy.create_engine(self.url) + engine = create_engine(self.url) db = engine.connect() - db.execute(text('commit')) - db.execute(text(f'drop database if exists {name}')) + db.execute(text("COMMIT")) + db.execute(text(f"DROP DATABASE IF EXISTS {name}")) async def drop(self, name): - await asyncio.get_event_loop().run_in_executor( - self.executor, self.sync_drop, name + return await asyncio.get_event_loop().run_in_executor(None, self.sync_drop, name) + + async def open(self): + assert self.executor is None, "Database already open." + kwargs = dict( + initializer=q.initialize, + initargs=(self.url, self.ledger) + ) + if self.multiprocess: + self.executor = ProcessPoolExecutor( + max_workers=max(os.cpu_count()-1, 4), **kwargs + ) + else: + self.executor = ThreadPoolExecutor( + max_workers=1, **kwargs + ) + return await self.run_in_executor(q.check_version_and_create_tables) + + async def close(self): + if self.executor is not None: + self.executor.shutdown() + self.executor = None + + async def run_in_executor(self, func, *args, **kwargs): + if kwargs: + clean_wallet_account_ids(kwargs) + return await asyncio.get_event_loop().run_in_executor( + self.executor, partial(func, *args, **kwargs) ) - def txo_to_row(self, tx, txo): - row = { - 'tx_hash': tx.hash, - 'txo_hash': txo.hash, - 'address': txo.get_address(self.ledger), - 'position': txo.position, - 'amount': txo.amount, - 'script': txo.script.source - } - if txo.is_claim: - if txo.can_decode_claim: - claim = txo.claim - row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream']) - if claim.is_repost: - row['reposted_claim_hash'] = claim.repost.reference.claim_hash - if claim.is_signed: - row['channel_hash'] = claim.signing_channel_hash - else: - row['txo_type'] = TXO_TYPES['stream'] - elif txo.is_support: - row['txo_type'] = TXO_TYPES['support'] - elif txo.purchase is not None: - row['txo_type'] = TXO_TYPES['purchase'] - row['claim_id'] = txo.purchased_claim_id - row['claim_hash'] = txo.purchased_claim_hash - if txo.script.is_claim_involved: - row['claim_id'] = txo.claim_id - row['claim_hash'] = txo.claim_hash - row['claim_name'] = txo.claim_name - return row + async def execute_fetchall(self, sql): + return await self.run_in_executor(q.execute_fetchall, sql) - def tx_to_row(self, tx): - row = { - 'tx_hash': tx.hash, - 'raw': tx.raw, - 'height': tx.height, - 'position': tx.position, - 'is_verified': tx.is_verified, - 'day': tx.get_ordinal_day(self.ledger), - } - txos = tx.outputs - if len(txos) >= 2 and txos[1].can_decode_purchase_data: - txos[0].purchase = txos[1] - row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash - return row + async def get_best_height(self): + return await self.run_in_executor(q.get_best_height) + + async def get_blocks_without_filters(self): + return await self.run_in_executor(q.get_blocks_without_filters) + + async def get_transactions_without_filters(self): + return await self.run_in_executor(q.get_transactions_without_filters) + + async def get_block_tx_addresses(self, block_hash=None, tx_hash=None): + return await self.run_in_executor(q.get_block_tx_addresses, block_hash, tx_hash) + + async def get_block_address_filters(self): + return await self.run_in_executor(q.get_block_address_filters) + + async def get_transaction_address_filters(self, block_hash): + return await self.run_in_executor(q.get_transaction_address_filters, block_hash) async def insert_transaction(self, tx): - await self.execute_fetchall(TX.insert().values(self.tx_to_row(tx))) + return await self.run_in_executor(q.insert_transaction, tx) - def _transaction_io(self, conn: sqlite3.Connection, tx: Transaction, address, txhash): - conn.execute( - insert_or_replace(conn, TX, ('block_hash', 'height', 'position', 'is_verified', 'day')).values( - self.tx_to_row(tx) - ) - ) - - is_my_input = False - - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - txo = txi.txo_ref.txo - if txo.has_address and txo.get_address(self.ledger) == address: - is_my_input = True - conn.execute( - insert_or_ignore(conn, TXI).values({ - 'tx_hash': tx.hash, - 'txo_hash': txo.hash, - 'address': address, - 'position': txi.position - }) - ) - - for txo in tx.outputs: - if txo.script.is_pay_pubkey_hash and (txo.pubkey_hash == txhash or is_my_input): - conn.execute(insert_or_ignore(conn, TXO).values(self.txo_to_row(tx, txo))) - elif txo.script.is_pay_script_hash: - # TODO: implement script hash payments - log.warning('Database.save_transaction_io: pay script hash is not implemented!') - - def save_transaction_io(self, tx: Transaction, address, txhash, history): - return self.save_transaction_io_batch([tx], address, txhash, history) - - def save_transaction_io_batch(self, txs: Iterable[Transaction], address, txhash, history): - history_count = history.count(':') // 2 - - def __many(): - for tx in txs: - self._transaction_io(self.db, tx, address, txhash) - self.db.execute( - PubkeyAddress.update() - .values(history=history, used_times=history_count) - .where(PubkeyAddress.c.address == address) - ) - - return asyncio.get_event_loop().run_in_executor(self.executor, __many) + async def update_address_used_times(self, addresses): + return await self.run_in_executor(q.update_address_used_times, addresses) async def reserve_outputs(self, txos, is_reserved=True): txo_hashes = [txo.hash for txo in txos] if txo_hashes: - await self.execute_fetchall( - TXO.update().values(is_reserved=is_reserved).where(TXO.c.txo_hash.in_(txo_hashes)) + return await self.run_in_executor( + q.reserve_outputs, txo_hashes, is_reserved ) async def release_outputs(self, txos): - await self.reserve_outputs(txos, is_reserved=False) + return await self.reserve_outputs(txos, is_reserved=False) - async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use - # TODO: - # 1. delete transactions above_height - # 2. update address histories removing deleted TXs - return True + async def release_tx(self, tx): + return await self.release_outputs([txi.txo_ref.txo for txi in tx.inputs]) - async def select_transactions(self, cols, accounts=None, **constraints): - s: Select = select(*cols).select_from(TX) - if not {'tx_hash', 'tx_hash__in'}.intersection(constraints): - assert accounts, "'accounts' argument required when no 'tx_hash' constraint is present" - where = in_account(accounts) - tx_hashes = union( - select(TXO.c.tx_hash).select_from(txo_join_account).where(where), - select(TXI.c.tx_hash).select_from(txi_join_account).where(where) - ) - s = s.where(TX.c.tx_hash.in_(tx_hashes)) - return await self.execute_fetchall(query2([TX], s, **constraints)) + async def release_all_outputs(self, account): + return await self.run_in_executor(q.release_all_outputs, account.id) - TXO_NOT_MINE = Output(None, None, is_my_output=False) + async def get_balance(self, **constraints): + return await self.run_in_executor(q.get_balance, **constraints) - async def get_transactions(self, wallet=None, **constraints): - include_is_spent = constraints.pop('include_is_spent', False) - include_is_my_input = constraints.pop('include_is_my_input', False) - include_is_my_output = constraints.pop('include_is_my_output', False) + async def get_supports_summary(self, **constraints): + return await self.run_in_executor(self.get_supports_summary, **constraints) - tx_rows = await self.select_transactions( - [TX.c.tx_hash, TX.c.raw, TX.c.height, TX.c.position, TX.c.is_verified], - order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), - **constraints - ) - - if not tx_rows: - return [] - - txids, txs, txi_txoids = [], [], [] - for row in tx_rows: - txids.append(row['tx_hash']) - txs.append(Transaction( - raw=row['raw'], height=row['height'], position=row['position'], - is_verified=bool(row['is_verified']) - )) - for txi in txs[-1].inputs: - txi_txoids.append(txi.txo_ref.hash) - - step = self.MAX_QUERY_VARIABLES - annotated_txos = {} - for offset in range(0, len(txids), step): - annotated_txos.update({ - txo.id: txo for txo in - (await self.get_txos( - wallet=wallet, - tx_hash__in=txids[offset:offset+step], order_by='txo.tx_hash', - include_is_spent=include_is_spent, - include_is_my_input=include_is_my_input, - include_is_my_output=include_is_my_output, - )) - }) - - referenced_txos = {} - for offset in range(0, len(txi_txoids), step): - referenced_txos.update({ - txo.id: txo for txo in - (await self.get_txos( - wallet=wallet, - txo_hash__in=txi_txoids[offset:offset+step], order_by='txo.txo_hash', - include_is_my_output=include_is_my_output, - )) - }) - - for tx in txs: - for txi in tx.inputs: - txo = referenced_txos.get(txi.txo_ref.id) - if txo: - txi.txo_ref = txo.ref - for txo in tx.outputs: - _txo = annotated_txos.get(txo.id) - if _txo: - txo.update_annotations(_txo) - else: - txo.update_annotations(self.TXO_NOT_MINE) - - for tx in txs: - txos = tx.outputs - if len(txos) >= 2 and txos[1].can_decode_purchase_data: - txos[0].purchase = txos[1] - - return txs - - async def get_transaction_count(self, **constraints): - constraints.pop('wallet', None) - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - count = await self.select_transactions([func.count().label('total')], **constraints) - return count[0]['total'] or 0 - - async def get_transaction(self, **constraints): - txs = await self.get_transactions(limit=1, **constraints) - if txs: - return txs[0] - - async def select_txos( - self, cols, accounts=None, is_my_input=None, is_my_output=True, - is_my_input_or_output=None, exclude_internal_transfers=False, - include_is_spent=False, include_is_my_input=False, - is_spent=None, spent=None, **constraints): - s: Select = select(*cols) - if accounts: - my_addresses = select(AccountAddress.c.address).where(in_account(accounts)) - if is_my_input_or_output: - include_is_my_input = True - s = s.where( - TXO.c.address.in_(my_addresses) | ( - (TXI.c.address != None) & - (TXI.c.address.in_(my_addresses)) - ) - ) - else: - if is_my_output: - s = s.where(TXO.c.address.in_(my_addresses)) - elif is_my_output is False: - s = s.where(TXO.c.address.notin_(my_addresses)) - if is_my_input: - include_is_my_input = True - s = s.where( - (TXI.c.address != None) & - (TXI.c.address.in_(my_addresses)) - ) - elif is_my_input is False: - include_is_my_input = True - s = s.where( - (TXI.c.address == None) | - (TXI.c.address.notin_(my_addresses)) - ) - if exclude_internal_transfers: - include_is_my_input = True - s = s.where( - (TXO.c.txo_type != TXO_TYPES['other']) | - (TXI.c.address == None) | - (TXI.c.address.notin_(my_addresses)) - ) - joins = TXO.join(TX) - if spent is None: - spent = TXI.alias('spent') - if is_spent: - s = s.where(spent.c.txo_hash != None) - elif is_spent is False: - s = s.where((spent.c.txo_hash == None) & (TXO.c.is_reserved == False)) - if include_is_spent or is_spent is not None: - joins = joins.join(spent, spent.c.txo_hash == TXO.c.txo_hash, isouter=True) - if include_is_my_input: - joins = joins.join(TXI, (TXI.c.position == 0) & (TXI.c.tx_hash == TXO.c.tx_hash), isouter=True) - s = s.select_from(joins) - return await self.execute_fetchall(query2([TXO, TX], s, **constraints)) - - async def get_txos(self, wallet=None, no_tx=False, **constraints): - include_is_spent = constraints.get('include_is_spent', False) - include_is_my_input = constraints.get('include_is_my_input', False) - include_is_my_output = constraints.pop('include_is_my_output', False) - include_received_tips = constraints.pop('include_received_tips', False) - - select_columns = [ - TX.c.tx_hash, TX.c.raw, TX.c.height, TX.c.position.label('tx_position'), TX.c.is_verified, - TXO.c.txo_type, TXO.c.position.label('txo_position'), TXO.c.amount, TXO.c.script - ] - - my_accounts = None - if wallet is not None: - my_accounts = select(AccountAddress.c.address).where(in_account(wallet.accounts)) - - if include_is_my_output and my_accounts is not None: - if constraints.get('is_my_output', None) in (True, False): - select_columns.append(text(f"{1 if constraints['is_my_output'] else 0} AS is_my_output")) - else: - select_columns.append(TXO.c.address.in_(my_accounts).label('is_my_output')) - - if include_is_my_input and my_accounts is not None: - if constraints.get('is_my_input', None) in (True, False): - select_columns.append(text(f"{1 if constraints['is_my_input'] else 0} AS is_my_input")) - else: - select_columns.append(( - (TXI.c.address != None) & - (TXI.c.address.in_(my_accounts)) - ).label('is_my_input')) - - spent = TXI.alias('spent') - if include_is_spent: - select_columns.append((spent.c.txo_hash != None).label('is_spent')) - - if include_received_tips: - support = TXO.alias('support') - select_columns.append( - select(func.coalesce(func.sum(support.c.amount), 0)) - .select_from(support).where( - (support.c.claim_hash == TXO.c.claim_hash) & - (support.c.txo_type == TXO_TYPES['support']) & - (support.c.address.in_(my_accounts)) & - (support.c.txo_hash.notin_(select(TXI.c.txo_hash))) - ).label('received_tips') - ) - - if 'order_by' not in constraints or constraints['order_by'] == 'height': - constraints['order_by'] = [ - "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position" - ] - elif constraints.get('order_by', None) == 'none': - del constraints['order_by'] - - rows = await self.select_txos(select_columns, spent=spent, **constraints) - - txos = [] - txs = {} - for row in rows: - if no_tx: - txo = Output( - amount=row['amount'], - script=OutputScript(row['script']), - tx_ref=TXRefImmutable.from_hash(row['tx_hash'], row['height']), - position=row['txo_position'] - ) - else: - if row['tx_hash'] not in txs: - txs[row['tx_hash']] = Transaction( - row['raw'], height=row['height'], position=row['tx_position'], - is_verified=bool(row['is_verified']) - ) - txo = txs[row['tx_hash']].outputs[row['txo_position']] - if include_is_spent: - txo.is_spent = bool(row['is_spent']) - if include_is_my_input: - txo.is_my_input = bool(row['is_my_input']) - if include_is_my_output: - txo.is_my_output = bool(row['is_my_output']) - if include_is_my_input and include_is_my_output: - if txo.is_my_input and txo.is_my_output and row['txo_type'] == TXO_TYPES['other']: - txo.is_internal_transfer = True - else: - txo.is_internal_transfer = False - if include_received_tips: - txo.received_tips = row['received_tips'] - txos.append(txo) - - channel_hashes = set() - for txo in txos: - if txo.is_claim and txo.can_decode_claim: - if txo.claim.is_signed: - channel_hashes.add(txo.claim.signing_channel_hash) - if txo.claim.is_channel and wallet: - for account in wallet.accounts: - private_key = account.get_channel_private_key( - txo.claim.channel.public_key_bytes - ) - if private_key: - txo.private_key = private_key - break - - if channel_hashes: - channels = { - txo.claim_hash: txo for txo in - (await self.get_channels( - wallet=wallet, - claim_hash__in=channel_hashes, - )) - } - for txo in txos: - if txo.is_claim and txo.can_decode_claim: - txo.channel = channels.get(txo.claim.signing_channel_hash, None) - - return txos - - @staticmethod - def _clean_txo_constraints_for_aggregation(constraints): - constraints.pop('include_is_spent', None) - constraints.pop('include_is_my_input', None) - constraints.pop('include_is_my_output', None) - constraints.pop('include_received_tips', None) - constraints.pop('wallet', None) - constraints.pop('resolve', None) - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - - async def get_txo_count(self, **constraints): - self._clean_txo_constraints_for_aggregation(constraints) - count = await self.select_txos([func.count().label('total')], **constraints) - return count[0]['total'] or 0 - - async def get_txo_sum(self, **constraints): - self._clean_txo_constraints_for_aggregation(constraints) - result = await self.select_txos([func.sum(TXO.c.amount).label('total')], **constraints) - return result[0]['total'] or 0 - - async def get_txo_plot(self, start_day=None, days_back=0, end_day=None, days_after=None, **constraints): - self._clean_txo_constraints_for_aggregation(constraints) - if start_day is None: - current_ordinal = self.ledger.headers.estimated_date(self.ledger.headers.height).toordinal() - constraints['day__gte'] = current_ordinal - days_back - else: - constraints['day__gte'] = date.fromisoformat(start_day).toordinal() - if end_day is not None: - constraints['day__lte'] = date.fromisoformat(end_day).toordinal() - elif days_after is not None: - constraints['day__lte'] = constraints['day__gte'] + days_after - plot = await self.select_txos( - [TX.c.day, func.sum(TXO.c.amount).label('total')], - group_by='day', order_by='day', **constraints - ) - for row in plot: - row['day'] = date.fromordinal(row['day']) - return plot - - def get_utxos(self, **constraints): - return self.get_txos(is_spent=False, **constraints) - - def get_utxo_count(self, **constraints): - return self.get_txo_count(is_spent=False, **constraints) - - async def get_balance(self, wallet=None, accounts=None, **constraints): - assert wallet or accounts, \ - "'wallet' or 'accounts' constraints required to calculate balance" - constraints['accounts'] = accounts or wallet.accounts - balance = await self.select_txos( - [func.sum(TXO.c.amount).label('total')], is_spent=False, **constraints - ) - return balance[0]['total'] or 0 - - async def select_addresses(self, cols, **constraints): - return await self.execute_fetchall(query2( - [AccountAddress, PubkeyAddress], - select(*cols).select_from(PubkeyAddress.join(AccountAddress)), - **constraints - )) - - async def get_addresses(self, cols=None, **constraints): - if cols is None: - cols = ( - PubkeyAddress.c.address, - PubkeyAddress.c.history, - PubkeyAddress.c.used_times, - AccountAddress.c.account, - AccountAddress.c.chain, - AccountAddress.c.pubkey, - AccountAddress.c.chain_code, - AccountAddress.c.n, - AccountAddress.c.depth - ) - addresses = await self.select_addresses(cols, **constraints) - if AccountAddress.c.pubkey in cols: + async def get_addresses(self, **constraints) -> Tuple[List[dict], Optional[int]]: + addresses, count = await self.run_in_executor(q.get_addresses, **constraints) + if addresses and 'pubkey' in addresses[0]: for address in addresses: address['pubkey'] = PubKey( self.ledger, bytes(address.pop('pubkey')), bytes(address.pop('chain_code')), address.pop('n'), address.pop('depth') ) - return addresses + return addresses, count - async def get_address_count(self, cols=None, **constraints): - count = await self.select_addresses([func.count().label('total')], **constraints) - return count[0]['total'] or 0 + async def get_all_addresses(self): + return await self.run_in_executor(q.get_all_addresses) async def get_address(self, **constraints): - addresses = await self.get_addresses(limit=1, **constraints) + addresses, _ = await self.get_addresses(limit=1, **constraints) if addresses: return addresses[0] async def add_keys(self, account, chain, pubkeys): - await self.execute_fetchall( - insert_or_ignore(self.db, PubkeyAddress).values([{ - 'address': k.address - } for k in pubkeys]) - ) - await self.execute_fetchall( - insert_or_ignore(self.db, AccountAddress).values([{ - 'account': account.id, - 'address': k.address, - 'chain': chain, - 'pubkey': k.pubkey_bytes, - 'chain_code': k.chain_code, - 'n': k.n, - 'depth': k.depth - } for k in pubkeys]) - ) + return await self.run_in_executor(q.add_keys, account, chain, pubkeys) - async def _set_address_history(self, address, history): - await self.execute_fetchall( - PubkeyAddress.update() - .values(history=history, used_times=history.count(':')//2) - .where(PubkeyAddress.c.address == address) - ) + async def get_raw_transactions(self, tx_hashes): + return await self.run_in_executor(q.get_raw_transactions, tx_hashes) - async def set_address_history(self, address, history): - await self._set_address_history(address, history) + async def get_transactions(self, **constraints) -> Tuple[List[Transaction], Optional[int]]: + return await self.run_in_executor(q.get_transactions, **constraints) - @staticmethod - def constrain_purchases(constraints): - accounts = constraints.pop('accounts', None) - assert accounts, "'accounts' argument required to find purchases" - if not {'purchased_claim_hash', 'purchased_claim_hash__in'}.intersection(constraints): - constraints['purchased_claim_hash__is_not_null'] = True - constraints['tx_hash__in'] = ( - select(TXI.c.tx_hash).select_from(txi_join_account).where(in_account(accounts)) - ) + async def get_transaction(self, **constraints) -> Optional[Transaction]: + txs, _ = await self.get_transactions(limit=1, **constraints) + if txs: + return txs[0] - async def get_purchases(self, **constraints): - self.constrain_purchases(constraints) - return [tx.outputs[0] for tx in await self.get_transactions(**constraints)] + async def get_purchases(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.run_in_executor(q.get_purchases, **constraints) - def get_purchase_count(self, **constraints): - self.constrain_purchases(constraints) - return self.get_transaction_count(**constraints) + async def search_claims(self, **constraints): + return await self.run_in_executor(q.search, **constraints) - @staticmethod - def constrain_claims(constraints): - if {'txo_type', 'txo_type__in'}.intersection(constraints): - return - claim_types = constraints.pop('claim_type', None) - if claim_types: - constrain_single_or_list( - constraints, 'txo_type', claim_types, lambda x: TXO_TYPES[x] - ) - else: - constraints['txo_type__in'] = CLAIM_TYPES + async def get_txo_sum(self, **constraints): + return await self.run_in_executor(q.get_txo_sum, **constraints) - async def get_claims(self, **constraints) -> List[Output]: - self.constrain_claims(constraints) - return await self.get_utxos(**constraints) + async def get_txo_plot(self, **constraints): + return await self.run_in_executor(q.get_txo_plot, **constraints) - def get_claim_count(self, **constraints): - self.constrain_claims(constraints) - return self.get_utxo_count(**constraints) + async def get_txos(self, **constraints) -> Tuple[List[Output], Optional[int]]: + txos = await self.run_in_executor(q.get_txos, **constraints) + if 'wallet' in constraints: + add_channel_keys_to_txo_results(constraints['wallet'], txos) + return txos - @staticmethod - def constrain_streams(constraints): - constraints['txo_type'] = TXO_TYPES['stream'] + async def get_utxos(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.get_txos(is_spent=False, **constraints) - def get_streams(self, **constraints): - self.constrain_streams(constraints) - return self.get_claims(**constraints) + async def get_supports(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.get_utxos(txo_type=TXO_TYPES['support'], **constraints) - def get_stream_count(self, **constraints): - self.constrain_streams(constraints) - return self.get_claim_count(**constraints) + async def get_claims(self, **constraints) -> Tuple[List[Output], Optional[int]]: + txos, count = await self.run_in_executor(q.get_claims, **constraints) + if 'wallet' in constraints: + add_channel_keys_to_txo_results(constraints['wallet'].accounts, txos) + return txos, count - @staticmethod - def constrain_channels(constraints): - constraints['txo_type'] = TXO_TYPES['channel'] + async def get_streams(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.get_claims(txo_type=TXO_TYPES['stream'], **constraints) - def get_channels(self, **constraints): - self.constrain_channels(constraints) - return self.get_claims(**constraints) + async def get_channels(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.get_claims(txo_type=TXO_TYPES['channel'], **constraints) - def get_channel_count(self, **constraints): - self.constrain_channels(constraints) - return self.get_claim_count(**constraints) - - @staticmethod - def constrain_supports(constraints): - constraints['txo_type'] = TXO_TYPES['support'] - - def get_supports(self, **constraints): - self.constrain_supports(constraints) - return self.get_utxos(**constraints) - - def get_support_count(self, **constraints): - self.constrain_supports(constraints) - return self.get_utxo_count(**constraints) - - @staticmethod - def constrain_collections(constraints): - constraints['txo_type'] = TXO_TYPES['collection'] - - def get_collections(self, **constraints): - self.constrain_collections(constraints) - return self.get_utxos(**constraints) - - def get_collection_count(self, **constraints): - self.constrain_collections(constraints) - return self.get_utxo_count(**constraints) - - async def release_all_outputs(self, account): - await self.execute_fetchall( - TXO.update().values(is_reserved=False).where( - (TXO.c.is_reserved == True) & - (TXO.c.address.in_(select(AccountAddress.c.address).where(in_account(account)))) - ) - ) - - def get_supports_summary(self, **constraints): - return self.get_txos( - txo_type=TXO_TYPES['support'], - is_spent=False, is_my_output=True, - include_is_my_input=True, - no_tx=True, - **constraints - ) + async def get_collections(self, **constraints) -> Tuple[List[Output], Optional[int]]: + return await self.get_claims(txo_type=TXO_TYPES['collection'], **constraints) diff --git a/lbry/db/queries.py b/lbry/db/queries.py new file mode 100644 index 000000000..d6dcf8407 --- /dev/null +++ b/lbry/db/queries.py @@ -0,0 +1,1201 @@ +# pylint: disable=singleton-comparison +import struct +from datetime import date +from decimal import Decimal +from binascii import unhexlify +from operator import itemgetter +from contextvars import ContextVar +from itertools import chain +from typing import NamedTuple, Tuple, Dict, Callable, Optional + +from sqlalchemy import create_engine, union, func, inspect +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.future import select + +from lbry.schema.tags import clean_tags +from lbry.schema.result import Censor, Outputs +from lbry.schema.url import URL, normalize_name +from lbry.error import ResolveCensoredError +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.transaction import Transaction, Output, Input, OutputScript, TXRefImmutable + +from .utils import * +from .tables import * +from .constants import * + + +MAX_QUERY_VARIABLES = 900 + + +_context: ContextVar['QueryContext'] = ContextVar('_context') + + +def ctx(): + return _context.get() + + +def initialize(url: str, ledger: Ledger, track_metrics=False, block_and_filter=None): + engine = create_engine(url) + connection = engine.connect() + if block_and_filter is not None: + blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter + else: + blocked_streams = blocked_channels = filtered_streams = filtered_channels = {} + _context.set( + QueryContext( + engine=engine, connection=connection, ledger=ledger, + stack=[], metrics={}, is_tracking_metrics=track_metrics, + blocked_streams=blocked_streams, blocked_channels=blocked_channels, + filtered_streams=filtered_streams, filtered_channels=filtered_channels, + ) + ) + + +def check_version_and_create_tables(): + context = ctx() + if SCHEMA_VERSION: + if context.has_table('version'): + version = context.fetchone(select(Version.c.version).limit(1)) + if version and version['version'] == SCHEMA_VERSION: + return + metadata.drop_all(context.engine) + metadata.create_all(context.engine) + context.execute(Version.insert().values(version=SCHEMA_VERSION)) + else: + metadata.create_all(context.engine) + + +class QueryContext(NamedTuple): + engine: Engine + connection: Connection + ledger: Ledger + stack: List[List] + metrics: Dict + is_tracking_metrics: bool + blocked_streams: Dict + blocked_channels: Dict + filtered_streams: Dict + filtered_channels: Dict + + @property + def is_postgres(self): + return self.connection.dialect.name == 'postgresql' + + @property + def is_sqlite(self): + return self.connection.dialect.name == 'sqlite' + + def raise_unsupported_dialect(self): + raise RuntimeError(f'Unsupported database dialect: {self.connection.dialect.name}.') + + def reset_metrics(self): + self.stack = [] + self.metrics = {} + + def get_resolve_censor(self) -> Censor: + return Censor(self.blocked_streams, self.blocked_channels) + + def get_search_censor(self) -> Censor: + return Censor(self.filtered_streams, self.filtered_channels) + + def execute(self, sql, *args): + return self.connection.execute(sql, *args) + + def fetchone(self, sql, *args): + row = self.connection.execute(sql, *args).fetchone() + return dict(row._mapping) if row else row + + def fetchall(self, sql, *args): + rows = self.connection.execute(sql, *args).fetchall() + return [dict(row._mapping) for row in rows] + + def insert_or_ignore(self, table): + if self.is_sqlite: + return table.insert().prefix_with("OR IGNORE") + elif self.is_postgres: + return pg_insert(table).on_conflict_do_nothing() + else: + self.raise_unsupported_dialect() + + def insert_or_replace(self, table, replace): + if self.is_sqlite: + return table.insert().prefix_with("OR REPLACE") + elif self.is_postgres: + insert = pg_insert(table) + return insert.on_conflict_do_update( + table.primary_key, set_={col: getattr(insert.excluded, col) for col in replace} + ) + else: + self.raise_unsupported_dialect() + + def has_table(self, table): + return inspect(self.engine).has_table(table) + + +class RowCollector: + + def __init__(self, context: QueryContext): + self.context = context + self.ledger = context.ledger + self.blocks = [] + self.txs = [] + self.txos = [] + self.txis = [] + + @staticmethod + def block_to_row(block): + return { + 'block_hash': block.block_hash, + 'previous_hash': block.prev_block_hash, + 'file_number': block.file_number, + 'height': 0 if block.is_first_block else None, + } + + @staticmethod + def tx_to_row(block_hash: bytes, tx: Transaction): + row = { + 'tx_hash': tx.hash, + 'block_hash': block_hash, + 'raw': tx.raw, + 'height': tx.height, + 'position': tx.position, + 'is_verified': tx.is_verified, + # TODO: fix + # 'day': tx.get_ordinal_day(self.db.ledger), + 'purchased_claim_hash': None, + } + txos = tx.outputs + if len(txos) >= 2 and txos[1].can_decode_purchase_data: + txos[0].purchase = txos[1] + row['purchased_claim_hash'] = txos[1].purchase_data.claim_hash + return row + + @staticmethod + def txi_to_row(tx: Transaction, txi: Input): + return { + 'tx_hash': tx.hash, + 'txo_hash': txi.txo_ref.hash, + 'position': txi.position, + } + + def txo_to_row(self, tx: Transaction, txo: Output): + row = { + 'tx_hash': tx.hash, + 'txo_hash': txo.hash, + 'address': txo.get_address(self.ledger) if txo.has_address else None, + 'position': txo.position, + 'amount': txo.amount, + 'script_offset': txo.script.offset, + 'script_length': txo.script.length, + 'txo_type': 0, + 'claim_id': None, + 'claim_hash': None, + 'claim_name': None, + 'reposted_claim_hash': None, + 'channel_hash': None, + } + if txo.is_claim: + if txo.can_decode_claim: + claim = txo.claim + row['txo_type'] = TXO_TYPES.get(claim.claim_type, TXO_TYPES['stream']) + if claim.is_repost: + row['reposted_claim_hash'] = claim.repost.reference.claim_hash + if claim.is_signed: + row['channel_hash'] = claim.signing_channel_hash + else: + row['txo_type'] = TXO_TYPES['stream'] + elif txo.is_support: + row['txo_type'] = TXO_TYPES['support'] + elif txo.purchase is not None: + row['txo_type'] = TXO_TYPES['purchase'] + row['claim_id'] = txo.purchased_claim_id + row['claim_hash'] = txo.purchased_claim_hash + if txo.script.is_claim_involved: + row['claim_id'] = txo.claim_id + row['claim_hash'] = txo.claim_hash + row['claim_name'] = txo.claim_name + return row + + def add_block(self, block): + self.blocks.append(self.block_to_row(block)) + for tx in block.txs: + self.add_transaction(block.block_hash, tx) + return self + + def add_transaction(self, block_hash: bytes, tx: Transaction): + self.txs.append(self.tx_to_row(block_hash, tx)) + for txi in tx.inputs: + if txi.coinbase is None: + self.txis.append(self.txi_to_row(tx, txi)) + for txo in tx.outputs: + self.txos.append(self.txo_to_row(tx, txo)) + return self + + def save(self, progress: Callable = None): + queries = ( + (Block.insert(), self.blocks), + (TX.insert(), self.txs), + (TXO.insert(), self.txos), + (TXI.insert(), self.txis), + ) + total_rows = sum(len(query[1]) for query in queries) + inserted_rows = 0 + if progress is not None: + progress(inserted_rows, total_rows) + execute = self.context.connection.execute + for sql, rows in queries: + for chunk_size, chunk_rows in chunk(rows, 10000): + execute(sql, list(chunk_rows)) + inserted_rows += chunk_size + if progress is not None: + progress(inserted_rows, total_rows) + + +def insert_transaction(block_hash, tx): + RowCollector(ctx()).add_transaction(block_hash, tx).save() + + +def process_claims_and_supports(block_range=None): + context = ctx() + if context.is_sqlite: + address_query = select(TXO.c.address).where(TXI.c.txo_hash == TXO.c.txo_hash) + sql = ( + TXI.update() + .values(address=address_query.scalar_subquery()) + .where(TXI.c.address == None) + ) + else: + sql = ( + TXI.update() + .values({TXI.c.address: TXO.c.address}) + .where((TXI.c.address == None) & (TXI.c.txo_hash == TXO.c.txo_hash)) + ) + context.execute(sql) + + context.execute(Claim.delete()) + for claim in get_txos(txo_type__in=CLAIM_TYPE_CODES, is_spent=False)[0]: + context.execute( + Claim.insert(), { + 'claim_hash': claim.claim_hash, + 'claim_name': claim.claim_name, + 'amount': claim.amount, + 'txo_hash': claim.hash + } + ) + + +def execute_fetchall(sql): + return ctx().fetchall(text(sql)) + + +def get_best_height(): + return ctx().fetchone( + select(func.coalesce(func.max(TX.c.height), 0).label('total')).select_from(TX) + )['total'] + + +def get_blocks_without_filters(): + return ctx().fetchall( + select(Block.c.block_hash) + .select_from(Block) + .where(Block.c.block_filter == None) + ) + + +def get_transactions_without_filters(): + return ctx().fetchall( + select(TX.c.tx_hash) + .select_from(TX) + .where(TX.c.tx_filter == None) + ) + + +def get_block_tx_addresses(block_hash=None, tx_hash=None): + if block_hash is not None: + constraint = (TX.c.block_hash == block_hash) + elif tx_hash is not None: + constraint = (TX.c.tx_hash == tx_hash) + else: + raise ValueError('block_hash or tx_hash must be provided.') + return ctx().fetchall( + union( + select(TXO.c.address).select_from(TXO.join(TX)).where((TXO.c.address != None) & constraint), + select(TXI.c.address).select_from(TXI.join(TX)).where((TXI.c.address != None) & constraint), + ) + ) + + +def get_block_address_filters(): + return ctx().fetchall( + select(Block.c.block_hash, Block.c.block_filter).select_from(Block) + ) + + +def get_transaction_address_filters(block_hash): + return ctx().fetchall( + select(TX.c.tx_hash, TX.c.tx_filter) + .select_from(TX) + .where(TX.c.block_hash == block_hash) + ) + + +def update_address_used_times(addresses): + ctx().execute( + PubkeyAddress.update() + .values(used_times=( + select(func.count(TXO.c.address)).where((TXO.c.address == PubkeyAddress.c.address)), + )) + .where(PubkeyAddress.c.address._in(addresses)) + ) + + +def reserve_outputs(txo_hashes, is_reserved=True): + ctx().execute( + TXO.update().values(is_reserved=is_reserved).where(TXO.c.txo_hash.in_(txo_hashes)) + ) + + +def release_all_outputs(account_id): + ctx().execute( + TXO.update().values(is_reserved=False).where( + (TXO.c.is_reserved == True) & + (TXO.c.address.in_(select(AccountAddress.c.address).where(in_account_ids(account_id)))) + ) + ) + + +def select_transactions(cols, account_ids=None, **constraints): + s: Select = select(*cols).select_from(TX) + if not {'tx_hash', 'tx_hash__in'}.intersection(constraints): + assert account_ids, "'accounts' argument required when no 'tx_hash' constraint is present" + where = in_account_ids(account_ids) + tx_hashes = union( + select(TXO.c.tx_hash).select_from(txo_join_account).where(where), + select(TXI.c.tx_hash).select_from(txi_join_account).where(where) + ) + s = s.where(TX.c.tx_hash.in_(tx_hashes)) + return ctx().fetchall(query([TX], s, **constraints)) + + +TXO_NOT_MINE = Output(None, None, is_my_output=False) + + +def get_raw_transactions(tx_hashes): + return ctx().fetchall( + select(TX.c.tx_hash, TX.c.raw).where(TX.c.tx_hash.in_(tx_hashes)) + ) + + +def get_transactions(wallet=None, include_total=False, **constraints) -> Tuple[List[Transaction], Optional[int]]: + include_is_spent = constraints.pop('include_is_spent', False) + include_is_my_input = constraints.pop('include_is_my_input', False) + include_is_my_output = constraints.pop('include_is_my_output', False) + + tx_rows = select_transactions( + [TX.c.tx_hash, TX.c.raw, TX.c.height, TX.c.position, TX.c.is_verified], + order_by=constraints.pop('order_by', ["height=0 DESC", "height DESC", "position DESC"]), + **constraints + ) + + txids, txs, txi_txoids = [], [], [] + for row in tx_rows: + txids.append(row['tx_hash']) + txs.append(Transaction( + raw=row['raw'], height=row['height'], position=row['position'], + is_verified=bool(row['is_verified']) + )) + for txi in txs[-1].inputs: + txi_txoids.append(txi.txo_ref.hash) + + annotated_txos = {} + for offset in range(0, len(txids), MAX_QUERY_VARIABLES): + annotated_txos.update({ + txo.id: txo for txo in + get_txos( + wallet=wallet, + tx_hash__in=txids[offset:offset + MAX_QUERY_VARIABLES], order_by='txo.tx_hash', + include_is_spent=include_is_spent, + include_is_my_input=include_is_my_input, + include_is_my_output=include_is_my_output, + )[0] + }) + + referenced_txos = {} + for offset in range(0, len(txi_txoids), MAX_QUERY_VARIABLES): + referenced_txos.update({ + txo.id: txo for txo in + get_txos( + wallet=wallet, + txo_hash__in=txi_txoids[offset:offset + MAX_QUERY_VARIABLES], order_by='txo.txo_hash', + include_is_my_output=include_is_my_output, + )[0] + }) + + for tx in txs: + for txi in tx.inputs: + txo = referenced_txos.get(txi.txo_ref.id) + if txo: + txi.txo_ref = txo.ref + for txo in tx.outputs: + _txo = annotated_txos.get(txo.id) + if _txo: + txo.update_annotations(_txo) + else: + txo.update_annotations(TXO_NOT_MINE) + + for tx in txs: + txos = tx.outputs + if len(txos) >= 2 and txos[1].can_decode_purchase_data: + txos[0].purchase = txos[1] + + return txs, get_transaction_count(**constraints) if include_total else None + + +def get_transaction_count(**constraints): + constraints.pop('wallet', None) + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + count = select_transactions([func.count().label('total')], **constraints) + return count[0]['total'] or 0 + + +def select_txos( + cols, account_ids=None, is_my_input=None, is_my_output=True, + is_my_input_or_output=None, exclude_internal_transfers=False, + include_is_spent=False, include_is_my_input=False, + is_spent=None, spent=None, is_claim_list=False, **constraints): + s: Select = select(*cols) + if account_ids: + my_addresses = select(AccountAddress.c.address).where(in_account_ids(account_ids)) + if is_my_input_or_output: + include_is_my_input = True + s = s.where( + TXO.c.address.in_(my_addresses) | ( + (TXI.c.address != None) & + (TXI.c.address.in_(my_addresses)) + ) + ) + else: + if is_my_output: + s = s.where(TXO.c.address.in_(my_addresses)) + elif is_my_output is False: + s = s.where(TXO.c.address.notin_(my_addresses)) + if is_my_input: + include_is_my_input = True + s = s.where( + (TXI.c.address != None) & + (TXI.c.address.in_(my_addresses)) + ) + elif is_my_input is False: + include_is_my_input = True + s = s.where( + (TXI.c.address == None) | + (TXI.c.address.notin_(my_addresses)) + ) + if exclude_internal_transfers: + include_is_my_input = True + s = s.where( + (TXO.c.txo_type != TXO_TYPES['other']) | + (TXO.c.address.notin_(my_addresses)) + (TXI.c.address == None) | + (TXI.c.address.notin_(my_addresses)) + ) + joins = TXO.join(TX) + tables = [TXO, TX] + if spent is None: + spent = TXI.alias('spent') + if is_spent: + s = s.where(spent.c.txo_hash != None) + elif is_spent is False: + s = s.where((spent.c.txo_hash == None) & (TXO.c.is_reserved == False)) + if include_is_spent or is_spent is not None: + joins = joins.join(spent, spent.c.txo_hash == TXO.c.txo_hash, isouter=True) + if include_is_my_input: + joins = joins.join(TXI, (TXI.c.position == 0) & (TXI.c.tx_hash == TXO.c.tx_hash), isouter=True) + if is_claim_list: + tables.append(Claim) + joins = joins.join(Claim) + s = s.select_from(joins) + return ctx().fetchall(query(tables, s, **constraints)) + + +def get_txos(no_tx=False, include_total=False, **constraints) -> Tuple[List[Output], Optional[int]]: + wallet_account_ids = constraints.pop('wallet_account_ids', []) + include_is_spent = constraints.get('include_is_spent', False) + include_is_my_input = constraints.get('include_is_my_input', False) + include_is_my_output = constraints.pop('include_is_my_output', False) + include_received_tips = constraints.pop('include_received_tips', False) + + select_columns = [ + TX.c.tx_hash, TX.c.raw, TX.c.height, TX.c.position.label('tx_position'), TX.c.is_verified, + TXO.c.txo_type, TXO.c.position.label('txo_position'), TXO.c.amount, + TXO.c.script_offset, TXO.c.script_length, + TXO.c.claim_name + + ] + + my_accounts = None + if wallet_account_ids: + my_accounts = select(AccountAddress.c.address).where(in_account_ids(wallet_account_ids)) + + if include_is_my_output and my_accounts is not None: + if constraints.get('is_my_output', None) in (True, False): + select_columns.append(text(f"{1 if constraints['is_my_output'] else 0} AS is_my_output")) + else: + select_columns.append(TXO.c.address.in_(my_accounts).label('is_my_output')) + + if include_is_my_input and my_accounts is not None: + if constraints.get('is_my_input', None) in (True, False): + select_columns.append(text(f"{1 if constraints['is_my_input'] else 0} AS is_my_input")) + else: + select_columns.append(( + (TXI.c.address != None) & + (TXI.c.address.in_(my_accounts)) + ).label('is_my_input')) + + spent = TXI.alias('spent') + if include_is_spent: + select_columns.append((spent.c.txo_hash != None).label('is_spent')) + + if include_received_tips: + support = TXO.alias('support') + select_columns.append( + select(func.coalesce(func.sum(support.c.amount), 0)) + .select_from(support).where( + (support.c.claim_hash == TXO.c.claim_hash) & + (support.c.txo_type == TXO_TYPES['support']) & + (support.c.address.in_(my_accounts)) & + (support.c.txo_hash.notin_(select(TXI.c.txo_hash))) + ).label('received_tips') + ) + + if 'order_by' not in constraints or constraints['order_by'] == 'height': + constraints['order_by'] = [ + "tx.height=0 DESC", "tx.height DESC", "tx.position DESC", "txo.position" + ] + elif constraints.get('order_by', None) == 'none': + del constraints['order_by'] + + rows = select_txos(select_columns, spent=spent, **constraints) + + txs = {} + txos = [] + for row in rows: + if no_tx: + source = row['raw'][row['script_offset']:row['script_offset']+row['script_length']] + txo = Output( + amount=row['amount'], + script=OutputScript(source), + tx_ref=TXRefImmutable.from_hash(row['tx_hash'], row['height']), + position=row['txo_position'] + ) + else: + if row['tx_hash'] not in txs: + txs[row['tx_hash']] = Transaction( + row['raw'], height=row['height'], position=row['tx_position'], + is_verified=bool(row['is_verified']) + ) + txo = txs[row['tx_hash']].outputs[row['txo_position']] + if include_is_spent: + txo.is_spent = bool(row['is_spent']) + if include_is_my_input: + txo.is_my_input = bool(row['is_my_input']) + if include_is_my_output: + txo.is_my_output = bool(row['is_my_output']) + if include_is_my_input and include_is_my_output: + if txo.is_my_input and txo.is_my_output and row['txo_type'] == TXO_TYPES['other']: + txo.is_internal_transfer = True + else: + txo.is_internal_transfer = False + if include_received_tips: + txo.received_tips = row['received_tips'] + txos.append(txo) + + channel_hashes = set() + for txo in txos: + if txo.is_claim and txo.can_decode_claim: + if txo.claim.is_signed: + channel_hashes.add(txo.claim.signing_channel_hash) + + if channel_hashes: + channels = { + txo.claim_hash: txo for txo in + get_txos( + txo_type=TXO_TYPES['channel'], is_spent=False, + wallet_account_ids=wallet_account_ids, claim_hash__in=channel_hashes + )[0] + } + for txo in txos: + if txo.is_claim and txo.can_decode_claim: + txo.channel = channels.get(txo.claim.signing_channel_hash, None) + + return txos, get_txo_count(**constraints) if include_total else None + + +def _clean_txo_constraints_for_aggregation(constraints): + constraints.pop('include_is_spent', None) + constraints.pop('include_is_my_input', None) + constraints.pop('include_is_my_output', None) + constraints.pop('include_received_tips', None) + constraints.pop('wallet_account_ids', None) + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + + +def get_txo_count(**constraints): + _clean_txo_constraints_for_aggregation(constraints) + count = select_txos([func.count().label('total')], **constraints) + return count[0]['total'] or 0 + + +def get_txo_sum(**constraints): + _clean_txo_constraints_for_aggregation(constraints) + result = select_txos([func.sum(TXO.c.amount).label('total')], **constraints) + return result[0]['total'] or 0 + + +def get_txo_plot(start_day=None, days_back=0, end_day=None, days_after=None, **constraints): + _clean_txo_constraints_for_aggregation(constraints) + if start_day is None: + # TODO: Fix + raise NotImplementedError + current_ordinal = 0 # self.ledger.headers.estimated_date(self.ledger.headers.height).toordinal() + constraints['day__gte'] = current_ordinal - days_back + else: + constraints['day__gte'] = date.fromisoformat(start_day).toordinal() + if end_day is not None: + constraints['day__lte'] = date.fromisoformat(end_day).toordinal() + elif days_after is not None: + constraints['day__lte'] = constraints['day__gte'] + days_after + plot = select_txos( + [TX.c.day, func.sum(TXO.c.amount).label('total')], + group_by='day', order_by='day', **constraints + ) + for row in plot: + row['day'] = date.fromordinal(row['day']) + return plot + + +def get_purchases(**constraints) -> Tuple[List[Output], Optional[int]]: + accounts = constraints.pop('accounts', None) + assert accounts, "'accounts' argument required to find purchases" + if not {'purchased_claim_hash', 'purchased_claim_hash__in'}.intersection(constraints): + constraints['purchased_claim_hash__is_not_null'] = True + constraints['tx_hash__in'] = ( + select(TXI.c.tx_hash).select_from(txi_join_account).where(in_account(accounts)) + ) + txs, count = get_transactions(**constraints) + return [tx.outputs[0] for tx in txs], count + + +def get_balance(**constraints): + balance = select_txos( + [func.sum(TXO.c.amount).label('total')], is_spent=False, **constraints + ) + return balance[0]['total'] or 0 + + +def select_addresses(cols, **constraints): + return ctx().fetchall(query( + [AccountAddress, PubkeyAddress], + select(*cols).select_from(PubkeyAddress.join(AccountAddress)), + **constraints + )) + + +def get_addresses(cols=None, include_total=False, **constraints) -> Tuple[List[dict], Optional[int]]: + if cols is None: + cols = ( + PubkeyAddress.c.address, + PubkeyAddress.c.used_times, + AccountAddress.c.account, + AccountAddress.c.chain, + AccountAddress.c.pubkey, + AccountAddress.c.chain_code, + AccountAddress.c.n, + AccountAddress.c.depth + ) + return ( + select_addresses(cols, **constraints), + get_address_count(**constraints) if include_total else None + ) + + +def get_address_count(**constraints): + count = select_addresses([func.count().label('total')], **constraints) + return count[0]['total'] or 0 + + +def get_all_addresses(self): + return ctx().execute(select(PubkeyAddress.c.address)) + + +def add_keys(account, chain, pubkeys): + c = ctx() + c.execute( + c.insert_or_ignore(PubkeyAddress) + .values([{'address': k.address} for k in pubkeys]) + ) + c.execute( + c.insert_or_ignore(AccountAddress) + .values([{ + 'account': account.id, + 'address': k.address, + 'chain': chain, + 'pubkey': k.pubkey_bytes, + 'chain_code': k.chain_code, + 'n': k.n, + 'depth': k.depth + } for k in pubkeys]) + ) + + +def get_supports_summary(self, **constraints): + return get_txos( + txo_type=TXO_TYPES['support'], + is_spent=False, is_my_output=True, + include_is_my_input=True, + no_tx=True, + **constraints + ) + + +def search_to_bytes(constraints) -> Union[bytes, Tuple[bytes, Dict]]: + return Outputs.to_bytes(*search(constraints)) + + +def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]: + return Outputs.to_bytes(*resolve(urls)) + + +def execute_censored(sql, row_offset: int, row_limit: int, censor: Censor) -> List: + context = ctx() + return ctx().fetchall(sql) + c = context.db.cursor() + def row_filter(cursor, row): + nonlocal row_offset + #row = row_factory(cursor, row) + if len(row) > 1 and censor.censor(row): + return + if row_offset: + row_offset -= 1 + return + return row + c.setrowtrace(row_filter) + i, rows = 0, [] + for row in c.execute(sql): + i += 1 + rows.append(row) + if i >= row_limit: + break + return rows + + +def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]: + if 'order_by' in constraints: + order_by_parts = constraints['order_by'] + if isinstance(order_by_parts, str): + order_by_parts = [order_by_parts] + sql_order_by = [] + for order_by in order_by_parts: + is_asc = order_by.startswith('^') + column = order_by[1:] if is_asc else order_by + if column not in SEARCH_ORDER_FIELDS: + raise NameError(f'{column} is not a valid order_by field') + if column == 'name': + column = 'claim.claim_name' + sql_order_by.append( + f"{column} ASC" if is_asc else f"{column} DESC" + ) + constraints['order_by'] = sql_order_by + + ops = {'<=': '__lte', '>=': '__gte', '<': '__lt', '>': '__gt'} + for constraint in SEARCH_INTEGER_PARAMS: + if constraint in constraints: + value = constraints.pop(constraint) + postfix = '' + if isinstance(value, str): + if len(value) >= 2 and value[:2] in ops: + postfix, value = ops[value[:2]], value[2:] + elif len(value) >= 1 and value[0] in ops: + postfix, value = ops[value[0]], value[1:] + if constraint == 'fee_amount': + value = Decimal(value)*1000 + constraints[f'claim.{constraint}{postfix}'] = int(value) + + if constraints.pop('is_controlling', False): + if {'sequence', 'amount_order'}.isdisjoint(constraints): + for_count = False + constraints['claimtrie.claim_hash__is_not_null'] = '' + if 'sequence' in constraints: + constraints['order_by'] = 'claim.activation_height ASC' + constraints['offset'] = int(constraints.pop('sequence')) - 1 + constraints['limit'] = 1 + if 'amount_order' in constraints: + constraints['order_by'] = 'claim.effective_amount DESC' + constraints['offset'] = int(constraints.pop('amount_order')) - 1 + constraints['limit'] = 1 + + if 'claim_id' in constraints: + claim_id = constraints.pop('claim_id') + if len(claim_id) == 40: + constraints['claim.claim_id'] = claim_id + else: + constraints['claim.claim_id__like'] = f'{claim_id[:40]}%' + elif 'claim_ids' in constraints: + constraints['claim.claim_id__in'] = set(constraints.pop('claim_ids')) + + if 'reposted_claim_id' in constraints: + constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1] + + if 'name' in constraints: + constraints['claim_name'] = normalize_name(constraints.pop('name')) + + if 'public_key_id' in constraints: + constraints['claim.public_key_hash'] = ( + ctx().ledger.address_to_hash160(constraints.pop('public_key_id'))) + if 'channel_hash' in constraints: + constraints['claim.channel_hash'] = constraints.pop('channel_hash') + if 'channel_ids' in constraints: + channel_ids = constraints.pop('channel_ids') + if channel_ids: + constraints['claim.channel_hash__in'] = { + unhexlify(cid)[::-1] for cid in channel_ids + } + if 'not_channel_ids' in constraints: + not_channel_ids = constraints.pop('not_channel_ids') + if not_channel_ids: + not_channel_ids_binary = { + unhexlify(ncid)[::-1] for ncid in not_channel_ids + } + constraints['claim.claim_hash__not_in#not_channel_ids'] = not_channel_ids_binary + if constraints.get('has_channel_signature', False): + constraints['claim.channel_hash__not_in'] = not_channel_ids_binary + else: + constraints['null_or_not_channel__or'] = { + 'claim.signature_valid__is_null': True, + 'claim.channel_hash__not_in': not_channel_ids_binary + } + if 'signature_valid' in constraints: + has_channel_signature = constraints.pop('has_channel_signature', False) + if has_channel_signature: + constraints['claim.signature_valid'] = constraints.pop('signature_valid') + else: + constraints['null_or_signature__or'] = { + 'claim.signature_valid__is_null': True, + 'claim.signature_valid': constraints.pop('signature_valid') + } + elif constraints.pop('has_channel_signature', False): + constraints['claim.signature_valid__is_not_null'] = True + + if 'txid' in constraints: + tx_hash = unhexlify(constraints.pop('txid'))[::-1] + nout = constraints.pop('nout', 0) + constraints['claim.txo_hash'] = tx_hash + struct.pack(' List: + if 'channel' in constraints: + channel_url = constraints.pop('channel') + match = resolve_url(channel_url) + if isinstance(match, dict): + constraints['channel_hash'] = match['claim_hash'] + else: + return [{'row_count': 0}] if cols == 'count(*) as row_count' else [] + row_offset = constraints.pop('offset', 0) + row_limit = constraints.pop('limit', 20) + return execute_censored( + claims_query(cols, for_count, **constraints), + row_offset, row_limit, censor + ) + + +def count_claims(**constraints) -> int: + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + count = select_claims(Censor(), [func.count().label('row_count')], for_count=True, **constraints) + return count[0]['row_count'] + + +def search_claims(censor: Censor, **constraints) -> List: + return select_claims( + censor, [ + Claimtrie.c.claim_hash.label('is_controlling'), + Claimtrie.c.last_take_over_height, + Claim.c.claim_hash, + Claim.c.txo_hash, +# Claim.c.claims_in_channel, +# Claim.c.reposted, +# Claim.c.height, +# Claim.c.creation_height, +# Claim.c.activation_height, +# Claim.c.expiration_height, +# Claim.c.effective_amount, +# Claim.c.support_amount, +# Claim.c.trending_group, +# Claim.c.trending_mixed, +# Claim.c.trending_local, +# Claim.c.trending_global, +# Claim.c.short_url, +# Claim.c.canonical_url, + Claim.c.channel_hash, + Claim.c.reposted_claim_hash, +# Claim.c.signature_valid + ], **constraints + ) + + +def get_claims(**constraints) -> Tuple[List[Output], Optional[int]]: + return get_txos(no_tx=True, is_claim_list=True, **constraints) + + +def _get_referenced_rows(txo_rows: List[dict], censor_channels: List[bytes]): + censor = ctx().get_resolve_censor() + repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) + channel_hashes = set(chain( + filter(None, map(itemgetter('channel_hash'), txo_rows)), + censor_channels + )) + + reposted_txos = [] + if repost_hashes: + reposted_txos = search_claims(censor, **{'claim.claim_hash__in': repost_hashes}) + channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) + + channel_txos = [] + if channel_hashes: + channel_txos = search_claims(censor, **{'claim.claim_hash__in': channel_hashes}) + + # channels must come first for client side inflation to work properly + return channel_txos + reposted_txos + + +def search(**constraints) -> Tuple[List, List, int, int, Censor]: + assert set(constraints).issubset(SEARCH_PARAMS), \ + f"Search query contains invalid arguments: {set(constraints).difference(SEARCH_PARAMS)}" + total = None + if not constraints.pop('no_totals', False): + total = count_claims(**constraints) + constraints['offset'] = abs(constraints.get('offset', 0)) + constraints['limit'] = min(abs(constraints.get('limit', 10)), 50) + context = ctx() + search_censor = context.get_search_censor() + txo_rows = search_claims(search_censor, **constraints) + extra_txo_rows = _get_referenced_rows(txo_rows, search_censor.censored.keys()) + return txo_rows, extra_txo_rows, constraints['offset'], total, search_censor + + +def resolve(urls) -> Tuple[List, List]: + txo_rows = [resolve_url(raw_url) for raw_url in urls] + extra_txo_rows = _get_referenced_rows( + [txo for txo in txo_rows if isinstance(txo, dict)], + [txo.censor_hash for txo in txo_rows if isinstance(txo, ResolveCensoredError)] + ) + return txo_rows, extra_txo_rows + + +def resolve_url(raw_url): + censor = ctx().get_resolve_censor() + + try: + url = URL.parse(raw_url) + except ValueError as e: + return e + + channel = None + + if url.has_channel: + query = url.channel.to_dict() + if set(query) == {'name'}: + query['is_controlling'] = True + else: + query['order_by'] = ['^creation_height'] + matches = search_claims(censor, **query, limit=1) + if matches: + channel = matches[0] + elif censor.censored: + return ResolveCensoredError(raw_url, next(iter(censor.censored))) + else: + return LookupError(f'Could not find channel in "{raw_url}".') + + if url.has_stream: + query = url.stream.to_dict() + if channel is not None: + if set(query) == {'name'}: + # temporarily emulate is_controlling for claims in channel + query['order_by'] = ['effective_amount', '^height'] + else: + query['order_by'] = ['^channel_join'] + query['channel_hash'] = channel['claim_hash'] + query['signature_valid'] = 1 + elif set(query) == {'name'}: + query['is_controlling'] = 1 + matches = search_claims(censor, **query, limit=1) + if matches: + return matches[0] + elif censor.censored: + return ResolveCensoredError(raw_url, next(iter(censor.censored))) + else: + return LookupError(f'Could not find claim at "{raw_url}".') + + return channel + + +CLAIM_HASH_OR_REPOST_HASH_SQL = f""" +CASE WHEN claim.claim_type = {TXO_TYPES['repost']} + THEN claim.reposted_claim_hash + ELSE claim.claim_hash +END +""" + + +def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_count=False): + any_items = set(cleaner(constraints.pop(f'any_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) + all_items = set(cleaner(constraints.pop(f'all_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) + not_items = set(cleaner(constraints.pop(f'not_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) + + all_items = {item for item in all_items if item not in not_items} + any_items = {item for item in any_items if item not in not_items} + + any_queries = {} + +# if attr == 'tag': +# common_tags = any_items & COMMON_TAGS.keys() +# if common_tags: +# any_items -= common_tags +# if len(common_tags) < 5: +# for item in common_tags: +# index_name = COMMON_TAGS[item] +# any_queries[f'#_common_tag_{index_name}'] = f""" +# EXISTS( +# SELECT 1 FROM tag INDEXED BY tag_{index_name}_idx +# WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=tag.claim_hash +# AND tag = '{item}' +# ) +# """ +# elif len(common_tags) >= 5: +# constraints.update({ +# f'$any_common_tag{i}': item for i, item in enumerate(common_tags) +# }) +# values = ', '.join( +# f':$any_common_tag{i}' for i in range(len(common_tags)) +# ) +# any_queries[f'#_any_common_tags'] = f""" +# EXISTS( +# SELECT 1 FROM tag WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=tag.claim_hash +# AND tag IN ({values}) +# ) +# """ + + if any_items: + + constraints.update({ + f'$any_{attr}{i}': item for i, item in enumerate(any_items) + }) + values = ', '.join( + f':$any_{attr}{i}' for i in range(len(any_items)) + ) + if for_count or attr == 'tag': + any_queries[f'#_any_{attr}'] = f""" + {CLAIM_HASH_OR_REPOST_HASH_SQL} IN ( + SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) + ) + """ + else: + any_queries[f'#_any_{attr}'] = f""" + EXISTS( + SELECT 1 FROM {attr} WHERE + {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash + AND {attr} IN ({values}) + ) + """ + + if len(any_queries) == 1: + constraints.update(any_queries) + elif len(any_queries) > 1: + constraints[f'ORed_{attr}_queries__any'] = any_queries + + if all_items: + constraints[f'$all_{attr}_count'] = len(all_items) + constraints.update({ + f'$all_{attr}{i}': item for i, item in enumerate(all_items) + }) + values = ', '.join( + f':$all_{attr}{i}' for i in range(len(all_items)) + ) + if for_count: + constraints[f'#_all_{attr}'] = f""" + {CLAIM_HASH_OR_REPOST_HASH_SQL} IN ( + SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) + GROUP BY claim_hash HAVING COUNT({attr}) = :$all_{attr}_count + ) + """ + else: + constraints[f'#_all_{attr}'] = f""" + {len(all_items)}=( + SELECT count(*) FROM {attr} WHERE + {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash + AND {attr} IN ({values}) + ) + """ + + if not_items: + constraints.update({ + f'$not_{attr}{i}': item for i, item in enumerate(not_items) + }) + values = ', '.join( + f':$not_{attr}{i}' for i in range(len(not_items)) + ) + if for_count: + constraints[f'#_not_{attr}'] = f""" + {CLAIM_HASH_OR_REPOST_HASH_SQL} NOT IN ( + SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) + ) + """ + else: + constraints[f'#_not_{attr}'] = f""" + NOT EXISTS( + SELECT 1 FROM {attr} WHERE + {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash + AND {attr} IN ({values}) + ) + """ diff --git a/lbry/db/search.py b/lbry/db/search.py index 86b24f009..fec9f75de 100644 --- a/lbry/db/search.py +++ b/lbry/db/search.py @@ -18,7 +18,7 @@ from lbry.schema.tags import clean_tags from lbry.schema.result import Outputs, Censor from lbry.blockchain.ledger import Ledger, RegTestLedger -from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS +from .constants import CLAIM_TYPES, STREAM_TYPES from .full_text_search import FTS_ORDER_BY @@ -34,32 +34,6 @@ class SQLiteInterruptedError(apsw.InterruptError): self.metrics = metrics -ATTRIBUTE_ARRAY_MAX_LENGTH = 100 - -INTEGER_PARAMS = { - 'height', 'creation_height', 'activation_height', 'expiration_height', - 'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount', - 'tx_position', 'channel_join', 'reposted', - 'amount', 'effective_amount', 'support_amount', - 'trending_group', 'trending_mixed', - 'trending_local', 'trending_global', -} - -SEARCH_PARAMS = { - 'name', 'text', 'claim_id', 'claim_ids', 'txid', 'nout', 'channel', 'channel_ids', 'not_channel_ids', - 'public_key_id', 'claim_type', 'stream_types', 'media_types', 'fee_currency', - 'has_channel_signature', 'signature_valid', - 'any_tags', 'all_tags', 'not_tags', 'reposted_claim_id', - 'any_locations', 'all_locations', 'not_locations', - 'any_languages', 'all_languages', 'not_languages', - 'is_controlling', 'limit', 'offset', 'order_by', - 'no_totals', -} | INTEGER_PARAMS - - -ORDER_FIELDS = { - 'name', 'claim_hash' -} | INTEGER_PARAMS @dataclass @@ -166,446 +140,3 @@ def reports_metrics(func): return wrapper -@reports_metrics -def search_to_bytes(constraints) -> Union[bytes, Tuple[bytes, Dict]]: - return encode_result(search(constraints)) - - -@reports_metrics -def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]: - return encode_result(resolve(urls)) - - -def encode_result(result): - return Outputs.to_bytes(*result) - - -@measure -def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor) -> List: - context = ctx.get() - context.set_query_timeout() - try: - c = context.db.cursor() - def row_filter(cursor, row): - nonlocal row_offset - row = row_factory(cursor, row) - if len(row) > 1 and censor.censor(row): - return - if row_offset: - row_offset -= 1 - return - return row - c.setrowtrace(row_filter) - i, rows = 0, [] - for row in c.execute(sql, values): - i += 1 - rows.append(row) - if i >= row_limit: - break - return rows - except apsw.Error as err: - plain_sql = interpolate(sql, values) - if context.is_tracking_metrics: - context.metrics['execute_query'][-1]['sql'] = plain_sql - if isinstance(err, apsw.InterruptError): - context.log.warning("interrupted slow sqlite query:\n%s", plain_sql) - raise SQLiteInterruptedError(context.metrics) - context.log.exception('failed running query', exc_info=err) - raise SQLiteOperationalError(context.metrics) - - -def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]: - if 'order_by' in constraints: - order_by_parts = constraints['order_by'] - if isinstance(order_by_parts, str): - order_by_parts = [order_by_parts] - sql_order_by = [] - for order_by in order_by_parts: - is_asc = order_by.startswith('^') - column = order_by[1:] if is_asc else order_by - if column not in ORDER_FIELDS: - raise NameError(f'{column} is not a valid order_by field') - if column == 'name': - column = 'normalized' - sql_order_by.append( - f"claim.{column} ASC" if is_asc else f"claim.{column} DESC" - ) - constraints['order_by'] = sql_order_by - - ops = {'<=': '__lte', '>=': '__gte', '<': '__lt', '>': '__gt'} - for constraint in INTEGER_PARAMS: - if constraint in constraints: - value = constraints.pop(constraint) - postfix = '' - if isinstance(value, str): - if len(value) >= 2 and value[:2] in ops: - postfix, value = ops[value[:2]], value[2:] - elif len(value) >= 1 and value[0] in ops: - postfix, value = ops[value[0]], value[1:] - if constraint == 'fee_amount': - value = Decimal(value)*1000 - constraints[f'claim.{constraint}{postfix}'] = int(value) - - if constraints.pop('is_controlling', False): - if {'sequence', 'amount_order'}.isdisjoint(constraints): - for_count = False - constraints['claimtrie.claim_hash__is_not_null'] = '' - if 'sequence' in constraints: - constraints['order_by'] = 'claim.activation_height ASC' - constraints['offset'] = int(constraints.pop('sequence')) - 1 - constraints['limit'] = 1 - if 'amount_order' in constraints: - constraints['order_by'] = 'claim.effective_amount DESC' - constraints['offset'] = int(constraints.pop('amount_order')) - 1 - constraints['limit'] = 1 - - if 'claim_id' in constraints: - claim_id = constraints.pop('claim_id') - if len(claim_id) == 40: - constraints['claim.claim_id'] = claim_id - else: - constraints['claim.claim_id__like'] = f'{claim_id[:40]}%' - elif 'claim_ids' in constraints: - constraints['claim.claim_id__in'] = set(constraints.pop('claim_ids')) - - if 'reposted_claim_id' in constraints: - constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1] - - if 'name' in constraints: - constraints['claim.normalized'] = normalize_name(constraints.pop('name')) - - if 'public_key_id' in constraints: - constraints['claim.public_key_hash'] = ( - ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id'))) - if 'channel_hash' in constraints: - constraints['claim.channel_hash'] = constraints.pop('channel_hash') - if 'channel_ids' in constraints: - channel_ids = constraints.pop('channel_ids') - if channel_ids: - constraints['claim.channel_hash__in'] = { - unhexlify(cid)[::-1] for cid in channel_ids - } - if 'not_channel_ids' in constraints: - not_channel_ids = constraints.pop('not_channel_ids') - if not_channel_ids: - not_channel_ids_binary = { - unhexlify(ncid)[::-1] for ncid in not_channel_ids - } - constraints['claim.claim_hash__not_in#not_channel_ids'] = not_channel_ids_binary - if constraints.get('has_channel_signature', False): - constraints['claim.channel_hash__not_in'] = not_channel_ids_binary - else: - constraints['null_or_not_channel__or'] = { - 'claim.signature_valid__is_null': True, - 'claim.channel_hash__not_in': not_channel_ids_binary - } - if 'signature_valid' in constraints: - has_channel_signature = constraints.pop('has_channel_signature', False) - if has_channel_signature: - constraints['claim.signature_valid'] = constraints.pop('signature_valid') - else: - constraints['null_or_signature__or'] = { - 'claim.signature_valid__is_null': True, - 'claim.signature_valid': constraints.pop('signature_valid') - } - elif constraints.pop('has_channel_signature', False): - constraints['claim.signature_valid__is_not_null'] = True - - if 'txid' in constraints: - tx_hash = unhexlify(constraints.pop('txid'))[::-1] - nout = constraints.pop('nout', 0) - constraints['claim.txo_hash'] = tx_hash + struct.pack(' List: - if 'channel' in constraints: - channel_url = constraints.pop('channel') - match = resolve_url(channel_url) - if isinstance(match, dict): - constraints['channel_hash'] = match['claim_hash'] - else: - return [{'row_count': 0}] if cols == 'count(*) as row_count' else [] - row_offset = constraints.pop('offset', 0) - row_limit = constraints.pop('limit', 20) - sql, values = claims_query(cols, for_count, **constraints) - return execute_query(sql, values, row_offset, row_limit, censor) - - -@measure -def count_claims(**constraints) -> int: - constraints.pop('offset', None) - constraints.pop('limit', None) - constraints.pop('order_by', None) - count = select_claims(Censor(), 'count(*) as row_count', for_count=True, **constraints) - return count[0]['row_count'] - - -def search_claims(censor: Censor, **constraints) -> List: - return select_claims( - censor, - """ - claimtrie.claim_hash as is_controlling, - claimtrie.last_take_over_height, - claim.claim_hash, claim.txo_hash, - claim.claims_in_channel, claim.reposted, - claim.height, claim.creation_height, - claim.activation_height, claim.expiration_height, - claim.effective_amount, claim.support_amount, - claim.trending_group, claim.trending_mixed, - claim.trending_local, claim.trending_global, - claim.short_url, claim.canonical_url, - claim.channel_hash, claim.reposted_claim_hash, - claim.signature_valid - """, **constraints - ) - - -def _get_referenced_rows(txo_rows: List[dict], censor_channels: List[bytes]): - censor = ctx.get().get_resolve_censor() - repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) - channel_hashes = set(chain( - filter(None, map(itemgetter('channel_hash'), txo_rows)), - censor_channels - )) - - reposted_txos = [] - if repost_hashes: - reposted_txos = search_claims(censor, **{'claim.claim_hash__in': repost_hashes}) - channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) - - channel_txos = [] - if channel_hashes: - channel_txos = search_claims(censor, **{'claim.claim_hash__in': channel_hashes}) - - # channels must come first for client side inflation to work properly - return channel_txos + reposted_txos - -@measure -def search(constraints) -> Tuple[List, List, int, int, Censor]: - assert set(constraints).issubset(SEARCH_PARAMS), \ - f"Search query contains invalid arguments: {set(constraints).difference(SEARCH_PARAMS)}" - total = None - if not constraints.pop('no_totals', False): - total = count_claims(**constraints) - constraints['offset'] = abs(constraints.get('offset', 0)) - constraints['limit'] = min(abs(constraints.get('limit', 10)), 50) - context = ctx.get() - search_censor = context.get_search_censor() - txo_rows = search_claims(search_censor, **constraints) - extra_txo_rows = _get_referenced_rows(txo_rows, search_censor.censored.keys()) - return txo_rows, extra_txo_rows, constraints['offset'], total, search_censor - - -@measure -def resolve(urls) -> Tuple[List, List]: - txo_rows = [resolve_url(raw_url) for raw_url in urls] - extra_txo_rows = _get_referenced_rows( - [txo for txo in txo_rows if isinstance(txo, dict)], - [txo.censor_hash for txo in txo_rows if isinstance(txo, ResolveCensoredError)] - ) - return txo_rows, extra_txo_rows - - -@measure -def resolve_url(raw_url): - censor = ctx.get().get_resolve_censor() - - try: - url = URL.parse(raw_url) - except ValueError as e: - return e - - channel = None - - if url.has_channel: - query = url.channel.to_dict() - if set(query) == {'name'}: - query['is_controlling'] = True - else: - query['order_by'] = ['^creation_height'] - matches = search_claims(censor, **query, limit=1) - if matches: - channel = matches[0] - elif censor.censored: - return ResolveCensoredError(raw_url, next(iter(censor.censored))) - else: - return LookupError(f'Could not find channel in "{raw_url}".') - - if url.has_stream: - query = url.stream.to_dict() - if channel is not None: - if set(query) == {'name'}: - # temporarily emulate is_controlling for claims in channel - query['order_by'] = ['effective_amount', '^height'] - else: - query['order_by'] = ['^channel_join'] - query['channel_hash'] = channel['claim_hash'] - query['signature_valid'] = 1 - elif set(query) == {'name'}: - query['is_controlling'] = 1 - matches = search_claims(censor, **query, limit=1) - if matches: - return matches[0] - elif censor.censored: - return ResolveCensoredError(raw_url, next(iter(censor.censored))) - else: - return LookupError(f'Could not find claim at "{raw_url}".') - - return channel - - -CLAIM_HASH_OR_REPOST_HASH_SQL = f""" -CASE WHEN claim.claim_type = {CLAIM_TYPES['repost']} - THEN claim.reposted_claim_hash - ELSE claim.claim_hash -END -""" - - -def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_count=False): - any_items = set(cleaner(constraints.pop(f'any_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) - all_items = set(cleaner(constraints.pop(f'all_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) - not_items = set(cleaner(constraints.pop(f'not_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) - - all_items = {item for item in all_items if item not in not_items} - any_items = {item for item in any_items if item not in not_items} - - any_queries = {} - - if attr == 'tag': - common_tags = any_items & COMMON_TAGS.keys() - if common_tags: - any_items -= common_tags - if len(common_tags) < 5: - for item in common_tags: - index_name = COMMON_TAGS[item] - any_queries[f'#_common_tag_{index_name}'] = f""" - EXISTS( - SELECT 1 FROM tag INDEXED BY tag_{index_name}_idx - WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=tag.claim_hash - AND tag = '{item}' - ) - """ - elif len(common_tags) >= 5: - constraints.update({ - f'$any_common_tag{i}': item for i, item in enumerate(common_tags) - }) - values = ', '.join( - f':$any_common_tag{i}' for i in range(len(common_tags)) - ) - any_queries[f'#_any_common_tags'] = f""" - EXISTS( - SELECT 1 FROM tag WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=tag.claim_hash - AND tag IN ({values}) - ) - """ - - if any_items: - - constraints.update({ - f'$any_{attr}{i}': item for i, item in enumerate(any_items) - }) - values = ', '.join( - f':$any_{attr}{i}' for i in range(len(any_items)) - ) - if for_count or attr == 'tag': - any_queries[f'#_any_{attr}'] = f""" - {CLAIM_HASH_OR_REPOST_HASH_SQL} IN ( - SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) - ) - """ - else: - any_queries[f'#_any_{attr}'] = f""" - EXISTS( - SELECT 1 FROM {attr} WHERE - {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash - AND {attr} IN ({values}) - ) - """ - - if len(any_queries) == 1: - constraints.update(any_queries) - elif len(any_queries) > 1: - constraints[f'ORed_{attr}_queries__any'] = any_queries - - if all_items: - constraints[f'$all_{attr}_count'] = len(all_items) - constraints.update({ - f'$all_{attr}{i}': item for i, item in enumerate(all_items) - }) - values = ', '.join( - f':$all_{attr}{i}' for i in range(len(all_items)) - ) - if for_count: - constraints[f'#_all_{attr}'] = f""" - {CLAIM_HASH_OR_REPOST_HASH_SQL} IN ( - SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) - GROUP BY claim_hash HAVING COUNT({attr}) = :$all_{attr}_count - ) - """ - else: - constraints[f'#_all_{attr}'] = f""" - {len(all_items)}=( - SELECT count(*) FROM {attr} WHERE - {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash - AND {attr} IN ({values}) - ) - """ - - if not_items: - constraints.update({ - f'$not_{attr}{i}': item for i, item in enumerate(not_items) - }) - values = ', '.join( - f':$not_{attr}{i}' for i in range(len(not_items)) - ) - if for_count: - constraints[f'#_not_{attr}'] = f""" - {CLAIM_HASH_OR_REPOST_HASH_SQL} NOT IN ( - SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) - ) - """ - else: - constraints[f'#_not_{attr}'] = f""" - NOT EXISTS( - SELECT 1 FROM {attr} WHERE - {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash - AND {attr} IN ({values}) - ) - """ diff --git a/lbry/db/tables.py b/lbry/db/tables.py index 7b084e1a2..845475e41 100644 --- a/lbry/db/tables.py +++ b/lbry/db/tables.py @@ -6,6 +6,9 @@ from sqlalchemy import ( ) +SCHEMA_VERSION = '1.4' + + metadata = MetaData() @@ -18,7 +21,6 @@ Version = Table( PubkeyAddress = Table( 'pubkey_address', metadata, Column('address', Text, primary_key=True), - Column('history', Text, nullable=True), Column('used_times', Integer, server_default='0'), ) @@ -41,6 +43,7 @@ Block = Table( Column('previous_hash', LargeBinary), Column('file_number', SmallInteger), Column('height', Integer), + Column('block_filter', LargeBinary, nullable=True) ) @@ -54,6 +57,7 @@ TX = Table( Column('is_verified', Boolean, server_default='FALSE'), Column('purchased_claim_hash', LargeBinary, nullable=True), Column('day', Integer, nullable=True), + Column('tx_filter', LargeBinary, nullable=True) ) @@ -64,14 +68,13 @@ TXO = Table( Column('address', Text), Column('position', Integer), Column('amount', BigInteger), - Column('script', LargeBinary), + Column('script_offset', BigInteger), + Column('script_length', BigInteger), Column('is_reserved', Boolean, server_default='0'), Column('txo_type', Integer, server_default='0'), Column('claim_id', Text, nullable=True), Column('claim_hash', LargeBinary, nullable=True), Column('claim_name', Text, nullable=True), - Column('channel_hash', LargeBinary, nullable=True), - Column('reposted_claim_hash', LargeBinary, nullable=True), ) txo_join_account = TXO.join(AccountAddress, TXO.columns.address == AccountAddress.columns.address) @@ -81,8 +84,37 @@ TXI = Table( 'txi', metadata, Column('tx_hash', LargeBinary, ForeignKey(TX.columns.tx_hash)), Column('txo_hash', LargeBinary, ForeignKey(TXO.columns.txo_hash), primary_key=True), - Column('address', Text), + Column('address', Text, nullable=True), Column('position', Integer), ) txi_join_account = TXI.join(AccountAddress, TXI.columns.address == AccountAddress.columns.address) + + +Claim = Table( + 'claim', metadata, + Column('claim_hash', LargeBinary, primary_key=True), + Column('claim_name', Text), + Column('txo_hash', LargeBinary, ForeignKey(TXO.columns.txo_hash)), + Column('amount', BigInteger), + Column('channel_hash', LargeBinary, nullable=True), + Column('effective_amount', BigInteger, server_default='0'), + Column('reposted_claim_hash', LargeBinary, nullable=True), + Column('activation_height', Integer, server_default='0'), + Column('expiration_height', Integer, server_default='0'), +) + + +Tag = Table( + 'tag', metadata, + Column('claim_hash', LargeBinary), + Column('tag', Text), +) + + +Claimtrie = Table( + 'claimtrie', metadata, + Column('normalized', Text, primary_key=True), + Column('claim_hash', LargeBinary, ForeignKey(Claim.columns.claim_hash)), + Column('last_take_over_height', Integer), +) diff --git a/lbry/db/utils.py b/lbry/db/utils.py new file mode 100644 index 000000000..a9cfff6b8 --- /dev/null +++ b/lbry/db/utils.py @@ -0,0 +1,129 @@ +from itertools import islice +from typing import List, Union + +from sqlalchemy import text, and_ +from sqlalchemy.sql.expression import Select +try: + from sqlalchemy.dialects.postgresql import insert as pg_insert +except ImportError: + pg_insert = None + +from .tables import AccountAddress + + +def chunk(rows, step): + it, total = iter(rows), len(rows) + for _ in range(0, total, step): + yield min(step, total), islice(it, step) + total -= step + + +def constrain_single_or_list(constraints, column, value, convert=lambda x: x): + if value is not None: + if isinstance(value, list): + value = [convert(v) for v in value] + if len(value) == 1: + constraints[column] = value[0] + elif len(value) > 1: + constraints[f"{column}__in"] = value + else: + constraints[column] = convert(value) + return constraints + + +def in_account_ids(account_ids: Union[List[str], str]): + if isinstance(account_ids, list): + if len(account_ids) > 1: + return AccountAddress.c.account.in_(account_ids) + account_ids = account_ids[0] + return AccountAddress.c.account == account_ids + + +def query(table, s: Select, **constraints) -> Select: + limit = constraints.pop('limit', None) + if limit is not None: + s = s.limit(limit) + + offset = constraints.pop('offset', None) + if offset is not None: + s = s.offset(offset) + + order_by = constraints.pop('order_by', None) + if order_by: + if isinstance(order_by, str): + s = s.order_by(text(order_by)) + elif isinstance(order_by, list): + s = s.order_by(text(', '.join(order_by))) + else: + raise ValueError("order_by must be string or list") + + group_by = constraints.pop('group_by', None) + if group_by is not None: + s = s.group_by(text(group_by)) + + account_ids = constraints.pop('account_ids', []) + if account_ids: + s = s.where(in_account_ids(account_ids)) + + if constraints: + s = s.where( + constraints_to_clause(table, constraints) + ) + + return s + + +def constraints_to_clause(tables, constraints): + clause = [] + for key, constraint in constraints.items(): + if key.endswith('__not'): + col, op = key[:-len('__not')], '__ne__' + elif key.endswith('__is_null'): + col = key[:-len('__is_null')] + op = '__eq__' + constraint = None + elif key.endswith('__is_not_null'): + col = key[:-len('__is_not_null')] + op = '__ne__' + constraint = None + elif key.endswith('__lt'): + col, op = key[:-len('__lt')], '__lt__' + elif key.endswith('__lte'): + col, op = key[:-len('__lte')], '__le__' + elif key.endswith('__gt'): + col, op = key[:-len('__gt')], '__gt__' + elif key.endswith('__gte'): + col, op = key[:-len('__gte')], '__ge__' + elif key.endswith('__like'): + col, op = key[:-len('__like')], 'like' + elif key.endswith('__not_like'): + col, op = key[:-len('__not_like')], 'notlike' + elif key.endswith('__in') or key.endswith('__not_in'): + if key.endswith('__in'): + col, op, one_val_op = key[:-len('__in')], 'in_', '__eq__' + else: + col, op, one_val_op = key[:-len('__not_in')], 'notin_', '__ne__' + if isinstance(constraint, Select): + pass + elif constraint: + if isinstance(constraint, (list, set, tuple)): + if len(constraint) == 1: + op = one_val_op + constraint = next(iter(constraint)) + elif isinstance(constraint, str): + constraint = text(constraint) + else: + raise ValueError(f"{col} requires a list, set or string as constraint value.") + else: + continue + else: + col, op = key, '__eq__' + attr = None + for table in tables: + attr = getattr(table.c, col, None) + if attr is not None: + clause.append(getattr(attr, op)(constraint)) + break + if attr is None: + raise ValueError(f"Attribute '{col}' not found on tables: {', '.join([t.name for t in tables])}.") + return and_(*clause)