all test_claim_commands tests green
This commit is contained in:
parent
9924b7b438
commit
90106f5f08
3 changed files with 63 additions and 49 deletions
|
@ -25,45 +25,32 @@ def set_reference(reference, claim_hash, rows):
|
|||
|
||||
class Censor:
|
||||
|
||||
__slots__ = 'streams', 'channels', 'limit_claims_per_channel', 'censored', 'claims_in_channel', 'total'
|
||||
SEARCH = 1
|
||||
RESOLVE = 2
|
||||
|
||||
def __init__(self, streams: dict = None, channels: dict = None, limit_claims_per_channel: int = None):
|
||||
self.streams = streams or {}
|
||||
self.channels = channels or {}
|
||||
self.limit_claims_per_channel = limit_claims_per_channel # doesn't count as censored
|
||||
__slots__ = 'censor_type', 'censored'
|
||||
|
||||
def __init__(self, censor_type):
|
||||
self.censor_type = censor_type
|
||||
self.censored = {}
|
||||
self.claims_in_channel = {}
|
||||
self.total = 0
|
||||
|
||||
def apply(self, rows):
|
||||
return [row for row in rows if not self.censor(row)]
|
||||
|
||||
def censor(self, row) -> bool:
|
||||
was_censored = False
|
||||
for claim_hash, lookup in (
|
||||
(row['claim_hash'], self.streams),
|
||||
(row['claim_hash'], self.channels),
|
||||
(row['channel_hash'], self.channels),
|
||||
(row['reposted_claim_hash'], self.streams),
|
||||
(row['reposted_claim_hash'], self.channels)):
|
||||
censoring_channel_hash = lookup.get(claim_hash)
|
||||
if censoring_channel_hash:
|
||||
was_censored = True
|
||||
self.censored.setdefault(censoring_channel_hash, 0)
|
||||
self.censored[censoring_channel_hash] += 1
|
||||
break
|
||||
was_censored = (row['censor_type'] or 0) >= self.censor_type
|
||||
if was_censored:
|
||||
self.total += 1
|
||||
if not was_censored and self.limit_claims_per_channel is not None and row['channel_hash']:
|
||||
self.claims_in_channel.setdefault(row['channel_hash'], 0)
|
||||
self.claims_in_channel[row['channel_hash']] += 1
|
||||
if self.claims_in_channel[row['channel_hash']] > self.limit_claims_per_channel:
|
||||
return True
|
||||
censoring_channel_hash = row['censoring_channel_hash']
|
||||
self.censored.setdefault(censoring_channel_hash, set())
|
||||
self.censored[censoring_channel_hash].add(row['tx_hash'])
|
||||
return was_censored
|
||||
|
||||
def to_message(self, outputs: OutputsMessage, extra_txo_rows):
|
||||
outputs.blocked_total = self.total
|
||||
for censoring_channel_hash, count in self.censored.items():
|
||||
blocked = outputs.blocked.add()
|
||||
blocked.count = count
|
||||
blocked.count = len(count)
|
||||
set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows)
|
||||
outputs.blocked_total += len(count)
|
||||
|
||||
|
||||
class Outputs:
|
||||
|
|
|
@ -9,7 +9,8 @@ from elasticsearch import AsyncElasticsearch, NotFoundError
|
|||
from elasticsearch.helpers import async_bulk
|
||||
|
||||
from lbry.crypto.base58 import Base58
|
||||
from lbry.schema.result import Outputs
|
||||
from lbry.error import ResolveCensoredError
|
||||
from lbry.schema.result import Outputs, Censor
|
||||
from lbry.schema.tags import clean_tags
|
||||
from lbry.schema.url import URL
|
||||
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
|
||||
|
@ -37,8 +38,9 @@ class SearchIndex:
|
|||
raise
|
||||
|
||||
def stop(self):
|
||||
asyncio.ensure_future(self.client.close())
|
||||
client = self.client
|
||||
self.client = None
|
||||
return asyncio.ensure_future(client.close())
|
||||
|
||||
def delete_index(self):
|
||||
return self.client.indices.delete(self.index)
|
||||
|
@ -78,14 +80,22 @@ class SearchIndex:
|
|||
async def session_query(self, query_name, function, kwargs):
|
||||
offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
|
||||
if query_name == 'resolve':
|
||||
response = await self.resolve(*kwargs)
|
||||
response, censored, censor = await self.resolve(*kwargs)
|
||||
else:
|
||||
censor = Censor(Censor.SEARCH)
|
||||
response, offset, total = await self.search(**kwargs)
|
||||
return Outputs.to_base64(response, await self._get_referenced_rows(response), offset, total)
|
||||
censored = censor.apply(response)
|
||||
return Outputs.to_base64(censored, await self._get_referenced_rows(response), offset, total, censor)
|
||||
|
||||
async def resolve(self, *urls):
|
||||
censor = Censor(Censor.RESOLVE)
|
||||
results = await asyncio.gather(*(self.resolve_url(url) for url in urls))
|
||||
return results
|
||||
censored = [
|
||||
result if not isinstance(result, dict) or not censor.censor(result)
|
||||
else ResolveCensoredError(url, result['censoring_channel_hash'])
|
||||
for url, result in zip(urls, results)
|
||||
]
|
||||
return results, censored, censor
|
||||
|
||||
async def search(self, **kwargs):
|
||||
if 'channel' in kwargs:
|
||||
|
@ -94,7 +104,7 @@ class SearchIndex:
|
|||
return [], 0, 0
|
||||
kwargs['channel_id'] = result['_id']
|
||||
try:
|
||||
result = await self.client.search(expand_query(**kwargs), self.index)
|
||||
result = await self.client.search(expand_query(**kwargs), index=self.index)
|
||||
except NotFoundError:
|
||||
# index has no docs, fixme: log something
|
||||
return [], 0, 0
|
||||
|
@ -144,6 +154,7 @@ class SearchIndex:
|
|||
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
|
||||
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
|
||||
channel_hashes = set(filter(None, (row['channel_hash'] for row in txo_rows)))
|
||||
channel_hashes |= set(filter(None, (row['censoring_channel_hash'] for row in txo_rows)))
|
||||
|
||||
reposted_txos = []
|
||||
if repost_hashes:
|
||||
|
@ -166,6 +177,8 @@ def extract_doc(doc, index):
|
|||
doc['reposted_claim_id'] = None
|
||||
channel_hash = doc.pop('channel_hash')
|
||||
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
|
||||
channel_hash = doc.pop('censoring_channel_hash')
|
||||
doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
|
||||
txo_hash = doc.pop('txo_hash')
|
||||
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode()
|
||||
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
|
||||
|
@ -322,6 +335,9 @@ def expand_result(results):
|
|||
result['reposted_claim_hash'] = None
|
||||
result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None
|
||||
result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack('<I', result['tx_nout'])
|
||||
result['tx_hash'] = unhexlify(result['tx_id'])[::-1]
|
||||
if result['censoring_channel_hash']:
|
||||
result['censoring_channel_hash'] = unhexlify(result['censoring_channel_hash'])[::-1]
|
||||
if inner_hits:
|
||||
return expand_result(inner_hits)
|
||||
return results
|
||||
|
|
|
@ -809,28 +809,39 @@ class SQLDB:
|
|||
def enqueue_changes(self, changed_claim_hashes, deleted_claims):
|
||||
if not changed_claim_hashes and not deleted_claims:
|
||||
return
|
||||
tags = {}
|
||||
langs = {}
|
||||
for claim_hash, tag in self.execute(
|
||||
f"select claim_hash, tag from tag "
|
||||
f"WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})", changed_claim_hashes):
|
||||
tags.setdefault(claim_hash, [])
|
||||
tags[claim_hash].append(tag)
|
||||
for claim_hash, lang in self.execute(
|
||||
f"select claim_hash, language from language "
|
||||
f"WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})", changed_claim_hashes):
|
||||
langs.setdefault(claim_hash, [])
|
||||
langs[claim_hash].append(lang)
|
||||
blocklist = set(self.blocked_streams.keys()) | set(self.filtered_streams.keys())
|
||||
blocked_channels = set(self.blocked_channels.keys()) | set(self.filtered_channels.keys())
|
||||
changed_claim_hashes |= blocklist | blocked_channels
|
||||
for claim in self.execute(f"""
|
||||
SELECT claimtrie.claim_hash as is_controlling,
|
||||
claimtrie.last_take_over_height,
|
||||
(select group_concat(tag, ' ') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
|
||||
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
|
||||
claim.*
|
||||
FROM claim LEFT JOIN claimtrie USING (claim_hash)
|
||||
WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})
|
||||
""", changed_claim_hashes):
|
||||
OR channel_hash IN ({','.join('?' for _ in blocked_channels)})
|
||||
""", list(changed_claim_hashes) + list(blocked_channels)):
|
||||
claim = dict(claim._asdict())
|
||||
claim['tags'] = tags.get(claim['claim_hash']) or tags.get(claim['reposted_claim_hash'])
|
||||
claim['languages'] = langs.get(claim['claim_hash'], [])
|
||||
id_set = set(filter(None, (claim['claim_hash'], claim['channel_hash'], claim['reposted_claim_hash'])))
|
||||
claim['censor_type'] = 0
|
||||
claim['censoring_channel_hash'] = None
|
||||
for reason_id in id_set.intersection(blocklist | blocked_channels):
|
||||
if reason_id in self.blocked_streams:
|
||||
claim['censor_type'] = 2
|
||||
claim['censoring_channel_hash'] = self.blocked_streams.get(reason_id)
|
||||
elif reason_id in self.blocked_channels:
|
||||
claim['censor_type'] = 2
|
||||
claim['censoring_channel_hash'] = self.blocked_channels.get(reason_id)
|
||||
elif reason_id in self.filtered_streams:
|
||||
claim['censor_type'] = 1
|
||||
claim['censoring_channel_hash'] = self.filtered_streams.get(reason_id)
|
||||
elif reason_id in self.filtered_channels:
|
||||
claim['censor_type'] = 1
|
||||
claim['censoring_channel_hash'] = self.filtered_channels.get(reason_id)
|
||||
|
||||
claim['tags'] = claim['tags'].split(' ') if claim['tags'] else []
|
||||
claim['languages'] = claim['languages'].split(' ') if claim['languages'] else []
|
||||
if not self.claim_queue.full():
|
||||
self.claim_queue.put_nowait(('update', claim))
|
||||
for claim_hash in deleted_claims:
|
||||
|
|
Loading…
Add table
Reference in a new issue