cache search_ahead

This commit is contained in:
Victor Shyba 2021-04-23 00:50:35 -03:00
parent 132ee1915f
commit fdb0e22656

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import struct import struct
from binascii import unhexlify from binascii import unhexlify
from collections import Counter from collections import Counter, deque
from decimal import Decimal from decimal import Decimal
from operator import itemgetter from operator import itemgetter
from typing import Optional, List, Iterable, Union from typing import Optional, List, Iterable, Union
@ -236,37 +236,51 @@ class SearchIndex:
limit = kwargs.pop('limit', 10) limit = kwargs.pop('limit', 10)
offset = kwargs.pop('offset', 0) offset = kwargs.pop('offset', 0)
kwargs['limit'] = 1000 kwargs['limit'] = 1000
query = expand_query(**kwargs) cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
result = (await self.search_client.search( if cache_item.result is not None:
to_inflate = cache_item.result
else:
async with cache_item.lock:
if cache_item.result:
to_inflate = cache_item.result
else:
query = expand_query(**kwargs)
to_inflate = await self.__search_ahead(query, limit, per_channel_per_page)
cache_item.result = to_inflate
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)])))
async def __search_ahead(self, query, limit, per_channel_per_page):
result = deque((await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
))['hits'] ))['hits']['hits'])
to_inflate = [] to_inflate = []
channel_counter = Counter() channel_counter = Counter()
delayed = [] delayed = deque()
while result['hits'] or delayed: while result or delayed:
if len(to_inflate) % limit == 0: if len(to_inflate) % limit == 0:
channel_counter.clear() channel_counter.clear()
else: else:
break # means last page was incomplete and we are left with bad replacements break # means last page was incomplete and we are left with bad replacements
new_delayed = [] for _ in range(len(delayed)):
for id, chann in delayed: claim_id, channel_id = delayed.popleft()
if channel_counter[chann] < per_channel_per_page: if channel_counter[channel_id] < per_channel_per_page:
to_inflate.append((id, chann)) to_inflate.append((claim_id, channel_id))
channel_counter[hit_channel_id] += 1 channel_counter[channel_id] += 1
else: else:
new_delayed.append((id, chann)) delayed.append((claim_id, channel_id))
delayed = new_delayed while result:
while result['hits']: hit = result.popleft()
hit = result['hits'].pop(0)
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
if channel_counter[hit_channel_id] < per_channel_per_page: if hit_channel_id is None:
to_inflate.append((hit_id, hit_channel_id))
elif channel_counter[hit_channel_id] < per_channel_per_page:
to_inflate.append((hit_id, hit_channel_id)) to_inflate.append((hit_id, hit_channel_id))
channel_counter[hit_channel_id] += 1 channel_counter[hit_channel_id] += 1
if len(to_inflate) % limit == 0: if len(to_inflate) % limit == 0:
break break
else: else:
delayed.append((hit_id, hit_channel_id)) delayed.append((hit_id, hit_channel_id))
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) return to_inflate
async def resolve_url(self, raw_url): async def resolve_url(self, raw_url):
if raw_url not in self.resolution_cache: if raw_url not in self.resolution_cache: