diff --git a/lbry/schema/result.py b/lbry/schema/result.py index 9ecca5888..5eb892c4f 100644 --- a/lbry/schema/result.py +++ b/lbry/schema/result.py @@ -25,45 +25,32 @@ def set_reference(reference, claim_hash, rows): class Censor: - __slots__ = 'streams', 'channels', 'limit_claims_per_channel', 'censored', 'claims_in_channel', 'total' + SEARCH = 1 + RESOLVE = 2 - def __init__(self, streams: dict = None, channels: dict = None, limit_claims_per_channel: int = None): - self.streams = streams or {} - self.channels = channels or {} - self.limit_claims_per_channel = limit_claims_per_channel # doesn't count as censored + __slots__ = 'censor_type', 'censored' + + def __init__(self, censor_type): + self.censor_type = censor_type self.censored = {} - self.claims_in_channel = {} - self.total = 0 + + def apply(self, rows): + return [row for row in rows if not self.censor(row)] def censor(self, row) -> bool: - was_censored = False - for claim_hash, lookup in ( - (row['claim_hash'], self.streams), - (row['claim_hash'], self.channels), - (row['channel_hash'], self.channels), - (row['reposted_claim_hash'], self.streams), - (row['reposted_claim_hash'], self.channels)): - censoring_channel_hash = lookup.get(claim_hash) - if censoring_channel_hash: - was_censored = True - self.censored.setdefault(censoring_channel_hash, 0) - self.censored[censoring_channel_hash] += 1 - break + was_censored = (row['censor_type'] or 0) >= self.censor_type if was_censored: - self.total += 1 - if not was_censored and self.limit_claims_per_channel is not None and row['channel_hash']: - self.claims_in_channel.setdefault(row['channel_hash'], 0) - self.claims_in_channel[row['channel_hash']] += 1 - if self.claims_in_channel[row['channel_hash']] > self.limit_claims_per_channel: - return True + censoring_channel_hash = row['censoring_channel_hash'] + self.censored.setdefault(censoring_channel_hash, set()) + self.censored[censoring_channel_hash].add(row['tx_hash']) return was_censored def to_message(self, outputs: OutputsMessage, extra_txo_rows): - outputs.blocked_total = self.total for censoring_channel_hash, count in self.censored.items(): blocked = outputs.blocked.add() - blocked.count = count + blocked.count = len(count) set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows) + outputs.blocked_total += len(count) class Outputs: diff --git a/lbry/wallet/server/db/elastic_search.py b/lbry/wallet/server/db/elastic_search.py index 15e0bcce4..17f5a0fdd 100644 --- a/lbry/wallet/server/db/elastic_search.py +++ b/lbry/wallet/server/db/elastic_search.py @@ -9,7 +9,8 @@ from elasticsearch import AsyncElasticsearch, NotFoundError from elasticsearch.helpers import async_bulk from lbry.crypto.base58 import Base58 -from lbry.schema.result import Outputs +from lbry.error import ResolveCensoredError +from lbry.schema.result import Outputs, Censor from lbry.schema.tags import clean_tags from lbry.schema.url import URL from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES @@ -37,8 +38,9 @@ class SearchIndex: raise def stop(self): - asyncio.ensure_future(self.client.close()) + client = self.client self.client = None + return asyncio.ensure_future(client.close()) def delete_index(self): return self.client.indices.delete(self.index) @@ -78,14 +80,22 @@ class SearchIndex: async def session_query(self, query_name, function, kwargs): offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0 if query_name == 'resolve': - response = await self.resolve(*kwargs) + response, censored, censor = await self.resolve(*kwargs) else: + censor = Censor(Censor.SEARCH) response, offset, total = await self.search(**kwargs) - return Outputs.to_base64(response, await self._get_referenced_rows(response), offset, total) + censored = censor.apply(response) + return Outputs.to_base64(censored, await self._get_referenced_rows(response), offset, total, censor) async def resolve(self, *urls): + censor = Censor(Censor.RESOLVE) results = await asyncio.gather(*(self.resolve_url(url) for url in urls)) - return results + censored = [ + result if not isinstance(result, dict) or not censor.censor(result) + else ResolveCensoredError(url, result['censoring_channel_hash']) + for url, result in zip(urls, results) + ] + return results, censored, censor async def search(self, **kwargs): if 'channel' in kwargs: @@ -94,7 +104,7 @@ class SearchIndex: return [], 0, 0 kwargs['channel_id'] = result['_id'] try: - result = await self.client.search(expand_query(**kwargs), self.index) + result = await self.client.search(expand_query(**kwargs), index=self.index) except NotFoundError: # index has no docs, fixme: log something return [], 0, 0 @@ -144,6 +154,7 @@ class SearchIndex: txo_rows = [row for row in txo_rows if isinstance(row, dict)] repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) channel_hashes = set(filter(None, (row['channel_hash'] for row in txo_rows))) + channel_hashes |= set(filter(None, (row['censoring_channel_hash'] for row in txo_rows))) reposted_txos = [] if repost_hashes: @@ -166,6 +177,8 @@ def extract_doc(doc, index): doc['reposted_claim_id'] = None channel_hash = doc.pop('channel_hash') doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash + channel_hash = doc.pop('censoring_channel_hash') + doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash txo_hash = doc.pop('txo_hash') doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode() doc['tx_nout'] = struct.unpack('