diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 98cc467f2..883b7877f 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -1,7 +1,7 @@ import asyncio import struct from binascii import unhexlify -from collections import Counter +from collections import Counter, deque from decimal import Decimal from operator import itemgetter from typing import Optional, List, Iterable, Union @@ -42,7 +42,7 @@ class SearchIndex: self.index = index_prefix + 'claims' self.logger = class_logger(__name__, self.__class__.__name__) self.claim_cache = LRUCache(2 ** 15) - self.short_id_cache = LRUCache(2 ** 17) # never invalidated, since short ids are forever + self.short_id_cache = LRUCache(2 ** 17) self.search_cache = LRUCache(2 ** 17) self.resolution_cache = LRUCache(2 ** 17) self._elastic_host = elastic_host @@ -134,6 +134,7 @@ class SearchIndex: self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) self.search_cache.clear() + self.short_id_cache.clear() self.claim_cache.clear() self.resolution_cache.clear() @@ -198,7 +199,7 @@ class SearchIndex: self.claim_cache.set(result['claim_id'], result) async def full_id_from_short_id(self, name, short_id, channel_id=None): - key = (channel_id or '') + name + short_id + key = '#'.join((channel_id or '', name, short_id)) if key not in self.short_id_cache: query = {'name': name, 'claim_id': short_id} if channel_id: @@ -233,40 +234,56 @@ class SearchIndex: async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return per_channel_per_page = kwargs.pop('limit_claims_per_channel') - limit = kwargs.pop('limit', 10) + page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 - query = expand_query(**kwargs) - result = (await self.search_client.search( + cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache) + if cache_item.result is not None: + reordered_hits = cache_item.result + else: + async with cache_item.lock: + if cache_item.result: + reordered_hits = cache_item.result + else: + query = expand_query(**kwargs) + reordered_hits = await self.__search_ahead(query, page_size, per_channel_per_page) + cache_item.result = reordered_hits + return list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) + + async def __search_ahead(self, query: dict, page_size: int, per_channel_per_page: int): + search_hits = deque((await self.search_client.search( query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] - ))['hits'] - to_inflate = [] - channel_counter = Counter() - delayed = [] - while result['hits'] or delayed: - if len(to_inflate) % limit == 0: - channel_counter.clear() + ))['hits']['hits']) + reordered_hits = [] + channel_counters = Counter() + next_page_hits_maybe_check_later = deque() + while search_hits or next_page_hits_maybe_check_later: + if reordered_hits and len(reordered_hits) % page_size == 0: + channel_counters.clear() + elif not reordered_hits: + pass else: break # means last page was incomplete and we are left with bad replacements - new_delayed = [] - for id, chann in delayed: - if channel_counter[chann] < per_channel_per_page: - to_inflate.append((id, chann)) - channel_counter[hit_channel_id] += 1 + for _ in range(len(next_page_hits_maybe_check_later)): + claim_id, channel_id = next_page_hits_maybe_check_later.popleft() + if channel_counters[channel_id] < per_channel_per_page: + reordered_hits.append((claim_id, channel_id)) + channel_counters[channel_id] += 1 else: - new_delayed.append((id, chann)) - delayed = new_delayed - while result['hits']: - hit = result['hits'].pop(0) + next_page_hits_maybe_check_later.append((claim_id, channel_id)) + while search_hits: + hit = search_hits.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] - if channel_counter[hit_channel_id] < per_channel_per_page: - to_inflate.append((hit_id, hit_channel_id)) - channel_counter[hit_channel_id] += 1 - if len(to_inflate) % limit == 0: + if hit_channel_id is None: + reordered_hits.append((hit_id, hit_channel_id)) + elif channel_counters[hit_channel_id] < per_channel_per_page: + reordered_hits.append((hit_id, hit_channel_id)) + channel_counters[hit_channel_id] += 1 + if len(reordered_hits) % page_size == 0: break else: - delayed.append((hit_id, hit_channel_id)) - return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) + next_page_hits_maybe_check_later.append((hit_id, hit_channel_id)) + return reordered_hits async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: