diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index 65af4f4cc..1b4c46f01 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -24,9 +24,12 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.blob: typing.Optional['AbstractBlob'] = None self._blob_bytes_received = 0 - self._response_fut: asyncio.Future = None + self._response_fut: typing.Optional[asyncio.Future] = None self.buf = b'' + # this is here to handle the race when the downloader is closed right as response_fut gets a result + self.closed = asyncio.Event(loop=self.loop) + def data_received(self, data: bytes): log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received", self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received) @@ -90,6 +93,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): :return: download success (bool), keep connection (bool) """ request = BlobRequest.make_request_for_blob_hash(self.blob.blob_hash) + blob_hash = self.blob.blob_hash try: msg = request.serialize() log.debug("send request to %s:%i -> %s", self.peer_address, self.peer_port, msg.decode()) @@ -98,6 +102,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol): availability_response = response.get_availability_response() price_response = response.get_price_response() blob_response = response.get_blob_response() + if self.closed.is_set(): + msg = f"cancelled blob request for {blob_hash} immediately after we got a response" + log.warning(msg) + raise asyncio.CancelledError(msg) if (not blob_response or blob_response.error) and\ (not availability_response or not availability_response.available_blobs): log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address, @@ -136,6 +144,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): return self._blob_bytes_received, self.close() def close(self): + self.closed.set() if self._response_fut and not self._response_fut.done(): self._response_fut.cancel() if self.writer and not self.writer.closed(): @@ -149,6 +158,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.buf = b'' async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: + self.closed.clear() blob_hash = blob.blob_hash if blob.get_is_verified() or not blob.is_writeable(): return 0, self.transport diff --git a/lbrynet/extras/daemon/Components.py b/lbrynet/extras/daemon/Components.py index afd32b078..359ed97f9 100644 --- a/lbrynet/extras/daemon/Components.py +++ b/lbrynet/extras/daemon/Components.py @@ -63,7 +63,7 @@ class DatabaseComponent(Component): @staticmethod def get_current_db_revision(): - return 10 + return 11 @property def revision_filename(self): diff --git a/lbrynet/extras/daemon/migrator/dbmigrator.py b/lbrynet/extras/daemon/migrator/dbmigrator.py index 47fa080f9..76d2264df 100644 --- a/lbrynet/extras/daemon/migrator/dbmigrator.py +++ b/lbrynet/extras/daemon/migrator/dbmigrator.py @@ -24,6 +24,8 @@ def migrate_db(conf, start, end): from .migrate8to9 import do_migration elif current == 9: from .migrate9to10 import do_migration + elif current == 10: + from .migrate10to11 import do_migration else: raise Exception("DB migration of version {} to {} is not available".format(current, current+1)) diff --git a/lbrynet/extras/daemon/migrator/migrate10to11.py b/lbrynet/extras/daemon/migrator/migrate10to11.py new file mode 100644 index 000000000..9974c785c --- /dev/null +++ b/lbrynet/extras/daemon/migrator/migrate10to11.py @@ -0,0 +1,54 @@ +import sqlite3 +import os +import binascii + + +def do_migration(conf): + db_path = os.path.join(conf.data_dir, "lbrynet.sqlite") + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + + current_columns = [] + for col_info in cursor.execute("pragma table_info('file');").fetchall(): + current_columns.append(col_info[1]) + if 'content_fee' in current_columns or 'saved_file' in current_columns: + connection.close() + print("already migrated") + return + + cursor.execute( + "pragma foreign_keys=off;" + ) + + cursor.execute(""" + create table if not exists new_file ( + stream_hash text primary key not null references stream, + file_name text, + download_directory text, + blob_data_rate real not null, + status text not null, + saved_file integer not null, + content_fee text + ); + """) + for (stream_hash, file_name, download_dir, data_rate, status) in cursor.execute("select * from file").fetchall(): + saved_file = 0 + if download_dir != '{stream}' and file_name != '{stream}': + try: + if os.path.isfile(os.path.join(binascii.unhexlify(download_dir).decode(), + binascii.unhexlify(file_name).decode())): + saved_file = 1 + else: + download_dir, file_name = None, None + except Exception: + download_dir, file_name = None, None + else: + download_dir, file_name = None, None + cursor.execute( + "insert into new_file values (?, ?, ?, ?, ?, ?, NULL)", + (stream_hash, file_name, download_dir, data_rate, status, saved_file) + ) + cursor.execute("drop table file") + cursor.execute("alter table new_file rename to file") + connection.commit() + connection.close() diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index 28ce23acf..9a6e2035c 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -8,6 +8,7 @@ import time from torba.client.basedatabase import SQLiteMixin from lbrynet.conf import Config from lbrynet.wallet.dewies import dewies_to_lbc, lbc_to_dewies +from lbrynet.wallet.transaction import Transaction from lbrynet.schema.claim import Claim from lbrynet.dht.constants import data_expiration from lbrynet.blob.blob_info import BlobInfo @@ -114,8 +115,8 @@ def _batched_select(transaction, query, parameters, batch_size=900): def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]: files = [] signed_claims = {} - for (rowid, stream_hash, file_name, download_dir, data_rate, status, _, sd_hash, stream_key, - stream_name, suggested_file_name, *claim_args) in _batched_select( + for (rowid, stream_hash, file_name, download_dir, data_rate, status, saved_file, raw_content_fee, _, + sd_hash, stream_key, stream_name, suggested_file_name, *claim_args) in _batched_select( transaction, "select file.rowid, file.*, stream.*, c.* " "from file inner join stream on file.stream_hash=stream.stream_hash " "inner join content_claim cc on file.stream_hash=cc.stream_hash " @@ -141,7 +142,11 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di "key": stream_key, "stream_name": stream_name, # hex "suggested_file_name": suggested_file_name, # hex - "claim": claim + "claim": claim, + "saved_file": bool(saved_file), + "content_fee": None if not raw_content_fee else Transaction( + binascii.unhexlify(raw_content_fee) + ) } ) for claim_name, claim_id in _batched_select( @@ -188,16 +193,20 @@ def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str], - download_directory: typing.Optional[str], data_payment_rate: float, status: str) -> int: + download_directory: typing.Optional[str], data_payment_rate: float, status: str, + content_fee: typing.Optional[Transaction]) -> int: if not file_name and not download_directory: - encoded_file_name, encoded_download_dir = "{stream}", "{stream}" + encoded_file_name, encoded_download_dir = None, None else: encoded_file_name = binascii.hexlify(file_name.encode()).decode() encoded_download_dir = binascii.hexlify(download_directory.encode()).decode() transaction.execute( - "insert or replace into file values (?, ?, ?, ?, ?)", - (stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status) + "insert or replace into file values (?, ?, ?, ?, ?, ?, ?)", + (stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status, + 1 if (file_name and download_directory and os.path.isfile(os.path.join(download_directory, file_name))) else 0, + None if not content_fee else content_fee.raw.decode()) ) + return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0] @@ -246,10 +255,12 @@ class SQLiteStorage(SQLiteMixin): create table if not exists file ( stream_hash text primary key not null references stream, - file_name text not null, - download_directory text not null, + file_name text, + download_directory text, blob_data_rate real not null, - status text not null + status text not null, + saved_file integer not null, + content_fee text ); create table if not exists content_claim ( @@ -430,7 +441,7 @@ class SQLiteStorage(SQLiteMixin): def set_files_as_streaming(self, stream_hashes: typing.List[str]): def _set_streaming(transaction: sqlite3.Connection): transaction.executemany( - "update file set file_name='{stream}', download_directory='{stream}' where stream_hash=?", + "update file set file_name=null, download_directory=null where stream_hash=?", [(stream_hash, ) for stream_hash in stream_hashes] ) @@ -509,16 +520,42 @@ class SQLiteStorage(SQLiteMixin): # # # # # # # # # file stuff # # # # # # # # # - def save_downloaded_file(self, stream_hash, file_name, download_directory, - data_payment_rate) -> typing.Awaitable[int]: + def save_downloaded_file(self, stream_hash: str, file_name: typing.Optional[str], + download_directory: typing.Optional[str], data_payment_rate: float, + content_fee: typing.Optional[Transaction] = None) -> typing.Awaitable[int]: return self.save_published_file( - stream_hash, file_name, download_directory, data_payment_rate, status="running" + stream_hash, file_name, download_directory, data_payment_rate, status="running", + content_fee=content_fee ) def save_published_file(self, stream_hash: str, file_name: typing.Optional[str], download_directory: typing.Optional[str], data_payment_rate: float, - status="finished") -> typing.Awaitable[int]: - return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status) + status: str = "finished", + content_fee: typing.Optional[Transaction] = None) -> typing.Awaitable[int]: + return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status, + content_fee) + + async def update_manually_removed_files_since_last_run(self): + """ + Update files that have been removed from the downloads directory since the last run + """ + def update_manually_removed_files(transaction: sqlite3.Connection): + removed = [] + for (stream_hash, download_directory, file_name) in transaction.execute( + "select stream_hash, download_directory, file_name from file where saved_file=1" + ).fetchall(): + if download_directory and file_name and os.path.isfile( + os.path.join(binascii.unhexlify(download_directory.encode()).decode(), + binascii.unhexlify(file_name.encode()).decode())): + continue + else: + removed.append((stream_hash,)) + if removed: + transaction.executemany( + "update file set file_name=null, download_directory=null, saved_file=0 where stream_hash=?", + removed + ) + return await self.db.run(update_manually_removed_files) def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]: return self.db.run(get_all_lbry_files) @@ -530,7 +567,7 @@ class SQLiteStorage(SQLiteMixin): async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str], file_name: typing.Optional[str]): if not file_name or not download_dir: - encoded_file_name, encoded_download_dir = "{stream}", "{stream}" + encoded_file_name, encoded_download_dir = None, None else: encoded_file_name = binascii.hexlify(file_name.encode()).decode() encoded_download_dir = binascii.hexlify(download_dir.encode()).decode() @@ -538,18 +575,34 @@ class SQLiteStorage(SQLiteMixin): encoded_download_dir, encoded_file_name, stream_hash, )) - async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']], + async def save_content_fee(self, stream_hash: str, content_fee: Transaction): + return await self.db.execute("update file set content_fee=? where stream_hash=?", ( + binascii.hexlify(content_fee.raw), stream_hash, + )) + + async def set_saved_file(self, stream_hash: str): + return await self.db.execute("update file set saved_file=1 where stream_hash=?", ( + stream_hash, + )) + + async def clear_saved_file(self, stream_hash: str): + return await self.db.execute("update file set saved_file=0 where stream_hash=?", ( + stream_hash, + )) + + async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile', + typing.Optional[Transaction]]], download_directory: str): def _recover(transaction: sqlite3.Connection): - stream_hashes = [d.stream_hash for d, s in descriptors_and_sds] - for descriptor, sd_blob in descriptors_and_sds: + stream_hashes = [x[0].stream_hash for x in descriptors_and_sds] + for descriptor, sd_blob, content_fee in descriptors_and_sds: content_claim = transaction.execute( "select * from content_claim where stream_hash=?", (descriptor.stream_hash, ) ).fetchone() delete_stream(transaction, descriptor) # this will also delete the content claim store_stream(transaction, sd_blob, descriptor) store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name), - download_directory, 0.0, 'stopped') + download_directory, 0.0, 'stopped', content_fee=content_fee) if content_claim: transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim) transaction.executemany( diff --git a/lbrynet/schema/claim.py b/lbrynet/schema/claim.py index c29699d67..d9732b408 100644 --- a/lbrynet/schema/claim.py +++ b/lbrynet/schema/claim.py @@ -215,8 +215,6 @@ class Stream(BaseClaim): if 'sd_hash' in kwargs: self.source.sd_hash = kwargs.pop('sd_hash') - if 'file_size' in kwargs: - self.source.size = kwargs.pop('file_size') if 'file_name' in kwargs: self.source.name = kwargs.pop('file_name') if 'file_hash' in kwargs: @@ -230,6 +228,9 @@ class Stream(BaseClaim): elif self.source.media_type: stream_type = guess_stream_type(self.source.media_type) + if 'file_size' in kwargs: + self.source.size = kwargs.pop('file_size') + if stream_type in ('image', 'video', 'audio'): media = getattr(self, stream_type) media_args = {'file_metadata': None} diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index 511b42738..42b8f29ce 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -198,21 +198,31 @@ class ManagedStream: return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0] def as_dict(self) -> typing.Dict: - if not self.written_bytes and self.output_file_exists: - written_bytes = os.stat(self.full_path).st_size + full_path = self.full_path + file_name = self.file_name + download_directory = self.download_directory + if self.full_path and self.output_file_exists: + if self.written_bytes: + written_bytes = self.written_bytes + else: + written_bytes = os.stat(self.full_path).st_size else: - written_bytes = self.written_bytes + full_path = None + file_name = None + download_directory = None + written_bytes = None return { - 'completed': self.output_file_exists and self.status in ('stopped', 'finished'), - 'file_name': self.file_name, - 'download_directory': self.download_directory, + 'completed': (self.output_file_exists and self.status in ('stopped', 'finished')) or all( + self.blob_manager.is_blob_verified(b.blob_hash) for b in self.descriptor.blobs[:-1]), + 'file_name': file_name, + 'download_directory': download_directory, 'points_paid': 0.0, 'stopped': not self.running, 'stream_hash': self.stream_hash, 'stream_name': self.descriptor.stream_name, 'suggested_file_name': self.descriptor.suggested_file_name, 'sd_hash': self.descriptor.sd_hash, - 'download_path': self.full_path, + 'download_path': full_path, 'mime_type': self.mime_type, 'key': self.descriptor.key, 'total_bytes_lower_bound': self.descriptor.lower_bound_decrypted_length(), @@ -231,7 +241,7 @@ class ManagedStream: 'channel_claim_id': self.channel_claim_id, 'channel_name': self.channel_name, 'claim_name': self.claim_name, - 'content_fee': self.content_fee # TODO: this isn't in the database + 'content_fee': self.content_fee } @classmethod @@ -328,6 +338,11 @@ class ManagedStream: if not self.streaming_responses: self.streaming.clear() + @staticmethod + def _write_decrypted_blob(handle: typing.IO, data: bytes): + handle.write(data) + handle.flush() + async def _save_file(self, output_path: str): log.info("save file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6], output_path) @@ -338,8 +353,7 @@ class ManagedStream: with open(output_path, 'wb') as file_write_handle: async for blob_info, decrypted in self._aiter_read_stream(connection_id=1): log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) - file_write_handle.write(decrypted) - file_write_handle.flush() + await self.loop.run_in_executor(None, self._write_decrypted_blob, file_write_handle, decrypted) self.written_bytes += len(decrypted) if not self.started_writing.is_set(): self.started_writing.set() @@ -351,6 +365,7 @@ class ManagedStream: self.finished_writing.set() log.info("finished saving file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6], self.full_path) + await self.blob_manager.storage.set_saved_file(self.stream_hash) except Exception as err: if os.path.isfile(output_path): log.warning("removing incomplete download %s for %s", output_path, self.sd_hash) @@ -376,7 +391,7 @@ class ManagedStream: os.mkdir(self.download_directory) self._file_name = await get_next_available_file_name( self.loop, self.download_directory, - file_name or self._file_name or self.descriptor.suggested_file_name + file_name or self.descriptor.suggested_file_name ) await self.blob_manager.storage.change_file_download_dir_and_file_name( self.stream_hash, self.download_directory, self.file_name @@ -461,8 +476,18 @@ class ManagedStream: get_range = get_range.split('=')[1] start, end = get_range.split('-') size = 0 + for blob in self.descriptor.blobs[:-1]: size += blob.length - 1 + if self.stream_claim_info and self.stream_claim_info.claim.stream.source.size: + size_from_claim = int(self.stream_claim_info.claim.stream.source.size) + if not size_from_claim <= size <= size_from_claim + 16: + raise ValueError("claim contains implausible stream size") + log.debug("using stream size from claim") + size = size_from_claim + elif self.stream_claim_info: + log.debug("estimating stream size") + start = int(start) end = int(end) if end else size - 1 skip_blobs = start // 2097150 diff --git a/lbrynet/stream/stream_manager.py b/lbrynet/stream/stream_manager.py index aa2f93610..b64d452dd 100644 --- a/lbrynet/stream/stream_manager.py +++ b/lbrynet/stream/stream_manager.py @@ -21,6 +21,7 @@ if typing.TYPE_CHECKING: from lbrynet.extras.daemon.analytics import AnalyticsManager from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim from lbrynet.wallet import LbryWalletManager + from lbrynet.wallet.transaction import Transaction from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager log = logging.getLogger(__name__) @@ -55,7 +56,9 @@ comparison_operators = { def path_or_none(p) -> typing.Optional[str]: - return None if p == '{stream}' else binascii.unhexlify(p).decode() + if not p: + return + return binascii.unhexlify(p).decode() class StreamManager: @@ -70,8 +73,8 @@ class StreamManager: self.node = node self.analytics_manager = analytics_manager self.streams: typing.Dict[str, ManagedStream] = {} - self.resume_downloading_task: asyncio.Task = None - self.re_reflect_task: asyncio.Task = None + self.resume_saving_task: typing.Optional[asyncio.Task] = None + self.re_reflect_task: typing.Optional[asyncio.Task] = None self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.running_reflector_uploads: typing.List[asyncio.Task] = [] self.started = asyncio.Event(loop=self.loop) @@ -84,7 +87,8 @@ class StreamManager: to_restore = [] async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str, - suggested_file_name: str, key: str) -> typing.Optional[StreamDescriptor]: + suggested_file_name: str, key: str, + content_fee: typing.Optional['Transaction']) -> typing.Optional[StreamDescriptor]: sd_blob = self.blob_manager.get_blob(sd_hash) blobs = await self.storage.get_blobs_for_stream(stream_hash) descriptor = await StreamDescriptor.recover( @@ -92,12 +96,13 @@ class StreamManager: ) if not descriptor: return - to_restore.append((descriptor, sd_blob)) + to_restore.append((descriptor, sd_blob, content_fee)) await asyncio.gather(*[ recover_stream( file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(), - binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key'] + binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key'], + file_info['content_fee'] ) for file_info in file_infos ]) @@ -109,7 +114,7 @@ class StreamManager: async def add_stream(self, rowid: int, sd_hash: str, file_name: typing.Optional[str], download_directory: typing.Optional[str], status: str, - claim: typing.Optional['StoredStreamClaim']): + claim: typing.Optional['StoredStreamClaim'], content_fee: typing.Optional['Transaction']): try: descriptor = await self.blob_manager.get_stream_descriptor(sd_hash) except InvalidStreamDescriptorError as err: @@ -117,16 +122,18 @@ class StreamManager: return stream = ManagedStream( self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status, - claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager + claim, content_fee=content_fee, rowid=rowid, descriptor=descriptor, + analytics_manager=self.analytics_manager ) self.streams[sd_hash] = stream + self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) - async def load_streams_from_database(self): + async def load_and_resume_streams_from_database(self): to_recover = [] to_start = [] - # this will set streams marked as finished and are missing blobs as being stopped - # await self.storage.sync_files_to_blobs() + await self.storage.update_manually_removed_files_since_last_run() + for file_info in await self.storage.get_all_lbry_files(): # if the sd blob is not verified, try to reconstruct it from the database # this could either be because the blob files were deleted manually or save_blobs was not true when @@ -138,29 +145,33 @@ class StreamManager: await self.recover_streams(to_recover) log.info("Initializing %i files", len(to_start)) - if to_start: - await asyncio.gather(*[ - self.add_stream( - file_info['rowid'], file_info['sd_hash'], path_or_none(file_info['file_name']), - path_or_none(file_info['download_directory']), file_info['status'], - file_info['claim'] - ) for file_info in to_start - ]) + to_resume_saving = [] + add_stream_tasks = [] + for file_info in to_start: + file_name = path_or_none(file_info['file_name']) + download_directory = path_or_none(file_info['download_directory']) + if file_name and download_directory and not file_info['saved_file'] and file_info['status'] == 'running': + to_resume_saving.append((file_name, download_directory, file_info['sd_hash'])) + add_stream_tasks.append(self.loop.create_task(self.add_stream( + file_info['rowid'], file_info['sd_hash'], file_name, + download_directory, file_info['status'], + file_info['claim'], file_info['content_fee'] + ))) + if add_stream_tasks: + await asyncio.gather(*add_stream_tasks, loop=self.loop) log.info("Started stream manager with %i files", len(self.streams)) - - async def resume(self): if not self.node: log.warning("no DHT node given, resuming downloads trusting that we can contact reflector") - t = [ - self.loop.create_task( - stream.start(node=self.node, save_now=(stream.full_path is not None)) - if not stream.full_path else - stream.save_file(node=self.node) - ) for stream in self.streams.values() if stream.running - ] - if t: - log.info("resuming %i downloads", len(t)) - await asyncio.gather(*t, loop=self.loop) + if to_resume_saving: + self.resume_saving_task = self.loop.create_task(self.resume(to_resume_saving)) + + async def resume(self, to_resume_saving): + log.info("Resuming saving %i files", len(to_resume_saving)) + await asyncio.gather( + *(self.streams[sd_hash].save_file(file_name, download_directory, node=self.node) + for (file_name, download_directory, sd_hash) in to_resume_saving), + loop=self.loop + ) async def reflect_streams(self): while True: @@ -182,14 +193,13 @@ class StreamManager: await asyncio.sleep(300, loop=self.loop) async def start(self): - await self.load_streams_from_database() - self.resume_downloading_task = self.loop.create_task(self.resume()) + await self.load_and_resume_streams_from_database() self.re_reflect_task = self.loop.create_task(self.reflect_streams()) self.started.set() def stop(self): - if self.resume_downloading_task and not self.resume_downloading_task.done(): - self.resume_downloading_task.cancel() + if self.resume_saving_task and not self.resume_saving_task.done(): + self.resume_saving_task.cancel() if self.re_reflect_task and not self.re_reflect_task.done(): self.re_reflect_task.cancel() while self.streams: @@ -387,6 +397,7 @@ class StreamManager: lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') ) log.info("paid fee of %s for %s", fee_amount, uri) + await self.storage.save_content_fee(stream.stream_hash, stream.content_fee) self.streams[stream.sd_hash] = stream self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) diff --git a/lbrynet/testcase.py b/lbrynet/testcase.py index 29a487cf2..cc508d3a8 100644 --- a/lbrynet/testcase.py +++ b/lbrynet/testcase.py @@ -172,27 +172,43 @@ class CommandTestCase(IntegrationTestCase): return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result'] async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, **kwargs): - with tempfile.NamedTemporaryFile() as file: - file.write(data) - file.flush() - claim = await self.out( - self.daemon.jsonrpc_stream_create(name, bid, file_path=file.name, **kwargs) - ) - self.assertEqual(claim['outputs'][0]['name'], name) - if confirm: - await self.on_transaction_dict(claim) - await self.generate(1) - await self.on_transaction_dict(claim) - return claim + file = tempfile.NamedTemporaryFile() + + def cleanup(): + try: + file.close() + except FileNotFoundError: + pass + + self.addCleanup(cleanup) + file.write(data) + file.flush() + claim = await self.out( + self.daemon.jsonrpc_stream_create(name, bid, file_path=file.name, **kwargs) + ) + self.assertEqual(claim['outputs'][0]['name'], name) + if confirm: + await self.on_transaction_dict(claim) + await self.generate(1) + await self.on_transaction_dict(claim) + return claim async def stream_update(self, claim_id, data=None, confirm=True, **kwargs): if data: - with tempfile.NamedTemporaryFile() as file: - file.write(data) - file.flush() - claim = await self.out( - self.daemon.jsonrpc_stream_update(claim_id, file_path=file.name, **kwargs) - ) + file = tempfile.NamedTemporaryFile() + file.write(data) + file.flush() + + def cleanup(): + try: + file.close() + except FileNotFoundError: + pass + + self.addCleanup(cleanup) + claim = await self.out( + self.daemon.jsonrpc_stream_update(claim_id, file_path=file.name, **kwargs) + ) else: claim = await self.out(self.daemon.jsonrpc_stream_update(claim_id, **kwargs)) self.assertIsNotNone(claim['outputs'][0]['name']) diff --git a/tests/integration/test_file_commands.py b/tests/integration/test_file_commands.py index 78cbd2d07..4cfb4c319 100644 --- a/tests/integration/test_file_commands.py +++ b/tests/integration/test_file_commands.py @@ -239,6 +239,7 @@ class FileCommands(CommandTestCase): await self.daemon.jsonrpc_file_delete(claim_name='icanpay') await self.assertBalance(self.account, '9.925679') response = await self.daemon.jsonrpc_get('lbry://icanpay') + raw_content_fee = response.content_fee.raw await self.ledger.wait(response.content_fee) await self.assertBalance(self.account, '8.925555') self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1) @@ -252,3 +253,10 @@ class FileCommands(CommandTestCase): self.assertEqual( await self.blockchain.get_balance(), starting_balance + block_reward_and_claim_fee ) + + # restart the daemon and make sure the fee is still there + + self.daemon.stream_manager.stop() + await self.daemon.stream_manager.start() + self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1) + self.assertEqual(self.daemon.jsonrpc_file_list()[0].content_fee.raw, raw_content_fee) diff --git a/tests/integration/test_streaming.py b/tests/integration/test_streaming.py index 832e8fed2..c36a5bdd8 100644 --- a/tests/integration/test_streaming.py +++ b/tests/integration/test_streaming.py @@ -24,11 +24,11 @@ class RangeRequests(CommandTestCase): await self.daemon.stream_manager.start() return - async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False): + async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False, file_size=0): self.daemon.conf.save_blobs = save_blobs self.daemon.conf.save_files = save_files self.data = data - await self.stream_create('foo', '0.01', data=self.data) + await self.stream_create('foo', '0.01', data=self.data, file_size=file_size) if save_blobs: self.assertTrue(len(os.listdir(self.daemon.blob_manager.blob_dir)) > 1) await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait() @@ -70,9 +70,10 @@ class RangeRequests(CommandTestCase): self.assertEqual('bytes 0-14/15', content_range) async def test_range_requests_0_padded_bytes(self, size: int = (MAX_BLOB_SIZE - 1) * 4, - expected_range: str = 'bytes 0-8388603/8388604', padding=b''): + expected_range: str = 'bytes 0-8388603/8388604', padding=b'', + file_size=0): self.data = get_random_bytes(size) - await self._setup_stream(self.data) + await self._setup_stream(self.data, file_size=file_size) streamed, content_range, content_length = await self._test_range_requests() self.assertEqual(len(self.data + padding), content_length) self.assertEqual(streamed, self.data + padding) @@ -93,6 +94,11 @@ class RangeRequests(CommandTestCase): ((MAX_BLOB_SIZE - 1) * 4) - 14, padding=b'\x00' * 14 ) + async def test_range_requests_no_padding_size_from_claim(self): + size = ((MAX_BLOB_SIZE - 1) * 4) - 14 + await self.test_range_requests_0_padded_bytes(size, padding=b'', file_size=size, + expected_range=f"bytes 0-{size-1}/{size}") + async def test_range_requests_15_padded_bytes(self): await self.test_range_requests_0_padded_bytes( ((MAX_BLOB_SIZE - 1) * 4) - 15, padding=b'\x00' * 15