Merge pull request #1919 from lbryio/sync-blobs-on-startup

Sync blobs on startup
This commit is contained in:
Jack Robison 2019-02-15 18:40:33 -05:00 committed by GitHub
commit cfbf0dbe0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 392 additions and 104 deletions

View file

@ -98,7 +98,7 @@ class BlobFile:
t.add_done_callback(lambda *_: self.finished_writing.set()) t.add_done_callback(lambda *_: self.finished_writing.set())
return return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)): if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
log.error("writer error downloading %s: %s", self.blob_hash[:8], str(error)) log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)): elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
log.exception("something else") log.exception("something else")
raise error raise error

View file

@ -30,13 +30,21 @@ class BlobFileManager:
self.blobs: typing.Dict[str, BlobFile] = {} self.blobs: typing.Dict[str, BlobFile] = {}
async def setup(self) -> bool: async def setup(self) -> bool:
def initialize_blob_hashes(): def get_files_in_blob_dir() -> typing.Set[str]:
self.completed_blob_hashes.update( return {
item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name) item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name)
) }
await self.loop.run_in_executor(None, initialize_blob_hashes)
in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir)
self.completed_blob_hashes.update(await self.storage.sync_missing_blobs(in_blobfiles_dir))
return True return True
def stop(self):
while self.blobs:
_, blob = self.blobs.popitem()
blob.close()
self.completed_blob_hashes.clear()
def get_blob(self, blob_hash, length: typing.Optional[int] = None): def get_blob(self, blob_hash, length: typing.Optional[int] = None):
if blob_hash in self.blobs: if blob_hash in self.blobs:
if length and self.blobs[blob_hash].length is None: if length and self.blobs[blob_hash].length is None:
@ -55,7 +63,7 @@ class BlobFileManager:
raise Exception("Blob has a length of 0") raise Exception("Blob has a length of 0")
if blob.blob_hash not in self.completed_blob_hashes: if blob.blob_hash not in self.completed_blob_hashes:
self.completed_blob_hashes.add(blob.blob_hash) self.completed_blob_hashes.add(blob.blob_hash)
await self.storage.add_completed_blob(blob.blob_hash) await self.storage.add_completed_blob(blob.blob_hash, blob.length)
def check_completed_blobs(self, blob_hashes: typing.List[str]) -> typing.List[str]: def check_completed_blobs(self, blob_hashes: typing.List[str]) -> typing.List[str]:
"""Returns of the blobhashes_to_check, which are valid""" """Returns of the blobhashes_to_check, which are valid"""

View file

@ -423,7 +423,15 @@ class KademliaProtocol(DatagramProtocol):
self.ping_queue.enqueue_maybe_ping(peer) self.ping_queue.enqueue_maybe_ping(peer)
elif is_good is True: elif is_good is True:
await self.add_peer(peer) await self.add_peer(peer)
except ValueError as err:
log.debug("error raised handling %s request from %s:%i - %s(%s)",
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
str(err))
await self.send_error(
peer,
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
str(err).encode())
)
except Exception as err: except Exception as err:
log.warning("error raised handling %s request from %s:%i - %s(%s)", log.warning("error raised handling %s request from %s:%i - %s(%s)",
request_datagram.method, peer.address, peer.udp_port, str(type(err)), request_datagram.method, peer.address, peer.udp_port, str(type(err)),

View file

@ -321,9 +321,7 @@ class BlobComponent(Component):
return await self.blob_manager.setup() return await self.blob_manager.setup()
async def stop(self): async def stop(self):
while self.blob_manager and self.blob_manager.blobs: self.blob_manager.stop()
_, blob = self.blob_manager.blobs.popitem()
blob.close()
async def get_status(self): async def get_status(self):
count = 0 count = 0

View file

@ -1382,7 +1382,7 @@ class Daemon(metaclass=JSONRPCServerType):
] ]
} }
""" """
sort = sort or 'status' sort = sort or 'rowid'
comparison = comparison or 'eq' comparison = comparison or 'eq'
return [ return [
stream.as_dict() for stream in self.stream_manager.get_filtered_streams( stream.as_dict() for stream in self.stream_manager.get_filtered_streams(
@ -2023,8 +2023,6 @@ class Daemon(metaclass=JSONRPCServerType):
if file_path: if file_path:
stream = await self.stream_manager.create_stream(file_path) stream = await self.stream_manager.create_stream(file_path)
await self.storage.save_published_file(stream.stream_hash, os.path.basename(file_path),
os.path.dirname(file_path), 0)
claim_dict['stream']['source']['source'] = stream.sd_hash claim_dict['stream']['source']['source'] = stream.sd_hash
claim_dict['stream']['source']['contentType'] = guess_media_type(file_path) claim_dict['stream']['source']['contentType'] = guess_media_type(file_path)

View file

@ -1,3 +1,4 @@
import os
import logging import logging
import sqlite3 import sqlite3
import typing import typing
@ -133,7 +134,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
signed_claims[claim.channel_claim_id].append(claim) signed_claims[claim.channel_claim_id].append(claim)
files.append( files.append(
{ {
"row_id": rowid, "rowid": rowid,
"stream_hash": stream_hash, "stream_hash": stream_hash,
"file_name": file_name, # hex "file_name": file_name, # hex
"download_directory": download_dir, # hex "download_directory": download_dir, # hex
@ -154,6 +155,35 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di
return files return files
def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'):
# add the head blob and set it to be announced
transaction.execute(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?), (?, ?, ?, ?, ?, ?, ?)",
(
sd_blob.blob_hash, sd_blob.length, 0, 1, "pending", 0, 0,
descriptor.blobs[0].blob_hash, descriptor.blobs[0].length, 0, 1, "pending", 0, 0
)
)
# add the rest of the blobs with announcement off
if len(descriptor.blobs) > 2:
transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
[(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
for blob in descriptor.blobs[1:-1]]
)
# associate the blobs to the stream
transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)",
(descriptor.stream_hash, sd_blob.blob_hash, descriptor.key,
binascii.hexlify(descriptor.stream_name.encode()).decode(),
binascii.hexlify(descriptor.suggested_file_name.encode()).decode()))
# add the stream
transaction.executemany(
"insert or ignore into stream_blob values (?, ?, ?, ?)",
[(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv)
for blob in descriptor.blobs]
)
def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor'): def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor'):
blob_hashes = [(blob.blob_hash, ) for blob in descriptor.blobs[:-1]] blob_hashes = [(blob.blob_hash, ) for blob in descriptor.blobs[:-1]]
blob_hashes.append((descriptor.sd_hash, )) blob_hashes.append((descriptor.sd_hash, ))
@ -164,6 +194,16 @@ def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor
transaction.executemany("delete from blob where blob_hash=?", blob_hashes) transaction.executemany("delete from blob where blob_hash=?", blob_hashes)
def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: str, download_directory: str,
data_payment_rate: float, status: str) -> int:
transaction.execute(
"insert or replace into file values (?, ?, ?, ?, ?)",
(stream_hash, binascii.hexlify(file_name.encode()).decode(),
binascii.hexlify(download_directory.encode()).decode(), data_payment_rate, status)
)
return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0]
class SQLiteStorage(SQLiteMixin): class SQLiteStorage(SQLiteMixin):
CREATE_TABLES_QUERY = """ CREATE_TABLES_QUERY = """
pragma foreign_keys=on; pragma foreign_keys=on;
@ -255,9 +295,16 @@ class SQLiteStorage(SQLiteMixin):
# # # # # # # # # blob functions # # # # # # # # # # # # # # # # # # blob functions # # # # # # # # #
def add_completed_blob(self, blob_hash: str): def add_completed_blob(self, blob_hash: str, length: int):
log.debug("Adding a completed blob. blob_hash=%s", blob_hash) def _add_blob(transaction: sqlite3.Connection):
return self.db.execute("update blob set status='finished' where blob.blob_hash=?", (blob_hash, )) transaction.execute(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
(blob_hash, length, 0, 0, "pending", 0, 0)
)
transaction.execute(
"update blob set status='finished' where blob.blob_hash=?", (blob_hash, )
)
return self.db.run(_add_blob)
def get_blob_status(self, blob_hash: str): def get_blob_status(self, blob_hash: str):
return self.run_and_return_one_or_none( return self.run_and_return_one_or_none(
@ -351,6 +398,26 @@ class SQLiteStorage(SQLiteMixin):
def get_all_blob_hashes(self): def get_all_blob_hashes(self):
return self.run_and_return_list("select blob_hash from blob") return self.run_and_return_list("select blob_hash from blob")
def sync_missing_blobs(self, blob_files: typing.Set[str]) -> typing.Awaitable[typing.Set[str]]:
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
to_update = [
(blob_hash, )
for (blob_hash, ) in transaction.execute("select blob_hash from blob where status='finished'")
if blob_hash not in blob_files
]
transaction.executemany(
"update blob set status='pending' where blob_hash=?",
to_update
)
return {
blob_hash
for blob_hash, in _batched_select(
transaction, "select blob_hash from blob where status='finished' and blob_hash in {}",
list(blob_files)
)
}
return self.db.run(_sync_blobs)
# # # # # # # # # stream functions # # # # # # # # # # # # # # # # # # stream functions # # # # # # # # #
async def stream_exists(self, sd_hash: str) -> bool: async def stream_exists(self, sd_hash: str) -> bool:
@ -363,42 +430,20 @@ class SQLiteStorage(SQLiteMixin):
"s.stream_hash=f.stream_hash and s.sd_hash=?", sd_hash) "s.stream_hash=f.stream_hash and s.sd_hash=?", sd_hash)
return streams is not None return streams is not None
def rowid_for_stream(self, stream_hash: str) -> typing.Awaitable[typing.Optional[int]]:
return self.run_and_return_one_or_none(
"select rowid from file where stream_hash=?", stream_hash
)
def store_stream(self, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'): def store_stream(self, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'):
def _store_stream(transaction: sqlite3.Connection): return self.db.run(store_stream, sd_blob, descriptor)
# add the head blob and set it to be announced
transaction.execute(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?), (?, ?, ?, ?, ?, ?, ?)",
(
sd_blob.blob_hash, sd_blob.length, 0, 1, "pending", 0, 0,
descriptor.blobs[0].blob_hash, descriptor.blobs[0].length, 0, 1, "pending", 0, 0
)
)
# add the rest of the blobs with announcement off
if len(descriptor.blobs) > 2:
transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
[(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0)
for blob in descriptor.blobs[1:-1]]
)
# associate the blobs to the stream
transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)",
(descriptor.stream_hash, sd_blob.blob_hash, descriptor.key,
binascii.hexlify(descriptor.stream_name.encode()).decode(),
binascii.hexlify(descriptor.suggested_file_name.encode()).decode()))
# add the stream
transaction.executemany(
"insert or ignore into stream_blob values (?, ?, ?, ?)",
[(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv)
for blob in descriptor.blobs]
)
return self.db.run(_store_stream) def get_blobs_for_stream(self, stream_hash, only_completed=False) -> typing.Awaitable[typing.List[BlobInfo]]:
def get_blobs_for_stream(self, stream_hash, only_completed=False):
def _get_blobs_for_stream(transaction): def _get_blobs_for_stream(transaction):
crypt_blob_infos = [] crypt_blob_infos = []
stream_blobs = transaction.execute( stream_blobs = transaction.execute(
"select blob_hash, position, iv from stream_blob where stream_hash=?", (stream_hash, ) "select blob_hash, position, iv from stream_blob where stream_hash=? "
"order by position asc", (stream_hash, )
).fetchall() ).fetchall()
if only_completed: if only_completed:
lengths = transaction.execute( lengths = transaction.execute(
@ -420,7 +465,8 @@ class SQLiteStorage(SQLiteMixin):
for blob_hash, position, iv in stream_blobs: for blob_hash, position, iv in stream_blobs:
blob_length = blob_length_dict.get(blob_hash, 0) blob_length = blob_length_dict.get(blob_hash, 0)
crypt_blob_infos.append(BlobInfo(position, blob_length, iv, blob_hash)) crypt_blob_infos.append(BlobInfo(position, blob_length, iv, blob_hash))
crypt_blob_infos = sorted(crypt_blob_infos, key=lambda info: info.blob_num) if not blob_hash:
break
return crypt_blob_infos return crypt_blob_infos
return self.db.run(_get_blobs_for_stream) return self.db.run(_get_blobs_for_stream)
@ -439,24 +485,21 @@ class SQLiteStorage(SQLiteMixin):
# # # # # # # # # file stuff # # # # # # # # # # # # # # # # # # file stuff # # # # # # # # #
def save_downloaded_file(self, stream_hash, file_name, download_directory, data_payment_rate): def save_downloaded_file(self, stream_hash, file_name, download_directory,
data_payment_rate) -> typing.Awaitable[int]:
return self.save_published_file( return self.save_published_file(
stream_hash, file_name, download_directory, data_payment_rate, status="running" stream_hash, file_name, download_directory, data_payment_rate, status="running"
) )
def save_published_file(self, stream_hash: str, file_name: str, download_directory: str, data_payment_rate: float, def save_published_file(self, stream_hash: str, file_name: str, download_directory: str, data_payment_rate: float,
status="finished"): status="finished") -> typing.Awaitable[int]:
return self.db.execute( return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status)
"insert into file values (?, ?, ?, ?, ?)",
(stream_hash, binascii.hexlify(file_name.encode()).decode(),
binascii.hexlify(download_directory.encode()).decode(), data_payment_rate, status)
)
def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]: def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]:
return self.db.run(get_all_lbry_files) return self.db.run(get_all_lbry_files)
def change_file_status(self, stream_hash: str, new_status: str): def change_file_status(self, stream_hash: str, new_status: str):
log.info("update file status %s -> %s", stream_hash, new_status) log.debug("update file status %s -> %s", stream_hash, new_status)
return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash)) return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash))
def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: str, file_name: str): def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: str, file_name: str):
@ -465,6 +508,31 @@ class SQLiteStorage(SQLiteMixin):
stream_hash stream_hash
)) ))
async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']],
download_directory: str):
def _recover(transaction: sqlite3.Connection):
stream_hashes = [d.stream_hash for d, s in descriptors_and_sds]
for descriptor, sd_blob in descriptors_and_sds:
content_claim = transaction.execute(
"select * from content_claim where stream_hash=?", (descriptor.stream_hash, )
).fetchone()
delete_stream(transaction, descriptor) # this will also delete the content claim
store_stream(transaction, sd_blob, descriptor)
store_file(transaction, descriptor.stream_hash, os.path.basename(descriptor.suggested_file_name),
download_directory, 0.0, 'stopped')
if content_claim:
transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim)
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
[(stream_hash, ) for stream_hash in stream_hashes]
)
download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode()
transaction.executemany(
f"update file set download_directory=? where stream_hash=?",
[(download_dir, stream_hash) for stream_hash in stream_hashes]
)
await self.db.run_with_foreign_keys_disabled(_recover)
def get_all_stream_hashes(self): def get_all_stream_hashes(self):
return self.run_and_return_list("select stream_hash from stream") return self.run_and_return_list("select stream_hash from stream")

View file

@ -4,6 +4,7 @@ import binascii
import logging import logging
import typing import typing
import asyncio import asyncio
from collections import OrderedDict
from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
from lbrynet.blob import MAX_BLOB_SIZE from lbrynet.blob import MAX_BLOB_SIZE
from lbrynet.blob.blob_info import BlobInfo from lbrynet.blob.blob_info import BlobInfo
@ -82,10 +83,41 @@ class StreamDescriptor:
[blob_info.as_dict() for blob_info in self.blobs]), sort_keys=True [blob_info.as_dict() for blob_info in self.blobs]), sort_keys=True
).encode() ).encode()
async def make_sd_blob(self): def old_sort_json(self) -> bytes:
sd_hash = self.calculate_sd_hash() blobs = []
sd_data = self.as_json() for b in self.blobs:
sd_blob = BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data)) blobs.append(OrderedDict(
[('length', b.length), ('blob_num', b.blob_num), ('iv', b.iv)] if not b.blob_hash else
[('length', b.length), ('blob_num', b.blob_num), ('blob_hash', b.blob_hash), ('iv', b.iv)]
))
if not b.blob_hash:
break
return json.dumps(
OrderedDict([
('stream_name', binascii.hexlify(self.stream_name.encode()).decode()),
('blobs', blobs),
('stream_type', 'lbryfile'),
('key', self.key),
('suggested_file_name', binascii.hexlify(self.suggested_file_name.encode()).decode()),
('stream_hash', self.stream_hash),
])
).encode()
def calculate_old_sort_sd_hash(self):
h = get_lbry_hash_obj()
h.update(self.old_sort_json())
return h.hexdigest()
async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None,
old_sort: typing.Optional[bool] = False):
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
if not old_sort:
sd_data = self.as_json()
else:
sd_data = self.old_sort_json()
sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
if blob_file_obj:
blob_file_obj.set_length(len(sd_data))
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
writer = sd_blob.open_for_writing() writer = sd_blob.open_for_writing()
writer.write(sd_data) writer.write(sd_data)
@ -160,8 +192,8 @@ class StreamDescriptor:
@classmethod @classmethod
async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str, async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
file_path: str, key: typing.Optional[bytes] = None, file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None,
) -> 'StreamDescriptor': old_sort: bool = False) -> 'StreamDescriptor':
blobs: typing.List[BlobInfo] = [] blobs: typing.List[BlobInfo] = []
@ -180,7 +212,7 @@ class StreamDescriptor:
loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path), loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path),
blobs blobs
) )
sd_blob = await descriptor.make_sd_blob() sd_blob = await descriptor.make_sd_blob(old_sort=old_sort)
descriptor.sd_hash = sd_blob.blob_hash descriptor.sd_hash = sd_blob.blob_hash
return descriptor return descriptor
@ -190,3 +222,19 @@ class StreamDescriptor:
def upper_bound_decrypted_length(self) -> int: def upper_bound_decrypted_length(self) -> int:
return self.lower_bound_decrypted_length() + (AES.block_size // 8) return self.lower_bound_decrypted_length() + (AES.block_size // 8)
@classmethod
async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str,
suggested_file_name: str, key: str,
blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']:
descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name,
blobs, stream_hash, sd_blob.blob_hash)
if descriptor.calculate_sd_hash() == sd_blob.blob_hash: # first check for a normal valid sd
old_sort = False
elif descriptor.calculate_old_sort_sd_hash() == sd_blob.blob_hash: # check if old field sorting works
old_sort = True
else:
return
await descriptor.make_sd_blob(sd_blob, old_sort)
return descriptor

