From e39331442385dccdd6257dea5d92e241adf7ad2d Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Sat, 22 Jun 2019 19:25:32 -0400 Subject: [PATCH] fix for claims_in_channel not updating on abandoned claim --- lbry/lbry/wallet/server/db.py | 15 ++-- lbry/tests/unit/wallet/server/test_sqldb.py | 77 ++++++++++++++------- 2 files changed, 61 insertions(+), 31 deletions(-) diff --git a/lbry/lbry/wallet/server/db.py b/lbry/lbry/wallet/server/db.py index 0a3861cf9..422ad6aaa 100644 --- a/lbry/lbry/wallet/server/db.py +++ b/lbry/lbry/wallet/server/db.py @@ -342,9 +342,14 @@ class SQLDB: """ Deletes claim supports and from claimtrie in case of an abandon. """ if claim_hashes: binary_claim_hashes = [sqlite3.Binary(claim_hash) for claim_hash in claim_hashes] + affected_channels = self.execute(*query( + "SELECT channel_hash FROM claim", channel_hash__is_not_null=1, claim_hash__in=binary_claim_hashes + )).fetchall() for table in ('claim', 'support', 'claimtrie'): self.execute(*self._delete_sql(table, {'claim_hash__in': binary_claim_hashes})) self._clear_claim_metadata(binary_claim_hashes) + return set(r['channel_hash'] for r in affected_channels) + return set() def _clear_claim_metadata(self, binary_claim_hashes: List[sqlite3.Binary]): if binary_claim_hashes: @@ -389,7 +394,7 @@ class SQLDB: 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} )) - def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, timer): + def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, affected_channels, timer): if not new_claims and not updated_claims and not spent_claims: return @@ -420,12 +425,12 @@ class SQLDB: sub_timer = timer.add_timer('lookup missing channels') sub_timer.start() all_channel_keys = {} - if new_channel_keys or missing_channel_keys: + if new_channel_keys or missing_channel_keys or affected_channels: all_channel_keys = dict(self.execute(*query( "SELECT claim_hash, public_key_bytes FROM claim", claim_hash__in=[ sqlite3.Binary(channel_hash) for channel_hash in - set(new_channel_keys) | missing_channel_keys + set(new_channel_keys) | missing_channel_keys | affected_channels ] ))) sub_timer.stop() @@ -715,12 +720,12 @@ class SQLDB: expire_timer.stop() r = timer.run - r(self.delete_claims, delete_claim_hashes) + affected_channels = r(self.delete_claims, delete_claim_hashes) r(self.delete_supports, delete_support_txo_hashes) r(self.insert_claims, insert_claims, header) r(self.update_claims, update_claims, header) r(self.validate_channel_signatures, height, insert_claims, - update_claims, delete_claim_hashes, forward_timer=True) + update_claims, delete_claim_hashes, affected_channels, forward_timer=True) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) r(calculate_trending, self.db, height, self.main.first_sync, daemon_height) diff --git a/lbry/tests/unit/wallet/server/test_sqldb.py b/lbry/tests/unit/wallet/server/test_sqldb.py index 7b6b59a2e..938616be1 100644 --- a/lbry/tests/unit/wallet/server/test_sqldb.py +++ b/lbry/tests/unit/wallet/server/test_sqldb.py @@ -144,6 +144,9 @@ class TestSQLDB(unittest.TestCase): self.assertEqual(active or [], self.get_active()) self.assertEqual(accepted or [], self.get_accepted()) + +class TestClaimtrie(TestSQLDB): + def test_example_from_spec(self): # https://spec.lbry.com/#claim-activation-example advance, state = self.advance, self.state @@ -310,28 +313,6 @@ class TestSQLDB(unittest.TestCase): accepted=[] ) - def test_trending(self): - advance, state = self.advance, self.state - no_trend = self.get_stream('Claim A', COIN) - downwards = self.get_stream('Claim B', COIN) - up_small = self.get_stream('Claim C', COIN) - up_medium = self.get_stream('Claim D', COIN) - up_biggly = self.get_stream('Claim E', COIN) - claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) - for window in range(1, 8): - advance(TRENDING_WINDOW * window, [ - self.get_support(downwards, (20-window)*COIN), - self.get_support(up_small, int(20+(window/10)*COIN)), - self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), - self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN), - ]) - results = self.sql._search(order_by=['trending_local']) - self.assertEqual([c.claim_id for c in claims], [hexlify(c['claim_hash'][::-1]).decode() for c in results]) - self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results]) - self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results]) - self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results]) - self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results]) - @staticmethod def _get_x_with_claim_id_prefix(getter, prefix, cached_iteration=None, **kwargs): iterations = cached_iteration+1 if cached_iteration else 100 @@ -411,13 +392,15 @@ class TestSQLDB(unittest.TestCase): self.assertEqual("@foo#a/foo#ab", r_ab2['canonical_url']) self.assertEqual(2, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + # claim abandon updates claims_in_channel + advance(10, [self.get_abandon(tx_ab2)]) + self.assertEqual(1, self.sql._search(claim_id=txo_chan_a.claim_id, limit=1)[0]['claims_in_channel']) + # delete channel, invaliding stream claim signatures - advance(10, [self.get_abandon(channel_update)]) - r_ab2, r_a2 = self.sql._search(order_by=['creation_height'], limit=2) + advance(11, [self.get_abandon(channel_update)]) + r_a2, = self.sql._search(order_by=['creation_height'], limit=1) self.assertEqual(f"foo#{a2_claim_id[:2]}", r_a2['short_url']) - self.assertEqual(f"foo#{ab2_claim_id[:4]}", r_ab2['short_url']) self.assertIsNone(r_a2['canonical_url']) - self.assertIsNone(r_ab2['canonical_url']) def test_canonical_find_shortest_id(self): new_hash = 'abcdef0123456789beef' @@ -434,3 +417,45 @@ class TestSQLDB(unittest.TestCase): self.assertEqual('#abcd', f.finalize()) f.step(other3, new_hash) self.assertEqual('#abcdef0123456789beef', f.finalize()) + + def test_claims_in_channel_gets_updated(self): + tx_chan_a = self.get_channel_with_claim_id_prefix('a', 1, key=b'c') + tx_chan_ab = self.get_channel_with_claim_id_prefix('ab', 72, key=b'c') + txo_chan_a = tx_chan_a[0].tx.outputs[0] + advance(1, [tx_chan_a]) + advance(2, [tx_chan_ab]) + r_ab, r_a = self.sql._search(order_by=['creation_height'], limit=2) + self.assertEqual("@foo#a", r_a['short_url']) + self.assertEqual("@foo#ab", r_ab['short_url']) + self.assertIsNone(r_a['canonical_url']) + self.assertIsNone(r_ab['canonical_url']) + self.assertEqual(0, r_a['claims_in_channel']) + self.assertEqual(0, r_ab['claims_in_channel']) + + tx_a2 = self.get_stream_with_claim_id_prefix('a', 7, channel=txo_chan_a) + tx_ab2 = self.get_stream_with_claim_id_prefix('ab', 23, channel=txo_chan_a) + + +class TestTrending(TestSQLDB): + + def test_trending(self): + advance, state = self.advance, self.state + no_trend = self.get_stream('Claim A', COIN) + downwards = self.get_stream('Claim B', COIN) + up_small = self.get_stream('Claim C', COIN) + up_medium = self.get_stream('Claim D', COIN) + up_biggly = self.get_stream('Claim E', COIN) + claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) + for window in range(1, 8): + advance(TRENDING_WINDOW * window, [ + self.get_support(downwards, (20-window)*COIN), + self.get_support(up_small, int(20+(window/10)*COIN)), + self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), + self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN), + ]) + results = self.sql._search(order_by=['trending_local']) + self.assertEqual([c.claim_id for c in claims], [hexlify(c['claim_hash'][::-1]).decode() for c in results]) + self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results]) + self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results]) + self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results]) + self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])