all test_claim_commands tests green

This commit is contained in:
Victor Shyba 2021-01-20 01:20:50 -03:00
parent 9924b7b438
commit 90106f5f08
3 changed files with 63 additions and 49 deletions

View file

@ -25,45 +25,32 @@ def set_reference(reference, claim_hash, rows):
class Censor:
__slots__ = 'streams', 'channels', 'limit_claims_per_channel', 'censored', 'claims_in_channel', 'total'
SEARCH = 1
RESOLVE = 2
def __init__(self, streams: dict = None, channels: dict = None, limit_claims_per_channel: int = None):
self.streams = streams or {}
self.channels = channels or {}
self.limit_claims_per_channel = limit_claims_per_channel # doesn't count as censored
__slots__ = 'censor_type', 'censored'
def __init__(self, censor_type):
self.censor_type = censor_type
self.censored = {}
self.claims_in_channel = {}
self.total = 0
def apply(self, rows):
return [row for row in rows if not self.censor(row)]
def censor(self, row) -> bool:
was_censored = False
for claim_hash, lookup in (
(row['claim_hash'], self.streams),
(row['claim_hash'], self.channels),
(row['channel_hash'], self.channels),
(row['reposted_claim_hash'], self.streams),
(row['reposted_claim_hash'], self.channels)):
censoring_channel_hash = lookup.get(claim_hash)
if censoring_channel_hash:
was_censored = True
self.censored.setdefault(censoring_channel_hash, 0)
self.censored[censoring_channel_hash] += 1
break
was_censored = (row['censor_type'] or 0) >= self.censor_type
if was_censored:
self.total += 1
if not was_censored and self.limit_claims_per_channel is not None and row['channel_hash']:
self.claims_in_channel.setdefault(row['channel_hash'], 0)
self.claims_in_channel[row['channel_hash']] += 1
if self.claims_in_channel[row['channel_hash']] > self.limit_claims_per_channel:
return True
censoring_channel_hash = row['censoring_channel_hash']
self.censored.setdefault(censoring_channel_hash, set())
self.censored[censoring_channel_hash].add(row['tx_hash'])
return was_censored
def to_message(self, outputs: OutputsMessage, extra_txo_rows):
outputs.blocked_total = self.total
for censoring_channel_hash, count in self.censored.items():
blocked = outputs.blocked.add()
blocked.count = count
blocked.count = len(count)
set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows)
outputs.blocked_total += len(count)
class Outputs:

View file

@ -9,7 +9,8 @@ from elasticsearch import AsyncElasticsearch, NotFoundError
from elasticsearch.helpers import async_bulk
from lbry.crypto.base58 import Base58
from lbry.schema.result import Outputs
from lbry.error import ResolveCensoredError
from lbry.schema.result import Outputs, Censor
from lbry.schema.tags import clean_tags
from lbry.schema.url import URL
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
@ -37,8 +38,9 @@ class SearchIndex:
raise
def stop(self):
asyncio.ensure_future(self.client.close())
client = self.client
self.client = None
return asyncio.ensure_future(client.close())
def delete_index(self):
return self.client.indices.delete(self.index)
@ -78,14 +80,22 @@ class SearchIndex:
async def session_query(self, query_name, function, kwargs):
offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
if query_name == 'resolve':
response = await self.resolve(*kwargs)
response, censored, censor = await self.resolve(*kwargs)
else:
censor = Censor(Censor.SEARCH)
response, offset, total = await self.search(**kwargs)
return Outputs.to_base64(response, await self._get_referenced_rows(response), offset, total)
censored = censor.apply(response)
return Outputs.to_base64(censored, await self._get_referenced_rows(response), offset, total, censor)
async def resolve(self, *urls):
censor = Censor(Censor.RESOLVE)
results = await asyncio.gather(*(self.resolve_url(url) for url in urls))
return results
censored = [
result if not isinstance(result, dict) or not censor.censor(result)
else ResolveCensoredError(url, result['censoring_channel_hash'])
for url, result in zip(urls, results)
]
return results, censored, censor
async def search(self, **kwargs):
if 'channel' in kwargs:
@ -94,7 +104,7 @@ class SearchIndex:
return [], 0, 0
kwargs['channel_id'] = result['_id']
try:
result = await self.client.search(expand_query(**kwargs), self.index)
result = await self.client.search(expand_query(**kwargs), index=self.index)
except NotFoundError:
# index has no docs, fixme: log something
return [], 0, 0
@ -144,6 +154,7 @@ class SearchIndex:
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
channel_hashes = set(filter(None, (row['channel_hash'] for row in txo_rows)))
channel_hashes |= set(filter(None, (row['censoring_channel_hash'] for row in txo_rows)))
reposted_txos = []
if repost_hashes:
@ -166,6 +177,8 @@ def extract_doc(doc, index):
doc['reposted_claim_id'] = None
channel_hash = doc.pop('channel_hash')
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
channel_hash = doc.pop('censoring_channel_hash')
doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
txo_hash = doc.pop('txo_hash')
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode()
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
@ -322,6 +335,9 @@ def expand_result(results):
result['reposted_claim_hash'] = None
result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None
result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack('<I', result['tx_nout'])
result['tx_hash'] = unhexlify(result['tx_id'])[::-1]
if result['censoring_channel_hash']:
result['censoring_channel_hash'] = unhexlify(result['censoring_channel_hash'])[::-1]
if inner_hits:
return expand_result(inner_hits)
return results

