use a dict on set_reference

This commit is contained in:
Victor Shyba 2021-03-02 19:58:18 -03:00
parent eb6924277f
commit 5a9338a27f

View file

@ -13,14 +13,11 @@ NOT_FOUND = ErrorMessage.Code.Name(ErrorMessage.NOT_FOUND)
BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED) BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED)
def set_reference(reference, claim_hash, rows): def set_reference(reference, txo_row):
if claim_hash: if txo_row:
for txo in rows: reference.tx_hash = txo_row['txo_hash'][:32]
if claim_hash == txo['claim_hash']: reference.nout = struct.unpack('<I', txo_row['txo_hash'][32:])[0]
reference.tx_hash = txo['txo_hash'][:32] reference.height = txo_row['height']
reference.nout = struct.unpack('<I', txo['txo_hash'][32:])[0]
reference.height = txo['height']
return
class Censor: class Censor:
@ -45,11 +42,11 @@ class Censor:
self.censored[censoring_channel_hash].add(row['tx_hash']) self.censored[censoring_channel_hash].add(row['tx_hash'])
return was_censored return was_censored
def to_message(self, outputs: OutputsMessage, extra_txo_rows): def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
for censoring_channel_hash, count in self.censored.items(): for censoring_channel_hash, count in self.censored.items():
blocked = outputs.blocked.add() blocked = outputs.blocked.add()
blocked.count = len(count) blocked.count = len(count)
set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows) set_reference(blocked.channel, extra_txo_rows.get(censoring_channel_hash))
outputs.blocked_total += len(count) outputs.blocked_total += len(count)
@ -155,6 +152,7 @@ class Outputs:
@classmethod @classmethod
def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> bytes: def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> bytes:
extra_txo_rows = {row['claim_hash']: row for row in extra_txo_rows}
page = OutputsMessage() page = OutputsMessage()
page.offset = offset page.offset = offset
if total is not None: if total is not None:
@ -163,12 +161,12 @@ class Outputs:
blocked.to_message(page, extra_txo_rows) blocked.to_message(page, extra_txo_rows)
for row in txo_rows: for row in txo_rows:
cls.row_to_message(row, page.txos.add(), extra_txo_rows) cls.row_to_message(row, page.txos.add(), extra_txo_rows)
for row in extra_txo_rows: for row in extra_txo_rows.values():
cls.row_to_message(row, page.extra_txos.add(), extra_txo_rows) cls.row_to_message(row, page.extra_txos.add(), extra_txo_rows)
return page.SerializeToString() return page.SerializeToString()
@classmethod @classmethod
def row_to_message(cls, txo, txo_message, extra_txo_rows): def row_to_message(cls, txo, txo_message, extra_row_dict: dict):
if isinstance(txo, Exception): if isinstance(txo, Exception):
txo_message.error.text = txo.args[0] txo_message.error.text = txo.args[0]
if isinstance(txo, ValueError): if isinstance(txo, ValueError):
@ -177,7 +175,7 @@ class Outputs:
txo_message.error.code = ErrorMessage.NOT_FOUND txo_message.error.code = ErrorMessage.NOT_FOUND
elif isinstance(txo, ResolveCensoredError): elif isinstance(txo, ResolveCensoredError):
txo_message.error.code = ErrorMessage.BLOCKED txo_message.error.code = ErrorMessage.BLOCKED
set_reference(txo_message.error.blocked.channel, txo.censor_hash, extra_txo_rows) set_reference(txo_message.error.blocked.channel, extra_row_dict.get(txo.censor_hash))
return return
txo_message.tx_hash = txo['txo_hash'][:32] txo_message.tx_hash = txo['txo_hash'][:32]
txo_message.nout, = struct.unpack('<I', txo['txo_hash'][32:]) txo_message.nout, = struct.unpack('<I', txo['txo_hash'][32:])
@ -200,5 +198,5 @@ class Outputs:
txo_message.claim.trending_mixed = txo['trending_mixed'] txo_message.claim.trending_mixed = txo['trending_mixed']
txo_message.claim.trending_local = txo['trending_local'] txo_message.claim.trending_local = txo['trending_local']
txo_message.claim.trending_global = txo['trending_global'] txo_message.claim.trending_global = txo['trending_global']
set_reference(txo_message.claim.channel, txo['channel_hash'], extra_txo_rows) set_reference(txo_message.claim.channel, extra_row_dict.get(txo['channel_hash']))
set_reference(txo_message.claim.repost, txo['reposted_claim_hash'], extra_txo_rows) set_reference(txo_message.claim.repost, extra_row_dict.get(txo['reposted_claim_hash']))