From 5dff02e8bc1a856565e659aef84110e73ddc692a Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Thu, 11 Mar 2021 01:41:55 -0300 Subject: [PATCH] on resolve, get all claims at once --- lbry/wallet/server/db/elastic_search.py | 42 +++++++++++++++---------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/lbry/wallet/server/db/elastic_search.py b/lbry/wallet/server/db/elastic_search.py index 776af9827..af569a10d 100644 --- a/lbry/wallet/server/db/elastic_search.py +++ b/lbry/wallet/server/db/elastic_search.py @@ -19,6 +19,14 @@ from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES from lbry.wallet.server.util import class_logger +class ChannelResolution(str): + pass + + +class StreamResolution(str): + pass + + class SearchIndex: def __init__(self, index_prefix: str, search_timeout=3.0): self.search_timeout = search_timeout @@ -183,6 +191,17 @@ class SearchIndex: async def resolve(self, *urls): censor = Censor(Censor.RESOLVE) results = [await self.resolve_url(url) for url in urls] + missing = await self.get_many(*filter(lambda x: isinstance(x, str), results)) + for index in range(len(results)): + result = results[index] + url = urls[index] + if missing.get(result): + results[index] = missing[result] + elif isinstance(result, StreamResolution): + results[index] = LookupError(f'Could not find claim at "{url}".') + elif isinstance(result, ChannelResolution): + results[index] = LookupError(f'Could not find channel in "{url}".') + censored = [ result if not isinstance(result, dict) or not censor.censor(result) else ResolveCensoredError(url, result['censoring_channel_hash']) @@ -199,7 +218,7 @@ class SearchIndex: results = expand_result(filter(lambda doc: doc['found'], results["docs"])) for result in results: self.claim_cache.set(result['claim_id'], result) - return list(filter(None, map(self.claim_cache.get, claim_ids))) + return {claim_id: self.claim_cache[claim_id] for claim_id in claim_ids if claim_id in self.claim_cache} async def full_id_from_short_id(self, name, short_id, channel_id=None): key = (channel_id or '') + name + short_id @@ -246,15 +265,9 @@ class SearchIndex: return channel_id stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream if url.has_stream: - result = stream + return StreamResolution(stream) else: - if isinstance(channel_id, str): - result = (await self.get_many(channel_id)) - result = result[0] if len(result) else LookupError(f'Could not find channel in "{url}".') - else: - result = channel_id - - return result + return ChannelResolution(channel_id) async def resolve_channel_id(self, url: URL): if not url.has_channel: @@ -290,10 +303,7 @@ class SearchIndex: claim_id = url.stream.claim_id else: claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id) - if claim_id: - stream = await self.get_many(claim_id) - return stream[0] if len(stream) else None - return None + return claim_id if channel_id is not None: if set(query) == {'name'}: @@ -307,7 +317,7 @@ class SearchIndex: query['is_controlling'] = True matches, _, _ = await self.search(**query, limit=1) if matches: - return matches[0] + return matches[0]['claim_id'] async def _get_referenced_rows(self, txo_rows: List[dict]): txo_rows = [row for row in txo_rows if isinstance(row, dict)] @@ -317,12 +327,12 @@ class SearchIndex: reposted_txos = [] if repost_hashes: - reposted_txos = await self.get_many(*repost_hashes) + reposted_txos = list((await self.get_many(*repost_hashes)).values()) channel_hashes |= set(filter(None, (row['channel_id'] for row in reposted_txos))) channel_txos = [] if channel_hashes: - channel_txos = await self.get_many(*channel_hashes) + channel_txos = list((await self.get_many(*channel_hashes)).values()) # channels must come first for client side inflation to work properly return channel_txos + reposted_txos