View file

@ -21,11 +21,13 @@ class ManagedStream:
STATUS_STOPPED = "stopped" STATUS_STOPPED = "stopped"
STATUS_FINISHED = "finished" STATUS_FINISHED = "finished"
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor', def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', rowid: int,
download_directory: str, file_name: str, downloader: typing.Optional[StreamDownloader] = None, descriptor: 'StreamDescriptor', download_directory: str, file_name: str,
downloader: typing.Optional[StreamDownloader] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None): status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None):
self.loop = loop self.loop = loop
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.rowid = rowid
self.download_directory = download_directory self.download_directory = download_directory
self._file_name = file_name self._file_name = file_name
self.descriptor = descriptor self.descriptor = descriptor
@ -168,7 +170,9 @@ class ManagedStream:
await blob_manager.blob_completed(sd_blob) await blob_manager.blob_completed(sd_blob)
for blob in descriptor.blobs[:-1]: for blob in descriptor.blobs[:-1]:
await blob_manager.blob_completed(blob_manager.get_blob(blob.blob_hash, blob.length)) await blob_manager.blob_completed(blob_manager.get_blob(blob.blob_hash, blob.length))
return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path), row_id = await blob_manager.storage.save_published_file(descriptor.stream_hash, os.path.basename(file_path),
os.path.dirname(file_path), 0)
return cls(loop, blob_manager, row_id, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
status=cls.STATUS_FINISHED) status=cls.STATUS_FINISHED)
def start_download(self, node: typing.Optional['Node']): def start_download(self, node: typing.Optional['Node']):

