recover streams with missing sd blobs, handle previous sd blob bugs

-test download and recover stream with old key sorting
This commit is contained in:
Jack Robison 2019-02-14 18:19:01 -05:00
parent a228d20137
commit dbb6ba6241
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 208 additions and 37 deletions

View file

@ -98,7 +98,7 @@ class BlobFile:
t.add_done_callback(lambda *_: self.finished_writing.set())
return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
log.error("writer error downloading %s: %s", self.blob_hash[:8], str(error))
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
log.exception("something else")
raise error

View file

@ -1,3 +1,4 @@
import os
import logging
import sqlite3
import typing
@ -500,6 +501,31 @@ class SQLiteStorage(SQLiteMixin):
stream_hash
))
async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']],
download_directory: str):
def _recover(transaction: sqlite3.Connection):
stream_hashes = [d.stream_hash for d, s in descriptors_and_sds]
for descriptor, sd_blob in descriptors_and_sds:
content_claim = transaction.execute(
"select * from content_claim where stream_hash=?", (descriptor.stream_hash, )
).fetchone()
delete_stream(transaction, descriptor) # this will also delete the content claim
store_stream(transaction, sd_blob, descriptor)
store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name),
download_directory, 0.0, 'stopped')
if content_claim:
transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim)
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
[(stream_hash, ) for stream_hash in stream_hashes]
)
download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode()
transaction.executemany(
f"update file set download_directory=? where stream_hash=?",
[(download_dir, stream_hash) for stream_hash in stream_hashes]
)
await self.db.run_with_foreign_keys_disabled(_recover)
def get_all_stream_hashes(self):
return self.run_and_return_list("select stream_hash from stream")

View file

@ -4,6 +4,7 @@ import binascii
import logging
import typing
import asyncio
from collections import OrderedDict
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from lbrynet.blob import MAX_BLOB_SIZE
from lbrynet.blob.blob_info import BlobInfo
@ -82,10 +83,41 @@ class StreamDescriptor:
[blob_info.as_dict() for blob_info in self.blobs]), sort_keys=True
).encode()
async def make_sd_blob(self):
sd_hash = self.calculate_sd_hash()
sd_data = self.as_json()
sd_blob = BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
def old_sort_json(self) -> bytes:
blobs = []
for b in self.blobs:
blobs.append(OrderedDict(
[('length', b.length), ('blob_num', b.blob_num), ('iv', b.iv)] if not b.blob_hash else
[('length', b.length), ('blob_num', b.blob_num), ('blob_hash', b.blob_hash), ('iv', b.iv)]
))
if not b.blob_hash:
break
return json.dumps(
OrderedDict([
('stream_name', binascii.hexlify(self.stream_name.encode()).decode()),
('blobs', blobs),
('stream_type', 'lbryfile'),
('key', self.key),
('suggested_file_name', binascii.hexlify(self.suggested_file_name.encode()).decode()),
('stream_hash', self.stream_hash),
])
).encode()
def calculate_old_sort_sd_hash(self):
h = get_lbry_hash_obj()
h.update(self.old_sort_json())
return h.hexdigest()
async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None,
old_sort: typing.Optional[bool] = False):
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
if not old_sort:
sd_data = self.as_json()
else:
sd_data = self.old_sort_json()
sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
if blob_file_obj:
blob_file_obj.set_length(len(sd_data))
if not sd_blob.get_is_verified():
writer = sd_blob.open_for_writing()
writer.write(sd_data)
@ -160,8 +192,8 @@ class StreamDescriptor:
@classmethod
async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None
) -> 'StreamDescriptor':
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None,
old_sort: bool = False) -> 'StreamDescriptor':
blobs: typing.List[BlobInfo] = []
@ -180,7 +212,7 @@ class StreamDescriptor:
loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path),
blobs
)
sd_blob = await descriptor.make_sd_blob()
sd_blob = await descriptor.make_sd_blob(old_sort=old_sort)
descriptor.sd_hash = sd_blob.blob_hash
return descriptor
@ -190,3 +222,19 @@ class StreamDescriptor:
def upper_bound_decrypted_length(self) -> int:
return self.lower_bound_decrypted_length() + (AES.block_size // 8)
@classmethod
async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str,
suggested_file_name: str, key: str,
blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']:
descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name,
blobs, stream_hash, sd_blob.blob_hash)
if descriptor.calculate_sd_hash() == sd_blob.blob_hash: # first check for a normal valid sd
old_sort = False
elif descriptor.calculate_old_sort_sd_hash() == sd_blob.blob_hash: # check if old field sorting works
old_sort = True
else:
return
await descriptor.make_sd_blob(sd_blob, old_sort)
return descriptor

