wallet server tracks claim signature

This commit is contained in:
Lex Berezhny 2019-05-25 23:06:22 -04:00
parent 9d97b90ff4
commit 29bd936181
6 changed files with 427 additions and 275 deletions

View file

@ -2,6 +2,7 @@ import base64
import struct import struct
from typing import List from typing import List
from binascii import hexlify from binascii import hexlify
from itertools import chain
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
@ -10,48 +11,45 @@ from lbrynet.schema.types.v2.result_pb2 import Outputs as OutputsMessage
class Outputs: class Outputs:
__slots__ = 'txos', 'txs', 'offset', 'total' __slots__ = 'txos', 'extra_txos', 'txs', 'offset', 'total'
def __init__(self, txos: List, txs: List, offset: int, total: int): def __init__(self, txos: List, extra_txos: List, txs: set, offset: int, total: int):
self.txos = txos self.txos = txos
self.txs = txs self.txs = txs
self.extra_txos = extra_txos
self.offset = offset self.offset = offset
self.total = total self.total = total
def _inflate_claim(self, txo, message):
txo.meta = {
'canonical_url': message.canonical_url,
'is_controlling': message.is_controlling,
'activation_height': message.activation_height,
'effective_amount': message.effective_amount,
'support_amount': message.support_amount,
'claims_in_channel': message.claims_in_channel,
'trending_group': message.trending_group,
'trending_mixed': message.trending_mixed,
'trending_local': message.trending_local,
'trending_global': message.trending_global,
}
try:
if txo.claim.is_channel:
txo.meta['claims_in_channel'] = message.claims_in_channel
except DecodeError:
pass
def inflate(self, txs): def inflate(self, txs):
tx_map, txos = {tx.hash: tx for tx in txs}, [] tx_map = {tx.hash: tx for tx in txs}
for txo_message in self.txos: for txo_message in self.extra_txos:
if txo_message.WhichOneof('meta') == 'error': self.message_to_txo(txo_message, tx_map)
txos.append(None) return [self.message_to_txo(txo_message, tx_map) for txo_message in self.txos]
continue
txo = tx_map[txo_message.tx_hash].outputs[txo_message.nout] def message_to_txo(self, txo_message, tx_map):
if txo_message.WhichOneof('meta') == 'claim': if txo_message.WhichOneof('meta') == 'error':
self._inflate_claim(txo, txo_message.claim) return None
if txo_message.claim.HasField('channel'): txo = tx_map[txo_message.tx_hash].outputs[txo_message.nout]
channel_message = txo_message.claim.channel if txo_message.WhichOneof('meta') == 'claim':
txo.channel = tx_map[channel_message.tx_hash].outputs[channel_message.nout] claim = txo_message.claim
self._inflate_claim(txo.channel, channel_message.claim) txo.meta = {
txos.append(txo) 'short_url': claim.short_url,
return txos 'canonical_url': claim.canonical_url or claim.short_url,
'is_controlling': claim.is_controlling,
'activation_height': claim.activation_height,
'expiration_height': claim.expiration_height,
'effective_amount': claim.effective_amount,
'support_amount': claim.support_amount,
'trending_group': claim.trending_group,
'trending_mixed': claim.trending_mixed,
'trending_local': claim.trending_local,
'trending_global': claim.trending_global,
}
if claim.HasField('channel'):
txo.channel = tx_map[claim.channel.tx_hash].outputs[claim.channel.nout]
if claim.claims_in_channel is not None:
txo.meta['claims_in_channel'] = claim.claims_in_channel
return txo
@classmethod @classmethod
def from_base64(cls, data: str) -> 'Outputs': def from_base64(cls, data: str) -> 'Outputs':
@ -61,50 +59,56 @@ class Outputs:
def from_bytes(cls, data: bytes) -> 'Outputs': def from_bytes(cls, data: bytes) -> 'Outputs':
outputs = OutputsMessage() outputs = OutputsMessage()
outputs.ParseFromString(data) outputs.ParseFromString(data)
txs = {} txs = set()
for txo_message in outputs.txos: for txo_message in chain(outputs.txos, outputs.extra_txos):
if txo_message.WhichOneof('meta') == 'error': if txo_message.WhichOneof('meta') == 'error':
continue continue
txs[txo_message.tx_hash] = (hexlify(txo_message.tx_hash[::-1]).decode(), txo_message.height) txs.add((hexlify(txo_message.tx_hash[::-1]).decode(), txo_message.height))
if txo_message.WhichOneof('meta') == 'claim' and txo_message.claim.HasField('channel'): return cls(outputs.txos, outputs.extra_txos, txs, outputs.offset, outputs.total)
channel = txo_message.claim.channel
txs[channel.tx_hash] = (hexlify(channel.tx_hash[::-1]).decode(), channel.height)
return cls(outputs.txos, list(txs.values()), outputs.offset, outputs.total)
@classmethod @classmethod
def to_base64(cls, txo_rows, offset=0, total=None) -> str: def to_base64(cls, txo_rows, extra_txo_rows, offset=0, total=None) -> str:
return base64.b64encode(cls.to_bytes(txo_rows, offset, total)).decode() return base64.b64encode(cls.to_bytes(txo_rows, extra_txo_rows, offset, total)).decode()
@classmethod @classmethod
def to_bytes(cls, txo_rows, offset=0, total=None) -> bytes: def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None) -> bytes:
page = OutputsMessage() page = OutputsMessage()
page.offset = offset page.offset = offset
page.total = total or len(txo_rows) page.total = total or len(txo_rows)
for txo in txo_rows: for row in txo_rows:
txo_message = page.txos.add() cls.row_to_message(row, page.txos.add())
if isinstance(txo, Exception): for row in extra_txo_rows:
txo_message.error.text = txo.args[0] cls.row_to_message(row, page.extra_txos.add())
if isinstance(txo, ValueError):
txo_message.error.code = txo_message.error.INVALID
elif isinstance(txo, LookupError):
txo_message.error.code = txo_message.error.NOT_FOUND
continue
txo_message.height = txo['height']
txo_message.tx_hash = txo['txo_hash'][:32]
txo_message.nout, = struct.unpack('<I', txo['txo_hash'][32:])
txo_message.claim.canonical_url = txo['canonical_url']
txo_message.claim.is_controlling = bool(txo['is_controlling'])
txo_message.claim.activation_height = txo['activation_height']
txo_message.claim.effective_amount = txo['effective_amount']
txo_message.claim.support_amount = txo['support_amount']
txo_message.claim.claims_in_channel = txo['claims_in_channel']
txo_message.claim.trending_group = txo['trending_group']
txo_message.claim.trending_mixed = txo['trending_mixed']
txo_message.claim.trending_local = txo['trending_local']
txo_message.claim.trending_global = txo['trending_global']
if txo['channel_txo_hash']:
channel = txo_message.claim.channel
channel.height = txo['channel_height']
channel.tx_hash = txo['channel_txo_hash'][:32]
channel.nout, = struct.unpack('<I', txo['channel_txo_hash'][32:])
return page.SerializeToString() return page.SerializeToString()
@classmethod
def row_to_message(cls, txo, txo_message):
if isinstance(txo, Exception):
txo_message.error.text = txo.args[0]
if isinstance(txo, ValueError):
txo_message.error.code = txo_message.error.INVALID
elif isinstance(txo, LookupError):
txo_message.error.code = txo_message.error.NOT_FOUND
return
txo_message.tx_hash = txo['txo_hash'][:32]
txo_message.nout, = struct.unpack('<I', txo['txo_hash'][32:])
txo_message.height = txo['height']
txo_message.claim.short_url = txo['short_url']
if txo['canonical_url'] is not None:
txo_message.claim.canonical_url = txo['canonical_url']
txo_message.claim.is_controlling = bool(txo['is_controlling'])
txo_message.claim.activation_height = txo['activation_height']
txo_message.claim.expiration_height = txo['expiration_height']
if txo['claims_in_channel'] is not None:
txo_message.claim.claims_in_channel = txo['claims_in_channel']
txo_message.claim.effective_amount = txo['effective_amount']
txo_message.claim.support_amount = txo['support_amount']
txo_message.claim.trending_group = txo['trending_group']
txo_message.claim.trending_mixed = txo['trending_mixed']
txo_message.claim.trending_local = txo['trending_local']
txo_message.claim.trending_global = txo['trending_global']
if txo['channel_txo_hash']:
channel = txo_message.claim.channel
channel.tx_hash = txo['channel_txo_hash'][:32]
channel.nout, = struct.unpack('<I', txo['channel_txo_hash'][32:])
channel.height = txo['channel_height']

