diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index be1047184..c562ec684 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -260,6 +260,7 @@ class SearchIndex: async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0 + remove_duplicates = kwargs.pop('remove_duplicates', False) page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 @@ -273,17 +274,36 @@ class SearchIndex: else: query = expand_query(**kwargs) search_hits = deque((await self.search_client.search( - query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] + query, index=self.index, track_total_hits=False, + _source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height'] ))['hits']['hits']) + if remove_duplicates: + search_hits = self.__remove_duplicates(search_hits) if per_channel_per_page > 0: - reordered_hits = await self.__search_ahead(search_hits, page_size, per_channel_per_page) + reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page) else: reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits] cache_item.result = reordered_hits result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) return result, 0, len(reordered_hits) - async def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): + def __remove_duplicates(self, search_hits: list): + known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original + dropped = set() + for hit in search_hits: + hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id'] + if hit_id not in known_ids: + known_ids[hit_id] = (hit_height, hit['_id']) + else: + previous_height, previous_id = known_ids[hit_id] + if hit_height < previous_height: + known_ids[hit_id] = (hit_height, hit['_id']) + dropped.add(previous_id) + else: + dropped.add(hit['_id']) + return [hit for hit in search_hits if hit['_id'] not in dropped] + + def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): reordered_hits = [] channel_counters = Counter() next_page_hits_maybe_check_later = deque() diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 45607ebe0..22e510c1e 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -398,6 +398,22 @@ class ClaimSearchCommand(ClaimTestCase): limit_claims_per_channel=3, claim_type='stream' ) + async def test_no_duplicates(self): + await self.generate(10) + match = self.assertFindsClaims + claims = [] + channels = [] + first = await self.stream_create('original_claim0') + second = await self.stream_create('original_claim1') + for i in range(10): + repost_id = self.get_claim_id(second if i % 2 == 0 else first) + channel = await self.channel_create(f'@chan{i}', bid='0.001') + channels.append(channel) + claims.append( + await self.stream_repost(repost_id, f'claim{i}', bid='0.001', channel_id=self.get_claim_id(channel))) + await match([first, second] + channels, + remove_duplicates=True, order_by=['^height']) + async def test_limit_claims_per_channel_across_sorted_pages(self): await self.generate(10) match = self.assertFindsClaims