diff --git a/lbrynet/wallet/server/db.py b/lbrynet/wallet/server/db.py index 48730d892..027f327f3 100644 --- a/lbrynet/wallet/server/db.py +++ b/lbrynet/wallet/server/db.py @@ -9,6 +9,9 @@ from torba.client.basedatabase import query, constraints_to_sql from lbrynet.schema.url import URL, normalize_name from lbrynet.wallet.transaction import Transaction, Output +from lbrynet.wallet.server.trending import ( + CREATE_TREND_TABLE, calculate_trending, register_trending_functions +) ATTRIBUTE_ARRAY_MAX_LENGTH = 100 @@ -56,8 +59,6 @@ def _apply_constraints_for_array_attributes(constraints, attr): class SQLDB: - DAY_BLOCKS = 720 - PRAGMAS = """ pragma journal_mode=WAL; """ @@ -98,16 +99,6 @@ class SQLDB: create index if not exists claim_trending_global_idx on claim (trending_global); """ - CREATE_TREND_TABLE = """ - create table if not exists trend ( - claim_hash bytes not null, - height integer not null, - amount integer not null, - primary key (claim_hash, height) - ) without rowid; - create index if not exists trend_claim_hash_idx on trend (claim_hash); - """ - CREATE_SUPPORT_TABLE = """ create table if not exists support ( txo_hash bytes primary key, @@ -159,6 +150,7 @@ class SQLDB: self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False) self.db.row_factory = sqlite3.Row self.db.executescript(self.CREATE_TABLES_QUERY) + register_trending_functions(self.db) def close(self): self.db.close() @@ -328,39 +320,6 @@ class SQLDB: 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]} )) - def _update_trending_amount(self, height): - day_ago = height-self.TRENDING_24_HOURS - two_day_ago = height-self.TRENDING_24_HOURS*2 - week_ago = height-self.TRENDING_WEEK - two_week_ago = height-self.TRENDING_WEEK*2 - self.execute(f""" - UPDATE claim SET - trending_day_one = COALESCE( - (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash - AND height >= {day_ago}), 0 - ), - trending_day_two = COALESCE( - (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash - AND {day_ago} > height and height >= {two_day_ago} - ), 0 - ), - trending_week_one = COALESCE( - (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash - AND height >= {week_ago} - ), 0 - ), - trending_week_two = COALESCE( - (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash - AND {week_ago} > height and height >= {two_week_ago} - ), 0 - ) - """) - self.execute(f""" - UPDATE claim SET - trending_daily = trending_day_one - trending_day_two, - trending_weekly = trending_week_one - trending_week_two - """) - def _update_support_amount(self, claim_hashes): if claim_hashes: self.execute(f""" @@ -440,10 +399,11 @@ class SQLDB: self.execute(f"CREATE TABLE claimtrie{height} AS SELECT * FROM claimtrie") def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer): + r = timer.run binary_claim_hashes = [ sqlite3.Binary(claim_hash) for claim_hash in changed_claim_hashes ] - r = timer.run + r(self._calculate_activation_height, height) r(self._update_support_amount, binary_claim_hashes) @@ -453,9 +413,6 @@ class SQLDB: r(self._update_effective_amount, height) r(self._perform_overtake, height, [], []) - #if not self.main.first_sync: - # r(self._update_trending_amount, height) - def advance_txs(self, height, all_txs, header, timer): insert_claims = set() update_claims = set() @@ -500,6 +457,8 @@ class SQLDB: r(self.update_claims, update_claims, header) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) + if not self.main.first_sync: + r(calculate_trending, self.db, height) def get_claims(self, cols, **constraints): if 'order_by' in constraints: diff --git a/lbrynet/wallet/server/trending.py b/lbrynet/wallet/server/trending.py new file mode 100644 index 000000000..87f245963 --- /dev/null +++ b/lbrynet/wallet/server/trending.py @@ -0,0 +1,109 @@ +from math import sqrt + +TRENDING_WINDOW = 650 # number of blocks, ~24hr period +TRENDING_DATA_POINTS = 7 # WINDOW * DATA_POINTS = ~1 week worth of trending data + +CREATE_TREND_TABLE = """ + create table if not exists trend ( + claim_hash bytes not null, + height integer not null, + amount integer not null, + primary key (claim_hash, height) + ) without rowid; +""" + + +class ZScore: + __slots__ = 'count', 'total', 'power', 'last' + + def __init__(self): + self.count = 0 + self.total = 0 + self.power = 0 + self.last = None + + def step(self, value): + if self.last is not None: + self.count += 1 + self.total += self.last + self.power += self.last**2 + self.last = value + + @property + def mean(self): + return self.total / self.count + + @property + def standard_deviation(self): + return sqrt((self.power / self.count) - self.mean**2) + + def finalize(self): + if self.count == 0: + return self.last + return (self.last - self.mean) / (self.standard_deviation or 1) + + +def register_trending_functions(connection): + connection.create_aggregate("zscore", 1, ZScore) + + +def calculate_trending(db, height): + if height % TRENDING_WINDOW != 0: + return + + db.execute(f""" + DELETE FROM trend WHERE height < {height-(TRENDING_WINDOW*TRENDING_DATA_POINTS)} + """) + + start = (height-TRENDING_WINDOW)+1 + db.execute(f""" + INSERT INTO trend (claim_hash, height, amount) + SELECT claim_hash, {start}, COALESCE( + (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash + AND height >= {start}), 0 + ) AS support_sum + FROM claim WHERE support_sum > 0 + """) + + zscore = ZScore() + for (global_sum,) in db.execute("SELECT AVG(amount) FROM trend GROUP BY height"): + zscore.step(global_sum) + global_mean, global_deviation = 0, 1 + if zscore.count > 0: + global_mean = zscore.mean + global_deviation = zscore.standard_deviation + + db.execute(f""" + UPDATE claim SET + trending_local = COALESCE(( + SELECT zscore(amount) FROM trend + WHERE claim_hash=claim.claim_hash ORDER BY height DESC + ), 0), + trending_global = COALESCE(( + SELECT (amount - {global_mean}) / {global_deviation} FROM trend + WHERE claim_hash=claim.claim_hash AND height = {start} + ), 0), + trending_group = 0, + trending_mixed = 0 + """) + + # trending_group and trending_mixed determine how trending will show in query results + # normally the SQL will be: "ORDER BY trending_group, trending_mixed" + # changing the trending_group will have significant impact on trending results + # changing the value used for trending_mixed will only impact trending within a trending_group + db.execute(f""" + UPDATE claim SET + trending_group = CASE + WHEN trending_local > 0 AND trending_global > 0 THEN 4 + WHEN trending_local <= 0 AND trending_global > 0 THEN 3 + WHEN trending_local > 0 AND trending_global <= 0 THEN 2 + WHEN trending_local <= 0 AND trending_global <= 0 THEN 1 + END, + trending_mixed = CASE + WHEN trending_local > 0 AND trending_global > 0 THEN trending_global + WHEN trending_local <= 0 AND trending_global > 0 THEN trending_local + WHEN trending_local > 0 AND trending_global <= 0 THEN trending_local + WHEN trending_local <= 0 AND trending_global <= 0 THEN trending_global + END + WHERE trending_local <> 0 OR trending_global <> 0 + """) diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py index 96223b763..550b25767 100644 --- a/tests/unit/wallet/server/test_sqldb.py +++ b/tests/unit/wallet/server/test_sqldb.py @@ -1,8 +1,10 @@ import unittest +from binascii import hexlify from torba.client.constants import COIN, NULL_HASH32 from lbrynet.schema.claim import Claim from lbrynet.wallet.server.db import SQLDB +from lbrynet.wallet.server.trending import TRENDING_WINDOW from lbrynet.wallet.server.block_processor import Timer from lbrynet.wallet.transaction import Transaction, Input, Output @@ -21,6 +23,10 @@ def get_tx(): return Transaction().add_inputs([get_input()]) +def claim_id(claim_hash): + return hexlify(claim_hash[::-1]).decode() + + class OldWalletServerTransaction: def __init__(self, tx): self.tx = tx @@ -110,6 +116,7 @@ class TestSQLDB(unittest.TestCase): def advance(self, height, txs): self._current_height = height self.sql.advance_txs(height, txs, {'timestamp': 1}, self.timer) + return [otx[0].tx.outputs[0] for otx in txs] def state(self, controlling=None, active=None, accepted=None): self.assertEqual(controlling or [], self.get_controlling()) @@ -259,3 +266,25 @@ class TestSQLDB(unittest.TestCase): active=[('Claim A', 10*COIN, 10*COIN, 13)], accepted=[] ) + + def test_trending(self): + advance, state = self.advance, self.state + no_trend = self.get_stream('Claim A', COIN) + downwards = self.get_stream('Claim B', COIN) + up_small = self.get_stream('Claim C', COIN) + up_medium = self.get_stream('Claim D', COIN) + up_biggly = self.get_stream('Claim E', COIN) + claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) + for window in range(1, 8): + advance(TRENDING_WINDOW * window, [ + self.get_support(downwards, (20-window)*COIN), + self.get_support(up_small, int(20+(window/10)*COIN)), + self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), + self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN), + ]) + results = self.sql._search(order_by=['trending_local']) + self.assertEqual([c.claim_id for c in claims], [claim_id(c['claim_hash']) for c in results]) + self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results]) + self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results]) + self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results]) + self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])