Merge pull request #2750 from lbryio/configurable_trending
configurable trending algorithms
This commit is contained in:
commit
3e0a9180bc
4 changed files with 27 additions and 15 deletions
6
lbry/wallet/server/db/trending/__init__.py
Normal file
6
lbry/wallet/server/db/trending/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from . import zscore
|
||||||
|
|
||||||
|
|
||||||
|
TRENDING_ALGORITHMS = {
|
||||||
|
'zscore': zscore,
|
||||||
|
}
|
|
@ -52,11 +52,12 @@ class ZScore:
|
||||||
return cls(), cls.step, cls.finalize
|
return cls(), cls.step, cls.finalize
|
||||||
|
|
||||||
|
|
||||||
def register_trending_functions(connection):
|
def install(connection):
|
||||||
connection.createaggregatefunction("zscore", ZScore.factory, 1)
|
connection.createaggregatefunction("zscore", ZScore.factory, 1)
|
||||||
|
connection.cursor().execute(CREATE_TREND_TABLE)
|
||||||
|
|
||||||
|
|
||||||
def calculate_trending(db, height, final_height):
|
def run(db, height, final_height, affected_claims):
|
||||||
# don't start tracking until we're at the end of initial sync
|
# don't start tracking until we're at the end of initial sync
|
||||||
if height < (final_height - (TRENDING_WINDOW * TRENDING_DATA_POINTS)):
|
if height < (final_height - (TRENDING_WINDOW * TRENDING_DATA_POINTS)):
|
||||||
return
|
return
|
|
@ -17,9 +17,7 @@ from lbry.wallet import Ledger, RegTestLedger
|
||||||
from lbry.wallet.transaction import Transaction, Output
|
from lbry.wallet.transaction import Transaction, Output
|
||||||
from lbry.wallet.server.db.canonical import register_canonical_functions
|
from lbry.wallet.server.db.canonical import register_canonical_functions
|
||||||
from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished
|
from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished
|
||||||
from lbry.wallet.server.db.trending import (
|
from lbry.wallet.server.db.trending import TRENDING_ALGORITHMS
|
||||||
CREATE_TREND_TABLE, calculate_trending, register_trending_functions
|
|
||||||
)
|
|
||||||
|
|
||||||
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
|
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS
|
||||||
|
|
||||||
|
@ -165,14 +163,14 @@ class SQLDB:
|
||||||
|
|
||||||
CREATE_TABLES_QUERY = (
|
CREATE_TABLES_QUERY = (
|
||||||
CREATE_CLAIM_TABLE +
|
CREATE_CLAIM_TABLE +
|
||||||
CREATE_TREND_TABLE +
|
|
||||||
CREATE_FULL_TEXT_SEARCH +
|
CREATE_FULL_TEXT_SEARCH +
|
||||||
CREATE_SUPPORT_TABLE +
|
CREATE_SUPPORT_TABLE +
|
||||||
CREATE_CLAIMTRIE_TABLE +
|
CREATE_CLAIMTRIE_TABLE +
|
||||||
CREATE_TAG_TABLE
|
CREATE_TAG_TABLE
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, main, path: str, blocking_channels: list, filtering_channels: list):
|
def __init__(
|
||||||
|
self, main, path: str, blocking_channels: list, filtering_channels: list, trending: list):
|
||||||
self.main = main
|
self.main = main
|
||||||
self._db_path = path
|
self._db_path = path
|
||||||
self.db = None
|
self.db = None
|
||||||
|
@ -190,6 +188,7 @@ class SQLDB:
|
||||||
self.filtering_channel_hashes = {
|
self.filtering_channel_hashes = {
|
||||||
unhexlify(channel_id)[::-1] for channel_id in filtering_channels if channel_id
|
unhexlify(channel_id)[::-1] for channel_id in filtering_channels if channel_id
|
||||||
}
|
}
|
||||||
|
self.trending = trending
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self.db = apsw.Connection(
|
self.db = apsw.Connection(
|
||||||
|
@ -208,13 +207,14 @@ class SQLDB:
|
||||||
self.execute(self.PRAGMAS)
|
self.execute(self.PRAGMAS)
|
||||||
self.execute(self.CREATE_TABLES_QUERY)
|
self.execute(self.CREATE_TABLES_QUERY)
|
||||||
register_canonical_functions(self.db)
|
register_canonical_functions(self.db)
|
||||||
register_trending_functions(self.db)
|
|
||||||
self.state_manager = Manager()
|
self.state_manager = Manager()
|
||||||
self.blocked_streams = self.state_manager.dict()
|
self.blocked_streams = self.state_manager.dict()
|
||||||
self.blocked_channels = self.state_manager.dict()
|
self.blocked_channels = self.state_manager.dict()
|
||||||
self.filtered_streams = self.state_manager.dict()
|
self.filtered_streams = self.state_manager.dict()
|
||||||
self.filtered_channels = self.state_manager.dict()
|
self.filtered_channels = self.state_manager.dict()
|
||||||
self.update_blocked_and_filtered_claims()
|
self.update_blocked_and_filtered_claims()
|
||||||
|
for algorithm in self.trending:
|
||||||
|
algorithm.install(self.db)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.db is not None:
|
if self.db is not None:
|
||||||
|
@ -834,7 +834,8 @@ class SQLDB:
|
||||||
update_claims, delete_claim_hashes, affected_channels, forward_timer=True)
|
update_claims, delete_claim_hashes, affected_channels, forward_timer=True)
|
||||||
r(self.insert_supports, insert_supports)
|
r(self.insert_supports, insert_supports)
|
||||||
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
|
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
|
||||||
r(calculate_trending, self.db.cursor(), height, daemon_height)
|
for algorithm in self.trending:
|
||||||
|
r(algorithm.run, self.db.cursor(), height, daemon_height, recalculate_claim_hashes)
|
||||||
if not self._fts_synced and self.main.first_sync and height == daemon_height:
|
if not self._fts_synced and self.main.first_sync and height == daemon_height:
|
||||||
r(first_sync_finished, self.db.cursor())
|
r(first_sync_finished, self.db.cursor())
|
||||||
self._fts_synced = True
|
self._fts_synced = True
|
||||||
|
@ -845,11 +846,15 @@ class LBRYLevelDB(LevelDB):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
path = os.path.join(self.env.db_dir, 'claims.db')
|
path = os.path.join(self.env.db_dir, 'claims.db')
|
||||||
# space separated list of channel URIs used for filtering bad content
|
trending = []
|
||||||
|
for algorithm_name in set(self.env.default('TRENDING_ALGORITHMS', 'zscore').split(' ')):
|
||||||
|
if algorithm_name in TRENDING_ALGORITHMS:
|
||||||
|
trending.append(TRENDING_ALGORITHMS[algorithm_name])
|
||||||
self.sql = SQLDB(
|
self.sql = SQLDB(
|
||||||
self, path,
|
self, path,
|
||||||
self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '),
|
self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '),
|
||||||
self.env.default('FILTERING_CHANNEL_IDS', '').split(' '),
|
self.env.default('FILTERING_CHANNEL_IDS', '').split(' '),
|
||||||
|
trending
|
||||||
)
|
)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
|
@ -10,7 +10,7 @@ from lbry.schema.claim import Claim
|
||||||
from lbry.schema.result import Censor
|
from lbry.schema.result import Censor
|
||||||
from lbry.wallet.server.db import reader, writer
|
from lbry.wallet.server.db import reader, writer
|
||||||
from lbry.wallet.server.coin import LBCRegTest
|
from lbry.wallet.server.coin import LBCRegTest
|
||||||
from lbry.wallet.server.db.trending import TRENDING_WINDOW
|
from lbry.wallet.server.db.trending import zscore
|
||||||
from lbry.wallet.server.db.canonical import FindShortestID
|
from lbry.wallet.server.db.canonical import FindShortestID
|
||||||
from lbry.wallet.server.block_processor import Timer
|
from lbry.wallet.server.block_processor import Timer
|
||||||
from lbry.wallet.transaction import Transaction, Input, Output
|
from lbry.wallet.transaction import Transaction, Input, Output
|
||||||
|
@ -47,7 +47,7 @@ class TestSQLDB(unittest.TestCase):
|
||||||
self.daemon_height = 1
|
self.daemon_height = 1
|
||||||
self.coin = LBCRegTest()
|
self.coin = LBCRegTest()
|
||||||
db_url = 'file:test_sqldb?mode=memory&cache=shared'
|
db_url = 'file:test_sqldb?mode=memory&cache=shared'
|
||||||
self.sql = writer.SQLDB(self, db_url, [], [])
|
self.sql = writer.SQLDB(self, db_url, [], [], [zscore])
|
||||||
self.addCleanup(self.sql.close)
|
self.addCleanup(self.sql.close)
|
||||||
self.sql.open()
|
self.sql.open()
|
||||||
reader.initializer(
|
reader.initializer(
|
||||||
|
@ -533,7 +533,7 @@ class TestTrending(TestSQLDB):
|
||||||
up_biggly = self.get_stream('Claim E', COIN)
|
up_biggly = self.get_stream('Claim E', COIN)
|
||||||
claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards])
|
claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards])
|
||||||
for window in range(1, 8):
|
for window in range(1, 8):
|
||||||
advance(TRENDING_WINDOW * window, [
|
advance(zscore.TRENDING_WINDOW * window, [
|
||||||
self.get_support(downwards, (20-window)*COIN),
|
self.get_support(downwards, (20-window)*COIN),
|
||||||
self.get_support(up_small, int(20+(window/10)*COIN)),
|
self.get_support(up_small, int(20+(window/10)*COIN)),
|
||||||
self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN),
|
self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN),
|
||||||
|
@ -549,8 +549,8 @@ class TestTrending(TestSQLDB):
|
||||||
def test_edge(self):
|
def test_edge(self):
|
||||||
problematic = self.get_stream('Problem', COIN)
|
problematic = self.get_stream('Problem', COIN)
|
||||||
self.advance(1, [problematic])
|
self.advance(1, [problematic])
|
||||||
self.advance(TRENDING_WINDOW, [self.get_support(problematic, 53000000000)])
|
self.advance(zscore.TRENDING_WINDOW, [self.get_support(problematic, 53000000000)])
|
||||||
self.advance(TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)])
|
self.advance(zscore.TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)])
|
||||||
|
|
||||||
|
|
||||||
class TestContentBlocking(TestSQLDB):
|
class TestContentBlocking(TestSQLDB):
|
||||||
|
|
Loading…
Reference in a new issue