From b8897223ec781bb9098be3cf3042db0cbece9f02 Mon Sep 17 00:00:00 2001
From: Lex Berezhny <lex@damoti.com>
Date: Sun, 19 May 2019 15:57:39 -0400
Subject: [PATCH] added zscore trending in wallet server

---
 lbrynet/wallet/server/db.py            |  57 ++-----------
 lbrynet/wallet/server/trending.py      | 109 +++++++++++++++++++++++++
 tests/unit/wallet/server/test_sqldb.py |  29 +++++++
 3 files changed, 146 insertions(+), 49 deletions(-)
 create mode 100644 lbrynet/wallet/server/trending.py

diff --git a/lbrynet/wallet/server/db.py b/lbrynet/wallet/server/db.py
index 48730d892..027f327f3 100644
--- a/lbrynet/wallet/server/db.py
+++ b/lbrynet/wallet/server/db.py
@@ -9,6 +9,9 @@ from torba.client.basedatabase import query, constraints_to_sql
 
 from lbrynet.schema.url import URL, normalize_name
 from lbrynet.wallet.transaction import Transaction, Output
+from lbrynet.wallet.server.trending import (
+    CREATE_TREND_TABLE, calculate_trending, register_trending_functions
+)
 
 
 ATTRIBUTE_ARRAY_MAX_LENGTH = 100
@@ -56,8 +59,6 @@ def _apply_constraints_for_array_attributes(constraints, attr):
 
 class SQLDB:
 
-    DAY_BLOCKS = 720
-
     PRAGMAS = """
         pragma journal_mode=WAL;
     """
@@ -98,16 +99,6 @@ class SQLDB:
         create index if not exists claim_trending_global_idx on claim (trending_global);
     """
 
-    CREATE_TREND_TABLE = """
-        create table if not exists trend (
-            claim_hash bytes not null,
-            height integer not null,
-            amount integer not null,
-            primary key (claim_hash, height)
-        ) without rowid;
-        create index if not exists trend_claim_hash_idx on trend (claim_hash);
-    """
-
     CREATE_SUPPORT_TABLE = """
         create table if not exists support (
             txo_hash bytes primary key,
@@ -159,6 +150,7 @@ class SQLDB:
         self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False)
         self.db.row_factory = sqlite3.Row
         self.db.executescript(self.CREATE_TABLES_QUERY)
+        register_trending_functions(self.db)
 
     def close(self):
         self.db.close()
@@ -328,39 +320,6 @@ class SQLDB:
                 'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]}
             ))
 