View file

@ -809,28 +809,39 @@ class SQLDB:
def enqueue_changes(self, changed_claim_hashes, deleted_claims):
if not changed_claim_hashes and not deleted_claims:
return
tags = {}
langs = {}
for claim_hash, tag in self.execute(
f"select claim_hash, tag from tag "
f"WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})", changed_claim_hashes):
tags.setdefault(claim_hash, [])
tags[claim_hash].append(tag)
for claim_hash, lang in self.execute(
f"select claim_hash, language from language "
f"WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})", changed_claim_hashes):
langs.setdefault(claim_hash, [])
langs[claim_hash].append(lang)
blocklist = set(self.blocked_streams.keys()) | set(self.filtered_streams.keys())
blocked_channels = set(self.blocked_channels.keys()) | set(self.filtered_channels.keys())
changed_claim_hashes |= blocklist | blocked_channels
for claim in self.execute(f"""
SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
(select group_concat(tag, ' ') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash)
WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})
""", changed_claim_hashes):
OR channel_hash IN ({','.join('?' for _ in blocked_channels)})
""", list(changed_claim_hashes) + list(blocked_channels)):
claim = dict(claim._asdict())
claim['tags'] = tags.get(claim['claim_hash']) or tags.get(claim['reposted_claim_hash'])
claim['languages'] = langs.get(claim['claim_hash'], [])
id_set = set(filter(None, (claim['claim_hash'], claim['channel_hash'], claim['reposted_claim_hash'])))
claim['censor_type'] = 0
claim['censoring_channel_hash'] = None
for reason_id in id_set.intersection(blocklist | blocked_channels):
if reason_id in self.blocked_streams:
claim['censor_type'] = 2
claim['censoring_channel_hash'] = self.blocked_streams.get(reason_id)
elif reason_id in self.blocked_channels:
claim['censor_type'] = 2
claim['censoring_channel_hash'] = self.blocked_channels.get(reason_id)
elif reason_id in self.filtered_streams:
claim['censor_type'] = 1
claim['censoring_channel_hash'] = self.filtered_streams.get(reason_id)
elif reason_id in self.filtered_channels:
claim['censor_type'] = 1
claim['censoring_channel_hash'] = self.filtered_channels.get(reason_id)
claim['tags'] = claim['tags'].split(' ') if claim['tags'] else []
claim['languages'] = claim['languages'].split(' ') if claim['languages'] else []
if not self.claim_queue.full():
self.claim_queue.put_nowait(('update', claim))
for claim_hash in deleted_claims: