on resolve, get all claims at once

This commit is contained in:
Victor Shyba 2021-03-11 01:41:55 -03:00
parent 60a59407d8
commit 5dff02e8bc

View file

@ -19,6 +19,14 @@ from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
from lbry.wallet.server.util import class_logger from lbry.wallet.server.util import class_logger
class ChannelResolution(str):
pass
class StreamResolution(str):
pass
class SearchIndex: class SearchIndex:
def __init__(self, index_prefix: str, search_timeout=3.0): def __init__(self, index_prefix: str, search_timeout=3.0):
self.search_timeout = search_timeout self.search_timeout = search_timeout
@ -183,6 +191,17 @@ class SearchIndex:
async def resolve(self, *urls): async def resolve(self, *urls):
censor = Censor(Censor.RESOLVE) censor = Censor(Censor.RESOLVE)
results = [await self.resolve_url(url) for url in urls] results = [await self.resolve_url(url) for url in urls]
missing = await self.get_many(*filter(lambda x: isinstance(x, str), results))
for index in range(len(results)):
result = results[index]
url = urls[index]
if missing.get(result):
results[index] = missing[result]
elif isinstance(result, StreamResolution):
results[index] = LookupError(f'Could not find claim at "{url}".')
elif isinstance(result, ChannelResolution):
results[index] = LookupError(f'Could not find channel in "{url}".')
censored = [ censored = [
result if not isinstance(result, dict) or not censor.censor(result) result if not isinstance(result, dict) or not censor.censor(result)
else ResolveCensoredError(url, result['censoring_channel_hash']) else ResolveCensoredError(url, result['censoring_channel_hash'])
@ -199,7 +218,7 @@ class SearchIndex:
results = expand_result(filter(lambda doc: doc['found'], results["docs"])) results = expand_result(filter(lambda doc: doc['found'], results["docs"]))
for result in results: for result in results:
self.claim_cache.set(result['claim_id'], result) self.claim_cache.set(result['claim_id'], result)
return list(filter(None, map(self.claim_cache.get, claim_ids))) return {claim_id: self.claim_cache[claim_id] for claim_id in claim_ids if claim_id in self.claim_cache}
async def full_id_from_short_id(self, name, short_id, channel_id=None): async def full_id_from_short_id(self, name, short_id, channel_id=None):
key = (channel_id or '') + name + short_id key = (channel_id or '') + name + short_id
@ -246,15 +265,9 @@ class SearchIndex:
return channel_id return channel_id
stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
if url.has_stream: if url.has_stream:
result = stream return StreamResolution(stream)
else: else:
if isinstance(channel_id, str): return ChannelResolution(channel_id)
result = (await self.get_many(channel_id))
result = result[0] if len(result) else LookupError(f'Could not find channel in "{url}".')
else:
result = channel_id
return result
async def resolve_channel_id(self, url: URL): async def resolve_channel_id(self, url: URL):
if not url.has_channel: if not url.has_channel:
@ -290,10 +303,7 @@ class SearchIndex:
claim_id = url.stream.claim_id claim_id = url.stream.claim_id
else: else:
claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id) claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
if claim_id: return claim_id
stream = await self.get_many(claim_id)
return stream[0] if len(stream) else None
return None
if channel_id is not None: if channel_id is not None:
if set(query) == {'name'}: if set(query) == {'name'}:
@ -307,7 +317,7 @@ class SearchIndex:
query['is_controlling'] = True query['is_controlling'] = True
matches, _, _ = await self.search(**query, limit=1) matches, _, _ = await self.search(**query, limit=1)
if matches: if matches:
return matches[0] return matches[0]['claim_id']
async def _get_referenced_rows(self, txo_rows: List[dict]): async def _get_referenced_rows(self, txo_rows: List[dict]):
txo_rows = [row for row in txo_rows if isinstance(row, dict)] txo_rows = [row for row in txo_rows if isinstance(row, dict)]
@ -317,12 +327,12 @@ class SearchIndex:
reposted_txos = [] reposted_txos = []
if repost_hashes: if repost_hashes:
reposted_txos = await self.get_many(*repost_hashes) reposted_txos = list((await self.get_many(*repost_hashes)).values())
channel_hashes |= set(filter(None, (row['channel_id'] for row in reposted_txos))) channel_hashes |= set(filter(None, (row['channel_id'] for row in reposted_txos)))
channel_txos = [] channel_txos = []
if channel_hashes: if channel_hashes:
channel_txos = await self.get_many(*channel_hashes) channel_txos = list((await self.get_many(*channel_hashes)).values())
# channels must come first for client side inflation to work properly # channels must come first for client side inflation to work properly
return channel_txos + reposted_txos return channel_txos + reposted_txos