diff --git a/lbry/wallet/server/db/trending/zscore.py b/lbry/wallet/server/db/trending.py similarity index 96% rename from lbry/wallet/server/db/trending/zscore.py rename to lbry/wallet/server/db/trending.py index bc0987d96..0fc425a3b 100644 --- a/lbry/wallet/server/db/trending/zscore.py +++ b/lbry/wallet/server/db/trending.py @@ -52,12 +52,11 @@ class ZScore: return cls(), cls.step, cls.finalize -def install(connection): +def register_trending_functions(connection): connection.createaggregatefunction("zscore", ZScore.factory, 1) - connection.cursor().execute(CREATE_TREND_TABLE) -def run(db, height, final_height, affected_claims): +def calculate_trending(db, height, final_height): # don't start tracking until we're at the end of initial sync if height < (final_height - (TRENDING_WINDOW * TRENDING_DATA_POINTS)): return diff --git a/lbry/wallet/server/db/trending/__init__.py b/lbry/wallet/server/db/trending/__init__.py deleted file mode 100644 index c9d9b616d..000000000 --- a/lbry/wallet/server/db/trending/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import zscore - - -TRENDING_ALGORITHMS = { - 'zscore': zscore, -} diff --git a/lbry/wallet/server/db/writer.py b/lbry/wallet/server/db/writer.py index b70810b07..a9c8af458 100644 --- a/lbry/wallet/server/db/writer.py +++ b/lbry/wallet/server/db/writer.py @@ -17,7 +17,9 @@ from lbry.wallet import Ledger, RegTestLedger from lbry.wallet.transaction import Transaction, Output from lbry.wallet.server.db.canonical import register_canonical_functions from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished -from lbry.wallet.server.db.trending import TRENDING_ALGORITHMS +from lbry.wallet.server.db.trending import ( + CREATE_TREND_TABLE, calculate_trending, register_trending_functions +) from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS @@ -163,14 +165,14 @@ class SQLDB: CREATE_TABLES_QUERY = ( CREATE_CLAIM_TABLE + + CREATE_TREND_TABLE + CREATE_FULL_TEXT_SEARCH + CREATE_SUPPORT_TABLE + CREATE_CLAIMTRIE_TABLE + CREATE_TAG_TABLE ) - def __init__( - self, main, path: str, blocking_channels: list, filtering_channels: list, trending: list): + def __init__(self, main, path: str, blocking_channels: list, filtering_channels: list): self.main = main self._db_path = path self.db = None @@ -188,7 +190,6 @@ class SQLDB: self.filtering_channel_hashes = { unhexlify(channel_id)[::-1] for channel_id in filtering_channels if channel_id } - self.trending = trending def open(self): self.db = apsw.Connection( @@ -207,14 +208,13 @@ class SQLDB: self.execute(self.PRAGMAS) self.execute(self.CREATE_TABLES_QUERY) register_canonical_functions(self.db) + register_trending_functions(self.db) self.state_manager = Manager() self.blocked_streams = self.state_manager.dict() self.blocked_channels = self.state_manager.dict() self.filtered_streams = self.state_manager.dict() self.filtered_channels = self.state_manager.dict() self.update_blocked_and_filtered_claims() - for algorithm in self.trending: - algorithm.install(self.db) def close(self): if self.db is not None: @@ -834,8 +834,7 @@ class SQLDB: update_claims, delete_claim_hashes, affected_channels, forward_timer=True) r(self.insert_supports, insert_supports) r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True) - for algorithm in self.trending: - r(algorithm.run, self.db.cursor(), height, daemon_height, recalculate_claim_hashes) + r(calculate_trending, self.db.cursor(), height, daemon_height) if not self._fts_synced and self.main.first_sync and height == daemon_height: r(first_sync_finished, self.db.cursor()) self._fts_synced = True @@ -846,15 +845,11 @@ class LBRYLevelDB(LevelDB): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) path = os.path.join(self.env.db_dir, 'claims.db') - trending = [] - for algorithm_name in set(self.env.default('TRENDING_ALGORITHMS', 'zscore').split(' ')): - if algorithm_name in TRENDING_ALGORITHMS: - trending.append(TRENDING_ALGORITHMS[algorithm_name]) + # space separated list of channel URIs used for filtering bad content self.sql = SQLDB( self, path, self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '), self.env.default('FILTERING_CHANNEL_IDS', '').split(' '), - trending ) def close(self): diff --git a/tests/unit/wallet/server/test_sqldb.py b/tests/unit/wallet/server/test_sqldb.py index b46c4f11e..52921514e 100644 --- a/tests/unit/wallet/server/test_sqldb.py +++ b/tests/unit/wallet/server/test_sqldb.py @@ -10,7 +10,7 @@ from lbry.schema.claim import Claim from lbry.schema.result import Censor from lbry.wallet.server.db import reader, writer from lbry.wallet.server.coin import LBCRegTest -from lbry.wallet.server.db.trending import zscore +from lbry.wallet.server.db.trending import TRENDING_WINDOW from lbry.wallet.server.db.canonical import FindShortestID from lbry.wallet.server.block_processor import Timer from lbry.wallet.transaction import Transaction, Input, Output @@ -47,7 +47,7 @@ class TestSQLDB(unittest.TestCase): self.daemon_height = 1 self.coin = LBCRegTest() db_url = 'file:test_sqldb?mode=memory&cache=shared' - self.sql = writer.SQLDB(self, db_url, [], [], [zscore]) + self.sql = writer.SQLDB(self, db_url, [], []) self.addCleanup(self.sql.close) self.sql.open() reader.initializer( @@ -533,7 +533,7 @@ class TestTrending(TestSQLDB): up_biggly = self.get_stream('Claim E', COIN) claims = advance(1, [up_biggly, up_medium, up_small, no_trend, downwards]) for window in range(1, 8): - advance(zscore.TRENDING_WINDOW * window, [ + advance(TRENDING_WINDOW * window, [ self.get_support(downwards, (20-window)*COIN), self.get_support(up_small, int(20+(window/10)*COIN)), self.get_support(up_medium, (20+(window*(2 if window == 7 else 1)))*COIN), @@ -549,8 +549,8 @@ class TestTrending(TestSQLDB): def test_edge(self): problematic = self.get_stream('Problem', COIN) self.advance(1, [problematic]) - self.advance(zscore.TRENDING_WINDOW, [self.get_support(problematic, 53000000000)]) - self.advance(zscore.TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)]) + self.advance(TRENDING_WINDOW, [self.get_support(problematic, 53000000000)]) + self.advance(TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)]) class TestContentBlocking(TestSQLDB):