canonical_url takes into account whether channel signature is valid

This commit is contained in:
Lex Berezhny 2019-05-24 22:40:39 -04:00
parent affa46e0f6
commit 370b34f860
4 changed files with 196 additions and 91 deletions

View file

@ -5,13 +5,11 @@ class FindShortestID:
__slots__ = 'short_id', 'new_id'
def __init__(self):
self.short_id = b''
self.short_id = ''
self.new_id = None
def step(self, other_hash, new_hash):
if self.new_id is None:
self.new_id = hexlify(new_hash[::-1])
other_id = hexlify(other_hash[::-1])
def step(self, other_id, new_id):
self.new_id = new_id
for i in range(len(self.new_id)):
if other_id[i] != self.new_id[i]:
if i > len(self.short_id)-1:
@ -19,9 +17,7 @@ class FindShortestID:
break
def finalize(self):
if self.short_id:
return '#'+self.short_id.decode()
return ''
return '#'+self.short_id
def register_canonical_functions(connection):

View file

@ -8,6 +8,7 @@ from torba.server.util import class_logger
from torba.client.basedatabase import query, constraints_to_sql
from lbrynet.schema.url import URL, normalize_name
from lbrynet.wallet.ledger import MainNetLedger, RegTestLedger
from lbrynet.wallet.transaction import Transaction, Output
from lbrynet.wallet.server.canonical import register_canonical_functions
from lbrynet.wallet.server.trending import (
@ -67,18 +68,27 @@ class SQLDB:
CREATE_CLAIM_TABLE = """
create table if not exists claim (
claim_hash bytes primary key,
claim_id text not null,
claim_name text not null,
normalized text not null,
canonical text not null,
is_channel bool not null,
txo_hash bytes not null,
tx_position integer not null,
height integer not null,
channel_hash bytes,
release_time integer,
publish_time integer,
activation_height integer,
amount integer not null,
is_channel bool not null,
public_key_bytes bytes,
timestamp integer not null,
height integer not null,
creation_height integer not null,
activation_height integer,
expiration_height integer not null,
release_time integer not null,
channel_hash bytes,
channel_height integer, -- height at which claim got valid signature
channel_canonical text, -- canonical URL \w channel
is_channel_signature_valid bool,
effective_amount integer not null default 0,
support_amount integer not null default 0,
trending_group integer not null default 0,
@ -91,7 +101,7 @@ class SQLDB:
create index if not exists claim_txo_hash_idx on claim (txo_hash);
create index if not exists claim_channel_hash_idx on claim (channel_hash);
create index if not exists claim_release_time_idx on claim (release_time);
create index if not exists claim_publish_time_idx on claim (publish_time);
create index if not exists claim_timestamp_idx on claim (timestamp);
create index if not exists claim_height_idx on claim (height);
create index if not exists claim_activation_height_idx on claim (activation_height);
@ -147,6 +157,7 @@ class SQLDB:
self._db_path = path
self.db = None
self.logger = class_logger(__name__, self.__class__.__name__)
self.ledger = MainNetLedger if self.main.coin.NET == 'mainnet' else RegTestLedger
def open(self):
self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False)
@ -194,7 +205,7 @@ class SQLDB:
def commit(self):
self.execute('commit;')
def _upsertable_claims(self, txos: Set[Output], header, clear_first=False):
def _upsertable_claims(self, txos: Set[Output], header, channels, clear_first=False):
claim_hashes, claims, tags = [], [], []
for txo in txos:
tx = txo.tx_ref.tx
@ -210,16 +221,20 @@ class SQLDB:
claim_hashes.append(claim_hash)
claim_record = {
'claim_hash': claim_hash,
'normalized': txo.normalized_name,
'claim_id': txo.claim_id,
'claim_name': txo.claim_name,
'is_channel': False,
'normalized': txo.normalized_name,
'txo_hash': sqlite3.Binary(txo.ref.hash),
'tx_position': tx.position,
'height': tx.height,
'amount': txo.amount,
'is_channel': False,
'public_key_bytes': None,
'timestamp': header['timestamp'],
'height': tx.height,
'release_time': None,
'channel_hash': None,
'publish_time': header['timestamp'],
'release_time': header['timestamp']
'channel_height': None,
'is_channel_signature_valid': None
}
claims.append(claim_record)
@ -229,11 +244,22 @@ class SQLDB:
#self.logger.exception(f"Could not parse claim protobuf for {tx.id}:{txo.position}.")
continue
if claim.is_stream and claim.stream.release_time:
claim_record['release_time'] = claim.stream.release_time
claim_record['is_channel'] = claim.is_channel
if claim.signing_channel_hash:
claim_record['channel_hash'] = sqlite3.Binary(claim.signing_channel_hash)
if claim.is_stream:
if claim.stream.release_time:
claim_record['release_time'] = claim.stream.release_time
if claim.signing_channel_hash:
claim_record['channel_hash'] = sqlite3.Binary(claim.signing_channel_hash)
channel_pub_key = channels.get(claim.signing_channel_hash)
if channel_pub_key:
claim_record['is_channel_signature_valid'] = txo.is_signed_by(
None, ledger=self.ledger, public_key_bytes=channel_pub_key
)
if claim_record['is_channel_signature_valid']:
claim_record['channel_height'] = tx.height
elif claim.is_channel:
claim_record['is_channel'] = True
claim_record['public_key_bytes'] = sqlite3.Binary(claim.channel.public_key_bytes)
for tag in claim.message.tags:
tags.append((tag, claim_hash, tx.height))
@ -247,38 +273,65 @@ class SQLDB:
return claims
def insert_claims(self, txos: Set[Output], header):
claims = self._upsertable_claims(txos, header)
def insert_claims(self, txos: Set[Output], header, channels):
claims = self._upsertable_claims(txos, header, channels)
if claims:
self.db.executemany("""
INSERT INTO claim (
claim_hash, normalized, claim_name, is_channel, txo_hash, tx_position,
height, amount, channel_hash, release_time, publish_time, activation_height,
canonical)
claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position,
amount, is_channel, public_key_bytes, timestamp, height, creation_height,
channel_hash, channel_height, is_channel_signature_valid, release_time,
activation_height, expiration_height, canonical, channel_canonical)
VALUES (
:claim_hash, :normalized, :claim_name, :is_channel, :txo_hash, :tx_position,
:height, :amount, :channel_hash, :release_time, :publish_time,
:claim_hash, :claim_id, :claim_name, :normalized, :txo_hash, :tx_position,
:amount, :is_channel, :public_key_bytes, :timestamp, :height, :height,
:channel_hash, :channel_height, :is_channel_signature_valid,
CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE :timestamp END,
CASE WHEN :normalized NOT IN (SELECT normalized FROM claimtrie) THEN :height END,
CASE WHEN :channel_hash IS NOT NULL
THEN (SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'||
:normalized||COALESCE((SELECT shortest_id(claim_hash, :claim_hash)
FROM claim WHERE normalized = :normalized), '')
ELSE :normalized||COALESCE((SELECT shortest_id(claim_hash, :claim_hash)
FROM claim WHERE normalized = :normalized), '')
CASE WHEN :height >= 262974 THEN :height+2102400 ELSE :height+262974 END,
:normalized||COALESCE(
(SELECT shortest_id(claim_id, :claim_id) FROM claim WHERE normalized = :normalized),
'#'||substr(:claim_id, 1, 1)
),
CASE WHEN :is_channel_signature_valid = 1 THEN
(SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'||
:normalized||COALESCE(
(SELECT shortest_id(claim_id, :claim_id) FROM claim
WHERE normalized = :normalized AND
channel_hash = :channel_hash AND
is_channel_signature_valid = 1),
'#'||substr(:claim_id, 1, 1)
)
END
)""", claims)
def update_claims(self, txos: Set[Output], header):
claims = self._upsertable_claims(txos, header, clear_first=True)
def update_claims(self, txos: Set[Output], header, channels):
claims = self._upsertable_claims(txos, header, channels, clear_first=True)
if claims:
self.db.executemany(
"UPDATE claim SET "
" is_channel=:is_channel, txo_hash=:txo_hash, tx_position=:tx_position,"
" height=:height, amount=:amount, channel_hash=:channel_hash,"
" release_time=:release_time, publish_time=:publish_time "
"WHERE claim_hash=:claim_hash;",
claims
)
self.db.executemany("""
UPDATE claim SET
txo_hash=:txo_hash, tx_position=:tx_position, height=:height, amount=:amount,
public_key_bytes=:public_key_bytes, timestamp=:timestamp,
release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END,
channel_hash=:channel_hash, is_channel_signature_valid=:is_channel_signature_valid,
channel_height=CASE
WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_height
WHEN :is_channel_signature_valid THEN :height
END,
channel_canonical=CASE
WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_canonical
WHEN :is_channel_signature_valid THEN
(SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'||
:normalized||COALESCE(
(SELECT shortest_id(claim_id, :claim_id) FROM claim
WHERE normalized = :normalized AND
channel_hash = :channel_hash AND
is_channel_signature_valid = 1),
'#'||substr(:claim_id, 1, 1)
)
END
WHERE claim_hash=:claim_hash;
""", claims)
def delete_claims(self, claim_hashes: Set[bytes]):
""" Deletes claim supports and from claimtrie in case of an abandon. """
@ -293,6 +346,13 @@ class SQLDB:
for table in ('tag',): # 'language', 'location', etc
self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes}))
def invalidate_channel_signatures(self, channels):
self.execute(f"""
UPDATE claim SET
channel_height=NULL, channel_canonical=NULL, is_channel_signature_valid=0
WHERE channel_hash IN ({','.join('?' for _ in channels)})
""", [sqlite3.Binary(channel) for channel in channels])
def split_inputs_into_claims_supports_and_other(self, txis):
txo_hashes = {txi.txo_ref.hash for txi in txis}
claims = self.execute(*query(
@ -309,6 +369,20 @@ class SQLDB:
txo_hashes -= {r['txo_hash'] for r in supports}
return claims, supports, txo_hashes
def get_channel_public_keys_for_outputs(self, txos):
channels = set()
for txo in txos:
try:
channel_hash = txo.claim.signing_channel_hash
if channel_hash:
channels.add(channel_hash)
except:
pass
return dict(self.execute(*query(
"SELECT claim_hash, public_key_bytes FROM claim",
claim_hash__in=[sqlite3.Binary(channel) for channel in channels]
)).fetchall())
def insert_supports(self, txos: Set[Output]):
supports = []
for txo in txos:
@ -463,11 +537,13 @@ class SQLDB:
update_claims.add(output)
recalculate_claim_hashes.add(output.claim_hash)
body_timer.stop()
channel_public_keys = self.get_channel_public_keys_for_outputs(insert_claims | update_claims)
r = timer.run
r(self.delete_claims, delete_claim_hashes)
r(self.delete_supports, delete_support_txo_hashes)
r(self.insert_claims, insert_claims, header)
r(self.update_claims, update_claims, header)
r(self.invalidate_channel_signatures, recalculate_claim_hashes)
r(self.insert_claims, insert_claims, header, channel_public_keys)
r(self.update_claims, update_claims, header, channel_public_keys)
r(self.insert_supports, insert_supports)
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
r(calculate_trending, self.db, height, self.main.first_sync, daemon_height)
@ -572,15 +648,19 @@ class SQLDB:
return self.get_claims(
"""
claimtrie.claim_hash as is_controlling,
claim.claim_hash, claim.txo_hash, claim.height, claim.canonical,
claim.claim_hash, claim.txo_hash, claim.height,
claim.activation_height, claim.effective_amount, claim.support_amount,
claim.trending_group, claim.trending_mixed,
claim.trending_local, claim.trending_global,
CASE WHEN claim.is_channel_signature_valid = 1
THEN claim.channel_canonical
ELSE claim.canonical
END AS canonical,
CASE WHEN claim.is_channel=1 THEN (
SELECT COUNT(*) FROM claim as claim_in_channel
WHERE claim_in_channel.channel_hash=claim.claim_hash
) ELSE 0 END AS claims_in_channel,
channel.txo_hash as channel_txo_hash, channel.height as channel_height
channel.txo_hash AS channel_txo_hash, channel.height AS channel_height
""", **constraints
)

View file

@ -97,7 +97,8 @@ class Output(BaseOutput):
def has_private_key(self):
return self.private_key is not None
def is_signed_by(self, channel: 'Output', ledger=None):
def is_signed_by(self, channel: 'Output', ledger=None, public_key_bytes=None):
public_key_bytes = public_key_bytes or channel.claim.channel.public_key_bytes
if self.claim.unsigned_payload:
pieces = [
Base58.decode(self.get_address(ledger)),
@ -111,7 +112,7 @@ class Output(BaseOutput):
self.claim.to_message_bytes()
]
digest = sha256(b''.join(pieces))
public_key = load_der_public_key(channel.claim.channel.public_key_bytes, default_backend())
public_key = load_der_public_key(public_key_bytes, default_backend())
hash = hashes.SHA256()
signature = hexlify(self.claim.signature)
r = int(signature[:int(len(signature)/2)], 16)

View file

@ -6,6 +6,7 @@ from torba.client.constants import COIN, NULL_HASH32
from lbrynet.schema.claim import Claim
from lbrynet.wallet.server.db import SQLDB
from lbrynet.wallet.server.coin import LBCRegTest
from lbrynet.wallet.server.trending import TRENDING_WINDOW
from lbrynet.wallet.server.canonical import FindShortestID
from lbrynet.wallet.server.block_processor import Timer
@ -39,6 +40,7 @@ class TestSQLDB(unittest.TestCase):
def setUp(self):
self.first_sync = False
self.daemon_height = 1
self.coin = LBCRegTest()
self.sql = SQLDB(self, ':memory:')
self.timer = Timer('BlockProcessor')
self.sql.open()
@ -56,16 +58,33 @@ class TestSQLDB(unittest.TestCase):
claim = Claim()
claim.channel.title = title
channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')
# deterministic private key
private_key = ecdsa.SigningKey.from_string(b'c'*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256)
channel.private_key = private_key.to_pem().decode()
channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der()
channel.script.generate()
return self._make_tx(channel)
def get_stream(self, title, amount, name='foo'):
def get_channel_update(self, channel, amount, key=b'd'):
private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256)
channel.private_key = private_key.to_pem().decode()
channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der()
channel.script.generate()
return self._make_tx(
Output.pay_update_claim_pubkey_hash(
amount, channel.claim_name, channel.claim_id, channel.claim, b'abc'
),
Input.spend(channel)
)
def get_stream(self, title, amount, name='foo', channel=None):
claim = Claim()
claim.stream.title = title
return self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc'))
result = self._make_tx(Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc'))
if channel:
result[0].tx.outputs[0].sign(channel)
result[0].tx._reset()
return result
def get_stream_update(self, tx, amount):
claim = Transaction(tx[0].serialize()).outputs[0]
@ -293,57 +312,66 @@ class TestSQLDB(unittest.TestCase):
self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])
@staticmethod
def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None):
def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs):
iterations = 100
for i in range(cached_iteration or 1, iterations):
stream = getter(f'claim #{i}', COIN)
stream = getter(f'claim #{i}', COIN, **kwargs)
if stream[0].tx.outputs[0].claim_id.startswith(prefix):
print(f'Found "{prefix}" in {i} iterations.')
cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.')
return stream
raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations.')
def get_channel_with_claim_id_prefix(self, prefix, cached_iteration):
def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None):
return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration)
def get_stream_with_claim_id_prefix(self, prefix, cached_iteration):
return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration)
def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs):
return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs)
def test_canonical_name(self):
advance = self.advance
tx_abc = self.get_stream_with_claim_id_prefix('abc', 65)
tx_ab = self.get_stream_with_claim_id_prefix('ab', 42)
tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1)
tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72)
txo_chan_a = tx_chan_a[0].tx.outputs[0]
advance(1, [tx_chan_a])
advance(2, [tx_chan_ab])
r_ab, r_a = self.sql._search(order_by=['height'], limit=2)
self.assertEqual("@foo#a", r_a['canonical'])
self.assertEqual("@foo#ab", r_ab['canonical'])
tx_a = self.get_stream_with_claim_id_prefix('a', 2)
advance(1, [tx_a])
advance(2, [tx_ab])
advance(3, [tx_abc])
r_a, r_ab, r_abc = self.sql._search(order_by=['^height'])
self.assertEqual("foo", r_a['canonical'])
self.assertEqual(f"foo#ab", r_ab['canonical'])
self.assertEqual(f"foo#abc", r_abc['canonical'])
tx_ab = self.get_stream_with_claim_id_prefix('ab', 42)
tx_abc = self.get_stream_with_claim_id_prefix('abc', 65)
advance(3, [tx_a])
advance(4, [tx_ab])
advance(5, [tx_abc])
r_abc, r_ab, r_a = self.sql._search(order_by=['height'], limit=3)
self.assertEqual("foo#a", r_a['canonical'])
self.assertEqual("foo#ab", r_ab['canonical'])
self.assertEqual("foo#abc", r_abc['canonical'])
tx_ab = self.get_channel_with_claim_id_prefix('ab', 72)
tx_a = self.get_channel_with_claim_id_prefix('a', 1)
advance(4, [tx_a])
advance(5, [tx_ab])
tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a)
tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a)
advance(6, [tx_a2])
advance(7, [tx_ab2])
r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=2)
self.assertEqual("@foo#a/foo#a", r_a2['canonical'])
self.assertEqual("@foo#a/foo#ab", r_ab2['canonical'])
tx_c = self.get_stream_with_claim_id_prefix('c', 2)
tx_cd = self.get_stream_with_claim_id_prefix('cd', 2)
advance(6, [tx_c])
advance(7, [tx_cd])
r_a, r_ab, r_abc = self.sql._search(order_by=['^height'])
self.assertEqual("foo", r_a['canonical'])
self.assertEqual(f"foo#ab", r_ab['canonical'])
self.assertEqual(f"foo#abc", r_abc['canonical'])
advance(8, [self.get_channel_update(txo_chan_a, COIN)])
_, r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=3)
a2_claim_id = hexlify(r_a2['claim_hash'][::-1]).decode()
ab2_claim_id = hexlify(r_ab2['claim_hash'][::-1]).decode()
self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['canonical'])
self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['canonical'])
def test_canonical_find_shortest_id(self):
new_hash = unhexlify('abcdef0123456789beef')[::-1]
other0 = unhexlify('1bcdef0123456789beef')[::-1]
other1 = unhexlify('ab1def0123456789beef')[::-1]
other2 = unhexlify('abc1ef0123456789beef')[::-1]
other3 = unhexlify('abcdef0123456789bee1')[::-1]
new_hash = 'abcdef0123456789beef'
other0 = '1bcdef0123456789beef'
other1 = 'ab1def0123456789beef'
other2 = 'abc1ef0123456789beef'
other3 = 'abcdef0123456789bee1'
f = FindShortestID()
self.assertEqual('', f.finalize())
f.step(other0, new_hash)
self.assertEqual('#a', f.finalize())
f.step(other1, new_hash)