View file

@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='pb', package='pb',
syntax='proto3', syntax='proto3',
serialized_options=None, serialized_options=None,
serialized_pb=_b('\n\x0cresult.proto\x12\x02pb\"B\n\x07Outputs\x12\x18\n\x04txos\x18\x01 \x03(\x0b\x32\n.pb.Output\x12\r\n\x05total\x18\x02 \x01(\r\x12\x0e\n\x06offset\x18\x03 \x01(\r\"{\n\x06Output\x12\x0f\n\x07tx_hash\x18\x01 \x01(\x0c\x12\x0c\n\x04nout\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\x1e\n\x05\x63laim\x18\x07 \x01(\x0b\x32\r.pb.ClaimMetaH\x00\x12\x1a\n\x05\x65rror\x18\x0f \x01(\x0b\x32\t.pb.ErrorH\x00\x42\x06\n\x04meta\"\x89\x02\n\tClaimMeta\x12\x1b\n\x07\x63hannel\x18\x01 \x01(\x0b\x32\n.pb.Output\x12\x16\n\x0eis_controlling\x18\x02 \x01(\x08\x12\x19\n\x11\x61\x63tivation_height\x18\x03 \x01(\r\x12\x18\n\x10\x65\x66\x66\x65\x63tive_amount\x18\x04 \x01(\x04\x12\x16\n\x0esupport_amount\x18\x05 \x01(\x04\x12\x19\n\x11\x63laims_in_channel\x18\x06 \x01(\r\x12\x16\n\x0etrending_group\x18\x07 \x01(\r\x12\x16\n\x0etrending_mixed\x18\x08 \x01(\x02\x12\x16\n\x0etrending_local\x18\t \x01(\x02\x12\x17\n\x0ftrending_global\x18\n \x01(\x02\"i\n\x05\x45rror\x12\x1c\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x0e.pb.Error.Code\x12\x0c\n\x04text\x18\x02 \x01(\t\"4\n\x04\x43ode\x12\x10\n\x0cUNKNOWN_CODE\x10\x00\x12\r\n\tNOT_FOUND\x10\x01\x12\x0b\n\x07INVALID\x10\x02\x62\x06proto3') serialized_pb=_b('\n\x0cresult.proto\x12\x02pb\"b\n\x07Outputs\x12\x18\n\x04txos\x18\x01 \x03(\x0b\x32\n.pb.Output\x12\x1e\n\nextra_txos\x18\x02 \x03(\x0b\x32\n.pb.Output\x12\r\n\x05total\x18\x03 \x01(\r\x12\x0e\n\x06offset\x18\x04 \x01(\r\"{\n\x06Output\x12\x0f\n\x07tx_hash\x18\x01 \x01(\x0c\x12\x0c\n\x04nout\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\x1e\n\x05\x63laim\x18\x07 \x01(\x0b\x32\r.pb.ClaimMetaH\x00\x12\x1a\n\x05\x65rror\x18\x0f \x01(\x0b\x32\t.pb.ErrorH\x00\x42\x06\n\x04meta\"\xce\x02\n\tClaimMeta\x12\x1b\n\x07\x63hannel\x18\x01 \x01(\x0b\x32\n.pb.Output\x12\x11\n\tshort_url\x18\x02 \x01(\t\x12\x15\n\rcanonical_url\x18\x03 \x01(\t\x12\x16\n\x0eis_controlling\x18\x04 \x01(\x08\x12\x19\n\x11\x61\x63tivation_height\x18\x05 \x01(\r\x12\x19\n\x11\x65xpiration_height\x18\x06 \x01(\r\x12\x19\n\x11\x63laims_in_channel\x18\x07 \x01(\r\x12\x18\n\x10\x65\x66\x66\x65\x63tive_amount\x18\n \x01(\x04\x12\x16\n\x0esupport_amount\x18\x0b \x01(\x04\x12\x16\n\x0etrending_group\x18\x0c \x01(\r\x12\x16\n\x0etrending_mixed\x18\r \x01(\x02\x12\x16\n\x0etrending_local\x18\x0e \x01(\x02\x12\x17\n\x0ftrending_global\x18\x0f \x01(\x02\"i\n\x05\x45rror\x12\x1c\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x0e.pb.Error.Code\x12\x0c\n\x04text\x18\x02 \x01(\t\"4\n\x04\x43ode\x12\x10\n\x0cUNKNOWN_CODE\x10\x00\x12\r\n\tNOT_FOUND\x10\x01\x12\x0b\n\x07INVALID\x10\x02\x62\x06proto3')
) )
@ -45,8 +45,8 @@ _ERROR_CODE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=534, serialized_start=635,
serialized_end=586, serialized_end=687,
) )
_sym_db.RegisterEnumDescriptor(_ERROR_CODE) _sym_db.RegisterEnumDescriptor(_ERROR_CODE)
@ -66,15 +66,22 @@ _OUTPUTS = _descriptor.Descriptor(
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='total', full_name='pb.Outputs.total', index=1, name='extra_txos', full_name='pb.Outputs.extra_txos', index=1,
number=2, type=13, cpp_type=3, label=1, number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='total', full_name='pb.Outputs.total', index=2,
number=3, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='offset', full_name='pb.Outputs.offset', index=2, name='offset', full_name='pb.Outputs.offset', index=3,
number=3, type=13, cpp_type=3, label=1, number=4, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
@ -92,7 +99,7 @@ _OUTPUTS = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=20, serialized_start=20,
serialized_end=86, serialized_end=118,
) )
@ -153,8 +160,8 @@ _OUTPUT = _descriptor.Descriptor(
name='meta', full_name='pb.Output.meta', name='meta', full_name='pb.Output.meta',
index=0, containing_type=None, fields=[]), index=0, containing_type=None, fields=[]),
], ],
serialized_start=88, serialized_start=120,
serialized_end=211, serialized_end=243,
) )
@ -173,64 +180,85 @@ _CLAIMMETA = _descriptor.Descriptor(
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='is_controlling', full_name='pb.ClaimMeta.is_controlling', index=1, name='short_url', full_name='pb.ClaimMeta.short_url', index=1,
number=2, type=8, cpp_type=7, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='canonical_url', full_name='pb.ClaimMeta.canonical_url', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='is_controlling', full_name='pb.ClaimMeta.is_controlling', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False, has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='activation_height', full_name='pb.ClaimMeta.activation_height', index=2, name='activation_height', full_name='pb.ClaimMeta.activation_height', index=4,
number=3, type=13, cpp_type=3, label=1, number=5, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='effective_amount', full_name='pb.ClaimMeta.effective_amount', index=3, name='expiration_height', full_name='pb.ClaimMeta.expiration_height', index=5,
number=4, type=4, cpp_type=4, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='support_amount', full_name='pb.ClaimMeta.support_amount', index=4,
number=5, type=4, cpp_type=4, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='claims_in_channel', full_name='pb.ClaimMeta.claims_in_channel', index=5,
number=6, type=13, cpp_type=3, label=1, number=6, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='trending_group', full_name='pb.ClaimMeta.trending_group', index=6, name='claims_in_channel', full_name='pb.ClaimMeta.claims_in_channel', index=6,
number=7, type=13, cpp_type=3, label=1, number=7, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='trending_mixed', full_name='pb.ClaimMeta.trending_mixed', index=7, name='effective_amount', full_name='pb.ClaimMeta.effective_amount', index=7,
number=8, type=2, cpp_type=6, label=1, number=10, type=4, cpp_type=4, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='support_amount', full_name='pb.ClaimMeta.support_amount', index=8,
number=11, type=4, cpp_type=4, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='trending_group', full_name='pb.ClaimMeta.trending_group', index=9,
number=12, type=13, cpp_type=3, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='trending_mixed', full_name='pb.ClaimMeta.trending_mixed', index=10,
number=13, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0), has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='trending_local', full_name='pb.ClaimMeta.trending_local', index=8, name='trending_local', full_name='pb.ClaimMeta.trending_local', index=11,
number=9, type=2, cpp_type=6, label=1, number=14, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0), has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='trending_global', full_name='pb.ClaimMeta.trending_global', index=9, name='trending_global', full_name='pb.ClaimMeta.trending_global', index=12,
number=10, type=2, cpp_type=6, label=1, number=15, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0), has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
@ -247,8 +275,8 @@ _CLAIMMETA = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=214, serialized_start=246,
serialized_end=479, serialized_end=580,
) )
@ -286,11 +314,12 @@ _ERROR = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=481, serialized_start=582,
serialized_end=586, serialized_end=687,
) )
_OUTPUTS.fields_by_name['txos'].message_type = _OUTPUT _OUTPUTS.fields_by_name['txos'].message_type = _OUTPUT
_OUTPUTS.fields_by_name['extra_txos'].message_type = _OUTPUT
_OUTPUT.fields_by_name['claim'].message_type = _CLAIMMETA _OUTPUT.fields_by_name['claim'].message_type = _CLAIMMETA
_OUTPUT.fields_by_name['error'].message_type = _ERROR _OUTPUT.fields_by_name['error'].message_type = _ERROR
_OUTPUT.oneofs_by_name['meta'].fields.append( _OUTPUT.oneofs_by_name['meta'].fields.append(

View file

@ -2,6 +2,8 @@ import sqlite3
import struct import struct
from typing import Union, Tuple, Set, List from typing import Union, Tuple, Set, List
from binascii import unhexlify from binascii import unhexlify
from itertools import chain
from torba.server.db import DB from torba.server.db import DB
from torba.server.util import class_logger from torba.server.util import class_logger
@ -71,23 +73,31 @@ class SQLDB:
claim_id text not null, claim_id text not null,
claim_name text not null, claim_name text not null,
normalized text not null, normalized text not null,
canonical text not null,
txo_hash bytes not null, txo_hash bytes not null,
tx_position integer not null, tx_position integer not null,
amount integer not null, amount integer not null,
is_channel bool not null, timestamp integer not null, -- last updated timestamp
public_key_bytes bytes, creation_timestamp integer not null,
timestamp integer not null, height integer not null, -- last updated height
height integer not null,
creation_height integer not null, creation_height integer not null,
activation_height integer, activation_height integer,
expiration_height integer not null, expiration_height integer not null,
release_time integer not null, release_time integer not null,
short_url text not null, -- normalized#shortest-unique-claim_id
canonical_url text, -- channel's-short_url/normalized#shortest-unique-claim_id-within-channel
-- claims which are channels
is_channel bool not null,
public_key_bytes bytes,
claims_in_channel integer,
-- claims which are inside channels
channel_hash bytes, channel_hash bytes,
channel_height integer, -- height at which claim got valid signature channel_join integer, -- height at which claim got valid signature / joined channel
channel_canonical text, -- canonical URL \w channel signature bytes,
is_channel_signature_valid bool, signature_digest bytes,
is_channel_signature_valid bool not null default false,
effective_amount integer not null default 0, effective_amount integer not null default 0,
support_amount integer not null default 0, support_amount integer not null default 0,
@ -205,7 +215,7 @@ class SQLDB:
def commit(self): def commit(self):
self.execute('commit;') self.execute('commit;')
def _upsertable_claims(self, txos: Set[Output], header, channels, clear_first=False): def _upsertable_claims(self, txos: List[Output], header, clear_first=False):
claim_hashes, claims, tags = [], [], [] claim_hashes, claims, tags = [], [], []
for txo in txos: for txo in txos:
tx = txo.tx_ref.tx tx = txo.tx_ref.tx
@ -228,13 +238,9 @@ class SQLDB:
'tx_position': tx.position, 'tx_position': tx.position,
'amount': txo.amount, 'amount': txo.amount,
'is_channel': False, 'is_channel': False,
'public_key_bytes': None,
'timestamp': header['timestamp'], 'timestamp': header['timestamp'],
'height': tx.height, 'height': tx.height,
'release_time': None, 'release_time': None,
'channel_hash': None,
'channel_height': None,
'is_channel_signature_valid': None
} }
claims.append(claim_record) claims.append(claim_record)
@ -247,18 +253,8 @@ class SQLDB:
if claim.is_stream: if claim.is_stream:
if claim.stream.release_time: if claim.stream.release_time:
claim_record['release_time'] = claim.stream.release_time claim_record['release_time'] = claim.stream.release_time
if claim.signing_channel_hash:
claim_record['channel_hash'] = sqlite3.Binary(claim.signing_channel_hash)
channel_pub_key = channels.get(claim.signing_channel_hash)
if channel_pub_key:
claim_record['is_channel_signature_valid'] = txo.is_signed_by(
None, ledger=self.ledger, public_key_bytes=channel_pub_key
)
if claim_record['is_channel_signature_valid']:
claim_record['channel_height'] = tx.height
elif claim.is_channel: elif claim.is_channel:
claim_record['is_channel'] = True claim_record['is_channel'] = True
claim_record['public_key_bytes'] = sqlite3.Binary(claim.channel.public_key_bytes)
for tag in claim.message.tags: for tag in claim.message.tags:
tags.append((tag, claim_hash, tx.height)) tags.append((tag, claim_hash, tx.height))
@ -273,63 +269,33 @@ class SQLDB:
return claims return claims
def insert_claims(self, txos: Set[Output], header, channels): def insert_claims(self, txos: List[Output], header):
claims = self._upsertable_claims(txos, header, channels) claims = self._upsertable_claims(txos, header)
if claims: if claims:
self.db.executemany(""" self.db.executemany("""
INSERT INTO claim ( INSERT INTO claim (
claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, amount,
amount, is_channel, public_key_bytes, timestamp, height, creation_height, is_channel, timestamp, creation_timestamp, height, creation_height,
channel_hash, channel_height, is_channel_signature_valid, release_time, release_time, activation_height, expiration_height, short_url)
activation_height, expiration_height, canonical, channel_canonical)
VALUES ( VALUES (
:claim_hash, :claim_id, :claim_name, :normalized, :txo_hash, :tx_position, :claim_hash, :claim_id, :claim_name, :normalized, :txo_hash, :tx_position, :amount,
:amount, :is_channel, :public_key_bytes, :timestamp, :height, :height, :is_channel, :timestamp, :timestamp, :height, :height,
:channel_hash, :channel_height, :is_channel_signature_valid,
CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE :timestamp END, CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE :timestamp END,
CASE WHEN :normalized NOT IN (SELECT normalized FROM claimtrie) THEN :height END, CASE WHEN :normalized NOT IN (SELECT normalized FROM claimtrie) THEN :height END,
CASE WHEN :height >= 262974 THEN :height+2102400 ELSE :height+262974 END, CASE WHEN :height >= 262974 THEN :height+2102400 ELSE :height+262974 END,
:normalized||COALESCE( :normalized||COALESCE(
(SELECT shortest_id(claim_id, :claim_id) FROM claim WHERE normalized = :normalized), (SELECT shortest_id(claim_id, :claim_id) FROM claim WHERE normalized = :normalized),
'#'||substr(:claim_id, 1, 1) '#'||substr(:claim_id, 1, 1)
), )
CASE WHEN :is_channel_signature_valid = 1 THEN
(SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'||
:normalized||COALESCE(
(SELECT shortest_id(claim_id, :claim_id) FROM claim
WHERE normalized = :normalized AND
channel_hash = :channel_hash AND
is_channel_signature_valid = 1),
'#'||substr(:claim_id, 1, 1)
)
END
)""", claims) )""", claims)
def update_claims(self, txos: Set[Output], header, channels): def update_claims(self, txos: List[Output], header):
claims = self._upsertable_claims(txos, header, channels, clear_first=True) claims = self._upsertable_claims(txos, header, clear_first=True)
if claims: if claims:
self.db.executemany(""" self.db.executemany("""
UPDATE claim SET UPDATE claim SET
txo_hash=:txo_hash, tx_position=:tx_position, height=:height, amount=:amount, txo_hash=:txo_hash, tx_position=:tx_position, amount=:amount, height=:height, timestamp=:timestamp,
public_key_bytes=:public_key_bytes, timestamp=:timestamp, release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END
release_time=CASE WHEN :release_time IS NOT NULL THEN :release_time ELSE release_time END,
channel_hash=:channel_hash, is_channel_signature_valid=:is_channel_signature_valid,
channel_height=CASE
WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_height
WHEN :is_channel_signature_valid THEN :height
END,
channel_canonical=CASE
WHEN channel_hash = :channel_hash AND :is_channel_signature_valid THEN channel_canonical
WHEN :is_channel_signature_valid THEN
(SELECT canonical FROM claim WHERE claim_hash=:channel_hash)||'/'||
:normalized||COALESCE(
(SELECT shortest_id(claim_id, :claim_id) FROM claim
WHERE normalized = :normalized AND
channel_hash = :channel_hash AND
is_channel_signature_valid = 1),
'#'||substr(:claim_id, 1, 1)
)
END
WHERE claim_hash=:claim_hash; WHERE claim_hash=:claim_hash;
""", claims) """, claims)
@ -346,13 +312,6 @@ class SQLDB:
for table in ('tag',): # 'language', 'location', etc for table in ('tag',): # 'language', 'location', etc
self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes}))
def invalidate_channel_signatures(self, channels):
self.execute(f"""
UPDATE claim SET
channel_height=NULL, channel_canonical=NULL, is_channel_signature_valid=0
WHERE channel_hash IN ({','.join('?' for _ in channels)})
""", [sqlite3.Binary(channel) for channel in channels])
def split_inputs_into_claims_supports_and_other(self, txis): def split_inputs_into_claims_supports_and_other(self, txis):
txo_hashes = {txi.txo_ref.hash for txi in txis} txo_hashes = {txi.txo_ref.hash for txi in txis}
claims = self.execute(*query( claims = self.execute(*query(
@ -369,21 +328,7 @@ class SQLDB:
txo_hashes -= {r['txo_hash'] for r in supports} txo_hashes -= {r['txo_hash'] for r in supports}
return claims, supports, txo_hashes return claims, supports, txo_hashes
def get_channel_public_keys_for_outputs(self, txos): def insert_supports(self, txos: List[Output]):
channels = set()
for txo in txos:
try:
channel_hash = txo.claim.signing_channel_hash
if channel_hash:
channels.add(channel_hash)
except:
pass
return dict(self.execute(*query(
"SELECT claim_hash, public_key_bytes FROM claim",
claim_hash__in=[sqlite3.Binary(channel) for channel in channels]
)).fetchall())
def insert_supports(self, txos: Set[Output]):
supports = [] supports = []
for txo in txos: for txo in txos:
tx = txo.tx_ref.tx tx = txo.tx_ref.tx
@ -405,6 +350,129 @@ class SQLDB:
'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]}
)) ))
def validate_channel_signatures(self, height, new_claims, updated_claims):
if not new_claims and not updated_claims:
return
channels, new_channel_keys, signables = {}, {}, {}
for txo in chain(new_claims, updated_claims):
try:
claim = txo.claim
except:
continue
if claim.is_channel:
channels[txo.claim_hash] = txo
new_channel_keys[txo.claim_hash] = claim.channel.public_key_bytes
else:
signables[txo.claim_hash] = txo
missing_channel_keys = set()
for txo in signables.values():
claim = txo.claim
if claim.is_signed and claim.signing_channel_hash not in new_channel_keys:
missing_channel_keys.add(claim.signing_channel_hash)
all_channel_keys = {}
if new_channel_keys or missing_channel_keys:
all_channel_keys = dict(self.execute(*query(
"SELECT claim_hash, public_key_bytes FROM claim",
claim_hash__in=[
sqlite3.Binary(channel_hash) for channel_hash in
set(new_channel_keys) | missing_channel_keys
]
)))
changed_channel_keys = {}
for claim_hash, new_key in new_channel_keys.items():
if all_channel_keys[claim_hash] != new_key:
all_channel_keys[claim_hash] = new_key
changed_channel_keys[claim_hash] = new_key
claim_updates = []
for claim_hash, txo in signables.items():
claim = txo.claim
update = {
'claim_hash': sqlite3.Binary(claim_hash),
'channel_hash': None,
'signature': None,
'signature_digest': None,
'is_channel_signature_valid': False
}
if claim.is_signed:
update.update({
'channel_hash': sqlite3.Binary(claim.signing_channel_hash),
'signature': sqlite3.Binary(txo.get_encoded_signature()),
'signature_digest': sqlite3.Binary(txo.get_signature_digest(self.ledger))
})
claim_updates.append(update)
if changed_channel_keys:
sql = f"""
SELECT * FROM claim WHERE
channel_hash IN ({','.join('?' for _ in changed_channel_keys)}) AND
signature IS NOT NULL
"""
for affected_claim in self.execute(sql, [sqlite3.Binary(h) for h in changed_channel_keys]):
if affected_claim['claim_hash'] not in signables:
claim_updates.append({
'claim_hash': sqlite3.Binary(affected_claim['claim_hash']),
'channel_hash': sqlite3.Binary(affected_claim['channel_hash']),
'signature': sqlite3.Binary(affected_claim['signature']),
'signature_digest': sqlite3.Binary(affected_claim['signature_digest']),
'is_channel_signature_valid': False
})
for update in claim_updates:
channel_pub_key = all_channel_keys.get(update['channel_hash'])
if channel_pub_key and update['signature']:
update['is_channel_signature_valid'] = Output.is_signature_valid(
bytes(update['signature']), bytes(update['signature_digest']), channel_pub_key
)
if claim_updates:
self.db.executemany(f"""
UPDATE claim SET
channel_hash=:channel_hash, signature=:signature, signature_digest=:signature_digest,
is_channel_signature_valid=:is_channel_signature_valid,
channel_join=CASE
WHEN is_channel_signature_valid AND :is_channel_signature_valid THEN channel_join
WHEN :is_channel_signature_valid THEN {height}
END,
canonical_url=CASE
WHEN is_channel_signature_valid AND :is_channel_signature_valid THEN canonical_url
WHEN :is_channel_signature_valid THEN
(SELECT short_url FROM claim WHERE claim_hash=:channel_hash)||'/'||
normalized||COALESCE(
(SELECT shortest_id(other_claim.claim_id, claim.claim_id) FROM claim AS other_claim
WHERE other_claim.normalized = claim.normalized AND
other_claim.channel_hash = :channel_hash AND
other_claim.is_channel_signature_valid = 1),
'#'||substr(claim_id, 1, 1)
)
END
WHERE claim_hash=:claim_hash;
""", claim_updates)
if channels:
self.db.executemany(
"UPDATE claim SET public_key_bytes=:public_key_bytes WHERE claim_hash=:claim_hash", [{
'claim_hash': sqlite3.Binary(claim_hash),
'public_key_bytes': sqlite3.Binary(txo.claim.channel.public_key_bytes)
} for claim_hash, txo in channels.items()]
)
if all_channel_keys:
self.db.executemany(f"""
UPDATE claim SET
claims_in_channel=(
SELECT COUNT(*) FROM claim AS claim_in_channel
WHERE claim_in_channel.channel_hash=claim.claim_hash AND
claim_in_channel.is_channel_signature_valid
)
WHERE claim_hash = ?
""", [(sqlite3.Binary(channel_hash),) for channel_hash in all_channel_keys.keys()])
def _update_support_amount(self, claim_hashes): def _update_support_amount(self, claim_hashes):
if claim_hashes: if claim_hashes:
self.execute(f""" self.execute(f"""
@ -501,20 +569,21 @@ class SQLDB:
r(self._perform_overtake, height, [], []) r(self._perform_overtake, height, [], [])
def advance_txs(self, height, all_txs, header, daemon_height, timer): def advance_txs(self, height, all_txs, header, daemon_height, timer):
insert_claims = set() insert_claims = []
update_claims = set() update_claims = []
delete_claim_hashes = set() delete_claim_hashes = set()
insert_supports = set() insert_supports = []
delete_support_txo_hashes = set() delete_support_txo_hashes = set()
recalculate_claim_hashes = set() # added/deleted supports, added/updated claim recalculate_claim_hashes = set() # added/deleted supports, added/updated claim
deleted_claim_names = set() deleted_claim_names = set()
delete_others = set()
body_timer = timer.add_timer('body') body_timer = timer.add_timer('body')
for position, (etx, txid) in enumerate(all_txs): for position, (etx, txid) in enumerate(all_txs):
tx = timer.run( tx = timer.run(
Transaction, etx.serialize(), height=height, position=position Transaction, etx.serialize(), height=height, position=position
) )
# Inputs # Inputs
spent_claims, spent_supports, spent_other = timer.run( spent_claims, spent_supports, spent_others = timer.run(
self.split_inputs_into_claims_supports_and_other, tx.inputs self.split_inputs_into_claims_supports_and_other, tx.inputs
) )
body_timer.start() body_timer.start()
@ -522,28 +591,38 @@ class SQLDB:
delete_support_txo_hashes.update({r['txo_hash'] for r in spent_supports}) delete_support_txo_hashes.update({r['txo_hash'] for r in spent_supports})
deleted_claim_names.update({r['normalized'] for r in spent_claims}) deleted_claim_names.update({r['normalized'] for r in spent_claims})
recalculate_claim_hashes.update({r['claim_hash'] for r in spent_supports}) recalculate_claim_hashes.update({r['claim_hash'] for r in spent_supports})
delete_others.update(spent_others)
# Outputs # Outputs
for output in tx.outputs: for output in tx.outputs:
if output.is_support: if output.is_support:
insert_supports.add(output) insert_supports.append(output)
recalculate_claim_hashes.add(output.claim_hash) recalculate_claim_hashes.add(output.claim_hash)
elif output.script.is_claim_name: elif output.script.is_claim_name:
insert_claims.add(output) insert_claims.append(output)
recalculate_claim_hashes.add(output.claim_hash) recalculate_claim_hashes.add(output.claim_hash)
elif output.script.is_update_claim: elif output.script.is_update_claim:
claim_hash = output.claim_hash claim_hash = output.claim_hash
if claim_hash in delete_claim_hashes: update_claims.append(output)
delete_claim_hashes.remove(claim_hash) recalculate_claim_hashes.add(claim_hash)
update_claims.add(output) delete_claim_hashes.discard(claim_hash)
recalculate_claim_hashes.add(output.claim_hash) delete_others.discard(output.ref.hash) # claim insertion and update occurring in the same block
body_timer.stop() body_timer.stop()
channel_public_keys = self.get_channel_public_keys_for_outputs(insert_claims | update_claims) skip_claim_timer = timer.add_timer('skip insertion of abandoned claims')
skip_claim_timer.start()
for new_claim in list(insert_claims):
if new_claim.ref.hash in delete_others:
insert_claims.remove(new_claim)
self.logger.info(
f"Skipping insertion of claim '{new_claim.id}' due to "
f"an abandon of it in the same block {height}."
)
skip_claim_timer.stop()
r = timer.run r = timer.run
r(self.delete_claims, delete_claim_hashes) r(self.delete_claims, delete_claim_hashes)
r(self.delete_supports, delete_support_txo_hashes) r(self.delete_supports, delete_support_txo_hashes)
r(self.invalidate_channel_signatures, recalculate_claim_hashes) r(self.insert_claims, insert_claims, header)
r(self.insert_claims, insert_claims, header, channel_public_keys) r(self.update_claims, update_claims, header)
r(self.update_claims, update_claims, header, channel_public_keys) r(self.validate_channel_signatures, height, insert_claims, update_claims)
r(self.insert_supports, insert_supports) r(self.insert_supports, insert_supports)
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
r(calculate_trending, self.db, height, self.main.first_sync, daemon_height) r(calculate_trending, self.db, height, self.main.first_sync, daemon_height)
@ -649,23 +728,20 @@ class SQLDB:
""" """
claimtrie.claim_hash as is_controlling, claimtrie.claim_hash as is_controlling,
claim.claim_hash, claim.txo_hash, claim.height, claim.claim_hash, claim.txo_hash, claim.height,
claim.activation_height, claim.effective_amount, claim.support_amount, claim.is_channel, claim.claims_in_channel,
claim.activation_height, claim.expiration_height,
claim.effective_amount, claim.support_amount,
claim.trending_group, claim.trending_mixed, claim.trending_group, claim.trending_mixed,
claim.trending_local, claim.trending_global, claim.trending_local, claim.trending_global,
CASE WHEN claim.is_channel_signature_valid = 1 claim.short_url, claim.canonical_url,
THEN claim.channel_canonical claim.channel_hash, channel.txo_hash AS channel_txo_hash,
ELSE claim.canonical channel.height AS channel_height, claim.is_channel_signature_valid
END AS canonical,
CASE WHEN claim.is_channel=1 THEN (
SELECT COUNT(*) FROM claim as claim_in_channel
WHERE claim_in_channel.channel_hash=claim.claim_hash
) ELSE 0 END AS claims_in_channel,
channel.txo_hash AS channel_txo_hash, channel.height AS channel_height
""", **constraints """, **constraints
) )
INTEGER_PARAMS = { INTEGER_PARAMS = {
'height', 'activation_height', 'release_time', 'publish_time', 'height', 'creation_height', 'activation_height', 'tx_position',
'release_time', 'timestamp',
'amount', 'effective_amount', 'support_amount', 'amount', 'effective_amount', 'support_amount',
'trending_group', 'trending_mixed', 'trending_group', 'trending_mixed',
'trending_local', 'trending_global', 'trending_local', 'trending_global',
@ -684,7 +760,7 @@ class SQLDB:
'name', 'name',
} | INTEGER_PARAMS } | INTEGER_PARAMS
def search(self, constraints) -> Tuple[List, int, int]: def search(self, constraints) -> Tuple[List, List, int, int]:
assert set(constraints).issubset(self.SEARCH_PARAMS), \ assert set(constraints).issubset(self.SEARCH_PARAMS), \
f"Search query contains invalid arguments: {set(constraints).difference(self.SEARCH_PARAMS)}" f"Search query contains invalid arguments: {set(constraints).difference(self.SEARCH_PARAMS)}"
total = self.get_claims_count(**constraints) total = self.get_claims_count(**constraints)
@ -693,10 +769,15 @@ class SQLDB:
if 'order_by' not in constraints: if 'order_by' not in constraints:
constraints['order_by'] = ["height", "^name"] constraints['order_by'] = ["height", "^name"]
txo_rows = self._search(**constraints) txo_rows = self._search(**constraints)
return txo_rows, constraints['offset'], total channel_hashes = set(txo['channel_hash'] for txo in txo_rows if txo['channel_hash'])
extra_txo_rows = []
if channel_hashes:
extra_txo_rows = self._search(**{'claim.claim_hash__in': [sqlite3.Binary(h) for h in channel_hashes]})
return txo_rows, extra_txo_rows, constraints['offset'], total
def resolve(self, urls) -> List: def resolve(self, urls) -> Tuple[List, List]:
result = [] result = []
channel_hashes = set()
for raw_url in urls: for raw_url in urls:
try: try:
url = URL.parse(raw_url) url = URL.parse(raw_url)
@ -723,12 +804,17 @@ class SQLDB:
matches = self._search(**query) matches = self._search(**query)
if matches: if matches:
result.append(matches[0]) result.append(matches[0])
if matches[0]['channel_hash']:
channel_hashes.add(matches[0]['channel_hash'])
else: else:
result.append(LookupError(f'Could not find stream in "{raw_url}".')) result.append(LookupError(f'Could not find stream in "{raw_url}".'))
continue continue
else: else:
result.append(channel) result.append(channel)
return result extra_txo_rows = []
if channel_hashes:
extra_txo_rows = self._search(**{'claim.claim_hash__in': [sqlite3.Binary(h) for h in channel_hashes]})
return result, extra_txo_rows
class LBRYDB(DB): class LBRYDB(DB):

