diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 730e4f033..be1047184 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -259,7 +259,7 @@ class SearchIndex: async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return - per_channel_per_page = kwargs.pop('limit_claims_per_channel') + per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0 page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 @@ -272,15 +272,18 @@ class SearchIndex: reordered_hits = cache_item.result else: query = expand_query(**kwargs) - reordered_hits = await self.__search_ahead(query, page_size, per_channel_per_page) + search_hits = deque((await self.search_client.search( + query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] + ))['hits']['hits']) + if per_channel_per_page > 0: + reordered_hits = await self.__search_ahead(search_hits, page_size, per_channel_per_page) + else: + reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits] cache_item.result = reordered_hits result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) return result, 0, len(reordered_hits) - async def __search_ahead(self, query: dict, page_size: int, per_channel_per_page: int): - search_hits = deque((await self.search_client.search( - query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] - ))['hits']['hits']) + async def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): reordered_hits = [] channel_counters = Counter() next_page_hits_maybe_check_later = deque() @@ -293,7 +296,7 @@ class SearchIndex: break # means last page was incomplete and we are left with bad replacements for _ in range(len(next_page_hits_maybe_check_later)): claim_id, channel_id = next_page_hits_maybe_check_later.popleft() - if channel_counters[channel_id] < per_channel_per_page: + if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page: reordered_hits.append((claim_id, channel_id)) channel_counters[channel_id] += 1 else: @@ -301,7 +304,7 @@ class SearchIndex: while search_hits: hit = search_hits.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] - if hit_channel_id is None: + if hit_channel_id is None or per_channel_per_page <= 0: reordered_hits.append((hit_id, hit_channel_id)) elif channel_counters[hit_channel_id] < per_channel_per_page: reordered_hits.append((hit_id, hit_channel_id)) diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index a644b394d..45607ebe0 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -429,6 +429,12 @@ class ClaimSearchCommand(ClaimTestCase): [claims[6], claims[7], last], page_size=4, page=3, limit_claims_per_channel=1, claim_type='stream', order_by=['^height'] ) + # feature disabled on 0 or negative values + for limit in [None, 0, -1]: + await match( + [first, second] + claims + [last], + limit_claims_per_channel=limit, claim_type='stream', order_by=['^height'] + ) async def test_claim_type_and_media_type_search(self): # create an invalid/unknown claim