diff --git a/lbrynet/schema/result.py b/lbrynet/schema/result.py index c2c316d03..7b9e3f680 100644 --- a/lbrynet/schema/result.py +++ b/lbrynet/schema/result.py @@ -2,6 +2,7 @@ import base64 import struct from typing import List from binascii import hexlify +from itertools import chain from google.protobuf.message import DecodeError @@ -10,48 +11,45 @@ from lbrynet.schema.types.v2.result_pb2 import Outputs as OutputsMessage class Outputs: - __slots__ = 'txos', 'txs', 'offset', 'total' + __slots__ = 'txos', 'extra_txos', 'txs', 'offset', 'total' - def __init__(self, txos: List, txs: List, offset: int, total: int): + def __init__(self, txos: List, extra_txos: List, txs: set, offset: int, total: int): self.txos = txos self.txs = txs + self.extra_txos = extra_txos self.offset = offset self.total = total - def _inflate_claim(self, txo, message): - txo.meta = { - 'canonical_url': message.canonical_url, - 'is_controlling': message.is_controlling, - 'activation_height': message.activation_height, - 'effective_amount': message.effective_amount, - 'support_amount': message.support_amount, - 'claims_in_channel': message.claims_in_channel, - 'trending_group': message.trending_group, - 'trending_mixed': message.trending_mixed, - 'trending_local': message.trending_local, - 'trending_global': message.trending_global, - } - try: - if txo.claim.is_channel: - txo.meta['claims_in_channel'] = message.claims_in_channel - except DecodeError: - pass - def inflate(self, txs): - tx_map, txos = {tx.hash: tx for tx in txs}, [] - for txo_message in self.txos: - if txo_message.WhichOneof('meta') == 'error': - txos.append(None) - continue - txo = tx_map[txo_message.tx_hash].outputs[txo_message.nout] - if txo_message.WhichOneof('meta') == 'claim': - self._inflate_claim(txo, txo_message.claim) - if txo_message.claim.HasField('channel'): - channel_message = txo_message.claim.channel - txo.channel = tx_map[channel_message.tx_hash].outputs[channel_message.nout] - self._inflate_claim(txo.channel, channel_message.claim) - txos.append(txo) - return txos + tx_map = {tx.hash: tx for tx in txs} + for txo_message in self.extra_txos: + self.message_to_txo(txo_message, tx_map) + return [self.message_to_txo(txo_message, tx_map) for txo_message in self.txos] + + def message_to_txo(self, txo_message, tx_map): + if txo_message.WhichOneof('meta') == 'error': + return None + txo = tx_map[txo_message.tx_hash].outputs[txo_message.nout] + if txo_message.WhichOneof('meta') == 'claim': + claim = txo_message.claim + txo.meta = { + 'short_url': claim.short_url, + 'canonical_url': claim.canonical_url or claim.short_url, + 'is_controlling': claim.is_controlling, + 'activation_height': claim.activation_height, + 'expiration_height': claim.expiration_height, + 'effective_amount': claim.effective_amount, + 'support_amount': claim.support_amount, + 'trending_group': claim.trending_group, + 'trending_mixed': claim.trending_mixed, + 'trending_local': claim.trending_local, + 'trending_global': claim.trending_global, + } + if claim.HasField('channel'): + txo.channel = tx_map[claim.channel.tx_hash].outputs[claim.channel.nout] + if claim.claims_in_channel is not None: + txo.meta['claims_in_channel'] = claim.claims_in_channel + return txo @classmethod def from_base64(cls, data: str) -> 'Outputs': @@ -61,50 +59,56 @@ class Outputs: def from_bytes(cls, data: bytes) -> 'Outputs': outputs = OutputsMessage() outputs.ParseFromString(data) - txs = {} - for txo_message in outputs.txos: + txs = set() + for txo_message in chain(outputs.txos, outputs.extra_txos): if txo_message.WhichOneof('meta') == 'error': continue - txs[txo_message.tx_hash] = (hexlify(txo_message.tx_hash[::-1]).decode(), txo_message.height) - if txo_message.WhichOneof('meta') == 'claim' and txo_message.claim.HasField('channel'): - channel = txo_message.claim.channel - txs[channel.tx_hash] = (hexlify(channel.tx_hash[::-1]).decode(), channel.height) - return cls(outputs.txos, list(txs.values()), outputs.offset, outputs.total) + txs.add((hexlify(txo_message.tx_hash[::-1]).decode(), txo_message.height)) + return cls(outputs.txos, outputs.extra_txos, txs, outputs.offset, outputs.total) @classmethod - def to_base64(cls, txo_rows, offset=0, total=None) -> str: - return base64.b64encode(cls.to_bytes(txo_rows, offset, total)).decode() + def to_base64(cls, txo_rows, extra_txo_rows, offset=0, total=None) -> str: + return base64.b64encode(cls.to_bytes(txo_rows, extra_txo_rows, offset, total)).decode() @classmethod - def to_bytes(cls, txo_rows, offset=0, total=None) -> bytes: + def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None) -> bytes: page = OutputsMessage() page.offset = offset page.total = total or len(txo_rows) - for txo in txo_rows: - txo_message = page.txos.add() - if isinstance(txo, Exception): - txo_message.error.text = txo.args[0] - if isinstance(txo, ValueError): - txo_message.error.code = txo_message.error.INVALID - elif isinstance(txo, LookupError): - txo_message.error.code = txo_message.error.NOT_FOUND - continue - txo_message.height = txo['height'] - txo_message.tx_hash = txo['txo_hash'][:32] - txo_message.nout, = struct.unpack('= 262974 THEN :height+2102400 ELSE :height+262974 END, :normalized||COALESCE( (SELECT shortest_id(claim_id, :claim_id) FROM claim WHERE normalized = :normalized), '#'||substr(:claim_id, 1, 1) - ), - CASE WHEN :is_channel_signature_valid = 1 THEN - (SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'|| - :normalized||COALESCE( - (SELECT shortest_id(claim_id, :claim_id) FROM claim - WHERE normalized = :normalized AND - channel_hash = :channel_hash AND - is_channel_signature_valid = 1), - '#'||substr(:claim_id, 1, 1) - ) - END + ) )""", claims) - def update_claims(self, txos: Set[Output], header, channels): - claims = self._upsertable_claims(txos, header, channels, clear_first=True) + def update_claims(self, txos: List[Output], header): + claims = self._upsertable_claims(txos, header, clear_first=True) if claims: self.db.executemany(""" - UPDATE claim SET - txo_hash=:txo_hash, tx_position=:tx_position, height=:height, amount=:amount, - public_key_bytes=:public_key_bytes, timestamp=:timestamp, - release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END, - channel_hash=:channel_hash, is_channel_signature_valid=:is_channel_signature_valid, - channel_height=CASE - WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_height - WHEN :is_channel_signature_valid THEN :height - END, - channel_canonical=CASE - WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_canonical - WHEN :is_channel_signature_valid THEN - (SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'|| - :normalized||COALESCE( - (SELECT shortest_id(claim_id, :claim_id) FROM claim - WHERE normalized = :normalized AND - channel_hash = :channel_hash AND - is_channel_signature_valid = 1), - '#'||substr(:claim_id, 1, 1) - ) - END + UPDATE claim SET + txo_hash=:txo_hash, tx_position=:tx_position, amount=:amount, height=:height, timestamp=:timestamp, + release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END WHERE claim_hash=:claim_hash; """, claims) @@ -346,13 +312,6 @@ class SQLDB: for table in ('tag',): # 'language', 'location', etc self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) - def invalidate_channel_signatures(self, channels): - self.execute(f""" - UPDATE claim SET - channel_height=NULL, channel_canonical=NULL, is_channel_signature_valid=0 - WHERE channel_hash IN ({','.join('?' for _ in channels)}) - """, [sqlite3.Binary(channel) for channel in channels]) - def split_inputs_into_claims_supports_and_other(self, txis): txo_hashes = {txi.txo_ref.hash for txi in txis} claims = self.execute(*query( @@ -369,21 +328,7 @@ class SQLDB: txo_hashes -= {r['txo_hash'] for r in supports} return claims, supports, txo_hashes - def get_channel_public_keys_for_outputs(self, txos): - channels = set() - for txo in txos: - try: - channel_hash = txo.claim.signing_channel_hash - if channel_hash: - channels.add(channel_hash) - except: - pass - return dict(self.execute(*query( - "SELECT claim_hash, public_key_bytes FROM claim", - claim_hash__in=[sqlite3.Binary(channel) for channel in channels] - )).fetchall()) - - def insert_supports(self, txos: Set[Output]): + def insert_supports(self, txos: List[Output]): supports = [] for txo in txos: tx = txo.tx_ref.tx @@ -405,6 +350,129 @@ class SQLDB: 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} )) + def validate_channel_signatures(self, height, new_claims, updated_claims): + if not new_claims and not updated_claims: + return + + channels, new_channel_keys, signables = {}, {}, {} + for txo in chain(new_claims, updated_claims): + try: + claim = txo.claim + except: + continue + if claim.is_channel: + channels[txo.claim_hash] = txo + new_channel_keys[txo.claim_hash] = claim.channel.public_key_bytes + else: + signables[txo.claim_hash] = txo + + missing_channel_keys = set() + for txo in signables.values(): + claim = txo.claim + if claim.is_signed and claim.signing_channel_hash not in new_channel_keys: + missing_channel_keys.add(claim.signing_channel_hash) + + all_channel_keys = {} + if new_channel_keys or missing_channel_keys: + all_channel_keys = dict(self.execute(*query( + "SELECT claim_hash, public_key_bytes FROM claim", + claim_hash__in=[ + sqlite3.Binary(channel_hash) for channel_hash in + set(new_channel_keys) | missing_channel_keys + ] + ))) + + changed_channel_keys = {} + for claim_hash, new_key in new_channel_keys.items(): + if all_channel_keys[claim_hash] != new_key: + all_channel_keys[claim_hash] = new_key + changed_channel_keys[claim_hash] = new_key + + claim_updates = [] + + for claim_hash, txo in signables.items(): + claim = txo.claim + update = { + 'claim_hash': sqlite3.Binary(claim_hash), + 'channel_hash': None, + 'signature': None, + 'signature_digest': None, + 'is_channel_signature_valid': False + } + if claim.is_signed: + update.update({ + 'channel_hash': sqlite3.Binary(claim.signing_channel_hash), + 'signature': sqlite3.Binary(txo.get_encoded_signature()), + 'signature_digest': sqlite3.Binary(txo.get_signature_digest(self.ledger)) + }) + claim_updates.append(update) + + if changed_channel_keys: + sql = f""" + SELECT * FROM claim WHERE + channel_hash IN ({','.join('?' for _ in changed_channel_keys)}) AND + signature IS NOT NULL + """ + for affected_claim in self.execute(sql, [sqlite3.Binary(h) for h in changed_channel_keys]): + if affected_claim['claim_hash'] not in signables: + claim_updates.append({ + 'claim_hash': sqlite3.Binary(affected_claim['claim_hash']), + 'channel_hash': sqlite3.Binary(affected_claim['channel_hash']), + 'signature': sqlite3.Binary(affected_claim['signature']), + 'signature_digest': sqlite3.Binary(affected_claim['signature_digest']), + 'is_channel_signature_valid': False + }) + + for update in claim_updates: + channel_pub_key = all_channel_keys.get(update['channel_hash']) + if channel_pub_key and update['signature']: + update['is_channel_signature_valid'] = Output.is_signature_valid( + bytes(update['signature']), bytes(update['signature_digest']), channel_pub_key + ) + + if claim_updates: + self.db.executemany(f""" + UPDATE claim SET + channel_hash=:channel_hash, signature=:signature, signature_digest=:signature_digest, + is_channel_signature_valid=:is_channel_signature_valid, + channel_join=CASE + WHEN is_channel_signature_valid AND :is_channel_signature_valid THEN channel_join + WHEN :is_channel_signature_valid THEN {height} + END, + canonical_url=CASE + WHEN is_channel_signature_valid AND :is_channel_signature_valid THEN canonical_url + WHEN :is_channel_signature_valid THEN + (SELECT short_url FROM claim WHERE claim_hash=:channel_hash)||'/'|| + normalized||COALESCE( + (SELECT shortest_id(other_claim.claim_id, claim.claim_id) FROM claim AS other_claim + WHERE other_claim.normalized = claim.normalized AND + other_claim.channel_hash = :channel_hash AND + other_claim.is_channel_signature_valid = 1), + '#'||substr(claim_id, 1, 1) + ) + END + WHERE claim_hash=:claim_hash; + """, claim_updates) + + if channels: + self.db.executemany( + "UPDATE claim SET public_key_bytes=:public_key_bytes WHERE claim_hash=:claim_hash", [{ + 'claim_hash': sqlite3.Binary(claim_hash), + 'public_key_bytes': sqlite3.Binary(txo.claim.channel.public_key_bytes) + } for claim_hash, txo in channels.items()] + ) + + if all_channel_keys: + self.db.executemany(f""" + UPDATE claim SET + claims_in_channel=( + SELECT COUNT(*) FROM claim AS claim_in_channel + WHERE claim_in_channel.channel_hash=claim.claim_hash AND + claim_in_channel.is_channel_signature_valid + ) + WHERE claim_hash = ? + """, [(sqlite3.Binary(channel_hash),) for channel_hash in all_channel_keys.keys()]) + def _update_support_amount(self, claim_hashes): if claim_hashes: self.execute(f""" @@ -501,20 +569,21 @@ class SQLDB: r(self._perform_overtake, height, [], []) def advance_txs(self, height, all_txs, header, daemon_height, timer): - insert_claims = set() - update_claims = set() + insert_claims = [] + update_claims = [] delete_claim_hashes = set() - insert_supports = set() + insert_supports = [] delete_support_txo_hashes = set() recalculate_claim_hashes = set() # added/deleted supports, added/updated claim deleted_claim_names = set() + delete_others = set() body_timer = timer.add_timer('body') for position, (etx, txid) in enumerate(all_txs): tx = timer.run( Transaction, etx.serialize(), height=height, position=position ) # Inputs - spent_claims, spent_supports, spent_other = timer.run( + spent_claims, spent_supports, spent_others = timer.run( self.split_inputs_into_claims_supports_and_other, tx.inputs ) body_timer.start() @@ -522,28 +591,38 @@ class SQLDB: delete_support_txo_hashes.update({r['txo_hash'] for r in spent_supports}) deleted_claim_names.update({r['normalized'] for r in spent_claims}) recalculate_claim_hashes.update({r['claim_hash'] for r in spent_supports}) + delete_others.update(spent_others) # Outputs for output in tx.outputs: if output.is_support: - insert_supports.add(output) + insert_supports.append(output) recalculate_claim_hashes.add(output.claim_hash) elif output.script.is_claim_name: - insert_claims.add(output) + insert_claims.append(output) recalculate_claim_hashes.add(output.claim_hash) elif output.script.is_update_claim: claim_hash = output.claim_hash - if claim_hash in delete_claim_hashes: - delete_claim_hashes.remove(claim_hash) - update_claims.add(output) - recalculate_claim_hashes.add(output.claim_hash) + update_claims.append(output) + recalculate_claim_hashes.add(claim_hash) + delete_claim_hashes.discard(claim_hash) + delete_others.discard(output.ref.hash) # claim insertion and update occurring in the same block body_timer.stop() - channel_public_keys = self.get_channel_public_keys_for_outputs(insert_claims | update_claims) + skip_claim_timer = timer.add_timer('skip insertion of abandoned claims') + skip_claim_timer.start() + for new_claim in list(insert_claims): + if new_claim.ref.hash in delete_others: + insert_claims.remove(new_claim) + self.logger.info( + f"Skipping insertion of claim '{new_claim.id}' due to " + f"an abandon of it in the same block {height}." + ) + skip_claim_timer.stop() r = timer.run r(self.delete_claims, delete_claim_hashes) r(self.delete_supports, delete_support_txo_hashes) - r(self.invalidate_channel_signatures, recalculate_claim_hashes) - r(self.insert_claims, insert_claims, header, channel_public_keys) - r(self.update_claims, update_claims, header, channel_public_keys) + r(self.insert_claims, insert_claims, header) + r(self.update_claims, update_claims, header) + r(self.validate_channel_signatures, height, insert_claims, update_claims) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) r(calculate_trending, self.db, height, self.main.first_sync, daemon_height) @@ -649,23 +728,20 @@ class SQLDB: """ claimtrie.claim_hash as is_controlling, claim.claim_hash, claim.txo_hash, claim.height, - claim.activation_height, claim.effective_amount, claim.support_amount, + claim.is_channel, claim.claims_in_channel, + claim.activation_height, claim.expiration_height, + claim.effective_amount, claim.support_amount, claim.trending_group, claim.trending_mixed, claim.trending_local, claim.trending_global, - CASE WHEN claim.is_channel_signature_valid = 1 - THEN claim.channel_canonical - ELSE claim.canonical - END AS canonical, - CASE WHEN claim.is_channel=1 THEN ( - SELECT COUNT(*) FROM claim as claim_in_channel - WHERE claim_in_channel.channel_hash=claim.claim_hash - ) ELSE 0 END AS claims_in_channel, - channel.txo_hash AS channel_txo_hash, channel.height AS channel_height + claim.short_url, claim.canonical_url, + claim.channel_hash, channel.txo_hash AS channel_txo_hash, + channel.height AS channel_height, claim.is_channel_signature_valid """, **constraints ) INTEGER_PARAMS = { - 'height', 'activation_height', 'release_time', 'publish_time', + 'height', 'creation_height', 'activation_height', 'tx_position', + 'release_time', 'timestamp', 'amount', 'effective_amount', 'support_amount', 'trending_group', 'trending_mixed', 'trending_local', 'trending_global', @@ -684,7 +760,7 @@ class SQLDB: 'name', } | INTEGER_PARAMS - def search(self, constraints) -> Tuple[List, int, int]: + def search(self, constraints) -> Tuple[List, List, int, int]: assert set(constraints).issubset(self.SEARCH_PARAMS), \ f"Search query contains invalid arguments: {set(constraints).difference(self.SEARCH_PARAMS)}" total = self.get_claims_count(**constraints) @@ -693,10 +769,15 @@ class SQLDB: if 'order_by' not in constraints: constraints['order_by'] = ["height", "^name"] txo_rows = self._search(**constraints) - return txo_rows, constraints['offset'], total + channel_hashes = set(txo['channel_hash'] for txo in txo_rows if txo['channel_hash']) + extra_txo_rows = [] + if channel_hashes: + extra_txo_rows = self._search(**{'claim.claim_hash__in': [sqlite3.Binary(h) for h in channel_hashes]}) + return txo_rows, extra_txo_rows, constraints['offset'], total - def resolve(self, urls) -> List: + def resolve(self, urls) -> Tuple[List, List]: result = [] + channel_hashes = set() for raw_url in urls: try: url = URL.parse(raw_url) @@ -723,12 +804,17 @@ class SQLDB: matches = self._search(**query) if matches: result.append(matches[0]) + if matches[0]['channel_hash']: + channel_hashes.add(matches[0]['channel_hash']) else: result.append(LookupError(f'Could not find stream in "{raw_url}".')) continue else: result.append(channel) - return result + extra_txo_rows = [] + if channel_hashes: + extra_txo_rows = self._search(**{'claim.claim_hash__in': [sqlite3.Binary(h) for h in channel_hashes]}) + return result, extra_txo_rows class LBRYDB(DB): diff --git a/lbrynet/wallet/server/session.py b/lbrynet/wallet/server/session.py index 842a7232d..81a8174f5 100644 --- a/lbrynet/wallet/server/session.py +++ b/lbrynet/wallet/server/session.py @@ -51,7 +51,7 @@ class LBRYElectrumX(ElectrumX): return Outputs.to_base64(*self.db.sql.search(kwargs)) async def claimtrie_resolve(self, *urls): - return Outputs.to_base64(self.db.sql.resolve(urls)) + return Outputs.to_base64(*self.db.sql.resolve(urls)) async def get_server_height(self): return self.bp.height diff --git a/lbrynet/wallet/transaction.py b/lbrynet/wallet/transaction.py index cb0a7175b..7afc6b23f 100644 --- a/lbrynet/wallet/transaction.py +++ b/lbrynet/wallet/transaction.py @@ -97,8 +97,7 @@ class Output(BaseOutput): def has_private_key(self): return self.private_key is not None - def is_signed_by(self, channel: 'Output', ledger=None, public_key_bytes=None): - public_key_bytes = public_key_bytes or channel.claim.channel.public_key_bytes + def get_signature_digest(self, ledger): if self.claim.unsigned_payload: pieces = [ Base58.decode(self.get_address(ledger)), @@ -111,20 +110,31 @@ class Output(BaseOutput): self.claim.signing_channel_hash, self.claim.to_message_bytes() ] - digest = sha256(b''.join(pieces)) - public_key = load_der_public_key(public_key_bytes, default_backend()) - hash = hashes.SHA256() + return sha256(b''.join(pieces)) + + def get_encoded_signature(self): signature = hexlify(self.claim.signature) r = int(signature[:int(len(signature)/2)], 16) s = int(signature[int(len(signature)/2):], 16) - encoded_sig = ecdsa.util.sigencode_der(r, s, len(signature)*4) + return ecdsa.util.sigencode_der(r, s, len(signature)*4) + + @staticmethod + def is_signature_valid(encoded_signature, signature_digest, public_key_bytes): try: - public_key.verify(encoded_sig, digest, ec.ECDSA(Prehashed(hash))) + public_key = load_der_public_key(public_key_bytes, default_backend()) + public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256()))) return True except (ValueError, InvalidSignature): pass return False + def is_signed_by(self, channel: 'Output', ledger=None): + return self.is_signature_valid( + self.get_encoded_signature(), + self.get_signature_digest(ledger), + channel.claim.channel.public_key_bytes + ) + def sign(self, channel: 'Output', first_input_id=None): self.channel = channel self.claim.signing_channel_hash = channel.claim_hash diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py index 12bf45d50..e7e3f72fe 100644 --- a/tests/unit/wallet/server/test_sqldb.py +++ b/tests/unit/wallet/server/test_sqldb.py @@ -1,7 +1,7 @@ import unittest import ecdsa import hashlib -from binascii import hexlify, unhexlify +from binascii import hexlify from torba.client.constants import COIN, NULL_HASH32 from lbrynet.schema.claim import Claim @@ -54,22 +54,21 @@ class TestSQLDB(unittest.TestCase): self._txos[output.ref.hash] = output return OldWalletServerTransaction(tx), tx.hash - def get_channel(self, title, amount, name='@foo'): - claim = Claim() - claim.channel.title = title - channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc') - # deterministic private key - private_key = ecdsa.SigningKey.from_string(b'c'*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) - channel.private_key = private_key.to_pem().decode() - channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() - channel.script.generate() - return self._make_tx(channel) - - def get_channel_update(self, channel, amount, key=b'd'): + def _set_channel_key(self, channel, key): private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) channel.private_key = private_key.to_pem().decode() channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() channel.script.generate() + + def get_channel(self, title, amount, name='@foo', key=b'a'): + claim = Claim() + claim.channel.title = title + channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc') + self._set_channel_key(channel, key) + return self._make_tx(channel) + + def get_channel_update(self, channel, amount, key=b'a'): + self._set_channel_key(channel, key) return self._make_tx( Output.pay_update_claim_pubkey_hash( amount, channel.claim_name, channel.claim_id, channel.claim, b'abc' @@ -313,57 +312,81 @@ class TestSQLDB(unittest.TestCase): @staticmethod def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs): - iterations = 100 + iterations = cached_iteration+1 if cached_iteration else 100 for i in range(cached_iteration or 1, iterations): stream = getter(f'claim #{i}', COIN, **kwargs) if stream[0].tx.outputs[0].claim_id.startswith(prefix): cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.') return stream - raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations.') + if cached_iteration: + raise ValueError(f'Failed to find "{prefix}" at cached iteration, run with None to find iteration.') + raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations, try different values.') - def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None): - return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration) + def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): + return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration, **kwargs) def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs) - def test_canonical_name(self): + def test_canonical_url_and_channel_validation(self): advance = self.advance - tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1) - tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72) + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c') + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c') txo_chan_a = tx_chan_a[0].tx.outputs[0] advance(1, [tx_chan_a]) advance(2, [tx_chan_ab]) - r_ab, r_a = self.sql._search(order_by=['height'], limit=2) - self.assertEqual("@foo#a", r_a['canonical']) - self.assertEqual("@foo#ab", r_ab['canonical']) + r_ab, r_a = self.sql._search(order_by=['creation_height'], limit=2) + self.assertEqual("@foo#a", r_a['short_url']) + self.assertEqual("@foo#ab", r_ab['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertEqual(0, r_a['claims_in_channel']) + self.assertEqual(0, r_ab['claims_in_channel']) tx_a = self.get_stream_with_claim_id_prefix('a', 2) tx_ab = self.get_stream_with_claim_id_prefix('ab', 42) tx_abc = self.get_stream_with_claim_id_prefix('abc', 65) advance(3, [tx_a]) - advance(4, [tx_ab]) - advance(5, [tx_abc]) - r_abc, r_ab, r_a = self.sql._search(order_by=['height'], limit=3) - self.assertEqual("foo#a", r_a['canonical']) - self.assertEqual("foo#ab", r_ab['canonical']) - self.assertEqual("foo#abc", r_abc['canonical']) + advance(4, [tx_ab, tx_abc]) + r_abc, r_ab, r_a = self.sql._search(order_by=['creation_height', 'tx_position'], limit=3) + self.assertEqual("foo#a", r_a['short_url']) + self.assertEqual("foo#ab", r_ab['short_url']) + self.assertEqual("foo#abc", r_abc['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertIsNone(r_abc['canonical_url']) tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a) tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a) + a2_claim_id = tx_a2[0].tx.outputs[0].claim_id + ab2_claim_id = tx_ab2[0].tx.outputs[0].claim_id advance(6, [tx_a2]) advance(7, [tx_ab2]) - r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=2) - self.assertEqual("@foo#a/foo#a", r_a2['canonical']) - self.assertEqual("@foo#a/foo#ab", r_ab2['canonical']) + r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url']) + self.assertEqual("@foo#a/foo#a", r_a2['canonical_url']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) + self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) - advance(8, [self.get_channel_update(txo_chan_a, COIN)]) - _, r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=3) - a2_claim_id = hexlify(r_a2['claim_hash'][::-1]).decode() - ab2_claim_id = hexlify(r_ab2['claim_hash'][::-1]).decode() - self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['canonical']) - self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['canonical']) + # invalidate channel signature + advance(8, [self.get_channel_update(txo_chan_a, COIN, key=b'a')]) + r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url']) + self.assertIsNone(r_a2['canonical_url']) + self.assertIsNone(r_ab2['canonical_url']) + self.assertEqual(0, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + + # re-validate signature (reverts signature to original one) + advance(9, [self.get_channel_update(txo_chan_a, COIN, key=b'c')]) + r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2) + self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url']) + self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url']) + self.assertEqual("@foo#a/foo#a", r_a2['canonical_url']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) + self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) def test_canonical_find_shortest_id(self): new_hash = 'abcdef0123456789beef'