recover streams with missing sd blobs, handle previous sd blob bugs
-test download and recover stream with old key sorting
This commit is contained in:
parent
a228d20137
commit
dbb6ba6241
6 changed files with 208 additions and 37 deletions
|
@ -98,7 +98,7 @@ class BlobFile:
|
|||
t.add_done_callback(lambda *_: self.finished_writing.set())
|
||||
return
|
||||
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
|
||||
log.error("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
|
||||
log.exception("something else")
|
||||
raise error
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import logging
|
||||
import sqlite3
|
||||
import typing
|
||||
|
@ -500,6 +501,31 @@ class SQLiteStorage(SQLiteMixin):
|
|||
stream_hash
|
||||
))
|
||||
|
||||
async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']],
|
||||
download_directory: str):
|
||||
def _recover(transaction: sqlite3.Connection):
|
||||
stream_hashes = [d.stream_hash for d, s in descriptors_and_sds]
|
||||
for descriptor, sd_blob in descriptors_and_sds:
|
||||
content_claim = transaction.execute(
|
||||
"select * from content_claim where stream_hash=?", (descriptor.stream_hash, )
|
||||
).fetchone()
|
||||
delete_stream(transaction, descriptor) # this will also delete the content claim
|
||||
store_stream(transaction, sd_blob, descriptor)
|
||||
store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name),
|
||||
download_directory, 0.0, 'stopped')
|
||||
if content_claim:
|
||||
transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim)
|
||||
transaction.executemany(
|
||||
"update file set status='stopped' where stream_hash=?",
|
||||
[(stream_hash, ) for stream_hash in stream_hashes]
|
||||
)
|
||||
download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode()
|
||||
transaction.executemany(
|
||||
f"update file set download_directory=? where stream_hash=?",
|
||||
[(download_dir, stream_hash) for stream_hash in stream_hashes]
|
||||
)
|
||||
await self.db.run_with_foreign_keys_disabled(_recover)
|
||||
|
||||
def get_all_stream_hashes(self):
|
||||
return self.run_and_return_list("select stream_hash from stream")
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import binascii
|
|||
import logging
|
||||
import typing
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||
from lbrynet.blob import MAX_BLOB_SIZE
|
||||
from lbrynet.blob.blob_info import BlobInfo
|
||||
|
@ -82,10 +83,41 @@ class StreamDescriptor:
|
|||
[blob_info.as_dict() for blob_info in self.blobs]), sort_keys=True
|
||||
).encode()
|
||||
|
||||
async def make_sd_blob(self):
|
||||
sd_hash = self.calculate_sd_hash()
|
||||
sd_data = self.as_json()
|
||||
sd_blob = BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
|
||||
def old_sort_json(self) -> bytes:
|
||||
blobs = []
|
||||
for b in self.blobs:
|
||||
blobs.append(OrderedDict(
|
||||
[('length', b.length), ('blob_num', b.blob_num), ('iv', b.iv)] if not b.blob_hash else
|
||||
[('length', b.length), ('blob_num', b.blob_num), ('blob_hash', b.blob_hash), ('iv', b.iv)]
|
||||
))
|
||||
if not b.blob_hash:
|
||||
break
|
||||
return json.dumps(
|
||||
OrderedDict([
|
||||
('stream_name', binascii.hexlify(self.stream_name.encode()).decode()),
|
||||
('blobs', blobs),
|
||||
('stream_type', 'lbryfile'),
|
||||
('key', self.key),
|
||||
('suggested_file_name', binascii.hexlify(self.suggested_file_name.encode()).decode()),
|
||||
('stream_hash', self.stream_hash),
|
||||
])
|
||||
).encode()
|
||||
|
||||
def calculate_old_sort_sd_hash(self):
|
||||
h = get_lbry_hash_obj()
|
||||
h.update(self.old_sort_json())
|
||||
return h.hexdigest()
|
||||
|
||||
async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None,
|
||||
old_sort: typing.Optional[bool] = False):
|
||||
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
|
||||
if not old_sort:
|
||||
sd_data = self.as_json()
|
||||
else:
|
||||
sd_data = self.old_sort_json()
|
||||
sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
|
||||
if blob_file_obj:
|
||||
blob_file_obj.set_length(len(sd_data))
|
||||
if not sd_blob.get_is_verified():
|
||||
writer = sd_blob.open_for_writing()
|
||||
writer.write(sd_data)
|
||||
|
@ -160,8 +192,8 @@ class StreamDescriptor:
|
|||
@classmethod
|
||||
async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
|
||||
file_path: str, key: typing.Optional[bytes] = None,
|
||||
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None
|
||||
) -> 'StreamDescriptor':
|
||||
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None,
|
||||
old_sort: bool = False) -> 'StreamDescriptor':
|
||||
|
||||
blobs: typing.List[BlobInfo] = []
|
||||
|
||||
|
@ -180,7 +212,7 @@ class StreamDescriptor:
|
|||
loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path),
|
||||
blobs
|
||||
)
|
||||
sd_blob = await descriptor.make_sd_blob()
|
||||
sd_blob = await descriptor.make_sd_blob(old_sort=old_sort)
|
||||
descriptor.sd_hash = sd_blob.blob_hash
|
||||
return descriptor
|
||||
|
||||
|
@ -190,3 +222,19 @@ class StreamDescriptor:
|
|||
|
||||
def upper_bound_decrypted_length(self) -> int:
|
||||
return self.lower_bound_decrypted_length() + (AES.block_size // 8)
|
||||
|
||||
@classmethod
|
||||
async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str,
|
||||
suggested_file_name: str, key: str,
|
||||
blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']:
|
||||
descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name,
|
||||
blobs, stream_hash, sd_blob.blob_hash)
|
||||
|
||||
if descriptor.calculate_sd_hash() == sd_blob.blob_hash: # first check for a normal valid sd
|
||||
old_sort = False
|
||||
elif descriptor.calculate_old_sort_sd_hash() == sd_blob.blob_hash: # check if old field sorting works
|
||||
old_sort = True
|
||||
else:
|
||||
return
|
||||
await descriptor.make_sd_blob(sd_blob, old_sort)
|
||||
return descriptor
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
import random
|
||||
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \
|
||||
DownloadDataTimeout, DownloadSDTimeout
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.downloader import StreamDownloader
|
||||
from lbrynet.stream.managed_stream import ManagedStream
|
||||
from lbrynet.schema.claim import ClaimDict
|
||||
|
@ -16,13 +17,12 @@ if typing.TYPE_CHECKING:
|
|||
from lbrynet.conf import Config
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.extras.daemon.storage import SQLiteStorage
|
||||
from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim
|
||||
from lbrynet.extras.wallet import LbryWalletManager
|
||||
from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
filter_fields = [
|
||||
'status',
|
||||
'file_name',
|
||||
|
@ -128,33 +128,75 @@ class StreamManager:
|
|||
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
|
||||
)
|
||||
|
||||
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim):
|
||||
sd_blob = self.blob_manager.get_blob(sd_hash)
|
||||
if sd_blob.get_is_verified():
|
||||
try:
|
||||
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
|
||||
except InvalidStreamDescriptorError as err:
|
||||
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
|
||||
return
|
||||
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
|
||||
to_restore = []
|
||||
|
||||
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
|
||||
stream = ManagedStream(
|
||||
self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim
|
||||
async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str,
|
||||
suggested_file_name: str, key: str) -> typing.Optional[StreamDescriptor]:
|
||||
sd_blob = self.blob_manager.get_blob(sd_hash)
|
||||
blobs = await self.storage.get_blobs_for_stream(stream_hash)
|
||||
descriptor = await StreamDescriptor.recover(
|
||||
self.blob_manager.blob_dir, sd_blob, stream_hash, stream_name, suggested_file_name, key, blobs
|
||||
)
|
||||
self.streams.add(stream)
|
||||
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
||||
if not descriptor:
|
||||
return
|
||||
to_restore.append((descriptor, sd_blob))
|
||||
|
||||
await asyncio.gather(*[
|
||||
recover_stream(
|
||||
file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(),
|
||||
binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key']
|
||||
) for file_info in file_infos
|
||||
])
|
||||
|
||||
if to_restore:
|
||||
await self.storage.recover_streams(to_restore, self.config.download_dir)
|
||||
log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos))
|
||||
|
||||
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str,
|
||||
claim: typing.Optional['StoredStreamClaim']):
|
||||
sd_blob = self.blob_manager.get_blob(sd_hash)
|
||||
if not sd_blob.get_is_verified():
|
||||
return
|
||||
try:
|
||||
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
|
||||
except InvalidStreamDescriptorError as err:
|
||||
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
|
||||
return
|
||||
if status == ManagedStream.STATUS_RUNNING:
|
||||
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
|
||||
else:
|
||||
downloader = None
|
||||
stream = ManagedStream(
|
||||
self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim
|
||||
)
|
||||
self.streams.add(stream)
|
||||
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
||||
|
||||
async def load_streams_from_database(self):
|
||||
log.info("Initializing stream manager from %s", self.storage._db_path)
|
||||
file_infos = await self.storage.get_all_lbry_files()
|
||||
log.info("Initializing %i files", len(file_infos))
|
||||
to_recover = []
|
||||
for file_info in await self.storage.get_all_lbry_files():
|
||||
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
|
||||
to_recover.append(file_info)
|
||||
|
||||
if to_recover:
|
||||
log.info("Attempting to recover %i streams", len(to_recover))
|
||||
await self.recover_streams(to_recover)
|
||||
|
||||
to_start = []
|
||||
for file_info in await self.storage.get_all_lbry_files():
|
||||
if self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
|
||||
to_start.append(file_info)
|
||||
log.info("Initializing %i files", len(to_start))
|
||||
|
||||
await asyncio.gather(*[
|
||||
self.add_stream(
|
||||
file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(),
|
||||
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'], file_info['claim']
|
||||
) for file_info in file_infos
|
||||
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'],
|
||||
file_info['claim']
|
||||
) for file_info in to_start
|
||||
])
|
||||
log.info("Started stream manager with %i files", len(file_infos))
|
||||
log.info("Started stream manager with %i files", len(self.streams))
|
||||
|
||||
async def resume(self):
|
||||
if self.node:
|
||||
|
@ -337,6 +379,8 @@ class StreamManager:
|
|||
raise ValueError(f"'{sort_by}' is not a valid field to sort by")
|
||||
if comparison and comparison not in comparison_operators:
|
||||
raise ValueError(f"'{comparison}' is not a valid comparison")
|
||||
if 'full_status' in search_by:
|
||||
del search_by['full_status']
|
||||
for search in search_by.keys():
|
||||
if search not in filter_fields:
|
||||
raise ValueError(f"'{search}' is not a valid search operation")
|
||||
|
@ -345,8 +389,6 @@ class StreamManager:
|
|||
streams = []
|
||||
for stream in self.streams:
|
||||
for search, val in search_by.items():
|
||||
if search == 'full_status':
|
||||
continue
|
||||
if comparison_operators[comparison](getattr(stream, search), val):
|
||||
streams.append(stream)
|
||||
break
|
||||
|
|
|
@ -75,3 +75,35 @@ class TestStreamDescriptor(AsyncioTestCase):
|
|||
async def test_zero_length_blob(self):
|
||||
self.sd_dict['blobs'][-2]['length'] = 0
|
||||
await self._test_invalid_sd()
|
||||
|
||||
|
||||
class TestRecoverOldStreamDescriptors(AsyncioTestCase):
|
||||
async def test_old_key_sort_sd_blob(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
||||
storage = SQLiteStorage(Config(), ":memory:")
|
||||
await storage.open()
|
||||
blob_manager = BlobFileManager(loop, tmp_dir, storage)
|
||||
|
||||
sd_bytes = b'{"stream_name": "4f62616d6120446f6e6b65792d322e73746c", "blobs": [{"length": 1153488, "blob_num' \
|
||||
b'": 0, "blob_hash": "9fa32a249ce3f2d4e46b78599800f368b72f2a7f22b81df443c7f6bdbef496bd61b4c0079c7' \
|
||||
b'3d79c8bb9be9a6bf86592", "iv": "0bf348867244019c9e22196339016ea6"}, {"length": 0, "blob_num": 1,' \
|
||||
b' "iv": "9f36abae16955463919b07ed530a3d18"}], "stream_type": "lbryfile", "key": "a03742b87628aa7' \
|
||||
b'228e48f1dcd207e48", "suggested_file_name": "4f62616d6120446f6e6b65792d322e73746c", "stream_hash' \
|
||||
b'": "b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e1' \
|
||||
b'60207"}'
|
||||
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
|
||||
stream_hash = 'b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e160207'
|
||||
|
||||
blob = blob_manager.get_blob(sd_hash)
|
||||
blob.set_length(len(sd_bytes))
|
||||
writer = blob.open_for_writing()
|
||||
writer.write(sd_bytes)
|
||||
await blob.verified.wait()
|
||||
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
loop, blob_manager.blob_dir, blob
|
||||
)
|
||||
self.assertEqual(stream_hash, descriptor.get_stream_hash())
|
||||
self.assertEqual(sd_hash, descriptor.calculate_old_sort_sd_hash())
|
||||
self.assertNotEqual(sd_hash, descriptor.calculate_sd_hash())
|
||||
|
|
|
@ -23,6 +23,8 @@ def get_mock_node(peer):
|
|||
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
mock_node.accumulate_peers = mock_accumulate_peers
|
||||
mock_node.joined = asyncio.Event()
|
||||
mock_node.joined.set()
|
||||
return mock_node
|
||||
|
||||
|
||||
|
@ -91,15 +93,13 @@ def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
|
|||
|
||||
|
||||
class TestStreamManager(BlobExchangeTestBase):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
async def setup_stream_manager(self, balance=10.0, fee=None, old_sort=False):
|
||||
file_path = os.path.join(self.server_dir, "test_file")
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(os.urandom(20000000))
|
||||
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
|
||||
self.sd_hash = descriptor.calculate_sd_hash()
|
||||
|
||||
async def setup_stream_manager(self, balance=10.0, fee=None):
|
||||
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path,
|
||||
old_sort=old_sort)
|
||||
self.sd_hash = descriptor.sd_hash
|
||||
self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage, balance, fee)
|
||||
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet,
|
||||
self.client_storage, get_mock_node(self.server_from_client))
|
||||
|
@ -169,3 +169,26 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
await self.setup_stream_manager(1000000.0, fee)
|
||||
with self.assertRaises(KeyFeeAboveMaxAllowed):
|
||||
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
|
||||
|
||||
async def test_download_then_recover_stream_on_startup(self, old_sort=False):
|
||||
await self.setup_stream_manager(old_sort=old_sort)
|
||||
self.assertSetEqual(self.stream_manager.streams, set())
|
||||
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
|
||||
await stream.downloader.stream_finished_event.wait()
|
||||
self.stream_manager.stop()
|
||||
self.client_blob_manager.stop()
|
||||
os.remove(os.path.join(self.client_blob_manager.blob_dir, stream.sd_hash))
|
||||
for blob in stream.descriptor.blobs[:-1]:
|
||||
os.remove(os.path.join(self.client_blob_manager.blob_dir, blob.blob_hash))
|
||||
await self.client_blob_manager.setup()
|
||||
await self.stream_manager.start()
|
||||
self.assertEqual(1, len(self.stream_manager.streams))
|
||||
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash)
|
||||
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status)
|
||||
|
||||
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
|
||||
self.assertTrue(sd_blob.file_exists)
|
||||
self.assertTrue(sd_blob.get_is_verified())
|
||||
|
||||
def test_download_then_recover_old_sort_stream_on_startup(self):
|
||||
return self.test_download_then_recover_stream_on_startup(old_sort=True)
|
||||
|
|
Loading…
Add table
Reference in a new issue