From fa34ff88bcf463670c954736afbdfb0f79c8fbe3 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Thu, 10 Sep 2020 21:52:23 -0300 Subject: [PATCH] refactor db, make resolve censor right --- lbry/db/queries/resolve.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/lbry/db/queries/resolve.py b/lbry/db/queries/resolve.py index 7221bf15e..7b6eff9e9 100644 --- a/lbry/db/queries/resolve.py +++ b/lbry/db/queries/resolve.py @@ -6,16 +6,22 @@ from lbry.schema.url import URL from lbry.schema.result import Outputs as ResultOutput from lbry.error import ResolveCensoredError from lbry.blockchain.transaction import Output +from . import rows_to_txos from ..query_context import context -from .search import search_claims +from .search import select_claims log = logging.getLogger(__name__) +def resolve_claims(**constraints): + censor = context().get_resolve_censor() + rows = context().fetchall(select_claims(**constraints)) + return rows_to_txos(rows), censor + + def _get_referenced_rows(txo_rows: List[Output], censor_channels: List[bytes]): - # censor = context().get_resolve_censor() repost_hashes = set(txo.reposted_claim.claim_hash for txo in txo_rows if txo.reposted_claim) channel_hashes = set(itertools.chain( (txo.channel.claim_hash for txo in txo_rows if txo.channel), @@ -24,14 +30,14 @@ def _get_referenced_rows(txo_rows: List[Output], censor_channels: List[bytes]): reposted_txos = [] if repost_hashes: - reposted_txos = search_claims(**{'claim.claim_hash__in': repost_hashes}) + reposted_txos = resolve_claims(**{'claim.claim_hash__in': repost_hashes}) if reposted_txos: reposted_txos = reposted_txos[0] channel_hashes |= set(txo.channel.claim_hash for txo in reposted_txos if txo.channel) channel_txos = [] if channel_hashes: - channel_txos = search_claims(**{'claim.claim_hash__in': channel_hashes}) + channel_txos = resolve_claims(**{'claim.claim_hash__in': channel_hashes}) channel_txos = channel_txos[0] if channel_txos else [] # channels must come first for client side inflation to work properly @@ -52,8 +58,6 @@ def resolve(urls, **kwargs) -> Dict[str, Output]: def resolve_url(raw_url): - censor = context().get_resolve_censor() - try: url = URL.parse(raw_url) except ValueError as e: @@ -67,13 +71,12 @@ def resolve_url(raw_url): q['is_controlling'] = True else: q['order_by'] = ['^creation_height'] - #matches = search_claims(censor, **q, limit=1) - matches = search_claims(**q, limit=1)[0] + matches, censor = resolve_claims(**q, limit=1) if matches: channel = matches[0] elif censor.censored: return ResolveCensoredError(raw_url, next(iter(censor.censored))) - else: + elif not channel: return LookupError(f'Could not find channel in "{raw_url}".') if url.has_stream: @@ -84,12 +87,10 @@ def resolve_url(raw_url): q['is_signature_valid'] = True elif set(q) == {'name'}: q['is_controlling'] = True - # matches = search_claims(censor, **q, limit=1) - matches = search_claims(**q, limit=1)[0] + matches, censor = resolve_claims(**q, limit=1) if matches: stream = matches[0] - if channel: - stream.channel = channel + stream.channel = channel return stream elif censor.censored: return ResolveCensoredError(raw_url, next(iter(censor.censored)))