on resolve, get all claims at once

This commit is contained in:
Victor Shyba 2021-03-11 01:41:55 -03:00
parent 60a59407d8
commit 5dff02e8bc

View file

@ -19,6 +19,14 @@ from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
from lbry.wallet.server.util import class_logger
class ChannelResolution(str):
pass
class StreamResolution(str):
pass
class SearchIndex:
def __init__(self, index_prefix: str, search_timeout=3.0):
self.search_timeout = search_timeout
@ -183,6 +191,17 @@ class SearchIndex:
async def resolve(self, *urls):
censor = Censor(Censor.RESOLVE)
results = [await self.resolve_url(url) for url in urls]
missing = await self.get_many(*filter(lambda x: isinstance(x, str), results))
for index in range(len(results)):
result = results[index]
url = urls[index]
if missing.get(result):
results[index] = missing[result]
elif isinstance(result, StreamResolution):
results[index] = LookupError(f'Could not find claim at "{url}".')
elif isinstance(result, ChannelResolution):
results[index] = LookupError(f'Could not find channel in "{url}".')
censored = [
result if not isinstance(result, dict) or not censor.censor(result)
else ResolveCensoredError(url, result['censoring_channel_hash'])
@ -199,7 +218,7 @@ class SearchIndex:
results = expand_result(filter(lambda doc: doc['found'], results["docs"]))
for result in results:
self.claim_cache.set(result['claim_id'], result)
return list(filter(None, map(self.claim_cache.get, claim_ids)))
return {claim_id: self.claim_cache[claim_id] for claim_id in claim_ids if claim_id in self.claim_cache}
async def full_id_from_short_id(self, name, short_id, channel_id=None):
key = (channel_id or '') + name + short_id
@ -246,15 +265,9 @@ class SearchIndex:
return channel_id
stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
if url.has_stream:
result = stream
return StreamResolution(stream)
else:
if isinstance(channel_id, str):
result = (await self.get_many(channel_id))
result = result[0] if len(result) else LookupError(f'Could not find channel in "{url}".')
else:
result = channel_id
return result
return ChannelResolution(channel_id)
async def resolve_channel_id(self, url: URL):
if not url.has_channel:
@ -290,10 +303,7 @@ class SearchIndex:
claim_id = url.stream.claim_id
else:
claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
if claim_id:
stream = await self.get_many(claim_id)
return stream[0] if len(stream) else None
return None
return claim_id
if channel_id is not None:
if set(query) == {'name'}:
@ -307,7 +317,7 @@ class SearchIndex:
query['is_controlling'] = True
matches, _, _ = await self.search(**query, limit=1)
if matches:
return matches[0]
return matches[0]['claim_id']
async def _get_referenced_rows(self, txo_rows: List[dict]):
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
@ -317,12 +327,12 @@ class SearchIndex:
reposted_txos = []
if repost_hashes:
reposted_txos = await self.get_many(*repost_hashes)
reposted_txos = list((await self.get_many(*repost_hashes)).values())
channel_hashes |= set(filter(None, (row['channel_id'] for row in reposted_txos)))
channel_txos = []
if channel_hashes:
channel_txos = await self.get_many(*channel_hashes)
channel_txos = list((await self.get_many(*channel_hashes)).values())
# channels must come first for client side inflation to work properly
return channel_txos + reposted_txos