fix signed claim invalidation corner cases

This commit is contained in:
Jack Robison 2021-07-06 17:56:18 -04:00 committed by Victor Shyba
parent d74d06d97b
commit 0c85de7839
3 changed files with 156 additions and 55 deletions

View file

@ -260,6 +260,9 @@ class BlockProcessor:
self.amount_cache = {} self.amount_cache = {}
self.expired_claim_hashes: Set[bytes] = set() self.expired_claim_hashes: Set[bytes] = set()
self.doesnt_have_valid_signature: Set[bytes] = set()
self.claim_channels: Dict[bytes, bytes] = {}
def claim_producer(self): def claim_producer(self):
if self.db.db_height <= 1: if self.db.db_height <= 1:
return return
@ -491,9 +494,11 @@ class BlockProcessor:
if is_channel: if is_channel:
self.pending_channels[claim_hash] = txo.claim.channel.public_key_bytes self.pending_channels[claim_hash] = txo.claim.channel.public_key_bytes
self.doesnt_have_valid_signature.add(claim_hash)
raw_channel_tx = None raw_channel_tx = None
if signable and signable.signing_channel_hash: if signable and signable.signing_channel_hash:
signing_channel = self.db.get_claim_txo(signing_channel_hash) signing_channel = self.db.get_claim_txo(signing_channel_hash)
if signing_channel: if signing_channel:
raw_channel_tx = self.db.db.get( raw_channel_tx = self.db.db.get(
DB_PREFIXES.TX_PREFIX.value + self.db.total_transactions[signing_channel.tx_num] DB_PREFIXES.TX_PREFIX.value + self.db.total_transactions[signing_channel.tx_num]
@ -502,10 +507,9 @@ class BlockProcessor:
try: try:
if not signing_channel: if not signing_channel:
if txo.signable.signing_channel_hash[::-1] in self.pending_channels: if txo.signable.signing_channel_hash[::-1] in self.pending_channels:
channel_pub_key_bytes = self.pending_channels[txo.signable.signing_channel_hash[::-1]] channel_pub_key_bytes = self.pending_channels[signing_channel_hash]
elif raw_channel_tx: elif raw_channel_tx:
chan_output = self.coin.transaction(raw_channel_tx).outputs[signing_channel.position] chan_output = self.coin.transaction(raw_channel_tx).outputs[signing_channel.position]
chan_script = OutputScript(chan_output.pk_script) chan_script = OutputScript(chan_output.pk_script)
chan_script.parse() chan_script.parse()
channel_meta = Claim.from_bytes(chan_script.values['claim']) channel_meta = Claim.from_bytes(chan_script.values['claim'])
@ -517,6 +521,8 @@ class BlockProcessor:
) )
if channel_signature_is_valid: if channel_signature_is_valid:
self.pending_channel_counts[signing_channel_hash] += 1 self.pending_channel_counts[signing_channel_hash] += 1
self.doesnt_have_valid_signature.remove(claim_hash)
self.claim_channels[claim_hash] = signing_channel_hash
except: except:
self.logger.exception(f"error validating channel signature for %s:%i", tx_hash[::-1].hex(), nout) self.logger.exception(f"error validating channel signature for %s:%i", tx_hash[::-1].hex(), nout)
@ -532,16 +538,16 @@ class BlockProcessor:
previous_claim = self.txo_to_claim.pop((prev_tx_num, prev_idx)) previous_claim = self.txo_to_claim.pop((prev_tx_num, prev_idx))
root_tx_num, root_idx = previous_claim.root_tx_num, previous_claim.root_position root_tx_num, root_idx = previous_claim.root_tx_num, previous_claim.root_position
else: else:
v = self.db.get_claim_txo( previous_claim = self._make_pending_claim_txo(claim_hash)
claim_hash root_tx_num, root_idx = previous_claim.root_tx_num, previous_claim.root_position
)
root_tx_num, root_idx = v.root_tx_num, v.root_position
activation = self.db.get_activation(prev_tx_num, prev_idx) activation = self.db.get_activation(prev_tx_num, prev_idx)
self.db_op_stack.extend( self.db_op_stack.extend(
StagedActivation( StagedActivation(
ACTIVATED_CLAIM_TXO_TYPE, claim_hash, prev_tx_num, prev_idx, activation, claim_name, v.amount ACTIVATED_CLAIM_TXO_TYPE, claim_hash, prev_tx_num, prev_idx, activation, claim_name,
previous_claim.amount
).get_remove_activate_ops() ).get_remove_activate_ops()
) )
pending = StagedClaimtrieItem( pending = StagedClaimtrieItem(
claim_name, claim_hash, txo.amount, self.coin.get_expiration_height(height), tx_num, nout, root_tx_num, claim_name, claim_hash, txo.amount, self.coin.get_expiration_height(height), tx_num, nout, root_tx_num,
root_idx, channel_signature_is_valid, signing_channel_hash, reposted_claim_hash root_idx, channel_signature_is_valid, signing_channel_hash, reposted_claim_hash
@ -605,16 +611,7 @@ class BlockProcessor:
) )
if not spent_claim_hash_and_name: # txo is not a claim if not spent_claim_hash_and_name: # txo is not a claim
return False return False
claim_hash = spent_claim_hash_and_name.claim_hash spent = self._make_pending_claim_txo(spent_claim_hash_and_name.claim_hash)
signing_hash = self.db.get_channel_for_claim(claim_hash, txin_num, txin.prev_idx)
v = self.db.get_claim_txo(claim_hash)
reposted_claim_hash = self.db.get_repost(claim_hash)
spent = StagedClaimtrieItem(
v.name, claim_hash, v.amount,
self.coin.get_expiration_height(bisect_right(self.db.tx_counts, txin_num)),
txin_num, txin.prev_idx, v.root_tx_num, v.root_position, v.channel_signature_is_valid, signing_hash,
reposted_claim_hash
)
if spent.reposted_claim_hash: if spent.reposted_claim_hash:
self.pending_reposted.add(spent.reposted_claim_hash) self.pending_reposted.add(spent.reposted_claim_hash)
if spent.signing_hash and spent.channel_signature_is_valid: if spent.signing_hash and spent.channel_signature_is_valid:
@ -658,44 +655,55 @@ class BlockProcessor:
self.support_txos_by_claim[claim_hash].clear() self.support_txos_by_claim[claim_hash].clear()
self.support_txos_by_claim.pop(claim_hash) self.support_txos_by_claim.pop(claim_hash)
if staged.name.startswith('@'): # abandon a channel, invalidate signatures if name.startswith('@'): # abandon a channel, invalidate signatures
for k, claim_hash in self.db.db.iterator( self._invalidate_channel_signatures(claim_hash)
prefix=Prefixes.channel_to_claim.pack_partial_key(staged.claim_hash)):
if claim_hash in self.abandoned_claims or claim_hash in self.expired_claim_hashes: def _invalidate_channel_signatures(self, claim_hash: bytes):
for k, signed_claim_hash in self.db.db.iterator(
prefix=Prefixes.channel_to_claim.pack_partial_key(claim_hash)):
if signed_claim_hash in self.abandoned_claims or signed_claim_hash in self.expired_claim_hashes:
continue continue
self.signatures_changed.add(claim_hash) # there is no longer a signing channel for this claim as of this block
if claim_hash in self.claim_hash_to_txo: if signed_claim_hash in self.doesnt_have_valid_signature:
claim = self.txo_to_claim[self.claim_hash_to_txo[claim_hash]] continue
self.txo_to_claim[self.claim_hash_to_txo[claim_hash]] = StagedClaimtrieItem( # the signing channel changed in this block
claim.name, claim.claim_hash, claim.amount, claim.expiration_height, claim.tx_num, if signed_claim_hash in self.claim_channels and signed_claim_hash != self.claim_channels[signed_claim_hash]:
claim.position, claim.root_tx_num, claim.root_position, channel_signature_is_valid=False, continue
signing_hash=None, reposted_claim_hash=claim.reposted_claim_hash
) # if the claim with an invalidated signature is in this block, update the StagedClaimtrieItem
# so that if we later try to spend it in this block we won't try to delete the channel info twice
if signed_claim_hash in self.claim_hash_to_txo:
signed_claim_txo = self.claim_hash_to_txo[signed_claim_hash]
claim = self.txo_to_claim[signed_claim_txo]
if claim.signing_hash != claim_hash: # claim was already invalidated this block
continue
self.txo_to_claim[signed_claim_txo] = claim.invalidate_signature()
else: else:
claim = self._make_pending_claim_txo(signed_claim_hash)
self.signatures_changed.add(signed_claim_hash)
self.pending_channel_counts[claim_hash] -= 1
self.db_op_stack.extend(claim.get_invalidate_signature_ops())
for staged in list(self.txo_to_claim.values()):
if staged.signing_hash == claim_hash and staged.claim_hash not in self.doesnt_have_valid_signature:
self.db_op_stack.extend(staged.get_invalidate_signature_ops())
self.txo_to_claim[self.claim_hash_to_txo[staged.claim_hash]] = staged.invalidate_signature()
self.signatures_changed.add(staged.claim_hash)
self.pending_channel_counts[claim_hash] -= 1
def _make_pending_claim_txo(self, claim_hash: bytes):
claim = self.db.get_claim_txo(claim_hash) claim = self.db.get_claim_txo(claim_hash)
assert claim is not None if claim_hash in self.doesnt_have_valid_signature:
signing_hash = Prefixes.channel_to_claim.unpack_key(k).signing_hash signing_hash = None
self.db_op_stack.extend([ else:
# delete channel_to_claim/claim_to_channel signing_hash = self.db.get_channel_for_claim(claim_hash, claim.tx_num, claim.position)
RevertableDelete(k, claim_hash), reposted_claim_hash = self.db.get_repost(claim_hash)
RevertableDelete( return StagedClaimtrieItem(
*Prefixes.claim_to_channel.pack_item(claim_hash, claim.tx_num, claim.position, signing_hash) claim.name, claim_hash, claim.amount,
), self.coin.get_expiration_height(bisect_right(self.db.tx_counts, claim.tx_num)),
# update claim_to_txo with channel_signature_is_valid=False claim.tx_num, claim.position, claim.root_tx_num, claim.root_position,
RevertableDelete( claim.channel_signature_is_valid, signing_hash, reposted_claim_hash
*Prefixes.claim_to_txo.pack_item(
claim_hash, claim.tx_num, claim.position, claim.root_tx_num, claim.root_position,
claim.amount, claim.channel_signature_is_valid, claim.name
) )
),
RevertablePut(
*Prefixes.claim_to_txo.pack_item(
claim_hash, claim.tx_num, claim.position, claim.root_tx_num, claim.root_position,
claim.amount, False, claim.name
)
)
])
def _expire_claims(self, height: int): def _expire_claims(self, height: int):
expired = self.db.get_expired_by_height(height) expired = self.db.get_expired_by_height(height)
@ -1276,6 +1284,8 @@ class BlockProcessor:
self.amount_cache.clear() self.amount_cache.clear()
self.signatures_changed.clear() self.signatures_changed.clear()
self.expired_claim_hashes.clear() self.expired_claim_hashes.clear()
self.doesnt_have_valid_signature.clear()
self.claim_channels.clear()
# for cache in self.search_cache.values(): # for cache in self.search_cache.values():
# cache.clear() # cache.clear()

