new implementation for limit_claims_per_channel

This commit is contained in:
Victor Shyba 2021-04-14 12:16:49 -03:00 committed by Lex Berezhny
parent 467637a9eb
commit cc2852cd48
2 changed files with 79 additions and 4 deletions

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import struct import struct
from binascii import unhexlify from binascii import unhexlify
from collections import Counter
from decimal import Decimal from decimal import Decimal
from operator import itemgetter from operator import itemgetter
from typing import Optional, List, Iterable, Union from typing import Optional, List, Iterable, Union
@ -218,13 +219,55 @@ class SearchIndex:
if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str): if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str):
return [], 0, 0 return [], 0, 0
try: try:
if 'limit_claims_per_channel' in kwargs:
return await self.search_ahead(**kwargs), 0, 0
else:
result = (await self.search_client.search( result = (await self.search_client.search(
expand_query(**kwargs), index=self.index, track_total_hits=False if kwargs.get('no_totals') else 10_000 expand_query(**kwargs), index=self.index,
track_total_hits=False if kwargs.get('no_totals') else 10_000
))['hits'] ))['hits']
except NotFoundError: except NotFoundError:
return [], 0, 0 return [], 0, 0
return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)
async def search_ahead(self, **kwargs):
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
per_channel_per_page = kwargs.pop('limit_claims_per_channel')
limit = kwargs.pop('limit', 10)
offset = kwargs.pop('offset', 0)
kwargs['limit'] = 1000
query = expand_query(**kwargs)
result = (await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
))['hits']
to_inflate = []
channel_counter = Counter()
delayed = []
while result['hits'] or delayed:
if len(to_inflate) % limit == 0:
channel_counter.clear()
else:
break # means last page was incomplete and we are left with bad replacements
new_delayed = []
for id, chann in delayed:
if channel_counter[chann] < per_channel_per_page:
to_inflate.append((id, chann))
channel_counter[hit_channel_id] += 1
else:
new_delayed.append((id, chann))
delayed = new_delayed
while result['hits']:
hit = result['hits'].pop(0)
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
if channel_counter[hit_channel_id] < per_channel_per_page:
to_inflate.append((hit_id, hit_channel_id))
channel_counter[hit_channel_id] += 1
if len(to_inflate) % limit == 0:
break
else:
delayed.append((hit_id, hit_channel_id))
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)])))
async def resolve_url(self, raw_url): async def resolve_url(self, raw_url):
if raw_url not in self.resolution_cache: if raw_url not in self.resolution_cache:
self.resolution_cache[raw_url] = await self._resolve_url(raw_url) self.resolution_cache[raw_url] = await self._resolve_url(raw_url)

View file

@ -75,7 +75,7 @@ class ClaimSearchCommand(ClaimTestCase):
self.assertEqual( self.assertEqual(
(claim['txid'], self.get_claim_id(claim)), (claim['txid'], self.get_claim_id(claim)),
(result['txid'], result['claim_id']), (result['txid'], result['claim_id']),
f"{claim['outputs'][0]['name']} != {result['name']}" f"(expected {claim['outputs'][0]['name']}) != (got {result['name']})"
) )
@skip("doesnt happen on ES...?") @skip("doesnt happen on ES...?")
@ -383,6 +383,38 @@ class ClaimSearchCommand(ClaimTestCase):
limit_claims_per_channel=3, claim_type='stream' limit_claims_per_channel=3, claim_type='stream'
) )
async def test_limit_claims_per_channel_across_sorted_pages(self):
await self.generate(10)
match = self.assertFindsClaims
channel_id = self.get_claim_id(await self.channel_create('@chan0'))
claims = []
first = await self.stream_create('claim0', channel_id=channel_id)
second = await self.stream_create('claim1', channel_id=channel_id)
for i in range(2, 10):
some_chan = self.get_claim_id(await self.channel_create(f'@chan{i}', bid='0.001'))
claims.append(await self.stream_create(f'claim{i}', bid='0.001', channel_id=some_chan))
last = await self.stream_create('claim10', channel_id=channel_id)
await match(
[first, second, claims[0], claims[1]], page_size=4,
limit_claims_per_channel=3, claim_type='stream', order_by=['^height']
)
# second goes out
await match(
[first, claims[0], claims[1], claims[2]], page_size=4,
limit_claims_per_channel=1, claim_type='stream', order_by=['^height']
)
# second appears, from replacement queue
await match(
[second, claims[3], claims[4], claims[5]], page_size=4, page=2,
limit_claims_per_channel=1, claim_type='stream', order_by=['^height']
)
# last is unaffected, as the limit applies per page
await match(
[claims[6], claims[7], last], page_size=4, page=3,
limit_claims_per_channel=1, claim_type='stream', order_by=['^height']
)
async def test_claim_type_and_media_type_search(self): async def test_claim_type_and_media_type_search(self):
# create an invalid/unknown claim # create an invalid/unknown claim
address = await self.account.receiving.get_or_create_usable_address() address = await self.account.receiving.get_or_create_usable_address()