diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 758358d30..883b7877f 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -234,54 +234,56 @@ class SearchIndex: async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return per_channel_per_page = kwargs.pop('limit_claims_per_channel') - limit = kwargs.pop('limit', 10) + page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache) if cache_item.result is not None: - to_inflate = cache_item.result + reordered_hits = cache_item.result else: async with cache_item.lock: if cache_item.result: - to_inflate = cache_item.result + reordered_hits = cache_item.result else: query = expand_query(**kwargs) - to_inflate = await self.__search_ahead(query, limit, per_channel_per_page) - cache_item.result = to_inflate - return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) + reordered_hits = await self.__search_ahead(query, page_size, per_channel_per_page) + cache_item.result = reordered_hits + return list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) - async def __search_ahead(self, query, limit, per_channel_per_page): - result = deque((await self.search_client.search( + async def __search_ahead(self, query: dict, page_size: int, per_channel_per_page: int): + search_hits = deque((await self.search_client.search( query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] ))['hits']['hits']) - to_inflate = [] - channel_counter = Counter() - delayed = deque() - while result or delayed: - if len(to_inflate) % limit == 0: - channel_counter.clear() + reordered_hits = [] + channel_counters = Counter() + next_page_hits_maybe_check_later = deque() + while search_hits or next_page_hits_maybe_check_later: + if reordered_hits and len(reordered_hits) % page_size == 0: + channel_counters.clear() + elif not reordered_hits: + pass else: break # means last page was incomplete and we are left with bad replacements - for _ in range(len(delayed)): - claim_id, channel_id = delayed.popleft() - if channel_counter[channel_id] < per_channel_per_page: - to_inflate.append((claim_id, channel_id)) - channel_counter[channel_id] += 1 + for _ in range(len(next_page_hits_maybe_check_later)): + claim_id, channel_id = next_page_hits_maybe_check_later.popleft() + if channel_counters[channel_id] < per_channel_per_page: + reordered_hits.append((claim_id, channel_id)) + channel_counters[channel_id] += 1 else: - delayed.append((claim_id, channel_id)) - while result: - hit = result.popleft() + next_page_hits_maybe_check_later.append((claim_id, channel_id)) + while search_hits: + hit = search_hits.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] if hit_channel_id is None: - to_inflate.append((hit_id, hit_channel_id)) - elif channel_counter[hit_channel_id] < per_channel_per_page: - to_inflate.append((hit_id, hit_channel_id)) - channel_counter[hit_channel_id] += 1 - if len(to_inflate) % limit == 0: + reordered_hits.append((hit_id, hit_channel_id)) + elif channel_counters[hit_channel_id] < per_channel_per_page: + reordered_hits.append((hit_id, hit_channel_id)) + channel_counters[hit_channel_id] += 1 + if len(reordered_hits) % page_size == 0: break else: - delayed.append((hit_id, hit_channel_id)) - return to_inflate + next_page_hits_maybe_check_later.append((hit_id, hit_channel_id)) + return reordered_hits async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: