From 676f0015aa430048fff30b80ea00145fc05003d9 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Sat, 30 Mar 2019 20:17:42 -0400 Subject: [PATCH] refactor BlobFile into AbstractBlob, BlobFile, and BlobBuffer classes --- lbrynet/blob/blob_file.py | 348 +++++++++++++++++----------- lbrynet/blob/blob_manager.py | 28 ++- lbrynet/blob_exchange/client.py | 22 +- lbrynet/blob_exchange/downloader.py | 34 ++- lbrynet/extras/daemon/Components.py | 2 +- lbrynet/stream/descriptor.py | 22 +- lbrynet/stream/reflector/server.py | 8 +- 7 files changed, 290 insertions(+), 174 deletions(-) diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index 2206b886e..4890e2ea2 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -4,6 +4,8 @@ import asyncio import binascii import logging import typing +import contextlib +from io import BytesIO from cryptography.hazmat.primitives.ciphers import Cipher, modes from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.padding import PKCS7 @@ -21,10 +23,6 @@ log = logging.getLogger(__name__) _hexmatch = re.compile("^[a-f,0-9]+$") -def is_valid_hashcharacter(char: str) -> bool: - return len(char) == 1 and _hexmatch.match(char) - - def is_valid_blobhash(blobhash: str) -> bool: """Checks whether the blobhash is the correct length and contains only valid characters (0-9, a-f) @@ -46,143 +44,51 @@ def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tupl return encrypted, digest.hexdigest() -class BlobFile: - """ - A chunk of data available on the network which is specified by a hashsum +def decrypt_blob_bytes(read_handle: typing.BinaryIO, length: int, key: bytes, iv: bytes) -> bytes: + buff = read_handle.read() + if len(buff) != length: + raise ValueError("unexpected length") + cipher = Cipher(AES(key), modes.CBC(iv), backend=backend) + unpadder = PKCS7(AES.block_size).unpadder() + decryptor = cipher.decryptor() + return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize() - This class is used to create blobs on the local filesystem - when we already know the blob hash before hand (i.e., when downloading blobs) - Also can be used for reading from blobs on the local filesystem - """ - def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str, - length: typing.Optional[int] = None, - blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None): +class AbstractBlob: + """ + A chunk of data (up to 2MB) available on the network which is specified by a sha384 hash + + This class is non-io specific + """ + def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None, + blob_directory: typing.Optional[str] = None): if not is_valid_blobhash(blob_hash): raise InvalidBlobHashError(blob_hash) + self.loop = loop self.blob_hash = blob_hash self.length = length - self.blob_dir = blob_dir - self.file_path = os.path.join(blob_dir, self.blob_hash) - self.writers: typing.List[HashBlobWriter] = [] - - self.verified: asyncio.Event = asyncio.Event(loop=self.loop) - self.finished_writing = asyncio.Event(loop=loop) - self.blob_write_lock = asyncio.Lock(loop=loop) - if self.file_exists: - length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size) - self.length = length - self.verified.set() - self.finished_writing.set() - self.saved_verified_blob = False self.blob_completed_callback = blob_completed_callback + self.blob_directory = blob_directory - @property - def file_exists(self): - return os.path.isfile(self.file_path) + self.writers: typing.List[HashBlobWriter] = [] + self.verified: asyncio.Event = asyncio.Event(loop=self.loop) + self.writing: asyncio.Event = asyncio.Event(loop=self.loop) - def writer_finished(self, writer: HashBlobWriter): - def callback(finished: asyncio.Future): - try: - error = finished.exception() - except Exception as err: - error = err - if writer in self.writers: # remove this download attempt - self.writers.remove(writer) - if not error: # the blob downloaded, cancel all the other download attempts and set the result - while self.writers: - other = self.writers.pop() - other.finished.cancel() - t = self.loop.create_task(self.save_verified_blob(writer, finished.result())) - t.add_done_callback(lambda *_: self.finished_writing.set()) - return - if isinstance(error, (InvalidBlobHashError, InvalidDataError)): - log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error)) - elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)): - log.exception("something else") - raise error - return callback - - async def save_verified_blob(self, writer, verified_bytes: bytes): - def _save_verified(): - # log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}") - if not self.saved_verified_blob and not os.path.isfile(self.file_path): - if self.get_length() == len(verified_bytes): - with open(self.file_path, 'wb') as write_handle: - write_handle.write(verified_bytes) - self.saved_verified_blob = True - else: - raise Exception("length mismatch") - - async with self.blob_write_lock: - if self.verified.is_set(): - return - await self.loop.run_in_executor(None, _save_verified) - if self.blob_completed_callback: - await self.blob_completed_callback(self) - self.verified.set() - - def open_for_writing(self) -> HashBlobWriter: - if self.file_exists: - raise OSError(f"File already exists '{self.file_path}'") - fut = asyncio.Future(loop=self.loop) - writer = HashBlobWriter(self.blob_hash, self.get_length, fut) - self.writers.append(writer) - fut.add_done_callback(self.writer_finished(writer)) - return writer - - async def sendfile(self, writer: asyncio.StreamWriter) -> int: - """ - Read and send the file to the writer and return the number of bytes sent - """ - - with open(self.file_path, 'rb') as handle: - return await self.loop.sendfile(writer.transport, handle, count=self.get_length()) - - def close(self): - while self.writers: - self.writers.pop().finished.cancel() - - def delete(self): + def __del__(self): + if self.writers: + log.warning("%s not closed before being garbage collected", self.blob_hash) self.close() - self.saved_verified_blob = False - if os.path.isfile(self.file_path): - os.remove(self.file_path) - self.verified.clear() - self.finished_writing.clear() - self.length = None - def decrypt(self, key: bytes, iv: bytes) -> bytes: - """ - Decrypt a BlobFile to plaintext bytes - """ + @contextlib.contextmanager + def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + raise NotImplementedError() - with open(self.file_path, "rb") as f: - buff = f.read() - if len(buff) != self.length: - raise ValueError("unexpected length") - cipher = Cipher(AES(key), modes.CBC(iv), backend=backend) - unpadder = PKCS7(AES.block_size).unpadder() - decryptor = cipher.decryptor() - return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize() + def _write_blob(self, blob_bytes: bytes): + raise NotImplementedError() - @classmethod - async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes, - iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo: - """ - Create an encrypted BlobFile from plaintext bytes - """ - - blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted) - length = len(blob_bytes) - blob = cls(loop, blob_dir, blob_hash, length) - writer = blob.open_for_writing() - writer.write(blob_bytes) - await blob.verified.wait() - return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash) - - def set_length(self, length): + def set_length(self, length) -> None: if self.length is not None and length == self.length: return if self.length is None and 0 <= length <= MAX_BLOB_SIZE: @@ -190,8 +96,192 @@ class BlobFile: return log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length) - def get_length(self): + def get_length(self) -> typing.Optional[int]: return self.length - def get_is_verified(self): + def get_is_verified(self) -> bool: return self.verified.is_set() + + def is_readable(self) -> bool: + return self.verified.is_set() + + def is_writeable(self) -> bool: + return not self.writing.is_set() + + def write_blob(self, blob_bytes: bytes): + if not self.is_writeable(): + raise OSError("cannot open blob for writing") + try: + self.writing.set() + self._write_blob(blob_bytes) + finally: + self.writing.clear() + + def close(self): + while self.writers: + self.writers.pop().finished.cancel() + + def delete(self): + self.close() + self.verified.clear() + self.length = None + + async def sendfile(self, writer: asyncio.StreamWriter) -> int: + """ + Read and send the file to the writer and return the number of bytes sent + """ + + if not self.is_readable(): + raise OSError('blob files cannot be read') + with self.reader_context() as handle: + return await self.loop.sendfile(writer.transport, handle, count=self.get_length()) + + def decrypt(self, key: bytes, iv: bytes) -> bytes: + """ + Decrypt a BlobFile to plaintext bytes + """ + + with self.reader_context() as reader: + return decrypt_blob_bytes(reader, self.length, key, iv) + + @classmethod + async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes, + iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo: + """ + Create an encrypted BlobFile from plaintext bytes + """ + + blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted) + length = len(blob_bytes) + blob = cls(loop, blob_hash, length, blob_directory=blob_dir) + writer = blob.get_blob_writer() + writer.write(blob_bytes) + await blob.verified.wait() + return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash) + + async def save_verified_blob(self, verified_bytes: bytes): + if self.verified.is_set(): + return + if self.is_writeable(): + if self.get_length() == len(verified_bytes): + self._write_blob(verified_bytes) + self.verified.set() + if self.blob_completed_callback: + await self.blob_completed_callback(self) + else: + raise Exception("length mismatch") + + def get_blob_writer(self) -> HashBlobWriter: + fut = asyncio.Future(loop=self.loop) + writer = HashBlobWriter(self.blob_hash, self.get_length, fut) + self.writers.append(writer) + + def writer_finished_callback(finished: asyncio.Future): + try: + err = finished.exception() + if err: + raise err + verified_bytes = finished.result() + while self.writers: + other = self.writers.pop() + if other is not writer: + other.finished.cancel() + self.loop.create_task(self.save_verified_blob(verified_bytes)) + return + except (InvalidBlobHashError, InvalidDataError) as error: + log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error)) + except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError) as error: + # log.exception("something else") + pass + finally: + if writer in self.writers: + self.writers.remove(writer) + + fut.add_done_callback(writer_finished_callback) + return writer + + +class BlobBuffer(AbstractBlob): + """ + An in-memory only blob + """ + def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None, + blob_directory: typing.Optional[str] = None): + super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) + self._verified_bytes: typing.Optional[BytesIO] = None + + @contextlib.contextmanager + def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + if not self.is_readable(): + raise OSError("cannot open blob for reading") + try: + yield self._verified_bytes + finally: + self._verified_bytes.close() + self._verified_bytes = None + self.verified.clear() + + def _write_blob(self, blob_bytes: bytes): + if self._verified_bytes: + raise OSError("already have bytes for blob") + self._verified_bytes = BytesIO(blob_bytes) + + +class BlobFile(AbstractBlob): + """ + A blob existing on the local file system + """ + def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None, + blob_directory: typing.Optional[str] = None): + if not blob_directory or not os.path.isdir(blob_directory): + raise OSError(f"invalid blob directory '{blob_directory}'") + super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) + if not is_valid_blobhash(blob_hash): + raise InvalidBlobHashError(blob_hash) + self.file_path = os.path.join(self.blob_directory, self.blob_hash) + if self.file_exists: + file_size = int(os.stat(self.file_path).st_size) + if length and length != file_size: + log.warning("expected %s to be %s bytes, file has %s", self.blob_hash, length, file_size) + self.delete() + else: + self.length = file_size + self.verified.set() + + @property + def file_exists(self): + return os.path.isfile(self.file_path) + + def is_writeable(self) -> bool: + return super().is_writeable() and not os.path.isfile(self.file_path) + + def get_blob_writer(self) -> HashBlobWriter: + if self.file_exists: + raise OSError(f"File already exists '{self.file_path}'") + return super().get_blob_writer() + + @contextlib.contextmanager + def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + handle = open(self.file_path, 'rb') + try: + yield handle + finally: + handle.close() + + def _write_blob(self, blob_bytes: bytes): + with open(self.file_path, 'wb') as f: + f.write(blob_bytes) + + def delete(self): + if os.path.isfile(self.file_path): + os.remove(self.file_path) + return super().delete() + + @classmethod + async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes, + iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo: + if not blob_dir or not os.path.isdir(blob_dir): + raise OSError(f"cannot create blob in directory: '{blob_dir}'") + return await super().create_from_unencrypted(loop, blob_dir, key, iv, unencrypted, blob_num) diff --git a/lbrynet/blob/blob_manager.py b/lbrynet/blob/blob_manager.py index 1c3b09979..c45ec5c58 100644 --- a/lbrynet/blob/blob_manager.py +++ b/lbrynet/blob/blob_manager.py @@ -2,7 +2,7 @@ import os import typing import asyncio import logging -from lbrynet.blob.blob_file import BlobFile, is_valid_blobhash +from lbrynet.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob from lbrynet.stream.descriptor import StreamDescriptor if typing.TYPE_CHECKING: @@ -14,7 +14,7 @@ log = logging.getLogger(__name__) class BlobManager: def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage', - node_data_store: typing.Optional['DictDataStore'] = None): + node_data_store: typing.Optional['DictDataStore'] = None, save_blobs: bool = True): """ This class stores blobs on the hard disk @@ -27,16 +27,25 @@ class BlobManager: self._node_data_store = node_data_store self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\ else self._node_data_store.completed_blobs - self.blobs: typing.Dict[str, BlobFile] = {} + self.blobs: typing.Dict[str, AbstractBlob] = {} + self._save_blobs = save_blobs + + def get_blob_class(self): + if not self._save_blobs: + return BlobBuffer + return BlobFile async def setup(self) -> bool: def get_files_in_blob_dir() -> typing.Set[str]: + if not self.blob_dir: + return set() return { item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name) } - in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir) - self.completed_blob_hashes.update(await self.storage.sync_missing_blobs(in_blobfiles_dir)) + to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir) + if to_add: + self.completed_blob_hashes.update(to_add) return True def stop(self): @@ -50,17 +59,20 @@ class BlobManager: if length and self.blobs[blob_hash].length is None: self.blobs[blob_hash].set_length(length) else: - self.blobs[blob_hash] = BlobFile(self.loop, self.blob_dir, blob_hash, length, self.blob_completed) + self.blobs[blob_hash] = self.get_blob_class()(self.loop, blob_hash, length, self.blob_completed, + self.blob_dir) return self.blobs[blob_hash] def get_stream_descriptor(self, sd_hash): return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash)) - async def blob_completed(self, blob: BlobFile): + async def blob_completed(self, blob: AbstractBlob): if blob.blob_hash is None: raise Exception("Blob hash is None") if not blob.length: raise Exception("Blob has a length of 0") + if isinstance(blob, BlobBuffer): # don't save blob buffers to the db / dont announce them + return if blob.blob_hash not in self.completed_blob_hashes: self.completed_blob_hashes.add(blob.blob_hash) await self.storage.add_completed_blob(blob.blob_hash, blob.length) @@ -75,7 +87,7 @@ class BlobManager: raise Exception("invalid blob hash to delete") if blob_hash not in self.blobs: - if os.path.isfile(os.path.join(self.blob_dir, blob_hash)): + if self.blob_dir and os.path.isfile(os.path.join(self.blob_dir, blob_hash)): os.remove(os.path.join(self.blob_dir, blob_hash)) else: self.blobs.pop(blob_hash).delete() diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index df0805e57..10c5d10df 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -5,7 +5,7 @@ import binascii from lbrynet.error import InvalidBlobHashError, InvalidDataError from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest if typing.TYPE_CHECKING: - from lbrynet.blob.blob_file import BlobFile + from lbrynet.blob.blob_file import AbstractBlob from lbrynet.blob.writer import HashBlobWriter log = logging.getLogger(__name__) @@ -17,10 +17,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.peer_port: typing.Optional[int] = None self.peer_address: typing.Optional[str] = None self.peer_timeout = peer_timeout - self.transport: asyncio.Transport = None + self.transport: typing.Optional[asyncio.Transport] = None - self.writer: 'HashBlobWriter' = None - self.blob: 'BlobFile' = None + self.writer: typing.Optional['HashBlobWriter'] = None + self.blob: typing.Optional['AbstractBlob'] = None self._blob_bytes_received = 0 self._response_fut: asyncio.Future = None @@ -63,8 +63,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): # write blob bytes if we're writing a blob and have blob bytes to write self._write(response.blob_data) - - def _write(self, data): + def _write(self, data: bytes): if len(data) > (self.blob.get_length() - self._blob_bytes_received): data = data[:(self.blob.get_length() - self._blob_bytes_received)] log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port) @@ -145,11 +144,12 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.transport = None self.buf = b'' - async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: - if blob.get_is_verified() or blob.file_exists or blob.blob_write_lock.locked(): + async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: + if blob.get_is_verified() or not blob.is_writeable(): return 0, self.transport try: - self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0 + blob.get_blob_writer() + self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0 self._response_fut = asyncio.Future(loop=self.loop) return await self._download_blob() except OSError as e: @@ -177,7 +177,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.close() -async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int, +async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int, peer_connect_timeout: float, blob_download_timeout: float, connected_transport: asyncio.Transport = None)\ -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: @@ -196,7 +196,7 @@ async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: s if not connected_transport: await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), peer_connect_timeout, loop=loop) - if blob.get_is_verified() or blob.file_exists: + if blob.get_is_verified() or not blob.is_writeable(): # file exists but not verified means someone is writing right now, give it time, come back later return 0, connected_transport return await protocol.download_blob(blob) diff --git a/lbrynet/blob_exchange/downloader.py b/lbrynet/blob_exchange/downloader.py index 0678a1626..0d9381f50 100644 --- a/lbrynet/blob_exchange/downloader.py +++ b/lbrynet/blob_exchange/downloader.py @@ -8,7 +8,7 @@ if typing.TYPE_CHECKING: from lbrynet.dht.node import Node from lbrynet.dht.peer import KademliaPeer from lbrynet.blob.blob_manager import BlobManager - from lbrynet.blob.blob_file import BlobFile + from lbrynet.blob.blob_file import AbstractBlob log = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class BlobDownloader: self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} self.time_since_last_blob = loop.time() - def should_race_continue(self, blob: 'BlobFile'): + def should_race_continue(self, blob: 'AbstractBlob'): if len(self.active_connections) >= self.config.max_connections_per_download: return False # if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves @@ -37,9 +37,9 @@ class BlobDownloader: # for peer, task in self.active_connections.items(): # if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done(): # return False - return not (blob.get_is_verified() or blob.file_exists) + return not (blob.get_is_verified() or not blob.is_writeable()) - async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'): + async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'): if blob.get_is_verified(): return self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones @@ -62,7 +62,7 @@ class BlobDownloader: rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0 self.scores[peer] = rough_speed - async def new_peer_or_finished(self, blob: 'BlobFile'): + async def new_peer_or_finished(self, blob: 'AbstractBlob'): async def get_and_re_add_peers(): try: new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0) @@ -90,7 +90,7 @@ class BlobDownloader: for banned_peer in forgiven: self.ignored.pop(banned_peer) - async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': + async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob': blob = self.blob_manager.get_blob(blob_hash, length) if blob.get_is_verified(): return blob @@ -99,7 +99,7 @@ class BlobDownloader: batch: typing.List['KademliaPeer'] = [] while not self.peer_queue.empty(): batch.extend(self.peer_queue.get_nowait()) - batch.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True) + batch.sort(key=lambda p: self.scores.get(p, 0), reverse=True) log.debug( "running, %d peers, %d ignored, %d active", len(batch), len(self.ignored), len(self.active_connections) @@ -114,15 +114,29 @@ class BlobDownloader: await self.new_peer_or_finished(blob) self.cleanup_active() if batch: - self.peer_queue.put_nowait(set(batch).difference(self.ignored)) + to_re_add = list(set(batch).difference(self.ignored)) + if to_re_add: + self.peer_queue.put_nowait(to_re_add) + else: + self.clearbanned() else: self.clearbanned() blob.close() log.debug("downloaded %s", blob_hash[:8]) return blob + except asyncio.CancelledError as err: + error = err finally: + re_add = set() while self.active_connections: - self.active_connections.popitem()[1].cancel() + peer, t = self.active_connections.popitem() + t.cancel() + re_add.add(peer) + re_add = re_add.difference(self.ignored) + if re_add: + self.peer_queue.put_nowait(list(re_add)) + blob.close() + raise error def close(self): self.scores.clear() @@ -132,7 +146,7 @@ class BlobDownloader: async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node', - blob_hash: str) -> 'BlobFile': + blob_hash: str) -> 'AbstractBlob': search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download) search_queue.put_nowait(blob_hash) peer_queue, accumulate_task = node.accumulate_peers(search_queue) diff --git a/lbrynet/extras/daemon/Components.py b/lbrynet/extras/daemon/Components.py index 920f0b189..657b53efc 100644 --- a/lbrynet/extras/daemon/Components.py +++ b/lbrynet/extras/daemon/Components.py @@ -294,7 +294,7 @@ class BlobComponent(Component): blob_dir = os.path.join(self.conf.data_dir, 'blobfiles') if not os.path.isdir(blob_dir): os.mkdir(blob_dir) - self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, data_store) + self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, data_store, self.conf.save_blobs) return await self.blob_manager.setup() async def stop(self): diff --git a/lbrynet/stream/descriptor.py b/lbrynet/stream/descriptor.py index d73369110..e770a3bcd 100644 --- a/lbrynet/stream/descriptor.py +++ b/lbrynet/stream/descriptor.py @@ -8,7 +8,7 @@ from collections import OrderedDict from cryptography.hazmat.primitives.ciphers.algorithms import AES from lbrynet.blob import MAX_BLOB_SIZE from lbrynet.blob.blob_info import BlobInfo -from lbrynet.blob.blob_file import BlobFile +from lbrynet.blob.blob_file import AbstractBlob, BlobFile from lbrynet.cryptoutils import get_lbry_hash_obj from lbrynet.error import InvalidStreamDescriptorError @@ -108,29 +108,29 @@ class StreamDescriptor: h.update(self.old_sort_json()) return h.hexdigest() - async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None, + async def make_sd_blob(self, blob_file_obj: typing.Optional[AbstractBlob] = None, old_sort: typing.Optional[bool] = False): sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash() if not old_sort: sd_data = self.as_json() else: sd_data = self.old_sort_json() - sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data)) + sd_blob = blob_file_obj or BlobFile(self.loop, sd_hash, len(sd_data), blob_directory=self.blob_dir) if blob_file_obj: blob_file_obj.set_length(len(sd_data)) if not sd_blob.get_is_verified(): - writer = sd_blob.open_for_writing() + writer = sd_blob.get_blob_writer() writer.write(sd_data) + await sd_blob.verified.wait() sd_blob.close() return sd_blob @classmethod def _from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str, - blob: BlobFile) -> 'StreamDescriptor': - assert os.path.isfile(blob.file_path) - with open(blob.file_path, 'rb') as f: - json_bytes = f.read() + blob: AbstractBlob) -> 'StreamDescriptor': + with blob.reader_context() as blob_reader: + json_bytes = blob_reader.read() try: decoded = json.loads(json_bytes.decode()) except json.JSONDecodeError: @@ -160,8 +160,8 @@ class StreamDescriptor: @classmethod async def from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str, - blob: BlobFile) -> 'StreamDescriptor': - return await loop.run_in_executor(None, lambda: cls._from_stream_descriptor_blob(loop, blob_dir, blob)) + blob: AbstractBlob) -> 'StreamDescriptor': + return await loop.run_in_executor(None, cls._from_stream_descriptor_blob, loop, blob_dir, blob) @staticmethod def get_blob_hashsum(b: typing.Dict): @@ -228,7 +228,7 @@ class StreamDescriptor: return self.lower_bound_decrypted_length() + (AES.block_size // 8) @classmethod - async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str, + async def recover(cls, blob_dir: str, sd_blob: 'AbstractBlob', stream_hash: str, stream_name: str, suggested_file_name: str, key: str, blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']: descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name, diff --git a/lbrynet/stream/reflector/server.py b/lbrynet/stream/reflector/server.py index bcfdeb14a..8d4c0289e 100644 --- a/lbrynet/stream/reflector/server.py +++ b/lbrynet/stream/reflector/server.py @@ -63,11 +63,11 @@ class ReflectorServerProtocol(asyncio.Protocol): return self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) if not self.sd_blob.get_is_verified(): - self.writer = self.sd_blob.open_for_writing() + self.writer = self.sd_blob.get_blob_writer() self.incoming.set() self.send_response({"send_sd_blob": True}) try: - await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop) + await asyncio.wait_for(self.sd_blob.verified.wait(), 30, loop=self.loop) self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.loop, self.blob_manager.blob_dir, self.sd_blob ) @@ -102,11 +102,11 @@ class ReflectorServerProtocol(asyncio.Protocol): return blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) if not blob.get_is_verified(): - self.writer = blob.open_for_writing() + self.writer = blob.get_blob_writer() self.incoming.set() self.send_response({"send_blob": True}) try: - await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop) + await asyncio.wait_for(blob.verified.wait(), 30, loop=self.loop) self.send_response({"received_blob": True}) except asyncio.TimeoutError: self.send_response({"received_blob": False})