test and implementation for remove_duplicates on post-search filtering

This commit is contained in:
Victor Shyba 2021-05-19 03:05:51 -03:00
parent bfc15ea029
commit ca28de02d8
2 changed files with 39 additions and 3 deletions

View file

@ -260,6 +260,7 @@ class SearchIndex:
async def search_ahead(self, **kwargs):
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
remove_duplicates = kwargs.pop('remove_duplicates', False)
page_size = kwargs.pop('limit', 10)
offset = kwargs.pop('offset', 0)
kwargs['limit'] = 1000
@ -273,17 +274,36 @@ class SearchIndex:
else:
query = expand_query(**kwargs)
search_hits = deque((await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
query, index=self.index, track_total_hits=False,
_source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height']
))['hits']['hits'])
if remove_duplicates:
search_hits = self.__remove_duplicates(search_hits)
if per_channel_per_page > 0:
reordered_hits = await self.__search_ahead(search_hits, page_size, per_channel_per_page)
reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page)
else:
reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
cache_item.result = reordered_hits
result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
return result, 0, len(reordered_hits)
async def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
def __remove_duplicates(self, search_hits: list):
known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
dropped = set()
for hit in search_hits:
hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id']
if hit_id not in known_ids:
known_ids[hit_id] = (hit_height, hit['_id'])
else:
previous_height, previous_id = known_ids[hit_id]
if hit_height < previous_height:
known_ids[hit_id] = (hit_height, hit['_id'])
dropped.add(previous_id)
else:
dropped.add(hit['_id'])
return [hit for hit in search_hits if hit['_id'] not in dropped]
def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
reordered_hits = []
channel_counters = Counter()
next_page_hits_maybe_check_later = deque()

View file

@ -398,6 +398,22 @@ class ClaimSearchCommand(ClaimTestCase):
limit_claims_per_channel=3, claim_type='stream'
)
async def test_no_duplicates(self):
await self.generate(10)
match = self.assertFindsClaims
claims = []
channels = []
first = await self.stream_create('original_claim0')
second = await self.stream_create('original_claim1')
for i in range(10):
repost_id = self.get_claim_id(second if i % 2 == 0 else first)
channel = await self.channel_create(f'@chan{i}', bid='0.001')
channels.append(channel)
claims.append(
await self.stream_repost(repost_id, f'claim{i}', bid='0.001', channel_id=self.get_claim_id(channel)))
await match([first, second] + channels,
remove_duplicates=True, order_by=['^height'])
async def test_limit_claims_per_channel_across_sorted_pages(self):
await self.generate(10)
match = self.assertFindsClaims