diff --git a/lbry/lbry/wallet/server/db/reader.py b/lbry/lbry/wallet/server/db/reader.py index 032ca8355..66073c6c4 100644 --- a/lbry/lbry/wallet/server/db/reader.py +++ b/lbry/lbry/wallet/server/db/reader.py @@ -177,7 +177,7 @@ def execute_query(sql, values) -> List: raise SQLiteOperationalError(context.metrics) -def get_claims(cols, for_count=False, **constraints) -> List: +def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]: if 'order_by' in constraints: sql_order_by = [] for order_by in constraints['order_by']: @@ -232,14 +232,6 @@ def get_claims(cols, for_count=False, **constraints) -> List: if 'public_key_id' in constraints: constraints['claim.public_key_hash'] = sqlite3.Binary( ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id'))) - - if 'channel' in constraints: - channel_url = constraints.pop('channel') - match = resolve_url(channel_url) - if isinstance(match, sqlite3.Row): - constraints['channel_hash'] = match['claim_hash'] - else: - return [[0]] if cols == 'count(*)' else [] if 'channel_hash' in constraints: constraints['claim.channel_hash'] = sqlite3.Binary(constraints.pop('channel_hash')) if 'channel_ids' in constraints: @@ -308,7 +300,18 @@ def get_claims(cols, for_count=False, **constraints) -> List: LEFT JOIN claim as channel ON (claim.channel_hash=channel.claim_hash) """, **constraints ) + return sql, values + +def get_claims(cols, for_count=False, **constraints) -> List: + if 'channel' in constraints: + channel_url = constraints.pop('channel') + match = resolve_url(channel_url) + if isinstance(match, sqlite3.Row): + constraints['channel_hash'] = match['claim_hash'] + else: + return [[0]] if cols == 'count(*)' else [] + sql, values = _get_claims(cols, for_count, **constraints) return execute_query(sql, values) diff --git a/lbry/scripts/claim_search_performance.py b/lbry/scripts/claim_search_performance.py index 3fcbc88bd..88c03d373 100644 --- a/lbry/scripts/claim_search_performance.py +++ b/lbry/scripts/claim_search_performance.py @@ -1,10 +1,10 @@ -import sys import os import time +import argparse import asyncio import logging from concurrent.futures.process import ProcessPoolExecutor -from lbry.wallet.server.db.reader import search_to_bytes, initializer +from lbry.wallet.server.db.reader import search_to_bytes, initializer, _get_claims, interpolate from lbry.wallet.ledger import MainNetLedger log = logging.getLogger(__name__) @@ -90,13 +90,12 @@ def get_args(limit=20): def _search(kwargs): start = time.time() - msg = f"offset={kwargs['offset']}, limit={kwargs['limit']}, no_totals={kwargs['no_totals']}, not_tags={kwargs.get('not_tags')}, any_tags={kwargs.get('any_tags')}, order_by={kwargs['order_by']}" try: search_to_bytes(kwargs) t = time.time() - start - return t, f"{t} - {msg}" + return t, kwargs except Exception as err: - return -1, f"failed: error={str(type(err))}({str(err)}) - {msg}" + return -1, f"failed: error={str(type(err))}({str(err)})" async def search(executor, kwargs): @@ -116,21 +115,32 @@ async def main(db_path, max_query_time): tasks = [search(query_executor, constraints) for constraints in get_args()] try: results = await asyncio.gather(*tasks) - times = {msg: ts for ts, msg in results} - log.info("\n".join(sorted(filter(lambda msg: times[msg] > max_query_time, times.keys()), key=lambda msg: times[msg]))) + for ts, constraints in results: + if ts >= max_query_time: + sql = interpolate(*_get_claims(""" + claimtrie.claim_hash as is_controlling, + claimtrie.last_take_over_height, + claim.claim_hash, claim.txo_hash, + claim.claims_in_channel, + claim.height, claim.creation_height, + claim.activation_height, claim.expiration_height, + claim.effective_amount, claim.support_amount, + claim.trending_group, claim.trending_mixed, + claim.trending_local, claim.trending_global, + claim.short_url, claim.canonical_url, + claim.channel_hash, channel.txo_hash AS channel_txo_hash, + channel.height AS channel_height, claim.signature_valid + """, **constraints)) + print(f"Query took {int(ts * 1000)}ms\n{sql}") finally: query_executor.shutdown() if __name__ == "__main__": - args = sys.argv[1:] - if len(args) >= 1: - db_path = args[0] - else: - db_path = os.path.expanduser('~/claims.db') - if len(args) >= 2: - max_query_time = float(args[1]) - else: - max_query_time = -3 - + parser = argparse.ArgumentParser() + parser.add_argument('--db_path', dest='db_path', default=os.path.expanduser('~/claims.db'), type=str) + parser.add_argument('--max_time', dest='max_time', default=0.0, type=float) + args = parser.parse_args() + db_path = args.db_path + max_query_time = args.max_time asyncio.run(main(db_path, max_query_time))