test and implementation for remove_duplicates on post-search filtering
This commit is contained in:
parent
bfc15ea029
commit
ca28de02d8
2 changed files with 39 additions and 3 deletions
|
@ -260,6 +260,7 @@ class SearchIndex:
|
||||||
async def search_ahead(self, **kwargs):
|
async def search_ahead(self, **kwargs):
|
||||||
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
|
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
|
||||||
per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
|
per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
|
||||||
|
remove_duplicates = kwargs.pop('remove_duplicates', False)
|
||||||
page_size = kwargs.pop('limit', 10)
|
page_size = kwargs.pop('limit', 10)
|
||||||
offset = kwargs.pop('offset', 0)
|
offset = kwargs.pop('offset', 0)
|
||||||
kwargs['limit'] = 1000
|
kwargs['limit'] = 1000
|
||||||
|
@ -273,17 +274,36 @@ class SearchIndex:
|
||||||
else:
|
else:
|
||||||
query = expand_query(**kwargs)
|
query = expand_query(**kwargs)
|
||||||
search_hits = deque((await self.search_client.search(
|
search_hits = deque((await self.search_client.search(
|
||||||
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
|
query, index=self.index, track_total_hits=False,
|
||||||
|
_source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height']
|
||||||
))['hits']['hits'])
|
))['hits']['hits'])
|
||||||
|
if remove_duplicates:
|
||||||
|
search_hits = self.__remove_duplicates(search_hits)
|
||||||
if per_channel_per_page > 0:
|
if per_channel_per_page > 0:
|
||||||
reordered_hits = await self.__search_ahead(search_hits, page_size, per_channel_per_page)
|
reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page)
|
||||||
else:
|
else:
|
||||||
reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
|
reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
|
||||||
cache_item.result = reordered_hits
|
cache_item.result = reordered_hits
|
||||||
result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
|
result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
|
||||||
return result, 0, len(reordered_hits)
|
return result, 0, len(reordered_hits)
|
||||||
|
|
||||||
async def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
|
def __remove_duplicates(self, search_hits: list):
|
||||||
|
known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
|
||||||
|
dropped = set()
|
||||||
|
for hit in search_hits:
|
||||||
|
hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id']
|
||||||
|
if hit_id not in known_ids:
|
||||||
|
known_ids[hit_id] = (hit_height, hit['_id'])
|
||||||
|
else:
|
||||||
|
previous_height, previous_id = known_ids[hit_id]
|
||||||
|
if hit_height < previous_height:
|
||||||
|
known_ids[hit_id] = (hit_height, hit['_id'])
|
||||||
|
dropped.add(previous_id)
|
||||||
|
else:
|
||||||
|
dropped.add(hit['_id'])
|
||||||
|
return [hit for hit in search_hits if hit['_id'] not in dropped]
|
||||||
|
|
||||||
|
def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
|
||||||
reordered_hits = []
|
reordered_hits = []
|
||||||
channel_counters = Counter()
|
channel_counters = Counter()
|
||||||
next_page_hits_maybe_check_later = deque()
|
next_page_hits_maybe_check_later = deque()
|
||||||
|
|
|
@ -398,6 +398,22 @@ class ClaimSearchCommand(ClaimTestCase):
|
||||||
limit_claims_per_channel=3, claim_type='stream'
|
limit_claims_per_channel=3, claim_type='stream'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def test_no_duplicates(self):
|
||||||
|
await self.generate(10)
|
||||||
|
match = self.assertFindsClaims
|
||||||
|
claims = []
|
||||||
|
channels = []
|
||||||
|
first = await self.stream_create('original_claim0')
|
||||||
|
second = await self.stream_create('original_claim1')
|
||||||
|
for i in range(10):
|
||||||
|
repost_id = self.get_claim_id(second if i % 2 == 0 else first)
|
||||||
|
channel = await self.channel_create(f'@chan{i}', bid='0.001')
|
||||||
|
channels.append(channel)
|
||||||
|
claims.append(
|
||||||
|
await self.stream_repost(repost_id, f'claim{i}', bid='0.001', channel_id=self.get_claim_id(channel)))
|
||||||
|
await match([first, second] + channels,
|
||||||
|
remove_duplicates=True, order_by=['^height'])
|
||||||
|
|
||||||
async def test_limit_claims_per_channel_across_sorted_pages(self):
|
async def test_limit_claims_per_channel_across_sorted_pages(self):
|
||||||
await self.generate(10)
|
await self.generate(10)
|
||||||
match = self.assertFindsClaims
|
match = self.assertFindsClaims
|
||||||
|
|
Loading…
Add table
Reference in a new issue