diff --git a/lbrynet/dht/blob_announcer.py b/lbrynet/dht/blob_announcer.py index 136499752..130a06626 100644 --- a/lbrynet/dht/blob_announcer.py +++ b/lbrynet/dht/blob_announcer.py @@ -1,7 +1,6 @@ import asyncio import typing import logging -import time if typing.TYPE_CHECKING: from lbrynet.dht.node import Node from lbrynet.extras.daemon.storage import SQLiteStorage @@ -10,8 +9,7 @@ log = logging.getLogger(__name__) class BlobAnnouncer: - def __init__(self, loop: asyncio.BaseEventLoop, node: 'Node', storage: 'SQLiteStorage', - time_getter: typing.Callable[[], float] = time.time): + def __init__(self, loop: asyncio.BaseEventLoop, node: 'Node', storage: 'SQLiteStorage'): self.loop = loop self.node = node self.storage = storage @@ -19,7 +17,6 @@ class BlobAnnouncer: self.announce_task: asyncio.Task = None self.running = False self.announce_queue: typing.List[str] = [] - self.time_getter = time_getter async def _announce(self, batch_size: typing.Optional[int] = 10): if not batch_size: @@ -44,7 +41,7 @@ class BlobAnnouncer: to_await.append(batch.pop()) if to_await: await asyncio.gather(*tuple(to_await), loop=self.loop) - await self.storage.update_last_announced_blobs(announced, self.time_getter()) + await self.storage.update_last_announced_blobs(announced) log.info("announced %i blobs", len(announced)) if self.running: self.pending_call = self.loop.call_later(60, self.announce, batch_size) diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index fbac8f09e..4475cd19e 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -185,11 +185,12 @@ class SQLiteStorage(SQLiteMixin): ); """ - def __init__(self, conf: Config, path, loop=None): + def __init__(self, conf: Config, path, loop=None, time_getter: typing.Optional[typing.Callable[[], float]] = None): super().__init__(path) self.conf = conf self.content_claim_callbacks = {} self.loop = loop or asyncio.get_event_loop() + self.time_getter = time_getter or time.time async def run_and_return_one_or_none(self, query, *args): for row in await self.db.execute_fetchall(query, args): @@ -248,8 +249,9 @@ class SQLiteStorage(SQLiteMixin): "select count(*) from blob where status='finished'" ) - def update_last_announced_blobs(self, blob_hashes: typing.List[str], last_announced: float): + def update_last_announced_blobs(self, blob_hashes: typing.List[str]): def _update_last_announced_blobs(transaction: sqlite3.Connection): + last_announced = self.time_getter() return transaction.executemany( "update blob set next_announce_time=?, last_announced_time=?, single_announce=0 " "where blob_hash=?", @@ -260,7 +262,7 @@ class SQLiteStorage(SQLiteMixin): def should_single_announce_blobs(self, blob_hashes, immediate=False): def set_single_announce(transaction): - now = self.loop.time() + now = int(self.time_getter()) for blob_hash in blob_hashes: if immediate: transaction.execute( @@ -275,7 +277,7 @@ class SQLiteStorage(SQLiteMixin): def get_blobs_to_announce(self): def get_and_update(transaction): - timestamp = int(self.loop.time()) + timestamp = int(self.time_getter()) if self.conf.announce_head_and_sd_only: r = transaction.execute( "select blob_hash from blob " @@ -700,5 +702,5 @@ class SQLiteStorage(SQLiteMixin): "select s.sd_hash from stream s " "left outer join reflected_stream r on s.sd_hash=r.sd_hash " "where r.timestamp is null or r.timestamp < ?", - self.loop.time() - 86400 + int(self.time_getter()) - 86400 )