refactor handling blocked claims in serialization

This commit is contained in:
Jack Robison 2022-03-16 15:30:44 -04:00
parent 9d6d9ff68f
commit 50b3acb4e6
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 76 additions and 103 deletions

View file

@ -327,7 +327,7 @@ class HubDB:
if blocker_hash: if blocker_hash:
reason_row = self._fs_get_claim_by_hash(blocker_hash) reason_row = self._fs_get_claim_by_hash(blocker_hash)
return ExpandedResolveResult( return ExpandedResolveResult(
None, ResolveCensoredError(url, blocker_hash, censor_row=reason_row), None, None None, ResolveCensoredError(url, blocker_hash.hex(), censor_row=reason_row), None, None
) )
if claim.reposted_claim_hash: if claim.reposted_claim_hash:
repost = self._fs_get_claim_by_hash(claim.reposted_claim_hash) repost = self._fs_get_claim_by_hash(claim.reposted_claim_hash)

View file

@ -10,7 +10,7 @@ from typing import Optional, List, Iterable
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
from elasticsearch.helpers import async_streaming_bulk from elasticsearch.helpers import async_streaming_bulk
from scribe.schema.result import Outputs, Censor from scribe.schema.result import Censor, Outputs
from scribe.schema.tags import clean_tags from scribe.schema.tags import clean_tags
from scribe.schema.url import normalize_name from scribe.schema.url import normalize_name
from scribe.error import TooManyClaimSearchParametersError from scribe.error import TooManyClaimSearchParametersError
@ -285,18 +285,21 @@ class SearchIndex:
async with cache_item.lock: async with cache_item.lock:
if cache_item.result: if cache_item.result:
return cache_item.result return cache_item.result
censor = Censor(Censor.SEARCH)
response, offset, total = await self.search(**kwargs) response, offset, total = await self.search(**kwargs)
censor.apply(response) censored = {}
for row in response:
if (row.get('censor_type') or 0) >= Censor.SEARCH:
censoring_channel_hash = bytes.fromhex(row['censoring_channel_id'])[::-1]
censored.setdefault(censoring_channel_hash, set())
censored[censoring_channel_hash].add(row['tx_hash'])
total_referenced.extend(response) total_referenced.extend(response)
if censored:
if censor.censored:
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
total_referenced.extend(response) total_referenced.extend(response)
response = [self._make_resolve_result(r) for r in response] response = [self._make_resolve_result(r) for r in response]
extra = [self._make_resolve_result(r) for r in await self._get_referenced_rows(total_referenced)] extra = [self._make_resolve_result(r) for r in await self._get_referenced_rows(total_referenced)]
result = Outputs.to_base64( result = Outputs.to_base64(
response, extra, offset, total, censor response, extra, offset, total, censored
) )
cache_item.result = result cache_item.result = result
return result return result
@ -314,7 +317,6 @@ class SearchIndex:
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
self.claim_cache.set(result['claim_id'], result) self.claim_cache.set(result['claim_id'], result)
async def search(self, **kwargs): async def search(self, **kwargs):
try: try:
return await self.search_ahead(**kwargs) return await self.search_ahead(**kwargs)

View file

