diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 8c6dbda84..26f8b7f32 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -2321,7 +2321,12 @@ class Daemon(metaclass=JSONRPCServerType): page_num, page_size = abs(kwargs.pop('page', 1)), min(abs(kwargs.pop('page_size', DEFAULT_PAGE_SIZE)), 50) kwargs.update({'offset': page_size * (page_num - 1), 'limit': page_size}) txos, blocked, _, total = await self.ledger.claim_search(wallet.accounts, **kwargs) - result = {"items": txos, "page": page_num, "page_size": page_size} + result = { + "items": txos, + "blocked": blocked, + "page": page_num, + "page_size": page_size + } if not kwargs.pop('no_totals', False): result['total_pages'] = int((total + (page_size - 1)) / page_size) result['total_items'] = total @@ -2756,7 +2761,7 @@ class Daemon(metaclass=JSONRPCServerType): # check that the holding_address hasn't changed since the export was made holding_address = data['holding_address'] - channels, _, _ = await self.ledger.claim_search( + channels, _, _, _ = await self.ledger.claim_search( wallet.accounts, public_key_id=self.ledger.public_key_to_address(public_key_der) ) if channels and channels[0].get_address(self.ledger) != holding_address: diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index e229ca14f..0b6771a8b 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -753,7 +753,7 @@ class Ledger(metaclass=LedgerRegistry): async def resolve_collection(self, collection, offset=0, page_size=1): claim_ids = collection.claim.collection.claims.ids[offset:page_size+offset] try: - resolve_results, _, _ = await self.claim_search([], claim_ids=claim_ids) + resolve_results, _, _, _ = await self.claim_search([], claim_ids=claim_ids) except Exception as err: if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 raise diff --git a/lbry/wallet/server/db/reader.py b/lbry/wallet/server/db/reader.py index a6f734fbf..728b4021f 100644 --- a/lbry/wallet/server/db/reader.py +++ b/lbry/wallet/server/db/reader.py @@ -166,15 +166,19 @@ def encode_result(result): @measure -def execute_query(sql, values, row_limit, censor) -> List: +def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor) -> List: context = ctx.get() context.set_query_timeout() try: c = context.db.cursor() def row_filter(cursor, row): + nonlocal row_offset row = row_factory(cursor, row) if len(row) > 1 and censor.censor(row): return + if row_offset: + row_offset -= 1 + return return row c.setrowtrace(row_filter) i, rows = 0, [] @@ -197,8 +201,11 @@ def execute_query(sql, values, row_limit, censor) -> List: def _get_claims(cols, for_count=False, **constraints) -> Tuple[str, Dict]: if 'order_by' in constraints: + order_by_parts = constraints['order_by'] + if isinstance(order_by_parts, str): + order_by_parts = [order_by_parts] sql_order_by = [] - for order_by in constraints['order_by']: + for order_by in order_by_parts: is_asc = order_by.startswith('^') column = order_by[1:] if is_asc else order_by if column not in ORDER_FIELDS: @@ -322,12 +329,12 @@ def get_claims(cols, for_count=False, **constraints) -> Tuple[List, Censor]: censor = Censor( ctx.get().blocked_claims, {unhexlify(ncid)[::-1] for ncid in constraints.pop('not_channel_ids', [])}, - set(constraints.pop('not_tags', {})) + set(clean_tags(constraints.pop('not_tags', {}))) ) + row_offset = constraints.pop('offset', 0) row_limit = constraints.pop('limit', 20) - constraints['limit'] = 1000 sql, values = _get_claims(cols, for_count, **constraints) - return execute_query(sql, values, row_limit, censor), censor + return execute_query(sql, values, row_offset, row_limit, censor), censor @measure @@ -354,10 +361,7 @@ def _search(**constraints) -> Tuple[List, Censor]: claim.short_url, claim.canonical_url, claim.channel_hash, claim.reposted_claim_hash, claim.signature_valid, - COALESCE( - (SELECT group_concat(tag) FROM tag WHERE tag.claim_hash = claim.claim_hash), - "" - ) as tags + COALESCE((SELECT group_concat(tag) FROM tag WHERE tag.claim_hash = claim.claim_hash), "") as tags """, **constraints ) diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 3ca9fa3dc..78eb7d1fd 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -751,27 +751,37 @@ class StreamCommands(ClaimTestCase): self.assertEqual(resolved['newstuff-again']['reposted_claim']['name'], 'newstuff') async def test_filtering_channels_for_removing_content(self): - await self.channel_create('@badstuff', '1.0') - await self.stream_create('not_bad', '1.1', channel_name='@badstuff') - tx = await self.stream_create('too_bad', '1.1', channel_name='@badstuff') - claim_id = self.get_claim_id(tx) - await self.channel_create('@reposts', '1.0') - await self.stream_repost(claim_id, 'normal_repost', '1.2', channel_name='@reposts') - filtering1 = await self.channel_create('@filtering1', '1.0') - filtering1 = self.get_claim_id(filtering1) - await self.stream_repost(claim_id, 'filter1', '1.1', channel_name='@filtering1') - await self.conductor.spv_node.stop() - await self.conductor.spv_node.start( - self.conductor.blockchain_node, extraconf={'FILTERING_CHANNELS_IDS': filtering1} + await self.channel_create('@some_channel', '1.0') + await self.stream_create('good_content', '1.1', channel_name='@some_channel', tags=['good']) + bad_content_id = self.get_claim_id( + await self.stream_create('bad_content', '1.1', channel_name='@some_channel', tags=['bad']) ) - await self.ledger.stop() - await self.ledger.start() - filtered_claim_search = await self.out(self.daemon.jsonrpc_claim_search(name='too_bad')) - self.assertEqual(filtered_claim_search, []) - filtered_claim_search = await self.claim_search(name='not_bad') - self.assertEqual(len(filtered_claim_search), 1) - filtered_claim_search = await self.claim_search(name='normal_repost') - self.assertEqual(len(filtered_claim_search), 1) + blocking_channel_id = self.get_claim_id( + await self.channel_create('@filtering', '1.0') + ) + self.conductor.spv_node.server.db.sql.filtering_channel_hashes.add( + unhexlify(blocking_channel_id)[::-1] + ) + await self.stream_repost(bad_content_id, 'filter1', '1.1', channel_name='@filtering') + + # search for blocked content directly + result = await self.out(self.daemon.jsonrpc_claim_search(name='bad_content')) + self.assertEqual([], result['items']) + self.assertEqual({"reposted_in_channel": {blocking_channel_id: 1}, "total": 1}, result['blocked']) + + # search channel containing blocked content + result = await self.out(self.daemon.jsonrpc_claim_search(channel='@some_channel')) + self.assertEqual(1, len(result['items'])) + self.assertEqual({"reposted_in_channel": {blocking_channel_id: 1}, "total": 1}, result['blocked']) + + # search channel containing blocked content, also block tag + result = await self.out(self.daemon.jsonrpc_claim_search(channel='@some_channel', not_tags=["good", "bad"])) + self.assertEqual(0, len(result['items'])) + self.assertEqual({ + "reposted_in_channel": {blocking_channel_id: 1}, + "has_tag": {"good": 1, "bad": 1}, + "total": 2 + }, result['blocked']) async def test_publish_updates_file_list(self): tx = await self.stream_create(title='created') diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py index 995d14c36..6054111fe 100644 --- a/tests/unit/wallet/server/test_sqldb.py +++ b/tests/unit/wallet/server/test_sqldb.py @@ -594,3 +594,33 @@ class TestContentBlocking(TestSQLDB): self.assertEqual({channel.claim_hash: 1}, censor.blocked_claims) self.assertEqual({a_channel.claim_hash: 1}, censor.blocked_channels) self.assertEqual({}, censor.blocked_tags) + + def test_pagination(self): + one, two, three, four, five, six, seven = ( + self.advance(1, [self.get_stream('One', COIN, tags=["mature"])])[0], + self.advance(2, [self.get_stream('Two', COIN, tags=["mature"])])[0], + self.advance(3, [self.get_stream('Three', COIN)])[0], + self.advance(4, [self.get_stream('Four', COIN)])[0], + self.advance(5, [self.get_stream('Five', COIN)])[0], + self.advance(6, [self.get_stream('Six', COIN)])[0], + self.advance(7, [self.get_stream('Seven', COIN)])[0], + ) + + # nothing blocked + results, censor = reader._search(order_by='^height', offset=1, limit=3) + self.assertEqual(3, len(results)) + self.assertEqual( + [two.claim_hash, three.claim_hash, four.claim_hash], + [r['claim_hash'] for r in results] + ) + self.assertEqual(0, censor.total) + + # tags blocked + results, censor = reader._search(order_by='^height', not_tags=('mature',), offset=1, limit=3) + self.assertEqual(3, len(results)) + self.assertEqual( + [four.claim_hash, five.claim_hash, six.claim_hash], + [r['claim_hash'] for r in results] + ) + self.assertEqual(2, censor.total) + self.assertEqual({"mature": 2}, censor.blocked_tags)