diff --git a/tests/unit/wallet/server/reader.py b/tests/unit/wallet/server/reader.py new file mode 100644 index 000000000..aef0a2369 --- /dev/null +++ b/tests/unit/wallet/server/reader.py @@ -0,0 +1,634 @@ +import time +import struct +import apsw +import logging +from operator import itemgetter +from typing import Tuple, List, Dict, Union, Type, Optional +from binascii import unhexlify +from decimal import Decimal +from contextvars import ContextVar +from functools import wraps +from itertools import chain +from dataclasses import dataclass + +from lbry.wallet.database import query, interpolate +from lbry.error import ResolveCensoredError +from lbry.schema.url import URL, normalize_name +from lbry.schema.tags import clean_tags +from lbry.schema.result import Outputs, Censor +from lbry.wallet import Ledger, RegTestLedger + +from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES + + +class SQLiteOperationalError(apsw.Error): + def __init__(self, metrics): + super().__init__('sqlite query errored') + self.metrics = metrics + + +class SQLiteInterruptedError(apsw.InterruptError): + def __init__(self, metrics): + super().__init__('sqlite query interrupted') + self.metrics = metrics + + +ATTRIBUTE_ARRAY_MAX_LENGTH = 100 + +INTEGER_PARAMS = { + 'height', 'creation_height', 'activation_height', 'expiration_height', + 'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount', + 'tx_position', 'channel_join', 'reposted', 'limit_claims_per_channel', + 'amount', 'effective_amount', 'support_amount', + 'trending_group', 'trending_mixed', + 'trending_local', 'trending_global', +} + +SEARCH_PARAMS = { + 'name', 'text', 'claim_id', 'claim_ids', 'txid', 'nout', 'channel', 'channel_ids', 'not_channel_ids', + 'public_key_id', 'claim_type', 'stream_types', 'media_types', 'fee_currency', + 'has_channel_signature', 'signature_valid', + 'any_tags', 'all_tags', 'not_tags', 'reposted_claim_id', + 'any_locations', 'all_locations', 'not_locations', + 'any_languages', 'all_languages', 'not_languages', + 'is_controlling', 'limit', 'offset', 'order_by', + 'no_totals', 'has_source' +} | INTEGER_PARAMS + + +ORDER_FIELDS = { + 'name', 'claim_hash' +} | INTEGER_PARAMS + + +@dataclass +class ReaderState: + db: apsw.Connection + stack: List[List] + metrics: Dict + is_tracking_metrics: bool + ledger: Type[Ledger] + query_timeout: float + log: logging.Logger + blocked_streams: Dict + blocked_channels: Dict + filtered_streams: Dict + filtered_channels: Dict + + def close(self): + self.db.close() + + def reset_metrics(self): + self.stack = [] + self.metrics = {} + + def set_query_timeout(self): + stop_at = time.perf_counter() + self.query_timeout + + def interruptor(): + if time.perf_counter() >= stop_at: + self.db.interrupt() + return + + self.db.setprogresshandler(interruptor, 100) + + def get_resolve_censor(self) -> Censor: + return Censor(Censor.RESOLVE) + + def get_search_censor(self, limit_claims_per_channel: int) -> Censor: + return Censor(Censor.SEARCH) + + +ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx') + + +def row_factory(cursor, row): + return { + k[0]: (set(row[i].split(',')) if k[0] == 'tags' else row[i]) + for i, k in enumerate(cursor.getdescription()) + } + + +def initializer(log, _path, _ledger_name, query_timeout, _measure=False, block_and_filter=None): + db = apsw.Connection(_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) + db.setrowtrace(row_factory) + if block_and_filter: + blocked_streams, blocked_channels, filtered_streams, filtered_channels = block_and_filter + else: + blocked_streams = blocked_channels = filtered_streams = filtered_channels = {} + ctx.set( + ReaderState( + db=db, stack=[], metrics={}, is_tracking_metrics=_measure, + ledger=Ledger if _ledger_name == 'mainnet' else RegTestLedger, + query_timeout=query_timeout, log=log, + blocked_streams=blocked_streams, blocked_channels=blocked_channels, + filtered_streams=filtered_streams, filtered_channels=filtered_channels, + ) + ) + + +def cleanup(): + ctx.get().close() + ctx.set(None) + + +def measure(func): + @wraps(func) + def wrapper(*args, **kwargs): + state = ctx.get() + if not state.is_tracking_metrics: + return func(*args, **kwargs) + metric = {} + state.metrics.setdefault(func.__name__, []).append(metric) + state.stack.append([]) + start = time.perf_counter() + try: + return func(*args, **kwargs) + finally: + elapsed = int((time.perf_counter()-start)*1000) + metric['total'] = elapsed + metric['isolated'] = (elapsed-sum(state.stack.pop())) + if state.stack: + state.stack[-1].append(elapsed) + return wrapper + + +def reports_metrics(func): + @wraps(func) + def wrapper(*args, **kwargs): + state = ctx.get() + if not state.is_tracking_metrics: + return func(*args, **kwargs) + state.reset_metrics() + r = func(*args, **kwargs) + return r, state.metrics + return wrapper + + +@reports_metrics +def search_to_bytes(constraints) -> Union[bytes, Tuple[bytes, Dict]]: + return encode_result(search(constraints)) + + +@reports_metrics +def resolve_to_bytes(urls) -> Union[bytes, Tuple[bytes, Dict]]: + return encode_result(resolve(urls)) + + +def encode_result(result): + return Outputs.to_bytes(*result) + + +@measure +def execute_query(sql, values, row_offset: int, row_limit: int, censor: Censor) -> List: + context = ctx.get() + context.set_query_timeout() + try: + c = context.db.cursor() + def row_filter(cursor, row): + nonlocal row_offset + row = row_factory(cursor, row) + if len(row) > 1 and censor.censor(row): + return + if row_offset: + row_offset -= 1 + return + return row + c.setrowtrace(row_filter) + i, rows = 0, [] + for row in c.execute(sql, values): + i += 1 + rows.append(row) + if i >= row_limit: + break + return rows + except apsw.Error as err: + plain_sql = interpolate(sql, values) + if context.is_tracking_metrics: + context.metrics['execute_query'][-1]['sql'] = plain_sql + if isinstance(err, apsw.InterruptError): + context.log.warning("interrupted slow sqlite query:\n%s", plain_sql) + raise SQLiteInterruptedError(context.metrics) + context.log.exception('failed running query', exc_info=err) + raise SQLiteOperationalError(context.metrics) + + +def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]: + if 'order_by' in constraints: + order_by_parts = constraints['order_by'] + if isinstance(order_by_parts, str): + order_by_parts = [order_by_parts] + sql_order_by = [] + for order_by in order_by_parts: + is_asc = order_by.startswith('^') + column = order_by[1:] if is_asc else order_by + if column not in ORDER_FIELDS: + raise NameError(f'{column} is not a valid order_by field') + if column == 'name': + column = 'normalized' + sql_order_by.append( + f"claim.{column} ASC" if is_asc else f"claim.{column} DESC" + ) + constraints['order_by'] = sql_order_by + + ops = {'<=': '__lte', '>=': '__gte', '<': '__lt', '>': '__gt'} + for constraint in INTEGER_PARAMS: + if constraint in constraints: + value = constraints.pop(constraint) + postfix = '' + if isinstance(value, str): + if len(value) >= 2 and value[:2] in ops: + postfix, value = ops[value[:2]], value[2:] + elif len(value) >= 1 and value[0] in ops: + postfix, value = ops[value[0]], value[1:] + if constraint == 'fee_amount': + value = Decimal(value)*1000 + constraints[f'claim.{constraint}{postfix}'] = int(value) + + if constraints.pop('is_controlling', False): + if {'sequence', 'amount_order'}.isdisjoint(constraints): + for_count = False + constraints['claimtrie.claim_hash__is_not_null'] = '' + if 'sequence' in constraints: + constraints['order_by'] = 'claim.activation_height ASC' + constraints['offset'] = int(constraints.pop('sequence')) - 1 + constraints['limit'] = 1 + if 'amount_order' in constraints: + constraints['order_by'] = 'claim.effective_amount DESC' + constraints['offset'] = int(constraints.pop('amount_order')) - 1 + constraints['limit'] = 1 + + if 'claim_id' in constraints: + claim_id = constraints.pop('claim_id') + if len(claim_id) == 40: + constraints['claim.claim_id'] = claim_id + else: + constraints['claim.claim_id__like'] = f'{claim_id[:40]}%' + elif 'claim_ids' in constraints: + constraints['claim.claim_id__in'] = set(constraints.pop('claim_ids')) + + if 'reposted_claim_id' in constraints: + constraints['claim.reposted_claim_hash'] = unhexlify(constraints.pop('reposted_claim_id'))[::-1] + + if 'name' in constraints: + constraints['claim.normalized'] = normalize_name(constraints.pop('name')) + + if 'public_key_id' in constraints: + constraints['claim.public_key_hash'] = ( + ctx.get().ledger.address_to_hash160(constraints.pop('public_key_id'))) + if 'channel_hash' in constraints: + constraints['claim.channel_hash'] = constraints.pop('channel_hash') + if 'channel_ids' in constraints: + channel_ids = constraints.pop('channel_ids') + if channel_ids: + constraints['claim.channel_hash__in'] = { + unhexlify(cid)[::-1] for cid in channel_ids if cid + } + if 'not_channel_ids' in constraints: + not_channel_ids = constraints.pop('not_channel_ids') + if not_channel_ids: + not_channel_ids_binary = { + unhexlify(ncid)[::-1] for ncid in not_channel_ids + } + constraints['claim.claim_hash__not_in#not_channel_ids'] = not_channel_ids_binary + if constraints.get('has_channel_signature', False): + constraints['claim.channel_hash__not_in'] = not_channel_ids_binary + else: + constraints['null_or_not_channel__or'] = { + 'claim.signature_valid__is_null': True, + 'claim.channel_hash__not_in': not_channel_ids_binary + } + if 'signature_valid' in constraints: + has_channel_signature = constraints.pop('has_channel_signature', False) + if has_channel_signature: + constraints['claim.signature_valid'] = constraints.pop('signature_valid') + else: + constraints['null_or_signature__or'] = { + 'claim.signature_valid__is_null': True, + 'claim.signature_valid': constraints.pop('signature_valid') + } + elif constraints.pop('has_channel_signature', False): + constraints['claim.signature_valid__is_not_null'] = True + + if 'txid' in constraints: + tx_hash = unhexlify(constraints.pop('txid'))[::-1] + nout = constraints.pop('nout', 0) + constraints['claim.txo_hash'] = tx_hash + struct.pack(' List: + if 'channel' in constraints: + channel_url = constraints.pop('channel') + match = resolve_url(channel_url) + if isinstance(match, dict): + constraints['channel_hash'] = match['claim_hash'] + else: + return [{'row_count': 0}] if cols == 'count(*) as row_count' else [] + row_offset = constraints.pop('offset', 0) + row_limit = constraints.pop('limit', 20) + sql, values = claims_query(cols, for_count, **constraints) + return execute_query(sql, values, row_offset, row_limit, censor) + + +@measure +def count_claims(**constraints) -> int: + constraints.pop('offset', None) + constraints.pop('limit', None) + constraints.pop('order_by', None) + count = select_claims(Censor(Censor.SEARCH), 'count(*) as row_count', for_count=True, **constraints) + return count[0]['row_count'] + + +def search_claims(censor: Censor, **constraints) -> List: + return select_claims( + censor, + """ + claimtrie.claim_hash as is_controlling, + claimtrie.last_take_over_height, + claim.claim_hash, claim.txo_hash, + claim.claims_in_channel, claim.reposted, + claim.height, claim.creation_height, + claim.activation_height, claim.expiration_height, + claim.effective_amount, claim.support_amount, + claim.trending_group, claim.trending_mixed, + claim.trending_local, claim.trending_global, + claim.short_url, claim.canonical_url, + claim.channel_hash, claim.reposted_claim_hash, + claim.signature_valid + """, **constraints + ) + + +def _get_referenced_rows(txo_rows: List[dict], censor_channels: List[bytes]): + censor = ctx.get().get_resolve_censor() + repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) + channel_hashes = set(chain( + filter(None, map(itemgetter('channel_hash'), txo_rows)), + censor_channels + )) + + reposted_txos = [] + if repost_hashes: + reposted_txos = search_claims(censor, **{'claim.claim_hash__in': repost_hashes}) + channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) + + channel_txos = [] + if channel_hashes: + channel_txos = search_claims(censor, **{'claim.claim_hash__in': channel_hashes}) + + # channels must come first for client side inflation to work properly + return channel_txos + reposted_txos + +@measure +def search(constraints) -> Tuple[List, List, int, int, Censor]: + assert set(constraints).issubset(SEARCH_PARAMS), \ + f"Search query contains invalid arguments: {set(constraints).difference(SEARCH_PARAMS)}" + total = None + limit_claims_per_channel = constraints.pop('limit_claims_per_channel', None) + if not constraints.pop('no_totals', False): + total = count_claims(**constraints) + constraints['offset'] = abs(constraints.get('offset', 0)) + constraints['limit'] = min(abs(constraints.get('limit', 10)), 50) + context = ctx.get() + search_censor = context.get_search_censor(limit_claims_per_channel) + txo_rows = search_claims(search_censor, **constraints) + extra_txo_rows = _get_referenced_rows(txo_rows, search_censor.censored.keys()) + return txo_rows, extra_txo_rows, constraints['offset'], total, search_censor + + +@measure +def resolve(urls) -> Tuple[List, List]: + txo_rows = [resolve_url(raw_url) for raw_url in urls] + extra_txo_rows = _get_referenced_rows( + [txo for txo in txo_rows if isinstance(txo, dict)], + [txo.censor_hash for txo in txo_rows if isinstance(txo, ResolveCensoredError)] + ) + return txo_rows, extra_txo_rows + + +@measure +def resolve_url(raw_url): + censor = ctx.get().get_resolve_censor() + + try: + url = URL.parse(raw_url) + except ValueError as e: + return e + + channel = None + + if url.has_channel: + query = url.channel.to_dict() + if set(query) == {'name'}: + query['is_controlling'] = True + else: + query['order_by'] = ['^creation_height'] + matches = search_claims(censor, **query, limit=1) + if matches: + channel = matches[0] + elif censor.censored: + return ResolveCensoredError(raw_url, next(iter(censor.censored))) + else: + return LookupError(f'Could not find channel in "{raw_url}".') + + if url.has_stream: + query = url.stream.to_dict() + if channel is not None: + if set(query) == {'name'}: + # temporarily emulate is_controlling for claims in channel + query['order_by'] = ['effective_amount', '^height'] + else: + query['order_by'] = ['^channel_join'] + query['channel_hash'] = channel['claim_hash'] + query['signature_valid'] = 1 + elif set(query) == {'name'}: + query['is_controlling'] = 1 + matches = search_claims(censor, **query, limit=1) + if matches: + return matches[0] + elif censor.censored: + return ResolveCensoredError(raw_url, next(iter(censor.censored))) + else: + return LookupError(f'Could not find claim at "{raw_url}".') + + return channel + + +CLAIM_HASH_OR_REPOST_HASH_SQL = f""" +CASE WHEN claim.claim_type = {CLAIM_TYPES['repost']} + THEN claim.reposted_claim_hash + ELSE claim.claim_hash +END +""" + + +def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_count=False): + any_items = set(cleaner(constraints.pop(f'any_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) + all_items = set(cleaner(constraints.pop(f'all_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) + not_items = set(cleaner(constraints.pop(f'not_{attr}s', []))[:ATTRIBUTE_ARRAY_MAX_LENGTH]) + + all_items = {item for item in all_items if item not in not_items} + any_items = {item for item in any_items if item not in not_items} + + any_queries = {} + + if attr == 'tag': + common_tags = any_items & COMMON_TAGS.keys() + if common_tags: + any_items -= common_tags + if len(common_tags) < 5: + for item in common_tags: + index_name = COMMON_TAGS[item] + any_queries[f'#_common_tag_{index_name}'] = f""" + EXISTS( + SELECT 1 FROM tag INDEXED BY tag_{index_name}_idx + WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=tag.claim_hash + AND tag = '{item}' + ) + """ + elif len(common_tags) >= 5: + constraints.update({ + f'$any_common_tag{i}': item for i, item in enumerate(common_tags) + }) + values = ', '.join( + f':$any_common_tag{i}' for i in range(len(common_tags)) + ) + any_queries[f'#_any_common_tags'] = f""" + EXISTS( + SELECT 1 FROM tag WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=tag.claim_hash + AND tag IN ({values}) + ) + """ + elif attr == 'language': + indexed_languages = any_items & set(INDEXED_LANGUAGES) + if indexed_languages: + any_items -= indexed_languages + for language in indexed_languages: + any_queries[f'#_any_common_languages_{language}'] = f""" + EXISTS( + SELECT 1 FROM language INDEXED BY language_{language}_idx + WHERE {CLAIM_HASH_OR_REPOST_HASH_SQL}=language.claim_hash + AND language = '{language}' + ) + """ + + if any_items: + + constraints.update({ + f'$any_{attr}{i}': item for i, item in enumerate(any_items) + }) + values = ', '.join( + f':$any_{attr}{i}' for i in range(len(any_items)) + ) + if for_count or attr == 'tag': + if attr == 'tag': + any_queries[f'#_any_{attr}'] = f""" + ((claim.claim_type != {CLAIM_TYPES['repost']} + AND claim.claim_hash IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) OR + (claim.claim_type == {CLAIM_TYPES['repost']} AND + claim.reposted_claim_hash IN (SELECT claim_hash FROM tag WHERE tag IN ({values})))) + """ + else: + any_queries[f'#_any_{attr}'] = f""" + {CLAIM_HASH_OR_REPOST_HASH_SQL} IN ( + SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) + ) + """ + else: + any_queries[f'#_any_{attr}'] = f""" + EXISTS( + SELECT 1 FROM {attr} WHERE + {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash + AND {attr} IN ({values}) + ) + """ + + if len(any_queries) == 1: + constraints.update(any_queries) + elif len(any_queries) > 1: + constraints[f'ORed_{attr}_queries__any'] = any_queries + + if all_items: + constraints[f'$all_{attr}_count'] = len(all_items) + constraints.update({ + f'$all_{attr}{i}': item for i, item in enumerate(all_items) + }) + values = ', '.join( + f':$all_{attr}{i}' for i in range(len(all_items)) + ) + if for_count: + constraints[f'#_all_{attr}'] = f""" + {CLAIM_HASH_OR_REPOST_HASH_SQL} IN ( + SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) + GROUP BY claim_hash HAVING COUNT({attr}) = :$all_{attr}_count + ) + """ + else: + constraints[f'#_all_{attr}'] = f""" + {len(all_items)}=( + SELECT count(*) FROM {attr} WHERE + {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash + AND {attr} IN ({values}) + ) + """ + + if not_items: + constraints.update({ + f'$not_{attr}{i}': item for i, item in enumerate(not_items) + }) + values = ', '.join( + f':$not_{attr}{i}' for i in range(len(not_items)) + ) + if for_count: + if attr == 'tag': + constraints[f'#_not_{attr}'] = f""" + ((claim.claim_type != {CLAIM_TYPES['repost']} + AND claim.claim_hash NOT IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) OR + (claim.claim_type == {CLAIM_TYPES['repost']} AND + claim.reposted_claim_hash NOT IN (SELECT claim_hash FROM tag WHERE tag IN ({values})))) + """ + else: + constraints[f'#_not_{attr}'] = f""" + {CLAIM_HASH_OR_REPOST_HASH_SQL} NOT IN ( + SELECT claim_hash FROM {attr} WHERE {attr} IN ({values}) + ) + """ + else: + constraints[f'#_not_{attr}'] = f""" + NOT EXISTS( + SELECT 1 FROM {attr} WHERE + {CLAIM_HASH_OR_REPOST_HASH_SQL}={attr}.claim_hash + AND {attr} IN ({values}) + ) + """ diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py new file mode 100644 index 000000000..52753ad99 --- /dev/null +++ b/tests/unit/wallet/server/test_sqldb.py @@ -0,0 +1,765 @@ +import unittest +import ecdsa +import hashlib +import logging +from binascii import hexlify +from typing import List, Tuple + +from lbry.wallet.constants import COIN, NULL_HASH32 +from lbry.schema.claim import Claim +from lbry.schema.result import Censor +from lbry.wallet.server.db import writer +from lbry.wallet.server.coin import LBCRegTest +from lbry.wallet.server.db.trending import zscore +from lbry.wallet.server.db.canonical import FindShortestID +from lbry.wallet.server.block_processor import Timer +from lbry.wallet.transaction import Transaction, Input, Output +try: + import reader +except: + from . import reader + + +def get_output(amount=COIN, pubkey_hash=NULL_HASH32): + return Transaction() \ + .add_outputs([Output.pay_pubkey_hash(amount, pubkey_hash)]) \ + .outputs[0] + + +def get_input(): + return Input.spend(get_output()) + + +def get_tx(): + return Transaction().add_inputs([get_input()]) + + +def search(**constraints) -> List: + return reader.search_claims(Censor(Censor.SEARCH), **constraints) + + +def censored_search(**constraints) -> Tuple[List, Censor]: + rows, _, _, _, censor = reader.search(constraints) + return rows, censor + + +class TestSQLDB(unittest.TestCase): + query_timeout = 0.25 + + def setUp(self): + self.first_sync = False + self.daemon_height = 1 + self.coin = LBCRegTest() + db_url = 'file:test_sqldb?mode=memory&cache=shared' + self.sql = writer.SQLDB(self, db_url, [], [], [zscore]) + self.addCleanup(self.sql.close) + self.sql.open() + reader.initializer( + logging.getLogger(__name__), db_url, 'regtest', + self.query_timeout, block_and_filter=( + self.sql.blocked_streams, self.sql.blocked_channels, + self.sql.filtered_streams, self.sql.filtered_channels + ) + ) + self.addCleanup(reader.cleanup) + self.timer = Timer('BlockProcessor') + self._current_height = 0 + self._txos = {} + + def _make_tx(self, output, txi=None): + tx = get_tx().add_outputs([output]) + if txi is not None: + tx.add_inputs([txi]) + self._txos[output.ref.hash] = output + return tx, tx.hash + + def _set_channel_key(self, channel, key): + private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) + channel.private_key = private_key + channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() + channel.script.generate() + + def get_channel(self, title, amount, name='@foo', key=b'a'): + claim = Claim() + claim.channel.title = title + channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc') + self._set_channel_key(channel, key) + return self._make_tx(channel) + + def get_channel_update(self, channel, amount, key=b'a'): + self._set_channel_key(channel, key) + return self._make_tx( + Output.pay_update_claim_pubkey_hash( + amount, channel.claim_name, channel.claim_id, channel.claim, b'abc' + ), + Input.spend(channel) + ) + + def get_stream(self, title, amount, name='foo', channel=None, **kwargs): + claim = Claim() + claim.stream.update(title=title, **kwargs) + result = self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')) + if channel: + result[0].outputs[0].sign(channel) + result[0]._reset() + return result + + def get_stream_update(self, tx, amount, channel=None): + stream = Transaction(tx[0].raw).outputs[0] + result = self._make_tx( + Output.pay_update_claim_pubkey_hash( + amount, stream.claim_name, stream.claim_id, stream.claim, b'abc' + ), + Input.spend(stream) + ) + if channel: + result[0].outputs[0].sign(channel) + result[0]._reset() + return result + + def get_repost(self, claim_id, amount, channel): + claim = Claim() + claim.repost.reference.claim_id = claim_id + result = self._make_tx(Output.pay_claim_name_pubkey_hash(amount, 'repost', claim, b'abc')) + result[0].outputs[0].sign(channel) + result[0]._reset() + return result + + def get_abandon(self, tx): + claim = Transaction(tx[0].raw).outputs[0] + return self._make_tx( + Output.pay_pubkey_hash(claim.amount, b'abc'), + Input.spend(claim) + ) + + def get_support(self, tx, amount): + claim = Transaction(tx[0].raw).outputs[0] + return self._make_tx( + Output.pay_support_pubkey_hash( + amount, claim.claim_name, claim.claim_id, b'abc' + ) + ) + + def get_controlling(self): + for claim in self.sql.execute("select claim.* from claimtrie natural join claim"): + txo = self._txos[claim.txo_hash] + controlling = txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height + return controlling + + def get_active(self): + controlling = self.get_controlling() + active = [] + for claim in self.sql.execute( + f"select * from claim where activation_height <= {self._current_height}"): + txo = self._txos[claim.txo_hash] + if controlling and controlling[0] == txo.claim.stream.title: + continue + active.append((txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height)) + return active + + def get_accepted(self): + accepted = [] + for claim in self.sql.execute( + f"select * from claim where activation_height > {self._current_height}"): + txo = self._txos[claim.txo_hash] + accepted.append((txo.claim.stream.title, claim.amount, claim.effective_amount, claim.activation_height)) + return accepted + + def advance(self, height, txs): + self._current_height = height + self.sql.advance_txs(height, txs, {'timestamp': 1}, self.daemon_height, self.timer) + return [otx[0].outputs[0] for otx in txs] + + def state(self, controlling=None, active=None, accepted=None): + self.assertEqual(controlling, self.get_controlling()) + self.assertEqual(active or [], self.get_active()) + self.assertEqual(accepted or [], self.get_accepted()) + + +class TestClaimtrie(TestSQLDB): + + def test_example_from_spec(self): + # https://spec.lbry.com/#claim-activation-example + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(13, [stream]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[], + accepted=[] + ) + advance(1001, [self.get_stream('Claim B', 20*COIN)]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[], + accepted=[('Claim B', 20*COIN, 0, 1031)] + ) + advance(1010, [self.get_support(stream, 14*COIN)]) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[], + accepted=[('Claim B', 20*COIN, 0, 1031)] + ) + advance(1020, [self.get_stream('Claim C', 50*COIN)]) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[], + accepted=[ + ('Claim B', 20*COIN, 0, 1031), + ('Claim C', 50*COIN, 0, 1051)] + ) + advance(1031, []) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[('Claim B', 20*COIN, 20*COIN, 1031)], + accepted=[('Claim C', 50*COIN, 0, 1051)] + ) + advance(1040, [self.get_stream('Claim D', 300*COIN)]) + state( + controlling=('Claim A', 10*COIN, 24*COIN, 13), + active=[('Claim B', 20*COIN, 20*COIN, 1031)], + accepted=[ + ('Claim C', 50*COIN, 0, 1051), + ('Claim D', 300*COIN, 0, 1072)] + ) + advance(1051, []) + state( + controlling=('Claim D', 300*COIN, 300*COIN, 1051), + active=[ + ('Claim A', 10*COIN, 24*COIN, 13), + ('Claim B', 20*COIN, 20*COIN, 1031), + ('Claim C', 50*COIN, 50*COIN, 1051)], + accepted=[] + ) + # beyond example + advance(1052, [self.get_stream_update(stream, 290*COIN)]) + state( + controlling=('Claim A', 290*COIN, 304*COIN, 13), + active=[ + ('Claim B', 20*COIN, 20*COIN, 1031), + ('Claim C', 50*COIN, 50*COIN, 1051), + ('Claim D', 300*COIN, 300*COIN, 1051), + ], + accepted=[] + ) + + def test_competing_claims_subsequent_blocks_height_wins(self): + advance, state = self.advance, self.state + advance(13, [self.get_stream('Claim A', 10*COIN)]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[], + accepted=[] + ) + advance(14, [self.get_stream('Claim B', 10*COIN)]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[('Claim B', 10*COIN, 10*COIN, 14)], + accepted=[] + ) + advance(15, [self.get_stream('Claim C', 10*COIN)]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[ + ('Claim B', 10*COIN, 10*COIN, 14), + ('Claim C', 10*COIN, 10*COIN, 15)], + accepted=[] + ) + + def test_competing_claims_in_single_block_position_wins(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + stream2 = self.get_stream('Claim B', 10*COIN) + advance(13, [stream, stream2]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[('Claim B', 10*COIN, 10*COIN, 13)], + accepted=[] + ) + + def test_competing_claims_in_single_block_effective_amount_wins(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + stream2 = self.get_stream('Claim B', 11*COIN) + advance(13, [stream, stream2]) + state( + controlling=('Claim B', 11*COIN, 11*COIN, 13), + active=[('Claim A', 10*COIN, 10*COIN, 13)], + accepted=[] + ) + + def test_winning_claim_deleted(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + stream2 = self.get_stream('Claim B', 11*COIN) + advance(13, [stream, stream2]) + state( + controlling=('Claim B', 11*COIN, 11*COIN, 13), + active=[('Claim A', 10*COIN, 10*COIN, 13)], + accepted=[] + ) + advance(14, [self.get_abandon(stream2)]) + state( + controlling=('Claim A', 10*COIN, 10*COIN, 13), + active=[], + accepted=[] + ) + + def test_winning_claim_deleted_and_new_claim_becomes_winner(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + stream2 = self.get_stream('Claim B', 11*COIN) + advance(13, [stream, stream2]) + state( + controlling=('Claim B', 11*COIN, 11*COIN, 13), + active=[('Claim A', 10*COIN, 10*COIN, 13)], + accepted=[] + ) + advance(15, [self.get_abandon(stream2), self.get_stream('Claim C', 12*COIN)]) + state( + controlling=('Claim C', 12*COIN, 12*COIN, 15), + active=[('Claim A', 10*COIN, 10*COIN, 13)], + accepted=[] + ) + + def test_winning_claim_expires_and_another_takes_over(self): + advance, state = self.advance, self.state + advance(10, [self.get_stream('Claim A', 11*COIN)]) + advance(20, [self.get_stream('Claim B', 10*COIN)]) + state( + controlling=('Claim A', 11*COIN, 11*COIN, 10), + active=[('Claim B', 10*COIN, 10*COIN, 20)], + accepted=[] + ) + advance(262984, []) + state( + controlling=('Claim B', 10*COIN, 10*COIN, 20), + active=[], + accepted=[] + ) + advance(262994, []) + state( + controlling=None, + active=[], + accepted=[] + ) + + def test_create_and_update_in_same_block(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(10, [stream, self.get_stream_update(stream, 11*COIN)]) + self.assertTrue(search()[0]) + + def test_double_updates_in_same_block(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(10, [stream]) + update = self.get_stream_update(stream, 11*COIN) + advance(20, [update, self.get_stream_update(update, 9*COIN)]) + self.assertTrue(search()[0]) + + def test_create_and_abandon_in_same_block(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(10, [stream, self.get_abandon(stream)]) + self.assertFalse(search()) + + def test_update_and_abandon_in_same_block(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(10, [stream]) + update = self.get_stream_update(stream, 11*COIN) + advance(20, [update, self.get_abandon(update)]) + self.assertFalse(search()) + + def test_create_update_and_delete_in_same_block(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + update = self.get_stream_update(stream, 11*COIN) + advance(10, [stream, update, self.get_abandon(update)]) + self.assertFalse(search()) + + def test_support_added_and_removed_in_same_block(self): + advance, state = self.advance, self.state + stream = self.get_stream('Claim A', 10*COIN) + advance(10, [stream]) + support = self.get_support(stream, COIN) + advance(20, [support, self.get_abandon(support)]) + self.assertEqual(search()[0]['support_amount'], 0) + + @staticmethod + def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs): + iterations = cached_iteration+1 if cached_iteration else 100 + for i in range(cached_iteration or 1, iterations): + stream = getter(f'claim #{i}', COIN, **kwargs) + if stream[0].outputs[0].claim_id.startswith(prefix): + cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.') + return stream + if cached_iteration: + raise ValueError(f'Failed to find "{prefix}" at cached iteration, run with None to find iteration.') + raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations, try different values.') + + def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): + return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration, **kwargs) + + def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): + return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs) + + def test_canonical_url_and_channel_validation(self): + advance = self.advance + + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c') + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c') + txo_chan_a = tx_chan_a[0].outputs[0] + txo_chan_ab = tx_chan_ab[0].outputs[0] + advance(1, [tx_chan_a]) + advance(2, [tx_chan_ab]) + (r_ab, r_a) = search(order_by=['creation_height'], limit=2) + self.assertEqual("@foo#a", r_a['short_url']) + self.assertEqual("@foo#ab", r_ab['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertEqual(0, r_a['claims_in_channel']) + self.assertEqual(0, r_ab['claims_in_channel']) + + tx_a = self.get_stream_with_claim_id_prefix('a', 2) + tx_ab = self.get_stream_with_claim_id_prefix('ab', 42) + tx_abc = self.get_stream_with_claim_id_prefix('abc', 65) + advance(3, [tx_a]) + advance(4, [tx_ab, tx_abc]) + (r_abc, r_ab, r_a) = search(order_by=['creation_height', 'tx_position'], limit=3) + self.assertEqual("foo#a", r_a['short_url']) + self.assertEqual("foo#ab", r_ab['short_url']) + self.assertEqual("foo#abc", r_abc['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertIsNone(r_abc['canonical_url']) + + tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a) + tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a) + a2_claim = tx_a2[0].outputs[0] + ab2_claim = tx_ab2[0].outputs[0] + advance(6, [tx_a2]) + advance(7, [tx_ab2]) + (r_ab2, r_a2) = search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim.claim_id[:4]}", r_ab2['short_url']) + self.assertEqual("@foo#a/foo#a", r_a2['canonical_url']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) + self.assertEqual(2, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + + # change channel public key, invaliding stream claim signatures + advance(8, [self.get_channel_update(txo_chan_a, COIN, key=b'a')]) + (r_ab2, r_a2) = search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim.claim_id[:4]}", r_ab2['short_url']) + self.assertIsNone(r_a2['canonical_url']) + self.assertIsNone(r_ab2['canonical_url']) + self.assertEqual(0, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + + # reinstate previous channel public key (previous stream claim signatures become valid again) + channel_update = self.get_channel_update(txo_chan_a, COIN, key=b'c') + advance(9, [channel_update]) + (r_ab2, r_a2) = search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim.claim_id[:4]}", r_ab2['short_url']) + self.assertEqual("@foo#a/foo#a", r_a2['canonical_url']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) + self.assertEqual(2, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + self.assertEqual(0, search(claim_id=txo_chan_ab.claim_id, limit=1)[0]['claims_in_channel']) + + # change channel of stream + self.assertEqual("@foo#a/foo#ab", search(claim_id=ab2_claim.claim_id, limit=1)[0]['canonical_url']) + tx_ab2 = self.get_stream_update(tx_ab2, COIN, txo_chan_ab) + advance(10, [tx_ab2]) + self.assertEqual("@foo#ab/foo#a", search(claim_id=ab2_claim.claim_id, limit=1)[0]['canonical_url']) + # TODO: currently there is a bug where stream leaving a channel does not update that channels claims count + self.assertEqual(2, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + # TODO: after bug is fixed remove test above and add test below + #self.assertEqual(1, search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + self.assertEqual(1, search(claim_id=txo_chan_ab.claim_id, limit=1)[0]['claims_in_channel']) + + # claim abandon updates claims_in_channel + advance(11, [self.get_abandon(tx_ab2)]) + self.assertEqual(0, search(claim_id=txo_chan_ab.claim_id, limit=1)[0]['claims_in_channel']) + + # delete channel, invaliding stream claim signatures + advance(12, [self.get_abandon(channel_update)]) + (r_a2,) = search(order_by=['creation_height'], limit=1) + self.assertEqual(f"foo#{a2_claim.claim_id[:2]}", r_a2['short_url']) + self.assertIsNone(r_a2['canonical_url']) + + def test_resolve_issue_2448(self): + advance = self.advance + + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c') + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c') + txo_chan_a = tx_chan_a[0].outputs[0] + txo_chan_ab = tx_chan_ab[0].outputs[0] + advance(1, [tx_chan_a]) + advance(2, [tx_chan_ab]) + + self.assertEqual(reader.resolve_url("@foo#a")['claim_hash'], txo_chan_a.claim_hash) + self.assertEqual(reader.resolve_url("@foo#ab")['claim_hash'], txo_chan_ab.claim_hash) + + # update increase last height change of channel + advance(9, [self.get_channel_update(txo_chan_a, COIN, key=b'c')]) + + # make sure that activation_height is used instead of height (issue #2448) + self.assertEqual(reader.resolve_url("@foo#a")['claim_hash'], txo_chan_a.claim_hash) + self.assertEqual(reader.resolve_url("@foo#ab")['claim_hash'], txo_chan_ab.claim_hash) + + def test_canonical_find_shortest_id(self): + new_hash = 'abcdef0123456789beef' + other0 = '1bcdef0123456789beef' + other1 = 'ab1def0123456789beef' + other2 = 'abc1ef0123456789beef' + other3 = 'abcdef0123456789bee1' + f = FindShortestID() + f.step(other0, new_hash) + self.assertEqual('#a', f.finalize()) + f.step(other1, new_hash) + self.assertEqual('#abc', f.finalize()) + f.step(other2, new_hash) + self.assertEqual('#abcd', f.finalize()) + f.step(other3, new_hash) + self.assertEqual('#abcdef0123456789beef', f.finalize()) + + +class TestTrending(TestSQLDB): + + def test_trending(self): + advance, state = self.advance, self.state + no_trend = self.get_stream('Claim A', COIN) + downwards = self.get_stream('Claim B', COIN) + up_small = self.get_stream('Claim C', COIN) + up_medium = self.get_stream('Claim D', COIN) + up_biggly = self.get_stream('Claim E', COIN) + claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) + for window in range(1, 8): + advance(zscore.TRENDING_WINDOW * window, [ + self.get_support(downwards, (20-window)*COIN), + self.get_support(up_small, int(20+(window/10)*COIN)), + self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), + self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN), + ]) + results = search(order_by=['trending_local']) + self.assertEqual([c.claim_id for c in claims], [hexlify(c['claim_hash'][::-1]).decode() for c in results]) + self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results]) + self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results]) + self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results]) + self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results]) + + def test_edge(self): + problematic = self.get_stream('Problem', COIN) + self.advance(1, [problematic]) + self.advance(zscore.TRENDING_WINDOW, [self.get_support(problematic, 53000000000)]) + self.advance(zscore.TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)]) + + +@unittest.skip("filtering/blocking is applied during ES sync, this needs to be ported to integration test") +class TestContentBlocking(TestSQLDB): + + def test_blocking_and_filtering(self): + # content claims and channels + tx0 = self.get_channel('A Channel', COIN, '@channel1') + regular_channel = tx0[0].outputs[0] + tx1 = self.get_stream('Claim One', COIN, 'claim1') + tx2 = self.get_stream('Claim Two', COIN, 'claim2', regular_channel) + tx3 = self.get_stream('Claim Three', COIN, 'claim3') + self.advance(1, [tx0, tx1, tx2, tx3]) + claim1, claim2, claim3 = tx1[0].outputs[0], tx2[0].outputs[0], tx3[0].outputs[0] + + # block and filter channels + tx0 = self.get_channel('Blocking Channel', COIN, '@block') + tx1 = self.get_channel('Filtering Channel', COIN, '@filter') + blocking_channel = tx0[0].outputs[0] + filtering_channel = tx1[0].outputs[0] + self.sql.blocking_channel_hashes.add(blocking_channel.claim_hash) + self.sql.filtering_channel_hashes.add(filtering_channel.claim_hash) + self.advance(2, [tx0, tx1]) + self.assertEqual({}, dict(self.sql.blocked_streams)) + self.assertEqual({}, dict(self.sql.blocked_channels)) + self.assertEqual({}, dict(self.sql.filtered_streams)) + self.assertEqual({}, dict(self.sql.filtered_channels)) + + # nothing blocked + results, _ = reader.resolve([ + claim1.claim_name, claim2.claim_name, + claim3.claim_name, regular_channel.claim_name + ]) + self.assertEqual(claim1.claim_hash, results[0]['claim_hash']) + self.assertEqual(claim2.claim_hash, results[1]['claim_hash']) + self.assertEqual(claim3.claim_hash, results[2]['claim_hash']) + self.assertEqual(regular_channel.claim_hash, results[3]['claim_hash']) + + # nothing filtered + results, censor = censored_search() + self.assertEqual(6, len(results)) + self.assertEqual(0, censor.total) + self.assertEqual({}, censor.censored) + + # block claim reposted to blocking channel, also gets filtered + repost_tx1 = self.get_repost(claim1.claim_id, COIN, blocking_channel) + repost1 = repost_tx1[0].outputs[0] + self.advance(3, [repost_tx1]) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_streams) + ) + self.assertEqual({}, dict(self.sql.blocked_channels)) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_streams) + ) + self.assertEqual({}, dict(self.sql.filtered_channels)) + + # claim is blocked from results by direct repost + results, censor = censored_search(text='Claim') + self.assertEqual(2, len(results)) + self.assertEqual(claim2.claim_hash, results[0]['claim_hash']) + self.assertEqual(claim3.claim_hash, results[1]['claim_hash']) + self.assertEqual(1, censor.total) + self.assertEqual({blocking_channel.claim_hash: 1}, censor.censored) + results, _ = reader.resolve([claim1.claim_name]) + self.assertEqual( + f"Resolve of 'claim1' was censored by channel with claim id '{blocking_channel.claim_id}'.", + results[0].args[0] + ) + results, _ = reader.resolve([ + claim2.claim_name, regular_channel.claim_name # claim2 and channel still resolved + ]) + self.assertEqual(claim2.claim_hash, results[0]['claim_hash']) + self.assertEqual(regular_channel.claim_hash, results[1]['claim_hash']) + + # block claim indirectly by blocking its parent channel + repost_tx2 = self.get_repost(regular_channel.claim_id, COIN, blocking_channel) + repost2 = repost_tx2[0].outputs[0] + self.advance(4, [repost_tx2]) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_channels) + ) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_channels) + ) + + # claim in blocked channel is filtered from search and can't resolve + results, censor = censored_search(text='Claim') + self.assertEqual(1, len(results)) + self.assertEqual(claim3.claim_hash, results[0]['claim_hash']) + self.assertEqual(2, censor.total) + self.assertEqual({blocking_channel.claim_hash: 2}, censor.censored) + results, _ = reader.resolve([ + claim2.claim_name, regular_channel.claim_name # claim2 and channel don't resolve + ]) + self.assertEqual( + f"Resolve of 'claim2' was censored by channel with claim id '{blocking_channel.claim_id}'.", + results[0].args[0] + ) + self.assertEqual( + f"Resolve of '@channel1' was censored by channel with claim id '{blocking_channel.claim_id}'.", + results[1].args[0] + ) + results, _ = reader.resolve([claim3.claim_name]) # claim3 still resolved + self.assertEqual(claim3.claim_hash, results[0]['claim_hash']) + + # filtered claim is only filtered and not blocked + repost_tx3 = self.get_repost(claim3.claim_id, COIN, filtering_channel) + repost3 = repost_tx3[0].outputs[0] + self.advance(5, [repost_tx3]) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.blocked_channels) + ) + self.assertEqual( + {repost1.claim.repost.reference.claim_hash: blocking_channel.claim_hash, + repost3.claim.repost.reference.claim_hash: filtering_channel.claim_hash}, + dict(self.sql.filtered_streams) + ) + self.assertEqual( + {repost2.claim.repost.reference.claim_hash: blocking_channel.claim_hash}, + dict(self.sql.filtered_channels) + ) + + # filtered claim doesn't return in search but is resolveable + results, censor = censored_search(text='Claim') + self.assertEqual(0, len(results)) + self.assertEqual(3, censor.total) + self.assertEqual({blocking_channel.claim_hash: 2, filtering_channel.claim_hash: 1}, censor.censored) + results, _ = reader.resolve([claim3.claim_name]) # claim3 still resolved + self.assertEqual(claim3.claim_hash, results[0]['claim_hash']) + + # abandon unblocks content + self.advance(6, [ + self.get_abandon(repost_tx1), + self.get_abandon(repost_tx2), + self.get_abandon(repost_tx3) + ]) + self.assertEqual({}, dict(self.sql.blocked_streams)) + self.assertEqual({}, dict(self.sql.blocked_channels)) + self.assertEqual({}, dict(self.sql.filtered_streams)) + self.assertEqual({}, dict(self.sql.filtered_channels)) + results, censor = censored_search(text='Claim') + self.assertEqual(3, len(results)) + self.assertEqual(0, censor.total) + results, censor = censored_search() + self.assertEqual(6, len(results)) + self.assertEqual(0, censor.total) + results, _ = reader.resolve([ + claim1.claim_name, claim2.claim_name, + claim3.claim_name, regular_channel.claim_name + ]) + self.assertEqual(claim1.claim_hash, results[0]['claim_hash']) + self.assertEqual(claim2.claim_hash, results[1]['claim_hash']) + self.assertEqual(claim3.claim_hash, results[2]['claim_hash']) + self.assertEqual(regular_channel.claim_hash, results[3]['claim_hash']) + + def test_pagination(self): + one, two, three, four, five, six, seven, filter_channel = self.advance(1, [ + self.get_stream('One', COIN), + self.get_stream('Two', COIN), + self.get_stream('Three', COIN), + self.get_stream('Four', COIN), + self.get_stream('Five', COIN), + self.get_stream('Six', COIN), + self.get_stream('Seven', COIN), + self.get_channel('Filtering Channel', COIN, '@filter'), + ]) + self.sql.filtering_channel_hashes.add(filter_channel.claim_hash) + + # nothing filtered + results, censor = censored_search(order_by='^height', offset=1, limit=3) + self.assertEqual(3, len(results)) + self.assertEqual( + [two.claim_hash, three.claim_hash, four.claim_hash], + [r['claim_hash'] for r in results] + ) + self.assertEqual(0, censor.total) + + # content filtered + repost1, repost2 = self.advance(2, [ + self.get_repost(one.claim_id, COIN, filter_channel), + self.get_repost(two.claim_id, COIN, filter_channel), + ]) + results, censor = censored_search(order_by='^height', offset=1, limit=3) + self.assertEqual(3, len(results)) + self.assertEqual( + [four.claim_hash, five.claim_hash, six.claim_hash], + [r['claim_hash'] for r in results] + ) + self.assertEqual(2, censor.total) + self.assertEqual({filter_channel.claim_hash: 2}, censor.censored)