handle limit being 0 and skip reordering if 0/none

This commit is contained in:
Victor Shyba 2021-05-19 02:35:11 -03:00
parent 6e8b8a5920
commit bfc15ea029
2 changed files with 17 additions and 8 deletions

View file

@ -259,7 +259,7 @@ class SearchIndex:
async def search_ahead(self, **kwargs): async def search_ahead(self, **kwargs):
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
per_channel_per_page = kwargs.pop('limit_claims_per_channel') per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
page_size = kwargs.pop('limit', 10) page_size = kwargs.pop('limit', 10)
offset = kwargs.pop('offset', 0) offset = kwargs.pop('offset', 0)
kwargs['limit'] = 1000 kwargs['limit'] = 1000
@ -272,15 +272,18 @@ class SearchIndex:
reordered_hits = cache_item.result reordered_hits = cache_item.result
else: else:
query = expand_query(**kwargs) query = expand_query(**kwargs)
reordered_hits = await self.__search_ahead(query, page_size, per_channel_per_page) search_hits = deque((await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
))['hits']['hits'])
if per_channel_per_page > 0:
reordered_hits = await self.__search_ahead(search_hits, page_size, per_channel_per_page)
else:
reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
cache_item.result = reordered_hits cache_item.result = reordered_hits
result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
return result, 0, len(reordered_hits) return result, 0, len(reordered_hits)
async def __search_ahead(self, query: dict, page_size: int, per_channel_per_page: int): async def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
search_hits = deque((await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
))['hits']['hits'])
reordered_hits = [] reordered_hits = []
channel_counters = Counter() channel_counters = Counter()
next_page_hits_maybe_check_later = deque() next_page_hits_maybe_check_later = deque()
@ -293,7 +296,7 @@ class SearchIndex:
break # means last page was incomplete and we are left with bad replacements break # means last page was incomplete and we are left with bad replacements
for _ in range(len(next_page_hits_maybe_check_later)): for _ in range(len(next_page_hits_maybe_check_later)):
claim_id, channel_id = next_page_hits_maybe_check_later.popleft() claim_id, channel_id = next_page_hits_maybe_check_later.popleft()
if channel_counters[channel_id] < per_channel_per_page: if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page:
reordered_hits.append((claim_id, channel_id)) reordered_hits.append((claim_id, channel_id))
channel_counters[channel_id] += 1 channel_counters[channel_id] += 1
else: else:
@ -301,7 +304,7 @@ class SearchIndex:
while search_hits: while search_hits:
hit = search_hits.popleft() hit = search_hits.popleft()
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
if hit_channel_id is None: if hit_channel_id is None or per_channel_per_page <= 0:
reordered_hits.append((hit_id, hit_channel_id)) reordered_hits.append((hit_id, hit_channel_id))
elif channel_counters[hit_channel_id] < per_channel_per_page: elif channel_counters[hit_channel_id] < per_channel_per_page:
reordered_hits.append((hit_id, hit_channel_id)) reordered_hits.append((hit_id, hit_channel_id))

View file

@ -429,6 +429,12 @@ class ClaimSearchCommand(ClaimTestCase):
[claims[6], claims[7], last], page_size=4, page=3, [claims[6], claims[7], last], page_size=4, page=3,
limit_claims_per_channel=1, claim_type='stream', order_by=['^height'] limit_claims_per_channel=1, claim_type='stream', order_by=['^height']
) )
# feature disabled on 0 or negative values
for limit in [None, 0, -1]:
await match(
[first, second] + claims + [last],
limit_claims_per_channel=limit, claim_type='stream', order_by=['^height']
)
async def test_claim_type_and_media_type_search(self): async def test_claim_type_and_media_type_search(self):
# create an invalid/unknown claim # create an invalid/unknown claim