View file

@ -51,7 +51,7 @@ class LBRYElectrumX(ElectrumX):
return Outputs.to_base64(*self.db.sql.search(kwargs)) return Outputs.to_base64(*self.db.sql.search(kwargs))
async def claimtrie_resolve(self, *urls): async def claimtrie_resolve(self, *urls):
return Outputs.to_base64(self.db.sql.resolve(urls)) return Outputs.to_base64(*self.db.sql.resolve(urls))
async def get_server_height(self): async def get_server_height(self):
return self.bp.height return self.bp.height

View file

@ -97,8 +97,7 @@ class Output(BaseOutput):
def has_private_key(self): def has_private_key(self):
return self.private_key is not None return self.private_key is not None
def is_signed_by(self, channel: 'Output', ledger=None, public_key_bytes=None): def get_signature_digest(self, ledger):
public_key_bytes = public_key_bytes or channel.claim.channel.public_key_bytes
if self.claim.unsigned_payload: if self.claim.unsigned_payload:
pieces = [ pieces = [
Base58.decode(self.get_address(ledger)), Base58.decode(self.get_address(ledger)),
@ -111,20 +110,31 @@ class Output(BaseOutput):
self.claim.signing_channel_hash, self.claim.signing_channel_hash,
self.claim.to_message_bytes() self.claim.to_message_bytes()
] ]
digest = sha256(b''.join(pieces)) return sha256(b''.join(pieces))
public_key = load_der_public_key(public_key_bytes, default_backend())
hash = hashes.SHA256() def get_encoded_signature(self):
signature = hexlify(self.claim.signature) signature = hexlify(self.claim.signature)
r = int(signature[:int(len(signature)/2)], 16) r = int(signature[:int(len(signature)/2)], 16)
s = int(signature[int(len(signature)/2):], 16) s = int(signature[int(len(signature)/2):], 16)
encoded_sig = ecdsa.util.sigencode_der(r, s, len(signature)*4) return ecdsa.util.sigencode_der(r, s, len(signature)*4)
@staticmethod
def is_signature_valid(encoded_signature, signature_digest, public_key_bytes):
try: try:
public_key.verify(encoded_sig, digest, ec.ECDSA(Prehashed(hash))) public_key = load_der_public_key(public_key_bytes, default_backend())
public_key.verify(encoded_signature, signature_digest, ec.ECDSA(Prehashed(hashes.SHA256())))
return True return True
except (ValueError, InvalidSignature): except (ValueError, InvalidSignature):
pass pass
return False return False
def is_signed_by(self, channel: 'Output', ledger=None):
return self.is_signature_valid(
self.get_encoded_signature(),
self.get_signature_digest(ledger),
channel.claim.channel.public_key_bytes
)
def sign(self, channel: 'Output', first_input_id=None): def sign(self, channel: 'Output', first_input_id=None):
self.channel = channel self.channel = channel
self.claim.signing_channel_hash = channel.claim_hash self.claim.signing_channel_hash = channel.claim_hash