View file

@ -6,6 +6,7 @@ import logging
import random
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \
DownloadDataTimeout, DownloadSDTimeout
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import ClaimDict
@ -16,13 +17,12 @@ if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.dht.node import Node
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim
from lbrynet.extras.wallet import LbryWalletManager
from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager
log = logging.getLogger(__name__)
filter_fields = [
'status',
'file_name',
@ -128,33 +128,75 @@ class StreamManager:
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
)
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim):
sd_blob = self.blob_manager.get_blob(sd_hash)
if sd_blob.get_is_verified():
try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
except InvalidStreamDescriptorError as err:
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
return
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = []
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
stream = ManagedStream(
self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim
async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str,
suggested_file_name: str, key: str) -> typing.Optional[StreamDescriptor]:
sd_blob = self.blob_manager.get_blob(sd_hash)
blobs = await self.storage.get_blobs_for_stream(stream_hash)
descriptor = await StreamDescriptor.recover(
self.blob_manager.blob_dir, sd_blob, stream_hash, stream_name, suggested_file_name, key, blobs
)
self.streams.add(stream)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
if not descriptor:
return
to_restore.append((descriptor, sd_blob))
await asyncio.gather(*[
recover_stream(
file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(),
binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key']
) for file_info in file_infos
])
if to_restore:
await self.storage.recover_streams(to_restore, self.config.download_dir)
log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos))
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str,
claim: typing.Optional['StoredStreamClaim']):
sd_blob = self.blob_manager.get_blob(sd_hash)
if not sd_blob.get_is_verified():
return
try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
except InvalidStreamDescriptorError as err:
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
return
if status == ManagedStream.STATUS_RUNNING:
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
else:
downloader = None
stream = ManagedStream(
self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim
)
self.streams.add(stream)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def load_streams_from_database(self):
log.info("Initializing stream manager from %s", self.storage._db_path)
file_infos = await self.storage.get_all_lbry_files()
log.info("Initializing %i files", len(file_infos))
to_recover = []
for file_info in await self.storage.get_all_lbry_files():
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_recover.append(file_info)
if to_recover:
log.info("Attempting to recover %i streams", len(to_recover))
await self.recover_streams(to_recover)
to_start = []
for file_info in await self.storage.get_all_lbry_files():
if self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_start.append(file_info)
log.info("Initializing %i files", len(to_start))
await asyncio.gather(*[
self.add_stream(
file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(),
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'], file_info['claim']
) for file_info in file_infos
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'],
file_info['claim']
) for file_info in to_start
])
log.info("Started stream manager with %i files", len(file_infos))
log.info("Started stream manager with %i files", len(self.streams))
async def resume(self):
if self.node:
@ -337,6 +379,8 @@ class StreamManager:
raise ValueError(f"'{sort_by}' is not a valid field to sort by")
if comparison and comparison not in comparison_operators:
raise ValueError(f"'{comparison}' is not a valid comparison")
if 'full_status' in search_by:
del search_by['full_status']
for search in search_by.keys():
if search not in filter_fields:
raise ValueError(f"'{search}' is not a valid search operation")
@ -345,8 +389,6 @@ class StreamManager:
streams = []
for stream in self.streams:
for search, val in search_by.items():
if search == 'full_status':
continue
if comparison_operators[comparison](getattr(stream, search), val):
streams.append(stream)
break

