add channel information on resolve protobuf and test for its presence

This commit is contained in:
Victor Shyba 2020-09-07 15:51:49 -03:00 committed by Lex Berezhny
parent 8cde120928
commit 4fcfa0b193
3 changed files with 38 additions and 8 deletions

View file

@ -1,6 +1,5 @@
import logging import logging
import itertools import itertools
from operator import itemgetter
from typing import List, Dict from typing import List, Dict
from lbry.schema.url import URL from lbry.schema.url import URL
@ -15,29 +14,37 @@ from .search import search_claims
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _get_referenced_rows(txo_rows: List[dict], censor_channels: List[bytes]): def _get_referenced_rows(txo_rows: List[Output], censor_channels: List[bytes]):
# censor = context().get_resolve_censor() # censor = context().get_resolve_censor()
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows))) repost_hashes = set(txo.reposted_claim.claim_hash for txo in txo_rows if txo.reposted_claim)
channel_hashes = set(itertools.chain( channel_hashes = set(itertools.chain(
filter(None, map(itemgetter('channel_hash'), txo_rows)), (txo.channel.claim_hash for txo in txo_rows if txo.channel),
censor_channels censor_channels
)) ))
reposted_txos = [] reposted_txos = []
if repost_hashes: if repost_hashes:
reposted_txos = search_claims(**{'claim.claim_hash__in': repost_hashes}) reposted_txos = search_claims(**{'claim.claim_hash__in': repost_hashes})
channel_hashes |= set(filter(None, map(itemgetter('channel_hash'), reposted_txos))) if reposted_txos:
reposted_txos = reposted_txos[0]
channel_hashes |= set(txo.channel.claim_hash for txo in reposted_txos if txo.channel)
channel_txos = [] channel_txos = []
if channel_hashes: if channel_hashes:
channel_txos = search_claims(**{'claim.claim_hash__in': channel_hashes}) channel_txos = search_claims(**{'claim.claim_hash__in': channel_hashes})
channel_txos = channel_txos[0] if channel_txos else []
# channels must come first for client side inflation to work properly # channels must come first for client side inflation to work properly
return channel_txos + reposted_txos return channel_txos + reposted_txos
def protobuf_resolve(urls, **kwargs) -> str: def protobuf_resolve(urls, **kwargs) -> str:
return ResultOutput.to_base64([resolve_url(raw_url) for raw_url in urls], []) txo_rows = [resolve_url(raw_url) for raw_url in urls]
extra_txo_rows = _get_referenced_rows(
txo_rows,
[txo.censor_hash for txo in txo_rows if isinstance(txo, ResolveCensoredError)]
)
return ResultOutput.to_base64(txo_rows, extra_txo_rows)
def resolve(urls, **kwargs) -> Dict[str, Output]: def resolve(urls, **kwargs) -> Dict[str, Output]:
@ -86,7 +93,10 @@ def resolve_url(raw_url):
# matches = search_claims(censor, **q, limit=1) # matches = search_claims(censor, **q, limit=1)
matches = search_claims(**q, limit=1)[0] matches = search_claims(**q, limit=1)[0]
if matches: if matches:
return matches[0] stream = matches[0]
if channel:
stream.channel = channel
return stream
elif censor.censored: elif censor.censored:
return ResolveCensoredError(raw_url, next(iter(censor.censored))) return ResolveCensoredError(raw_url, next(iter(censor.censored)))
else: else:

View file

@ -206,5 +206,10 @@ class Outputs:
#txo_message.claim.trending_mixed = txo['trending_mixed'] #txo_message.claim.trending_mixed = txo['trending_mixed']
#txo_message.claim.trending_local = txo['trending_local'] #txo_message.claim.trending_local = txo['trending_local']
#txo_message.claim.trending_global = txo['trending_global'] #txo_message.claim.trending_global = txo['trending_global']
#set_reference(txo_message.claim.channel, txo['channel_hash'], extra_txo_rows) if txo.channel:
reference = txo_message.claim.channel
hash = txo.channel.hash
reference.tx_hash = hash[:32]
reference.nout = struct.unpack('<I', hash[32:])[0]
reference.height = txo.channel.spent_height
#set_reference(txo_message.claim.repost, txo['reposted_claim_hash'], extra_txo_rows) #set_reference(txo_message.claim.repost, txo['reposted_claim_hash'], extra_txo_rows)

View file

@ -10,6 +10,7 @@ from distutils.dir_util import copy_tree, remove_tree
from lbry import Config, Database, RegTestLedger, Transaction, Output, Input from lbry import Config, Database, RegTestLedger, Transaction, Output, Input
from lbry.crypto.base58 import Base58 from lbry.crypto.base58 import Base58
from lbry.schema.claim import Stream, Channel from lbry.schema.claim import Stream, Channel
from lbry.schema.result import Outputs
from lbry.schema.support import Support from lbry.schema.support import Support
from lbry.error import LbrycrdEventSubscriptionError, LbrycrdUnauthorizedError from lbry.error import LbrycrdEventSubscriptionError, LbrycrdUnauthorizedError
from lbry.blockchain.lbrycrd import Lbrycrd from lbry.blockchain.lbrycrd import Lbrycrd
@ -897,6 +898,20 @@ class TestGeneralBlockchainSync(SyncingBlockchainTestCase):
self.assertEqual(stream_c.claim_id, await self.resolve_to_claim_id("@foo#a/foo#c")) self.assertEqual(stream_c.claim_id, await self.resolve_to_claim_id("@foo#a/foo#c"))
self.assertEqual(stream_cd.claim_id, await self.resolve_to_claim_id("@foo#ab/foo#cd")) self.assertEqual(stream_cd.claim_id, await self.resolve_to_claim_id("@foo#ab/foo#cd"))
async def test_resolve_protobuf_includes_enough_information_for_signature_validation(self):
# important for old sdk
chan_ab = await self.get_claim(
await self.create_claim(claim_id_startswith='ab', is_channel=True))
await self.create_claim(claim_id_startswith='cd', sign=chan_ab)
await self.generate(1)
resolutions = await self.db.protobuf_resolve(["@foo#ab/foo#cd"])
resolutions = Outputs.from_base64(resolutions)
txs = await self.db.get_transactions(tx_hash__in=[tx[0] for tx in resolutions.txs])
self.assertEqual(len(txs), 2)
resolutions = resolutions.inflate(txs)
claim = resolutions[0][0]
self.assertTrue(claim.is_signed_by(claim.channel, self.chain.ledger))
class TestClaimtrieSync(SyncingBlockchainTestCase): class TestClaimtrieSync(SyncingBlockchainTestCase):