@ -1,5 +1,5 @@
import base64 import base64
from typing import List, TYPE_CHECKING, Union, Optional, NamedTuple from typing import List, TYPE_CHECKING, Union, Optional, Dict, Set, Tuple
from itertools import chain from itertools import chain
from scribe.error import ResolveCensoredError from scribe.error import ResolveCensoredError
@ -12,53 +12,46 @@ NOT_FOUND = ErrorMessage.Code.Name(ErrorMessage.NOT_FOUND)
BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED) BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED)
def set_reference(reference, claim_hash, rows):
if claim_hash:
for txo in rows:
if claim_hash == txo.claim_hash:
reference.tx_hash = txo.tx_hash
reference.nout = txo.position
reference.height = txo.height
return
class Censor: class Censor:
NOT_CENSORED = 0 NOT_CENSORED = 0
SEARCH = 1 SEARCH = 1
RESOLVE = 2 RESOLVE = 2
__slots__ = 'censor_type', 'censored'
def __init__(self, censor_type): def encode_txo(txo_message: OutputsMessage, resolve_result: Union['ResolveResult', Exception]):
self.censor_type = censor_type if isinstance(resolve_result, Exception):
self.censored = {} txo_message.error.text = resolve_result.args[0]
if isinstance(resolve_result, ValueError):
txo_message.error.code = ErrorMessage.INVALID
elif isinstance(resolve_result, LookupError):
txo_message.error.code = ErrorMessage.NOT_FOUND
return
txo_message.tx_hash = resolve_result.tx_hash
txo_message.nout = resolve_result.position
txo_message.height = resolve_result.height
txo_message.claim.short_url = resolve_result.short_url
txo_message.claim.reposted = resolve_result.reposted
txo_message.claim.is_controlling = resolve_result.is_controlling
txo_message.claim.creation_height = resolve_result.creation_height
txo_message.claim.activation_height = resolve_result.activation_height
txo_message.claim.expiration_height = resolve_result.expiration_height
txo_message.claim.effective_amount = resolve_result.effective_amount
txo_message.claim.support_amount = resolve_result.support_amount
def is_censored(self, row): if resolve_result.canonical_url is not None:
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type txo_message.claim.canonical_url = resolve_result.canonical_url
if resolve_result.last_takeover_height is not None:
def apply(self, rows): txo_message.claim.take_over_height = resolve_result.last_takeover_height
return [row for row in rows if not self.censor(row)] if resolve_result.claims_in_channel is not None:
txo_message.claim.claims_in_channel = resolve_result.claims_in_channel
def censor(self, row) -> Optional[bytes]: if resolve_result.reposted_claim_hash and resolve_result.reposted_tx_hash is not None:
if self.is_censored(row): txo_message.claim.repost.tx_hash = resolve_result.reposted_tx_hash
censoring_channel_hash = bytes.fromhex(row['censoring_channel_id'])[::-1] txo_message.claim.repost.nout = resolve_result.reposted_tx_position
self.censored.setdefault(censoring_channel_hash, set()) txo_message.claim.repost.height = resolve_result.reposted_height
self.censored[censoring_channel_hash].add(row['tx_hash']) if resolve_result.channel_hash and resolve_result.channel_tx_hash is not None:
return censoring_channel_hash txo_message.claim.channel.tx_hash = resolve_result.channel_tx_hash
return None txo_message.claim.channel.nout = resolve_result.channel_tx_position
txo_message.claim.channel.height = resolve_result.channel_height
def to_message(self, outputs: OutputsMessage, extra_txo_rows: List['ResolveResult']):
for censoring_channel_hash, count in self.censored.items():
outputs.blocked_total += len(count)
blocked = outputs.blocked.add()
blocked.count = len(count)
for resolve_result in extra_txo_rows:
if resolve_result.claim_hash == censoring_channel_hash:
blocked.channel.tx_hash = resolve_result.tx_hash
blocked.channel.nout = resolve_result.position
blocked.channel.height = resolve_result.height
return
class Outputs: class Outputs:
@ -170,67 +163,45 @@ class Outputs:
outputs.blocked, outputs.blocked_total outputs.blocked, outputs.blocked_total
) )
@classmethod @staticmethod
def to_base64(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> str: def to_base64(txo_rows: List[Union[Exception, 'ResolveResult']], extra_txo_rows: List['ResolveResult'],
return base64.b64encode(cls.to_bytes(txo_rows, extra_txo_rows, offset, total, blocked)).decode() offset: int = 0, total: Optional[int] = None,
censored: Optional[Dict[bytes, Set[bytes]]] = None) -> str:
return base64.b64encode(Outputs.to_bytes(txo_rows, extra_txo_rows, offset, total, censored)).decode()
@classmethod @staticmethod
def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> bytes: def to_bytes(txo_rows: List[Union[Exception, 'ResolveResult']], extra_txo_rows: List['ResolveResult'],
offset: int = 0, total: Optional[int] = None,
censored: Optional[Dict[bytes, Set[bytes]]] = None) -> bytes:
page = OutputsMessage() page = OutputsMessage()
page.offset = offset page.offset = offset
if total is not None: if total is not None:
page.total = total page.total = total
if blocked is not None: censored = censored or {}
blocked.to_message(page, extra_txo_rows) censored_txos: Dict[bytes, List[Tuple[str, 'ResolveResult']]] = {}
for row in extra_txo_rows:
cls.encode_txo(page.extra_txos.add(), row)
for row in txo_rows: for row in txo_rows:
txo_message = page.txos.add() txo_message = page.txos.add()
if isinstance(row, ResolveCensoredError): if isinstance(row, ResolveCensoredError):
for resolve_result in extra_txo_rows: censored_hash = bytes.fromhex(row.censor_id)
if resolve_result.claim_hash == row.censor_id: if censored_hash not in censored_txos:
txo_message.error.code = ErrorMessage.BLOCKED censored_txos[censored_hash] = []
txo_message.error.text = str(row) censored_txos[censored_hash].append((str(row), txo_message))
txo_message.error.blocked.channel.tx_hash = resolve_result.tx_hash
txo_message.error.blocked.channel.nout = resolve_result.position
txo_message.error.blocked.channel.height = resolve_result.height
break
else: else:
cls.encode_txo(txo_message, row) encode_txo(txo_message, row)
for row in extra_txo_rows:
if row.claim_hash in censored:
page.blocked_total += len(censored[row.claim_hash])
blocked = page.blocked.add()
blocked.count = len(censored[row.claim_hash])
blocked.channel.tx_hash = row.tx_hash
blocked.channel.nout = row.position
blocked.channel.height = row.height
if row.claim_hash in censored_txos:
for (text, txo_message) in censored_txos[row.claim_hash]:
txo_message.error.code = ErrorMessage.BLOCKED
txo_message.error.text = text
txo_message.error.blocked.channel.tx_hash = row.tx_hash
txo_message.error.blocked.channel.nout = row.position
txo_message.error.blocked.channel.height = row.height
encode_txo(page.extra_txos.add(), row)
return page.SerializeToString() return page.SerializeToString()
@classmethod
def encode_txo(cls, txo_message: OutputsMessage, resolve_result: Union['ResolveResult', Exception]):
if isinstance(resolve_result, Exception):
txo_message.error.text = resolve_result.args[0]
if isinstance(resolve_result, ValueError):
txo_message.error.code = ErrorMessage.INVALID
elif isinstance(resolve_result, LookupError):
txo_message.error.code = ErrorMessage.NOT_FOUND
return
txo_message.tx_hash = resolve_result.tx_hash
txo_message.nout = resolve_result.position
txo_message.height = resolve_result.height
txo_message.claim.short_url = resolve_result.short_url
txo_message.claim.reposted = resolve_result.reposted
txo_message.claim.is_controlling = resolve_result.is_controlling
txo_message.claim.creation_height = resolve_result.creation_height
txo_message.claim.activation_height = resolve_result.activation_height
txo_message.claim.expiration_height = resolve_result.expiration_height
txo_message.claim.effective_amount = resolve_result.effective_amount
txo_message.claim.support_amount = resolve_result.support_amount
if resolve_result.canonical_url is not None:
txo_message.claim.canonical_url = resolve_result.canonical_url
if resolve_result.last_takeover_height is not None:
txo_message.claim.take_over_height = resolve_result.last_takeover_height
if resolve_result.claims_in_channel is not None:
txo_message.claim.claims_in_channel = resolve_result.claims_in_channel
if resolve_result.reposted_claim_hash and resolve_result.reposted_tx_hash is not None:
txo_message.claim.repost.tx_hash = resolve_result.reposted_tx_hash
txo_message.claim.repost.nout = resolve_result.reposted_tx_position
txo_message.claim.repost.height = resolve_result.reposted_height
if resolve_result.channel_hash and resolve_result.channel_tx_hash is not None:
txo_message.claim.channel.tx_hash = resolve_result.channel_tx_hash
txo_message.claim.channel.nout = resolve_result.channel_tx_position
txo_message.claim.channel.height = resolve_result.channel_height