View file

@ -75,3 +75,35 @@ class TestStreamDescriptor(AsyncioTestCase):
async def test_zero_length_blob(self):
self.sd_dict['blobs'][-2]['length'] = 0
await self._test_invalid_sd()
class TestRecoverOldStreamDescriptors(AsyncioTestCase):
async def test_old_key_sort_sd_blob(self):
loop = asyncio.get_event_loop()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
storage = SQLiteStorage(Config(), ":memory:")
await storage.open()
blob_manager = BlobFileManager(loop, tmp_dir, storage)
sd_bytes = b'{"stream_name": "4f62616d6120446f6e6b65792d322e73746c", "blobs": [{"length": 1153488, "blob_num' \
b'": 0, "blob_hash": "9fa32a249ce3f2d4e46b78599800f368b72f2a7f22b81df443c7f6bdbef496bd61b4c0079c7' \
b'3d79c8bb9be9a6bf86592", "iv": "0bf348867244019c9e22196339016ea6"}, {"length": 0, "blob_num": 1,' \
b' "iv": "9f36abae16955463919b07ed530a3d18"}], "stream_type": "lbryfile", "key": "a03742b87628aa7' \
b'228e48f1dcd207e48", "suggested_file_name": "4f62616d6120446f6e6b65792d322e73746c", "stream_hash' \
b'": "b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e1' \
b'60207"}'
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
stream_hash = 'b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e160207'
blob = blob_manager.get_blob(sd_hash)
blob.set_length(len(sd_bytes))
writer = blob.open_for_writing()
writer.write(sd_bytes)
await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
loop, blob_manager.blob_dir, blob
)
self.assertEqual(stream_hash, descriptor.get_stream_hash())
self.assertEqual(sd_hash, descriptor.calculate_old_sort_sd_hash())
self.assertNotEqual(sd_hash, descriptor.calculate_sd_hash())

View file

@ -23,6 +23,8 @@ def get_mock_node(peer):
mock_node = mock.Mock(spec=Node)
mock_node.accumulate_peers = mock_accumulate_peers
mock_node.joined = asyncio.Event()
mock_node.joined.set()
return mock_node
@ -91,15 +93,13 @@ def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
class TestStreamManager(BlobExchangeTestBase):
async def asyncSetUp(self):
await super().asyncSetUp()
async def setup_stream_manager(self, balance=10.0, fee=None, old_sort=False):
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(os.urandom(20000000))
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
async def setup_stream_manager(self, balance=10.0, fee=None):
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path,
old_sort=old_sort)
self.sd_hash = descriptor.sd_hash
self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage, balance, fee)
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet,
self.client_storage, get_mock_node(self.server_from_client))
@ -169,3 +169,26 @@ class TestStreamManager(BlobExchangeTestBase):
await self.setup_stream_manager(1000000.0, fee)
with self.assertRaises(KeyFeeAboveMaxAllowed):
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
async def test_download_then_recover_stream_on_startup(self, old_sort=False):
await self.setup_stream_manager(old_sort=old_sort)
self.assertSetEqual(self.stream_manager.streams, set())
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
await stream.downloader.stream_finished_event.wait()
self.stream_manager.stop()
self.client_blob_manager.stop()
os.remove(os.path.join(self.client_blob_manager.blob_dir, stream.sd_hash))
for blob in stream.descriptor.blobs[:-1]:
os.remove(os.path.join(self.client_blob_manager.blob_dir, blob.blob_hash))
await self.client_blob_manager.setup()
await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams))
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash)
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists)
self.assertTrue(sd_blob.get_is_verified())
def test_download_then_recover_old_sort_stream_on_startup(self):
return self.test_download_then_recover_stream_on_startup(old_sort=True)