diff --git a/lbrynet/wallet/server/canonical.py b/lbrynet/wallet/server/canonical.py index 1d14ee047..20447a71a 100644 --- a/lbrynet/wallet/server/canonical.py +++ b/lbrynet/wallet/server/canonical.py @@ -5,13 +5,11 @@ class FindShortestID: __slots__ = 'short_id', 'new_id' def __init__(self): - self.short_id = b'' + self.short_id = '' self.new_id = None - def step(self, other_hash, new_hash): - if self.new_id is None: - self.new_id = hexlify(new_hash[::-1]) - other_id = hexlify(other_hash[::-1]) + def step(self, other_id, new_id): + self.new_id = new_id for i in range(len(self.new_id)): if other_id[i] != self.new_id[i]: if i > len(self.short_id)-1: @@ -19,9 +17,7 @@ class FindShortestID: break def finalize(self): - if self.short_id: - return '#'+self.short_id.decode() - return '' + return '#'+self.short_id def register_canonical_functions(connection): diff --git a/lbrynet/wallet/server/db.py b/lbrynet/wallet/server/db.py index 8b198ba89..1d492731d 100644 --- a/lbrynet/wallet/server/db.py +++ b/lbrynet/wallet/server/db.py @@ -8,6 +8,7 @@ from torba.server.util import class_logger from torba.client.basedatabase import query, constraints_to_sql from lbrynet.schema.url import URL, normalize_name +from lbrynet.wallet.ledger import MainNetLedger, RegTestLedger from lbrynet.wallet.transaction import Transaction, Output from lbrynet.wallet.server.canonical import register_canonical_functions from lbrynet.wallet.server.trending import ( @@ -67,18 +68,27 @@ class SQLDB: CREATE_CLAIM_TABLE = """ create table if not exists claim ( claim_hash bytes primary key, + claim_id text not null, claim_name text not null, normalized text not null, canonical text not null, - is_channel bool not null, txo_hash bytes not null, tx_position integer not null, - height integer not null, - channel_hash bytes, - release_time integer, - publish_time integer, - activation_height integer, amount integer not null, + is_channel bool not null, + public_key_bytes bytes, + timestamp integer not null, + height integer not null, + creation_height integer not null, + activation_height integer, + expiration_height integer not null, + release_time integer not null, + + channel_hash bytes, + channel_height integer, -- height at which claim got valid signature + channel_canonical text, -- canonical URL \w channel + is_channel_signature_valid bool, + effective_amount integer not null default 0, support_amount integer not null default 0, trending_group integer not null default 0, @@ -91,7 +101,7 @@ class SQLDB: create index if not exists claim_txo_hash_idx on claim (txo_hash); create index if not exists claim_channel_hash_idx on claim (channel_hash); create index if not exists claim_release_time_idx on claim (release_time); - create index if not exists claim_publish_time_idx on claim (publish_time); + create index if not exists claim_timestamp_idx on claim (timestamp); create index if not exists claim_height_idx on claim (height); create index if not exists claim_activation_height_idx on claim (activation_height); @@ -147,6 +157,7 @@ class SQLDB: self._db_path = path self.db = None self.logger = class_logger(__name__, self.__class__.__name__) + self.ledger = MainNetLedger if self.main.coin.NET == 'mainnet' else RegTestLedger def open(self): self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False) @@ -194,7 +205,7 @@ class SQLDB: def commit(self): self.execute('commit;') - def _upsertable_claims(self, txos: Set[Output], header, clear_first=False): + def _upsertable_claims(self, txos: Set[Output], header, channels, clear_first=False): claim_hashes, claims, tags = [], [], [] for txo in txos: tx = txo.tx_ref.tx @@ -210,16 +221,20 @@ class SQLDB: claim_hashes.append(claim_hash) claim_record = { 'claim_hash': claim_hash, - 'normalized': txo.normalized_name, + 'claim_id': txo.claim_id, 'claim_name': txo.claim_name, - 'is_channel': False, + 'normalized': txo.normalized_name, 'txo_hash': sqlite3.Binary(txo.ref.hash), 'tx_position': tx.position, - 'height': tx.height, 'amount': txo.amount, + 'is_channel': False, + 'public_key_bytes': None, + 'timestamp': header['timestamp'], + 'height': tx.height, + 'release_time': None, 'channel_hash': None, - 'publish_time': header['timestamp'], - 'release_time': header['timestamp'] + 'channel_height': None, + 'is_channel_signature_valid': None } claims.append(claim_record) @@ -229,11 +244,22 @@ class SQLDB: #self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.") continue - if claim.is_stream and claim.stream.release_time: - claim_record['release_time'] = claim.stream.release_time - claim_record['is_channel'] = claim.is_channel - if claim.signing_channel_hash: - claim_record['channel_hash'] = sqlite3.Binary(claim.signing_channel_hash) + if claim.is_stream: + if claim.stream.release_time: + claim_record['release_time'] = claim.stream.release_time + if claim.signing_channel_hash: + claim_record['channel_hash'] = sqlite3.Binary(claim.signing_channel_hash) + channel_pub_key = channels.get(claim.signing_channel_hash) + if channel_pub_key: + claim_record['is_channel_signature_valid'] = txo.is_signed_by( + None, ledger=self.ledger, public_key_bytes=channel_pub_key + ) + if claim_record['is_channel_signature_valid']: + claim_record['channel_height'] = tx.height + elif claim.is_channel: + claim_record['is_channel'] = True + claim_record['public_key_bytes'] = sqlite3.Binary(claim.channel.public_key_bytes) + for tag in claim.message.tags: tags.append((tag, claim_hash, tx.height)) @@ -247,38 +273,65 @@ class SQLDB: return claims - def insert_claims(self, txos: Set[Output], header): - claims = self._upsertable_claims(txos, header) + def insert_claims(self, txos: Set[Output], header, channels): + claims = self._upsertable_claims(txos, header, channels) if claims: self.db.executemany(""" INSERT INTO claim ( - claim_hash, normalized, claim_name, is_channel, txo_hash, tx_position, - height, amount, channel_hash, release_time, publish_time, activation_height, - canonical) + claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, + amount, is_channel, public_key_bytes, timestamp, height, creation_height, + channel_hash, channel_height, is_channel_signature_valid, release_time, + activation_height, expiration_height, canonical, channel_canonical) VALUES ( - :claim_hash, :normalized, :claim_name, :is_channel, :txo_hash, :tx_position, - :height, :amount, :channel_hash, :release_time, :publish_time, + :claim_hash, :claim_id, :claim_name, :normalized, :txo_hash, :tx_position, + :amount, :is_channel, :public_key_bytes, :timestamp, :height, :height, + :channel_hash, :channel_height, :is_channel_signature_valid, + CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE :timestamp END, CASE WHEN :normalized NOT IN (SELECT normalized FROM claimtrie) THEN :height END, - CASE WHEN :channel_hash IS NOT NULL - THEN (SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'|| - :normalized||COALESCE((SELECT shortest_id(claim_hash, :claim_hash) - FROM claim WHERE normalized = :normalized), '') - ELSE :normalized||COALESCE((SELECT shortest_id(claim_hash, :claim_hash) - FROM claim WHERE normalized = :normalized), '') + CASE WHEN :height >= 262974 THEN :height+2102400 ELSE :height+262974 END, + :normalized||COALESCE( + (SELECT shortest_id(claim_id, :claim_id) FROM claim WHERE normalized = :normalized), + '#'||substr(:claim_id, 1, 1) + ), + CASE WHEN :is_channel_signature_valid = 1 THEN + (SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'|| + :normalized||COALESCE( + (SELECT shortest_id(claim_id, :claim_id) FROM claim + WHERE normalized = :normalized AND + channel_hash = :channel_hash AND + is_channel_signature_valid = 1), + '#'||substr(:claim_id, 1, 1) + ) END )""", claims) - def update_claims(self, txos: Set[Output], header): - claims = self._upsertable_claims(txos, header, clear_first=True) + def update_claims(self, txos: Set[Output], header, channels): + claims = self._upsertable_claims(txos, header, channels, clear_first=True) if claims: - self.db.executemany( - "UPDATE claim SET " - " is_channel=:is_channel, txo_hash=:txo_hash, tx_position=:tx_position," - " height=:height, amount=:amount, channel_hash=:channel_hash," - " release_time=:release_time, publish_time=:publish_time " - "WHERE claim_hash=:claim_hash;", - claims - ) + self.db.executemany(""" + UPDATE claim SET + txo_hash=:txo_hash, tx_position=:tx_position, height=:height, amount=:amount, + public_key_bytes=:public_key_bytes, timestamp=:timestamp, + release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END, + channel_hash=:channel_hash, is_channel_signature_valid=:is_channel_signature_valid, + channel_height=CASE + WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_height + WHEN :is_channel_signature_valid THEN :height + END, + channel_canonical=CASE + WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_canonical + WHEN :is_channel_signature_valid THEN + (SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'|| + :normalized||COALESCE( + (SELECT shortest_id(claim_id, :claim_id) FROM claim + WHERE normalized = :normalized AND + channel_hash = :channel_hash AND + is_channel_signature_valid = 1), + '#'||substr(:claim_id, 1, 1) + ) + END + WHERE claim_hash=:claim_hash; + """, claims) def delete_claims(self, claim_hashes: Set[bytes]): """ Deletes claim supports and from claimtrie in case of an abandon. """ @@ -293,6 +346,13 @@ class SQLDB: for table in ('tag',): # 'language', 'location', etc self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) + def invalidate_channel_signatures(self, channels): + self.execute(f""" + UPDATE claim SET + channel_height=NULL, channel_canonical=NULL, is_channel_signature_valid=0 + WHERE channel_hash IN ({','.join('?' for _ in channels)}) + """, [sqlite3.Binary(channel) for channel in channels]) + def split_inputs_into_claims_supports_and_other(self, txis): txo_hashes = {txi.txo_ref.hash for txi in txis} claims = self.execute(*query( @@ -309,6 +369,20 @@ class SQLDB: txo_hashes -= {r['txo_hash'] for r in supports} return claims, supports, txo_hashes + def get_channel_public_keys_for_outputs(self, txos): + channels = set() + for txo in txos: + try: + channel_hash = txo.claim.signing_channel_hash + if channel_hash: + channels.add(channel_hash) + except: + pass + return dict(self.execute(*query( + "SELECT claim_hash, public_key_bytes FROM claim", + claim_hash__in=[sqlite3.Binary(channel) for channel in channels] + )).fetchall()) + def insert_supports(self, txos: Set[Output]): supports = [] for txo in txos: @@ -463,11 +537,13 @@ class SQLDB: update_claims.add(output) recalculate_claim_hashes.add(output.claim_hash) body_timer.stop() + channel_public_keys = self.get_channel_public_keys_for_outputs(insert_claims | update_claims) r = timer.run r(self.delete_claims, delete_claim_hashes) r(self.delete_supports, delete_support_txo_hashes) - r(self.insert_claims, insert_claims, header) - r(self.update_claims, update_claims, header) + r(self.invalidate_channel_signatures, recalculate_claim_hashes) + r(self.insert_claims, insert_claims, header, channel_public_keys) + r(self.update_claims, update_claims, header, channel_public_keys) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) r(calculate_trending, self.db, height, self.main.first_sync, daemon_height) @@ -572,15 +648,19 @@ class SQLDB: return self.get_claims( """ claimtrie.claim_hash as is_controlling, - claim.claim_hash, claim.txo_hash, claim.height, claim.canonical, + claim.claim_hash, claim.txo_hash, claim.height, claim.activation_height, claim.effective_amount, claim.support_amount, claim.trending_group, claim.trending_mixed, claim.trending_local, claim.trending_global, + CASE WHEN claim.is_channel_signature_valid = 1 + THEN claim.channel_canonical + ELSE claim.canonical + END AS canonical, CASE WHEN claim.is_channel=1 THEN ( SELECT COUNT(*) FROM claim as claim_in_channel WHERE claim_in_channel.channel_hash=claim.claim_hash ) ELSE 0 END AS claims_in_channel, - channel.txo_hash as channel_txo_hash, channel.height as channel_height + channel.txo_hash AS channel_txo_hash, channel.height AS channel_height """, **constraints ) diff --git a/lbrynet/wallet/transaction.py b/lbrynet/wallet/transaction.py index 277c14001..cb0a7175b 100644 --- a/lbrynet/wallet/transaction.py +++ b/lbrynet/wallet/transaction.py @@ -97,7 +97,8 @@ class Output(BaseOutput): def has_private_key(self): return self.private_key is not None - def is_signed_by(self, channel: 'Output', ledger=None): + def is_signed_by(self, channel: 'Output', ledger=None, public_key_bytes=None): + public_key_bytes = public_key_bytes or channel.claim.channel.public_key_bytes if self.claim.unsigned_payload: pieces = [ Base58.decode(self.get_address(ledger)), @@ -111,7 +112,7 @@ class Output(BaseOutput): self.claim.to_message_bytes() ] digest = sha256(b''.join(pieces)) - public_key = load_der_public_key(channel.claim.channel.public_key_bytes, default_backend()) + public_key = load_der_public_key(public_key_bytes, default_backend()) hash = hashes.SHA256() signature = hexlify(self.claim.signature) r = int(signature[:int(len(signature)/2)], 16) diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py index c8bc41b92..12bf45d50 100644 --- a/tests/unit/wallet/server/test_sqldb.py +++ b/tests/unit/wallet/server/test_sqldb.py @@ -6,6 +6,7 @@ from torba.client.constants import COIN, NULL_HASH32 from lbrynet.schema.claim import Claim from lbrynet.wallet.server.db import SQLDB +from lbrynet.wallet.server.coin import LBCRegTest from lbrynet.wallet.server.trending import TRENDING_WINDOW from lbrynet.wallet.server.canonical import FindShortestID from lbrynet.wallet.server.block_processor import Timer @@ -39,6 +40,7 @@ class TestSQLDB(unittest.TestCase): def setUp(self): self.first_sync = False self.daemon_height = 1 + self.coin = LBCRegTest() self.sql = SQLDB(self, ':memory:') self.timer = Timer('BlockProcessor') self.sql.open() @@ -56,16 +58,33 @@ class TestSQLDB(unittest.TestCase): claim = Claim() claim.channel.title = title channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc') + # deterministic private key private_key = ecdsa.SigningKey.from_string(b'c'*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) channel.private_key = private_key.to_pem().decode() channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() channel.script.generate() return self._make_tx(channel) - def get_stream(self, title, amount, name='foo'): + def get_channel_update(self, channel, amount, key=b'd'): + private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) + channel.private_key = private_key.to_pem().decode() + channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() + channel.script.generate() + return self._make_tx( + Output.pay_update_claim_pubkey_hash( + amount, channel.claim_name, channel.claim_id, channel.claim, b'abc' + ), + Input.spend(channel) + ) + + def get_stream(self, title, amount, name='foo', channel=None): claim = Claim() claim.stream.title = title - return self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')) + result = self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')) + if channel: + result[0].tx.outputs[0].sign(channel) + result[0].tx._reset() + return result def get_stream_update(self, tx, amount): claim = Transaction(tx[0].serialize()).outputs[0] @@ -293,57 +312,66 @@ class TestSQLDB(unittest.TestCase): self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results]) @staticmethod - def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None): + def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs): iterations = 100 for i in range(cached_iteration or 1, iterations): - stream = getter(f'claim #{i}', COIN) + stream = getter(f'claim #{i}', COIN, **kwargs) if stream[0].tx.outputs[0].claim_id.startswith(prefix): - print(f'Found "{prefix}" in {i} iterations.') + cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.') return stream raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations.') - def get_channel_with_claim_id_prefix(self, prefix, cached_iteration): + def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None): return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration) - def get_stream_with_claim_id_prefix(self, prefix, cached_iteration): - return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration) + def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): + return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs) def test_canonical_name(self): advance = self.advance - tx_abc = self.get_stream_with_claim_id_prefix('abc', 65) - tx_ab = self.get_stream_with_claim_id_prefix('ab', 42) + + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1) + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72) + txo_chan_a = tx_chan_a[0].tx.outputs[0] + advance(1, [tx_chan_a]) + advance(2, [tx_chan_ab]) + r_ab, r_a = self.sql._search(order_by=['height'], limit=2) + self.assertEqual("@foo#a", r_a['canonical']) + self.assertEqual("@foo#ab", r_ab['canonical']) + tx_a = self.get_stream_with_claim_id_prefix('a', 2) - advance(1, [tx_a]) - advance(2, [tx_ab]) - advance(3, [tx_abc]) - r_a, r_ab, r_abc = self.sql._search(order_by=['^height']) - self.assertEqual("foo", r_a['canonical']) - self.assertEqual(f"foo#ab", r_ab['canonical']) - self.assertEqual(f"foo#abc", r_abc['canonical']) + tx_ab = self.get_stream_with_claim_id_prefix('ab', 42) + tx_abc = self.get_stream_with_claim_id_prefix('abc', 65) + advance(3, [tx_a]) + advance(4, [tx_ab]) + advance(5, [tx_abc]) + r_abc, r_ab, r_a = self.sql._search(order_by=['height'], limit=3) + self.assertEqual("foo#a", r_a['canonical']) + self.assertEqual("foo#ab", r_ab['canonical']) + self.assertEqual("foo#abc", r_abc['canonical']) - tx_ab = self.get_channel_with_claim_id_prefix('ab', 72) - tx_a = self.get_channel_with_claim_id_prefix('a', 1) - advance(4, [tx_a]) - advance(5, [tx_ab]) + tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a) + tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a) + advance(6, [tx_a2]) + advance(7, [tx_ab2]) + r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=2) + self.assertEqual("@foo#a/foo#a", r_a2['canonical']) + self.assertEqual("@foo#a/foo#ab", r_ab2['canonical']) - tx_c = self.get_stream_with_claim_id_prefix('c', 2) - tx_cd = self.get_stream_with_claim_id_prefix('cd', 2) - advance(6, [tx_c]) - advance(7, [tx_cd]) - - r_a, r_ab, r_abc = self.sql._search(order_by=['^height']) - self.assertEqual("foo", r_a['canonical']) - self.assertEqual(f"foo#ab", r_ab['canonical']) - self.assertEqual(f"foo#abc", r_abc['canonical']) + advance(8, [self.get_channel_update(txo_chan_a, COIN)]) + _, r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=3) + a2_claim_id = hexlify(r_a2['claim_hash'][::-1]).decode() + ab2_claim_id = hexlify(r_ab2['claim_hash'][::-1]).decode() + self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['canonical']) + self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['canonical']) def test_canonical_find_shortest_id(self): - new_hash = unhexlify('abcdef0123456789beef')[::-1] - other0 = unhexlify('1bcdef0123456789beef')[::-1] - other1 = unhexlify('ab1def0123456789beef')[::-1] - other2 = unhexlify('abc1ef0123456789beef')[::-1] - other3 = unhexlify('abcdef0123456789bee1')[::-1] + new_hash = 'abcdef0123456789beef' + other0 = '1bcdef0123456789beef' + other1 = 'ab1def0123456789beef' + other2 = 'abc1ef0123456789beef' + other3 = 'abcdef0123456789bee1' f = FindShortestID() - self.assertEqual('', f.finalize()) f.step(other0, new_hash) self.assertEqual('#a', f.finalize()) f.step(other1, new_hash)