diff --git a/example_daemon_settings.yml b/example_daemon_settings.yml index 79ca48177..d3dd1224f 100644 --- a/example_daemon_settings.yml +++ b/example_daemon_settings.yml @@ -12,12 +12,13 @@ blockchain_name: lbrycrd_main data_dir: /home/lbry/.lbrynet download_directory: /home/lbry/downloads -delete_blobs_on_remove: True +save_blobs: true +save_files: false dht_node_port: 4444 peer_port: 3333 -use_upnp: True +use_upnp: true #components_to_skip: -# - peer_protocol_server # - hash_announcer +# - blob_server # - dht diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index 2206b886e..484e93122 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -4,6 +4,8 @@ import asyncio import binascii import logging import typing +import contextlib +from io import BytesIO from cryptography.hazmat.primitives.ciphers import Cipher, modes from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.padding import PKCS7 @@ -21,10 +23,6 @@ log = logging.getLogger(__name__) _hexmatch = re.compile("^[a-f,0-9]+$") -def is_valid_hashcharacter(char: str) -> bool: - return len(char) == 1 and _hexmatch.match(char) - - def is_valid_blobhash(blobhash: str) -> bool: """Checks whether the blobhash is the correct length and contains only valid characters (0-9, a-f) @@ -46,143 +44,74 @@ def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tupl return encrypted, digest.hexdigest() -class BlobFile: - """ - A chunk of data available on the network which is specified by a hashsum +def decrypt_blob_bytes(data: bytes, length: int, key: bytes, iv: bytes) -> bytes: + if len(data) != length: + raise ValueError("unexpected length") + cipher = Cipher(AES(key), modes.CBC(iv), backend=backend) + unpadder = PKCS7(AES.block_size).unpadder() + decryptor = cipher.decryptor() + return unpadder.update(decryptor.update(data) + decryptor.finalize()) + unpadder.finalize() - This class is used to create blobs on the local filesystem - when we already know the blob hash before hand (i.e., when downloading blobs) - Also can be used for reading from blobs on the local filesystem - """ - def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str, - length: typing.Optional[int] = None, - blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None): - if not is_valid_blobhash(blob_hash): - raise InvalidBlobHashError(blob_hash) +class AbstractBlob: + """ + A chunk of data (up to 2MB) available on the network which is specified by a sha384 hash + + This class is non-io specific + """ + __slots__ = [ + 'loop', + 'blob_hash', + 'length', + 'blob_completed_callback', + 'blob_directory', + 'writers', + 'verified', + 'writing', + 'readers' + ] + + def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, + blob_directory: typing.Optional[str] = None): self.loop = loop self.blob_hash = blob_hash self.length = length - self.blob_dir = blob_dir - self.file_path = os.path.join(blob_dir, self.blob_hash) - self.writers: typing.List[HashBlobWriter] = [] - - self.verified: asyncio.Event = asyncio.Event(loop=self.loop) - self.finished_writing = asyncio.Event(loop=loop) - self.blob_write_lock = asyncio.Lock(loop=loop) - if self.file_exists: - length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size) - self.length = length - self.verified.set() - self.finished_writing.set() - self.saved_verified_blob = False self.blob_completed_callback = blob_completed_callback + self.blob_directory = blob_directory + self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {} + self.verified: asyncio.Event = asyncio.Event(loop=self.loop) + self.writing: asyncio.Event = asyncio.Event(loop=self.loop) + self.readers: typing.List[typing.BinaryIO] = [] - @property - def file_exists(self): - return os.path.isfile(self.file_path) + if not is_valid_blobhash(blob_hash): + raise InvalidBlobHashError(blob_hash) - def writer_finished(self, writer: HashBlobWriter): - def callback(finished: asyncio.Future): + def __del__(self): + if self.writers or self.readers: + log.warning("%s not closed before being garbage collected", self.blob_hash) + self.close() + + @contextlib.contextmanager + def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + raise NotImplementedError() + + @contextlib.contextmanager + def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + if not self.is_readable(): + raise OSError(f"{str(type(self))} not readable, {len(self.readers)} readers {len(self.writers)} writers") + with self._reader_context() as reader: try: - error = finished.exception() - except Exception as err: - error = err - if writer in self.writers: # remove this download attempt - self.writers.remove(writer) - if not error: # the blob downloaded, cancel all the other download attempts and set the result - while self.writers: - other = self.writers.pop() - other.finished.cancel() - t = self.loop.create_task(self.save_verified_blob(writer, finished.result())) - t.add_done_callback(lambda *_: self.finished_writing.set()) - return - if isinstance(error, (InvalidBlobHashError, InvalidDataError)): - log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error)) - elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)): - log.exception("something else") - raise error - return callback + self.readers.append(reader) + yield reader + finally: + if reader in self.readers: + self.readers.remove(reader) - async def save_verified_blob(self, writer, verified_bytes: bytes): - def _save_verified(): - # log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}") - if not self.saved_verified_blob and not os.path.isfile(self.file_path): - if self.get_length() == len(verified_bytes): - with open(self.file_path, 'wb') as write_handle: - write_handle.write(verified_bytes) - self.saved_verified_blob = True - else: - raise Exception("length mismatch") + def _write_blob(self, blob_bytes: bytes): + raise NotImplementedError() - async with self.blob_write_lock: - if self.verified.is_set(): - return - await self.loop.run_in_executor(None, _save_verified) - if self.blob_completed_callback: - await self.blob_completed_callback(self) - self.verified.set() - - def open_for_writing(self) -> HashBlobWriter: - if self.file_exists: - raise OSError(f"File already exists '{self.file_path}'") - fut = asyncio.Future(loop=self.loop) - writer = HashBlobWriter(self.blob_hash, self.get_length, fut) - self.writers.append(writer) - fut.add_done_callback(self.writer_finished(writer)) - return writer - - async def sendfile(self, writer: asyncio.StreamWriter) -> int: - """ - Read and send the file to the writer and return the number of bytes sent - """ - - with open(self.file_path, 'rb') as handle: - return await self.loop.sendfile(writer.transport, handle, count=self.get_length()) - - def close(self): - while self.writers: - self.writers.pop().finished.cancel() - - def delete(self): - self.close() - self.saved_verified_blob = False - if os.path.isfile(self.file_path): - os.remove(self.file_path) - self.verified.clear() - self.finished_writing.clear() - self.length = None - - def decrypt(self, key: bytes, iv: bytes) -> bytes: - """ - Decrypt a BlobFile to plaintext bytes - """ - - with open(self.file_path, "rb") as f: - buff = f.read() - if len(buff) != self.length: - raise ValueError("unexpected length") - cipher = Cipher(AES(key), modes.CBC(iv), backend=backend) - unpadder = PKCS7(AES.block_size).unpadder() - decryptor = cipher.decryptor() - return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize() - - @classmethod - async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes, - iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo: - """ - Create an encrypted BlobFile from plaintext bytes - """ - - blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted) - length = len(blob_bytes) - blob = cls(loop, blob_dir, blob_hash, length) - writer = blob.open_for_writing() - writer.write(blob_bytes) - await blob.verified.wait() - return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash) - - def set_length(self, length): + def set_length(self, length) -> None: if self.length is not None and length == self.length: return if self.length is None and 0 <= length <= MAX_BLOB_SIZE: @@ -190,8 +119,217 @@ class BlobFile: return log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length) - def get_length(self): + def get_length(self) -> typing.Optional[int]: return self.length - def get_is_verified(self): + def get_is_verified(self) -> bool: return self.verified.is_set() + + def is_readable(self) -> bool: + return self.verified.is_set() + + def is_writeable(self) -> bool: + return not self.writing.is_set() + + def write_blob(self, blob_bytes: bytes): + if not self.is_writeable(): + raise OSError("cannot open blob for writing") + try: + self.writing.set() + self._write_blob(blob_bytes) + finally: + self.writing.clear() + + def close(self): + while self.writers: + peer, writer = self.writers.popitem() + if writer and writer.finished and not writer.finished.done() and not self.loop.is_closed(): + writer.finished.cancel() + while self.readers: + reader = self.readers.pop() + if reader: + reader.close() + + def delete(self): + self.close() + self.verified.clear() + self.length = None + + async def sendfile(self, writer: asyncio.StreamWriter) -> int: + """ + Read and send the file to the writer and return the number of bytes sent + """ + + if not self.is_readable(): + raise OSError('blob files cannot be read') + with self.reader_context() as handle: + return await self.loop.sendfile(writer.transport, handle, count=self.get_length()) + + def decrypt(self, key: bytes, iv: bytes) -> bytes: + """ + Decrypt a BlobFile to plaintext bytes + """ + + with self.reader_context() as reader: + return decrypt_blob_bytes(reader.read(), self.length, key, iv) + + @classmethod + async def create_from_unencrypted( + cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes, + unencrypted: bytes, blob_num: int, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None) -> BlobInfo: + """ + Create an encrypted BlobFile from plaintext bytes + """ + + blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted) + length = len(blob_bytes) + blob = cls(loop, blob_hash, length, blob_completed_callback, blob_dir) + writer = blob.get_blob_writer() + writer.write(blob_bytes) + await blob.verified.wait() + return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash) + + def save_verified_blob(self, verified_bytes: bytes): + if self.verified.is_set(): + return + if self.is_writeable(): + self._write_blob(verified_bytes) + self.verified.set() + if self.blob_completed_callback: + self.blob_completed_callback(self) + + def get_blob_writer(self, peer_address: typing.Optional[str] = None, + peer_port: typing.Optional[int] = None) -> HashBlobWriter: + if (peer_address, peer_port) in self.writers: + raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}") + fut = asyncio.Future(loop=self.loop) + writer = HashBlobWriter(self.blob_hash, self.get_length, fut) + self.writers[(peer_address, peer_port)] = writer + + def remove_writer(_): + if (peer_address, peer_port) in self.writers: + del self.writers[(peer_address, peer_port)] + + fut.add_done_callback(remove_writer) + + def writer_finished_callback(finished: asyncio.Future): + try: + err = finished.exception() + if err: + raise err + verified_bytes = finished.result() + while self.writers: + _, other = self.writers.popitem() + if other is not writer: + other.close_handle() + self.save_verified_blob(verified_bytes) + except (InvalidBlobHashError, InvalidDataError) as error: + log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error)) + except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError): + pass + + fut.add_done_callback(writer_finished_callback) + return writer + + +class BlobBuffer(AbstractBlob): + """ + An in-memory only blob + """ + def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, + blob_directory: typing.Optional[str] = None): + self._verified_bytes: typing.Optional[BytesIO] = None + super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) + + @contextlib.contextmanager + def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + if not self.is_readable(): + raise OSError("cannot open blob for reading") + try: + yield self._verified_bytes + finally: + if self._verified_bytes: + self._verified_bytes.close() + self._verified_bytes = None + self.verified.clear() + + def _write_blob(self, blob_bytes: bytes): + if self._verified_bytes: + raise OSError("already have bytes for blob") + self._verified_bytes = BytesIO(blob_bytes) + + def delete(self): + if self._verified_bytes: + self._verified_bytes.close() + self._verified_bytes = None + return super().delete() + + def __del__(self): + super().__del__() + if self._verified_bytes: + self.delete() + + +class BlobFile(AbstractBlob): + """ + A blob existing on the local file system + """ + def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, + blob_directory: typing.Optional[str] = None): + super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) + if not blob_directory or not os.path.isdir(blob_directory): + raise OSError(f"invalid blob directory '{blob_directory}'") + self.file_path = os.path.join(self.blob_directory, self.blob_hash) + if self.file_exists: + file_size = int(os.stat(self.file_path).st_size) + if length and length != file_size: + log.warning("expected %s to be %s bytes, file has %s", self.blob_hash, length, file_size) + self.delete() + else: + self.length = file_size + self.verified.set() + + @property + def file_exists(self): + return os.path.isfile(self.file_path) + + def is_writeable(self) -> bool: + return super().is_writeable() and not os.path.isfile(self.file_path) + + def get_blob_writer(self, peer_address: typing.Optional[str] = None, + peer_port: typing.Optional[str] = None) -> HashBlobWriter: + if self.file_exists: + raise OSError(f"File already exists '{self.file_path}'") + return super().get_blob_writer(peer_address, peer_port) + + @contextlib.contextmanager + def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + handle = open(self.file_path, 'rb') + try: + yield handle + finally: + handle.close() + + def _write_blob(self, blob_bytes: bytes): + with open(self.file_path, 'wb') as f: + f.write(blob_bytes) + + def delete(self): + if os.path.isfile(self.file_path): + os.remove(self.file_path) + return super().delete() + + @classmethod + async def create_from_unencrypted( + cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes, + unencrypted: bytes, blob_num: int, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], + asyncio.Task]] = None) -> BlobInfo: + if not blob_dir or not os.path.isdir(blob_dir): + raise OSError(f"cannot create blob in directory: '{blob_dir}'") + return await super().create_from_unencrypted( + loop, blob_dir, key, iv, unencrypted, blob_num, blob_completed_callback + ) diff --git a/lbrynet/blob/blob_manager.py b/lbrynet/blob/blob_manager.py index 84379b67f..9ae0a0e0e 100644 --- a/lbrynet/blob/blob_manager.py +++ b/lbrynet/blob/blob_manager.py @@ -2,18 +2,19 @@ import os import typing import asyncio import logging -from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_file import BlobFile, is_valid_blobhash +from lbrynet.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob from lbrynet.stream.descriptor import StreamDescriptor if typing.TYPE_CHECKING: + from lbrynet.conf import Config from lbrynet.dht.protocol.data_store import DictDataStore + from lbrynet.extras.daemon.storage import SQLiteStorage log = logging.getLogger(__name__) -class BlobFileManager: - def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: SQLiteStorage, +class BlobManager: + def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage', config: 'Config', node_data_store: typing.Optional['DictDataStore'] = None): """ This class stores blobs on the hard disk @@ -27,16 +28,59 @@ class BlobFileManager: self._node_data_store = node_data_store self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\ else self._node_data_store.completed_blobs - self.blobs: typing.Dict[str, BlobFile] = {} + self.blobs: typing.Dict[str, AbstractBlob] = {} + self.config = config + + def _get_blob(self, blob_hash: str, length: typing.Optional[int] = None): + if self.config.save_blobs: + return BlobFile( + self.loop, blob_hash, length, self.blob_completed, self.blob_dir + ) + else: + if is_valid_blobhash(blob_hash) and os.path.isfile(os.path.join(self.blob_dir, blob_hash)): + return BlobFile( + self.loop, blob_hash, length, self.blob_completed, self.blob_dir + ) + return BlobBuffer( + self.loop, blob_hash, length, self.blob_completed, self.blob_dir + ) + + def get_blob(self, blob_hash, length: typing.Optional[int] = None): + if blob_hash in self.blobs: + if self.config.save_blobs and isinstance(self.blobs[blob_hash], BlobBuffer): + buffer = self.blobs.pop(blob_hash) + if blob_hash in self.completed_blob_hashes: + self.completed_blob_hashes.remove(blob_hash) + self.blobs[blob_hash] = self._get_blob(blob_hash, length) + if buffer.is_readable(): + with buffer.reader_context() as reader: + self.blobs[blob_hash].write_blob(reader.read()) + if length and self.blobs[blob_hash].length is None: + self.blobs[blob_hash].set_length(length) + else: + self.blobs[blob_hash] = self._get_blob(blob_hash, length) + return self.blobs[blob_hash] + + def is_blob_verified(self, blob_hash: str, length: typing.Optional[int] = None) -> bool: + if not is_valid_blobhash(blob_hash): + raise ValueError(blob_hash) + if blob_hash in self.blobs: + return self.blobs[blob_hash].get_is_verified() + if not os.path.isfile(os.path.join(self.blob_dir, blob_hash)): + return False + return self._get_blob(blob_hash, length).get_is_verified() async def setup(self) -> bool: def get_files_in_blob_dir() -> typing.Set[str]: + if not self.blob_dir: + return set() return { item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name) } - in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir) - self.completed_blob_hashes.update(await self.storage.sync_missing_blobs(in_blobfiles_dir)) + to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir) + if to_add: + self.completed_blob_hashes.update(to_add) return True def stop(self): @@ -45,37 +89,31 @@ class BlobFileManager: blob.close() self.completed_blob_hashes.clear() - def get_blob(self, blob_hash, length: typing.Optional[int] = None): - if blob_hash in self.blobs: - if length and self.blobs[blob_hash].length is None: - self.blobs[blob_hash].set_length(length) - else: - self.blobs[blob_hash] = BlobFile(self.loop, self.blob_dir, blob_hash, length, self.blob_completed) - return self.blobs[blob_hash] - def get_stream_descriptor(self, sd_hash): return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash)) - async def blob_completed(self, blob: BlobFile): + def blob_completed(self, blob: AbstractBlob) -> asyncio.Task: if blob.blob_hash is None: raise Exception("Blob hash is None") if not blob.length: raise Exception("Blob has a length of 0") - if blob.blob_hash not in self.completed_blob_hashes: - self.completed_blob_hashes.add(blob.blob_hash) - await self.storage.add_completed_blob(blob.blob_hash, blob.length) + if isinstance(blob, BlobFile): + if blob.blob_hash not in self.completed_blob_hashes: + self.completed_blob_hashes.add(blob.blob_hash) + return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=True)) + else: + return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=False)) def check_completed_blobs(self, blob_hashes: typing.List[str]) -> typing.List[str]: """Returns of the blobhashes_to_check, which are valid""" - blobs = [self.get_blob(b) for b in blob_hashes] - return [blob.blob_hash for blob in blobs if blob.get_is_verified()] + return [blob_hash for blob_hash in blob_hashes if self.is_blob_verified(blob_hash)] def delete_blob(self, blob_hash: str): if not is_valid_blobhash(blob_hash): raise Exception("invalid blob hash to delete") if blob_hash not in self.blobs: - if os.path.isfile(os.path.join(self.blob_dir, blob_hash)): + if self.blob_dir and os.path.isfile(os.path.join(self.blob_dir, blob_hash)): os.remove(os.path.join(self.blob_dir, blob_hash)) else: self.blobs.pop(blob_hash).delete() diff --git a/lbrynet/blob/writer.py b/lbrynet/blob/writer.py index 699c226d8..b4b2902a8 100644 --- a/lbrynet/blob/writer.py +++ b/lbrynet/blob/writer.py @@ -44,16 +44,15 @@ class HashBlobWriter: self._hashsum.update(data) self.len_so_far += len(data) if self.len_so_far > expected_length: - self.close_handle() self.finished.set_exception(InvalidDataError( f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}' )) + self.close_handle() return self.buffer.write(data) if self.len_so_far == expected_length: blob_hash = self.calculate_blob_hash() if blob_hash != self.expected_blob_hash: - self.close_handle() self.finished.set_exception(InvalidBlobHashError( f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}" )) @@ -62,6 +61,8 @@ class HashBlobWriter: self.close_handle() def close_handle(self): + if not self.finished.done(): + self.finished.cancel() if self.buffer is not None: self.buffer.close() self.buffer = None diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index df0805e57..0f58efeb3 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -5,7 +5,7 @@ import binascii from lbrynet.error import InvalidBlobHashError, InvalidDataError from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest if typing.TYPE_CHECKING: - from lbrynet.blob.blob_file import BlobFile + from lbrynet.blob.blob_file import AbstractBlob from lbrynet.blob.writer import HashBlobWriter log = logging.getLogger(__name__) @@ -17,10 +17,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.peer_port: typing.Optional[int] = None self.peer_address: typing.Optional[str] = None self.peer_timeout = peer_timeout - self.transport: asyncio.Transport = None + self.transport: typing.Optional[asyncio.Transport] = None - self.writer: 'HashBlobWriter' = None - self.blob: 'BlobFile' = None + self.writer: typing.Optional['HashBlobWriter'] = None + self.blob: typing.Optional['AbstractBlob'] = None self._blob_bytes_received = 0 self._response_fut: asyncio.Future = None @@ -63,8 +63,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): # write blob bytes if we're writing a blob and have blob bytes to write self._write(response.blob_data) - - def _write(self, data): + def _write(self, data: bytes): if len(data) > (self.blob.get_length() - self._blob_bytes_received): data = data[:(self.blob.get_length() - self._blob_bytes_received)] log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port) @@ -145,17 +144,18 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.transport = None self.buf = b'' - async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: - if blob.get_is_verified() or blob.file_exists or blob.blob_write_lock.locked(): + async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: + blob_hash = blob.blob_hash + if blob.get_is_verified() or not blob.is_writeable(): return 0, self.transport try: - self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0 + self._blob_bytes_received = 0 + self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port) self._response_fut = asyncio.Future(loop=self.loop) return await self._download_blob() except OSError as e: - log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port) # i'm not sure how to fix this race condition - jack - log.exception(e) + log.exception("race happened downloading %s from %s:%i", blob_hash, self.peer_address, self.peer_port) return self._blob_bytes_received, self.transport except asyncio.TimeoutError: if self._response_fut and not self._response_fut.done(): @@ -177,7 +177,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.close() -async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int, +async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int, peer_connect_timeout: float, blob_download_timeout: float, connected_transport: asyncio.Transport = None)\ -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: @@ -196,7 +196,7 @@ async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: s if not connected_transport: await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), peer_connect_timeout, loop=loop) - if blob.get_is_verified() or blob.file_exists: + if blob.get_is_verified() or not blob.is_writeable(): # file exists but not verified means someone is writing right now, give it time, come back later return 0, connected_transport return await protocol.download_blob(blob) diff --git a/lbrynet/blob_exchange/downloader.py b/lbrynet/blob_exchange/downloader.py index a2661e455..24142f955 100644 --- a/lbrynet/blob_exchange/downloader.py +++ b/lbrynet/blob_exchange/downloader.py @@ -1,21 +1,22 @@ import asyncio import typing import logging -from lbrynet.utils import drain_tasks +from lbrynet.utils import drain_tasks, cache_concurrent from lbrynet.blob_exchange.client import request_blob if typing.TYPE_CHECKING: from lbrynet.conf import Config from lbrynet.dht.node import Node from lbrynet.dht.peer import KademliaPeer - from lbrynet.blob.blob_manager import BlobFileManager - from lbrynet.blob.blob_file import BlobFile + from lbrynet.blob.blob_manager import BlobManager + from lbrynet.blob.blob_file import AbstractBlob log = logging.getLogger(__name__) class BlobDownloader: BAN_TIME = 10.0 # fixme: when connection manager gets implemented, move it out from here - def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', + + def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', peer_queue: asyncio.Queue): self.loop = loop self.config = config @@ -27,7 +28,7 @@ class BlobDownloader: self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} self.time_since_last_blob = loop.time() - def should_race_continue(self, blob: 'BlobFile'): + def should_race_continue(self, blob: 'AbstractBlob'): if len(self.active_connections) >= self.config.max_connections_per_download: return False # if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves @@ -36,9 +37,9 @@ class BlobDownloader: # for peer, task in self.active_connections.items(): # if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done(): # return False - return not (blob.get_is_verified() or blob.file_exists) + return not (blob.get_is_verified() or not blob.is_writeable()) - async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'): + async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'): if blob.get_is_verified(): return self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones @@ -61,7 +62,7 @@ class BlobDownloader: rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0 self.scores[peer] = rough_speed - async def new_peer_or_finished(self, blob: 'BlobFile'): + async def new_peer_or_finished(self, blob: 'AbstractBlob'): async def get_and_re_add_peers(): try: new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0) @@ -89,7 +90,8 @@ class BlobDownloader: for banned_peer in forgiven: self.ignored.pop(banned_peer) - async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': + @cache_concurrent + async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob': blob = self.blob_manager.get_blob(blob_hash, length) if blob.get_is_verified(): return blob @@ -98,7 +100,7 @@ class BlobDownloader: batch: typing.List['KademliaPeer'] = [] while not self.peer_queue.empty(): batch.extend(self.peer_queue.get_nowait()) - batch.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True) + batch.sort(key=lambda p: self.scores.get(p, 0), reverse=True) log.debug( "running, %d peers, %d ignored, %d active", len(batch), len(self.ignored), len(self.active_connections) @@ -113,15 +115,26 @@ class BlobDownloader: await self.new_peer_or_finished(blob) self.cleanup_active() if batch: - self.peer_queue.put_nowait(set(batch).difference(self.ignored)) + to_re_add = list(set(batch).difference(self.ignored)) + if to_re_add: + self.peer_queue.put_nowait(to_re_add) + else: + self.clearbanned() else: self.clearbanned() blob.close() log.debug("downloaded %s", blob_hash[:8]) return blob finally: + re_add = set() while self.active_connections: - self.active_connections.popitem()[1].cancel() + peer, t = self.active_connections.popitem() + t.cancel() + re_add.add(peer) + re_add = re_add.difference(self.ignored) + if re_add: + self.peer_queue.put_nowait(list(re_add)) + blob.close() def close(self): self.scores.clear() @@ -130,8 +143,8 @@ class BlobDownloader: transport.close() -async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node', - blob_hash: str) -> 'BlobFile': +async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node', + blob_hash: str) -> 'AbstractBlob': search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download) search_queue.put_nowait(blob_hash) peer_queue, accumulate_task = node.accumulate_peers(search_queue) diff --git a/lbrynet/blob_exchange/server.py b/lbrynet/blob_exchange/server.py index 2f47718a8..da9bd3bf1 100644 --- a/lbrynet/blob_exchange/server.py +++ b/lbrynet/blob_exchange/server.py @@ -8,13 +8,13 @@ from lbrynet.blob_exchange.serialization import BlobAvailabilityResponse, BlobPr BlobPaymentAddressResponse if typing.TYPE_CHECKING: - from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.blob_manager import BlobManager log = logging.getLogger(__name__) class BlobServerProtocol(asyncio.Protocol): - def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', lbrycrd_address: str): + def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str): self.loop = loop self.blob_manager = blob_manager self.server_task: asyncio.Task = None @@ -94,7 +94,7 @@ class BlobServerProtocol(asyncio.Protocol): class BlobServer: - def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', lbrycrd_address: str): + def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str): self.loop = loop self.blob_manager = blob_manager self.server_task: asyncio.Task = None diff --git a/lbrynet/conf.py b/lbrynet/conf.py index 81fe6df71..0e80ed74a 100644 --- a/lbrynet/conf.py +++ b/lbrynet/conf.py @@ -484,6 +484,8 @@ class Config(CLIConfig): node_rpc_timeout = Float("Timeout when making a DHT request", constants.rpc_timeout) # blob announcement and download + save_blobs = Toggle("Save encrypted blob files for hosting, otherwise download blobs to memory only.", True) + announce_head_and_sd_only = Toggle( "Announce only the descriptor and first (rather than all) data blob for a stream to the DHT", True, previous_names=['announce_head_blobs_only'] @@ -537,6 +539,7 @@ class Config(CLIConfig): cache_time = Integer("Time to cache resolved claims", 150) # TODO: use this # daemon + save_files = Toggle("Save downloaded files when calling `get` by default", True) components_to_skip = Strings("components which will be skipped during start-up of daemon", []) share_usage_data = Toggle( "Whether to share usage stats and diagnostic info with LBRY.", True, diff --git a/lbrynet/dht/protocol/iterative_find.py b/lbrynet/dht/protocol/iterative_find.py index a899f210d..adbe202c2 100644 --- a/lbrynet/dht/protocol/iterative_find.py +++ b/lbrynet/dht/protocol/iterative_find.py @@ -5,7 +5,7 @@ from itertools import chain import typing import logging from lbrynet.dht import constants -from lbrynet.dht.error import RemoteException +from lbrynet.dht.error import RemoteException, TransportNotConnected from lbrynet.dht.protocol.distance import Distance from typing import TYPE_CHECKING @@ -169,7 +169,7 @@ class IterativeFinder: log.warning(str(err)) self.active.discard(peer) return - except RemoteException: + except (RemoteException, TransportNotConnected): return return await self._handle_probe_result(peer, response) @@ -215,7 +215,7 @@ class IterativeFinder: await self._search_round() if self.running: self.delayed_calls.append(self.loop.call_later(delay, self._search)) - except (asyncio.CancelledError, StopAsyncIteration): + except (asyncio.CancelledError, StopAsyncIteration, TransportNotConnected): if self.running: self.loop.call_soon(self.aclose) diff --git a/lbrynet/dht/protocol/protocol.py b/lbrynet/dht/protocol/protocol.py index c0bdf6a2d..f1aab10ad 100644 --- a/lbrynet/dht/protocol/protocol.py +++ b/lbrynet/dht/protocol/protocol.py @@ -524,6 +524,10 @@ class KademliaProtocol(DatagramProtocol): response = await asyncio.wait_for(response_fut, self.rpc_timeout) self.peer_manager.report_last_replied(peer.address, peer.udp_port) return response + except asyncio.CancelledError: + if not response_fut.done(): + response_fut.cancel() + raise except (asyncio.TimeoutError, RemoteException): self.peer_manager.report_failure(peer.address, peer.udp_port) if self.peer_manager.peer_is_good(peer) is False: @@ -540,7 +544,7 @@ class KademliaProtocol(DatagramProtocol): async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]): - if not self.transport: + if not self.transport or self.transport.is_closing(): raise TransportNotConnected() data = message.bencode() diff --git a/lbrynet/error.py b/lbrynet/error.py index 234bb9229..a57252da0 100644 --- a/lbrynet/error.py +++ b/lbrynet/error.py @@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception): self.download = download +class ResolveTimeout(Exception): + def __init__(self, uri): + super().__init__(f'Failed to resolve "{uri}" within the timeout') + self.uri = uri + + class RequestCanceledError(Exception): pass diff --git a/lbrynet/extras/daemon/Components.py b/lbrynet/extras/daemon/Components.py index 95a2eade9..fb3a6485f 100644 --- a/lbrynet/extras/daemon/Components.py +++ b/lbrynet/extras/daemon/Components.py @@ -16,7 +16,7 @@ from lbrynet import utils from lbrynet.conf import HEADERS_FILE_SHA256_CHECKSUM from lbrynet.dht.node import Node from lbrynet.dht.blob_announcer import BlobAnnouncer -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob_exchange.server import BlobServer from lbrynet.stream.stream_manager import StreamManager from lbrynet.extras.daemon.Component import Component @@ -278,10 +278,10 @@ class BlobComponent(Component): def __init__(self, component_manager): super().__init__(component_manager) - self.blob_manager: BlobFileManager = None + self.blob_manager: BlobManager = None @property - def component(self) -> typing.Optional[BlobFileManager]: + def component(self) -> typing.Optional[BlobManager]: return self.blob_manager async def start(self): @@ -291,8 +291,10 @@ class BlobComponent(Component): dht_node: Node = self.component_manager.get_component(DHT_COMPONENT) if dht_node: data_store = dht_node.protocol.data_store - self.blob_manager = BlobFileManager(asyncio.get_event_loop(), os.path.join(self.conf.data_dir, "blobfiles"), - storage, data_store) + blob_dir = os.path.join(self.conf.data_dir, 'blobfiles') + if not os.path.isdir(blob_dir): + os.mkdir(blob_dir) + self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, self.conf, data_store) return await self.blob_manager.setup() async def stop(self): @@ -451,7 +453,7 @@ class PeerProtocolServerComponent(Component): async def start(self): log.info("start blob server") upnp = self.component_manager.get_component(UPNP_COMPONENT) - blob_manager: BlobFileManager = self.component_manager.get_component(BLOB_COMPONENT) + blob_manager: BlobManager = self.component_manager.get_component(BLOB_COMPONENT) wallet: LbryWalletManager = self.component_manager.get_component(WALLET_COMPONENT) peer_port = self.conf.tcp_port address = await wallet.get_unused_address() @@ -485,7 +487,7 @@ class UPnPComponent(Component): while True: if now: await self._maintain_redirects() - await asyncio.sleep(360) + await asyncio.sleep(360, loop=self.component_manager.loop) async def _maintain_redirects(self): # setup the gateway if necessary @@ -528,7 +530,7 @@ class UPnPComponent(Component): self.upnp_redirects.update(upnp_redirects) except (asyncio.TimeoutError, UPnPError): self.upnp = None - return self._maintain_redirects() + return elif self.upnp: # check existing redirects are still active found = set() mappings = await self.upnp.get_redirects() diff --git a/lbrynet/extras/daemon/Daemon.py b/lbrynet/extras/daemon/Daemon.py index e8e65d041..0a48e9dc3 100644 --- a/lbrynet/extras/daemon/Daemon.py +++ b/lbrynet/extras/daemon/Daemon.py @@ -40,7 +40,7 @@ from lbrynet.extras.daemon.comment_client import jsonrpc_batch, jsonrpc_post, rp if typing.TYPE_CHECKING: - from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.blob_manager import BlobManager from lbrynet.dht.node import Node from lbrynet.extras.daemon.Components import UPnPComponent from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager @@ -48,6 +48,7 @@ if typing.TYPE_CHECKING: from lbrynet.wallet.manager import LbryWalletManager from lbrynet.wallet.ledger import MainNetLedger from lbrynet.stream.stream_manager import StreamManager + from lbrynet.stream.managed_stream import ManagedStream log = logging.getLogger(__name__) @@ -272,6 +273,9 @@ class Daemon(metaclass=JSONRPCServerType): app = web.Application() app.router.add_get('/lbryapi', self.handle_old_jsonrpc) app.router.add_post('/lbryapi', self.handle_old_jsonrpc) + app.router.add_get('/get/{claim_name}', self.handle_stream_get_request) + app.router.add_get('/get/{claim_name}/{claim_id}', self.handle_stream_get_request) + app.router.add_get('/stream/{sd_hash}', self.handle_stream_range_request) app.router.add_post('/', self.handle_old_jsonrpc) self.runner = web.AppRunner(app) @@ -296,7 +300,7 @@ class Daemon(metaclass=JSONRPCServerType): return self.component_manager.get_component(EXCHANGE_RATE_MANAGER_COMPONENT) @property - def blob_manager(self) -> typing.Optional['BlobFileManager']: + def blob_manager(self) -> typing.Optional['BlobManager']: return self.component_manager.get_component(BLOB_COMPONENT) @property @@ -452,6 +456,72 @@ class Daemon(metaclass=JSONRPCServerType): content_type='application/json' ) + async def handle_stream_get_request(self, request: web.Request): + name_and_claim_id = request.path.split("/get/")[1] + if "/" not in name_and_claim_id: + uri = f"lbry://{name_and_claim_id}" + else: + name, claim_id = name_and_claim_id.split("/") + uri = f"lbry://{name}#{claim_id}" + stream = await self.jsonrpc_get(uri) + if isinstance(stream, dict): + raise web.HTTPServerError(text=stream['error']) + raise web.HTTPFound(f"/stream/{stream.sd_hash}") + + @staticmethod + def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str], + int, int]: + if '=' in get_range: + get_range = get_range.split('=')[1] + start, end = get_range.split('-') + size = 0 + for blob in stream.descriptor.blobs[:-1]: + size += blob.length - 1 + start = int(start) + end = int(end) if end else size - 1 + skip_blobs = start // 2097150 + skip = skip_blobs * 2097151 + start = skip + final_size = end - start + 1 + + headers = { + 'Accept-Ranges': 'bytes', + 'Content-Range': f'bytes {start}-{end}/{size}', + 'Content-Length': str(final_size), + 'Content-Type': stream.mime_type + } + return headers, size, skip_blobs + + async def handle_stream_range_request(self, request: web.Request): + sd_hash = request.path.split("/stream/")[1] + if sd_hash not in self.stream_manager.streams: + return web.HTTPNotFound() + stream = self.stream_manager.streams[sd_hash] + if stream.status == 'stopped': + await self.stream_manager.start_stream(stream) + if stream.delayed_stop: + stream.delayed_stop.cancel() + headers, size, skip_blobs = self.prepare_range_response_headers( + request.headers.get('range', 'bytes=0-'), stream + ) + response = web.StreamResponse( + status=206, + headers=headers + ) + await response.prepare(request) + wrote = 0 + async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs): + log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1) + if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size): + decrypted += b'\x00' * (size - len(decrypted) - wrote) + await response.write_eof(decrypted) + break + else: + await response.write(decrypted) + wrote += len(decrypted) + response.force_close() + return response + async def _process_rpc_call(self, data): args = data.get('params', {}) @@ -827,24 +897,26 @@ class Daemon(metaclass=JSONRPCServerType): @requires(WALLET_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT, STREAM_MANAGER_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) - async def jsonrpc_get(self, uri, file_name=None, timeout=None): + async def jsonrpc_get(self, uri, file_name=None, timeout=None, save_file=None): """ Download stream from a LBRY name. Usage: - get [ | --file_name=] [ | --timeout=] + get [ | --file_name=] [ | --timeout=] [--save_file] Options: --uri= : (str) uri of the content to download - --file_name= : (str) specified name for the downloaded file + --file_name= : (str) specified name for the downloaded file, overrides the stream file name --timeout= : (int) download timeout in number of seconds + --save_file : (bool) save the file to the downloads directory Returns: {File} """ + save_file = save_file if save_file is not None else self.conf.save_files try: stream = await self.stream_manager.download_stream_from_uri( - uri, self.exchange_rate_manager, file_name, timeout + uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file ) if not stream: raise DownloadSDTimeout(uri) @@ -1534,6 +1606,45 @@ class Daemon(metaclass=JSONRPCServerType): result = True return result + @requires(STREAM_MANAGER_COMPONENT) + async def jsonrpc_file_save(self, file_name=None, download_directory=None, **kwargs): + """ + Output a download to a file + + Usage: + file_save [--file_name=] [--download_directory=] [--sd_hash=] + [--stream_hash=] [--rowid=] [--claim_id=] [--txid=] + [--nout=] [--claim_name=] [--channel_claim_id=] + [--channel_name=] + + Options: + --file_name= : (str) delete by file name in downloads folder + --download_directory= : (str) delete by file name in downloads folder + --sd_hash= : (str) delete by file sd hash + --stream_hash= : (str) delete by file stream hash + --rowid= : (int) delete by file row id + --claim_id= : (str) delete by file claim id + --txid= : (str) delete by file claim txid + --nout= : (int) delete by file claim nout + --claim_name= : (str) delete by file claim name + --channel_claim_id= : (str) delete by file channel claim id + --channel_name= : (str) delete by file channel claim name + + Returns: {File} + """ + + streams = self.stream_manager.get_filtered_streams(**kwargs) + + if len(streams) > 1: + log.warning("There are %i matching files, use narrower filters to select one", len(streams)) + return False + if not streams: + log.warning("There is no file to save") + return False + stream = streams[0] + await stream.save_file(file_name, download_directory) + return stream + CLAIM_DOC = """ List and search all types of claims. """ @@ -2210,9 +2321,7 @@ class Daemon(metaclass=JSONRPCServerType): await self.storage.save_claims([self._old_get_temp_claim_info( tx, new_txo, claim_address, claim, name, dewies_to_lbc(amount) )]) - stream_hash = await self.storage.get_stream_hash_for_sd_hash(claim.stream.source.sd_hash) - if stream_hash: - await self.storage.save_content_claim(stream_hash, new_txo.id) + await self.storage.save_content_claim(file_stream.stream_hash, new_txo.id) await self.analytics_manager.send_claim_action('publish') else: await account.ledger.release_tx(tx) @@ -2365,6 +2474,9 @@ class Daemon(metaclass=JSONRPCServerType): file_stream = await self.stream_manager.create_stream(file_path) new_txo.claim.stream.source.sd_hash = file_stream.sd_hash new_txo.script.generate() + stream_hash = file_stream.stream_hash + else: + stream_hash = await self.storage.get_stream_hash_for_sd_hash(old_txo.claim.stream.source.sd_hash) if channel: new_txo.sign(channel) await tx.sign([account]) @@ -2372,9 +2484,7 @@ class Daemon(metaclass=JSONRPCServerType): await self.storage.save_claims([self._old_get_temp_claim_info( tx, new_txo, claim_address, new_txo.claim, new_txo.claim_name, dewies_to_lbc(amount) )]) - stream_hash = await self.storage.get_stream_hash_for_sd_hash(new_txo.claim.stream.source.sd_hash) - if stream_hash: - await self.storage.save_content_claim(stream_hash, new_txo.id) + await self.storage.save_content_claim(stream_hash, new_txo.id) await self.analytics_manager.send_claim_action('publish') else: await account.ledger.release_tx(tx) @@ -2934,9 +3044,9 @@ class Daemon(metaclass=JSONRPCServerType): else: blobs = list(self.blob_manager.completed_blob_hashes) if needed: - blobs = [blob_hash for blob_hash in blobs if not self.blob_manager.get_blob(blob_hash).get_is_verified()] + blobs = [blob_hash for blob_hash in blobs if not self.blob_manager.is_blob_verified(blob_hash)] if finished: - blobs = [blob_hash for blob_hash in blobs if self.blob_manager.get_blob(blob_hash).get_is_verified()] + blobs = [blob_hash for blob_hash in blobs if self.blob_manager.is_blob_verified(blob_hash)] page_size = page_size or len(blobs) page = page or 0 start_index = page * page_size diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index 0a5c25ff8..28e8271ba 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -54,10 +54,6 @@ class StoredStreamClaim: def nout(self) -> typing.Optional[int]: return None if not self.outpoint else int(self.outpoint.split(":")[1]) - @property - def metadata(self) -> typing.Optional[typing.Dict]: - return None if not self.claim else self.claim.claim_dict['stream']['metadata'] - def as_dict(self) -> typing.Dict: return { "name": self.claim_name, @@ -195,12 +191,16 @@ def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor transaction.executemany("delete from blob where blob_hash=?", blob_hashes) -def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: str, download_directory: str, - data_payment_rate: float, status: str) -> int: +def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str], + download_directory: typing.Optional[str], data_payment_rate: float, status: str) -> int: + if not file_name and not download_directory: + encoded_file_name, encoded_download_dir = "{stream}", "{stream}" + else: + encoded_file_name = binascii.hexlify(file_name.encode()).decode() + encoded_download_dir = binascii.hexlify(download_directory.encode()).decode() transaction.execute( "insert or replace into file values (?, ?, ?, ?, ?)", - (stream_hash, binascii.hexlify(file_name.encode()).decode(), - binascii.hexlify(download_directory.encode()).decode(), data_payment_rate, status) + (stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status) ) return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0] @@ -296,27 +296,28 @@ class SQLiteStorage(SQLiteMixin): # # # # # # # # # blob functions # # # # # # # # # - def add_completed_blob(self, blob_hash: str, length: int): - def _add_blob(transaction: sqlite3.Connection): - transaction.execute( + async def add_blobs(self, *blob_hashes_and_lengths: typing.Tuple[str, int], finished=False): + def _add_blobs(transaction: sqlite3.Connection): + transaction.executemany( "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", - (blob_hash, length, 0, 0, "pending", 0, 0) + [ + (blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0) + for blob_hash, length in blob_hashes_and_lengths + ] ) - transaction.execute( - "update blob set status='finished' where blob.blob_hash=?", (blob_hash, ) - ) - return self.db.run(_add_blob) + if finished: + transaction.executemany( + "update blob set status='finished' where blob.blob_hash=?", [ + (blob_hash, ) for blob_hash, _ in blob_hashes_and_lengths + ] + ) + return await self.db.run(_add_blobs) def get_blob_status(self, blob_hash: str): return self.run_and_return_one_or_none( "select status from blob where blob_hash=?", blob_hash ) - def add_known_blob(self, blob_hash: str, length: int): - return self.db.execute( - "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", (blob_hash, length, 0, 0, "pending", 0, 0) - ) - def should_announce(self, blob_hash: str): return self.run_and_return_one_or_none( "select should_announce from blob where blob_hash=?", blob_hash @@ -419,6 +420,26 @@ class SQLiteStorage(SQLiteMixin): } return self.db.run(_sync_blobs) + def sync_files_to_blobs(self): + def _sync_blobs(transaction: sqlite3.Connection): + transaction.executemany( + "update file set status='stopped' where stream_hash=?", + transaction.execute( + "select distinct sb.stream_hash from stream_blob sb " + "inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'" + ).fetchall() + ) + return self.db.run(_sync_blobs) + + def set_files_as_streaming(self, stream_hashes: typing.List[str]): + def _set_streaming(transaction: sqlite3.Connection): + transaction.executemany( + "update file set file_name='{stream}', download_directory='{stream}' where stream_hash=?", + [(stream_hash, ) for stream_hash in stream_hashes] + ) + + return self.db.run(_set_streaming) + # # # # # # # # # stream functions # # # # # # # # # async def stream_exists(self, sd_hash: str) -> bool: @@ -481,6 +502,12 @@ class SQLiteStorage(SQLiteMixin): "select stream_hash from stream where sd_hash = ?", sd_blob_hash ) + def get_stream_info_for_sd_hash(self, sd_blob_hash): + return self.run_and_return_one_or_none( + "select stream_hash, stream_name, suggested_filename, stream_key from stream where sd_hash = ?", + sd_blob_hash + ) + def delete_stream(self, descriptor: 'StreamDescriptor'): return self.db.run_with_foreign_keys_disabled(delete_stream, descriptor) @@ -492,7 +519,8 @@ class SQLiteStorage(SQLiteMixin): stream_hash, file_name, download_directory, data_payment_rate, status="running" ) - def save_published_file(self, stream_hash: str, file_name: str, download_directory: str, data_payment_rate: float, + def save_published_file(self, stream_hash: str, file_name: typing.Optional[str], + download_directory: typing.Optional[str], data_payment_rate: float, status="finished") -> typing.Awaitable[int]: return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status) @@ -503,10 +531,15 @@ class SQLiteStorage(SQLiteMixin): log.debug("update file status %s -> %s", stream_hash, new_status) return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash)) - def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: str, file_name: str): - return self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", ( - binascii.hexlify(download_dir.encode()).decode(), binascii.hexlify(file_name.encode()).decode(), - stream_hash + async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str], + file_name: typing.Optional[str]): + if not file_name or not download_dir: + encoded_file_name, encoded_download_dir = "{stream}", "{stream}" + else: + encoded_file_name = binascii.hexlify(file_name.encode()).decode() + encoded_download_dir = binascii.hexlify(download_dir.encode()).decode() + return await self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", ( + encoded_download_dir, encoded_file_name, stream_hash, )) async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']], diff --git a/lbrynet/stream/assembler.py b/lbrynet/stream/assembler.py deleted file mode 100644 index a972b791f..000000000 --- a/lbrynet/stream/assembler.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import binascii -import logging -import typing -import asyncio -from lbrynet.blob import MAX_BLOB_SIZE -from lbrynet.stream.descriptor import StreamDescriptor -if typing.TYPE_CHECKING: - from lbrynet.blob.blob_manager import BlobFileManager - from lbrynet.blob.blob_info import BlobInfo - from lbrynet.blob.blob_file import BlobFile - - -log = logging.getLogger(__name__) - - -def _get_next_available_file_name(download_directory: str, file_name: str) -> str: - base_name, ext = os.path.splitext(os.path.basename(file_name)) - i = 0 - while os.path.isfile(os.path.join(download_directory, file_name)): - i += 1 - file_name = "%s_%i%s" % (base_name, i, ext) - - return file_name - - -async def get_next_available_file_name(loop: asyncio.BaseEventLoop, download_directory: str, file_name: str) -> str: - return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) - - -class StreamAssembler: - def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', sd_hash: str, - output_file_name: typing.Optional[str] = None): - self.output_file_name = output_file_name - self.loop = loop - self.blob_manager = blob_manager - self.sd_hash = sd_hash - self.sd_blob: 'BlobFile' = None - self.descriptor: StreamDescriptor = None - self.got_descriptor = asyncio.Event(loop=self.loop) - self.wrote_bytes_event = asyncio.Event(loop=self.loop) - self.stream_finished_event = asyncio.Event(loop=self.loop) - self.output_path = '' - self.stream_handle = None - self.written_bytes: int = 0 - - async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str): - if not blob or not self.stream_handle or self.stream_handle.closed: - return False - - def _decrypt_and_write(): - offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1) - self.stream_handle.seek(offset) - _decrypted = blob.decrypt( - binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode()) - ) - self.stream_handle.write(_decrypted) - self.stream_handle.flush() - self.written_bytes += len(_decrypted) - log.debug("decrypted %s", blob.blob_hash[:8]) - - await self.loop.run_in_executor(None, _decrypt_and_write) - return True - - async def setup(self): - pass - - async def after_got_descriptor(self): - pass - - async def after_finished(self): - pass - - async def assemble_decrypted_stream(self, output_dir: str, output_file_name: typing.Optional[str] = None): - if not os.path.isdir(output_dir): - raise OSError(f"output directory does not exist: '{output_dir}' '{output_file_name}'") - await self.setup() - self.sd_blob = await self.get_blob(self.sd_hash) - self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_manager.blob_dir, - self.sd_blob) - await self.after_got_descriptor() - self.output_file_name = output_file_name or self.descriptor.suggested_file_name - self.output_file_name = await get_next_available_file_name(self.loop, output_dir, self.output_file_name) - self.output_path = os.path.join(output_dir, self.output_file_name) - if not self.got_descriptor.is_set(): - self.got_descriptor.set() - await self.blob_manager.storage.store_stream( - self.sd_blob, self.descriptor - ) - await self.blob_manager.blob_completed(self.sd_blob) - written_blobs = None - save_tasks = [] - try: - with open(self.output_path, 'wb') as stream_handle: - self.stream_handle = stream_handle - for i, blob_info in enumerate(self.descriptor.blobs[:-1]): - if blob_info.blob_num != i: - log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash) - return - while self.stream_handle and not self.stream_handle.closed: - try: - blob = await self.get_blob(blob_info.blob_hash, blob_info.length) - if blob and blob.length != blob_info.length: - log.warning("Found incomplete, deleting: %s", blob_info.blob_hash) - await self.blob_manager.delete_blobs([blob_info.blob_hash]) - continue - if await self._decrypt_blob(blob, blob_info, self.descriptor.key): - save_tasks.append(asyncio.ensure_future(self.blob_manager.blob_completed(blob))) - written_blobs = i - if not self.wrote_bytes_event.is_set(): - self.wrote_bytes_event.set() - log.debug("written %i/%i", written_blobs, len(self.descriptor.blobs) - 2) - break - except FileNotFoundError: - log.debug("stream assembler stopped") - return - except (ValueError, IOError, OSError): - log.warning("failed to decrypt blob %s for stream %s", blob_info.blob_hash, - self.descriptor.sd_hash) - continue - finally: - if written_blobs == len(self.descriptor.blobs) - 2: - log.debug("finished decrypting and assembling stream") - if save_tasks: - await asyncio.wait(save_tasks) - await self.after_finished() - self.stream_finished_event.set() - else: - log.debug("stream decryption and assembly did not finish (%i/%i blobs are done)", written_blobs or 0, - len(self.descriptor.blobs) - 2) - if self.output_path and os.path.isfile(self.output_path): - log.debug("erasing incomplete file assembly: %s", self.output_path) - os.unlink(self.output_path) - - async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': - return self.blob_manager.get_blob(blob_hash, length) diff --git a/lbrynet/stream/descriptor.py b/lbrynet/stream/descriptor.py index d73369110..a2f4d1da2 100644 --- a/lbrynet/stream/descriptor.py +++ b/lbrynet/stream/descriptor.py @@ -8,7 +8,7 @@ from collections import OrderedDict from cryptography.hazmat.primitives.ciphers.algorithms import AES from lbrynet.blob import MAX_BLOB_SIZE from lbrynet.blob.blob_info import BlobInfo -from lbrynet.blob.blob_file import BlobFile +from lbrynet.blob.blob_file import AbstractBlob, BlobFile from lbrynet.cryptoutils import get_lbry_hash_obj from lbrynet.error import InvalidStreamDescriptorError @@ -108,29 +108,30 @@ class StreamDescriptor: h.update(self.old_sort_json()) return h.hexdigest() - async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None, - old_sort: typing.Optional[bool] = False): + async def make_sd_blob(self, blob_file_obj: typing.Optional[AbstractBlob] = None, + old_sort: typing.Optional[bool] = False, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None): sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash() if not old_sort: sd_data = self.as_json() else: sd_data = self.old_sort_json() - sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data)) + sd_blob = blob_file_obj or BlobFile(self.loop, sd_hash, len(sd_data), blob_completed_callback, self.blob_dir) if blob_file_obj: blob_file_obj.set_length(len(sd_data)) if not sd_blob.get_is_verified(): - writer = sd_blob.open_for_writing() + writer = sd_blob.get_blob_writer() writer.write(sd_data) + await sd_blob.verified.wait() sd_blob.close() return sd_blob @classmethod def _from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str, - blob: BlobFile) -> 'StreamDescriptor': - assert os.path.isfile(blob.file_path) - with open(blob.file_path, 'rb') as f: - json_bytes = f.read() + blob: AbstractBlob) -> 'StreamDescriptor': + with blob.reader_context() as blob_reader: + json_bytes = blob_reader.read() try: decoded = json.loads(json_bytes.decode()) except json.JSONDecodeError: @@ -160,8 +161,10 @@ class StreamDescriptor: @classmethod async def from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str, - blob: BlobFile) -> 'StreamDescriptor': - return await loop.run_in_executor(None, lambda: cls._from_stream_descriptor_blob(loop, blob_dir, blob)) + blob: AbstractBlob) -> 'StreamDescriptor': + if not blob.is_readable(): + raise InvalidStreamDescriptorError(f"unreadable/missing blob: {blob.blob_hash}") + return await loop.run_in_executor(None, cls._from_stream_descriptor_blob, loop, blob_dir, blob) @staticmethod def get_blob_hashsum(b: typing.Dict): @@ -194,11 +197,12 @@ class StreamDescriptor: return h.hexdigest() @classmethod - async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str, - file_path: str, key: typing.Optional[bytes] = None, - iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None, - old_sort: bool = False) -> 'StreamDescriptor': - + async def create_stream( + cls, loop: asyncio.BaseEventLoop, blob_dir: str, file_path: str, key: typing.Optional[bytes] = None, + iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None, + old_sort: bool = False, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], + asyncio.Task]] = None) -> 'StreamDescriptor': blobs: typing.List[BlobInfo] = [] iv_generator = iv_generator or random_iv_generator() @@ -207,7 +211,7 @@ class StreamDescriptor: for blob_bytes in file_reader(file_path): blob_num += 1 blob_info = await BlobFile.create_from_unencrypted( - loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num + loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num, blob_completed_callback ) blobs.append(blob_info) blobs.append( @@ -216,7 +220,7 @@ class StreamDescriptor: loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path), blobs ) - sd_blob = await descriptor.make_sd_blob(old_sort=old_sort) + sd_blob = await descriptor.make_sd_blob(old_sort=old_sort, blob_completed_callback=blob_completed_callback) descriptor.sd_hash = sd_blob.blob_hash return descriptor @@ -228,7 +232,7 @@ class StreamDescriptor: return self.lower_bound_decrypted_length() + (AES.block_size // 8) @classmethod - async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str, + async def recover(cls, blob_dir: str, sd_blob: 'AbstractBlob', stream_hash: str, stream_name: str, suggested_file_name: str, key: str, blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']: descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name, diff --git a/lbrynet/stream/downloader.py b/lbrynet/stream/downloader.py index 5aeead15c..8c0653d4d 100644 --- a/lbrynet/stream/downloader.py +++ b/lbrynet/stream/downloader.py @@ -1,108 +1,132 @@ -import os import asyncio import typing import logging +import binascii +from lbrynet.error import DownloadSDTimeout from lbrynet.utils import resolve_host -from lbrynet.stream.assembler import StreamAssembler from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.blob_exchange.downloader import BlobDownloader from lbrynet.dht.peer import KademliaPeer if typing.TYPE_CHECKING: from lbrynet.conf import Config from lbrynet.dht.node import Node - from lbrynet.blob.blob_manager import BlobFileManager - from lbrynet.blob.blob_file import BlobFile + from lbrynet.blob.blob_manager import BlobManager + from lbrynet.blob.blob_file import AbstractBlob + from lbrynet.blob.blob_info import BlobInfo log = logging.getLogger(__name__) -def drain_into(a: list, b: list): - while a: - b.append(a.pop()) - - -class StreamDownloader(StreamAssembler): - def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', sd_hash: str, - output_dir: typing.Optional[str] = None, output_file_name: typing.Optional[str] = None): - super().__init__(loop, blob_manager, sd_hash, output_file_name) +class StreamDownloader: + def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', sd_hash: str, + descriptor: typing.Optional[StreamDescriptor] = None): + self.loop = loop self.config = config - self.output_dir = output_dir or self.config.download_dir - self.output_file_name = output_file_name - self.blob_downloader: typing.Optional[BlobDownloader] = None - self.search_queue = asyncio.Queue(loop=loop) - self.peer_queue = asyncio.Queue(loop=loop) - self.accumulate_task: typing.Optional[asyncio.Task] = None - self.descriptor: typing.Optional[StreamDescriptor] + self.blob_manager = blob_manager + self.sd_hash = sd_hash + self.search_queue = asyncio.Queue(loop=loop) # blob hashes to feed into the iterative finder + self.peer_queue = asyncio.Queue(loop=loop) # new peers to try + self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue) + self.descriptor: typing.Optional[StreamDescriptor] = descriptor self.node: typing.Optional['Node'] = None - self.assemble_task: typing.Optional[asyncio.Task] = None + self.accumulate_task: typing.Optional[asyncio.Task] = None self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None self.fixed_peers_delay: typing.Optional[float] = None self.added_fixed_peers = False + self.time_to_descriptor: typing.Optional[float] = None + self.time_to_first_bytes: typing.Optional[float] = None - async def setup(self): # start the peer accumulator and initialize the downloader - if self.blob_downloader: - raise Exception("downloader is already set up") - if self.node: - _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) - self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue) - self.search_queue.put_nowait(self.sd_hash) + async def add_fixed_peers(self): + def _delayed_add_fixed_peers(): + self.added_fixed_peers = True + self.peer_queue.put_nowait([ + KademliaPeer(self.loop, address=address, tcp_port=port + 1) + for address, port in addresses + ]) - async def after_got_descriptor(self): - self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash) - log.info("added head blob to search") - - async def after_finished(self): - log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path) - await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished') - self.blob_downloader.close() - - def stop(self): - if self.accumulate_task: - self.accumulate_task.cancel() - self.accumulate_task = None - if self.assemble_task: - self.assemble_task.cancel() - self.assemble_task = None - if self.fixed_peers_handle: - self.fixed_peers_handle.cancel() - self.fixed_peers_handle = None - self.blob_downloader = None - if self.stream_handle: - if not self.stream_handle.closed: - self.stream_handle.close() - self.stream_handle = None - if not self.stream_finished_event.is_set() and self.wrote_bytes_event.is_set() and \ - self.output_path and os.path.isfile(self.output_path): - os.remove(self.output_path) - - async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': - return await self.blob_downloader.download_blob(blob_hash, length) - - def add_fixed_peers(self): - async def _add_fixed_peers(): - addresses = [ - (await resolve_host(url, port + 1, proto='tcp'), port) - for url, port in self.config.reflector_servers - ] - - def _delayed_add_fixed_peers(): - self.added_fixed_peers = True - self.peer_queue.put_nowait([ - KademliaPeer(self.loop, address=address, tcp_port=port + 1) - for address, port in addresses - ]) - - self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers) if not self.config.reflector_servers: return + addresses = [ + (await resolve_host(url, port + 1, proto='tcp'), port) + for url, port in self.config.reflector_servers + ] if 'dht' in self.config.components_to_skip or not self.node or not \ len(self.node.protocol.routing_table.get_peers()): self.fixed_peers_delay = 0.0 else: self.fixed_peers_delay = self.config.fixed_peer_delay - self.loop.create_task(_add_fixed_peers()) - def download(self, node: typing.Optional['Node'] = None): - self.node = node - self.assemble_task = self.loop.create_task(self.assemble_decrypted_stream(self.config.download_dir)) - self.add_fixed_peers() + self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers) + + async def load_descriptor(self): + # download or get the sd blob + sd_blob = self.blob_manager.get_blob(self.sd_hash) + if not sd_blob.get_is_verified(): + try: + now = self.loop.time() + sd_blob = await asyncio.wait_for( + self.blob_downloader.download_blob(self.sd_hash), + self.config.blob_download_timeout, loop=self.loop + ) + log.info("downloaded sd blob %s", self.sd_hash) + self.time_to_descriptor = self.loop.time() - now + except asyncio.TimeoutError: + raise DownloadSDTimeout(self.sd_hash) + + # parse the descriptor + self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( + self.loop, self.blob_manager.blob_dir, sd_blob + ) + log.info("loaded stream manifest %s", self.sd_hash) + + async def start(self, node: typing.Optional['Node'] = None): + # set up peer accumulation + if node: + self.node = node + _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) + await self.add_fixed_peers() + # start searching for peers for the sd hash + self.search_queue.put_nowait(self.sd_hash) + log.info("searching for peers for stream %s", self.sd_hash) + + if not self.descriptor: + await self.load_descriptor() + + # add the head blob to the peer search + self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash) + log.info("added head blob to peer search for stream %s", self.sd_hash) + + if not await self.blob_manager.storage.stream_exists(self.sd_hash): + await self.blob_manager.storage.store_stream( + self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor + ) + + async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob': + if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]): + raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}") + blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length) + return blob + + def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: + return blob.decrypt( + binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) + ) + + async def read_blob(self, blob_info: 'BlobInfo') -> bytes: + start = None + if self.time_to_first_bytes is None: + start = self.loop.time() + blob = await self.download_stream_blob(blob_info) + decrypted = self.decrypt_blob(blob_info, blob) + if start: + self.time_to_first_bytes = self.loop.time() - start + return decrypted + + def stop(self): + if self.accumulate_task: + self.accumulate_task.cancel() + self.accumulate_task = None + if self.fixed_peers_handle: + self.fixed_peers_handle.cancel() + self.fixed_peers_handle = None + self.blob_downloader.close() diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index c5d41d58b..b722ac69b 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -10,41 +10,75 @@ from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.reflector.client import StreamReflectorClient from lbrynet.extras.daemon.storage import StoredStreamClaim if typing.TYPE_CHECKING: + from lbrynet.conf import Config from lbrynet.schema.claim import Claim - from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.blob_manager import BlobManager + from lbrynet.blob.blob_info import BlobInfo from lbrynet.dht.node import Node + from lbrynet.extras.daemon.analytics import AnalyticsManager + from lbrynet.wallet.transaction import Transaction log = logging.getLogger(__name__) +def _get_next_available_file_name(download_directory: str, file_name: str) -> str: + base_name, ext = os.path.splitext(os.path.basename(file_name)) + i = 0 + while os.path.isfile(os.path.join(download_directory, file_name)): + i += 1 + file_name = "%s_%i%s" % (base_name, i, ext) + + return file_name + + +async def get_next_available_file_name(loop: asyncio.BaseEventLoop, download_directory: str, file_name: str) -> str: + return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) + + class ManagedStream: STATUS_RUNNING = "running" STATUS_STOPPED = "stopped" STATUS_FINISHED = "finished" - def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', rowid: int, - descriptor: 'StreamDescriptor', download_directory: str, file_name: typing.Optional[str], - downloader: typing.Optional[StreamDownloader] = None, + def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', + sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, - download_id: typing.Optional[str] = None): + download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None, + descriptor: typing.Optional[StreamDescriptor] = None, + content_fee: typing.Optional['Transaction'] = None, + analytics_manager: typing.Optional['AnalyticsManager'] = None): self.loop = loop + self.config = config self.blob_manager = blob_manager - self.rowid = rowid + self.sd_hash = sd_hash self.download_directory = download_directory self._file_name = file_name - self.descriptor = descriptor - self.downloader = downloader - self.stream_hash = descriptor.stream_hash - self.stream_claim_info = claim self._status = status - - self.fully_reflected = asyncio.Event(loop=self.loop) - self.tx = None + self.stream_claim_info = claim self.download_id = download_id or binascii.hexlify(generate_id()).decode() + self.rowid = rowid + self.written_bytes = 0 + self.content_fee = content_fee + self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) + self.analytics_manager = analytics_manager + self.fully_reflected = asyncio.Event(loop=self.loop) + self.file_output_task: typing.Optional[asyncio.Task] = None + self.delayed_stop: typing.Optional[asyncio.Handle] = None + self.saving = asyncio.Event(loop=self.loop) + self.finished_writing = asyncio.Event(loop=self.loop) + self.started_writing = asyncio.Event(loop=self.loop) + + @property + def descriptor(self) -> StreamDescriptor: + return self.downloader.descriptor + + @property + def stream_hash(self) -> str: + return self.descriptor.stream_hash @property def file_name(self) -> typing.Optional[str]: - return self.downloader.output_file_name if self.downloader else self._file_name + return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None) @property def status(self) -> str: @@ -95,7 +129,7 @@ class ManagedStream: return None if not self.stream_claim_info else self.stream_claim_info.claim_name @property - def metadata(self) ->typing.Optional[typing.Dict]: + def metadata(self) -> typing.Optional[typing.Dict]: return None if not self.stream_claim_info else self.stream_claim_info.claim.stream.to_dict() @property @@ -105,37 +139,35 @@ class ManagedStream: @property def blobs_completed(self) -> int: - return sum([1 if self.blob_manager.get_blob(b.blob_hash).get_is_verified() else 0 + return sum([1 if self.blob_manager.is_blob_verified(b.blob_hash) else 0 for b in self.descriptor.blobs[:-1]]) @property def blobs_in_stream(self) -> int: return len(self.descriptor.blobs) - 1 - @property - def sd_hash(self): - return self.descriptor.sd_hash - @property def blobs_remaining(self) -> int: return self.blobs_in_stream - self.blobs_completed @property def full_path(self) -> typing.Optional[str]: - return os.path.join(self.download_directory, os.path.basename(self.file_name)) if self.file_name else None + return os.path.join(self.download_directory, os.path.basename(self.file_name)) \ + if self.file_name and self.download_directory else None @property def output_file_exists(self): return os.path.isfile(self.full_path) if self.full_path else False - def as_dict(self) -> typing.Dict: - full_path = self.full_path if self.output_file_exists else None - mime_type = guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0] + @property + def mime_type(self): + return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0] - if self.downloader and self.downloader.written_bytes: - written_bytes = self.downloader.written_bytes - elif full_path: - written_bytes = os.stat(full_path).st_size + def as_dict(self) -> typing.Dict: + if self.written_bytes: + written_bytes = self.written_bytes + elif self.output_file_exists: + written_bytes = os.stat(self.full_path).st_size else: written_bytes = None return { @@ -143,14 +175,13 @@ class ManagedStream: 'file_name': self.file_name, 'download_directory': self.download_directory, 'points_paid': 0.0, - 'tx': self.tx, 'stopped': not self.running, 'stream_hash': self.stream_hash, 'stream_name': self.descriptor.stream_name, 'suggested_file_name': self.descriptor.suggested_file_name, 'sd_hash': self.descriptor.sd_hash, - 'download_path': full_path, - 'mime_type': mime_type, + 'download_path': self.full_path, + 'mime_type': self.mime_type, 'key': self.descriptor.key, 'total_bytes_lower_bound': self.descriptor.lower_bound_decrypted_length(), 'total_bytes': self.descriptor.upper_bound_decrypted_length(), @@ -167,36 +198,135 @@ class ManagedStream: 'protobuf': self.metadata_protobuf, 'channel_claim_id': self.channel_claim_id, 'channel_name': self.channel_name, - 'claim_name': self.claim_name + 'claim_name': self.claim_name, + 'content_fee': self.content_fee # TODO: this isn't in the database } @classmethod - async def create(cls, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', + async def create(cls, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', file_path: str, key: typing.Optional[bytes] = None, iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedStream': descriptor = await StreamDescriptor.create_stream( - loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator + loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator, + blob_completed_callback=blob_manager.blob_completed ) - sd_blob = blob_manager.get_blob(descriptor.sd_hash) await blob_manager.storage.store_stream( blob_manager.get_blob(descriptor.sd_hash), descriptor ) - await blob_manager.blob_completed(sd_blob) - for blob in descriptor.blobs[:-1]: - await blob_manager.blob_completed(blob_manager.get_blob(blob.blob_hash, blob.length)) row_id = await blob_manager.storage.save_published_file(descriptor.stream_hash, os.path.basename(file_path), os.path.dirname(file_path), 0) - return cls(loop, blob_manager, row_id, descriptor, os.path.dirname(file_path), os.path.basename(file_path), - status=cls.STATUS_FINISHED) + return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), + os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) - def start_download(self, node: typing.Optional['Node']): - self.downloader.download(node) - self.update_status(self.STATUS_RUNNING) + async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True, + file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): + await self.downloader.start(node) + if not save_file and not file_name: + if not await self.blob_manager.storage.file_exists(self.sd_hash): + self.rowid = await self.blob_manager.storage.save_downloaded_file( + self.stream_hash, None, None, 0.0 + ) + self.download_directory = None + self._file_name = None + self.update_status(ManagedStream.STATUS_RUNNING) + await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) + self.update_delayed_stop() + elif not os.path.isfile(self.full_path): + await self.save_file(file_name, download_directory) + await self.started_writing.wait() + + def update_delayed_stop(self): + def _delayed_stop(): + log.info("Stopping inactive download for stream %s", self.sd_hash) + self.stop_download() + + if self.delayed_stop: + self.delayed_stop.cancel() + self.delayed_stop = self.loop.call_later(60, _delayed_stop) + + async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[ + typing.Tuple['BlobInfo', bytes]]: + if start_blob_num >= len(self.descriptor.blobs[:-1]): + raise IndexError(start_blob_num) + for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]): + assert i + start_blob_num == blob_info.blob_num + if self.delayed_stop: + self.delayed_stop.cancel() + try: + decrypted = await self.downloader.read_blob(blob_info) + yield (blob_info, decrypted) + except asyncio.CancelledError: + if not self.saving.is_set() and not self.finished_writing.is_set(): + self.update_delayed_stop() + raise + + async def _save_file(self, output_path: str): + log.debug("save file %s -> %s", self.sd_hash, output_path) + self.saving.set() + self.finished_writing.clear() + self.started_writing.clear() + try: + with open(output_path, 'wb') as file_write_handle: + async for blob_info, decrypted in self.aiter_read_stream(): + log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) + file_write_handle.write(decrypted) + file_write_handle.flush() + self.written_bytes += len(decrypted) + if not self.started_writing.is_set(): + self.started_writing.set() + self.update_status(ManagedStream.STATUS_FINISHED) + await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED) + if self.analytics_manager: + self.loop.create_task(self.analytics_manager.send_download_finished( + self.download_id, self.claim_name, self.sd_hash + )) + self.finished_writing.set() + except Exception as err: + if os.path.isfile(output_path): + log.info("removing incomplete download %s for %s", output_path, self.sd_hash) + os.remove(output_path) + if not isinstance(err, asyncio.CancelledError): + log.exception("unexpected error encountered writing file for stream %s", self.sd_hash) + raise err + finally: + self.saving.clear() + + async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): + if self.file_output_task and not self.file_output_task.done(): + self.file_output_task.cancel() + if self.delayed_stop: + self.delayed_stop.cancel() + self.delayed_stop = None + self.download_directory = download_directory or self.download_directory or self.config.download_dir + if not self.download_directory: + raise ValueError("no directory to download to") + if not (file_name or self._file_name or self.descriptor.suggested_file_name): + raise ValueError("no file name to download to") + if not os.path.isdir(self.download_directory): + log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) + os.mkdir(self.download_directory) + self._file_name = await get_next_available_file_name( + self.loop, self.download_directory, + file_name or self._file_name or self.descriptor.suggested_file_name + ) + if not await self.blob_manager.storage.file_exists(self.sd_hash): + self.rowid = self.blob_manager.storage.save_downloaded_file( + self.stream_hash, self.file_name, self.download_directory, 0.0 + ) + else: + await self.blob_manager.storage.change_file_download_dir_and_file_name( + self.stream_hash, self.download_directory, self.file_name + ) + self.update_status(ManagedStream.STATUS_RUNNING) + await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) + self.written_bytes = 0 + self.file_output_task = self.loop.create_task(self._save_file(self.full_path)) def stop_download(self): - if self.downloader: - self.downloader.stop() - self.downloader = None + if self.file_output_task and not self.file_output_task.done(): + self.file_output_task.cancel() + self.file_output_task = None + self.downloader.stop() async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]: sent = [] @@ -212,7 +342,9 @@ class ManagedStream: self.fully_reflected.set() await self.blob_manager.storage.update_reflected_stream(self.sd_hash, f"{host}:{port}") return [] - we_have = [blob_hash for blob_hash in needed if blob_hash in self.blob_manager.completed_blob_hashes] + we_have = [ + blob_hash for blob_hash in needed if blob_hash in self.blob_manager.completed_blob_hashes + ] for blob_hash in we_have: await protocol.send_blob(blob_hash) sent.append(blob_hash) diff --git a/lbrynet/stream/reflector/client.py b/lbrynet/stream/reflector/client.py index 2c12b33d6..bd00ed412 100644 --- a/lbrynet/stream/reflector/client.py +++ b/lbrynet/stream/reflector/client.py @@ -4,7 +4,7 @@ import logging import typing if typing.TYPE_CHECKING: - from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.blob_manager import BlobManager from lbrynet.stream.descriptor import StreamDescriptor REFLECTOR_V1 = 0 @@ -14,7 +14,7 @@ log = logging.getLogger(__name__) class StreamReflectorClient(asyncio.Protocol): - def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'): + def __init__(self, blob_manager: 'BlobManager', descriptor: 'StreamDescriptor'): self.loop = asyncio.get_event_loop() self.transport: asyncio.StreamWriter = None self.blob_manager = blob_manager @@ -64,7 +64,7 @@ class StreamReflectorClient(asyncio.Protocol): async def send_descriptor(self) -> typing.Tuple[bool, typing.List[str]]: # returns a list of needed blob hashes sd_blob = self.blob_manager.get_blob(self.descriptor.sd_hash) - assert sd_blob.get_is_verified(), "need to have a sd blob to send at this point" + assert self.blob_manager.is_blob_verified(self.descriptor.sd_hash), "need to have sd blob to send at this point" response = await self.send_request({ 'sd_blob_hash': sd_blob.blob_hash, 'sd_blob_size': sd_blob.length @@ -80,7 +80,7 @@ class StreamReflectorClient(asyncio.Protocol): sent_sd = True if not needed: for blob in self.descriptor.blobs[:-1]: - if self.blob_manager.get_blob(blob.blob_hash, blob.length).get_is_verified(): + if self.blob_manager.is_blob_verified(blob.blob_hash, blob.length): needed.append(blob.blob_hash) log.info("Sent reflector descriptor %s", sd_blob.blob_hash[:8]) self.reflected_blobs.append(sd_blob.blob_hash) @@ -91,8 +91,8 @@ class StreamReflectorClient(asyncio.Protocol): return sent_sd, needed async def send_blob(self, blob_hash: str): + assert self.blob_manager.is_blob_verified(blob_hash), "need to have a blob to send at this point" blob = self.blob_manager.get_blob(blob_hash) - assert blob.get_is_verified(), "need to have a blob to send at this point" response = await self.send_request({ 'blob_hash': blob.blob_hash, 'blob_size': blob.length diff --git a/lbrynet/stream/reflector/server.py b/lbrynet/stream/reflector/server.py index d5effd0a4..fd43f0c3a 100644 --- a/lbrynet/stream/reflector/server.py +++ b/lbrynet/stream/reflector/server.py @@ -7,7 +7,7 @@ from lbrynet.stream.descriptor import StreamDescriptor if typing.TYPE_CHECKING: from lbrynet.blob.blob_file import BlobFile - from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.writer import HashBlobWriter @@ -15,7 +15,7 @@ log = logging.getLogger(__name__) class ReflectorServerProtocol(asyncio.Protocol): - def __init__(self, blob_manager: 'BlobFileManager'): + def __init__(self, blob_manager: 'BlobManager'): self.loop = asyncio.get_event_loop() self.blob_manager = blob_manager self.server_task: asyncio.Task = None @@ -63,11 +63,11 @@ class ReflectorServerProtocol(asyncio.Protocol): return self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) if not self.sd_blob.get_is_verified(): - self.writer = self.sd_blob.open_for_writing() + self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.incoming.set() self.send_response({"send_sd_blob": True}) try: - await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop) + await asyncio.wait_for(self.sd_blob.verified.wait(), 30, loop=self.loop) self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.loop, self.blob_manager.blob_dir, self.sd_blob ) @@ -102,11 +102,11 @@ class ReflectorServerProtocol(asyncio.Protocol): return blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) if not blob.get_is_verified(): - self.writer = blob.open_for_writing() + self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.incoming.set() self.send_response({"send_blob": True}) try: - await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop) + await asyncio.wait_for(blob.verified.wait(), 30, loop=self.loop) self.send_response({"received_blob": True}) except asyncio.TimeoutError: self.send_response({"received_blob": False}) @@ -121,7 +121,7 @@ class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServer: - def __init__(self, blob_manager: 'BlobFileManager'): + def __init__(self, blob_manager: 'BlobManager'): self.loop = asyncio.get_event_loop() self.blob_manager = blob_manager self.server_task: asyncio.Task = None diff --git a/lbrynet/stream/stream_manager.py b/lbrynet/stream/stream_manager.py index 565656a70..6aeb13f6b 100644 --- a/lbrynet/stream/stream_manager.py +++ b/lbrynet/stream/stream_manager.py @@ -5,18 +5,17 @@ import binascii import logging import random from decimal import Decimal -from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \ - DownloadDataTimeout, DownloadSDTimeout -from lbrynet.utils import generate_id +from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError +from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout +from lbrynet.utils import cache_concurrent from lbrynet.stream.descriptor import StreamDescriptor -from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.managed_stream import ManagedStream from lbrynet.schema.claim import Claim from lbrynet.schema.uri import parse_lbry_uri from lbrynet.extras.daemon.storage import lbc_to_dewies if typing.TYPE_CHECKING: from lbrynet.conf import Config - from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.blob_manager import BlobManager from lbrynet.dht.node import Node from lbrynet.extras.daemon.analytics import AnalyticsManager from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim @@ -54,8 +53,11 @@ comparison_operators = { } +def path_or_none(p) -> typing.Optional[str]: + return None if p == '{stream}' else binascii.unhexlify(p).decode() + class StreamManager: - def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', + def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'], analytics_manager: typing.Optional['AnalyticsManager'] = None): self.loop = loop @@ -65,8 +67,7 @@ class StreamManager: self.storage = storage self.node = node self.analytics_manager = analytics_manager - self.streams: typing.Set[ManagedStream] = set() - self.starting_streams: typing.Dict[str, asyncio.Future] = {} + self.streams: typing.Dict[str, ManagedStream] = {} self.resume_downloading_task: asyncio.Task = None self.re_reflect_task: asyncio.Task = None self.update_stream_finished_futs: typing.List[asyncio.Future] = [] @@ -76,46 +77,6 @@ class StreamManager: claim_info = await self.storage.get_content_claim(stream.stream_hash) stream.set_claim(claim_info, claim_info['value']) - async def start_stream(self, stream: ManagedStream) -> bool: - """ - Resume or rebuild a partial or completed stream - """ - if not stream.running and not stream.output_file_exists: - if stream.downloader: - stream.downloader.stop() - stream.downloader = None - - # the directory is gone, can happen when the folder that contains a published file is deleted - # reset the download directory to the default and update the file name - if not os.path.isdir(stream.download_directory): - stream.download_directory = self.config.download_dir - - stream.downloader = self.make_downloader( - stream.sd_hash, stream.download_directory, stream.descriptor.suggested_file_name - ) - if stream.status != ManagedStream.STATUS_FINISHED: - await self.storage.change_file_status(stream.stream_hash, 'running') - stream.update_status('running') - stream.start_download(self.node) - try: - await asyncio.wait_for(self.loop.create_task(stream.downloader.wrote_bytes_event.wait()), - self.config.download_timeout) - except asyncio.TimeoutError: - await self.stop_stream(stream) - if stream in self.streams: - self.streams.remove(stream) - return False - file_name = os.path.basename(stream.downloader.output_path) - output_dir = os.path.dirname(stream.downloader.output_path) - await self.storage.change_file_download_dir_and_file_name( - stream.stream_hash, output_dir, file_name - ) - stream._file_name = file_name - stream.download_directory = output_dir - self.wait_for_stream_finished(stream) - return True - return True - async def stop_stream(self, stream: ManagedStream): stream.stop_download() if not stream.finished and stream.output_file_exists: @@ -128,10 +89,11 @@ class StreamManager: stream.update_status(ManagedStream.STATUS_STOPPED) await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED) - def make_downloader(self, sd_hash: str, download_directory: str, file_name: str): - return StreamDownloader( - self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name - ) + async def start_stream(self, stream: ManagedStream): + stream.update_status(ManagedStream.STATUS_RUNNING) + await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING) + await stream.setup(self.node, save_file=self.config.save_files) + self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) async def recover_streams(self, file_infos: typing.List[typing.Dict]): to_restore = [] @@ -156,81 +118,88 @@ class StreamManager: if to_restore: await self.storage.recover_streams(to_restore, self.config.download_dir) - log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos)) - async def add_stream(self, rowid: int, sd_hash: str, file_name: str, download_directory: str, status: str, + # if self.blob_manager._save_blobs: + # log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos)) + + async def add_stream(self, rowid: int, sd_hash: str, file_name: typing.Optional[str], + download_directory: typing.Optional[str], status: str, claim: typing.Optional['StoredStreamClaim']): - sd_blob = self.blob_manager.get_blob(sd_hash) - if not sd_blob.get_is_verified(): - return try: - descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash) + descriptor = await self.blob_manager.get_stream_descriptor(sd_hash) except InvalidStreamDescriptorError as err: log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err)) return - if status == ManagedStream.STATUS_RUNNING: - downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name) - else: - downloader = None stream = ManagedStream( - self.loop, self.blob_manager, rowid, descriptor, download_directory, file_name, downloader, status, claim + self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status, + claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager ) - self.streams.add(stream) - self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) + self.streams[sd_hash] = stream async def load_streams_from_database(self): to_recover = [] + to_start = [] + await self.storage.sync_files_to_blobs() for file_info in await self.storage.get_all_lbry_files(): - if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): + if not self.blob_manager.is_blob_verified(file_info['sd_hash']): to_recover.append(file_info) - + to_start.append(file_info) if to_recover: - log.info("Attempting to recover %i streams", len(to_recover)) + # if self.blob_manager._save_blobs: + # log.info("Attempting to recover %i streams", len(to_recover)) await self.recover_streams(to_recover) - to_start = [] - for file_info in await self.storage.get_all_lbry_files(): - if self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): - to_start.append(file_info) - log.info("Initializing %i files", len(to_start)) + if not self.config.save_files: + to_set_as_streaming = [] + for file_info in to_start: + file_name = path_or_none(file_info['file_name']) + download_dir = path_or_none(file_info['download_directory']) + if file_name and download_dir and not os.path.isfile(os.path.join(file_name, download_dir)): + file_info['file_name'], file_info['download_directory'] = '{stream}', '{stream}' + to_set_as_streaming.append(file_info['stream_hash']) - await asyncio.gather(*[ - self.add_stream( - file_info['rowid'], file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(), - binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'], - file_info['claim'] - ) for file_info in to_start - ]) + if to_set_as_streaming: + await self.storage.set_files_as_streaming(to_set_as_streaming) + + log.info("Initializing %i files", len(to_start)) + if to_start: + await asyncio.gather(*[ + self.add_stream( + file_info['rowid'], file_info['sd_hash'], path_or_none(file_info['file_name']), + path_or_none(file_info['download_directory']), file_info['status'], + file_info['claim'] + ) for file_info in to_start + ]) log.info("Started stream manager with %i files", len(self.streams)) async def resume(self): - if self.node: - await self.node.joined.wait() - else: + if not self.node: log.warning("no DHT node given, resuming downloads trusting that we can contact reflector") t = [ - (stream.start_download(self.node), self.wait_for_stream_finished(stream)) - for stream in self.streams if stream.status == ManagedStream.STATUS_RUNNING + self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values() + if stream.running ] if t: log.info("resuming %i downloads", len(t)) + await asyncio.gather(*t, loop=self.loop) async def reflect_streams(self): while True: if self.config.reflect_streams and self.config.reflector_servers: sd_hashes = await self.storage.get_streams_to_re_reflect() - streams = list(filter(lambda s: s.sd_hash in sd_hashes, self.streams)) + sd_hashes = [sd for sd in sd_hashes if sd in self.streams] batch = [] - while streams: - stream = streams.pop() - if not stream.fully_reflected.is_set(): - host, port = random.choice(self.config.reflector_servers) - batch.append(stream.upload_to_reflector(host, port)) + while sd_hashes: + stream = self.streams[sd_hashes.pop()] + if self.blob_manager.is_blob_verified(stream.sd_hash) and stream.blobs_completed: + if not stream.fully_reflected.is_set(): + host, port = random.choice(self.config.reflector_servers) + batch.append(stream.upload_to_reflector(host, port)) if len(batch) >= self.config.concurrent_reflector_uploads: - await asyncio.gather(*batch) + await asyncio.gather(*batch, loop=self.loop) batch = [] if batch: - await asyncio.gather(*batch) + await asyncio.gather(*batch, loop=self.loop) await asyncio.sleep(300, loop=self.loop) async def start(self): @@ -244,7 +213,7 @@ class StreamManager: if self.re_reflect_task and not self.re_reflect_task.done(): self.re_reflect_task.cancel() while self.streams: - stream = self.streams.pop() + _, stream = self.streams.popitem() stream.stop_download() while self.update_stream_finished_futs: self.update_stream_finished_futs.pop().cancel() @@ -253,8 +222,8 @@ class StreamManager: async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None, iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream: - stream = await ManagedStream.create(self.loop, self.blob_manager, file_path, key, iv_generator) - self.streams.add(stream) + stream = await ManagedStream.create(self.loop, self.config, self.blob_manager, file_path, key, iv_generator) + self.streams[stream.sd_hash] = stream self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) if self.config.reflect_streams and self.config.reflector_servers: host, port = random.choice(self.config.reflector_servers) @@ -268,8 +237,8 @@ class StreamManager: async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): await self.stop_stream(stream) - if stream in self.streams: - self.streams.remove(stream) + if stream.sd_hash in self.streams: + del self.streams[stream.sd_hash] blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]] await self.blob_manager.delete_blobs(blob_hashes, delete_from_db=False) await self.storage.delete_stream(stream.descriptor) @@ -277,7 +246,7 @@ class StreamManager: os.remove(stream.full_path) def get_stream_by_stream_hash(self, stream_hash: str) -> typing.Optional[ManagedStream]: - streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams)) + streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams.values())) if streams: return streams[0] @@ -304,54 +273,25 @@ class StreamManager: if search_by: comparison = comparison or 'eq' streams = [] - for stream in self.streams: + for stream in self.streams.values(): for search, val in search_by.items(): if comparison_operators[comparison](getattr(stream, search), val): streams.append(stream) break else: - streams = list(self.streams) + streams = list(self.streams.values()) if sort_by: streams.sort(key=lambda s: getattr(s, sort_by)) if reverse: streams.reverse() return streams - def wait_for_stream_finished(self, stream: ManagedStream): - async def _wait_for_stream_finished(): - if stream.downloader and stream.running: - await stream.downloader.stream_finished_event.wait() - stream.update_status(ManagedStream.STATUS_FINISHED) - if self.analytics_manager: - self.loop.create_task(self.analytics_manager.send_download_finished( - stream.download_id, stream.claim_name, stream.sd_hash - )) - - task = self.loop.create_task(_wait_for_stream_finished()) - self.update_stream_finished_futs.append(task) - task.add_done_callback( - lambda _: None if task not in self.update_stream_finished_futs else - self.update_stream_finished_futs.remove(task) - ) - - async def _store_stream(self, downloader: StreamDownloader) -> int: - file_name = os.path.basename(downloader.output_path) - download_directory = os.path.dirname(downloader.output_path) - if not await self.storage.stream_exists(downloader.sd_hash): - await self.storage.store_stream(downloader.sd_blob, downloader.descriptor) - if not await self.storage.file_exists(downloader.sd_hash): - return await self.storage.save_downloaded_file( - downloader.descriptor.stream_hash, file_name, download_directory, - 0.0 - ) - else: - return await self.storage.rowid_for_stream(downloader.descriptor.stream_hash) - async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[ typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: existing = self.get_filtered_streams(outpoint=outpoint) if existing: - await self.start_stream(existing[0]) + if existing[0].status == ManagedStream.STATUS_STOPPED: + await self.start_stream(existing[0]) return existing[0], None existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) if existing and existing[0].claim_id != claim_id: @@ -363,7 +303,8 @@ class StreamManager: existing[0].stream_hash, outpoint ) await self._update_content_claim(existing[0]) - await self.start_stream(existing[0]) + if not existing[0].running: + await self.start_stream(existing[0]) return existing[0], None else: existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id) @@ -372,142 +313,119 @@ class StreamManager: return None, existing_for_claim_id[0] return None, None - async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: StreamDownloader, - download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict, - file_name: typing.Optional[str] = None) -> ManagedStream: - start_time = self.loop.time() - downloader.download(self.node) - await downloader.got_descriptor.wait() - got_descriptor_time.set_result(self.loop.time() - start_time) - rowid = await self._store_stream(downloader) - await self.storage.save_content_claim( - downloader.descriptor.stream_hash, outpoint - ) - stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir, - file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id) - stream.set_claim(resolved, claim) - await stream.downloader.wrote_bytes_event.wait() - self.streams.add(stream) - return stream - - async def _download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager', - file_name: typing.Optional[str] = None) -> ManagedStream: - start_time = self.loop.time() - parsed_uri = parse_lbry_uri(uri) - if parsed_uri.is_channel: - raise ResolveError("cannot download a channel claim, specify a /path") - - # resolve the claim - resolved = (await self.wallet.ledger.resolve(0, 10, uri)).get(uri, {}) - resolved = resolved if 'value' in resolved else resolved.get('claim') - - if not resolved: - raise ResolveError(f"Failed to resolve stream at '{uri}'") - if 'error' in resolved: - raise ResolveError(f"error resolving stream: {resolved['error']}") - - claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf'])) - outpoint = f"{resolved['txid']}:{resolved['nout']}" - resolved_time = self.loop.time() - start_time - - # resume or update an existing stream, if the stream changed download it and delete the old one after - updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) - if updated_stream: - return updated_stream - - # check that the fee is payable - fee_amount, fee_address = None, None - if claim.stream.has_fee: - fee_amount = round(exchange_rate_manager.convert_currency( - claim.stream.fee.currency, "LBC", claim.stream.fee.amount - ), 5) - max_fee_amount = round(exchange_rate_manager.convert_currency( - self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount']) - ), 5) - if fee_amount > max_fee_amount: - msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}" - log.warning(msg) - raise KeyFeeAboveMaxAllowed(msg) - balance = await self.wallet.default_account.get_balance() - if lbc_to_dewies(str(fee_amount)) > balance: - msg = f"fee of {fee_amount} exceeds max available balance" - log.warning(msg) - raise InsufficientFundsError(msg) - fee_address = claim.stream.fee.address - - # download the stream - download_id = binascii.hexlify(generate_id()).decode() - downloader = StreamDownloader(self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, - self.config.download_dir, file_name) - - stream = None - descriptor_time_fut = self.loop.create_future() - start_download_time = self.loop.time() - time_to_descriptor = None - time_to_first_bytes = None - error = None - try: - stream = await asyncio.wait_for( - asyncio.ensure_future( - self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved, - file_name) - ), timeout - ) - time_to_descriptor = await descriptor_time_fut - time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor - self.wait_for_stream_finished(stream) - if fee_address and fee_amount and not to_replace: - stream.tx = await self.wallet.send_amount_to_address( - lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')) - elif to_replace: # delete old stream now that the replacement has started downloading - await self.delete_stream(to_replace) - except asyncio.TimeoutError: - if descriptor_time_fut.done(): - time_to_descriptor = descriptor_time_fut.result() - error = DownloadDataTimeout(downloader.sd_hash) - self.blob_manager.delete_blob(downloader.sd_hash) - await self.storage.delete_stream(downloader.descriptor) - else: - descriptor_time_fut.cancel() - error = DownloadSDTimeout(downloader.sd_hash) - if stream: - await self.stop_stream(stream) - else: - downloader.stop() - if error: - log.warning(error) - if self.analytics_manager: - self.loop.create_task( - self.analytics_manager.send_time_to_first_bytes( - resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint, - None if not stream else len(stream.downloader.blob_downloader.active_connections), - None if not stream else len(stream.downloader.blob_downloader.scores), - False if not downloader else downloader.added_fixed_peers, - self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay, - claim.stream.source.sd_hash, time_to_descriptor, - None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash, - None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length, - time_to_first_bytes, None if not error else error.__class__.__name__ - ) - ) - if error: - raise error - return stream - + @cache_concurrent async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager', + timeout: typing.Optional[float] = None, file_name: typing.Optional[str] = None, - timeout: typing.Optional[float] = None) -> ManagedStream: + download_directory: typing.Optional[str] = None, + save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream: timeout = timeout or self.config.download_timeout - if uri in self.starting_streams: - return await self.starting_streams[uri] - fut = asyncio.Future(loop=self.loop) - self.starting_streams[uri] = fut + start_time = self.loop.time() + resolved_time = None + stream = None + error = None + outpoint = None try: - stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name) - fut.set_result(stream) + # resolve the claim + parsed_uri = parse_lbry_uri(uri) + if parsed_uri.is_channel: + raise ResolveError("cannot download a channel claim, specify a /path") + try: + resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout) + except asyncio.TimeoutError: + raise ResolveTimeout(uri) + await self.storage.save_claims_for_resolve([ + value for value in resolved_result.values() if 'error' not in value + ]) + resolved = resolved_result.get(uri, {}) + resolved = resolved if 'value' in resolved else resolved.get('claim') + if not resolved: + raise ResolveError(f"Failed to resolve stream at '{uri}'") + if 'error' in resolved: + raise ResolveError(f"error resolving stream: {resolved['error']}") + + claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf'])) + outpoint = f"{resolved['txid']}:{resolved['nout']}" + resolved_time = self.loop.time() - start_time + + # resume or update an existing stream, if the stream changed download it and delete the old one after + updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) + if updated_stream: + return updated_stream + + content_fee = None + + # check that the fee is payable + if not to_replace and claim.stream.has_fee: + fee_amount = round(exchange_rate_manager.convert_currency( + claim.stream.fee.currency, "LBC", claim.stream.fee.amount + ), 5) + max_fee_amount = round(exchange_rate_manager.convert_currency( + self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount']) + ), 5) + if fee_amount > max_fee_amount: + msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}" + log.warning(msg) + raise KeyFeeAboveMaxAllowed(msg) + balance = await self.wallet.default_account.get_balance() + if lbc_to_dewies(str(fee_amount)) > balance: + msg = f"fee of {fee_amount} exceeds max available balance" + log.warning(msg) + raise InsufficientFundsError(msg) + fee_address = claim.stream.fee.address + + content_fee = await self.wallet.send_amount_to_address( + lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') + ) + + log.info("paid fee of %s for %s", fee_amount, uri) + + download_directory = download_directory or self.config.download_dir + if not file_name and (not self.config.save_files or not save_file): + download_dir, file_name = None, None + stream = ManagedStream( + self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory, + file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee, + analytics_manager=self.analytics_manager + ) + log.info("starting download for %s", uri) + try: + await asyncio.wait_for(stream.setup( + self.node, save_file=save_file, file_name=file_name, download_directory=download_directory + ), timeout, loop=self.loop) + except asyncio.TimeoutError: + if not stream.descriptor: + raise DownloadSDTimeout(stream.sd_hash) + raise DownloadDataTimeout(stream.sd_hash) + if to_replace: # delete old stream now that the replacement has started downloading + await self.delete_stream(to_replace) + stream.set_claim(resolved, claim) + await self.storage.save_content_claim(stream.stream_hash, outpoint) + self.streams[stream.sd_hash] = stream + return stream except Exception as err: - fut.set_exception(err) - try: - return await fut + error = err + if stream and stream.descriptor: + await self.storage.delete_stream(stream.descriptor) + await self.blob_manager.delete_blob(stream.sd_hash) finally: - del self.starting_streams[uri] + if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or + stream.downloader.time_to_first_bytes))): + self.loop.create_task( + self.analytics_manager.send_time_to_first_bytes( + resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id, + uri, outpoint, + None if not stream else len(stream.downloader.blob_downloader.active_connections), + None if not stream else len(stream.downloader.blob_downloader.scores), + False if not stream else stream.downloader.added_fixed_peers, + self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay, + None if not stream else stream.sd_hash, + None if not stream else stream.downloader.time_to_descriptor, + None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash, + None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length, + None if not stream else stream.downloader.time_to_first_bytes, + None if not error else error.__class__.__name__ + ) + ) + if error: + raise error diff --git a/lbrynet/testcase.py b/lbrynet/testcase.py index aaa273ec6..80e121e5a 100644 --- a/lbrynet/testcase.py +++ b/lbrynet/testcase.py @@ -18,7 +18,7 @@ from lbrynet.extras.daemon.Components import ( ) from lbrynet.extras.daemon.ComponentManager import ComponentManager from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.stream.reflector.server import ReflectorServer from lbrynet.blob_exchange.server import BlobServer @@ -107,9 +107,11 @@ class CommandTestCase(IntegrationTestCase): server_tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, server_tmp_dir) - self.server_storage = SQLiteStorage(Config(), ':memory:') + self.server_config = Config() + self.server_storage = SQLiteStorage(self.server_config, ':memory:') await self.server_storage.open() - self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage) + + self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_config) self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') self.server.start_server(5567, '127.0.0.1') await self.server.started_listening.wait() diff --git a/lbrynet/utils.py b/lbrynet/utils.py index c25d8633c..19498444e 100644 --- a/lbrynet/utils.py +++ b/lbrynet/utils.py @@ -14,6 +14,7 @@ import pkg_resources import contextlib import certifi import aiohttp +import functools from lbrynet.schema.claim import Claim from lbrynet.cryptoutils import get_lbry_hash_obj @@ -146,6 +147,44 @@ def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): cancel_task(tasks.pop()) +def async_timed_cache(duration: int): + def wrapper(fn): + cache: typing.Dict[typing.Tuple, + typing.Tuple[typing.Any, float]] = {} + + @functools.wraps(fn) + async def _inner(*args, **kwargs) -> typing.Any: + loop = asyncio.get_running_loop() + now = loop.time() + key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])]) + if key in cache and (now - cache[key][1] < duration): + return cache[key][0] + to_cache = await fn(*args, **kwargs) + cache[key] = to_cache, now + return to_cache + return _inner + return wrapper + + +def cache_concurrent(async_fn): + """ + When the decorated function has concurrent calls made to it with the same arguments, only run it once + """ + cache: typing.Dict = {} + + @functools.wraps(async_fn) + async def wrapper(*args, **kwargs): + key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])]) + cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs)) + try: + return await cache[key] + finally: + cache.pop(key, None) + + return wrapper + + +@async_timed_cache(300) async def resolve_host(url: str, port: int, proto: str) -> str: if proto not in ['udp', 'tcp']: raise Exception("invalid protocol") diff --git a/scripts/download_blob_from_peer.py b/scripts/download_blob_from_peer.py index 7c1ad8d07..a145de700 100644 --- a/scripts/download_blob_from_peer.py +++ b/scripts/download_blob_from_peer.py @@ -5,7 +5,7 @@ import socket import ipaddress from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob import logging @@ -32,7 +32,7 @@ async def main(blob_hash: str, url: str): host = host_info[0][4][0] storage = SQLiteStorage(conf, os.path.join(conf.data_dir, "lbrynet.sqlite")) - blob_manager = BlobFileManager(loop, os.path.join(conf.data_dir, "blobfiles"), storage) + blob_manager = BlobManager(loop, os.path.join(conf.data_dir, "blobfiles"), storage) await storage.open() await blob_manager.setup() diff --git a/scripts/generate_json_api.py b/scripts/generate_json_api.py index c0308b7b3..c6a2f1573 100644 --- a/scripts/generate_json_api.py +++ b/scripts/generate_json_api.py @@ -325,6 +325,11 @@ class Examples(CommandTestCase): 'get', file_uri ) + await r( + 'Save a file to the downloads directory', + 'file', 'save', f"--sd_hash=\"{file_list_result[0]['sd_hash']}\"" + ) + # blobs bloblist = await r( diff --git a/scripts/standalone_blob_server.py b/scripts/standalone_blob_server.py index 455aaf3c8..c056e6c53 100644 --- a/scripts/standalone_blob_server.py +++ b/scripts/standalone_blob_server.py @@ -1,7 +1,7 @@ import sys import os import asyncio -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob_exchange.server import BlobServer from lbrynet.schema.address import decode_address from lbrynet.extras.daemon.storage import SQLiteStorage @@ -17,7 +17,7 @@ async def main(address: str): storage = SQLiteStorage(os.path.expanduser("~/.lbrynet/lbrynet.sqlite")) await storage.open() - blob_manager = BlobFileManager(loop, os.path.expanduser("~/.lbrynet/blobfiles"), storage) + blob_manager = BlobManager(loop, os.path.expanduser("~/.lbrynet/blobfiles"), storage) await blob_manager.setup() server = await loop.create_server( diff --git a/tests/dht_mocks.py b/tests/dht_mocks.py index 3caa8eef7..801175edf 100644 --- a/tests/dht_mocks.py +++ b/tests/dht_mocks.py @@ -61,6 +61,7 @@ def mock_network_loop(loop: asyncio.BaseEventLoop): protocol = proto_lam() transport = asyncio.DatagramTransport(extra={'socket': mock_sock}) + transport.is_closing = lambda: False transport.close = lambda: mock_sock.close() mock_sock.sendto = sendto transport.sendto = mock_sock.sendto diff --git a/tests/integration/test_claim_commands.py b/tests/integration/test_claim_commands.py index 2224d8706..2f59dd7ec 100644 --- a/tests/integration/test_claim_commands.py +++ b/tests/integration/test_claim_commands.py @@ -518,8 +518,12 @@ class StreamCommands(CommandTestCase): file.flush() tx1 = await self.publish('foo', bid='1.0', file_path=file.name) + self.assertEqual(1, len(self.daemon.jsonrpc_file_list())) + # doesn't error on missing arguments when doing an update stream tx2 = await self.publish('foo', tags='updated') + + self.assertEqual(1, len(self.daemon.jsonrpc_file_list())) self.assertEqual( tx1['outputs'][0]['claim_id'], tx2['outputs'][0]['claim_id'] @@ -530,12 +534,14 @@ class StreamCommands(CommandTestCase): with self.assertRaisesRegex(Exception, "There are 2 claims for 'foo'"): await self.daemon.jsonrpc_publish('foo') + self.assertEqual(2, len(self.daemon.jsonrpc_file_list())) # abandon duplicate stream await self.stream_abandon(tx3['outputs'][0]['claim_id']) # publish to a channel await self.channel_create('@abc') tx3 = await self.publish('foo', channel_name='@abc') + self.assertEqual(2, len(self.daemon.jsonrpc_file_list())) r = await self.resolve('lbry://@abc/foo') self.assertEqual( r['lbry://@abc/foo']['claim']['claim_id'], @@ -544,6 +550,7 @@ class StreamCommands(CommandTestCase): # publishing again re-signs with the same channel tx4 = await self.publish('foo', languages='uk-UA') + self.assertEqual(2, len(self.daemon.jsonrpc_file_list())) r = await self.resolve('lbry://@abc/foo') claim = r['lbry://@abc/foo']['claim'] self.assertEqual(claim['txid'], tx4['outputs'][0]['txid']) diff --git a/tests/integration/test_file_commands.py b/tests/integration/test_file_commands.py index 8a6265eaf..23c487c44 100644 --- a/tests/integration/test_file_commands.py +++ b/tests/integration/test_file_commands.py @@ -36,12 +36,12 @@ class FileCommands(CommandTestCase): await self.server.blob_manager.delete_blobs(all_except_sd) resp = await self.daemon.jsonrpc_get('lbry://foo', timeout=2) self.assertIn('error', resp) - self.assertEquals('Failed to download data blobs for sd hash %s within timeout' % sd_hash, resp['error']) + self.assertEqual('Failed to download data blobs for sd hash %s within timeout' % sd_hash, resp['error']) await self.daemon.jsonrpc_file_delete(claim_name='foo') await self.server.blob_manager.delete_blobs([sd_hash]) resp = await self.daemon.jsonrpc_get('lbry://foo', timeout=2) self.assertIn('error', resp) - self.assertEquals('Failed to download sd blob %s within timeout' % sd_hash, resp['error']) + self.assertEqual('Failed to download sd blob %s within timeout' % sd_hash, resp['error']) async def wait_files_to_complete(self): while self.sout(self.daemon.jsonrpc_file_list(status='running')): @@ -59,17 +59,14 @@ class FileCommands(CommandTestCase): await self.daemon.stream_manager.start() await asyncio.wait_for(self.wait_files_to_complete(), timeout=5) # if this hangs, file didnt get set completed # check that internal state got through up to the file list API - downloader = self.daemon.stream_manager.get_stream_by_stream_hash(file_info['stream_hash']).downloader - file_info = self.sout(self.daemon.jsonrpc_file_list())[0] - self.assertEqual(downloader.output_file_name, file_info['file_name']) + stream = self.daemon.stream_manager.get_stream_by_stream_hash(file_info['stream_hash']) + file_info = self.sout(self.daemon.jsonrpc_file_list()[0]) + self.assertEqual(stream.file_name, file_info['file_name']) # checks if what the API shows is what he have at the very internal level. - self.assertEqual(downloader.output_path, file_info['download_path']) - # if you got here refactoring just change above, but ensure what gets set internally gets reflected externally! - self.assertTrue(downloader.output_path.endswith(downloader.output_file_name)) - # this used to be inconsistent, if it becomes again it would create weird bugs, so worth checking + self.assertEqual(stream.full_path, file_info['download_path']) async def test_incomplete_downloads_erases_output_file_on_stop(self): - tx = await self.stream_create('foo', '0.01') + tx = await self.stream_create('foo', '0.01', data=b'deadbeef' * 1000000) sd_hash = tx['outputs'][0]['value']['source']['sd_hash'] file_info = self.sout(self.daemon.jsonrpc_file_list())[0] await self.daemon.jsonrpc_file_delete(claim_name='foo') @@ -77,25 +74,27 @@ class FileCommands(CommandTestCase): await self.server_storage.get_stream_hash_for_sd_hash(sd_hash) ) all_except_sd_and_head = [ - blob.blob_hash for blob in blobs[1:] if blob.blob_hash + blob.blob_hash for blob in blobs[1:-1] ] await self.server.blob_manager.delete_blobs(all_except_sd_and_head) - self.assertFalse(os.path.isfile(os.path.join(self.daemon.conf.download_dir, file_info['file_name']))) + path = os.path.join(self.daemon.conf.download_dir, file_info['file_name']) + self.assertFalse(os.path.isfile(path)) resp = await self.out(self.daemon.jsonrpc_get('lbry://foo', timeout=2)) self.assertNotIn('error', resp) - self.assertTrue(os.path.isfile(os.path.join(self.daemon.conf.download_dir, file_info['file_name']))) + self.assertTrue(os.path.isfile(path)) self.daemon.stream_manager.stop() - self.assertFalse(os.path.isfile(os.path.join(self.daemon.conf.download_dir, file_info['file_name']))) + await asyncio.sleep(0.01, loop=self.loop) # FIXME: this sleep should not be needed + self.assertFalse(os.path.isfile(path)) async def test_incomplete_downloads_retry(self): - tx = await self.stream_create('foo', '0.01') + tx = await self.stream_create('foo', '0.01', data=b'deadbeef' * 1000000) sd_hash = tx['outputs'][0]['value']['source']['sd_hash'] await self.daemon.jsonrpc_file_delete(claim_name='foo') blobs = await self.server_storage.get_blobs_for_stream( await self.server_storage.get_stream_hash_for_sd_hash(sd_hash) ) all_except_sd_and_head = [ - blob.blob_hash for blob in blobs[1:] if blob.blob_hash + blob.blob_hash for blob in blobs[1:-1] ] # backup server blobs @@ -143,7 +142,7 @@ class FileCommands(CommandTestCase): os.rename(missing_blob.file_path + '__', missing_blob.file_path) self.server_blob_manager.blobs.clear() missing_blob = self.server_blob_manager.get_blob(missing_blob_hash) - await self.server_blob_manager.blob_completed(missing_blob) + self.server_blob_manager.blob_completed(missing_blob) await asyncio.wait_for(self.wait_files_to_complete(), timeout=1) async def test_paid_download(self): @@ -176,9 +175,8 @@ class FileCommands(CommandTestCase): ) await self.daemon.jsonrpc_file_delete(claim_name='icanpay') await self.assertBalance(self.account, '9.925679') - response = await self.out(self.daemon.jsonrpc_get('lbry://icanpay')) - self.assertNotIn('error', response) - await self.on_transaction_dict(response['tx']) + response = await self.daemon.jsonrpc_get('lbry://icanpay') + await self.ledger.wait(response.content_fee) await self.assertBalance(self.account, '8.925555') self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1) diff --git a/tests/integration/test_streaming.py b/tests/integration/test_streaming.py new file mode 100644 index 000000000..1c4bdc8df --- /dev/null +++ b/tests/integration/test_streaming.py @@ -0,0 +1,358 @@ +import os +import hashlib +import aiohttp +import aiohttp.web + +from lbrynet.utils import aiohttp_request +from lbrynet.blob.blob_file import MAX_BLOB_SIZE +from lbrynet.testcase import CommandTestCase + + +def get_random_bytes(n: int) -> bytes: + result = b''.join(hashlib.sha256(os.urandom(4)).digest() for _ in range(n // 16)) + if len(result) < n: + result += os.urandom(n - len(result)) + elif len(result) > n: + result = result[:-(len(result) - n)] + assert len(result) == n, (n, len(result)) + return result + + +class RangeRequests(CommandTestCase): + async def _restart_stream_manager(self): + self.daemon.stream_manager.stop() + await self.daemon.stream_manager.start() + return + + async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False): + self.daemon.conf.save_blobs = save_blobs + self.daemon.conf.save_files = save_files + self.data = data + await self.stream_create('foo', '0.01', data=self.data) + if save_blobs: + self.assertTrue(len(os.listdir(self.daemon.blob_manager.blob_dir)) > 1) + await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait() + await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, claim_name='foo') + self.assertEqual(0, len(os.listdir(self.daemon.blob_manager.blob_dir))) + # await self._restart_stream_manager() + await self.daemon.runner.setup() + site = aiohttp.web.TCPSite(self.daemon.runner, self.daemon.conf.api_host, self.daemon.conf.api_port) + await site.start() + self.assertListEqual(self.daemon.jsonrpc_file_list(), []) + + async def _test_range_requests(self): + name = 'foo' + url = f'http://{self.daemon.conf.api_host}:{self.daemon.conf.api_port}/get/{name}' + + async with aiohttp_request('get', url) as req: + self.assertEqual(req.headers.get('Content-Type'), 'application/octet-stream') + content_range = req.headers.get('Content-Range') + content_length = int(req.headers.get('Content-Length')) + streamed_bytes = await req.content.read() + self.assertEqual(content_length, len(streamed_bytes)) + return streamed_bytes, content_range, content_length + + async def test_range_requests_2_byte(self): + self.data = b'hi' + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(15, content_length) + self.assertEqual(b'hi\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', streamed) + self.assertEqual('bytes 0-14/15', content_range) + + async def test_range_requests_15_byte(self): + self.data = b'123456789abcdef' + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(15, content_length) + self.assertEqual(15, len(streamed)) + self.assertEqual(self.data, streamed) + self.assertEqual('bytes 0-14/15', content_range) + + async def test_range_requests_0_padded_bytes(self, size: int = (MAX_BLOB_SIZE - 1) * 4, + expected_range: str = 'bytes 0-8388603/8388604', padding=b''): + self.data = get_random_bytes(size) + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(len(self.data + padding), content_length) + self.assertEqual(streamed, self.data + padding) + self.assertEqual(expected_range, content_range) + + async def test_range_requests_1_padded_bytes(self): + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 1, padding=b'\x00' + ) + + async def test_range_requests_2_padded_bytes(self): + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 2, padding=b'\x00' * 2 + ) + + async def test_range_requests_14_padded_bytes(self): + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 14, padding=b'\x00' * 14 + ) + + async def test_range_requests_15_padded_bytes(self): + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 15, padding=b'\x00' * 15 + ) + + async def test_range_requests_last_block_of_last_blob_padding(self): + self.data = get_random_bytes(((MAX_BLOB_SIZE - 1) * 4) - 16) + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(len(self.data), content_length) + self.assertEqual(streamed, self.data) + self.assertEqual('bytes 0-8388587/8388588', content_range) + + async def test_streaming_only_with_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data) + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + + # test that repeated range requests do not create duplicate files + for _ in range(3): + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + # test that a range request after restart does not create a duplicate file + await self._restart_stream_manager() + + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + async def test_streaming_only_without_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data, save_blobs=False) + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + + # test that repeated range requests do not create duplicate files + for _ in range(3): + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + # test that a range request after restart does not create a duplicate file + await self._restart_stream_manager() + + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + async def test_stream_and_save_file_with_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data, save_files=True) + + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + full_path = stream.full_path + files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + + for _ in range(3): + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + await self._restart_stream_manager() + + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + with open(stream.full_path, 'rb') as f: + self.assertEqual(self.data, f.read()) + + async def test_stream_and_save_file_without_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data, save_files=True) + self.daemon.conf.save_blobs = False + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + full_path = stream.full_path + files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + + for _ in range(3): + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + await self._restart_stream_manager() + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + with open(stream.full_path, 'rb') as f: + self.assertEqual(self.data, f.read()) + + async def test_switch_save_blobs_while_running(self): + await self.test_streaming_only_without_blobs() + self.daemon.conf.save_blobs = True + blobs_in_stream = self.daemon.jsonrpc_file_list()[0].blobs_in_stream + sd_hash = self.daemon.jsonrpc_file_list()[0].sd_hash + start_file_count = len(os.listdir(self.daemon.blob_manager.blob_dir)) + await self._test_range_requests() + self.assertEqual(start_file_count + blobs_in_stream, len(os.listdir(self.daemon.blob_manager.blob_dir))) + self.assertEqual(0, self.daemon.jsonrpc_file_list()[0].blobs_remaining) + + # switch back + self.daemon.conf.save_blobs = False + await self._test_range_requests() + self.assertEqual(start_file_count + blobs_in_stream, len(os.listdir(self.daemon.blob_manager.blob_dir))) + self.assertEqual(0, self.daemon.jsonrpc_file_list()[0].blobs_remaining) + await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, sd_hash=sd_hash) + self.assertEqual(start_file_count, len(os.listdir(self.daemon.blob_manager.blob_dir))) + await self._test_range_requests() + self.assertEqual(start_file_count, len(os.listdir(self.daemon.blob_manager.blob_dir))) + self.assertEqual(blobs_in_stream, self.daemon.jsonrpc_file_list()[0].blobs_remaining) + + async def test_file_save_streaming_only_save_blobs(self): + await self.test_streaming_only_with_blobs() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.full_path) + self.server.stop_server() + await self.daemon.jsonrpc_file_save('test', self.daemon.conf.data_dir) + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNotNone(stream.full_path) + await stream.finished_writing.wait() + with open(stream.full_path, 'rb') as f: + self.assertEqual(self.data, f.read()) + await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, sd_hash=stream.sd_hash) + + async def test_file_save_stop_before_finished_streaming_only(self, wait_for_start_writing=False): + await self.test_streaming_only_with_blobs() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.full_path) + self.server.stop_server() + await self.daemon.jsonrpc_file_save('test', self.daemon.conf.data_dir) + stream = self.daemon.jsonrpc_file_list()[0] + path = stream.full_path + self.assertIsNotNone(path) + if wait_for_start_writing: + await stream.started_writing.wait() + self.assertTrue(os.path.isfile(path)) + await self._restart_stream_manager() + stream = self.daemon.jsonrpc_file_list()[0] + + self.assertIsNone(stream.full_path) + self.assertFalse(os.path.isfile(path)) + + async def test_file_save_stop_before_finished_streaming_only_wait_for_start(self): + return await self.test_file_save_stop_before_finished_streaming_only(wait_for_start_writing=True) + + async def test_file_save_streaming_only_dont_save_blobs(self): + await self.test_streaming_only_without_blobs() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.full_path) + await self.daemon.jsonrpc_file_save('test', self.daemon.conf.data_dir) + stream = self.daemon.jsonrpc_file_list()[0] + await stream.finished_writing.wait() + with open(stream.full_path, 'rb') as f: + self.assertEqual(self.data, f.read()) diff --git a/tests/unit/blob/test_blob_file.py b/tests/unit/blob/test_blob_file.py index 4254e0720..8b519e27e 100644 --- a/tests/unit/blob/test_blob_file.py +++ b/tests/unit/blob/test_blob_file.py @@ -3,34 +3,207 @@ import tempfile import shutil import os from torba.testcase import AsyncioTestCase +from lbrynet.error import InvalidDataError, InvalidBlobHashError from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager +from lbrynet.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob -class TestBlobfile(AsyncioTestCase): - async def test_create_blob(self): - blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" - blob_bytes = b'1' * ((2 * 2 ** 20) - 1) +class TestBlob(AsyncioTestCase): + blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" + blob_bytes = b'1' * ((2 * 2 ** 20) - 1) + + async def asyncSetUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) + self.loop = asyncio.get_running_loop() + self.config = Config() + self.storage = SQLiteStorage(self.config, ":memory:", self.loop) + self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.config) + await self.storage.open() + + def _get_blob(self, blob_class=AbstractBlob, blob_directory=None): + blob = blob_class(self.loop, self.blob_hash, len(self.blob_bytes), self.blob_manager.blob_completed, + blob_directory=blob_directory) + self.assertFalse(blob.get_is_verified()) + self.addCleanup(blob.close) + return blob + + async def _test_create_blob(self, blob_class=AbstractBlob, blob_directory=None): + blob = self._get_blob(blob_class, blob_directory) + writer = blob.get_blob_writer() + writer.write(self.blob_bytes) + await blob.verified.wait() + self.assertTrue(blob.get_is_verified()) + await asyncio.sleep(0, loop=self.loop) # wait for the db save task + return blob + + async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None): + blob = self._get_blob(blob_class, blob_directory=blob_directory) + writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)] + self.assertEqual(5, len(blob.writers)) + + # test that writing too much causes the writer to fail with InvalidDataError and to be removed + with self.assertRaises(InvalidDataError): + writers[1].write(self.blob_bytes * 2) + await writers[1].finished + await asyncio.sleep(0, loop=self.loop) + self.assertEqual(4, len(blob.writers)) + + # write the blob + other = writers[2] + writers[3].write(self.blob_bytes) + await blob.verified.wait() + + self.assertTrue(blob.get_is_verified()) + self.assertEqual(0, len(blob.writers)) + with self.assertRaises(IOError): + other.write(self.blob_bytes) + + def _test_ioerror_if_length_not_set(self, blob_class=AbstractBlob, blob_directory=None): + blob = blob_class( + self.loop, self.blob_hash, blob_completed_callback=self.blob_manager.blob_completed, + blob_directory=blob_directory + ) + self.addCleanup(blob.close) + writer = blob.get_blob_writer() + with self.assertRaises(IOError): + writer.write(b'') + + async def _test_invalid_blob_bytes(self, blob_class=AbstractBlob, blob_directory=None): + blob = blob_class( + self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed, + blob_directory=blob_directory + ) + self.addCleanup(blob.close) + writer = blob.get_blob_writer() + writer.write(self.blob_bytes[:-4] + b'fake') + with self.assertRaises(InvalidBlobHashError): + await writer.finished + + async def test_add_blob_buffer_to_db(self): + blob = await self._test_create_blob(BlobBuffer) + db_status = await self.storage.get_blob_status(blob.blob_hash) + self.assertEqual(db_status, 'pending') + + async def test_add_blob_file_to_db(self): + blob = await self._test_create_blob(BlobFile, self.tmp_dir) + db_status = await self.storage.get_blob_status(blob.blob_hash) + self.assertEqual(db_status, 'finished') + + async def test_invalid_blob_bytes(self): + await self._test_invalid_blob_bytes(BlobBuffer) + await self._test_invalid_blob_bytes(BlobFile, self.tmp_dir) + + def test_ioerror_if_length_not_set(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + self._test_ioerror_if_length_not_set(BlobBuffer) + self._test_ioerror_if_length_not_set(BlobFile, tmp_dir) + + async def test_create_blob_file(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + blob = await self._test_create_blob(BlobFile, tmp_dir) + self.assertIsInstance(blob, BlobFile) + self.assertTrue(os.path.isfile(blob.file_path)) + + for _ in range(2): + with blob.reader_context() as reader: + self.assertEqual(self.blob_bytes, reader.read()) + + async def test_create_blob_buffer(self): + blob = await self._test_create_blob(BlobBuffer) + self.assertIsInstance(blob, BlobBuffer) + self.assertIsNotNone(blob._verified_bytes) + + # check we can only read the bytes once, and that the buffer is torn down + with blob.reader_context() as reader: + self.assertEqual(self.blob_bytes, reader.read()) + self.assertIsNone(blob._verified_bytes) + with self.assertRaises(OSError): + with blob.reader_context() as reader: + self.assertEqual(self.blob_bytes, reader.read()) + self.assertIsNone(blob._verified_bytes) + + async def test_close_writers_on_finished(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + await self._test_close_writers_on_finished(BlobBuffer) + await self._test_close_writers_on_finished(BlobFile, tmp_dir) + + async def test_delete(self): + blob_buffer = await self._test_create_blob(BlobBuffer) + self.assertIsInstance(blob_buffer, BlobBuffer) + self.assertIsNotNone(blob_buffer._verified_bytes) + self.assertTrue(blob_buffer.get_is_verified()) + blob_buffer.delete() + self.assertIsNone(blob_buffer._verified_bytes) + self.assertFalse(blob_buffer.get_is_verified()) - loop = asyncio.get_event_loop() tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite")) - blob_manager = BlobFileManager(loop, tmp_dir, storage) + blob_file = await self._test_create_blob(BlobFile, tmp_dir) + self.assertIsInstance(blob_file, BlobFile) + self.assertTrue(os.path.isfile(blob_file.file_path)) + self.assertTrue(blob_file.get_is_verified()) + blob_file.delete() + self.assertFalse(os.path.isfile(blob_file.file_path)) + self.assertFalse(blob_file.get_is_verified()) - await storage.open() - await blob_manager.setup() + async def test_delete_corrupt(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + blob = BlobFile( + self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed, + blob_directory=tmp_dir + ) + writer = blob.get_blob_writer() + writer.write(self.blob_bytes) + await blob.verified.wait() + blob.close() + blob = BlobFile( + self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed, + blob_directory=tmp_dir + ) + self.assertTrue(blob.get_is_verified()) - # add the blob on the server - blob = blob_manager.get_blob(blob_hash, len(blob_bytes)) - self.assertEqual(blob.get_is_verified(), False) - self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes) + with open(blob.file_path, 'wb+') as f: + f.write(b'\x00') + blob = BlobFile( + self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed, + blob_directory=tmp_dir + ) + self.assertFalse(blob.get_is_verified()) + self.assertFalse(os.path.isfile(blob.file_path)) - writer = blob.open_for_writing() - writer.write(blob_bytes) - await blob.finished_writing.wait() - self.assertTrue(os.path.isfile(blob.file_path), True) - self.assertEqual(blob.get_is_verified(), True) - self.assertIn(blob_hash, blob_manager.completed_blob_hashes) + def test_invalid_blob_hash(self): + self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, '', len(self.blob_bytes)) + self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, 'x' * 96, len(self.blob_bytes)) + self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, 'a' * 97, len(self.blob_bytes)) + + async def _test_close_reader(self, blob_class=AbstractBlob, blob_directory=None): + blob = await self._test_create_blob(blob_class, blob_directory) + reader = blob.reader_context() + self.assertEqual(0, len(blob.readers)) + + async def read_blob_buffer(): + with reader as read_handle: + self.assertEqual(1, len(blob.readers)) + await asyncio.sleep(2, loop=self.loop) + self.assertEqual(0, len(blob.readers)) + return read_handle.read() + + self.loop.call_later(1, blob.close) + with self.assertRaises(ValueError) as err: + read_task = self.loop.create_task(read_blob_buffer()) + await read_task + self.assertEqual(err.exception, ValueError("I/O operation on closed file")) + + async def test_close_reader(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + await self._test_close_reader(BlobBuffer) + await self._test_close_reader(BlobFile, tmp_dir) diff --git a/tests/unit/blob/test_blob_manager.py b/tests/unit/blob/test_blob_manager.py index a181c3874..6dd1885dd 100644 --- a/tests/unit/blob/test_blob_manager.py +++ b/tests/unit/blob/test_blob_manager.py @@ -5,50 +5,53 @@ import os from torba.testcase import AsyncioTestCase from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager class TestBlobManager(AsyncioTestCase): - async def test_sync_blob_manager_on_startup(self): - loop = asyncio.get_event_loop() + async def setup_blob_manager(self, save_blobs=True): tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + self.config = Config(save_blobs=save_blobs) + self.storage = SQLiteStorage(self.config, os.path.join(tmp_dir, "lbrynet.sqlite")) + self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.config) + await self.storage.open() - storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite")) - blob_manager = BlobFileManager(loop, tmp_dir, storage) + async def test_sync_blob_file_manager_on_startup(self): + await self.setup_blob_manager(save_blobs=True) # add a blob file blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" blob_bytes = b'1' * ((2 * 2 ** 20) - 1) - with open(os.path.join(blob_manager.blob_dir, blob_hash), 'wb') as f: + with open(os.path.join(self.blob_manager.blob_dir, blob_hash), 'wb') as f: f.write(blob_bytes) # it should not have been added automatically on startup - await storage.open() - await blob_manager.setup() - self.assertSetEqual(blob_manager.completed_blob_hashes, set()) + + await self.blob_manager.setup() + self.assertSetEqual(self.blob_manager.completed_blob_hashes, set()) # make sure we can add the blob - await blob_manager.blob_completed(blob_manager.get_blob(blob_hash, len(blob_bytes))) - self.assertSetEqual(blob_manager.completed_blob_hashes, {blob_hash}) + await self.blob_manager.blob_completed(self.blob_manager.get_blob(blob_hash, len(blob_bytes))) + self.assertSetEqual(self.blob_manager.completed_blob_hashes, {blob_hash}) # stop the blob manager and restart it, make sure the blob is there - blob_manager.stop() - self.assertSetEqual(blob_manager.completed_blob_hashes, set()) - await blob_manager.setup() - self.assertSetEqual(blob_manager.completed_blob_hashes, {blob_hash}) + self.blob_manager.stop() + self.assertSetEqual(self.blob_manager.completed_blob_hashes, set()) + await self.blob_manager.setup() + self.assertSetEqual(self.blob_manager.completed_blob_hashes, {blob_hash}) # test that the blob is removed upon the next startup after the file being manually deleted - blob_manager.stop() + self.blob_manager.stop() # manually delete the blob file and restart the blob manager - os.remove(os.path.join(blob_manager.blob_dir, blob_hash)) - await blob_manager.setup() - self.assertSetEqual(blob_manager.completed_blob_hashes, set()) + os.remove(os.path.join(self.blob_manager.blob_dir, blob_hash)) + await self.blob_manager.setup() + self.assertSetEqual(self.blob_manager.completed_blob_hashes, set()) # check that the deleted blob was updated in the database self.assertEqual( 'pending', ( - await storage.run_and_return_one_or_none('select status from blob where blob_hash=?', blob_hash) + await self.storage.run_and_return_one_or_none('select status from blob where blob_hash=?', blob_hash) ) ) diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index 482781ab2..2cc2c18c8 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -9,9 +9,9 @@ from lbrynet.blob_exchange.serialization import BlobRequest from torba.testcase import AsyncioTestCase from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol -from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob +from lbrynet.blob_exchange.client import request_blob from lbrynet.dht.peer import KademliaPeer, PeerManager # import logging @@ -35,13 +35,13 @@ class BlobExchangeTestBase(AsyncioTestCase): self.server_config = Config(data_dir=self.server_dir, download_dir=self.server_dir, wallet=self.server_dir, reflector_servers=[]) self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite")) - self.server_blob_manager = BlobFileManager(self.loop, self.server_dir, self.server_storage) + self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config) self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') self.client_config = Config(data_dir=self.client_dir, download_dir=self.client_dir, wallet=self.client_dir, reflector_servers=[]) self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite")) - self.client_blob_manager = BlobFileManager(self.loop, self.client_dir, self.client_storage) + self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config) self.client_peer_manager = PeerManager(self.loop) self.server_from_client = KademliaPeer(self.loop, "127.0.0.1", b'1' * 48, tcp_port=33333) @@ -51,6 +51,7 @@ class BlobExchangeTestBase(AsyncioTestCase): await self.server_blob_manager.setup() self.server.start_server(33333, '127.0.0.1') + self.addCleanup(self.server.stop_server) await self.server.started_listening.wait() @@ -58,21 +59,25 @@ class TestBlobExchange(BlobExchangeTestBase): async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes): # add the blob on the server server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes)) - writer = server_blob.open_for_writing() + writer = server_blob.get_blob_writer() writer.write(blob_bytes) - await server_blob.finished_writing.wait() + await server_blob.verified.wait() self.assertTrue(os.path.isfile(server_blob.file_path)) self.assertEqual(server_blob.get_is_verified(), True) + self.assertTrue(writer.closed()) async def _test_transfer_blob(self, blob_hash: str): client_blob = self.client_blob_manager.get_blob(blob_hash) # download the blob - downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address, - self.server_from_client.tcp_port, 2, 3) - await client_blob.finished_writing.wait() + downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address, + self.server_from_client.tcp_port, 2, 3) + self.assertIsNotNone(transport) + self.addCleanup(transport.close) + await client_blob.verified.wait() self.assertEqual(client_blob.get_is_verified(), True) self.assertTrue(downloaded) + client_blob.close() async def test_transfer_sd_blob(self): sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02" @@ -92,9 +97,11 @@ class TestBlobExchange(BlobExchangeTestBase): second_client_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, second_client_dir) - - second_client_storage = SQLiteStorage(Config(), os.path.join(second_client_dir, "lbrynet.sqlite")) - second_client_blob_manager = BlobFileManager(self.loop, second_client_dir, second_client_storage) + second_client_conf = Config() + second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite")) + second_client_blob_manager = BlobManager( + self.loop, second_client_dir, second_client_storage, second_client_conf + ) server_from_second_client = KademliaPeer(self.loop, "127.0.0.1", b'1' * 48, tcp_port=33333) await second_client_storage.open() @@ -102,7 +109,7 @@ class TestBlobExchange(BlobExchangeTestBase): await self._add_blob_to_server(blob_hash, mock_blob_bytes) - second_client_blob = self.client_blob_manager.get_blob(blob_hash) + second_client_blob = second_client_blob_manager.get_blob(blob_hash) # download the blob await asyncio.gather( @@ -112,9 +119,65 @@ class TestBlobExchange(BlobExchangeTestBase): ), self._test_transfer_blob(blob_hash) ) - await second_client_blob.finished_writing.wait() + await second_client_blob.verified.wait() self.assertEqual(second_client_blob.get_is_verified(), True) + async def test_blob_writers_concurrency(self): + blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" + mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1) + blob = self.server_blob_manager.get_blob(blob_hash) + write_blob = blob._write_blob + write_called_count = 0 + + def wrap_write_blob(blob_bytes): + nonlocal write_called_count + write_called_count += 1 + write_blob(blob_bytes) + blob._write_blob = wrap_write_blob + + writer1 = blob.get_blob_writer(peer_port=1) + writer2 = blob.get_blob_writer(peer_port=2) + reader1_ctx_before_write = blob.reader_context() + + with self.assertRaises(OSError): + blob.get_blob_writer(peer_port=2) + with self.assertRaises(OSError): + with blob.reader_context(): + pass + + blob.set_length(len(mock_blob_bytes)) + results = {} + + def check_finished_callback(writer, num): + def inner(writer_future: asyncio.Future): + results[num] = writer_future.result() + writer.finished.add_done_callback(inner) + + check_finished_callback(writer1, 1) + check_finished_callback(writer2, 2) + + def write_task(writer): + async def _inner(): + writer.write(mock_blob_bytes) + return self.loop.create_task(_inner()) + + await asyncio.gather(write_task(writer1), write_task(writer2), loop=self.loop) + + self.assertDictEqual({1: mock_blob_bytes, 2: mock_blob_bytes}, results) + self.assertEqual(1, write_called_count) + self.assertTrue(blob.get_is_verified()) + self.assertDictEqual({}, blob.writers) + + with reader1_ctx_before_write as f: + self.assertEqual(mock_blob_bytes, f.read()) + with blob.reader_context() as f: + self.assertEqual(mock_blob_bytes, f.read()) + with blob.reader_context() as f: + blob.close() + with self.assertRaises(ValueError): + f.read() + self.assertListEqual([], blob.readers) + async def test_host_different_blobs_to_multiple_peers_at_once(self): blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1) @@ -124,9 +187,12 @@ class TestBlobExchange(BlobExchangeTestBase): second_client_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, second_client_dir) + second_client_conf = Config() - second_client_storage = SQLiteStorage(Config(), os.path.join(second_client_dir, "lbrynet.sqlite")) - second_client_blob_manager = BlobFileManager(self.loop, second_client_dir, second_client_storage) + second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite")) + second_client_blob_manager = BlobManager( + self.loop, second_client_dir, second_client_storage, second_client_conf + ) server_from_second_client = KademliaPeer(self.loop, "127.0.0.1", b'1' * 48, tcp_port=33333) await second_client_storage.open() @@ -143,7 +209,7 @@ class TestBlobExchange(BlobExchangeTestBase): server_from_second_client.tcp_port, 2, 3 ), self._test_transfer_blob(sd_hash), - second_client_blob.finished_writing.wait() + second_client_blob.verified.wait() ) self.assertEqual(second_client_blob.get_is_verified(), True) diff --git a/tests/unit/core/test_utils.py b/tests/unit/core/test_utils.py index b4db86e8f..fb783628c 100644 --- a/tests/unit/core/test_utils.py +++ b/tests/unit/core/test_utils.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from lbrynet import utils - import unittest +import asyncio +from lbrynet import utils +from torba.testcase import AsyncioTestCase class CompareVersionTest(unittest.TestCase): @@ -61,3 +62,76 @@ class SdHashTests(unittest.TestCase): } } self.assertIsNone(utils.get_sd_hash(claim)) + + +class CacheConcurrentDecoratorTests(AsyncioTestCase): + def setUp(self): + self.called = [] + self.finished = [] + self.counter = 0 + + @utils.cache_concurrent + async def foo(self, arg1, arg2=None, delay=1): + self.called.append((arg1, arg2, delay)) + await asyncio.sleep(delay, loop=self.loop) + self.counter += 1 + self.finished.append((arg1, arg2, delay)) + return object() + + async def test_gather_duplicates(self): + result = await asyncio.gather( + self.loop.create_task(self.foo(1)), self.loop.create_task(self.foo(1)), loop=self.loop + ) + self.assertEqual(1, len(self.called)) + self.assertEqual(1, len(self.finished)) + self.assertEqual(1, self.counter) + self.assertIs(result[0], result[1]) + self.assertEqual(2, len(result)) + + async def test_one_cancelled_all_cancel(self): + t1 = self.loop.create_task(self.foo(1)) + self.loop.call_later(0.1, t1.cancel) + + with self.assertRaises(asyncio.CancelledError): + await asyncio.gather( + t1, self.loop.create_task(self.foo(1)), loop=self.loop + ) + self.assertEqual(1, len(self.called)) + self.assertEqual(0, len(self.finished)) + self.assertEqual(0, self.counter) + + async def test_error_after_success(self): + def cause_type_error(): + self.counter = "" + + self.loop.call_later(0.1, cause_type_error) + + t1 = self.loop.create_task(self.foo(1)) + t2 = self.loop.create_task(self.foo(1)) + + with self.assertRaises(TypeError): + await t2 + self.assertEqual(1, len(self.called)) + self.assertEqual(0, len(self.finished)) + self.assertTrue(t1.done()) + self.assertEqual("", self.counter) + + # test that the task is run fresh, it should not error + self.counter = 0 + t3 = self.loop.create_task(self.foo(1)) + self.assertTrue((await t3)) + self.assertEqual(1, self.counter) + + # the previously failed call should still raise if awaited + with self.assertRaises(TypeError): + await t1 + + self.assertEqual(1, self.counter) + + async def test_break_it(self): + t1 = self.loop.create_task(self.foo(1)) + t2 = self.loop.create_task(self.foo(1)) + t3 = self.loop.create_task(self.foo(2, delay=0)) + t3.add_done_callback(lambda _: t2.cancel()) + with self.assertRaises(asyncio.CancelledError): + await asyncio.gather(t1, t2, t3) diff --git a/tests/unit/database/test_SQLiteStorage.py b/tests/unit/database/test_SQLiteStorage.py index 0f380b17b..3a7564faa 100644 --- a/tests/unit/database/test_SQLiteStorage.py +++ b/tests/unit/database/test_SQLiteStorage.py @@ -7,7 +7,7 @@ from torba.testcase import AsyncioTestCase from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage from lbrynet.blob.blob_info import BlobInfo -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.stream.descriptor import StreamDescriptor from tests.test_utils import random_lbry_hash @@ -68,17 +68,18 @@ fake_claim_info = { class StorageTest(AsyncioTestCase): async def asyncSetUp(self): - self.storage = SQLiteStorage(Config(), ':memory:') + self.conf = Config() + self.storage = SQLiteStorage(self.conf, ':memory:') self.blob_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.blob_dir) - self.blob_manager = BlobFileManager(asyncio.get_event_loop(), self.blob_dir, self.storage) + self.blob_manager = BlobManager(asyncio.get_event_loop(), self.blob_dir, self.storage, self.conf) await self.storage.open() async def asyncTearDown(self): await self.storage.close() async def store_fake_blob(self, blob_hash, length=100): - await self.storage.add_completed_blob(blob_hash, length) + await self.storage.add_blobs((blob_hash, length), finished=True) async def store_fake_stream(self, stream_hash, blobs=None, file_name="fake_file", key="DEADBEEF"): blobs = blobs or [BlobInfo(1, 100, "DEADBEEF", random_lbry_hash())] diff --git a/tests/unit/dht/protocol/test_data_store.py b/tests/unit/dht/protocol/test_data_store.py index 54c58cce9..f8d264ffd 100644 --- a/tests/unit/dht/protocol/test_data_store.py +++ b/tests/unit/dht/protocol/test_data_store.py @@ -1,12 +1,13 @@ import asyncio -from torba.testcase import AsyncioTestCase +from unittest import mock, TestCase from lbrynet.dht.protocol.data_store import DictDataStore from lbrynet.dht.peer import PeerManager -class DataStoreTests(AsyncioTestCase): +class DataStoreTests(TestCase): def setUp(self): - self.loop = asyncio.get_event_loop() + self.loop = mock.Mock(spec=asyncio.BaseEventLoop) + self.loop.time = lambda: 0.0 self.peer_manager = PeerManager(self.loop) self.data_store = DictDataStore(self.loop, self.peer_manager) diff --git a/tests/unit/dht/test_blob_announcer.py b/tests/unit/dht/test_blob_announcer.py index 654484af7..85dcd0946 100644 --- a/tests/unit/dht/test_blob_announcer.py +++ b/tests/unit/dht/test_blob_announcer.py @@ -78,8 +78,7 @@ class TestBlobAnnouncer(AsyncioTestCase): blob2 = binascii.hexlify(b'2' * 48).decode() async with self._test_network_context(): - await self.storage.add_completed_blob(blob1, 1024) - await self.storage.add_completed_blob(blob2, 1024) + await self.storage.add_blobs((blob1, 1024), (blob2, 1024), finished=True) await self.storage.db.execute( "update blob set next_announce_time=0, should_announce=1 where blob_hash in (?, ?)", (blob1, blob2) diff --git a/tests/unit/stream/test_assembler.py b/tests/unit/stream/test_assembler.py deleted file mode 100644 index ef636890f..000000000 --- a/tests/unit/stream/test_assembler.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import asyncio -import tempfile -import shutil - -from torba.testcase import AsyncioTestCase -from lbrynet.conf import Config -from lbrynet.blob.blob_file import MAX_BLOB_SIZE -from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager -from lbrynet.stream.assembler import StreamAssembler -from lbrynet.stream.descriptor import StreamDescriptor -from lbrynet.stream.stream_manager import StreamManager - - -class TestStreamAssembler(AsyncioTestCase): - def setUp(self): - self.loop = asyncio.get_event_loop() - self.key = b'deadbeef' * 4 - self.cleartext = b'test' - - async def test_create_and_decrypt_one_blob_stream(self, corrupt=False): - tmp_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - self.storage = SQLiteStorage(Config(), ":memory:") - await self.storage.open() - self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage) - - download_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(download_dir)) - - # create the stream - file_path = os.path.join(tmp_dir, "test_file") - with open(file_path, 'wb') as f: - f.write(self.cleartext) - - sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key) - - # copy blob files - sd_hash = sd.calculate_sd_hash() - shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash)) - for blob_info in sd.blobs: - if blob_info.blob_hash: - shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash)) - if corrupt and blob_info.length == MAX_BLOB_SIZE: - with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle: - handle.truncate() - handle.flush() - - downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite")) - await downloader_storage.open() - - # add the blobs to the blob table (this would happen upon a blob download finishing) - downloader_blob_manager = BlobFileManager(self.loop, download_dir, downloader_storage) - descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash) - - # assemble the decrypted file - assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash) - await assembler.assemble_decrypted_stream(download_dir) - if corrupt: - return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file"))) - - with open(os.path.join(download_dir, "test_file"), "rb") as f: - decrypted = f.read() - self.assertEqual(decrypted, self.cleartext) - self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified()) - self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified()) - # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs - self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs())) - self.assertEqual( - [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce() - ) - - await downloader_storage.close() - await self.storage.close() - - async def test_create_and_decrypt_multi_blob_stream(self): - self.cleartext = b'test\n' * 20000000 - await self.test_create_and_decrypt_one_blob_stream() - - async def test_create_and_decrypt_padding(self): - for i in range(16): - self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i) - await self.test_create_and_decrypt_one_blob_stream() - - for i in range(16): - self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i) - await self.test_create_and_decrypt_one_blob_stream() - - async def test_create_and_decrypt_random(self): - self.cleartext = os.urandom(20000000) - await self.test_create_and_decrypt_one_blob_stream() - - async def test_create_managed_stream_announces(self): - # setup a blob manager - storage = SQLiteStorage(Config(), ":memory:") - await storage.open() - tmp_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - blob_manager = BlobFileManager(self.loop, tmp_dir, storage) - stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None) - # create the stream - download_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(download_dir)) - file_path = os.path.join(download_dir, "test_file") - with open(file_path, 'wb') as f: - f.write(b'testtest') - - stream = await stream_manager.create_stream(file_path) - self.assertEqual( - [stream.sd_hash, stream.descriptor.blobs[0].blob_hash], - await storage.get_blobs_to_announce()) - - async def test_create_truncate_and_handle_stream(self): - self.cleartext = b'potato' * 1337 * 5279 - # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated - await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5) diff --git a/tests/unit/stream/test_downloader.py b/tests/unit/stream/test_downloader.py deleted file mode 100644 index d97444c0c..000000000 --- a/tests/unit/stream/test_downloader.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import time -import unittest -from unittest import mock -import asyncio - -from lbrynet.blob_exchange.serialization import BlobResponse -from lbrynet.blob_exchange.server import BlobServerProtocol -from lbrynet.conf import Config -from lbrynet.stream.descriptor import StreamDescriptor -from lbrynet.stream.downloader import StreamDownloader -from lbrynet.dht.node import Node -from lbrynet.dht.peer import KademliaPeer -from lbrynet.blob.blob_file import MAX_BLOB_SIZE -from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase - - -class TestStreamDownloader(BlobExchangeTestBase): - async def setup_stream(self, blob_count: int = 10): - self.stream_bytes = b'' - for _ in range(blob_count): - self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) - # create the stream - file_path = os.path.join(self.server_dir, "test_file") - with open(file_path, 'wb') as f: - f.write(self.stream_bytes) - descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) - self.sd_hash = descriptor.calculate_sd_hash() - conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir, - reflector_servers=[]) - self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash) - - async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): - await self.setup_stream(blob_count) - mock_node = mock.Mock(spec=Node) - - def _mock_accumulate_peers(q1, q2): - async def _task(): - pass - q2.put_nowait([self.server_from_client]) - return q2, self.loop.create_task(_task()) - - mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers - self.downloader.download(mock_node) - await self.downloader.stream_finished_event.wait() - self.assertTrue(self.downloader.stream_handle.closed) - self.assertTrue(os.path.isfile(self.downloader.output_path)) - self.downloader.stop() - self.assertIs(self.downloader.stream_handle, None) - self.assertTrue(os.path.isfile(self.downloader.output_path)) - with open(self.downloader.output_path, 'rb') as f: - self.assertEqual(f.read(), self.stream_bytes) - await asyncio.sleep(0.01) - - async def test_transfer_stream(self): - await self._test_transfer_stream(10) - - @unittest.SkipTest - async def test_transfer_hundred_blob_stream(self): - await self._test_transfer_stream(100) - - async def test_transfer_stream_bad_first_peer_good_second(self): - await self.setup_stream(2) - - mock_node = mock.Mock(spec=Node) - q = asyncio.Queue() - - bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334) - - def _mock_accumulate_peers(q1, q2): - async def _task(): - pass - - q2.put_nowait([bad_peer]) - self.loop.call_later(1, q2.put_nowait, [self.server_from_client]) - return q2, self.loop.create_task(_task()) - - mock_node.accumulate_peers = _mock_accumulate_peers - - self.downloader.download(mock_node) - await self.downloader.stream_finished_event.wait() - self.assertTrue(os.path.isfile(self.downloader.output_path)) - with open(self.downloader.output_path, 'rb') as f: - self.assertEqual(f.read(), self.stream_bytes) - # self.assertIs(self.server_from_client.tcp_last_down, None) - # self.assertIsNot(bad_peer.tcp_last_down, None) - - async def test_client_chunked_response(self): - self.server.stop_server() - class ChunkedServerProtocol(BlobServerProtocol): - - def send_response(self, responses): - to_send = [] - while responses: - to_send.append(responses.pop()) - for byte in BlobResponse(to_send).serialize(): - self.transport.write(bytes([byte])) - self.server.server_protocol_class = ChunkedServerProtocol - self.server.start_server(33333, '127.0.0.1') - self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes)) - await asyncio.wait_for(self._test_transfer_stream(10), timeout=2) - self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes)) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py new file mode 100644 index 000000000..7b8d92ff7 --- /dev/null +++ b/tests/unit/stream/test_managed_stream.py @@ -0,0 +1,176 @@ +import os +import shutil +import unittest +from unittest import mock +import asyncio +from lbrynet.blob.blob_file import MAX_BLOB_SIZE +from lbrynet.blob_exchange.serialization import BlobResponse +from lbrynet.blob_exchange.server import BlobServerProtocol +from lbrynet.dht.node import Node +from lbrynet.dht.peer import KademliaPeer +from lbrynet.extras.daemon.storage import StoredStreamClaim +from lbrynet.stream.managed_stream import ManagedStream +from lbrynet.stream.descriptor import StreamDescriptor +from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase + + +def get_mock_node(loop): + mock_node = mock.Mock(spec=Node) + mock_node.joined = asyncio.Event(loop=loop) + mock_node.joined.set() + return mock_node + + +class TestManagedStream(BlobExchangeTestBase): + async def create_stream(self, blob_count: int = 10): + self.stream_bytes = b'' + for _ in range(blob_count): + self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) + # create the stream + file_path = os.path.join(self.server_dir, "test_file") + with open(file_path, 'wb') as f: + f.write(self.stream_bytes) + descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) + self.sd_hash = descriptor.calculate_sd_hash() + return descriptor + + async def setup_stream(self, blob_count: int = 10): + await self.create_stream(blob_count) + self.stream = ManagedStream( + self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir + ) + + async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): + await self.setup_stream(blob_count) + mock_node = mock.Mock(spec=Node) + + def _mock_accumulate_peers(q1, q2): + async def _task(): + pass + q2.put_nowait([self.server_from_client]) + return q2, self.loop.create_task(_task()) + + mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers + await self.stream.setup(mock_node, save_file=True) + await self.stream.finished_writing.wait() + self.assertTrue(os.path.isfile(self.stream.full_path)) + self.stream.stop_download() + self.assertTrue(os.path.isfile(self.stream.full_path)) + with open(self.stream.full_path, 'rb') as f: + self.assertEqual(f.read(), self.stream_bytes) + await asyncio.sleep(0.01) + + async def test_transfer_stream(self): + await self._test_transfer_stream(10) + + @unittest.SkipTest + async def test_transfer_hundred_blob_stream(self): + await self._test_transfer_stream(100) + + async def test_transfer_stream_bad_first_peer_good_second(self): + await self.setup_stream(2) + + mock_node = mock.Mock(spec=Node) + q = asyncio.Queue() + + bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334) + + def _mock_accumulate_peers(q1, q2): + async def _task(): + pass + + q2.put_nowait([bad_peer]) + self.loop.call_later(1, q2.put_nowait, [self.server_from_client]) + return q2, self.loop.create_task(_task()) + + mock_node.accumulate_peers = _mock_accumulate_peers + + await self.stream.setup(mock_node, save_file=True) + await self.stream.finished_writing.wait() + self.assertTrue(os.path.isfile(self.stream.full_path)) + with open(self.stream.full_path, 'rb') as f: + self.assertEqual(f.read(), self.stream_bytes) + # self.assertIs(self.server_from_client.tcp_last_down, None) + # self.assertIsNot(bad_peer.tcp_last_down, None) + + async def test_client_chunked_response(self): + self.server.stop_server() + + class ChunkedServerProtocol(BlobServerProtocol): + def send_response(self, responses): + to_send = [] + while responses: + to_send.append(responses.pop()) + for byte in BlobResponse(to_send).serialize(): + self.transport.write(bytes([byte])) + self.server.server_protocol_class = ChunkedServerProtocol + self.server.start_server(33333, '127.0.0.1') + self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes)) + await asyncio.wait_for(self._test_transfer_stream(10), timeout=2) + self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes)) + + async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False): + descriptor = await self.create_stream(blobs) + + # copy blob files + shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash), + os.path.join(self.client_blob_manager.blob_dir, self.sd_hash)) + self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash, + self.client_dir) + + for blob_info in descriptor.blobs[:-1]: + shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash), + os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash)) + if corrupt and blob_info.length == MAX_BLOB_SIZE: + with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle: + handle.truncate() + handle.flush() + await self.stream.setup() + await self.stream.finished_writing.wait() + if corrupt: + return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) + + with open(os.path.join(self.client_dir, "test_file"), "rb") as f: + decrypted = f.read() + self.assertEqual(decrypted, self.stream_bytes) + + self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified()) + self.assertEqual( + True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified() + ) + # + # # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs + # self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs())) + # self.assertEqual( + # [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce() + # ) + # + # await downloader_storage.close() + # await self.storage.close() + + async def test_create_and_decrypt_multi_blob_stream(self): + await self.test_create_and_decrypt_one_blob_stream(10) + + # async def test_create_managed_stream_announces(self): + # # setup a blob manager + # storage = SQLiteStorage(Config(), ":memory:") + # await storage.open() + # tmp_dir = tempfile.mkdtemp() + # self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + # blob_manager = BlobManager(self.loop, tmp_dir, storage) + # stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None) + # # create the stream + # download_dir = tempfile.mkdtemp() + # self.addCleanup(lambda: shutil.rmtree(download_dir)) + # file_path = os.path.join(download_dir, "test_file") + # with open(file_path, 'wb') as f: + # f.write(b'testtest') + # + # stream = await stream_manager.create_stream(file_path) + # self.assertEqual( + # [stream.sd_hash, stream.descriptor.blobs[0].blob_hash], + # await storage.get_blobs_to_announce()) + + # async def test_create_truncate_and_handle_stream(self): + # # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated + # await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5) diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py index 95b87da08..b5cdf2960 100644 --- a/tests/unit/stream/test_reflector.py +++ b/tests/unit/stream/test_reflector.py @@ -5,7 +5,7 @@ import shutil from torba.testcase import AsyncioTestCase from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.stream.stream_manager import StreamManager from lbrynet.stream.reflector.server import ReflectorServer @@ -18,16 +18,18 @@ class TestStreamAssembler(AsyncioTestCase): tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - self.storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite")) + self.conf = Config() + self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite")) await self.storage.open() - self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage) + self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf) self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None) server_tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(server_tmp_dir)) - self.server_storage = SQLiteStorage(Config(), os.path.join(server_tmp_dir, "lbrynet.sqlite")) + self.server_conf = Config() + self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite")) await self.server_storage.open() - self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage) + self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf) download_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(download_dir)) diff --git a/tests/unit/stream/test_stream_descriptor.py b/tests/unit/stream/test_stream_descriptor.py index 0f095e39d..479ab48ac 100644 --- a/tests/unit/stream/test_stream_descriptor.py +++ b/tests/unit/stream/test_stream_descriptor.py @@ -9,7 +9,7 @@ from torba.testcase import AsyncioTestCase from lbrynet.conf import Config from lbrynet.error import InvalidStreamDescriptorError from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.blob.blob_manager import BlobManager from lbrynet.stream.descriptor import StreamDescriptor @@ -20,9 +20,10 @@ class TestStreamDescriptor(AsyncioTestCase): self.cleartext = os.urandom(20000000) self.tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(self.tmp_dir)) - self.storage = SQLiteStorage(Config(), ":memory:") + self.conf = Config() + self.storage = SQLiteStorage(self.conf, ":memory:") await self.storage.open() - self.blob_manager = BlobFileManager(self.loop, self.tmp_dir, self.storage) + self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.conf) self.file_path = os.path.join(self.tmp_dir, "test_file") with open(self.file_path, 'wb') as f: @@ -83,9 +84,10 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase): loop = asyncio.get_event_loop() tmp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - storage = SQLiteStorage(Config(), ":memory:") + self.conf = Config() + storage = SQLiteStorage(self.conf, ":memory:") await storage.open() - blob_manager = BlobFileManager(loop, tmp_dir, storage) + blob_manager = BlobManager(loop, tmp_dir, storage, self.conf) sd_bytes = b'{"stream_name": "4f62616d6120446f6e6b65792d322e73746c", "blobs": [{"length": 1153488, "blob_num' \ b'": 0, "blob_hash": "9fa32a249ce3f2d4e46b78599800f368b72f2a7f22b81df443c7f6bdbef496bd61b4c0079c7' \ @@ -99,7 +101,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase): blob = blob_manager.get_blob(sd_hash) blob.set_length(len(sd_bytes)) - writer = blob.open_for_writing() + writer = blob.get_blob_writer() writer.write(sd_bytes) await blob.verified.wait() descriptor = await StreamDescriptor.from_stream_descriptor_blob( @@ -116,7 +118,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase): sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2' with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle: handle.write(b'doesnt work') - blob = BlobFile(loop, tmp_dir, sd_hash) + blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir) self.assertTrue(blob.file_exists) self.assertIsNotNone(blob.length) with self.assertRaises(InvalidStreamDescriptorError): diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index b2a57c652..9fd957df1 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase): def check_post(event): self.assertEqual(event['event'], 'Time To First Bytes') self.assertEqual(event['properties']['error'], 'DownloadSDTimeout') - self.assertEqual(event['properties']['tried_peers_count'], None) - self.assertEqual(event['properties']['active_peer_count'], None) + self.assertEqual(event['properties']['tried_peers_count'], 0) + self.assertEqual(event['properties']['active_peer_count'], 0) self.assertEqual(event['properties']['use_fixed_peers'], False) self.assertEqual(event['properties']['added_fixed_peers'], False) self.assertEqual(event['properties']['fixed_peer_delay'], None) @@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase): self.stream_manager.analytics_manager._post = check_post - self.assertSetEqual(self.stream_manager.streams, set()) + self.assertDictEqual(self.stream_manager.streams, {}) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) stream_hash = stream.stream_hash - self.assertSetEqual(self.stream_manager.streams, {stream}) + self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream}) self.assertTrue(stream.running) self.assertFalse(stream.finished) self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file"))) @@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase): self.assertEqual(stored_status, "stopped") await self.stream_manager.start_stream(stream) - await stream.downloader.stream_finished_event.wait() + await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.assertTrue(stream.finished) self.assertFalse(stream.running) @@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase): self.assertEqual(stored_status, "finished") await self.stream_manager.delete_stream(stream, True) - self.assertSetEqual(self.stream_manager.streams, set()) + self.assertDictEqual(self.stream_manager.streams, {}) self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) stored_status = await self.client_storage.run_and_return_one_or_none( "select status from file where stream_hash=?", stream_hash @@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase): async def _test_download_error_on_start(self, expected_error, timeout=None): with self.assertRaises(expected_error): - await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout) + await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout) async def _test_download_error_analytics_on_start(self, expected_error, timeout=None): received = [] @@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase): await self.setup_stream_manager(old_sort=old_sort) self.stream_manager.analytics_manager._post = check_post - self.assertSetEqual(self.stream_manager.streams, set()) + self.assertDictEqual(self.stream_manager.streams, {}) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) - await stream.downloader.stream_finished_event.wait() + await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.stream_manager.stop() self.client_blob_manager.stop() @@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase): await self.client_blob_manager.setup() await self.stream_manager.start() self.assertEqual(1, len(self.stream_manager.streams)) - self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash) - self.assertEqual('stopped', list(self.stream_manager.streams)[0].status) + self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys())) + for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]: + blob_status = await self.client_storage.get_blob_status(blob_hash) + self.assertEqual('pending', blob_status) + self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status) sd_blob = self.client_blob_manager.get_blob(stream.sd_hash) self.assertTrue(sd_blob.file_exists)