-    def _update_trending_amount(self, height):
-        day_ago = height-self.TRENDING_24_HOURS
-        two_day_ago = height-self.TRENDING_24_HOURS*2
-        week_ago = height-self.TRENDING_WEEK
-        two_week_ago = height-self.TRENDING_WEEK*2
-        self.execute(f"""
-            UPDATE claim SET
-                trending_day_one = COALESCE(
-                    (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
-                     AND height >= {day_ago}), 0
-                ),
-                trending_day_two = COALESCE(
-                    (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
-                     AND {day_ago} > height and height >= {two_day_ago}
-                     ), 0
-                ),
-                trending_week_one = COALESCE(
-                    (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
-                     AND height >= {week_ago}
-                     ), 0
-                ),
-                trending_week_two = COALESCE(
-                    (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
-                     AND {week_ago} > height and height >= {two_week_ago}
-                     ), 0
-                )
-        """)
-        self.execute(f"""
-            UPDATE claim SET
-                trending_daily = trending_day_one - trending_day_two,
-                trending_weekly = trending_week_one - trending_week_two
-        """)
-
     def _update_support_amount(self, claim_hashes):
         if claim_hashes:
             self.execute(f"""
@@ -440,10 +399,11 @@ class SQLDB:
         self.execute(f"CREATE TABLE claimtrie{height} AS SELECT * FROM claimtrie")
 
     def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer):
+        r = timer.run
         binary_claim_hashes = [
             sqlite3.Binary(claim_hash) for claim_hash in changed_claim_hashes
         ]
-        r = timer.run
+
         r(self._calculate_activation_height, height)
         r(self._update_support_amount, binary_claim_hashes)
 
@@ -453,9 +413,6 @@ class SQLDB:
         r(self._update_effective_amount, height)
         r(self._perform_overtake, height, [], [])
 
-        #if not self.main.first_sync:
-        #    r(self._update_trending_amount, height)
-
     def advance_txs(self, height, all_txs, header, timer):
         insert_claims = set()
         update_claims = set()
@@ -500,6 +457,8 @@ class SQLDB:
         r(self.update_claims, update_claims, header)
         r(self.insert_supports, insert_supports)
         r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
+        if not self.main.first_sync:
+            r(calculate_trending, self.db, height)
 
     def get_claims(self, cols, **constraints):
         if 'order_by' in constraints:
diff --git a/lbrynet/wallet/server/trending.py b/lbrynet/wallet/server/trending.py
new file mode 100644
index 000000000..87f245963
--- /dev/null
+++ b/lbrynet/wallet/server/trending.py
@@ -0,0 +1,109 @@
+from math import sqrt
+
+TRENDING_WINDOW = 650  # number of blocks, ~24hr period
+TRENDING_DATA_POINTS = 7  # WINDOW * DATA_POINTS = ~1 week worth of trending data
+
+CREATE_TREND_TABLE = """
+    create table if not exists trend (
+        claim_hash bytes not null,
+        height integer not null,
+        amount integer not null,
+        primary key (claim_hash, height)
+    ) without rowid;
+"""
+
+
+class ZScore:
+    __slots__ = 'count', 'total', 'power', 'last'
+
+    def __init__(self):
+        self.count = 0
+        self.total = 0
+        self.power = 0
+        self.last = None
+
+    def step(self, value):
+        if self.last is not None:
+            self.count += 1
+            self.total += self.last
+            self.power += self.last**2
+        self.last = value
+
+    @property
+    def mean(self):
+        return self.total / self.count
+
+    @property
+    def standard_deviation(self):
+        return sqrt((self.power / self.count) - self.mean**2)
+
+    def finalize(self):
+        if self.count == 0:
+            return self.last
+        return (self.last - self.mean) / (self.standard_deviation or 1)
+
+
+def register_trending_functions(connection):
+    connection.create_aggregate("zscore", 1, ZScore)
+
+
+def calculate_trending(db, height):
+    if height % TRENDING_WINDOW != 0:
+        return
+
+    db.execute(f"""
+    DELETE FROM trend WHERE height < {height-(TRENDING_WINDOW*TRENDING_DATA_POINTS)}
+    """)
+
+    start = (height-TRENDING_WINDOW)+1
+    db.execute(f"""
+    INSERT INTO trend (claim_hash, height, amount)
+    SELECT claim_hash, {start}, COALESCE(
+            (SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
+             AND height >= {start}), 0
+        ) AS support_sum
+    FROM claim WHERE support_sum > 0
+    """)
+
+    zscore = ZScore()
+    for (global_sum,) in db.execute("SELECT AVG(amount) FROM trend GROUP BY height"):
+        zscore.step(global_sum)
+    global_mean, global_deviation = 0, 1
+    if zscore.count > 0:
+        global_mean = zscore.mean
+        global_deviation = zscore.standard_deviation
+
+    db.execute(f"""
+    UPDATE claim SET
+        trending_local = COALESCE((
+            SELECT zscore(amount) FROM trend
+            WHERE claim_hash=claim.claim_hash ORDER BY height DESC
+        ), 0),
+        trending_global = COALESCE((
+            SELECT (amount - {global_mean}) / {global_deviation} FROM trend
+            WHERE claim_hash=claim.claim_hash AND height = {start}
+        ), 0),
+        trending_group = 0,
+        trending_mixed = 0
+    """)
+
+    # trending_group and trending_mixed determine how trending will show in query results
+    # normally the SQL will be: "ORDER BY trending_group, trending_mixed"
+    # changing the trending_group will have significant impact on trending results
+    # changing the value used for trending_mixed will only impact trending within a trending_group
+    db.execute(f"""
+    UPDATE claim SET
+        trending_group = CASE 
+        WHEN trending_local > 0 AND trending_global > 0 THEN 4
+        WHEN trending_local <= 0 AND trending_global > 0 THEN 3
+        WHEN trending_local > 0 AND trending_global <= 0 THEN 2
+        WHEN trending_local <= 0 AND trending_global <= 0 THEN 1
+        END,
+        trending_mixed = CASE 
+        WHEN trending_local > 0 AND trending_global > 0 THEN trending_global
+        WHEN trending_local <= 0 AND trending_global > 0 THEN trending_local
+        WHEN trending_local > 0 AND trending_global <= 0 THEN trending_local
+        WHEN trending_local <= 0 AND trending_global <= 0 THEN trending_global
+        END
+    WHERE trending_local <> 0 OR trending_global <> 0
+    """)
diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py
index 96223b763..550b25767 100644
--- a/tests/unit/wallet/server/test_sqldb.py
+++ b/tests/unit/wallet/server/test_sqldb.py
@@ -1,8 +1,10 @@
 import unittest
+from binascii import hexlify
 from torba.client.constants import COIN, NULL_HASH32
 
 from lbrynet.schema.claim import Claim
 from lbrynet.wallet.server.db import SQLDB
+from lbrynet.wallet.server.trending import TRENDING_WINDOW
 from lbrynet.wallet.server.block_processor import Timer
 from lbrynet.wallet.transaction import Transaction, Input, Output
 
@@ -21,6 +23,10 @@ def get_tx():
     return Transaction().add_inputs([get_input()])
 
 
+def claim_id(claim_hash):
+    return hexlify(claim_hash[::-1]).decode()
+
+
 class OldWalletServerTransaction:
     def __init__(self, tx):
         self.tx = tx
@@ -110,6 +116,7 @@ class TestSQLDB(unittest.TestCase):
     def advance(self, height, txs):
         self._current_height = height
         self.sql.advance_txs(height, txs, {'timestamp': 1}, self.timer)
+        return [otx[0].tx.outputs[0] for otx in txs]
 
     def state(self, controlling=None, active=None, accepted=None):
         self.assertEqual(controlling or [], self.get_controlling())
@@ -259,3 +266,25 @@ class TestSQLDB(unittest.TestCase):
             active=[('Claim A', 10*COIN, 10*COIN, 13)],
             accepted=[]
         )
+
+    def test_trending(self):
+        advance, state = self.advance, self.state
+        no_trend = self.get_stream('Claim A', COIN)
+        downwards = self.get_stream('Claim B', COIN)
+        up_small = self.get_stream('Claim C', COIN)
+        up_medium = self.get_stream('Claim D', COIN)
+        up_biggly = self.get_stream('Claim E', COIN)
+        claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards])
+        for window in range(1, 8):
+            advance(TRENDING_WINDOW * window, [
+                self.get_support(downwards, (20-window)*COIN),
+                self.get_support(up_small, int(20+(window/10)*COIN)),
+                self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN),
+                self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN),
+            ])
+        results = self.sql._search(order_by=['trending_local'])
+        self.assertEqual([c.claim_id for c in claims], [claim_id(c['claim_hash']) for c in results])
+        self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results])
+        self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results])
+        self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results])
+        self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])