fix resolve by short id

This commit is contained in:
Jack Robison 2021-07-17 22:22:14 -04:00 committed by Victor Shyba
parent 749e64b101
commit 2180e24bc1
4 changed files with 89 additions and 32 deletions

View file

@ -171,13 +171,15 @@ class StagedClaimtrieItem(typing.NamedTuple):
) )
), ),
# short url resolution # short url resolution
]
ops.extend([
op( op(
*Prefixes.claim_short_id.pack_item( *Prefixes.claim_short_id.pack_item(
self.name, self.claim_hash, self.root_tx_num, self.root_position, self.tx_num, self.name, self.claim_hash.hex()[:prefix_len + 1], self.root_tx_num, self.root_position,
self.position self.tx_num, self.position
) )
) ) for prefix_len in range(10)
] ])
if self.signing_hash and self.channel_signature_is_valid: if self.signing_hash and self.channel_signature_is_valid:
ops.extend([ ops.extend([

View file

@ -15,6 +15,10 @@ def length_encoded_name(name: str) -> bytes:
return len(encoded).to_bytes(2, byteorder='big') + encoded return len(encoded).to_bytes(2, byteorder='big') + encoded
def length_prefix(key: str) -> bytes:
return len(key).to_bytes(1, byteorder='big') + key.encode()
class PrefixRow: class PrefixRow:
prefix: bytes prefix: bytes
key_struct: struct.Struct key_struct: struct.Struct
@ -187,12 +191,12 @@ class TXOToClaimValue(typing.NamedTuple):
class ClaimShortIDKey(typing.NamedTuple): class ClaimShortIDKey(typing.NamedTuple):
name: str name: str
claim_hash: bytes partial_claim_id: str
root_tx_num: int root_tx_num: int
root_position: int root_position: int
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}(name={self.name}, claim_hash={self.claim_hash.hex()}, " \ return f"{self.__class__.__name__}(name={self.name}, partial_claim_id={self.partial_claim_id}, " \
f"root_tx_num={self.root_tx_num}, root_position={self.root_position})" f"root_tx_num={self.root_tx_num}, root_position={self.root_position})"
@ -517,26 +521,25 @@ def shortid_key_helper(struct_fmt):
return wrapper return wrapper
def shortid_key_partial_claim_helper(name: str, partial_claim_hash: bytes): def shortid_key_partial_claim_helper(name: str, partial_claim_id: str):
assert len(partial_claim_hash) <= 20 assert len(partial_claim_id) < 40
return length_encoded_name(name) + partial_claim_hash return length_encoded_name(name) + length_prefix(partial_claim_id)
class ClaimShortIDPrefixRow(PrefixRow): class ClaimShortIDPrefixRow(PrefixRow):
prefix = DB_PREFIXES.claim_short_id_prefix.value prefix = DB_PREFIXES.claim_short_id_prefix.value
key_struct = struct.Struct(b'>20sLH') key_struct = struct.Struct(b'>LH')
value_struct = struct.Struct(b'>LH') value_struct = struct.Struct(b'>LH')
key_part_lambdas = [ key_part_lambdas = [
lambda: b'', lambda: b'',
length_encoded_name, length_encoded_name,
shortid_key_partial_claim_helper, shortid_key_partial_claim_helper
shortid_key_helper(b'>20sL'),
shortid_key_helper(b'>20sLH'),
] ]
@classmethod @classmethod
def pack_key(cls, name: str, claim_hash: bytes, root_tx_num: int, root_position: int): def pack_key(cls, name: str, short_claim_id: str, root_tx_num: int, root_position: int):
return cls.prefix + length_encoded_name(name) + cls.key_struct.pack(claim_hash, root_tx_num, root_position) return cls.prefix + length_encoded_name(name) + length_prefix(short_claim_id) +\
cls.key_struct.pack(root_tx_num, root_position)
@classmethod @classmethod
def pack_value(cls, tx_num: int, position: int): def pack_value(cls, tx_num: int, position: int):
@ -547,16 +550,18 @@ class ClaimShortIDPrefixRow(PrefixRow):
assert key[:1] == cls.prefix assert key[:1] == cls.prefix
name_len = int.from_bytes(key[1:3], byteorder='big') name_len = int.from_bytes(key[1:3], byteorder='big')
name = key[3:3 + name_len].decode() name = key[3:3 + name_len].decode()
return ClaimShortIDKey(name, *cls.key_struct.unpack(key[3 + name_len:])) claim_id_len = int.from_bytes(key[3+name_len:4+name_len], byteorder='big')
partial_claim_id = key[4+name_len:4+name_len+claim_id_len].decode()
return ClaimShortIDKey(name, partial_claim_id, *cls.key_struct.unpack(key[4 + name_len + claim_id_len:]))
@classmethod @classmethod
def unpack_value(cls, data: bytes) -> ClaimShortIDValue: def unpack_value(cls, data: bytes) -> ClaimShortIDValue:
return ClaimShortIDValue(*super().unpack_value(data)) return ClaimShortIDValue(*super().unpack_value(data))
@classmethod @classmethod
def pack_item(cls, name: str, claim_hash: bytes, root_tx_num: int, root_position: int, def pack_item(cls, name: str, partial_claim_id: str, root_tx_num: int, root_position: int,
tx_num: int, position: int): tx_num: int, position: int):
return cls.pack_key(name, claim_hash, root_tx_num, root_position), \ return cls.pack_key(name, partial_claim_id, root_tx_num, root_position), \
cls.pack_value(tx_num, position) cls.pack_value(tx_num, position)

View file

@ -209,6 +209,17 @@ class LevelDB:
supports.append((unpacked_k.tx_num, unpacked_k.position, unpacked_v.amount)) supports.append((unpacked_k.tx_num, unpacked_k.position, unpacked_v.amount))
return supports return supports
def get_short_claim_id_url(self, name: str, claim_hash: bytes, root_tx_num: int, root_position: int) -> str:
claim_id = claim_hash.hex()
for prefix_len in range(10):
prefix = Prefixes.claim_short_id.pack_partial_key(name, claim_id[:prefix_len+1])
for _k in self.db.iterator(prefix=prefix, include_value=False):
k = Prefixes.claim_short_id.unpack_key(_k)
if k.root_tx_num == root_tx_num and k.root_position == root_position:
return f'{name}#{k.partial_claim_id}'
break
raise Exception('wat')
def _prepare_resolve_result(self, tx_num: int, position: int, claim_hash: bytes, name: str, root_tx_num: int, def _prepare_resolve_result(self, tx_num: int, position: int, claim_hash: bytes, name: str, root_tx_num: int,
root_position: int, activation_height: int, signature_valid: bool) -> ResolveResult: root_position: int, activation_height: int, signature_valid: bool) -> ResolveResult:
controlling_claim = self.get_controlling_claim(name) controlling_claim = self.get_controlling_claim(name)
@ -225,8 +236,7 @@ class LevelDB:
effective_amount = support_amount + claim_amount effective_amount = support_amount + claim_amount
channel_hash = self.get_channel_for_claim(claim_hash, tx_num, position) channel_hash = self.get_channel_for_claim(claim_hash, tx_num, position)
reposted_claim_hash = self.get_repost(claim_hash) reposted_claim_hash = self.get_repost(claim_hash)
short_url = self.get_short_claim_id_url(name, claim_hash, root_tx_num, root_position)
short_url = f'{name}#{claim_hash.hex()}'
canonical_url = short_url canonical_url = short_url
claims_in_channel = self.get_claims_in_channel_count(claim_hash) claims_in_channel = self.get_claims_in_channel_count(claim_hash)
if channel_hash: if channel_hash:
@ -264,15 +274,24 @@ class LevelDB:
amount_order = max(int(amount_order or 1), 1) amount_order = max(int(amount_order or 1), 1)
if claim_id: if claim_id:
if len(claim_id) == 40: # a full claim id
claim_txo = self.get_claim_txo(bytes.fromhex(claim_id))
if normalized_name != claim_txo.name:
return
return self._prepare_resolve_result(
claim_txo.tx_num, claim_txo.position, bytes.fromhex(claim_id), claim_txo.name,
claim_txo.root_tx_num, claim_txo.root_position,
self.get_activation(claim_txo.tx_num, claim_txo.position), claim_txo.channel_signature_is_valid
)
# resolve by partial/complete claim id # resolve by partial/complete claim id
short_claim_hash = bytes.fromhex(claim_id) prefix = Prefixes.claim_short_id.pack_partial_key(normalized_name, claim_id[:10])
prefix = Prefixes.claim_short_id.pack_partial_key(normalized_name, short_claim_hash)
for k, v in self.db.iterator(prefix=prefix): for k, v in self.db.iterator(prefix=prefix):
key = Prefixes.claim_short_id.unpack_key(k) key = Prefixes.claim_short_id.unpack_key(k)
claim_txo = Prefixes.claim_short_id.unpack_value(v) claim_txo = Prefixes.claim_short_id.unpack_value(v)
signature_is_valid = self.get_claim_txo(key.claim_hash).channel_signature_is_valid claim_hash = self.get_claim_from_txo(claim_txo.tx_num, claim_txo.position).claim_hash
signature_is_valid = self.get_claim_txo(claim_hash).channel_signature_is_valid
return self._prepare_resolve_result( return self._prepare_resolve_result(
claim_txo.tx_num, claim_txo.position, key.claim_hash, key.name, key.root_tx_num, claim_txo.tx_num, claim_txo.position, claim_hash, key.name, key.root_tx_num,
key.root_position, self.get_activation(claim_txo.tx_num, claim_txo.position), key.root_position, self.get_activation(claim_txo.tx_num, claim_txo.position),
signature_is_valid signature_is_valid
) )
@ -396,11 +415,12 @@ class LevelDB:
def get_claims_for_name(self, name): def get_claims_for_name(self, name):
claims = [] claims = []
for _k, _v in self.db.iterator(prefix=Prefixes.claim_short_id.pack_partial_key(name)): prefix = Prefixes.claim_short_id.pack_partial_key(name) + int(1).to_bytes(1, byteorder='big')
k, v = Prefixes.claim_short_id.unpack_key(_k), Prefixes.claim_short_id.unpack_value(_v) for _k, _v in self.db.iterator(prefix=prefix):
# claims[v.claim_hash] = (k, v) v = Prefixes.claim_short_id.unpack_value(_v)
if k.claim_hash not in claims: claim_hash = self.get_claim_from_txo(v.tx_num, v.position).claim_hash
claims.append(k.claim_hash) if claim_hash not in claims:
claims.append(claim_hash)
return claims return claims
def get_claims_in_channel_count(self, channel_hash) -> int: def get_claims_in_channel_count(self, channel_hash) -> int:
@ -435,10 +455,10 @@ class LevelDB:
def get_claim_txos_for_name(self, name: str): def get_claim_txos_for_name(self, name: str):
txos = {} txos = {}
for k, v in self.db.iterator(prefix=Prefixes.claim_short_id.pack_partial_key(name)): prefix = Prefixes.claim_short_id.pack_partial_key(name) + int(1).to_bytes(1, byteorder='big')
claim_hash = Prefixes.claim_short_id.unpack_key(k).claim_hash for k, v in self.db.iterator(prefix=prefix):
tx_num, nout = Prefixes.claim_short_id.unpack_value(v) tx_num, nout = Prefixes.claim_short_id.unpack_value(v)
txos[claim_hash] = tx_num, nout txos[self.get_claim_from_txo(tx_num, nout).claim_hash] = tx_num, nout
return txos return txos
def get_claim_metadata(self, tx_hash, nout): def get_claim_metadata(self, tx_hash, nout):

View file

@ -3,6 +3,7 @@ import json
import hashlib import hashlib
from bisect import bisect_right from bisect import bisect_right
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from collections import defaultdict
from lbry.testcase import CommandTestCase from lbry.testcase import CommandTestCase
from lbry.wallet.transaction import Transaction, Output from lbry.wallet.transaction import Transaction, Output
from lbry.schema.compat import OldClaimMessage from lbry.schema.compat import OldClaimMessage
@ -100,6 +101,35 @@ class BaseResolveTestCase(CommandTestCase):
class ResolveCommand(BaseResolveTestCase): class ResolveCommand(BaseResolveTestCase):
async def test_colliding_short_id(self):
prefixes = defaultdict(list)
colliding_claim_ids = []
first_claims_one_char_shortid = {}
while True:
chan = self.get_claim_id(
await self.channel_create('@abc', '0.01', allow_duplicate_name=True)
)
if chan[:1] not in first_claims_one_char_shortid:
first_claims_one_char_shortid[chan[:1]] = chan
prefixes[chan[:2]].append(chan)
if len(prefixes[chan[:2]]) > 1:
colliding_claim_ids.extend(prefixes[chan[:2]])
break
first_claim = first_claims_one_char_shortid[colliding_claim_ids[0][:1]]
await self.assertResolvesToClaimId(
f'@abc#{colliding_claim_ids[0][:1]}', first_claim
)
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0][:2]}', colliding_claim_ids[0])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0][:7]}', colliding_claim_ids[0])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0][:17]}', colliding_claim_ids[0])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0]}', colliding_claim_ids[0])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1][:3]}', colliding_claim_ids[1])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1][:7]}', colliding_claim_ids[1])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1][:17]}', colliding_claim_ids[1])
await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1]}', colliding_claim_ids[1])
async def test_resolve_response(self): async def test_resolve_response(self):
channel_id = self.get_claim_id( channel_id = self.get_claim_id(
await self.channel_create('@abc', '0.01') await self.channel_create('@abc', '0.01')