View file

@ -1,7 +1,7 @@
import unittest import unittest
import ecdsa import ecdsa
import hashlib import hashlib
from binascii import hexlify, unhexlify from binascii import hexlify
from torba.client.constants import COIN, NULL_HASH32 from torba.client.constants import COIN, NULL_HASH32
from lbrynet.schema.claim import Claim from lbrynet.schema.claim import Claim
@ -54,22 +54,21 @@ class TestSQLDB(unittest.TestCase):
self._txos[output.ref.hash] = output self._txos[output.ref.hash] = output
return OldWalletServerTransaction(tx), tx.hash return OldWalletServerTransaction(tx), tx.hash
def get_channel(self, title, amount, name='@foo'): def _set_channel_key(self, channel, key):
claim = Claim()
claim.channel.title = title
channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')
# deterministic private key
private_key = ecdsa.SigningKey.from_string(b'c'*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256)
channel.private_key = private_key.to_pem().decode()
channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der()
channel.script.generate()
return self._make_tx(channel)
def get_channel_update(self, channel, amount, key=b'd'):
private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) private_key = ecdsa.SigningKey.from_string(key*32, curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256)
channel.private_key = private_key.to_pem().decode() channel.private_key = private_key.to_pem().decode()
channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der() channel.claim.channel.public_key_bytes = private_key.get_verifying_key().to_der()
channel.script.generate() channel.script.generate()
def get_channel(self, title, amount, name='@foo', key=b'a'):
claim = Claim()
claim.channel.title = title
channel = Output.pay_claim_name_pubkey_hash(amount, name, claim, b'abc')
self._set_channel_key(channel, key)
return self._make_tx(channel)
def get_channel_update(self, channel, amount, key=b'a'):
self._set_channel_key(channel, key)
return self._make_tx( return self._make_tx(
Output.pay_update_claim_pubkey_hash( Output.pay_update_claim_pubkey_hash(
amount, channel.claim_name, channel.claim_id, channel.claim, b'abc' amount, channel.claim_name, channel.claim_id, channel.claim, b'abc'
@ -313,57 +312,81 @@ class TestSQLDB(unittest.TestCase):
@staticmethod @staticmethod
def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs): def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs):
iterations = 100 iterations = cached_iteration+1 if cached_iteration else 100
for i in range(cached_iteration or 1, iterations): for i in range(cached_iteration or 1, iterations):
stream = getter(f'claim #{i}', COIN, **kwargs) stream = getter(f'claim #{i}', COIN, **kwargs)
if stream[0].tx.outputs[0].claim_id.startswith(prefix): if stream[0].tx.outputs[0].claim_id.startswith(prefix):
cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.') cached_iteration is None and print(f'Found "{prefix}" in {i} iterations.')
return stream return stream
raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations.') if cached_iteration:
raise ValueError(f'Failed to find "{prefix}" at cached iteration, run with None to find iteration.')
raise ValueError(f'Failed to find "{prefix}" in {iterations} iterations, try different values.')
def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None): def get_channel_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs):
return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration) return self._get_x_with_claim_id_prefix(self.get_channel, prefix, cached_iteration, **kwargs)
def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs): def get_stream_with_claim_id_prefix(self, prefix, cached_iteration=None, **kwargs):
return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs) return self._get_x_with_claim_id_prefix(self.get_stream, prefix, cached_iteration, **kwargs)
def test_canonical_name(self): def test_canonical_url_and_channel_validation(self):
advance = self.advance advance = self.advance
tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1) tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c')
tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72) tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c')
txo_chan_a = tx_chan_a[0].tx.outputs[0] txo_chan_a = tx_chan_a[0].tx.outputs[0]
advance(1, [tx_chan_a]) advance(1, [tx_chan_a])
advance(2, [tx_chan_ab]) advance(2, [tx_chan_ab])
r_ab, r_a = self.sql._search(order_by=['height'], limit=2) r_ab, r_a = self.sql._search(order_by=['creation_height'], limit=2)
self.assertEqual("@foo#a", r_a['canonical']) self.assertEqual("@foo#a", r_a['short_url'])
self.assertEqual("@foo#ab", r_ab['canonical']) self.assertEqual("@foo#ab", r_ab['short_url'])
self.assertIsNone(r_a['canonical_url'])
self.assertIsNone(r_ab['canonical_url'])
self.assertEqual(0, r_a['claims_in_channel'])
self.assertEqual(0, r_ab['claims_in_channel'])
tx_a = self.get_stream_with_claim_id_prefix('a', 2) tx_a = self.get_stream_with_claim_id_prefix('a', 2)
tx_ab = self.get_stream_with_claim_id_prefix('ab', 42) tx_ab = self.get_stream_with_claim_id_prefix('ab', 42)
tx_abc = self.get_stream_with_claim_id_prefix('abc', 65) tx_abc = self.get_stream_with_claim_id_prefix('abc', 65)
advance(3, [tx_a]) advance(3, [tx_a])
advance(4, [tx_ab]) advance(4, [tx_ab, tx_abc])
advance(5, [tx_abc]) r_abc, r_ab, r_a = self.sql._search(order_by=['creation_height', 'tx_position'], limit=3)
r_abc, r_ab, r_a = self.sql._search(order_by=['height'], limit=3) self.assertEqual("foo#a", r_a['short_url'])
self.assertEqual("foo#a", r_a['canonical']) self.assertEqual("foo#ab", r_ab['short_url'])
self.assertEqual("foo#ab", r_ab['canonical']) self.assertEqual("foo#abc", r_abc['short_url'])
self.assertEqual("foo#abc", r_abc['canonical']) self.assertIsNone(r_a['canonical_url'])
self.assertIsNone(r_ab['canonical_url'])
self.assertIsNone(r_abc['canonical_url'])
tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a) tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a)
tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a) tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a)
a2_claim_id = tx_a2[0].tx.outputs[0].claim_id
ab2_claim_id = tx_ab2[0].tx.outputs[0].claim_id
advance(6, [tx_a2]) advance(6, [tx_a2])
advance(7, [tx_ab2]) advance(7, [tx_ab2])
r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=2) r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2)
self.assertEqual("@foo#a/foo#a", r_a2['canonical']) self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url'])
self.assertEqual("@foo#a/foo#ab", r_ab2['canonical']) self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url'])
self.assertEqual("@foo#a/foo#a", r_a2['canonical_url'])
self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url'])
self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel'])
advance(8, [self.get_channel_update(txo_chan_a, COIN)]) # invalidate channel signature
_, r_ab2, r_a2 = self.sql._search(order_by=['height'], limit=3) advance(8, [self.get_channel_update(txo_chan_a, COIN, key=b'a')])
a2_claim_id = hexlify(r_a2['claim_hash'][::-1]).decode() r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2)
ab2_claim_id = hexlify(r_ab2['claim_hash'][::-1]).decode() self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url'])
self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['canonical']) self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url'])
self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['canonical']) self.assertIsNone(r_a2['canonical_url'])
self.assertIsNone(r_ab2['canonical_url'])
self.assertEqual(0, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel'])
# re-validate signature (reverts signature to original one)
advance(9, [self.get_channel_update(txo_chan_a, COIN, key=b'c')])
r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2)
self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url'])
self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url'])
self.assertEqual("@foo#a/foo#a", r_a2['canonical_url'])
self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url'])
self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel'])
def test_canonical_find_shortest_id(self): def test_canonical_find_shortest_id(self):
new_hash = 'abcdef0123456789beef' new_hash = 'abcdef0123456789beef'