import time import struct import apsw import logging from operator import itemgetter from typing import Tuple, List, Dict, Union, Type, Optional from binascii import unhexlify from decimal import Decimal from contextvars import ContextVar from functools import wraps from itertools import chain from dataclasses import dataclass from lbry.wallet.database import query, interpolate from lbry.error import ResolveCensoredError from lbry.schema.url import URL, normalize_name from lbry.schema.tags import clean_tags from lbry.schema.result import Outputs, Censor from lbry.wallet import Ledger, RegTestLedger from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS from .full_text_search import FTS_ORDER_BY class SQLiteOperationalError(apsw.Error): def __init__(self, metrics): super().__init__('sqlite query errored') self.metrics = metrics class SQLiteInterruptedError(apsw.InterruptError): def __init__(self, metrics): super().__init__('sqlite query interrupted') self.metrics = metrics ATTRIBUTE_ARRAY_MAX_LENGTH = 100 INTEGER_PARAMS = { 'height', 'creation_height', 'activation_height', 'expiration_height', 'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount', 'tx_position', 'channel_join', 'reposted', 'amount', 'effective_amount', 'support_amount', 'trending_group', 'trending_mixed', 'trending_local', 'trending_global', } SEARCH_PARAMS = { 'name', 'text', 'claim_id', 'claim_ids', 'txid', 'nout', 'channel', 'channel_ids', 'not_channel_ids', 'public_key_id', 'claim_type', 'stream_types', 'media_types', 'fee_currency', 'has_channel_signature', 'signature_valid', 'any_tags', 'all_tags', 'not_tags', 'reposted_claim_id', 'any_locations', 'all_locations', 'not_locations', 'any_languages', 'all_languages', 'not_languages', 'is_controlling', 'limit', 'offset', 'order_by', 'no_totals', } | INTEGER_PARAMS ORDER_FIELDS = { 'name', 'claim_hash' } | INTEGER_PARAMS @dataclass class ReaderState: db: apsw.Connection stack: List[List] metrics: Dict is_tracking_metrics: bool ledger: Type[Ledger] query_timeout: float log: logging.Logger blocked_streams: Dict blocked_channels: Dict filtered_streams: Dict filtered_channels: Dict def close(self): self.db.close() def reset_metrics(self): self.stack = [] self.metrics = {} def set_query_timeout(self): stop_at = time.perf_counter() + self.query_timeout def interruptor(): if time.perf_counter() >= stop_at: self.db.interrupt() return self.db.setprogresshandler(interruptor, 100) def get_resolve_censor(self) -> Censor: return Censor(self.blocked_streams, self.blocked_channels) def get_search_censor(self) -> Censor: return Censor(self.filtered_streams, self.filtered_channels) ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx') def row_factory(cursor, row): return { k[0]: (set(row[i].split(',')) if k[0] == 'tags' else row[i]) for i, k in enumerate(cursor.getdescription()) } def initializer(log, _path, _ledger_name, query_timeout, _measure=False, block_and_filter=None): db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) db.setrowtrace(row_factory) if block_and_filter: blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter else: blocked_streams = blocked_channels = filtered_streams = filtered_channels = {} ctx.set( ReaderState( db=db, stack=[], metrics={}, is_tracking_metrics=_measure, ledger=Ledger if _ledger_name == 'mainnet' else RegTestLedger, query_timeout=query_timeout, log=log, blocked_streams=blocked_streams, blocked_channels=blocked_channels, filtered_streams=filtered_streams, filtered_channels=filtered_channels, ) ) def cleanup(): ctx.get().close() ctx.set(None) def measure(func): @wraps(func) def wrapper(*args, **kwargs): state = ctx.get() if not state.is_tracking_metrics: return func(*args, **kwargs) metric = {} state.metrics.setdefault(func.__name__, []).append(metric) state.stack.append([]) start = time.perf_counter() try: return func(*args, **kwargs) finally: elapsed = int((time.perf_counter()-start)*1000) metric['total'] = elapsed metric['isolated'] = (elapsed-sum(state.stack.pop())) if state.stack: state.stack[-1].append(elapsed) return wrapper def reports_metrics(func): @wraps(func) def wrapper(*args, **kwargs): state = ctx.get() if not state.is_tracking_metrics: return func(*args, **kwargs) state.reset_metrics() r = func(*args, **kwargs) return r, state.metrics return wrapper @reports_metrics def search_to_bytes(constraints) -> Union[bytes, Tuple[bytes, Dict]]: return encode_result(search(constraints)) @reports_metrics def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]: return encode_result(resolve(urls)) def encode_result(result): return Outputs.to_bytes(*result) @measure def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor) -> List: context = ctx.get() context.set_query_timeout() try: c = context.db.cursor() def row_filter(cursor, row): nonlocal row_offset row = row_factory(cursor, row) if len(row) > 1 and censor.censor(row): return if row_offset: row_offset -= 1 return return row c.setrowtrace(row_filter) i, rows = 0, [] for row in c.execute(sql, values): i += 1 rows.append(row) if i >= row_limit: break return rows except apsw.Error as err: plain_sql = interpolate(sql, values) if context.is_tracking_metrics: context.metrics['execute_query'][-1]['sql'] = plain_sql if isinstance(err, apsw.InterruptError): context.log.warning("interrupted slow sqlite query:\n%s", plain_sql) raise SQLiteInterruptedError(context.metrics) context.log.exception('failed running query', exc_info=err) raise SQLiteOperationalError(context.metrics) def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]: if 'order_by' in constraints: order_by_parts = constraints['order_by'] if isinstance(order_by_parts, str): order_by_parts = [order_by_parts] sql_order_by = [] for order_by in order_by_parts: is_asc = order_by.startswith('^') column = order_by[1:] if is_asc else order_by if column not in ORDER_FIELDS: raise NameError(f'{column} is not a valid order_by field') if column == 'name': column = 'normalized' sql_order_by.append( f"claim.{column} ASC" if is_asc else f"claim.{column} DESC" ) constraints['order_by'] = sql_order_by ops = {'<=': '__lte', '>=': '__gte', '<': '__lt', '>': '__gt'} for constraint in INTEGER_PARAMS: if constraint in constraints: value = constraints.pop(constraint) postfix = '' if isinstance(value, str): if len(value) >= 2 and value[:2] in ops: postfix, value = ops[value[:2]], value[2:] elif len(value) >= 1 and value[0] in ops: postfix, value = ops[value[0]], value[1:] if constraint == 'fee_amount': value = Decimal(value)*1000 constraints[f'claim.{constraint}{postfix}'] = int(value) if constraints.pop('is_controlling', False): if {'sequence', 'amount_order'}.isdisjoint(constraints): for_count = False constraints['claimtrie.claim_hash__is_not_null'] = '' if 'sequence' in constraints: constraints['order_by'] = 'claim.activation_height ASC' constraints['offset'] = int(constraints.pop('sequence')) - 1 constraints['limit'] = 1 if 'amount_order' in constraints: constraints['order_by'] = 'claim.effective_amount DESC' constraints['offset'] = int(constraints.pop('amount_order')) - 1 constraints['limit'] = 1 if 'claim_id' in constraints: claim_id = constraints.pop('claim_id') if len(claim_id) == 40: constraints['claim.claim_id'] = claim_id else: constraints['claim.claim_id__like'] = f'{claim_id[:40]}%' elif 'claim_ids' in constraints: constraints['claim.claim_id__in'] = set(constraints.pop('claim_ids')) if 'reposted_claim_id' in constraints: constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1] if 'name' in constraints: constraints['claim.normalized'] = normalize_name(constraints.pop('name')) if 'public_key_id' in constraints: constraints['claim.public_key_hash'] = ( ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id'))) if 'channel_hash' in constraints: constraints['claim.channel_hash'] = constraints.pop('channel_hash') if 'channel_ids' in constraints: channel_ids = constraints.pop('channel_ids') if channel_ids: constraints['claim.channel_hash__in'] = { unhexlify(cid)[::-1] for cid in channel_ids } if 'not_channel_ids' in constraints: not_channel_ids = constraints.pop('not_channel_ids') if not_channel_ids: not_channel_ids_binary = { unhexlify(ncid)[::-1] for ncid in not_channel_ids } if constraints.get('has_channel_signature', False): constraints['claim.channel_hash__not_in'] = not_channel_ids_binary else: constraints['null_or_not_channel__or'] = { 'claim.signature_valid__is_null': True, 'claim.channel_hash__not_in': not_channel_ids_binary } if 'signature_valid' in constraints: has_channel_signature = constraints.pop('has_channel_signature', False) if has_channel_signature: constraints['claim.signature_valid'] = constraints.pop('signature_valid') else: constraints['null_or_signature__or'] = { 'claim.signature_valid__is_null': True, 'claim.signature_valid': constraints.pop('signature_valid') } elif constraints.pop('has_channel_signature', False): constraints['claim.signature_valid__is_not_null'] = True if 'txid' in constraints: tx_hash = unhexlify(constraints.pop('txid'))[::-1] nout = constraints.pop('nout', 0) constraints['claim.txo_hash'] = tx_hash + struct.pack(' List: if 'channel' in constraints: channel_url = constraints.pop('channel') match = resolve_url(channel_url) if isinstance(match, dict): constraints['channel_hash'] = match['claim_hash'] else: return [{'row_count': 0}] if cols == 'count(*) as row_count' else [] row_offset = constraints.pop('offset', 0) row_limit = constraints.pop('limit', 20) sql, values = claims_query(cols, for_count, **constraints) return execute_query(sql, values, row_offset, row_limit, censor) @measure def count_claims(**constraints) -> int: constraints.pop('offset', None) constraints.pop('limit', None) constraints.pop('order_by', None) count = select_claims(Censor(), 'count(*) as row_count', for_count=True, **constraints) return count[0]['row_count'] def search_claims(censor: Censor, **constraints) -> List: return select_claims( censor, """ claimtrie.claim_hash as is_controlling, claimtrie.last_take_over_height, claim.claim_hash, claim.txo_hash, claim.claims_in_channel, claim.reposted, claim.height, claim.creation_height, claim.activation_height, claim.expiration_height, claim.effective_amount, claim.support_amount, claim.trending_group, claim.trending_mixed, claim.trending_local, claim.trending_global, claim.short_url, claim.canonical_url, claim.channel_hash, claim.reposted_claim_hash, claim.signature_valid """, **constraints ) def _get_referenced_rows(txo_rows: List[dict], censor_channels: List[bytes]): censor = ctx.get().get_resolve_censor() repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) channel_hashes = set(chain( filter(None, map(itemgetter('channel_hash'), txo_rows)), censor_channels )) reposted_txos = [] if repost_hashes: reposted_txos = search_claims(censor, **{'claim.claim_hash__in': repost_hashes}) channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) channel_txos = [] if channel_hashes: channel_txos = search_claims(censor, **{'claim.claim_hash__in': channel_hashes}) # channels must come first for client side inflation to work properly return channel_txos + reposted_txos @measure def search(constraints) -> Tuple[List, List, int, int, Censor]: assert set(constraints).issubset(SEARCH_PARAMS), \ f"Search query contains invalid arguments: {set(constraints).difference(SEARCH_PARAMS)}" total = None if not constraints.pop('no_totals', False): total = count_claims(**constraints) constraints['offset'] = abs(constraints.get('offset', 0)) constraints['limit'] = min(abs(constraints.get('limit', 10)), 50) context = ctx.get() search_censor = context.get_search_censor() txo_rows = search_claims(search_censor, **constraints) extra_txo_rows = _get_referenced_rows(txo_rows, search_censor.censored.keys()) return txo_rows, extra_txo_rows, constraints['offset'], total, search_censor @measure def resolve(urls) -> Tuple[List, List]: txo_rows = [resolve_url(raw_url) for raw_url in urls] extra_txo_rows = _get_referenced_rows( [txo for txo in txo_rows if isinstance(txo, dict)], [txo.censor_hash for txo in txo_rows if isinstance(txo, ResolveCensoredError)] ) return txo_rows, extra_txo_rows @measure def resolve_url(raw_url): censor = ctx.get().get_resolve_censor() try: url = URL.parse(raw_url) except ValueError as e: return e channel = None if url.has_channel: query = url.channel.to_dict() if set(query) == {'name'}: query['is_controlling'] = True else: query['order_by'] = ['^creation_height'] matches = search_claims(censor, **query, limit=1) if matches: channel = matches[0] elif censor.censored: return ResolveCensoredError(raw_url, next(iter(censor.censored))) else: return LookupError(f'Could not find channel in "{raw_url}".') if url.has_stream: query = url.stream.to_dict() if channel is not None: if set(query) == {'name'}: # temporarily emulate is_controlling for claims in channel query['order_by'] = ['effective_amount', '^height'] else: query['order_by'] = ['^channel_join'] query['channel_hash'] = channel['claim_hash'] query['signature_valid'] = 1 elif set(query) == {'name'}: query['is_controlling'] = 1 matches = search_claims(censor, **query, limit=1) if matches: return matches[0] elif censor.censored: return ResolveCensoredError(raw_url, next(iter(censor.censored))) else: return LookupError(f'Could not find claim at "{raw_url}".') return channel def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_count=False): any_items = set(cleaner(constraints.pop(f'any_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) all_items = set(cleaner(constraints.pop(f'all_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) not_items = set(cleaner(constraints.pop(f'not_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) all_items = {item for item in all_items if item not in not_items} any_items = {item for item in any_items if item not in not_items} any_queries = {} if attr == 'tag': common_tags = any_items & COMMON_TAGS.keys() if common_tags: any_items -= common_tags if len(common_tags) < 5: for item in common_tags: index_name = COMMON_TAGS[item] any_queries[f'#_common_tag_{index_name}'] = f""" EXISTS( SELECT 1 FROM tag INDEXED BY tag_{index_name}_idx WHERE claim.claim_hash=tag.claim_hash AND tag = '{item}' ) """ elif len(common_tags) >= 5: constraints.update({ f'$any_common_tag{i}': item for i, item in enumerate(common_tags) }) values = ', '.join( f':$any_common_tag{i}' for i in range(len(common_tags)) ) any_queries[f'#_any_common_tags'] = f""" EXISTS( SELECT 1 FROM tag WHERE claim.claim_hash=tag.claim_hash AND tag IN ({values}) ) """ if any_items: constraints.update({ f'$any_{attr}{i}': item for i, item in enumerate(any_items) }) values = ', '.join( f':$any_{attr}{i}' for i in range(len(any_items)) ) if for_count or attr == 'tag': any_queries[f'claim.claim_hash__in#_any_{attr}'] = f""" SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) """ else: any_queries[f'#_any_{attr}'] = f""" EXISTS( SELECT 1 FROM {attr} WHERE claim.claim_hash={attr}.claim_hash AND {attr} IN ({values}) ) """ if len(any_queries) == 1: constraints.update(any_queries) elif len(any_queries) > 1: constraints[f'ORed_{attr}_queries__any'] = any_queries if all_items: constraints[f'$all_{attr}_count'] = len(all_items) constraints.update({ f'$all_{attr}{i}': item for i, item in enumerate(all_items) }) values = ', '.join( f':$all_{attr}{i}' for i in range(len(all_items)) ) if for_count: constraints[f'claim.claim_hash__in#_all_{attr}'] = f""" SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) GROUP BY claim_hash HAVING COUNT({attr}) = :$all_{attr}_count """ else: constraints[f'#_all_{attr}'] = f""" {len(all_items)}=( SELECT count(*) FROM {attr} WHERE claim.claim_hash={attr}.claim_hash AND {attr} IN ({values}) ) """ if not_items: constraints.update({ f'$not_{attr}{i}': item for i, item in enumerate(not_items) }) values = ', '.join( f':$not_{attr}{i}' for i in range(len(not_items)) ) if for_count: constraints[f'claim.claim_hash__not_in#_not_{attr}'] = f""" SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) """ else: constraints[f'#_not_{attr}'] = f""" NOT EXISTS( SELECT 1 FROM {attr} WHERE claim.claim_hash={attr}.claim_hash AND {attr} IN ({values}) ) """