forked from LBRYCommunity/lbry-sdk
test and implementation for remove_duplicates on post-search filtering
This commit is contained in:
parent
bfc15ea029
commit
ca28de02d8
2 changed files with 39 additions and 3 deletions
|
@ -260,6 +260,7 @@ class SearchIndex:
|
|||
async def search_ahead(self, **kwargs):
|
||||
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
|
||||
per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
|
||||
remove_duplicates = kwargs.pop('remove_duplicates', False)
|
||||
page_size = kwargs.pop('limit', 10)
|
||||
offset = kwargs.pop('offset', 0)
|
||||
kwargs['limit'] = 1000
|
||||
|
@ -273,17 +274,36 @@ class SearchIndex:
|
|||
else:
|
||||
query = expand_query(**kwargs)
|
||||
search_hits = deque((await self.search_client.search(
|
||||
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
|
||||
query, index=self.index, track_total_hits=False,
|
||||
_source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height']
|
||||
))['hits']['hits'])
|
||||
if remove_duplicates:
|
||||
search_hits = self.__remove_duplicates(search_hits)
|
||||
if per_channel_per_page > 0:
|
||||
reordered_hits = await self.__search_ahead(search_hits, page_size, per_channel_per_page)
|
||||
reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page)
|
||||
else:
|
||||
reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
|
||||
cache_item.result = reordered_hits
|
||||
result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
|
||||
return result, 0, len(reordered_hits)
|
||||
|
||||
async def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
|
||||
def __remove_duplicates(self, search_hits: list):
|
||||
known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
|
||||
dropped = set()
|
||||
for hit in search_hits:
|
||||
hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id']
|
||||
if hit_id not in known_ids:
|
||||
known_ids[hit_id] = (hit_height, hit['_id'])
|
||||
else:
|
||||
previous_height, previous_id = known_ids[hit_id]
|
||||
if hit_height < previous_height:
|
||||
known_ids[hit_id] = (hit_height, hit['_id'])
|
||||
dropped.add(previous_id)
|
||||
else:
|
||||
dropped.add(hit['_id'])
|
||||
return [hit for hit in search_hits if hit['_id'] not in dropped]
|
||||
|
||||
def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
|
||||
reordered_hits = []
|
||||
channel_counters = Counter()
|
||||
next_page_hits_maybe_check_later = deque()
|
||||
|
|
|
@ -398,6 +398,22 @@ class ClaimSearchCommand(ClaimTestCase):
|
|||
limit_claims_per_channel=3, claim_type='stream'
|
||||
)
|
||||
|
||||
async def test_no_duplicates(self):
|
||||
await self.generate(10)
|
||||
match = self.assertFindsClaims
|
||||
claims = []
|
||||
channels = []
|
||||
first = await self.stream_create('original_claim0')
|
||||
second = await self.stream_create('original_claim1')
|
||||
for i in range(10):
|
||||
repost_id = self.get_claim_id(second if i % 2 == 0 else first)
|
||||
channel = await self.channel_create(f'@chan{i}', bid='0.001')
|
||||
channels.append(channel)
|
||||
claims.append(
|
||||
await self.stream_repost(repost_id, f'claim{i}', bid='0.001', channel_id=self.get_claim_id(channel)))
|
||||
await match([first, second] + channels,
|
||||
remove_duplicates=True, order_by=['^height'])
|
||||
|
||||
async def test_limit_claims_per_channel_across_sorted_pages(self):
|
||||
await self.generate(10)
|
||||
match = self.assertFindsClaims
|
||||
|
|
Loading…
Add table
Reference in a new issue