diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index db4fd5f88..74ecc1e5a 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -98,7 +98,7 @@ class BlobFile: t.add_done_callback(lambda *_: self.finished_writing.set()) return if isinstance(error, (InvalidBlobHashError, InvalidDataError)): - log.error("writer error downloading %s: %s", self.blob_hash[:8], str(error)) + log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error)) elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)): log.exception("something else") raise error diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index e6643b62c..7830a4a77 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -1,3 +1,4 @@ +import os import logging import sqlite3 import typing @@ -500,6 +501,31 @@ class SQLiteStorage(SQLiteMixin): stream_hash )) + async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']], + download_directory: str): + def _recover(transaction: sqlite3.Connection): + stream_hashes = [d.stream_hash for d, s in descriptors_and_sds] + for descriptor, sd_blob in descriptors_and_sds: + content_claim = transaction.execute( + "select * from content_claim where stream_hash=?", (descriptor.stream_hash, ) + ).fetchone() + delete_stream(transaction, descriptor) # this will also delete the content claim + store_stream(transaction, sd_blob, descriptor) + store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name), + download_directory, 0.0, 'stopped') + if content_claim: + transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim) + transaction.executemany( + "update file set status='stopped' where stream_hash=?", + [(stream_hash, ) for stream_hash in stream_hashes] + ) + download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode() + transaction.executemany( + f"update file set download_directory=? where stream_hash=?", + [(download_dir, stream_hash) for stream_hash in stream_hashes] + ) + await self.db.run_with_foreign_keys_disabled(_recover) + def get_all_stream_hashes(self): return self.run_and_return_list("select stream_hash from stream") diff --git a/lbrynet/stream/descriptor.py b/lbrynet/stream/descriptor.py index 658e593f2..ad676a38e 100644 --- a/lbrynet/stream/descriptor.py +++ b/lbrynet/stream/descriptor.py @@ -4,6 +4,7 @@ import binascii import logging import typing import asyncio +from collections import OrderedDict from cryptography.hazmat.primitives.ciphers.algorithms import AES from lbrynet.blob import MAX_BLOB_SIZE from lbrynet.blob.blob_info import BlobInfo @@ -82,10 +83,41 @@ class StreamDescriptor: [blob_info.as_dict() for blob_info in self.blobs]), sort_keys=True ).encode() - async def make_sd_blob(self): - sd_hash = self.calculate_sd_hash() - sd_data = self.as_json() - sd_blob = BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data)) + def old_sort_json(self) -> bytes: + blobs = [] + for b in self.blobs: + blobs.append(OrderedDict( + [('length', b.length), ('blob_num', b.blob_num), ('iv', b.iv)] if not b.blob_hash else + [('length', b.length), ('blob_num', b.blob_num), ('blob_hash', b.blob_hash), ('iv', b.iv)] + )) + if not b.blob_hash: + break + return json.dumps( + OrderedDict([ + ('stream_name', binascii.hexlify(self.stream_name.encode()).decode()), + ('blobs', blobs), + ('stream_type', 'lbryfile'), + ('key', self.key), + ('suggested_file_name', binascii.hexlify(self.suggested_file_name.encode()).decode()), + ('stream_hash', self.stream_hash), + ]) + ).encode() + + def calculate_old_sort_sd_hash(self): + h = get_lbry_hash_obj() + h.update(self.old_sort_json()) + return h.hexdigest() + + async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None, + old_sort: typing.Optional[bool] = False): + sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash() + if not old_sort: + sd_data = self.as_json() + else: + sd_data = self.old_sort_json() + sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data)) + if blob_file_obj: + blob_file_obj.set_length(len(sd_data)) if not sd_blob.get_is_verified(): writer = sd_blob.open_for_writing() writer.write(sd_data) @@ -160,8 +192,8 @@ class StreamDescriptor: @classmethod async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str, file_path: str, key: typing.Optional[bytes] = None, - iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None - ) -> 'StreamDescriptor': + iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None, + old_sort: bool = False) -> 'StreamDescriptor': blobs: typing.List[BlobInfo] = [] @@ -180,7 +212,7 @@ class StreamDescriptor: loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path), blobs ) - sd_blob = await descriptor.make_sd_blob() + sd_blob = await descriptor.make_sd_blob(old_sort=old_sort) descriptor.sd_hash = sd_blob.blob_hash return descriptor @@ -190,3 +222,19 @@ class StreamDescriptor: def upper_bound_decrypted_length(self) -> int: return self.lower_bound_decrypted_length() + (AES.block_size // 8) + + @classmethod + async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str, + suggested_file_name: str, key: str, + blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']: + descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name, + blobs, stream_hash, sd_blob.blob_hash) + + if descriptor.calculate_sd_hash() == sd_blob.blob_hash: # first check for a normal valid sd + old_sort = False + elif descriptor.calculate_old_sort_sd_hash() == sd_blob.blob_hash: # check if old field sorting works + old_sort = True + else: + return + await descriptor.make_sd_blob(sd_blob, old_sort) + return descriptor diff --git a/lbrynet/stream/stream_manager.py b/lbrynet/stream/stream_manager.py index e911381e6..a69b57234 100644 --- a/lbrynet/stream/stream_manager.py +++ b/lbrynet/stream/stream_manager.py @@ -6,6 +6,7 @@ import logging import random from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \ DownloadDataTimeout, DownloadSDTimeout +from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.managed_stream import ManagedStream from lbrynet.schema.claim import ClaimDict @@ -16,13 +17,12 @@ if typing.TYPE_CHECKING: from lbrynet.conf import Config from lbrynet.blob.blob_manager import BlobFileManager from lbrynet.dht.node import Node - from lbrynet.extras.daemon.storage import SQLiteStorage + from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim from lbrynet.extras.wallet import LbryWalletManager from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager log = logging.getLogger(__name__) - filter_fields = [ 'status', 'file_name', @@ -128,33 +128,75 @@ class StreamManager: self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name ) - async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim): - sd_blob = self.blob_manager.get_blob(sd_hash) - if sd_blob.get_is_verified(): - try: - descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash) - except InvalidStreamDescriptorError as err: - log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err)) - return + async def recover_streams(self, file_infos: typing.List[typing.Dict]): + to_restore = [] - downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name) - stream = ManagedStream( - self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim + async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str, + suggested_file_name: str, key: str) -> typing.Optional[StreamDescriptor]: + sd_blob = self.blob_manager.get_blob(sd_hash) + blobs = await self.storage.get_blobs_for_stream(stream_hash) + descriptor = await StreamDescriptor.recover( + self.blob_manager.blob_dir, sd_blob, stream_hash, stream_name, suggested_file_name, key, blobs ) - self.streams.add(stream) - self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) + if not descriptor: + return + to_restore.append((descriptor, sd_blob)) + + await asyncio.gather(*[ + recover_stream( + file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(), + binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key'] + ) for file_info in file_infos + ]) + + if to_restore: + await self.storage.recover_streams(to_restore, self.config.download_dir) + log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos)) + + async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, + claim: typing.Optional['StoredStreamClaim']): + sd_blob = self.blob_manager.get_blob(sd_hash) + if not sd_blob.get_is_verified(): + return + try: + descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash) + except InvalidStreamDescriptorError as err: + log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err)) + return + if status == ManagedStream.STATUS_RUNNING: + downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name) + else: + downloader = None + stream = ManagedStream( + self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim + ) + self.streams.add(stream) + self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) async def load_streams_from_database(self): - log.info("Initializing stream manager from %s", self.storage._db_path) - file_infos = await self.storage.get_all_lbry_files() - log.info("Initializing %i files", len(file_infos)) + to_recover = [] + for file_info in await self.storage.get_all_lbry_files(): + if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): + to_recover.append(file_info) + + if to_recover: + log.info("Attempting to recover %i streams", len(to_recover)) + await self.recover_streams(to_recover) + + to_start = [] + for file_info in await self.storage.get_all_lbry_files(): + if self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): + to_start.append(file_info) + log.info("Initializing %i files", len(to_start)) + await asyncio.gather(*[ self.add_stream( file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(), - binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'], file_info['claim'] - ) for file_info in file_infos + binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'], + file_info['claim'] + ) for file_info in to_start ]) - log.info("Started stream manager with %i files", len(file_infos)) + log.info("Started stream manager with %i files", len(self.streams)) async def resume(self): if self.node: @@ -337,6 +379,8 @@ class StreamManager: raise ValueError(f"'{sort_by}' is not a valid field to sort by") if comparison and comparison not in comparison_operators: raise ValueError(f"'{comparison}' is not a valid comparison") + if 'full_status' in search_by: + del search_by['full_status'] for search in search_by.keys(): if search not in filter_fields: raise ValueError(f"'{search}' is not a valid search operation") @@ -345,8 +389,6 @@ class StreamManager: streams = [] for stream in self.streams: for search, val in search_by.items(): - if search == 'full_status': - continue if comparison_operators[comparison](getattr(stream, search), val): streams.append(stream) break diff --git a/tests/unit/stream/test_stream_descriptor.py b/tests/unit/stream/test_stream_descriptor.py index f634e9972..a0ffdb764 100644 --- a/tests/unit/stream/test_stream_descriptor.py +++ b/tests/unit/stream/test_stream_descriptor.py @@ -75,3 +75,35 @@ class TestStreamDescriptor(AsyncioTestCase): async def test_zero_length_blob(self): self.sd_dict['blobs'][-2]['length'] = 0 await self._test_invalid_sd() + + +class TestRecoverOldStreamDescriptors(AsyncioTestCase): + async def test_old_key_sort_sd_blob(self): + loop = asyncio.get_event_loop() + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + storage = SQLiteStorage(Config(), ":memory:") + await storage.open() + blob_manager = BlobFileManager(loop, tmp_dir, storage) + + sd_bytes = b'{"stream_name": "4f62616d6120446f6e6b65792d322e73746c", "blobs": [{"length": 1153488, "blob_num' \ + b'": 0, "blob_hash": "9fa32a249ce3f2d4e46b78599800f368b72f2a7f22b81df443c7f6bdbef496bd61b4c0079c7' \ + b'3d79c8bb9be9a6bf86592", "iv": "0bf348867244019c9e22196339016ea6"}, {"length": 0, "blob_num": 1,' \ + b' "iv": "9f36abae16955463919b07ed530a3d18"}], "stream_type": "lbryfile", "key": "a03742b87628aa7' \ + b'228e48f1dcd207e48", "suggested_file_name": "4f62616d6120446f6e6b65792d322e73746c", "stream_hash' \ + b'": "b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e1' \ + b'60207"}' + sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2' + stream_hash = 'b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e160207' + + blob = blob_manager.get_blob(sd_hash) + blob.set_length(len(sd_bytes)) + writer = blob.open_for_writing() + writer.write(sd_bytes) + await blob.verified.wait() + descriptor = await StreamDescriptor.from_stream_descriptor_blob( + loop, blob_manager.blob_dir, blob + ) + self.assertEqual(stream_hash, descriptor.get_stream_hash()) + self.assertEqual(sd_hash, descriptor.calculate_old_sort_sd_hash()) + self.assertNotEqual(sd_hash, descriptor.calculate_sd_hash()) diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index a2076c15d..2d564e760 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -23,6 +23,8 @@ def get_mock_node(peer): mock_node = mock.Mock(spec=Node) mock_node.accumulate_peers = mock_accumulate_peers + mock_node.joined = asyncio.Event() + mock_node.joined.set() return mock_node @@ -91,15 +93,13 @@ def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): class TestStreamManager(BlobExchangeTestBase): - async def asyncSetUp(self): - await super().asyncSetUp() + async def setup_stream_manager(self, balance=10.0, fee=None, old_sort=False): file_path = os.path.join(self.server_dir, "test_file") with open(file_path, 'wb') as f: f.write(os.urandom(20000000)) - descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) - self.sd_hash = descriptor.calculate_sd_hash() - - async def setup_stream_manager(self, balance=10.0, fee=None): + descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path, + old_sort=old_sort) + self.sd_hash = descriptor.sd_hash self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage, balance, fee) self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet, self.client_storage, get_mock_node(self.server_from_client)) @@ -169,3 +169,26 @@ class TestStreamManager(BlobExchangeTestBase): await self.setup_stream_manager(1000000.0, fee) with self.assertRaises(KeyFeeAboveMaxAllowed): await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + + async def test_download_then_recover_stream_on_startup(self, old_sort=False): + await self.setup_stream_manager(old_sort=old_sort) + self.assertSetEqual(self.stream_manager.streams, set()) + stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + await stream.downloader.stream_finished_event.wait() + self.stream_manager.stop() + self.client_blob_manager.stop() + os.remove(os.path.join(self.client_blob_manager.blob_dir, stream.sd_hash)) + for blob in stream.descriptor.blobs[:-1]: + os.remove(os.path.join(self.client_blob_manager.blob_dir, blob.blob_hash)) + await self.client_blob_manager.setup() + await self.stream_manager.start() + self.assertEqual(1, len(self.stream_manager.streams)) + self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash) + self.assertEqual('stopped', list(self.stream_manager.streams)[0].status) + + sd_blob = self.client_blob_manager.get_blob(stream.sd_hash) + self.assertTrue(sd_blob.file_exists) + self.assertTrue(sd_blob.get_is_verified()) + + def test_download_then_recover_old_sort_stream_on_startup(self): + return self.test_download_then_recover_stream_on_startup(old_sort=True)