View file

@ -6,6 +6,7 @@ import logging
import random import random
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \ from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \
DownloadDataTimeout, DownloadSDTimeout DownloadDataTimeout, DownloadSDTimeout
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import ClaimDict from lbrynet.schema.claim import ClaimDict
@ -16,14 +17,14 @@ if typing.TYPE_CHECKING:
from lbrynet.conf import Config from lbrynet.conf import Config
from lbrynet.blob.blob_manager import BlobFileManager from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.dht.node import Node from lbrynet.dht.node import Node
from lbrynet.extras.daemon.storage import SQLiteStorage from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim
from lbrynet.extras.wallet import LbryWalletManager from lbrynet.extras.wallet import LbryWalletManager
from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
filter_fields = [ filter_fields = [
'rowid',
'status', 'status',
'file_name', 'file_name',
'sd_hash', 'sd_hash',
@ -96,7 +97,7 @@ class StreamManager:
stream.update_status('running') stream.update_status('running')
stream.start_download(self.node) stream.start_download(self.node)
try: try:
await asyncio.wait_for(self.loop.create_task(stream.downloader.got_descriptor.wait()), await asyncio.wait_for(self.loop.create_task(stream.downloader.wrote_bytes_event.wait()),
self.config.download_timeout) self.config.download_timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
await self.stop_stream(stream) await self.stop_stream(stream)
@ -104,9 +105,12 @@ class StreamManager:
self.streams.remove(stream) self.streams.remove(stream)
return False return False
file_name = os.path.basename(stream.downloader.output_path) file_name = os.path.basename(stream.downloader.output_path)
output_dir = os.path.dirname(stream.downloader.output_path)
await self.storage.change_file_download_dir_and_file_name( await self.storage.change_file_download_dir_and_file_name(
stream.stream_hash, self.config.download_dir, file_name stream.stream_hash, output_dir, file_name
) )
stream._file_name = file_name
stream.download_directory = output_dir
self.wait_for_stream_finished(stream) self.wait_for_stream_finished(stream)
return True return True
return True return True
@ -128,33 +132,75 @@ class StreamManager:
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
) )
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim): async def recover_streams(self, file_infos: typing.List[typing.Dict]):
sd_blob = self.blob_manager.get_blob(sd_hash) to_restore = []
if sd_blob.get_is_verified():
try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
except InvalidStreamDescriptorError as err:
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
return
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name) async def recover_stream(sd_hash: str, stream_hash: str, stream_name: str,
stream = ManagedStream( suggested_file_name: str, key: str) -> typing.Optional[StreamDescriptor]:
self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim sd_blob = self.blob_manager.get_blob(sd_hash)
blobs = await self.storage.get_blobs_for_stream(stream_hash)
descriptor = await StreamDescriptor.recover(
self.blob_manager.blob_dir, sd_blob, stream_hash, stream_name, suggested_file_name, key, blobs
) )
self.streams.add(stream) if not descriptor:
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) return
to_restore.append((descriptor, sd_blob))
async def load_streams_from_database(self):
log.info("Initializing stream manager from %s", self.storage._db_path)
file_infos = await self.storage.get_all_lbry_files()
log.info("Initializing %i files", len(file_infos))
await asyncio.gather(*[ await asyncio.gather(*[
self.add_stream( recover_stream(
file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(), file_info['sd_hash'], file_info['stream_hash'], binascii.unhexlify(file_info['stream_name']).decode(),
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'], file_info['claim'] binascii.unhexlify(file_info['suggested_file_name']).decode(), file_info['key']
) for file_info in file_infos ) for file_info in file_infos
]) ])
log.info("Started stream manager with %i files", len(file_infos))
if to_restore:
await self.storage.recover_streams(to_restore, self.config.download_dir)
log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos))
async def add_stream(self, rowid: int, sd_hash: str, file_name: str, download_directory: str, status: str,
claim: typing.Optional['StoredStreamClaim']):
sd_blob = self.blob_manager.get_blob(sd_hash)
if not sd_blob.get_is_verified():
return
try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
except InvalidStreamDescriptorError as err:
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
return
if status == ManagedStream.STATUS_RUNNING:
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
else:
downloader = None
stream = ManagedStream(
self.loop, self.blob_manager, rowid, descriptor, download_directory, file_name, downloader, status, claim
)
self.streams.add(stream)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def load_streams_from_database(self):
to_recover = []
for file_info in await self.storage.get_all_lbry_files():
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_recover.append(file_info)
if to_recover:
log.info("Attempting to recover %i streams", len(to_recover))
await self.recover_streams(to_recover)
to_start = []
for file_info in await self.storage.get_all_lbry_files():
if self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_start.append(file_info)
log.info("Initializing %i files", len(to_start))
await asyncio.gather(*[
self.add_stream(
file_info['rowid'], file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(),
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'],
file_info['claim']
) for file_info in to_start
])
log.info("Started stream manager with %i files", len(self.streams))
async def resume(self): async def resume(self):
if self.node: if self.node:
@ -264,14 +310,16 @@ class StreamManager:
if not await self.blob_manager.storage.stream_exists(downloader.sd_hash): if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
await self.blob_manager.storage.store_stream(downloader.sd_blob, downloader.descriptor) await self.blob_manager.storage.store_stream(downloader.sd_blob, downloader.descriptor)
if not await self.blob_manager.storage.file_exists(downloader.sd_hash): if not await self.blob_manager.storage.file_exists(downloader.sd_hash):
await self.blob_manager.storage.save_downloaded_file( rowid = await self.blob_manager.storage.save_downloaded_file(
downloader.descriptor.stream_hash, file_name, download_directory, downloader.descriptor.stream_hash, file_name, download_directory,
0.0 0.0
) )
else:
rowid = self.blob_manager.storage.rowid_for_stream(downloader.descriptor.stream_hash)
await self.blob_manager.storage.save_content_claim( await self.blob_manager.storage.save_content_claim(
downloader.descriptor.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}" downloader.descriptor.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}"
) )
stream = ManagedStream(self.loop, self.blob_manager, downloader.descriptor, download_directory, stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, download_directory,
file_name, downloader, ManagedStream.STATUS_RUNNING) file_name, downloader, ManagedStream.STATUS_RUNNING)
stream.set_claim(claim_info, claim) stream.set_claim(claim_info, claim)
self.streams.add(stream) self.streams.add(stream)
@ -337,6 +385,8 @@ class StreamManager:
raise ValueError(f"'{sort_by}' is not a valid field to sort by") raise ValueError(f"'{sort_by}' is not a valid field to sort by")
if comparison and comparison not in comparison_operators: if comparison and comparison not in comparison_operators:
raise ValueError(f"'{comparison}' is not a valid comparison") raise ValueError(f"'{comparison}' is not a valid comparison")
if 'full_status' in search_by:
del search_by['full_status']
for search in search_by.keys(): for search in search_by.keys():
if search not in filter_fields: if search not in filter_fields:
raise ValueError(f"'{search}' is not a valid search operation") raise ValueError(f"'{search}' is not a valid search operation")
@ -345,8 +395,6 @@ class StreamManager:
streams = [] streams = []
for stream in self.streams: for stream in self.streams:
for search, val in search_by.items(): for search, val in search_by.items():
if search == 'full_status':
continue
if comparison_operators[comparison](getattr(stream, search), val): if comparison_operators[comparison](getattr(stream, search), val):
streams.append(stream) streams.append(stream)
break break