View file

@ -197,10 +197,12 @@ class StagedClaimtrieItem(typing.NamedTuple):
if self.reposted_claim_hash: if self.reposted_claim_hash:
ops.extend([ ops.extend([
op( op(
*RepostPrefixRow.pack_item(self.claim_hash, self.reposted_claim_hash) *Prefixes.repost.pack_item(self.claim_hash, self.reposted_claim_hash)
), ),
op( op(
*RepostedPrefixRow.pack_item(self.reposted_claim_hash, self.tx_num, self.position, self.claim_hash) *Prefixes.reposted_claim.pack_item(
self.reposted_claim_hash, self.tx_num, self.position, self.claim_hash
)
), ),
]) ])
@ -212,3 +214,42 @@ class StagedClaimtrieItem(typing.NamedTuple):
def get_spend_claim_txo_ops(self) -> typing.List[RevertableOp]: def get_spend_claim_txo_ops(self) -> typing.List[RevertableOp]:
return self._get_add_remove_claim_utxo_ops(add=False) return self._get_add_remove_claim_utxo_ops(add=False)
def get_invalidate_signature_ops(self):
if not self.signing_hash:
return []
ops = [
RevertableDelete(
*Prefixes.claim_to_channel.pack_item(
self.claim_hash, self.tx_num, self.position, self.signing_hash
)
)
]
if self.channel_signature_is_valid:
ops.extend([
# delete channel_to_claim/claim_to_channel
RevertableDelete(
*Prefixes.channel_to_claim.pack_item(
self.signing_hash, self.name, self.tx_num, self.position, self.claim_hash
)
),
# update claim_to_txo with channel_signature_is_valid=False
RevertableDelete(
*Prefixes.claim_to_txo.pack_item(
self.claim_hash, self.tx_num, self.position, self.root_tx_num, self.root_position,
self.amount, self.channel_signature_is_valid, self.name
)
),
RevertablePut(
*Prefixes.claim_to_txo.pack_item(
self.claim_hash, self.tx_num, self.position, self.root_tx_num, self.root_position,
self.amount, False, self.name
)
)
])
return ops
def invalidate_signature(self) -> 'StagedClaimtrieItem':
return StagedClaimtrieItem(
self.name, self.claim_hash, self.amount, self.expiration_height, self.tx_num, self.position,
self.root_tx_num, self.root_position, False, None, self.reposted_claim_hash
)

