From fdb0e226560650e5a48802fc02256eade369951e Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Fri, 23 Apr 2021 00:50:35 -0300 Subject: [PATCH] cache search_ahead --- lbry/wallet/server/db/elasticsearch/search.py | 48 ++++++++++++------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 98cc467f2..7b6daa76d 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -1,7 +1,7 @@ import asyncio import struct from binascii import unhexlify -from collections import Counter +from collections import Counter, deque from decimal import Decimal from operator import itemgetter from typing import Optional, List, Iterable, Union @@ -236,37 +236,51 @@ class SearchIndex: limit = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 - query = expand_query(**kwargs) - result = (await self.search_client.search( + cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache) + if cache_item.result is not None: + to_inflate = cache_item.result + else: + async with cache_item.lock: + if cache_item.result: + to_inflate = cache_item.result + else: + query = expand_query(**kwargs) + to_inflate = await self.__search_ahead(query, limit, per_channel_per_page) + cache_item.result = to_inflate + return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) + + async def __search_ahead(self, query, limit, per_channel_per_page): + result = deque((await self.search_client.search( query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] - ))['hits'] + ))['hits']['hits']) to_inflate = [] channel_counter = Counter() - delayed = [] - while result['hits'] or delayed: + delayed = deque() + while result or delayed: if len(to_inflate) % limit == 0: channel_counter.clear() else: break # means last page was incomplete and we are left with bad replacements - new_delayed = [] - for id, chann in delayed: - if channel_counter[chann] < per_channel_per_page: - to_inflate.append((id, chann)) - channel_counter[hit_channel_id] += 1 + for _ in range(len(delayed)): + claim_id, channel_id = delayed.popleft() + if channel_counter[channel_id] < per_channel_per_page: + to_inflate.append((claim_id, channel_id)) + channel_counter[channel_id] += 1 else: - new_delayed.append((id, chann)) - delayed = new_delayed - while result['hits']: - hit = result['hits'].pop(0) + delayed.append((claim_id, channel_id)) + while result: + hit = result.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] - if channel_counter[hit_channel_id] < per_channel_per_page: + if hit_channel_id is None: + to_inflate.append((hit_id, hit_channel_id)) + elif channel_counter[hit_channel_id] < per_channel_per_page: to_inflate.append((hit_id, hit_channel_id)) channel_counter[hit_channel_id] += 1 if len(to_inflate) % limit == 0: break else: delayed.append((hit_id, hit_channel_id)) - return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) + return to_inflate async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: