invalidate channel signatures upon channel abandon

This commit is contained in:
Jack Robison 2021-06-17 21:20:57 -04:00 committed by Victor Shyba
parent da4e4ecd23
commit 89e7c8582e
5 changed files with 50 additions and 26 deletions

View file

@ -251,6 +251,7 @@ class BlockProcessor:
self.removed_claims_to_send_es = set() self.removed_claims_to_send_es = set()
self.touched_claims_to_send_es = set() self.touched_claims_to_send_es = set()
self.signatures_changed = set()
self.pending_reposted = set() self.pending_reposted = set()
self.pending_channel_counts = defaultdict(lambda: 0) self.pending_channel_counts = defaultdict(lambda: 0)
@ -662,7 +663,36 @@ class BlockProcessor:
self.pending_supports[claim_hash].clear() self.pending_supports[claim_hash].clear()
self.pending_supports.pop(claim_hash) self.pending_supports.pop(claim_hash)
return staged.get_abandon_ops(self.db.db) ops = []
if staged.name.startswith('@'): # abandon a channel, invalidate signatures
for k, claim_hash in self.db.db.iterator(prefix=Prefixes.channel_to_claim.pack_partial_key(staged.claim_hash)):
if claim_hash in self.staged_pending_abandoned:
continue
self.signatures_changed.add(claim_hash)
if claim_hash in self.pending_claims:
claim = self.pending_claims[claim_hash]
else:
claim = self.db.get_claim_txo(claim_hash)
assert claim is not None
ops.extend([
RevertableDelete(k, claim_hash),
RevertableDelete(
*Prefixes.claim_to_txo.pack_item(
claim_hash, claim.tx_num, claim.position, claim.root_tx_num, claim.root_position,
claim.amount, claim.channel_signature_is_valid, claim.name
)
),
RevertablePut(
*Prefixes.claim_to_txo.pack_item(
claim_hash, claim.tx_num, claim.position, claim.root_tx_num, claim.root_position,
claim.amount, False, claim.name
)
)
])
if staged.signing_hash:
ops.append(RevertableDelete(*Prefixes.claim_to_channel.pack_item(staged.claim_hash, staged.signing_hash)))
return ops
def _abandon(self, spent_claims) -> List['RevertableOp']: def _abandon(self, spent_claims) -> List['RevertableOp']:
# Handle abandoned claims # Handle abandoned claims
@ -1100,7 +1130,7 @@ class BlockProcessor:
self.touched_claims_to_send_es.update( self.touched_claims_to_send_es.update(
set(self.staged_activated_support.keys()).union( set(self.staged_activated_support.keys()).union(
set(claim_hash for (_, claim_hash) in self.staged_activated_claim.keys()) set(claim_hash for (_, claim_hash) in self.staged_activated_claim.keys())
).difference(self.removed_claims_to_send_es) ).union(self.signatures_changed).difference(self.removed_claims_to_send_es)
) )
# use the cumulative changes to update bid ordered resolve # use the cumulative changes to update bid ordered resolve
@ -1256,6 +1286,7 @@ class BlockProcessor:
self.possible_future_support_txos.clear() self.possible_future_support_txos.clear()
self.pending_channels.clear() self.pending_channels.clear()
self.amount_cache.clear() self.amount_cache.clear()
self.signatures_changed.clear()
# for cache in self.search_cache.values(): # for cache in self.search_cache.values():
# cache.clear() # cache.clear()

View file

@ -214,17 +214,3 @@ class StagedClaimtrieItem(typing.NamedTuple):
def get_spend_claim_txo_ops(self) -> typing.List[RevertableOp]: def get_spend_claim_txo_ops(self) -> typing.List[RevertableOp]:
return self._get_add_remove_claim_utxo_ops(add=False) return self._get_add_remove_claim_utxo_ops(add=False)
def get_invalidate_channel_ops(self, db) -> typing.List[RevertableOp]:
if not self.signing_hash:
return []
return [
RevertableDelete(*Prefixes.claim_to_channel.pack_item(self.claim_hash, self.signing_hash))
] + delete_prefix(db, DB_PREFIXES.channel_to_claim.value + self.signing_hash)
def get_abandon_ops(self, db) -> typing.List[RevertableOp]:
delete_short_id_ops = delete_prefix(
db, Prefixes.claim_short_id.pack_partial_key(self.name, self.claim_hash)
)
delete_claim_ops = delete_prefix(db, DB_PREFIXES.claim_to_txo.value + self.claim_hash)
delete_supports_ops = delete_prefix(db, DB_PREFIXES.claim_to_support.value + self.claim_hash)
return delete_short_id_ops + delete_claim_ops + delete_supports_ops + self.get_invalidate_channel_ops(db)

View file

@ -443,3 +443,4 @@ class ResolveResult(typing.NamedTuple):
claims_in_channel: typing.Optional[int] claims_in_channel: typing.Optional[int]
channel_hash: typing.Optional[bytes] channel_hash: typing.Optional[bytes]
reposted_claim_hash: typing.Optional[bytes] reposted_claim_hash: typing.Optional[bytes]
signature_valid: typing.Optional[bool]

View file

@ -193,7 +193,6 @@ class SearchIndex:
if censor.censored: if censor.censored:
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
total_referenced.extend(response) total_referenced.extend(response)
response = [ response = [
ResolveResult( ResolveResult(
name=r['claim_name'], name=r['claim_name'],
@ -215,7 +214,8 @@ class SearchIndex:
claims_in_channel=r['claims_in_channel'], claims_in_channel=r['claims_in_channel'],
channel_hash=r['channel_hash'], channel_hash=r['channel_hash'],
reposted_claim_hash=r['reposted_claim_hash'], reposted_claim_hash=r['reposted_claim_hash'],
reposted=r['reposted'] reposted=r['reposted'],
signature_valid=r['signature_valid']
) for r in response ) for r in response
] ]
extra = [ extra = [
@ -239,7 +239,8 @@ class SearchIndex:
claims_in_channel=r['claims_in_channel'], claims_in_channel=r['claims_in_channel'],
channel_hash=r['channel_hash'], channel_hash=r['channel_hash'],
reposted_claim_hash=r['reposted_claim_hash'], reposted_claim_hash=r['reposted_claim_hash'],
reposted=r['reposted'] reposted=r['reposted'],
signature_valid=r['signature_valid']
) for r in await self._get_referenced_rows(total_referenced) ) for r in await self._get_referenced_rows(total_referenced)
] ]
result = Outputs.to_base64( result = Outputs.to_base64(
@ -304,7 +305,7 @@ class SearchIndex:
return await self.search_ahead(**kwargs) return await self.search_ahead(**kwargs)
except NotFoundError: except NotFoundError:
return [], 0, 0 return [], 0, 0
return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) # return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)
async def search_ahead(self, **kwargs): async def search_ahead(self, **kwargs):
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
@ -489,7 +490,7 @@ def extract_doc(doc, index):
doc['repost_count'] = doc.pop('reposted') doc['repost_count'] = doc.pop('reposted')
doc['is_controlling'] = bool(doc['is_controlling']) doc['is_controlling'] = bool(doc['is_controlling'])
doc['signature'] = (doc.pop('signature') or b'').hex() or None doc['signature'] = (doc.pop('signature') or b'').hex() or None
doc['signature_digest'] = (doc.pop('signature_digest') or b'').hex() or None doc['signature_digest'] = doc['signature']
doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None
doc['public_key_id'] = (doc.pop('public_key_hash') or b'').hex() or None doc['public_key_id'] = (doc.pop('public_key_hash') or b'').hex() or None
doc['is_signature_valid'] = bool(doc['signature_valid']) doc['is_signature_valid'] = bool(doc['signature_valid'])
@ -512,6 +513,8 @@ def expand_query(**kwargs):
kwargs.pop('is_controlling') kwargs.pop('is_controlling')
query = {'must': [], 'must_not': []} query = {'must': [], 'must_not': []}
collapse = None collapse = None
if 'fee_currency' in kwargs and kwargs['fee_currency'] is not None:
kwargs['fee_currency'] = kwargs['fee_currency'].upper()
for key, value in kwargs.items(): for key, value in kwargs.items():
key = key.replace('claim.', '') key = key.replace('claim.', '')
many = key.endswith('__in') or isinstance(value, list) many = key.endswith('__in') or isinstance(value, list)

View file

@ -216,7 +216,7 @@ class LevelDB:
return supports return supports
def _prepare_resolve_result(self, tx_num: int, position: int, claim_hash: bytes, name: str, root_tx_num: int, def _prepare_resolve_result(self, tx_num: int, position: int, claim_hash: bytes, name: str, root_tx_num: int,
root_position: int, activation_height: int) -> ResolveResult: root_position: int, activation_height: int, signature_valid: bool) -> ResolveResult:
controlling_claim = self.get_controlling_claim(name) controlling_claim = self.get_controlling_claim(name)
tx_hash = self.total_transactions[tx_num] tx_hash = self.total_transactions[tx_num]
@ -247,7 +247,8 @@ class LevelDB:
creation_height=created_height, activation_height=activation_height, creation_height=created_height, activation_height=activation_height,
expiration_height=expiration_height, effective_amount=effective_amount, support_amount=support_amount, expiration_height=expiration_height, effective_amount=effective_amount, support_amount=support_amount,
channel_hash=channel_hash, reposted_claim_hash=reposted_claim_hash, channel_hash=channel_hash, reposted_claim_hash=reposted_claim_hash,
reposted=self.get_reposted_count(claim_hash) reposted=self.get_reposted_count(claim_hash),
signature_valid=None if not channel_hash else signature_valid
) )
def _resolve(self, normalized_name: str, claim_id: Optional[str] = None, def _resolve(self, normalized_name: str, claim_id: Optional[str] = None,
@ -275,9 +276,11 @@ class LevelDB:
for k, v in self.db.iterator(prefix=prefix): for k, v in self.db.iterator(prefix=prefix):
key = Prefixes.claim_short_id.unpack_key(k) key = Prefixes.claim_short_id.unpack_key(k)
claim_txo = Prefixes.claim_short_id.unpack_value(v) claim_txo = Prefixes.claim_short_id.unpack_value(v)
signature_is_valid = self.get_claim_txo(key.claim_hash).channel_signature_is_valid
return self._prepare_resolve_result( return self._prepare_resolve_result(
claim_txo.tx_num, claim_txo.position, key.claim_hash, key.name, key.root_tx_num, claim_txo.tx_num, claim_txo.position, key.claim_hash, key.name, key.root_tx_num,
key.root_position, self.get_activation(claim_txo.tx_num, claim_txo.position) key.root_position, self.get_activation(claim_txo.tx_num, claim_txo.position),
signature_is_valid
) )
return return
@ -292,7 +295,7 @@ class LevelDB:
activation = self.get_activation(key.tx_num, key.position) activation = self.get_activation(key.tx_num, key.position)
return self._prepare_resolve_result( return self._prepare_resolve_result(
key.tx_num, key.position, claim_val.claim_hash, key.name, claim_txo.root_tx_num, key.tx_num, key.position, claim_val.claim_hash, key.name, claim_txo.root_tx_num,
claim_txo.root_position, activation claim_txo.root_position, activation, claim_txo.channel_signature_is_valid
) )
return return
@ -354,7 +357,7 @@ class LevelDB:
activation_height = self.get_activation(v.tx_num, v.position) activation_height = self.get_activation(v.tx_num, v.position)
return self._prepare_resolve_result( return self._prepare_resolve_result(
v.tx_num, v.position, claim_hash, v.name, v.tx_num, v.position, claim_hash, v.name,
v.root_tx_num, v.root_position, activation_height v.root_tx_num, v.root_position, activation_height, v.channel_signature_is_valid
) )
async def fs_getclaimbyid(self, claim_id): async def fs_getclaimbyid(self, claim_id):