View file

@ -403,6 +403,56 @@ class ResolveCommand(BaseResolveTestCase):
class ResolveClaimTakeovers(BaseResolveTestCase): class ResolveClaimTakeovers(BaseResolveTestCase):
async def test_channel_invalidation(self):
channel_id = (await self.channel_create('@test', '0.1'))['outputs'][0]['claim_id']
initially_unsigned1 = (
await self.stream_create('initially_unsigned1', '0.1')
)['outputs'][0]['claim_id']
initially_unsigned2 = (
await self.stream_create('initially_unsigned2', '0.1')
)['outputs'][0]['claim_id']
initially_signed1 = (
await self.stream_create('signed1', '0.01', channel_id=channel_id)
)['outputs'][0]['claim_id']
await self.generate(1)
self.assertIn("error", await self.resolve('@test/initially_unsigned1'))
await self.assertMatchClaimIsWinning('initially_unsigned1', initially_unsigned1)
self.assertIn("error", await self.resolve('@test/initially_unsigned2'))
await self.assertMatchClaimIsWinning('initially_unsigned2', initially_unsigned2)
self.assertDictEqual(await self.resolve('@test/signed1'), await self.resolve('signed1'))
await self.assertMatchClaimIsWinning('signed1', initially_signed1)
# sign 'initially_unsigned1' and update it
await self.ledger.wait(await self.daemon.jsonrpc_stream_update(
initially_unsigned1, '0.09', channel_id=channel_id))
await self.ledger.wait(await self.daemon.jsonrpc_stream_update(initially_unsigned2, '0.09'))
# update the still unsigned 'initially_unsigned2'
await self.ledger.wait(await self.daemon.jsonrpc_stream_update(
initially_unsigned2, '0.09', channel_id=channel_id))
await self.ledger.wait(await self.daemon.jsonrpc_stream_update(
initially_signed1, '0.09', clear_channel=True))
await self.daemon.jsonrpc_txo_spend(type='channel', claim_id=channel_id)
signed2 = (
await self.stream_create('signed2', '0.01', channel_id=channel_id)
)['outputs'][0]['claim_id']
await self.generate(1)
self.assertIn("error", await self.resolve('@test'))
self.assertIn("error", await self.resolve('@test/signed1'))
self.assertIn("error", await self.resolve('@test/initially_unsigned2'))
self.assertIn("error", await self.resolve('@test/initially_unsigned1'))
self.assertIn("error", await self.resolve('@test/signed2'))
await self.assertMatchClaimIsWinning('signed1', initially_signed1)
await self.assertMatchClaimIsWinning('initially_unsigned1', initially_unsigned1)
await self.assertMatchClaimIsWinning('initially_unsigned2', initially_unsigned2)
await self.assertMatchClaimIsWinning('signed2', signed2)
async def _test_activation_delay(self): async def _test_activation_delay(self):
name = 'derp' name = 'derp'
# initially claim the name # initially claim the name