diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index fbbf2cb48..98cc467f2 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -1,6 +1,7 @@ import asyncio import struct from binascii import unhexlify +from collections import Counter from decimal import Decimal from operator import itemgetter from typing import Optional, List, Iterable, Union @@ -218,13 +219,55 @@ class SearchIndex: if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str): return [], 0, 0 try: - result = (await self.search_client.search( - expand_query(**kwargs), index=self.index, track_total_hits=False if kwargs.get('no_totals') else 10_000 - ))['hits'] + if 'limit_claims_per_channel' in kwargs: + return await self.search_ahead(**kwargs), 0, 0 + else: + result = (await self.search_client.search( + expand_query(**kwargs), index=self.index, + track_total_hits=False if kwargs.get('no_totals') else 10_000 + ))['hits'] except NotFoundError: return [], 0, 0 return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) + async def search_ahead(self, **kwargs): + # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return + per_channel_per_page = kwargs.pop('limit_claims_per_channel') + limit = kwargs.pop('limit', 10) + offset = kwargs.pop('offset', 0) + kwargs['limit'] = 1000 + query = expand_query(**kwargs) + result = (await self.search_client.search( + query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] + ))['hits'] + to_inflate = [] + channel_counter = Counter() + delayed = [] + while result['hits'] or delayed: + if len(to_inflate) % limit == 0: + channel_counter.clear() + else: + break # means last page was incomplete and we are left with bad replacements + new_delayed = [] + for id, chann in delayed: + if channel_counter[chann] < per_channel_per_page: + to_inflate.append((id, chann)) + channel_counter[hit_channel_id] += 1 + else: + new_delayed.append((id, chann)) + delayed = new_delayed + while result['hits']: + hit = result['hits'].pop(0) + hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] + if channel_counter[hit_channel_id] < per_channel_per_page: + to_inflate.append((hit_id, hit_channel_id)) + channel_counter[hit_channel_id] += 1 + if len(to_inflate) % limit == 0: + break + else: + delayed.append((hit_id, hit_channel_id)) + return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) + async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: self.resolution_cache[raw_url] = await self._resolve_url(raw_url) diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 975b20ce7..22af4190a 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -75,7 +75,7 @@ class ClaimSearchCommand(ClaimTestCase): self.assertEqual( (claim['txid'], self.get_claim_id(claim)), (result['txid'], result['claim_id']), - f"{claim['outputs'][0]['name']} != {result['name']}" + f"(expected {claim['outputs'][0]['name']}) != (got {result['name']})" ) @skip("doesnt happen on ES...?") @@ -383,6 +383,38 @@ class ClaimSearchCommand(ClaimTestCase): limit_claims_per_channel=3, claim_type='stream' ) + async def test_limit_claims_per_channel_across_sorted_pages(self): + await self.generate(10) + match = self.assertFindsClaims + channel_id = self.get_claim_id(await self.channel_create('@chan0')) + claims = [] + first = await self.stream_create('claim0', channel_id=channel_id) + second = await self.stream_create('claim1', channel_id=channel_id) + for i in range(2, 10): + some_chan = self.get_claim_id(await self.channel_create(f'@chan{i}', bid='0.001')) + claims.append(await self.stream_create(f'claim{i}', bid='0.001', channel_id=some_chan)) + last = await self.stream_create('claim10', channel_id=channel_id) + + await match( + [first, second, claims[0], claims[1]], page_size=4, + limit_claims_per_channel=3, claim_type='stream', order_by=['^height'] + ) + # second goes out + await match( + [first, claims[0], claims[1], claims[2]], page_size=4, + limit_claims_per_channel=1, claim_type='stream', order_by=['^height'] + ) + # second appears, from replacement queue + await match( + [second, claims[3], claims[4], claims[5]], page_size=4, page=2, + limit_claims_per_channel=1, claim_type='stream', order_by=['^height'] + ) + # last is unaffected, as the limit applies per page + await match( + [claims[6], claims[7], last], page_size=4, page=3, + limit_claims_per_channel=1, claim_type='stream', order_by=['^height'] + ) + async def test_claim_type_and_media_type_search(self): # create an invalid/unknown claim address = await self.account.receiving.get_or_create_usable_address()