View file

@ -0,0 +1,54 @@
import asyncio
import tempfile
import shutil
import os
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
class TestBlobManager(AsyncioTestCase):
async def test_sync_blob_manager_on_startup(self):
loop = asyncio.get_event_loop()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite"))
blob_manager = BlobFileManager(loop, tmp_dir, storage)
# add a blob file
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
with open(os.path.join(blob_manager.blob_dir, blob_hash), 'wb') as f:
f.write(blob_bytes)
# it should not have been added automatically on startup
await storage.open()
await blob_manager.setup()
self.assertSetEqual(blob_manager.completed_blob_hashes, set())
# make sure we can add the blob
await blob_manager.blob_completed(blob_manager.get_blob(blob_hash, len(blob_bytes)))
self.assertSetEqual(blob_manager.completed_blob_hashes, {blob_hash})
# stop the blob manager and restart it, make sure the blob is there
blob_manager.stop()
self.assertSetEqual(blob_manager.completed_blob_hashes, set())
await blob_manager.setup()
self.assertSetEqual(blob_manager.completed_blob_hashes, {blob_hash})
# test that the blob is removed upon the next startup after the file being manually deleted
blob_manager.stop()
# manually delete the blob file and restart the blob manager
os.remove(os.path.join(blob_manager.blob_dir, blob_hash))
await blob_manager.setup()
self.assertSetEqual(blob_manager.completed_blob_hashes, set())
# check that the deleted blob was updated in the database
self.assertEqual(
'pending', (
await storage.run_and_return_one_or_none('select status from blob where blob_hash=?', blob_hash)
)
)

View file

@ -78,8 +78,7 @@ class StorageTest(AsyncioTestCase):
await self.storage.close() await self.storage.close()
async def store_fake_blob(self, blob_hash, length=100): async def store_fake_blob(self, blob_hash, length=100):
await self.storage.add_known_blob(blob_hash, length) await self.storage.add_completed_blob(blob_hash, length)
await self.storage.add_completed_blob(blob_hash)
async def store_fake_stream(self, stream_hash, blobs=None, file_name="fake_file", key="DEADBEEF"): async def store_fake_stream(self, stream_hash, blobs=None, file_name="fake_file", key="DEADBEEF"):
blobs = blobs or [BlobInfo(1, 100, "DEADBEEF", random_lbry_hash())] blobs = blobs or [BlobInfo(1, 100, "DEADBEEF", random_lbry_hash())]

View file

@ -75,3 +75,35 @@ class TestStreamDescriptor(AsyncioTestCase):
async def test_zero_length_blob(self): async def test_zero_length_blob(self):
self.sd_dict['blobs'][-2]['length'] = 0 self.sd_dict['blobs'][-2]['length'] = 0
await self._test_invalid_sd() await self._test_invalid_sd()
class TestRecoverOldStreamDescriptors(AsyncioTestCase):
async def test_old_key_sort_sd_blob(self):
loop = asyncio.get_event_loop()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
storage = SQLiteStorage(Config(), ":memory:")
await storage.open()
blob_manager = BlobFileManager(loop, tmp_dir, storage)
sd_bytes = b'{"stream_name": "4f62616d6120446f6e6b65792d322e73746c", "blobs": [{"length": 1153488, "blob_num' \
b'": 0, "blob_hash": "9fa32a249ce3f2d4e46b78599800f368b72f2a7f22b81df443c7f6bdbef496bd61b4c0079c7' \
b'3d79c8bb9be9a6bf86592", "iv": "0bf348867244019c9e22196339016ea6"}, {"length": 0, "blob_num": 1,' \
b' "iv": "9f36abae16955463919b07ed530a3d18"}], "stream_type": "lbryfile", "key": "a03742b87628aa7' \
b'228e48f1dcd207e48", "suggested_file_name": "4f62616d6120446f6e6b65792d322e73746c", "stream_hash' \
b'": "b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e1' \
b'60207"}'
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
stream_hash = 'b43f4b1379780caf60d20aa06ac38fb144df61e514ebfa97537018ba73bce8fe37ae712f473ff0ba0be0eef44e160207'
blob = blob_manager.get_blob(sd_hash)
blob.set_length(len(sd_bytes))
writer = blob.open_for_writing()
writer.write(sd_bytes)
await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
loop, blob_manager.blob_dir, blob
)
self.assertEqual(stream_hash, descriptor.get_stream_hash())
self.assertEqual(sd_hash, descriptor.calculate_old_sort_sd_hash())
self.assertNotEqual(sd_hash, descriptor.calculate_sd_hash())

View file

@ -23,6 +23,8 @@ def get_mock_node(peer):
mock_node = mock.Mock(spec=Node) mock_node = mock.Mock(spec=Node)
mock_node.accumulate_peers = mock_accumulate_peers mock_node.accumulate_peers = mock_accumulate_peers
mock_node.joined = asyncio.Event()
mock_node.joined.set()
return mock_node return mock_node
@ -91,15 +93,13 @@ def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None):
class TestStreamManager(BlobExchangeTestBase): class TestStreamManager(BlobExchangeTestBase):
async def asyncSetUp(self): async def setup_stream_manager(self, balance=10.0, fee=None, old_sort=False):
await super().asyncSetUp()
file_path = os.path.join(self.server_dir, "test_file") file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(os.urandom(20000000)) f.write(os.urandom(20000000))
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path,
self.sd_hash = descriptor.calculate_sd_hash() old_sort=old_sort)
self.sd_hash = descriptor.sd_hash
async def setup_stream_manager(self, balance=10.0, fee=None):
self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage, balance, fee) self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage, balance, fee)
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet, self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet,
self.client_storage, get_mock_node(self.server_from_client)) self.client_storage, get_mock_node(self.server_from_client))
@ -169,3 +169,26 @@ class TestStreamManager(BlobExchangeTestBase):
await self.setup_stream_manager(1000000.0, fee) await self.setup_stream_manager(1000000.0, fee)
with self.assertRaises(KeyFeeAboveMaxAllowed): with self.assertRaises(KeyFeeAboveMaxAllowed):
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
async def test_download_then_recover_stream_on_startup(self, old_sort=False):
await self.setup_stream_manager(old_sort=old_sort)
self.assertSetEqual(self.stream_manager.streams, set())
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
await stream.downloader.stream_finished_event.wait()
self.stream_manager.stop()
self.client_blob_manager.stop()
os.remove(os.path.join(self.client_blob_manager.blob_dir, stream.sd_hash))
for blob in stream.descriptor.blobs[:-1]:
os.remove(os.path.join(self.client_blob_manager.blob_dir, blob.blob_hash))
await self.client_blob_manager.setup()
await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams))
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash)
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists)
self.assertTrue(sd_blob.get_is_verified())
def test_download_then_recover_old_sort_stream_on_startup(self):
return self.test_download_then_recover_stream_on_startup(old_sort=True)