diff --git a/lbry/wallet/server/db/trending/__init__.py b/lbry/wallet/server/db/trending/__init__.py new file mode 100644 index 000000000..c9d9b616d --- /dev/null +++ b/lbry/wallet/server/db/trending/__init__.py @@ -0,0 +1,6 @@ +from . import zscore + + +TRENDING_ALGORITHMS = { + 'zscore': zscore, +} diff --git a/lbry/wallet/server/db/trending.py b/lbry/wallet/server/db/trending/zscore.py similarity index 96% rename from lbry/wallet/server/db/trending.py rename to lbry/wallet/server/db/trending/zscore.py index 0fc425a3b..bc0987d96 100644 --- a/lbry/wallet/server/db/trending.py +++ b/lbry/wallet/server/db/trending/zscore.py @@ -52,11 +52,12 @@ class ZScore: return cls(), cls.step, cls.finalize -def register_trending_functions(connection): +def install(connection): connection.createaggregatefunction("zscore", ZScore.factory, 1) + connection.cursor().execute(CREATE_TREND_TABLE) -def calculate_trending(db, height, final_height): +def run(db, height, final_height, affected_claims): # don't start tracking until we're at the end of initial sync if height < (final_height - (TRENDING_WINDOW * TRENDING_DATA_POINTS)): return diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py index a9c8af458..b70810b07 100644 --- a/lbry/wallet/server/db/writer.py +++ b/lbry/wallet/server/db/writer.py @@ -17,9 +17,7 @@ from lbry.wallet import Ledger, RegTestLedger from lbry.wallet.transaction import Transaction, Output from lbry.wallet.server.db.canonical import register_canonical_functions from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished -from lbry.wallet.server.db.trending import ( - CREATE_TREND_TABLE, calculate_trending, register_trending_functions -) +from lbry.wallet.server.db.trending import TRENDING_ALGORITHMS from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS @@ -165,14 +163,14 @@ class SQLDB: CREATE_TABLES_QUERY = ( CREATE_CLAIM_TABLE + - CREATE_TREND_TABLE + CREATE_FULL_TEXT_SEARCH + CREATE_SUPPORT_TABLE + CREATE_CLAIMTRIE_TABLE + CREATE_TAG_TABLE ) - def __init__(self, main, path: str, blocking_channels: list, filtering_channels: list): + def __init__( + self, main, path: str, blocking_channels: list, filtering_channels: list, trending: list): self.main = main self._db_path = path self.db = None @@ -190,6 +188,7 @@ class SQLDB: self.filtering_channel_hashes = { unhexlify(channel_id)[::-1] for channel_id in filtering_channels if channel_id } + self.trending = trending def open(self): self.db = apsw.Connection( @@ -208,13 +207,14 @@ class SQLDB: self.execute(self.PRAGMAS) self.execute(self.CREATE_TABLES_QUERY) register_canonical_functions(self.db) - register_trending_functions(self.db) self.state_manager = Manager() self.blocked_streams = self.state_manager.dict() self.blocked_channels = self.state_manager.dict() self.filtered_streams = self.state_manager.dict() self.filtered_channels = self.state_manager.dict() self.update_blocked_and_filtered_claims() + for algorithm in self.trending: + algorithm.install(self.db) def close(self): if self.db is not None: @@ -834,7 +834,8 @@ class SQLDB: update_claims, delete_claim_hashes, affected_channels, forward_timer=True) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) - r(calculate_trending, self.db.cursor(), height, daemon_height) + for algorithm in self.trending: + r(algorithm.run, self.db.cursor(), height, daemon_height, recalculate_claim_hashes) if not self._fts_synced and self.main.first_sync and height == daemon_height: r(first_sync_finished, self.db.cursor()) self._fts_synced = True @@ -845,11 +846,15 @@ class LBRYLevelDB(LevelDB): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) path = os.path.join(self.env.db_dir, 'claims.db') - # space separated list of channel URIs used for filtering bad content + trending = [] + for algorithm_name in set(self.env.default('TRENDING_ALGORITHMS', 'zscore').split(' ')): + if algorithm_name in TRENDING_ALGORITHMS: + trending.append(TRENDING_ALGORITHMS[algorithm_name]) self.sql = SQLDB( self, path, self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '), self.env.default('FILTERING_CHANNEL_IDS', '').split(' '), + trending ) def close(self): diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py index 52921514e..b46c4f11e 100644 --- a/tests/unit/wallet/server/test_sqldb.py +++ b/tests/unit/wallet/server/test_sqldb.py @@ -10,7 +10,7 @@ from lbry.schema.claim import Claim from lbry.schema.result import Censor from lbry.wallet.server.db import reader, writer from lbry.wallet.server.coin import LBCRegTest -from lbry.wallet.server.db.trending import TRENDING_WINDOW +from lbry.wallet.server.db.trending import zscore from lbry.wallet.server.db.canonical import FindShortestID from lbry.wallet.server.block_processor import Timer from lbry.wallet.transaction import Transaction, Input, Output @@ -47,7 +47,7 @@ class TestSQLDB(unittest.TestCase): self.daemon_height = 1 self.coin = LBCRegTest() db_url = 'file:test_sqldb?mode=memory&cache=shared' - self.sql = writer.SQLDB(self, db_url, [], []) + self.sql = writer.SQLDB(self, db_url, [], [], [zscore]) self.addCleanup(self.sql.close) self.sql.open() reader.initializer( @@ -533,7 +533,7 @@ class TestTrending(TestSQLDB): up_biggly = self.get_stream('Claim E', COIN) claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) for window in range(1, 8): - advance(TRENDING_WINDOW * window, [ + advance(zscore.TRENDING_WINDOW * window, [ self.get_support(downwards, (20-window)*COIN), self.get_support(up_small, int(20+(window/10)*COIN)), self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), @@ -549,8 +549,8 @@ class TestTrending(TestSQLDB): def test_edge(self): problematic = self.get_stream('Problem', COIN) self.advance(1, [problematic]) - self.advance(TRENDING_WINDOW, [self.get_support(problematic, 53000000000)]) - self.advance(TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)]) + self.advance(zscore.TRENDING_WINDOW, [self.get_support(problematic, 53000000000)]) + self.advance(zscore.TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)]) class TestContentBlocking(TestSQLDB):