cache search_ahead

This commit is contained in:
Victor Shyba 2021-04-23 00:50:35 -03:00
parent 132ee1915f
commit fdb0e22656

View file

@ -1,7 +1,7 @@
import asyncio
import struct
from binascii import unhexlify
from collections import Counter
from collections import Counter, deque
from decimal import Decimal
from operator import itemgetter
from typing import Optional, List, Iterable, Union
@ -236,37 +236,51 @@ class SearchIndex:
limit = kwargs.pop('limit', 10)
offset = kwargs.pop('offset', 0)
kwargs['limit'] = 1000
cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
if cache_item.result is not None:
to_inflate = cache_item.result
else:
async with cache_item.lock:
if cache_item.result:
to_inflate = cache_item.result
else:
query = expand_query(**kwargs)
result = (await self.search_client.search(
to_inflate = await self.__search_ahead(query, limit, per_channel_per_page)
cache_item.result = to_inflate
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)])))
async def __search_ahead(self, query, limit, per_channel_per_page):
result = deque((await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
))['hits']
))['hits']['hits'])
to_inflate = []
channel_counter = Counter()
delayed = []
while result['hits'] or delayed:
delayed = deque()
while result or delayed:
if len(to_inflate) % limit == 0:
channel_counter.clear()
else:
break # means last page was incomplete and we are left with bad replacements
new_delayed = []
for id, chann in delayed:
if channel_counter[chann] < per_channel_per_page:
to_inflate.append((id, chann))
channel_counter[hit_channel_id] += 1
for _ in range(len(delayed)):
claim_id, channel_id = delayed.popleft()
if channel_counter[channel_id] < per_channel_per_page:
to_inflate.append((claim_id, channel_id))
channel_counter[channel_id] += 1
else:
new_delayed.append((id, chann))
delayed = new_delayed
while result['hits']:
hit = result['hits'].pop(0)
delayed.append((claim_id, channel_id))
while result:
hit = result.popleft()
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
if channel_counter[hit_channel_id] < per_channel_per_page:
if hit_channel_id is None:
to_inflate.append((hit_id, hit_channel_id))
elif channel_counter[hit_channel_id] < per_channel_per_page:
to_inflate.append((hit_id, hit_channel_id))
channel_counter[hit_channel_id] += 1
if len(to_inflate) % limit == 0:
break
else:
delayed.append((hit_id, hit_channel_id))
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)])))
return to_inflate
async def resolve_url(self, raw_url):
if raw_url not in self.resolution_cache: