forked from LBRYCommunity/lbry-sdk
cache search_ahead
This commit is contained in:
parent
132ee1915f
commit
fdb0e22656
1 changed files with 31 additions and 17 deletions
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import struct
|
||||
from binascii import unhexlify
|
||||
from collections import Counter
|
||||
from collections import Counter, deque
|
||||
from decimal import Decimal
|
||||
from operator import itemgetter
|
||||
from typing import Optional, List, Iterable, Union
|
||||
|
@ -236,37 +236,51 @@ class SearchIndex:
|
|||
limit = kwargs.pop('limit', 10)
|
||||
offset = kwargs.pop('offset', 0)
|
||||
kwargs['limit'] = 1000
|
||||
cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
|
||||
if cache_item.result is not None:
|
||||
to_inflate = cache_item.result
|
||||
else:
|
||||
async with cache_item.lock:
|
||||
if cache_item.result:
|
||||
to_inflate = cache_item.result
|
||||
else:
|
||||
query = expand_query(**kwargs)
|
||||
result = (await self.search_client.search(
|
||||
to_inflate = await self.__search_ahead(query, limit, per_channel_per_page)
|
||||
cache_item.result = to_inflate
|
||||
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)])))
|
||||
|
||||
async def __search_ahead(self, query, limit, per_channel_per_page):
|
||||
result = deque((await self.search_client.search(
|
||||
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
|
||||
))['hits']
|
||||
))['hits']['hits'])
|
||||
to_inflate = []
|
||||
channel_counter = Counter()
|
||||
delayed = []
|
||||
while result['hits'] or delayed:
|
||||
delayed = deque()
|
||||
while result or delayed:
|
||||
if len(to_inflate) % limit == 0:
|
||||
channel_counter.clear()
|
||||
else:
|
||||
break # means last page was incomplete and we are left with bad replacements
|
||||
new_delayed = []
|
||||
for id, chann in delayed:
|
||||
if channel_counter[chann] < per_channel_per_page:
|
||||
to_inflate.append((id, chann))
|
||||
channel_counter[hit_channel_id] += 1
|
||||
for _ in range(len(delayed)):
|
||||
claim_id, channel_id = delayed.popleft()
|
||||
if channel_counter[channel_id] < per_channel_per_page:
|
||||
to_inflate.append((claim_id, channel_id))
|
||||
channel_counter[channel_id] += 1
|
||||
else:
|
||||
new_delayed.append((id, chann))
|
||||
delayed = new_delayed
|
||||
while result['hits']:
|
||||
hit = result['hits'].pop(0)
|
||||
delayed.append((claim_id, channel_id))
|
||||
while result:
|
||||
hit = result.popleft()
|
||||
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
|
||||
if channel_counter[hit_channel_id] < per_channel_per_page:
|
||||
if hit_channel_id is None:
|
||||
to_inflate.append((hit_id, hit_channel_id))
|
||||
elif channel_counter[hit_channel_id] < per_channel_per_page:
|
||||
to_inflate.append((hit_id, hit_channel_id))
|
||||
channel_counter[hit_channel_id] += 1
|
||||
if len(to_inflate) % limit == 0:
|
||||
break
|
||||
else:
|
||||
delayed.append((hit_id, hit_channel_id))
|
||||
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)])))
|
||||
return to_inflate
|
||||
|
||||
async def resolve_url(self, raw_url):
|
||||
if raw_url not in self.resolution_cache:
|
||||
|
|
Loading…
Reference in a new issue