added zscore trending in wallet server

This commit is contained in:
Lex Berezhny 2019-05-19 15:57:39 -04:00
parent c2c184b4ef
commit b8897223ec
3 changed files with 146 additions and 49 deletions

View file

@ -9,6 +9,9 @@ from torba.client.basedatabase import query, constraints_to_sql
from lbrynet.schema.url import URL, normalize_name
from lbrynet.wallet.transaction import Transaction, Output
from lbrynet.wallet.server.trending import (
CREATE_TREND_TABLE, calculate_trending, register_trending_functions
)
ATTRIBUTE_ARRAY_MAX_LENGTH = 100
@ -56,8 +59,6 @@ def _apply_constraints_for_array_attributes(constraints, attr):
class SQLDB:
DAY_BLOCKS = 720
PRAGMAS = """
pragma journal_mode=WAL;
"""
@ -98,16 +99,6 @@ class SQLDB:
create index if not exists claim_trending_global_idx on claim (trending_global);
"""
CREATE_TREND_TABLE = """
create table if not exists trend (
claim_hash bytes not null,
height integer not null,
amount integer not null,
primary key (claim_hash, height)
) without rowid;
create index if not exists trend_claim_hash_idx on trend (claim_hash);
"""
CREATE_SUPPORT_TABLE = """
create table if not exists support (
txo_hash bytes primary key,
@ -159,6 +150,7 @@ class SQLDB:
self.db = sqlite3.connect(self._db_path, isolation_level=None, check_same_thread=False)
self.db.row_factory = sqlite3.Row
self.db.executescript(self.CREATE_TABLES_QUERY)
register_trending_functions(self.db)
def close(self):
self.db.close()
@ -328,39 +320,6 @@ class SQLDB:
'support', {'txo_hash__in': [sqlite3.Binary(txo_hash) for txo_hash in txo_hashes]}
))
def _update_trending_amount(self, height):
day_ago = height-self.TRENDING_24_HOURS
two_day_ago = height-self.TRENDING_24_HOURS*2
week_ago = height-self.TRENDING_WEEK
two_week_ago = height-self.TRENDING_WEEK*2
self.execute(f"""
UPDATE claim SET
trending_day_one = COALESCE(
(SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
AND height >= {day_ago}), 0
),
trending_day_two = COALESCE(
(SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
AND {day_ago} > height and height >= {two_day_ago}
), 0
),
trending_week_one = COALESCE(
(SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
AND height >= {week_ago}
), 0
),
trending_week_two = COALESCE(
(SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
AND {week_ago} > height and height >= {two_week_ago}
), 0
)
""")
self.execute(f"""
UPDATE claim SET
trending_daily = trending_day_one - trending_day_two,
trending_weekly = trending_week_one - trending_week_two
""")
def _update_support_amount(self, claim_hashes):
if claim_hashes:
self.execute(f"""
@ -440,10 +399,11 @@ class SQLDB:
self.execute(f"CREATE TABLE claimtrie{height} AS SELECT * FROM claimtrie")
def update_claimtrie(self, height, changed_claim_hashes, deleted_names, timer):
r = timer.run
binary_claim_hashes = [
sqlite3.Binary(claim_hash) for claim_hash in changed_claim_hashes
]
r = timer.run
r(self._calculate_activation_height, height)
r(self._update_support_amount, binary_claim_hashes)
@ -453,9 +413,6 @@ class SQLDB:
r(self._update_effective_amount, height)
r(self._perform_overtake, height, [], [])
#if not self.main.first_sync:
# r(self._update_trending_amount, height)
def advance_txs(self, height, all_txs, header, timer):
insert_claims = set()
update_claims = set()
@ -500,6 +457,8 @@ class SQLDB:
r(self.update_claims, update_claims, header)
r(self.insert_supports, insert_supports)
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
if not self.main.first_sync:
r(calculate_trending, self.db, height)
def get_claims(self, cols, **constraints):
if 'order_by' in constraints:

View file

@ -0,0 +1,109 @@
from math import sqrt
TRENDING_WINDOW = 650 # number of blocks, ~24hr period
TRENDING_DATA_POINTS = 7 # WINDOW * DATA_POINTS = ~1 week worth of trending data
CREATE_TREND_TABLE = """
create table if not exists trend (
claim_hash bytes not null,
height integer not null,
amount integer not null,
primary key (claim_hash, height)
) without rowid;
"""
class ZScore:
__slots__ = 'count', 'total', 'power', 'last'
def __init__(self):
self.count = 0
self.total = 0
self.power = 0
self.last = None
def step(self, value):
if self.last is not None:
self.count += 1
self.total += self.last
self.power += self.last**2
self.last = value
@property
def mean(self):
return self.total / self.count
@property
def standard_deviation(self):
return sqrt((self.power / self.count) - self.mean**2)
def finalize(self):
if self.count == 0:
return self.last
return (self.last - self.mean) / (self.standard_deviation or 1)
def register_trending_functions(connection):
connection.create_aggregate("zscore", 1, ZScore)
def calculate_trending(db, height):
if height % TRENDING_WINDOW != 0:
return
db.execute(f"""
DELETE FROM trend WHERE height < {height-(TRENDING_WINDOW*TRENDING_DATA_POINTS)}
""")
start = (height-TRENDING_WINDOW)+1
db.execute(f"""
INSERT INTO trend (claim_hash, height, amount)
SELECT claim_hash, {start}, COALESCE(
(SELECT SUM(amount) FROM support WHERE claim_hash=claim.claim_hash
AND height >= {start}), 0
) AS support_sum
FROM claim WHERE support_sum > 0
""")
zscore = ZScore()
for (global_sum,) in db.execute("SELECT AVG(amount) FROM trend GROUP BY height"):
zscore.step(global_sum)
global_mean, global_deviation = 0, 1
if zscore.count > 0:
global_mean = zscore.mean
global_deviation = zscore.standard_deviation
db.execute(f"""
UPDATE claim SET
trending_local = COALESCE((
SELECT zscore(amount) FROM trend
WHERE claim_hash=claim.claim_hash ORDER BY height DESC
), 0),
trending_global = COALESCE((
SELECT (amount - {global_mean}) / {global_deviation} FROM trend
WHERE claim_hash=claim.claim_hash AND height = {start}
), 0),
trending_group = 0,
trending_mixed = 0
""")
# trending_group and trending_mixed determine how trending will show in query results
# normally the SQL will be: "ORDER BY trending_group, trending_mixed"
# changing the trending_group will have significant impact on trending results
# changing the value used for trending_mixed will only impact trending within a trending_group
db.execute(f"""
UPDATE claim SET
trending_group = CASE
WHEN trending_local > 0 AND trending_global > 0 THEN 4
WHEN trending_local <= 0 AND trending_global > 0 THEN 3
WHEN trending_local > 0 AND trending_global <= 0 THEN 2
WHEN trending_local <= 0 AND trending_global <= 0 THEN 1
END,
trending_mixed = CASE
WHEN trending_local > 0 AND trending_global > 0 THEN trending_global
WHEN trending_local <= 0 AND trending_global > 0 THEN trending_local
WHEN trending_local > 0 AND trending_global <= 0 THEN trending_local
WHEN trending_local <= 0 AND trending_global <= 0 THEN trending_global
END
WHERE trending_local <> 0 OR trending_global <> 0
""")

View file

@ -1,8 +1,10 @@
import unittest
from binascii import hexlify
from torba.client.constants import COIN, NULL_HASH32
from lbrynet.schema.claim import Claim
from lbrynet.wallet.server.db import SQLDB
from lbrynet.wallet.server.trending import TRENDING_WINDOW
from lbrynet.wallet.server.block_processor import Timer
from lbrynet.wallet.transaction import Transaction, Input, Output
@ -21,6 +23,10 @@ def get_tx():
return Transaction().add_inputs([get_input()])
def claim_id(claim_hash):
return hexlify(claim_hash[::-1]).decode()
class OldWalletServerTransaction:
def __init__(self, tx):
self.tx = tx
@ -110,6 +116,7 @@ class TestSQLDB(unittest.TestCase):
def advance(self, height, txs):
self._current_height = height
self.sql.advance_txs(height, txs, {'timestamp': 1}, self.timer)
return [otx[0].tx.outputs[0] for otx in txs]
def state(self, controlling=None, active=None, accepted=None):
self.assertEqual(controlling or [], self.get_controlling())
@ -259,3 +266,25 @@ class TestSQLDB(unittest.TestCase):
active=[('Claim A', 10*COIN, 10*COIN, 13)],
accepted=[]
)
def test_trending(self):
advance, state = self.advance, self.state
no_trend = self.get_stream('Claim A', COIN)
downwards = self.get_stream('Claim B', COIN)
up_small = self.get_stream('Claim C', COIN)
up_medium = self.get_stream('Claim D', COIN)
up_biggly = self.get_stream('Claim E', COIN)
claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards])
for window in range(1, 8):
advance(TRENDING_WINDOW * window, [
self.get_support(downwards, (20-window)*COIN),
self.get_support(up_small, int(20+(window/10)*COIN)),
self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN),
self.get_support(up_biggly, (20+(window*(3 if window == 7 else 1)))*COIN),
])
results = self.sql._search(order_by=['trending_local'])
self.assertEqual([c.claim_id for c in claims], [claim_id(c['claim_hash']) for c in results])
self.assertEqual([10, 6, 2, 0, -2], [int(c['trending_local']) for c in results])
self.assertEqual([53, 38, -32, 0, -6], [int(c['trending_global']) for c in results])
self.assertEqual([4, 4, 2, 0, 1], [int(c['trending_group']) for c in results])
self.assertEqual([53, 38, 2, 0, -6], [int(c['trending_mixed']) for c in results])