fix for claims_in_channel not updating on abandoned claim

This commit is contained in:
Lex Berezhny 2019-06-22 19:25:32 -04:00
parent c54373bd9c
commit e393314423
2 changed files with 61 additions and 31 deletions

View file

@ -342,9 +342,14 @@ class SQLDB:
""" Deletes claim supports and from claimtrie in case of an abandon. """ """ Deletes claim supports and from claimtrie in case of an abandon. """
if claim_hashes: if claim_hashes:
binary_claim_hashes = [sqlite3.Binary(claim_hash) for claim_hash in claim_hashes] binary_claim_hashes = [sqlite3.Binary(claim_hash) for claim_hash in claim_hashes]
affected_channels = self.execute(*query(
"SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=binary_claim_hashes
)).fetchall()
for table in ('claim', 'support', 'claimtrie'): for table in ('claim', 'support', 'claimtrie'):
self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes}))
self._clear_claim_metadata(binary_claim_hashes) self._clear_claim_metadata(binary_claim_hashes)
return set(r['channel_hash'] for r in affected_channels)
return set()
def _clear_claim_metadata(self, binary_claim_hashes: List[sqlite3.Binary]): def _clear_claim_metadata(self, binary_claim_hashes: List[sqlite3.Binary]):
if binary_claim_hashes: if binary_claim_hashes:
@ -389,7 +394,7 @@ class SQLDB:
'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]}
)) ))
def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, timer): def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, affected_channels, timer):
if not new_claims and not updated_claims and not spent_claims: if not new_claims and not updated_claims and not spent_claims:
return return
@ -420,12 +425,12 @@ class SQLDB:
sub_timer = timer.add_timer('lookup missing channels') sub_timer = timer.add_timer('lookup missing channels')
sub_timer.start() sub_timer.start()
all_channel_keys = {} all_channel_keys = {}
if new_channel_keys or missing_channel_keys: if new_channel_keys or missing_channel_keys or affected_channels:
all_channel_keys = dict(self.execute(*query( all_channel_keys = dict(self.execute(*query(
"SELECT claim_hash, public_key_bytes FROM claim", "SELECT claim_hash, public_key_bytes FROM claim",
claim_hash__in=[ claim_hash__in=[
sqlite3.Binary(channel_hash) for channel_hash in sqlite3.Binary(channel_hash) for channel_hash in
set(new_channel_keys) | missing_channel_keys set(new_channel_keys) | missing_channel_keys | affected_channels
] ]
))) )))
sub_timer.stop() sub_timer.stop()
@ -715,12 +720,12 @@ class SQLDB:
expire_timer.stop() expire_timer.stop()
r = timer.run r = timer.run
r(self.delete_claims, delete_claim_hashes) affected_channels = r(self.delete_claims, delete_claim_hashes)
r(self.delete_supports, delete_support_txo_hashes) r(self.delete_supports, delete_support_txo_hashes)
r(self.insert_claims, insert_claims, header) r(self.insert_claims, insert_claims, header)
r(self.update_claims, update_claims, header) r(self.update_claims, update_claims, header)
r(self.validate_channel_signatures, height, insert_claims, r(self.validate_channel_signatures, height, insert_claims,
update_claims, delete_claim_hashes, forward_timer=True) update_claims, delete_claim_hashes, affected_channels, forward_timer=True)
r(self.insert_supports, insert_supports) r(self.insert_supports, insert_supports)
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
r(calculate_trending, self.db, height, self.main.first_sync, daemon_height) r(calculate_trending, self.db, height, self.main.first_sync, daemon_height)

View file

@ -144,6 +144,9 @@ class TestSQLDB(unittest.TestCase):
self.assertEqual(active or [], self.get_active()) self.assertEqual(active or [], self.get_active())
self.assertEqual(accepted or [], self.get_accepted()) self.assertEqual(accepted or [], self.get_accepted())
class TestClaimtrie(TestSQLDB):
def test_example_from_spec(self): def test_example_from_spec(self):
# https://spec.lbry.com/#claim-activation-example # https://spec.lbry.com/#claim-activation-example
advance, state = self.advance, self.state advance, state = self.advance, self.state
@ -310,28 +313,6 @@ class TestSQLDB(unittest.TestCase):
accepted=[] accepted=[]
) )
def test_trending(self):
advance, state = self.advance, self.state
no_trend = self.get_stream('Claim A', COIN)
downwards = self.get_stream('Claim B', COIN)
up_small = self.get_stream('Claim C', COIN)
up_medium = self.get_stream('Claim D', COIN)
up_biggly = self.get_stream('Claim E', COIN)
claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards])
for window in range(1, 8):
advance(TRENDING_WINDOW * window, [
self.get_support(downwards, (20-window)*COIN),
self.get_support(up_small, int(20+(window/10)*COIN)),
self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN),
self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN),
])
results = self.sql._search(order_by=['trending_local'])
self.assertEqual([c.claim_id for c in claims], [hexlify(c['claim_hash'][::-1]).decode() for c in results])
self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results])
self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results])
self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results])
self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])
@staticmethod @staticmethod
def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs): def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs):
iterations = cached_iteration+1 if cached_iteration else 100 iterations = cached_iteration+1 if cached_iteration else 100
@ -411,13 +392,15 @@ class TestSQLDB(unittest.TestCase):
self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url'])
self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel'])
# claim abandon updates claims_in_channel
advance(10, [self.get_abandon(tx_ab2)])
self.assertEqual(1, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel'])
# delete channel, invaliding stream claim signatures # delete channel, invaliding stream claim signatures
advance(10, [self.get_abandon(channel_update)]) advance(11, [self.get_abandon(channel_update)])
r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2) r_a2, = self.sql._search(order_by=['creation_height'], limit=1)
self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url']) self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url'])
self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url'])
self.assertIsNone(r_a2['canonical_url']) self.assertIsNone(r_a2['canonical_url'])
self.assertIsNone(r_ab2['canonical_url'])
def test_canonical_find_shortest_id(self): def test_canonical_find_shortest_id(self):
new_hash = 'abcdef0123456789beef' new_hash = 'abcdef0123456789beef'
@ -434,3 +417,45 @@ class TestSQLDB(unittest.TestCase):
self.assertEqual('#abcd', f.finalize()) self.assertEqual('#abcd', f.finalize())
f.step(other3, new_hash) f.step(other3, new_hash)
self.assertEqual('#abcdef0123456789beef', f.finalize()) self.assertEqual('#abcdef0123456789beef', f.finalize())
def test_claims_in_channel_gets_updated(self):
tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c')
tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c')
txo_chan_a = tx_chan_a[0].tx.outputs[0]
advance(1, [tx_chan_a])
advance(2, [tx_chan_ab])
r_ab, r_a = self.sql._search(order_by=['creation_height'], limit=2)
self.assertEqual("@foo#a", r_a['short_url'])
self.assertEqual("@foo#ab", r_ab['short_url'])
self.assertIsNone(r_a['canonical_url'])
self.assertIsNone(r_ab['canonical_url'])
self.assertEqual(0, r_a['claims_in_channel'])
self.assertEqual(0, r_ab['claims_in_channel'])
tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a)
tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a)
class TestTrending(TestSQLDB):
def test_trending(self):
advance, state = self.advance, self.state
no_trend = self.get_stream('Claim A', COIN)
downwards = self.get_stream('Claim B', COIN)
up_small = self.get_stream('Claim C', COIN)
up_medium = self.get_stream('Claim D', COIN)
up_biggly = self.get_stream('Claim E', COIN)
claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards])
for window in range(1, 8):
advance(TRENDING_WINDOW * window, [
self.get_support(downwards, (20-window)*COIN),
self.get_support(up_small, int(20+(window/10)*COIN)),
self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN),
self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN),
])
results = self.sql._search(order_by=['trending_local'])
self.assertEqual([c.claim_id for c in claims], [hexlify(c['claim_hash'][::-1]).decode() for c in results])
self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results])
self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results])
self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results])
self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])