From 2180e24bc13d87dda9b247fb0c18592a12f429d8 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Sat, 17 Jul 2021 22:22:14 -0400 Subject: [PATCH] fix resolve by short id --- lbry/wallet/server/db/claimtrie.py | 10 ++-- lbry/wallet/server/db/prefixes.py | 33 +++++++------ lbry/wallet/server/leveldb.py | 48 +++++++++++++------ .../blockchain/test_resolve_command.py | 30 ++++++++++++ 4 files changed, 89 insertions(+), 32 deletions(-) diff --git a/lbry/wallet/server/db/claimtrie.py b/lbry/wallet/server/db/claimtrie.py index f18812d88..4c688ada8 100644 --- a/lbry/wallet/server/db/claimtrie.py +++ b/lbry/wallet/server/db/claimtrie.py @@ -171,13 +171,15 @@ class StagedClaimtrieItem(typing.NamedTuple): ) ), # short url resolution + ] + ops.extend([ op( *Prefixes.claim_short_id.pack_item( - self.name, self.claim_hash, self.root_tx_num, self.root_position, self.tx_num, - self.position + self.name, self.claim_hash.hex()[:prefix_len + 1], self.root_tx_num, self.root_position, + self.tx_num, self.position ) - ) - ] + ) for prefix_len in range(10) + ]) if self.signing_hash and self.channel_signature_is_valid: ops.extend([ diff --git a/lbry/wallet/server/db/prefixes.py b/lbry/wallet/server/db/prefixes.py index 86254f394..53e17596c 100644 --- a/lbry/wallet/server/db/prefixes.py +++ b/lbry/wallet/server/db/prefixes.py @@ -15,6 +15,10 @@ def length_encoded_name(name: str) -> bytes: return len(encoded).to_bytes(2, byteorder='big') + encoded +def length_prefix(key: str) -> bytes: + return len(key).to_bytes(1, byteorder='big') + key.encode() + + class PrefixRow: prefix: bytes key_struct: struct.Struct @@ -187,12 +191,12 @@ class TXOToClaimValue(typing.NamedTuple): class ClaimShortIDKey(typing.NamedTuple): name: str - claim_hash: bytes + partial_claim_id: str root_tx_num: int root_position: int def __str__(self): - return f"{self.__class__.__name__}(name={self.name}, claim_hash={self.claim_hash.hex()}, " \ + return f"{self.__class__.__name__}(name={self.name}, partial_claim_id={self.partial_claim_id}, " \ f"root_tx_num={self.root_tx_num}, root_position={self.root_position})" @@ -517,26 +521,25 @@ def shortid_key_helper(struct_fmt): return wrapper -def shortid_key_partial_claim_helper(name: str, partial_claim_hash: bytes): - assert len(partial_claim_hash) <= 20 - return length_encoded_name(name) + partial_claim_hash +def shortid_key_partial_claim_helper(name: str, partial_claim_id: str): + assert len(partial_claim_id) < 40 + return length_encoded_name(name) + length_prefix(partial_claim_id) class ClaimShortIDPrefixRow(PrefixRow): prefix = DB_PREFIXES.claim_short_id_prefix.value - key_struct = struct.Struct(b'>20sLH') + key_struct = struct.Struct(b'>LH') value_struct = struct.Struct(b'>LH') key_part_lambdas = [ lambda: b'', length_encoded_name, - shortid_key_partial_claim_helper, - shortid_key_helper(b'>20sL'), - shortid_key_helper(b'>20sLH'), + shortid_key_partial_claim_helper ] @classmethod - def pack_key(cls, name: str, claim_hash: bytes, root_tx_num: int, root_position: int): - return cls.prefix + length_encoded_name(name) + cls.key_struct.pack(claim_hash, root_tx_num, root_position) + def pack_key(cls, name: str, short_claim_id: str, root_tx_num: int, root_position: int): + return cls.prefix + length_encoded_name(name) + length_prefix(short_claim_id) +\ + cls.key_struct.pack(root_tx_num, root_position) @classmethod def pack_value(cls, tx_num: int, position: int): @@ -547,16 +550,18 @@ class ClaimShortIDPrefixRow(PrefixRow): assert key[:1] == cls.prefix name_len = int.from_bytes(key[1:3], byteorder='big') name = key[3:3 + name_len].decode() - return ClaimShortIDKey(name, *cls.key_struct.unpack(key[3 + name_len:])) + claim_id_len = int.from_bytes(key[3+name_len:4+name_len], byteorder='big') + partial_claim_id = key[4+name_len:4+name_len+claim_id_len].decode() + return ClaimShortIDKey(name, partial_claim_id, *cls.key_struct.unpack(key[4 + name_len + claim_id_len:])) @classmethod def unpack_value(cls, data: bytes) -> ClaimShortIDValue: return ClaimShortIDValue(*super().unpack_value(data)) @classmethod - def pack_item(cls, name: str, claim_hash: bytes, root_tx_num: int, root_position: int, + def pack_item(cls, name: str, partial_claim_id: str, root_tx_num: int, root_position: int, tx_num: int, position: int): - return cls.pack_key(name, claim_hash, root_tx_num, root_position), \ + return cls.pack_key(name, partial_claim_id, root_tx_num, root_position), \ cls.pack_value(tx_num, position) diff --git a/lbry/wallet/server/leveldb.py b/lbry/wallet/server/leveldb.py index 61dec5697..3f6f55a12 100644 --- a/lbry/wallet/server/leveldb.py +++ b/lbry/wallet/server/leveldb.py @@ -209,6 +209,17 @@ class LevelDB: supports.append((unpacked_k.tx_num, unpacked_k.position, unpacked_v.amount)) return supports + def get_short_claim_id_url(self, name: str, claim_hash: bytes, root_tx_num: int, root_position: int) -> str: + claim_id = claim_hash.hex() + for prefix_len in range(10): + prefix = Prefixes.claim_short_id.pack_partial_key(name, claim_id[:prefix_len+1]) + for _k in self.db.iterator(prefix=prefix, include_value=False): + k = Prefixes.claim_short_id.unpack_key(_k) + if k.root_tx_num == root_tx_num and k.root_position == root_position: + return f'{name}#{k.partial_claim_id}' + break + raise Exception('wat') + def _prepare_resolve_result(self, tx_num: int, position: int, claim_hash: bytes, name: str, root_tx_num: int, root_position: int, activation_height: int, signature_valid: bool) -> ResolveResult: controlling_claim = self.get_controlling_claim(name) @@ -225,8 +236,7 @@ class LevelDB: effective_amount = support_amount + claim_amount channel_hash = self.get_channel_for_claim(claim_hash, tx_num, position) reposted_claim_hash = self.get_repost(claim_hash) - - short_url = f'{name}#{claim_hash.hex()}' + short_url = self.get_short_claim_id_url(name, claim_hash, root_tx_num, root_position) canonical_url = short_url claims_in_channel = self.get_claims_in_channel_count(claim_hash) if channel_hash: @@ -264,15 +274,24 @@ class LevelDB: amount_order = max(int(amount_order or 1), 1) if claim_id: + if len(claim_id) == 40: # a full claim id + claim_txo = self.get_claim_txo(bytes.fromhex(claim_id)) + if normalized_name != claim_txo.name: + return + return self._prepare_resolve_result( + claim_txo.tx_num, claim_txo.position, bytes.fromhex(claim_id), claim_txo.name, + claim_txo.root_tx_num, claim_txo.root_position, + self.get_activation(claim_txo.tx_num, claim_txo.position), claim_txo.channel_signature_is_valid + ) # resolve by partial/complete claim id - short_claim_hash = bytes.fromhex(claim_id) - prefix = Prefixes.claim_short_id.pack_partial_key(normalized_name, short_claim_hash) + prefix = Prefixes.claim_short_id.pack_partial_key(normalized_name, claim_id[:10]) for k, v in self.db.iterator(prefix=prefix): key = Prefixes.claim_short_id.unpack_key(k) claim_txo = Prefixes.claim_short_id.unpack_value(v) - signature_is_valid = self.get_claim_txo(key.claim_hash).channel_signature_is_valid + claim_hash = self.get_claim_from_txo(claim_txo.tx_num, claim_txo.position).claim_hash + signature_is_valid = self.get_claim_txo(claim_hash).channel_signature_is_valid return self._prepare_resolve_result( - claim_txo.tx_num, claim_txo.position, key.claim_hash, key.name, key.root_tx_num, + claim_txo.tx_num, claim_txo.position, claim_hash, key.name, key.root_tx_num, key.root_position, self.get_activation(claim_txo.tx_num, claim_txo.position), signature_is_valid ) @@ -396,11 +415,12 @@ class LevelDB: def get_claims_for_name(self, name): claims = [] - for _k, _v in self.db.iterator(prefix=Prefixes.claim_short_id.pack_partial_key(name)): - k, v = Prefixes.claim_short_id.unpack_key(_k), Prefixes.claim_short_id.unpack_value(_v) - # claims[v.claim_hash] = (k, v) - if k.claim_hash not in claims: - claims.append(k.claim_hash) + prefix = Prefixes.claim_short_id.pack_partial_key(name) + int(1).to_bytes(1, byteorder='big') + for _k, _v in self.db.iterator(prefix=prefix): + v = Prefixes.claim_short_id.unpack_value(_v) + claim_hash = self.get_claim_from_txo(v.tx_num, v.position).claim_hash + if claim_hash not in claims: + claims.append(claim_hash) return claims def get_claims_in_channel_count(self, channel_hash) -> int: @@ -435,10 +455,10 @@ class LevelDB: def get_claim_txos_for_name(self, name: str): txos = {} - for k, v in self.db.iterator(prefix=Prefixes.claim_short_id.pack_partial_key(name)): - claim_hash = Prefixes.claim_short_id.unpack_key(k).claim_hash + prefix = Prefixes.claim_short_id.pack_partial_key(name) + int(1).to_bytes(1, byteorder='big') + for k, v in self.db.iterator(prefix=prefix): tx_num, nout = Prefixes.claim_short_id.unpack_value(v) - txos[claim_hash] = tx_num, nout + txos[self.get_claim_from_txo(tx_num, nout).claim_hash] = tx_num, nout return txos def get_claim_metadata(self, tx_hash, nout): diff --git a/tests/integration/blockchain/test_resolve_command.py b/tests/integration/blockchain/test_resolve_command.py index 55b710323..2c443944e 100644 --- a/tests/integration/blockchain/test_resolve_command.py +++ b/tests/integration/blockchain/test_resolve_command.py @@ -3,6 +3,7 @@ import json import hashlib from bisect import bisect_right from binascii import hexlify, unhexlify +from collections import defaultdict from lbry.testcase import CommandTestCase from lbry.wallet.transaction import Transaction, Output from lbry.schema.compat import OldClaimMessage @@ -100,6 +101,35 @@ class BaseResolveTestCase(CommandTestCase): class ResolveCommand(BaseResolveTestCase): + async def test_colliding_short_id(self): + prefixes = defaultdict(list) + + colliding_claim_ids = [] + first_claims_one_char_shortid = {} + + while True: + chan = self.get_claim_id( + await self.channel_create('@abc', '0.01', allow_duplicate_name=True) + ) + if chan[:1] not in first_claims_one_char_shortid: + first_claims_one_char_shortid[chan[:1]] = chan + prefixes[chan[:2]].append(chan) + if len(prefixes[chan[:2]]) > 1: + colliding_claim_ids.extend(prefixes[chan[:2]]) + break + first_claim = first_claims_one_char_shortid[colliding_claim_ids[0][:1]] + await self.assertResolvesToClaimId( + f'@abc#{colliding_claim_ids[0][:1]}', first_claim + ) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0][:2]}', colliding_claim_ids[0]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0][:7]}', colliding_claim_ids[0]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0][:17]}', colliding_claim_ids[0]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[0]}', colliding_claim_ids[0]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1][:3]}', colliding_claim_ids[1]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1][:7]}', colliding_claim_ids[1]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1][:17]}', colliding_claim_ids[1]) + await self.assertResolvesToClaimId(f'@abc#{colliding_claim_ids[1]}', colliding_claim_ids[1]) + async def test_resolve_response(self): channel_id = self.get_claim_id( await self.channel_create('@abc', '0.01')