Merge pull request #2020 from lbryio/download-range-requests

Support HTTP 206 partial content requests for streaming downloads
This commit is contained in:
Jack Robison 2019-04-24 14:21:41 -04:00 committed by GitHub
commit 6f52d36b22
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 2119 additions and 1137 deletions

View file

@ -12,12 +12,13 @@ blockchain_name: lbrycrd_main
data_dir: /home/lbry/.lbrynet
download_directory: /home/lbry/downloads
delete_blobs_on_remove: True
save_blobs: true
save_files: false
dht_node_port: 4444
peer_port: 3333
use_upnp: True
use_upnp: true
#components_to_skip:
# - peer_protocol_server
# - hash_announcer
# - blob_server
# - dht

View file

@ -4,6 +4,8 @@ import asyncio
import binascii
import logging
import typing
import contextlib
from io import BytesIO
from cryptography.hazmat.primitives.ciphers import Cipher, modes
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.padding import PKCS7
@ -21,10 +23,6 @@ log = logging.getLogger(__name__)
_hexmatch = re.compile("^[a-f,0-9]+$")
def is_valid_hashcharacter(char: str) -> bool:
return len(char) == 1 and _hexmatch.match(char)
def is_valid_blobhash(blobhash: str) -> bool:
"""Checks whether the blobhash is the correct length and contains only
valid characters (0-9, a-f)
@ -46,143 +44,74 @@ def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tupl
return encrypted, digest.hexdigest()
class BlobFile:
"""
A chunk of data available on the network which is specified by a hashsum
This class is used to create blobs on the local filesystem
when we already know the blob hash before hand (i.e., when downloading blobs)
Also can be used for reading from blobs on the local filesystem
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str,
length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None):
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
self.loop = loop
self.blob_hash = blob_hash
self.length = length
self.blob_dir = blob_dir
self.file_path = os.path.join(blob_dir, self.blob_hash)
self.writers: typing.List[HashBlobWriter] = []
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=loop)
self.blob_write_lock = asyncio.Lock(loop=loop)
if self.file_exists:
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
self.length = length
self.verified.set()
self.finished_writing.set()
self.saved_verified_blob = False
self.blob_completed_callback = blob_completed_callback
@property
def file_exists(self):
return os.path.isfile(self.file_path)
def writer_finished(self, writer: HashBlobWriter):
def callback(finished: asyncio.Future):
try:
error = finished.exception()
except Exception as err:
error = err
if writer in self.writers: # remove this download attempt
self.writers.remove(writer)
if not error: # the blob downloaded, cancel all the other download attempts and set the result
while self.writers:
other = self.writers.pop()
other.finished.cancel()
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
t.add_done_callback(lambda *_: self.finished_writing.set())
return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
log.exception("something else")
raise error
return callback
async def save_verified_blob(self, writer, verified_bytes: bytes):
def _save_verified():
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
if self.get_length() == len(verified_bytes):
with open(self.file_path, 'wb') as write_handle:
write_handle.write(verified_bytes)
self.saved_verified_blob = True
else:
raise Exception("length mismatch")
async with self.blob_write_lock:
if self.verified.is_set():
return
await self.loop.run_in_executor(None, _save_verified)
if self.blob_completed_callback:
await self.blob_completed_callback(self)
self.verified.set()
def open_for_writing(self) -> HashBlobWriter:
if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'")
fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers.append(writer)
fut.add_done_callback(self.writer_finished(writer))
return writer
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
"""
Read and send the file to the writer and return the number of bytes sent
"""
with open(self.file_path, 'rb') as handle:
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
def close(self):
while self.writers:
self.writers.pop().finished.cancel()
def delete(self):
self.close()
self.saved_verified_blob = False
if os.path.isfile(self.file_path):
os.remove(self.file_path)
self.verified.clear()
self.finished_writing.clear()
self.length = None
def decrypt(self, key: bytes, iv: bytes) -> bytes:
"""
Decrypt a BlobFile to plaintext bytes
"""
with open(self.file_path, "rb") as f:
buff = f.read()
if len(buff) != self.length:
def decrypt_blob_bytes(data: bytes, length: int, key: bytes, iv: bytes) -> bytes:
if len(data) != length:
raise ValueError("unexpected length")
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
unpadder = PKCS7(AES.block_size).unpadder()
decryptor = cipher.decryptor()
return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize()
return unpadder.update(decryptor.update(data) + decryptor.finalize()) + unpadder.finalize()
@classmethod
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
class AbstractBlob:
"""
Create an encrypted BlobFile from plaintext bytes
A chunk of data (up to 2MB) available on the network which is specified by a sha384 hash
This class is non-io specific
"""
__slots__ = [
'loop',
'blob_hash',
'length',
'blob_completed_callback',
'blob_directory',
'writers',
'verified',
'writing',
'readers'
]
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
length = len(blob_bytes)
blob = cls(loop, blob_dir, blob_hash, length)
writer = blob.open_for_writing()
writer.write(blob_bytes)
await blob.verified.wait()
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_directory: typing.Optional[str] = None):
self.loop = loop
self.blob_hash = blob_hash
self.length = length
self.blob_completed_callback = blob_completed_callback
self.blob_directory = blob_directory
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
self.readers: typing.List[typing.BinaryIO] = []
def set_length(self, length):
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
def __del__(self):
if self.writers or self.readers:
log.warning("%s not closed before being garbage collected", self.blob_hash)
self.close()
@contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
raise NotImplementedError()
@contextlib.contextmanager
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
if not self.is_readable():
raise OSError(f"{str(type(self))} not readable, {len(self.readers)} readers {len(self.writers)} writers")
with self._reader_context() as reader:
try:
self.readers.append(reader)
yield reader
finally:
if reader in self.readers:
self.readers.remove(reader)
def _write_blob(self, blob_bytes: bytes):
raise NotImplementedError()
def set_length(self, length) -> None:
if self.length is not None and length == self.length:
return
if self.length is None and 0 <= length <= MAX_BLOB_SIZE:
@ -190,8 +119,217 @@ class BlobFile:
return
log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length)
def get_length(self):
def get_length(self) -> typing.Optional[int]:
return self.length
def get_is_verified(self):
def get_is_verified(self) -> bool:
return self.verified.is_set()
def is_readable(self) -> bool:
return self.verified.is_set()
def is_writeable(self) -> bool:
return not self.writing.is_set()
def write_blob(self, blob_bytes: bytes):
if not self.is_writeable():
raise OSError("cannot open blob for writing")
try:
self.writing.set()
self._write_blob(blob_bytes)
finally:
self.writing.clear()
def close(self):
while self.writers:
peer, writer = self.writers.popitem()
if writer and writer.finished and not writer.finished.done() and not self.loop.is_closed():
writer.finished.cancel()
while self.readers:
reader = self.readers.pop()
if reader:
reader.close()
def delete(self):
self.close()
self.verified.clear()
self.length = None
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
"""
Read and send the file to the writer and return the number of bytes sent
"""
if not self.is_readable():
raise OSError('blob files cannot be read')
with self.reader_context() as handle:
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
def decrypt(self, key: bytes, iv: bytes) -> bytes:
"""
Decrypt a BlobFile to plaintext bytes
"""
with self.reader_context() as reader:
return decrypt_blob_bytes(reader.read(), self.length, key, iv)
@classmethod
async def create_from_unencrypted(
cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes,
unencrypted: bytes, blob_num: int,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None) -> BlobInfo:
"""
Create an encrypted BlobFile from plaintext bytes
"""
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
length = len(blob_bytes)
blob = cls(loop, blob_hash, length, blob_completed_callback, blob_dir)
writer = blob.get_blob_writer()
writer.write(blob_bytes)
await blob.verified.wait()
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
def save_verified_blob(self, verified_bytes: bytes):
if self.verified.is_set():
return
if self.is_writeable():
self._write_blob(verified_bytes)
self.verified.set()
if self.blob_completed_callback:
self.blob_completed_callback(self)
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
peer_port: typing.Optional[int] = None) -> HashBlobWriter:
if (peer_address, peer_port) in self.writers:
raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}")
fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers[(peer_address, peer_port)] = writer
def remove_writer(_):
if (peer_address, peer_port) in self.writers:
del self.writers[(peer_address, peer_port)]
fut.add_done_callback(remove_writer)
def writer_finished_callback(finished: asyncio.Future):
try:
err = finished.exception()
if err:
raise err
verified_bytes = finished.result()
while self.writers:
_, other = self.writers.popitem()
if other is not writer:
other.close_handle()
self.save_verified_blob(verified_bytes)
except (InvalidBlobHashError, InvalidDataError) as error:
log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error))
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError):
pass
fut.add_done_callback(writer_finished_callback)
return writer
class BlobBuffer(AbstractBlob):
"""
An in-memory only blob
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_directory: typing.Optional[str] = None):
self._verified_bytes: typing.Optional[BytesIO] = None
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
@contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
if not self.is_readable():
raise OSError("cannot open blob for reading")
try:
yield self._verified_bytes
finally:
if self._verified_bytes:
self._verified_bytes.close()
self._verified_bytes = None
self.verified.clear()
def _write_blob(self, blob_bytes: bytes):
if self._verified_bytes:
raise OSError("already have bytes for blob")
self._verified_bytes = BytesIO(blob_bytes)
def delete(self):
if self._verified_bytes:
self._verified_bytes.close()
self._verified_bytes = None
return super().delete()
def __del__(self):
super().__del__()
if self._verified_bytes:
self.delete()
class BlobFile(AbstractBlob):
"""
A blob existing on the local file system
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_directory: typing.Optional[str] = None):
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
if not blob_directory or not os.path.isdir(blob_directory):
raise OSError(f"invalid blob directory '{blob_directory}'")
self.file_path = os.path.join(self.blob_directory, self.blob_hash)
if self.file_exists:
file_size = int(os.stat(self.file_path).st_size)
if length and length != file_size:
log.warning("expected %s to be %s bytes, file has %s", self.blob_hash, length, file_size)
self.delete()
else:
self.length = file_size
self.verified.set()
@property
def file_exists(self):
return os.path.isfile(self.file_path)
def is_writeable(self) -> bool:
return super().is_writeable() and not os.path.isfile(self.file_path)
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'")
return super().get_blob_writer(peer_address, peer_port)
@contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
handle = open(self.file_path, 'rb')
try:
yield handle
finally:
handle.close()
def _write_blob(self, blob_bytes: bytes):
with open(self.file_path, 'wb') as f:
f.write(blob_bytes)
def delete(self):
if os.path.isfile(self.file_path):
os.remove(self.file_path)
return super().delete()
@classmethod
async def create_from_unencrypted(
cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes,
unencrypted: bytes, blob_num: int,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'],
asyncio.Task]] = None) -> BlobInfo:
if not blob_dir or not os.path.isdir(blob_dir):
raise OSError(f"cannot create blob in directory: '{blob_dir}'")
return await super().create_from_unencrypted(
loop, blob_dir, key, iv, unencrypted, blob_num, blob_completed_callback
)

View file

@ -2,18 +2,19 @@ import os
import typing
import asyncio
import logging
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_file import BlobFile, is_valid_blobhash
from lbrynet.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob
from lbrynet.stream.descriptor import StreamDescriptor
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.dht.protocol.data_store import DictDataStore
from lbrynet.extras.daemon.storage import SQLiteStorage
log = logging.getLogger(__name__)
class BlobFileManager:
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: SQLiteStorage,
class BlobManager:
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage', config: 'Config',
node_data_store: typing.Optional['DictDataStore'] = None):
"""
This class stores blobs on the hard disk
@ -27,16 +28,59 @@ class BlobFileManager:
self._node_data_store = node_data_store
self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\
else self._node_data_store.completed_blobs
self.blobs: typing.Dict[str, BlobFile] = {}
self.blobs: typing.Dict[str, AbstractBlob] = {}
self.config = config
def _get_blob(self, blob_hash: str, length: typing.Optional[int] = None):
if self.config.save_blobs:
return BlobFile(
self.loop, blob_hash, length, self.blob_completed, self.blob_dir
)
else:
if is_valid_blobhash(blob_hash) and os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
return BlobFile(
self.loop, blob_hash, length, self.blob_completed, self.blob_dir
)
return BlobBuffer(
self.loop, blob_hash, length, self.blob_completed, self.blob_dir
)
def get_blob(self, blob_hash, length: typing.Optional[int] = None):
if blob_hash in self.blobs:
if self.config.save_blobs and isinstance(self.blobs[blob_hash], BlobBuffer):
buffer = self.blobs.pop(blob_hash)
if blob_hash in self.completed_blob_hashes:
self.completed_blob_hashes.remove(blob_hash)
self.blobs[blob_hash] = self._get_blob(blob_hash, length)
if buffer.is_readable():
with buffer.reader_context() as reader:
self.blobs[blob_hash].write_blob(reader.read())
if length and self.blobs[blob_hash].length is None:
self.blobs[blob_hash].set_length(length)
else:
self.blobs[blob_hash] = self._get_blob(blob_hash, length)
return self.blobs[blob_hash]
def is_blob_verified(self, blob_hash: str, length: typing.Optional[int] = None) -> bool:
if not is_valid_blobhash(blob_hash):
raise ValueError(blob_hash)
if blob_hash in self.blobs:
return self.blobs[blob_hash].get_is_verified()
if not os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
return False
return self._get_blob(blob_hash, length).get_is_verified()
async def setup(self) -> bool:
def get_files_in_blob_dir() -> typing.Set[str]:
if not self.blob_dir:
return set()
return {
item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name)
}
in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir)
self.completed_blob_hashes.update(await self.storage.sync_missing_blobs(in_blobfiles_dir))
to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir)
if to_add:
self.completed_blob_hashes.update(to_add)
return True
def stop(self):
@ -45,37 +89,31 @@ class BlobFileManager:
blob.close()
self.completed_blob_hashes.clear()
def get_blob(self, blob_hash, length: typing.Optional[int] = None):
if blob_hash in self.blobs:
if length and self.blobs[blob_hash].length is None:
self.blobs[blob_hash].set_length(length)
else:
self.blobs[blob_hash] = BlobFile(self.loop, self.blob_dir, blob_hash, length, self.blob_completed)
return self.blobs[blob_hash]
def get_stream_descriptor(self, sd_hash):
return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash))
async def blob_completed(self, blob: BlobFile):
def blob_completed(self, blob: AbstractBlob) -> asyncio.Task:
if blob.blob_hash is None:
raise Exception("Blob hash is None")
if not blob.length:
raise Exception("Blob has a length of 0")
if isinstance(blob, BlobFile):
if blob.blob_hash not in self.completed_blob_hashes:
self.completed_blob_hashes.add(blob.blob_hash)
await self.storage.add_completed_blob(blob.blob_hash, blob.length)
return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=True))
else:
return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=False))
def check_completed_blobs(self, blob_hashes: typing.List[str]) -> typing.List[str]:
"""Returns of the blobhashes_to_check, which are valid"""
blobs = [self.get_blob(b) for b in blob_hashes]
return [blob.blob_hash for blob in blobs if blob.get_is_verified()]
return [blob_hash for blob_hash in blob_hashes if self.is_blob_verified(blob_hash)]
def delete_blob(self, blob_hash: str):
if not is_valid_blobhash(blob_hash):
raise Exception("invalid blob hash to delete")
if blob_hash not in self.blobs:
if os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
if self.blob_dir and os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
os.remove(os.path.join(self.blob_dir, blob_hash))
else:
self.blobs.pop(blob_hash).delete()

View file

@ -44,16 +44,15 @@ class HashBlobWriter:
self._hashsum.update(data)
self.len_so_far += len(data)
if self.len_so_far > expected_length:
self.close_handle()
self.finished.set_exception(InvalidDataError(
f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}'
))
self.close_handle()
return
self.buffer.write(data)
if self.len_so_far == expected_length:
blob_hash = self.calculate_blob_hash()
if blob_hash != self.expected_blob_hash:
self.close_handle()
self.finished.set_exception(InvalidBlobHashError(
f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}"
))
@ -62,6 +61,8 @@ class HashBlobWriter:
self.close_handle()
def close_handle(self):
if not self.finished.done():
self.finished.cancel()
if self.buffer is not None:
self.buffer.close()
self.buffer = None

View file

@ -5,7 +5,7 @@ import binascii
from lbrynet.error import InvalidBlobHashError, InvalidDataError
from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.blob_file import AbstractBlob
from lbrynet.blob.writer import HashBlobWriter
log = logging.getLogger(__name__)
@ -17,10 +17,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.peer_port: typing.Optional[int] = None
self.peer_address: typing.Optional[str] = None
self.peer_timeout = peer_timeout
self.transport: asyncio.Transport = None
self.transport: typing.Optional[asyncio.Transport] = None
self.writer: 'HashBlobWriter' = None
self.blob: 'BlobFile' = None
self.writer: typing.Optional['HashBlobWriter'] = None
self.blob: typing.Optional['AbstractBlob'] = None
self._blob_bytes_received = 0
self._response_fut: asyncio.Future = None
@ -63,8 +63,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
# write blob bytes if we're writing a blob and have blob bytes to write
self._write(response.blob_data)
def _write(self, data):
def _write(self, data: bytes):
if len(data) > (self.blob.get_length() - self._blob_bytes_received):
data = data[:(self.blob.get_length() - self._blob_bytes_received)]
log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port)
@ -145,17 +144,18 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.transport = None
self.buf = b''
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
if blob.get_is_verified() or blob.file_exists or blob.blob_write_lock.locked():
async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
blob_hash = blob.blob_hash
if blob.get_is_verified() or not blob.is_writeable():
return 0, self.transport
try:
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0
self._blob_bytes_received = 0
self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port)
self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob()
except OSError as e:
log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port)
# i'm not sure how to fix this race condition - jack
log.exception(e)
log.exception("race happened downloading %s from %s:%i", blob_hash, self.peer_address, self.peer_port)
return self._blob_bytes_received, self.transport
except asyncio.TimeoutError:
if self._response_fut and not self._response_fut.done():
@ -177,7 +177,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.close()
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
@ -196,7 +196,7 @@ async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: s
if not connected_transport:
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
peer_connect_timeout, loop=loop)
if blob.get_is_verified() or blob.file_exists:
if blob.get_is_verified() or not blob.is_writeable():
# file exists but not verified means someone is writing right now, give it time, come back later
return 0, connected_transport
return await protocol.download_blob(blob)

View file

@ -1,21 +1,22 @@
import asyncio
import typing
import logging
from lbrynet.utils import drain_tasks
from lbrynet.utils import drain_tasks, cache_concurrent
from lbrynet.blob_exchange.client import request_blob
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_file import AbstractBlob
log = logging.getLogger(__name__)
class BlobDownloader:
BAN_TIME = 10.0 # fixme: when connection manager gets implemented, move it out from here
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager',
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
peer_queue: asyncio.Queue):
self.loop = loop
self.config = config
@ -27,7 +28,7 @@ class BlobDownloader:
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
self.time_since_last_blob = loop.time()
def should_race_continue(self, blob: 'BlobFile'):
def should_race_continue(self, blob: 'AbstractBlob'):
if len(self.active_connections) >= self.config.max_connections_per_download:
return False
# if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves
@ -36,9 +37,9 @@ class BlobDownloader:
# for peer, task in self.active_connections.items():
# if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done():
# return False
return not (blob.get_is_verified() or blob.file_exists)
return not (blob.get_is_verified() or not blob.is_writeable())
async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'):
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'):
if blob.get_is_verified():
return
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
@ -61,7 +62,7 @@ class BlobDownloader:
rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0
self.scores[peer] = rough_speed
async def new_peer_or_finished(self, blob: 'BlobFile'):
async def new_peer_or_finished(self, blob: 'AbstractBlob'):
async def get_and_re_add_peers():
try:
new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0)
@ -89,7 +90,8 @@ class BlobDownloader:
for banned_peer in forgiven:
self.ignored.pop(banned_peer)
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
@cache_concurrent
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
blob = self.blob_manager.get_blob(blob_hash, length)
if blob.get_is_verified():
return blob
@ -98,7 +100,7 @@ class BlobDownloader:
batch: typing.List['KademliaPeer'] = []
while not self.peer_queue.empty():
batch.extend(self.peer_queue.get_nowait())
batch.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True)
batch.sort(key=lambda p: self.scores.get(p, 0), reverse=True)
log.debug(
"running, %d peers, %d ignored, %d active",
len(batch), len(self.ignored), len(self.active_connections)
@ -113,15 +115,26 @@ class BlobDownloader:
await self.new_peer_or_finished(blob)
self.cleanup_active()
if batch:
self.peer_queue.put_nowait(set(batch).difference(self.ignored))
to_re_add = list(set(batch).difference(self.ignored))
if to_re_add:
self.peer_queue.put_nowait(to_re_add)
else:
self.clearbanned()
else:
self.clearbanned()
blob.close()
log.debug("downloaded %s", blob_hash[:8])
return blob
finally:
re_add = set()
while self.active_connections:
self.active_connections.popitem()[1].cancel()
peer, t = self.active_connections.popitem()
t.cancel()
re_add.add(peer)
re_add = re_add.difference(self.ignored)
if re_add:
self.peer_queue.put_nowait(list(re_add))
blob.close()
def close(self):
self.scores.clear()
@ -130,8 +143,8 @@ class BlobDownloader:
transport.close()
async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node',
blob_hash: str) -> 'BlobFile':
async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node',
blob_hash: str) -> 'AbstractBlob':
search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download)
search_queue.put_nowait(blob_hash)
peer_queue, accumulate_task = node.accumulate_peers(search_queue)

View file

@ -8,13 +8,13 @@ from lbrynet.blob_exchange.serialization import BlobAvailabilityResponse, BlobPr
BlobPaymentAddressResponse
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
log = logging.getLogger(__name__)
class BlobServerProtocol(asyncio.Protocol):
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', lbrycrd_address: str):
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str):
self.loop = loop
self.blob_manager = blob_manager
self.server_task: asyncio.Task = None
@ -94,7 +94,7 @@ class BlobServerProtocol(asyncio.Protocol):
class BlobServer:
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', lbrycrd_address: str):
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str):
self.loop = loop
self.blob_manager = blob_manager
self.server_task: asyncio.Task = None

View file

@ -484,6 +484,8 @@ class Config(CLIConfig):
node_rpc_timeout = Float("Timeout when making a DHT request", constants.rpc_timeout)
# blob announcement and download
save_blobs = Toggle("Save encrypted blob files for hosting, otherwise download blobs to memory only.", True)
announce_head_and_sd_only = Toggle(
"Announce only the descriptor and first (rather than all) data blob for a stream to the DHT", True,
previous_names=['announce_head_blobs_only']
@ -537,6 +539,7 @@ class Config(CLIConfig):
cache_time = Integer("Time to cache resolved claims", 150) # TODO: use this
# daemon
save_files = Toggle("Save downloaded files when calling `get` by default", True)
components_to_skip = Strings("components which will be skipped during start-up of daemon", [])
share_usage_data = Toggle(
"Whether to share usage stats and diagnostic info with LBRY.", True,

View file

@ -5,7 +5,7 @@ from itertools import chain
import typing
import logging
from lbrynet.dht import constants
from lbrynet.dht.error import RemoteException
from lbrynet.dht.error import RemoteException, TransportNotConnected
from lbrynet.dht.protocol.distance import Distance
from typing import TYPE_CHECKING
@ -169,7 +169,7 @@ class IterativeFinder:
log.warning(str(err))
self.active.discard(peer)
return
except RemoteException:
except (RemoteException, TransportNotConnected):
return
return await self._handle_probe_result(peer, response)
@ -215,7 +215,7 @@ class IterativeFinder:
await self._search_round()
if self.running:
self.delayed_calls.append(self.loop.call_later(delay, self._search))
except (asyncio.CancelledError, StopAsyncIteration):
except (asyncio.CancelledError, StopAsyncIteration, TransportNotConnected):
if self.running:
self.loop.call_soon(self.aclose)

View file

@ -524,6 +524,10 @@ class KademliaProtocol(DatagramProtocol):
response = await asyncio.wait_for(response_fut, self.rpc_timeout)
self.peer_manager.report_last_replied(peer.address, peer.udp_port)
return response
except asyncio.CancelledError:
if not response_fut.done():
response_fut.cancel()
raise
except (asyncio.TimeoutError, RemoteException):
self.peer_manager.report_failure(peer.address, peer.udp_port)
if self.peer_manager.peer_is_good(peer) is False:
@ -540,7 +544,7 @@ class KademliaProtocol(DatagramProtocol):
async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
ErrorDatagram]):
if not self.transport:
if not self.transport or self.transport.is_closing():
raise TransportNotConnected()
data = message.bencode()

View file

@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception):
self.download = download
class ResolveTimeout(Exception):
def __init__(self, uri):
super().__init__(f'Failed to resolve "{uri}" within the timeout')
self.uri = uri
class RequestCanceledError(Exception):
pass

View file

@ -16,7 +16,7 @@ from lbrynet import utils
from lbrynet.conf import HEADERS_FILE_SHA256_CHECKSUM
from lbrynet.dht.node import Node
from lbrynet.dht.blob_announcer import BlobAnnouncer
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.server import BlobServer
from lbrynet.stream.stream_manager import StreamManager
from lbrynet.extras.daemon.Component import Component
@ -278,10 +278,10 @@ class BlobComponent(Component):
def __init__(self, component_manager):
super().__init__(component_manager)
self.blob_manager: BlobFileManager = None
self.blob_manager: BlobManager = None
@property
def component(self) -> typing.Optional[BlobFileManager]:
def component(self) -> typing.Optional[BlobManager]:
return self.blob_manager
async def start(self):
@ -291,8 +291,10 @@ class BlobComponent(Component):
dht_node: Node = self.component_manager.get_component(DHT_COMPONENT)
if dht_node:
data_store = dht_node.protocol.data_store
self.blob_manager = BlobFileManager(asyncio.get_event_loop(), os.path.join(self.conf.data_dir, "blobfiles"),
storage, data_store)
blob_dir = os.path.join(self.conf.data_dir, 'blobfiles')
if not os.path.isdir(blob_dir):
os.mkdir(blob_dir)
self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, self.conf, data_store)
return await self.blob_manager.setup()
async def stop(self):
@ -451,7 +453,7 @@ class PeerProtocolServerComponent(Component):
async def start(self):
log.info("start blob server")
upnp = self.component_manager.get_component(UPNP_COMPONENT)
blob_manager: BlobFileManager = self.component_manager.get_component(BLOB_COMPONENT)
blob_manager: BlobManager = self.component_manager.get_component(BLOB_COMPONENT)
wallet: LbryWalletManager = self.component_manager.get_component(WALLET_COMPONENT)
peer_port = self.conf.tcp_port
address = await wallet.get_unused_address()
@ -485,7 +487,7 @@ class UPnPComponent(Component):
while True:
if now:
await self._maintain_redirects()
await asyncio.sleep(360)
await asyncio.sleep(360, loop=self.component_manager.loop)
async def _maintain_redirects(self):
# setup the gateway if necessary
@ -528,7 +530,7 @@ class UPnPComponent(Component):
self.upnp_redirects.update(upnp_redirects)
except (asyncio.TimeoutError, UPnPError):
self.upnp = None
return self._maintain_redirects()
return
elif self.upnp: # check existing redirects are still active
found = set()
mappings = await self.upnp.get_redirects()

View file

@ -40,7 +40,7 @@ from lbrynet.extras.daemon.comment_client import jsonrpc_batch, jsonrpc_post, rp
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.dht.node import Node
from lbrynet.extras.daemon.Components import UPnPComponent
from lbrynet.extras.daemon.exchange_rate_manager import ExchangeRateManager
@ -48,6 +48,7 @@ if typing.TYPE_CHECKING:
from lbrynet.wallet.manager import LbryWalletManager
from lbrynet.wallet.ledger import MainNetLedger
from lbrynet.stream.stream_manager import StreamManager
from lbrynet.stream.managed_stream import ManagedStream
log = logging.getLogger(__name__)
@ -272,6 +273,9 @@ class Daemon(metaclass=JSONRPCServerType):
app = web.Application()
app.router.add_get('/lbryapi', self.handle_old_jsonrpc)
app.router.add_post('/lbryapi', self.handle_old_jsonrpc)
app.router.add_get('/get/{claim_name}', self.handle_stream_get_request)
app.router.add_get('/get/{claim_name}/{claim_id}', self.handle_stream_get_request)
app.router.add_get('/stream/{sd_hash}', self.handle_stream_range_request)
app.router.add_post('/', self.handle_old_jsonrpc)
self.runner = web.AppRunner(app)
@ -296,7 +300,7 @@ class Daemon(metaclass=JSONRPCServerType):
return self.component_manager.get_component(EXCHANGE_RATE_MANAGER_COMPONENT)
@property
def blob_manager(self) -> typing.Optional['BlobFileManager']:
def blob_manager(self) -> typing.Optional['BlobManager']:
return self.component_manager.get_component(BLOB_COMPONENT)
@property
@ -452,6 +456,72 @@ class Daemon(metaclass=JSONRPCServerType):
content_type='application/json'
)
async def handle_stream_get_request(self, request: web.Request):
name_and_claim_id = request.path.split("/get/")[1]
if "/" not in name_and_claim_id:
uri = f"lbry://{name_and_claim_id}"
else:
name, claim_id = name_and_claim_id.split("/")
uri = f"lbry://{name}#{claim_id}"
stream = await self.jsonrpc_get(uri)
if isinstance(stream, dict):
raise web.HTTPServerError(text=stream['error'])
raise web.HTTPFound(f"/stream/{stream.sd_hash}")
@staticmethod
def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str],
int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in stream.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': stream.mime_type
}
return headers, size, skip_blobs
async def handle_stream_range_request(self, request: web.Request):
sd_hash = request.path.split("/stream/")[1]
if sd_hash not in self.stream_manager.streams:
return web.HTTPNotFound()
stream = self.stream_manager.streams[sd_hash]
if stream.status == 'stopped':
await self.stream_manager.start_stream(stream)
if stream.delayed_stop:
stream.delayed_stop.cancel()
headers, size, skip_blobs = self.prepare_range_response_headers(
request.headers.get('range', 'bytes=0-'), stream
)
response = web.StreamResponse(
status=206,
headers=headers
)
await response.prepare(request)
wrote = 0
async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs):
log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1)
if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += b'\x00' * (size - len(decrypted) - wrote)
await response.write_eof(decrypted)
break
else:
await response.write(decrypted)
wrote += len(decrypted)
response.force_close()
return response
async def _process_rpc_call(self, data):
args = data.get('params', {})
@ -827,24 +897,26 @@ class Daemon(metaclass=JSONRPCServerType):
@requires(WALLET_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT,
STREAM_MANAGER_COMPONENT,
conditions=[WALLET_IS_UNLOCKED])
async def jsonrpc_get(self, uri, file_name=None, timeout=None):
async def jsonrpc_get(self, uri, file_name=None, timeout=None, save_file=None):
"""
Download stream from a LBRY name.
Usage:
get <uri> [<file_name> | --file_name=<file_name>] [<timeout> | --timeout=<timeout>]
get <uri> [<file_name> | --file_name=<file_name>] [<timeout> | --timeout=<timeout>] [--save_file]
Options:
--uri=<uri> : (str) uri of the content to download
--file_name=<file_name> : (str) specified name for the downloaded file
--file_name=<file_name> : (str) specified name for the downloaded file, overrides the stream file name
--timeout=<timeout> : (int) download timeout in number of seconds
--save_file : (bool) save the file to the downloads directory
Returns: {File}
"""
save_file = save_file if save_file is not None else self.conf.save_files
try:
stream = await self.stream_manager.download_stream_from_uri(
uri, self.exchange_rate_manager, file_name, timeout
uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file
)
if not stream:
raise DownloadSDTimeout(uri)
@ -1534,6 +1606,45 @@ class Daemon(metaclass=JSONRPCServerType):
result = True
return result
@requires(STREAM_MANAGER_COMPONENT)
async def jsonrpc_file_save(self, file_name=None, download_directory=None, **kwargs):
"""
Output a download to a file
Usage:
file_save [--file_name=<file_name>] [--download_directory=<download_directory>] [--sd_hash=<sd_hash>]
[--stream_hash=<stream_hash>] [--rowid=<rowid>] [--claim_id=<claim_id>] [--txid=<txid>]
[--nout=<nout>] [--claim_name=<claim_name>] [--channel_claim_id=<channel_claim_id>]
[--channel_name=<channel_name>]
Options:
--file_name=<file_name> : (str) delete by file name in downloads folder
--download_directory=<download_directory> : (str) delete by file name in downloads folder
--sd_hash=<sd_hash> : (str) delete by file sd hash
--stream_hash=<stream_hash> : (str) delete by file stream hash
--rowid=<rowid> : (int) delete by file row id
--claim_id=<claim_id> : (str) delete by file claim id
--txid=<txid> : (str) delete by file claim txid
--nout=<nout> : (int) delete by file claim nout
--claim_name=<claim_name> : (str) delete by file claim name
--channel_claim_id=<channel_claim_id> : (str) delete by file channel claim id
--channel_name=<channel_name> : (str) delete by file channel claim name
Returns: {File}
"""
streams = self.stream_manager.get_filtered_streams(**kwargs)
if len(streams) > 1:
log.warning("There are %i matching files, use narrower filters to select one", len(streams))
return False
if not streams:
log.warning("There is no file to save")
return False
stream = streams[0]
await stream.save_file(file_name, download_directory)
return stream
CLAIM_DOC = """
List and search all types of claims.
"""
@ -2210,9 +2321,7 @@ class Daemon(metaclass=JSONRPCServerType):
await self.storage.save_claims([self._old_get_temp_claim_info(
tx, new_txo, claim_address, claim, name, dewies_to_lbc(amount)
)])
stream_hash = await self.storage.get_stream_hash_for_sd_hash(claim.stream.source.sd_hash)
if stream_hash:
await self.storage.save_content_claim(stream_hash, new_txo.id)
await self.storage.save_content_claim(file_stream.stream_hash, new_txo.id)
await self.analytics_manager.send_claim_action('publish')
else:
await account.ledger.release_tx(tx)
@ -2365,6 +2474,9 @@ class Daemon(metaclass=JSONRPCServerType):
file_stream = await self.stream_manager.create_stream(file_path)
new_txo.claim.stream.source.sd_hash = file_stream.sd_hash
new_txo.script.generate()
stream_hash = file_stream.stream_hash
else:
stream_hash = await self.storage.get_stream_hash_for_sd_hash(old_txo.claim.stream.source.sd_hash)
if channel:
new_txo.sign(channel)
await tx.sign([account])
@ -2372,8 +2484,6 @@ class Daemon(metaclass=JSONRPCServerType):
await self.storage.save_claims([self._old_get_temp_claim_info(
tx, new_txo, claim_address, new_txo.claim, new_txo.claim_name, dewies_to_lbc(amount)
)])
stream_hash = await self.storage.get_stream_hash_for_sd_hash(new_txo.claim.stream.source.sd_hash)
if stream_hash:
await self.storage.save_content_claim(stream_hash, new_txo.id)
await self.analytics_manager.send_claim_action('publish')
else:
@ -2934,9 +3044,9 @@ class Daemon(metaclass=JSONRPCServerType):
else:
blobs = list(self.blob_manager.completed_blob_hashes)
if needed:
blobs = [blob_hash for blob_hash in blobs if not self.blob_manager.get_blob(blob_hash).get_is_verified()]
blobs = [blob_hash for blob_hash in blobs if not self.blob_manager.is_blob_verified(blob_hash)]
if finished:
blobs = [blob_hash for blob_hash in blobs if self.blob_manager.get_blob(blob_hash).get_is_verified()]
blobs = [blob_hash for blob_hash in blobs if self.blob_manager.is_blob_verified(blob_hash)]
page_size = page_size or len(blobs)
page = page or 0
start_index = page * page_size

View file

@ -54,10 +54,6 @@ class StoredStreamClaim:
def nout(self) -> typing.Optional[int]:
return None if not self.outpoint else int(self.outpoint.split(":")[1])
@property
def metadata(self) -> typing.Optional[typing.Dict]:
return None if not self.claim else self.claim.claim_dict['stream']['metadata']
def as_dict(self) -> typing.Dict:
return {
"name": self.claim_name,
@ -195,12 +191,16 @@ def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor
transaction.executemany("delete from blob where blob_hash=?", blob_hashes)
def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: str, download_directory: str,
data_payment_rate: float, status: str) -> int:
def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str],
download_directory: typing.Optional[str], data_payment_rate: float, status: str) -> int:
if not file_name and not download_directory:
encoded_file_name, encoded_download_dir = "{stream}", "{stream}"
else:
encoded_file_name = binascii.hexlify(file_name.encode()).decode()
encoded_download_dir = binascii.hexlify(download_directory.encode()).decode()
transaction.execute(
"insert or replace into file values (?, ?, ?, ?, ?)",
(stream_hash, binascii.hexlify(file_name.encode()).decode(),
binascii.hexlify(download_directory.encode()).decode(), data_payment_rate, status)
(stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status)
)
return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0]
@ -296,27 +296,28 @@ class SQLiteStorage(SQLiteMixin):
# # # # # # # # # blob functions # # # # # # # # #
def add_completed_blob(self, blob_hash: str, length: int):
def _add_blob(transaction: sqlite3.Connection):
transaction.execute(
async def add_blobs(self, *blob_hashes_and_lengths: typing.Tuple[str, int], finished=False):
def _add_blobs(transaction: sqlite3.Connection):
transaction.executemany(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)",
(blob_hash, length, 0, 0, "pending", 0, 0)
[
(blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0)
for blob_hash, length in blob_hashes_and_lengths
]
)
transaction.execute(
"update blob set status='finished' where blob.blob_hash=?", (blob_hash, )
if finished:
transaction.executemany(
"update blob set status='finished' where blob.blob_hash=?", [
(blob_hash, ) for blob_hash, _ in blob_hashes_and_lengths
]
)
return self.db.run(_add_blob)
return await self.db.run(_add_blobs)
def get_blob_status(self, blob_hash: str):
return self.run_and_return_one_or_none(
"select status from blob where blob_hash=?", blob_hash
)
def add_known_blob(self, blob_hash: str, length: int):
return self.db.execute(
"insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", (blob_hash, length, 0, 0, "pending", 0, 0)
)
def should_announce(self, blob_hash: str):
return self.run_and_return_one_or_none(
"select should_announce from blob where blob_hash=?", blob_hash
@ -419,6 +420,26 @@ class SQLiteStorage(SQLiteMixin):
}
return self.db.run(_sync_blobs)
def sync_files_to_blobs(self):
def _sync_blobs(transaction: sqlite3.Connection):
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
transaction.execute(
"select distinct sb.stream_hash from stream_blob sb "
"inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'"
).fetchall()
)
return self.db.run(_sync_blobs)
def set_files_as_streaming(self, stream_hashes: typing.List[str]):
def _set_streaming(transaction: sqlite3.Connection):
transaction.executemany(
"update file set file_name='{stream}', download_directory='{stream}' where stream_hash=?",
[(stream_hash, ) for stream_hash in stream_hashes]
)
return self.db.run(_set_streaming)
# # # # # # # # # stream functions # # # # # # # # #
async def stream_exists(self, sd_hash: str) -> bool:
@ -481,6 +502,12 @@ class SQLiteStorage(SQLiteMixin):
"select stream_hash from stream where sd_hash = ?", sd_blob_hash
)
def get_stream_info_for_sd_hash(self, sd_blob_hash):
return self.run_and_return_one_or_none(
"select stream_hash, stream_name, suggested_filename, stream_key from stream where sd_hash = ?",
sd_blob_hash
)
def delete_stream(self, descriptor: 'StreamDescriptor'):
return self.db.run_with_foreign_keys_disabled(delete_stream, descriptor)
@ -492,7 +519,8 @@ class SQLiteStorage(SQLiteMixin):
stream_hash, file_name, download_directory, data_payment_rate, status="running"
)
def save_published_file(self, stream_hash: str, file_name: str, download_directory: str, data_payment_rate: float,
def save_published_file(self, stream_hash: str, file_name: typing.Optional[str],
download_directory: typing.Optional[str], data_payment_rate: float,
status="finished") -> typing.Awaitable[int]:
return self.db.run(store_file, stream_hash, file_name, download_directory, data_payment_rate, status)
@ -503,10 +531,15 @@ class SQLiteStorage(SQLiteMixin):
log.debug("update file status %s -> %s", stream_hash, new_status)
return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash))
def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: str, file_name: str):
return self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", (
binascii.hexlify(download_dir.encode()).decode(), binascii.hexlify(file_name.encode()).decode(),
stream_hash
async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str],
file_name: typing.Optional[str]):
if not file_name or not download_dir:
encoded_file_name, encoded_download_dir = "{stream}", "{stream}"
else:
encoded_file_name = binascii.hexlify(file_name.encode()).decode()
encoded_download_dir = binascii.hexlify(download_dir.encode()).decode()
return await self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", (
encoded_download_dir, encoded_file_name, stream_hash,
))
async def recover_streams(self, descriptors_and_sds: typing.List[typing.Tuple['StreamDescriptor', 'BlobFile']],

View file

@ -1,136 +0,0 @@
import os
import binascii
import logging
import typing
import asyncio
from lbrynet.blob import MAX_BLOB_SIZE
from lbrynet.stream.descriptor import StreamDescriptor
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_info import BlobInfo
from lbrynet.blob.blob_file import BlobFile
log = logging.getLogger(__name__)
def _get_next_available_file_name(download_directory: str, file_name: str) -> str:
base_name, ext = os.path.splitext(os.path.basename(file_name))
i = 0
while os.path.isfile(os.path.join(download_directory, file_name)):
i += 1
file_name = "%s_%i%s" % (base_name, i, ext)
return file_name
async def get_next_available_file_name(loop: asyncio.BaseEventLoop, download_directory: str, file_name: str) -> str:
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
class StreamAssembler:
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', sd_hash: str,
output_file_name: typing.Optional[str] = None):
self.output_file_name = output_file_name
self.loop = loop
self.blob_manager = blob_manager
self.sd_hash = sd_hash
self.sd_blob: 'BlobFile' = None
self.descriptor: StreamDescriptor = None
self.got_descriptor = asyncio.Event(loop=self.loop)
self.wrote_bytes_event = asyncio.Event(loop=self.loop)
self.stream_finished_event = asyncio.Event(loop=self.loop)
self.output_path = ''
self.stream_handle = None
self.written_bytes: int = 0
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
if not blob or not self.stream_handle or self.stream_handle.closed:
return False
def _decrypt_and_write():
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
self.stream_handle.seek(offset)
_decrypted = blob.decrypt(
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
)
self.stream_handle.write(_decrypted)
self.stream_handle.flush()
self.written_bytes += len(_decrypted)
log.debug("decrypted %s", blob.blob_hash[:8])
await self.loop.run_in_executor(None, _decrypt_and_write)
return True
async def setup(self):
pass
async def after_got_descriptor(self):
pass
async def after_finished(self):
pass
async def assemble_decrypted_stream(self, output_dir: str, output_file_name: typing.Optional[str] = None):
if not os.path.isdir(output_dir):
raise OSError(f"output directory does not exist: '{output_dir}' '{output_file_name}'")
await self.setup()
self.sd_blob = await self.get_blob(self.sd_hash)
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_manager.blob_dir,
self.sd_blob)
await self.after_got_descriptor()
self.output_file_name = output_file_name or self.descriptor.suggested_file_name
self.output_file_name = await get_next_available_file_name(self.loop, output_dir, self.output_file_name)
self.output_path = os.path.join(output_dir, self.output_file_name)
if not self.got_descriptor.is_set():
self.got_descriptor.set()
await self.blob_manager.storage.store_stream(
self.sd_blob, self.descriptor
)
await self.blob_manager.blob_completed(self.sd_blob)
written_blobs = None
save_tasks = []
try:
with open(self.output_path, 'wb') as stream_handle:
self.stream_handle = stream_handle
for i, blob_info in enumerate(self.descriptor.blobs[:-1]):
if blob_info.blob_num != i:
log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash)
return
while self.stream_handle and not self.stream_handle.closed:
try:
blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
if blob and blob.length != blob_info.length:
log.warning("Found incomplete, deleting: %s", blob_info.blob_hash)
await self.blob_manager.delete_blobs([blob_info.blob_hash])
continue
if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
save_tasks.append(asyncio.ensure_future(self.blob_manager.blob_completed(blob)))
written_blobs = i
if not self.wrote_bytes_event.is_set():
self.wrote_bytes_event.set()
log.debug("written %i/%i", written_blobs, len(self.descriptor.blobs) - 2)
break
except FileNotFoundError:
log.debug("stream assembler stopped")
return
except (ValueError, IOError, OSError):
log.warning("failed to decrypt blob %s for stream %s", blob_info.blob_hash,
self.descriptor.sd_hash)
continue
finally:
if written_blobs == len(self.descriptor.blobs) - 2:
log.debug("finished decrypting and assembling stream")
if save_tasks:
await asyncio.wait(save_tasks)
await self.after_finished()
self.stream_finished_event.set()
else:
log.debug("stream decryption and assembly did not finish (%i/%i blobs are done)", written_blobs or 0,
len(self.descriptor.blobs) - 2)
if self.output_path and os.path.isfile(self.output_path):
log.debug("erasing incomplete file assembly: %s", self.output_path)
os.unlink(self.output_path)
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
return self.blob_manager.get_blob(blob_hash, length)

View file

@ -8,7 +8,7 @@ from collections import OrderedDict
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from lbrynet.blob import MAX_BLOB_SIZE
from lbrynet.blob.blob_info import BlobInfo
from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.blob_file import AbstractBlob, BlobFile
from lbrynet.cryptoutils import get_lbry_hash_obj
from lbrynet.error import InvalidStreamDescriptorError
@ -108,29 +108,30 @@ class StreamDescriptor:
h.update(self.old_sort_json())
return h.hexdigest()
async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None,
old_sort: typing.Optional[bool] = False):
async def make_sd_blob(self, blob_file_obj: typing.Optional[AbstractBlob] = None,
old_sort: typing.Optional[bool] = False,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None):
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
if not old_sort:
sd_data = self.as_json()
else:
sd_data = self.old_sort_json()
sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
sd_blob = blob_file_obj or BlobFile(self.loop, sd_hash, len(sd_data), blob_completed_callback, self.blob_dir)
if blob_file_obj:
blob_file_obj.set_length(len(sd_data))
if not sd_blob.get_is_verified():
writer = sd_blob.open_for_writing()
writer = sd_blob.get_blob_writer()
writer.write(sd_data)
await sd_blob.verified.wait()
sd_blob.close()
return sd_blob
@classmethod
def _from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
blob: BlobFile) -> 'StreamDescriptor':
assert os.path.isfile(blob.file_path)
with open(blob.file_path, 'rb') as f:
json_bytes = f.read()
blob: AbstractBlob) -> 'StreamDescriptor':
with blob.reader_context() as blob_reader:
json_bytes = blob_reader.read()
try:
decoded = json.loads(json_bytes.decode())
except json.JSONDecodeError:
@ -160,8 +161,10 @@ class StreamDescriptor:
@classmethod
async def from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
blob: BlobFile) -> 'StreamDescriptor':
return await loop.run_in_executor(None, lambda: cls._from_stream_descriptor_blob(loop, blob_dir, blob))
blob: AbstractBlob) -> 'StreamDescriptor':
if not blob.is_readable():
raise InvalidStreamDescriptorError(f"unreadable/missing blob: {blob.blob_hash}")
return await loop.run_in_executor(None, cls._from_stream_descriptor_blob, loop, blob_dir, blob)
@staticmethod
def get_blob_hashsum(b: typing.Dict):
@ -194,11 +197,12 @@ class StreamDescriptor:
return h.hexdigest()
@classmethod
async def create_stream(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
file_path: str, key: typing.Optional[bytes] = None,
async def create_stream(
cls, loop: asyncio.BaseEventLoop, blob_dir: str, file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None,
old_sort: bool = False) -> 'StreamDescriptor':
old_sort: bool = False,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'],
asyncio.Task]] = None) -> 'StreamDescriptor':
blobs: typing.List[BlobInfo] = []
iv_generator = iv_generator or random_iv_generator()
@ -207,7 +211,7 @@ class StreamDescriptor:
for blob_bytes in file_reader(file_path):
blob_num += 1
blob_info = await BlobFile.create_from_unencrypted(
loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num
loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num, blob_completed_callback
)
blobs.append(blob_info)
blobs.append(
@ -216,7 +220,7 @@ class StreamDescriptor:
loop, blob_dir, os.path.basename(file_path), binascii.hexlify(key).decode(), os.path.basename(file_path),
blobs
)
sd_blob = await descriptor.make_sd_blob(old_sort=old_sort)
sd_blob = await descriptor.make_sd_blob(old_sort=old_sort, blob_completed_callback=blob_completed_callback)
descriptor.sd_hash = sd_blob.blob_hash
return descriptor
@ -228,7 +232,7 @@ class StreamDescriptor:
return self.lower_bound_decrypted_length() + (AES.block_size // 8)
@classmethod
async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str,
async def recover(cls, blob_dir: str, sd_blob: 'AbstractBlob', stream_hash: str, stream_name: str,
suggested_file_name: str, key: str,
blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']:
descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name,

View file

@ -1,90 +1,42 @@
import os
import asyncio
import typing
import logging
import binascii
from lbrynet.error import DownloadSDTimeout
from lbrynet.utils import resolve_host
from lbrynet.stream.assembler import StreamAssembler
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.blob_exchange.downloader import BlobDownloader
from lbrynet.dht.peer import KademliaPeer
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.dht.node import Node
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_file import AbstractBlob
from lbrynet.blob.blob_info import BlobInfo
log = logging.getLogger(__name__)
def drain_into(a: list, b: list):
while a:
b.append(a.pop())
class StreamDownloader(StreamAssembler):
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', sd_hash: str,
output_dir: typing.Optional[str] = None, output_file_name: typing.Optional[str] = None):
super().__init__(loop, blob_manager, sd_hash, output_file_name)
class StreamDownloader:
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', sd_hash: str,
descriptor: typing.Optional[StreamDescriptor] = None):
self.loop = loop
self.config = config
self.output_dir = output_dir or self.config.download_dir
self.output_file_name = output_file_name
self.blob_downloader: typing.Optional[BlobDownloader] = None
self.search_queue = asyncio.Queue(loop=loop)
self.peer_queue = asyncio.Queue(loop=loop)
self.accumulate_task: typing.Optional[asyncio.Task] = None
self.descriptor: typing.Optional[StreamDescriptor]
self.blob_manager = blob_manager
self.sd_hash = sd_hash
self.search_queue = asyncio.Queue(loop=loop) # blob hashes to feed into the iterative finder
self.peer_queue = asyncio.Queue(loop=loop) # new peers to try
self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue)
self.descriptor: typing.Optional[StreamDescriptor] = descriptor
self.node: typing.Optional['Node'] = None
self.assemble_task: typing.Optional[asyncio.Task] = None
self.accumulate_task: typing.Optional[asyncio.Task] = None
self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None
self.fixed_peers_delay: typing.Optional[float] = None
self.added_fixed_peers = False
self.time_to_descriptor: typing.Optional[float] = None
self.time_to_first_bytes: typing.Optional[float] = None
async def setup(self): # start the peer accumulator and initialize the downloader
if self.blob_downloader:
raise Exception("downloader is already set up")
if self.node:
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue)
self.search_queue.put_nowait(self.sd_hash)
async def after_got_descriptor(self):
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
log.info("added head blob to search")
async def after_finished(self):
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
self.blob_downloader.close()
def stop(self):
if self.accumulate_task:
self.accumulate_task.cancel()
self.accumulate_task = None
if self.assemble_task:
self.assemble_task.cancel()
self.assemble_task = None
if self.fixed_peers_handle:
self.fixed_peers_handle.cancel()
self.fixed_peers_handle = None
self.blob_downloader = None
if self.stream_handle:
if not self.stream_handle.closed:
self.stream_handle.close()
self.stream_handle = None
if not self.stream_finished_event.is_set() and self.wrote_bytes_event.is_set() and \
self.output_path and os.path.isfile(self.output_path):
os.remove(self.output_path)
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
return await self.blob_downloader.download_blob(blob_hash, length)
def add_fixed_peers(self):
async def _add_fixed_peers():
addresses = [
(await resolve_host(url, port + 1, proto='tcp'), port)
for url, port in self.config.reflector_servers
]
async def add_fixed_peers(self):
def _delayed_add_fixed_peers():
self.added_fixed_peers = True
self.peer_queue.put_nowait([
@ -92,17 +44,89 @@ class StreamDownloader(StreamAssembler):
for address, port in addresses
])
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
if not self.config.reflector_servers:
return
addresses = [
(await resolve_host(url, port + 1, proto='tcp'), port)
for url, port in self.config.reflector_servers
]
if 'dht' in self.config.components_to_skip or not self.node or not \
len(self.node.protocol.routing_table.get_peers()):
self.fixed_peers_delay = 0.0
else:
self.fixed_peers_delay = self.config.fixed_peer_delay
self.loop.create_task(_add_fixed_peers())
def download(self, node: typing.Optional['Node'] = None):
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
async def load_descriptor(self):
# download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified():
try:
now = self.loop.time()
sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash),
self.config.blob_download_timeout, loop=self.loop
)
log.info("downloaded sd blob %s", self.sd_hash)
self.time_to_descriptor = self.loop.time() - now
except asyncio.TimeoutError:
raise DownloadSDTimeout(self.sd_hash)
# parse the descriptor
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, sd_blob
)
log.info("loaded stream manifest %s", self.sd_hash)
async def start(self, node: typing.Optional['Node'] = None):
# set up peer accumulation
if node:
self.node = node
self.assemble_task = self.loop.create_task(self.assemble_decrypted_stream(self.config.download_dir))
self.add_fixed_peers()
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
await self.add_fixed_peers()
# start searching for peers for the sd hash
self.search_queue.put_nowait(self.sd_hash)
log.info("searching for peers for stream %s", self.sd_hash)
if not self.descriptor:
await self.load_descriptor()
# add the head blob to the peer search
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
log.info("added head blob to peer search for stream %s", self.sd_hash)
if not await self.blob_manager.storage.stream_exists(self.sd_hash):
await self.blob_manager.storage.store_stream(
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
)
async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob':
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length)
return blob
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
return blob.decrypt(
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
)
async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
start = None
if self.time_to_first_bytes is None:
start = self.loop.time()
blob = await self.download_stream_blob(blob_info)
decrypted = self.decrypt_blob(blob_info, blob)
if start:
self.time_to_first_bytes = self.loop.time() - start
return decrypted
def stop(self):
if self.accumulate_task:
self.accumulate_task.cancel()
self.accumulate_task = None
if self.fixed_peers_handle:
self.fixed_peers_handle.cancel()
self.fixed_peers_handle = None
self.blob_downloader.close()

View file

@ -10,41 +10,75 @@ from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.reflector.client import StreamReflectorClient
from lbrynet.extras.daemon.storage import StoredStreamClaim
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.schema.claim import Claim
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_info import BlobInfo
from lbrynet.dht.node import Node
from lbrynet.extras.daemon.analytics import AnalyticsManager
from lbrynet.wallet.transaction import Transaction
log = logging.getLogger(__name__)
def _get_next_available_file_name(download_directory: str, file_name: str) -> str:
base_name, ext = os.path.splitext(os.path.basename(file_name))
i = 0
while os.path.isfile(os.path.join(download_directory, file_name)):
i += 1
file_name = "%s_%i%s" % (base_name, i, ext)
return file_name
async def get_next_available_file_name(loop: asyncio.BaseEventLoop, download_directory: str, file_name: str) -> str:
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
class ManagedStream:
STATUS_RUNNING = "running"
STATUS_STOPPED = "stopped"
STATUS_FINISHED = "finished"
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', rowid: int,
descriptor: 'StreamDescriptor', download_directory: str, file_name: typing.Optional[str],
downloader: typing.Optional[StreamDownloader] = None,
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
download_id: typing.Optional[str] = None):
download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None,
descriptor: typing.Optional[StreamDescriptor] = None,
content_fee: typing.Optional['Transaction'] = None,
analytics_manager: typing.Optional['AnalyticsManager'] = None):
self.loop = loop
self.config = config
self.blob_manager = blob_manager
self.rowid = rowid
self.sd_hash = sd_hash
self.download_directory = download_directory
self._file_name = file_name
self.descriptor = descriptor
self.downloader = downloader
self.stream_hash = descriptor.stream_hash
self.stream_claim_info = claim
self._status = status
self.fully_reflected = asyncio.Event(loop=self.loop)
self.tx = None
self.stream_claim_info = claim
self.download_id = download_id or binascii.hexlify(generate_id()).decode()
self.rowid = rowid
self.written_bytes = 0
self.content_fee = content_fee
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None
self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop)
@property
def descriptor(self) -> StreamDescriptor:
return self.downloader.descriptor
@property
def stream_hash(self) -> str:
return self.descriptor.stream_hash
@property
def file_name(self) -> typing.Optional[str]:
return self.downloader.output_file_name if self.downloader else self._file_name
return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None)
@property
def status(self) -> str:
@ -105,37 +139,35 @@ class ManagedStream:
@property
def blobs_completed(self) -> int:
return sum([1 if self.blob_manager.get_blob(b.blob_hash).get_is_verified() else 0
return sum([1 if self.blob_manager.is_blob_verified(b.blob_hash) else 0
for b in self.descriptor.blobs[:-1]])
@property
def blobs_in_stream(self) -> int:
return len(self.descriptor.blobs) - 1
@property
def sd_hash(self):
return self.descriptor.sd_hash
@property
def blobs_remaining(self) -> int:
return self.blobs_in_stream - self.blobs_completed
@property
def full_path(self) -> typing.Optional[str]:
return os.path.join(self.download_directory, os.path.basename(self.file_name)) if self.file_name else None
return os.path.join(self.download_directory, os.path.basename(self.file_name)) \
if self.file_name and self.download_directory else None
@property
def output_file_exists(self):
return os.path.isfile(self.full_path) if self.full_path else False
def as_dict(self) -> typing.Dict:
full_path = self.full_path if self.output_file_exists else None
mime_type = guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0]
@property
def mime_type(self):
return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0]
if self.downloader and self.downloader.written_bytes:
written_bytes = self.downloader.written_bytes
elif full_path:
written_bytes = os.stat(full_path).st_size
def as_dict(self) -> typing.Dict:
if self.written_bytes:
written_bytes = self.written_bytes
elif self.output_file_exists:
written_bytes = os.stat(self.full_path).st_size
else:
written_bytes = None
return {
@ -143,14 +175,13 @@ class ManagedStream:
'file_name': self.file_name,
'download_directory': self.download_directory,
'points_paid': 0.0,
'tx': self.tx,
'stopped': not self.running,
'stream_hash': self.stream_hash,
'stream_name': self.descriptor.stream_name,
'suggested_file_name': self.descriptor.suggested_file_name,
'sd_hash': self.descriptor.sd_hash,
'download_path': full_path,
'mime_type': mime_type,
'download_path': self.full_path,
'mime_type': self.mime_type,
'key': self.descriptor.key,
'total_bytes_lower_bound': self.descriptor.lower_bound_decrypted_length(),
'total_bytes': self.descriptor.upper_bound_decrypted_length(),
@ -167,36 +198,135 @@ class ManagedStream:
'protobuf': self.metadata_protobuf,
'channel_claim_id': self.channel_claim_id,
'channel_name': self.channel_name,
'claim_name': self.claim_name
'claim_name': self.claim_name,
'content_fee': self.content_fee # TODO: this isn't in the database
}
@classmethod
async def create(cls, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager',
async def create(cls, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedStream':
descriptor = await StreamDescriptor.create_stream(
loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator
loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator,
blob_completed_callback=blob_manager.blob_completed
)
sd_blob = blob_manager.get_blob(descriptor.sd_hash)
await blob_manager.storage.store_stream(
blob_manager.get_blob(descriptor.sd_hash), descriptor
)
await blob_manager.blob_completed(sd_blob)
for blob in descriptor.blobs[:-1]:
await blob_manager.blob_completed(blob_manager.get_blob(blob.blob_hash, blob.length))
row_id = await blob_manager.storage.save_published_file(descriptor.stream_hash, os.path.basename(file_path),
os.path.dirname(file_path), 0)
return cls(loop, blob_manager, row_id, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
status=cls.STATUS_FINISHED)
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
def start_download(self, node: typing.Optional['Node']):
self.downloader.download(node)
self.update_status(self.STATUS_RUNNING)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
await self.downloader.start(node)
if not save_file and not file_name:
if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = await self.blob_manager.storage.save_downloaded_file(
self.stream_hash, None, None, 0.0
)
self.download_directory = None
self._file_name = None
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self.update_delayed_stop()
elif not os.path.isfile(self.full_path):
await self.save_file(file_name, download_directory)
await self.started_writing.wait()
def update_delayed_stop(self):
def _delayed_stop():
log.info("Stopping inactive download for stream %s", self.sd_hash)
self.stop_download()
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = self.loop.call_later(60, _delayed_stop)
async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[
typing.Tuple['BlobInfo', bytes]]:
if start_blob_num >= len(self.descriptor.blobs[:-1]):
raise IndexError(start_blob_num)
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
assert i + start_blob_num == blob_info.blob_num
if self.delayed_stop:
self.delayed_stop.cancel()
try:
decrypted = await self.downloader.read_blob(blob_info)
yield (blob_info, decrypted)
except asyncio.CancelledError:
if not self.saving.is_set() and not self.finished_writing.is_set():
self.update_delayed_stop()
raise
async def _save_file(self, output_path: str):
log.debug("save file %s -> %s", self.sd_hash, output_path)
self.saving.set()
self.finished_writing.clear()
self.started_writing.clear()
try:
with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream():
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted)
file_write_handle.flush()
self.written_bytes += len(decrypted)
if not self.started_writing.is_set():
self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash
))
self.finished_writing.set()
except Exception as err:
if os.path.isfile(output_path):
log.info("removing incomplete download %s for %s", output_path, self.sd_hash)
os.remove(output_path)
if not isinstance(err, asyncio.CancelledError):
log.exception("unexpected error encountered writing file for stream %s", self.sd_hash)
raise err
finally:
self.saving.clear()
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel()
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = None
self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory:
raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.descriptor.suggested_file_name):
raise ValueError("no file name to download to")
if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory)
self._file_name = await get_next_available_file_name(
self.loop, self.download_directory,
file_name or self._file_name or self.descriptor.suggested_file_name
)
if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = self.blob_manager.storage.save_downloaded_file(
self.stream_hash, self.file_name, self.download_directory, 0.0
)
else:
await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name
)
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self.written_bytes = 0
self.file_output_task = self.loop.create_task(self._save_file(self.full_path))
def stop_download(self):
if self.downloader:
if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel()
self.file_output_task = None
self.downloader.stop()
self.downloader = None
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
sent = []
@ -212,7 +342,9 @@ class ManagedStream:
self.fully_reflected.set()
await self.blob_manager.storage.update_reflected_stream(self.sd_hash, f"{host}:{port}")
return []
we_have = [blob_hash for blob_hash in needed if blob_hash in self.blob_manager.completed_blob_hashes]
we_have = [
blob_hash for blob_hash in needed if blob_hash in self.blob_manager.completed_blob_hashes
]
for blob_hash in we_have:
await protocol.send_blob(blob_hash)
sent.append(blob_hash)

View file

@ -4,7 +4,7 @@ import logging
import typing
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.descriptor import StreamDescriptor
REFLECTOR_V1 = 0
@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
class StreamReflectorClient(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
def __init__(self, blob_manager: 'BlobManager', descriptor: 'StreamDescriptor'):
self.loop = asyncio.get_event_loop()
self.transport: asyncio.StreamWriter = None
self.blob_manager = blob_manager
@ -64,7 +64,7 @@ class StreamReflectorClient(asyncio.Protocol):
async def send_descriptor(self) -> typing.Tuple[bool, typing.List[str]]: # returns a list of needed blob hashes
sd_blob = self.blob_manager.get_blob(self.descriptor.sd_hash)
assert sd_blob.get_is_verified(), "need to have a sd blob to send at this point"
assert self.blob_manager.is_blob_verified(self.descriptor.sd_hash), "need to have sd blob to send at this point"
response = await self.send_request({
'sd_blob_hash': sd_blob.blob_hash,
'sd_blob_size': sd_blob.length
@ -80,7 +80,7 @@ class StreamReflectorClient(asyncio.Protocol):
sent_sd = True
if not needed:
for blob in self.descriptor.blobs[:-1]:
if self.blob_manager.get_blob(blob.blob_hash, blob.length).get_is_verified():
if self.blob_manager.is_blob_verified(blob.blob_hash, blob.length):
needed.append(blob.blob_hash)
log.info("Sent reflector descriptor %s", sd_blob.blob_hash[:8])
self.reflected_blobs.append(sd_blob.blob_hash)
@ -91,8 +91,8 @@ class StreamReflectorClient(asyncio.Protocol):
return sent_sd, needed
async def send_blob(self, blob_hash: str):
assert self.blob_manager.is_blob_verified(blob_hash), "need to have a blob to send at this point"
blob = self.blob_manager.get_blob(blob_hash)
assert blob.get_is_verified(), "need to have a blob to send at this point"
response = await self.send_request({
'blob_hash': blob.blob_hash,
'blob_size': blob.length

View file

@ -7,7 +7,7 @@ from lbrynet.stream.descriptor import StreamDescriptor
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.writer import HashBlobWriter
@ -15,7 +15,7 @@ log = logging.getLogger(__name__)
class ReflectorServerProtocol(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobFileManager'):
def __init__(self, blob_manager: 'BlobManager'):
self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager
self.server_task: asyncio.Task = None
@ -63,11 +63,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
return
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.open_for_writing()
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.incoming.set()
self.send_response({"send_sd_blob": True})
try:
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
await asyncio.wait_for(self.sd_blob.verified.wait(), 30, loop=self.loop)
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob
)
@ -102,11 +102,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
return
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified():
self.writer = blob.open_for_writing()
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.incoming.set()
self.send_response({"send_blob": True})
try:
await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop)
await asyncio.wait_for(blob.verified.wait(), 30, loop=self.loop)
self.send_response({"received_blob": True})
except asyncio.TimeoutError:
self.send_response({"received_blob": False})
@ -121,7 +121,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
class ReflectorServer:
def __init__(self, blob_manager: 'BlobFileManager'):
def __init__(self, blob_manager: 'BlobManager'):
self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager
self.server_task: asyncio.Task = None

View file

@ -5,18 +5,17 @@ import binascii
import logging
import random
from decimal import Decimal
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError, \
DownloadDataTimeout, DownloadSDTimeout
from lbrynet.utils import generate_id
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import Claim
from lbrynet.schema.uri import parse_lbry_uri
from lbrynet.extras.daemon.storage import lbc_to_dewies
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.dht.node import Node
from lbrynet.extras.daemon.analytics import AnalyticsManager
from lbrynet.extras.daemon.storage import SQLiteStorage, StoredStreamClaim
@ -54,8 +53,11 @@ comparison_operators = {
}
def path_or_none(p) -> typing.Optional[str]:
return None if p == '{stream}' else binascii.unhexlify(p).decode()
class StreamManager:
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager',
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'],
analytics_manager: typing.Optional['AnalyticsManager'] = None):
self.loop = loop
@ -65,8 +67,7 @@ class StreamManager:
self.storage = storage
self.node = node
self.analytics_manager = analytics_manager
self.streams: typing.Set[ManagedStream] = set()
self.starting_streams: typing.Dict[str, asyncio.Future] = {}
self.streams: typing.Dict[str, ManagedStream] = {}
self.resume_downloading_task: asyncio.Task = None
self.re_reflect_task: asyncio.Task = None
self.update_stream_finished_futs: typing.List[asyncio.Future] = []
@ -76,46 +77,6 @@ class StreamManager:
claim_info = await self.storage.get_content_claim(stream.stream_hash)
stream.set_claim(claim_info, claim_info['value'])
async def start_stream(self, stream: ManagedStream) -> bool:
"""
Resume or rebuild a partial or completed stream
"""
if not stream.running and not stream.output_file_exists:
if stream.downloader:
stream.downloader.stop()
stream.downloader = None
# the directory is gone, can happen when the folder that contains a published file is deleted
# reset the download directory to the default and update the file name
if not os.path.isdir(stream.download_directory):
stream.download_directory = self.config.download_dir
stream.downloader = self.make_downloader(
stream.sd_hash, stream.download_directory, stream.descriptor.suggested_file_name
)
if stream.status != ManagedStream.STATUS_FINISHED:
await self.storage.change_file_status(stream.stream_hash, 'running')
stream.update_status('running')
stream.start_download(self.node)
try:
await asyncio.wait_for(self.loop.create_task(stream.downloader.wrote_bytes_event.wait()),
self.config.download_timeout)
except asyncio.TimeoutError:
await self.stop_stream(stream)
if stream in self.streams:
self.streams.remove(stream)
return False
file_name = os.path.basename(stream.downloader.output_path)
output_dir = os.path.dirname(stream.downloader.output_path)
await self.storage.change_file_download_dir_and_file_name(
stream.stream_hash, output_dir, file_name
)
stream._file_name = file_name
stream.download_directory = output_dir
self.wait_for_stream_finished(stream)
return True
return True
async def stop_stream(self, stream: ManagedStream):
stream.stop_download()
if not stream.finished and stream.output_file_exists:
@ -128,10 +89,11 @@ class StreamManager:
stream.update_status(ManagedStream.STATUS_STOPPED)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
def make_downloader(self, sd_hash: str, download_directory: str, file_name: str):
return StreamDownloader(
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
)
async def start_stream(self, stream: ManagedStream):
stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
await stream.setup(self.node, save_file=self.config.save_files)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = []
@ -156,81 +118,88 @@ class StreamManager:
if to_restore:
await self.storage.recover_streams(to_restore, self.config.download_dir)
log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos))
async def add_stream(self, rowid: int, sd_hash: str, file_name: str, download_directory: str, status: str,
# if self.blob_manager._save_blobs:
# log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos))
async def add_stream(self, rowid: int, sd_hash: str, file_name: typing.Optional[str],
download_directory: typing.Optional[str], status: str,
claim: typing.Optional['StoredStreamClaim']):
sd_blob = self.blob_manager.get_blob(sd_hash)
if not sd_blob.get_is_verified():
return
try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
descriptor = await self.blob_manager.get_stream_descriptor(sd_hash)
except InvalidStreamDescriptorError as err:
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
return
if status == ManagedStream.STATUS_RUNNING:
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
else:
downloader = None
stream = ManagedStream(
self.loop, self.blob_manager, rowid, descriptor, download_directory, file_name, downloader, status, claim
self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status,
claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager
)
self.streams.add(stream)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
self.streams[sd_hash] = stream
async def load_streams_from_database(self):
to_recover = []
to_start = []
await self.storage.sync_files_to_blobs()
for file_info in await self.storage.get_all_lbry_files():
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
if not self.blob_manager.is_blob_verified(file_info['sd_hash']):
to_recover.append(file_info)
to_start.append(file_info)
if to_recover:
log.info("Attempting to recover %i streams", len(to_recover))
# if self.blob_manager._save_blobs:
# log.info("Attempting to recover %i streams", len(to_recover))
await self.recover_streams(to_recover)
to_start = []
for file_info in await self.storage.get_all_lbry_files():
if self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_start.append(file_info)
log.info("Initializing %i files", len(to_start))
if not self.config.save_files:
to_set_as_streaming = []
for file_info in to_start:
file_name = path_or_none(file_info['file_name'])
download_dir = path_or_none(file_info['download_directory'])
if file_name and download_dir and not os.path.isfile(os.path.join(file_name, download_dir)):
file_info['file_name'], file_info['download_directory'] = '{stream}', '{stream}'
to_set_as_streaming.append(file_info['stream_hash'])
if to_set_as_streaming:
await self.storage.set_files_as_streaming(to_set_as_streaming)
log.info("Initializing %i files", len(to_start))
if to_start:
await asyncio.gather(*[
self.add_stream(
file_info['rowid'], file_info['sd_hash'], binascii.unhexlify(file_info['file_name']).decode(),
binascii.unhexlify(file_info['download_directory']).decode(), file_info['status'],
file_info['rowid'], file_info['sd_hash'], path_or_none(file_info['file_name']),
path_or_none(file_info['download_directory']), file_info['status'],
file_info['claim']
) for file_info in to_start
])
log.info("Started stream manager with %i files", len(self.streams))
async def resume(self):
if self.node:
await self.node.joined.wait()
else:
if not self.node:
log.warning("no DHT node given, resuming downloads trusting that we can contact reflector")
t = [
(stream.start_download(self.node), self.wait_for_stream_finished(stream))
for stream in self.streams if stream.status == ManagedStream.STATUS_RUNNING
self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values()
if stream.running
]
if t:
log.info("resuming %i downloads", len(t))
await asyncio.gather(*t, loop=self.loop)
async def reflect_streams(self):
while True:
if self.config.reflect_streams and self.config.reflector_servers:
sd_hashes = await self.storage.get_streams_to_re_reflect()
streams = list(filter(lambda s: s.sd_hash in sd_hashes, self.streams))
sd_hashes = [sd for sd in sd_hashes if sd in self.streams]
batch = []
while streams:
stream = streams.pop()
while sd_hashes:
stream = self.streams[sd_hashes.pop()]
if self.blob_manager.is_blob_verified(stream.sd_hash) and stream.blobs_completed:
if not stream.fully_reflected.is_set():
host, port = random.choice(self.config.reflector_servers)
batch.append(stream.upload_to_reflector(host, port))
if len(batch) >= self.config.concurrent_reflector_uploads:
await asyncio.gather(*batch)
await asyncio.gather(*batch, loop=self.loop)
batch = []
if batch:
await asyncio.gather(*batch)
await asyncio.gather(*batch, loop=self.loop)
await asyncio.sleep(300, loop=self.loop)
async def start(self):
@ -244,7 +213,7 @@ class StreamManager:
if self.re_reflect_task and not self.re_reflect_task.done():
self.re_reflect_task.cancel()
while self.streams:
stream = self.streams.pop()
_, stream = self.streams.popitem()
stream.stop_download()
while self.update_stream_finished_futs:
self.update_stream_finished_futs.pop().cancel()
@ -253,8 +222,8 @@ class StreamManager:
async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream:
stream = await ManagedStream.create(self.loop, self.blob_manager, file_path, key, iv_generator)
self.streams.add(stream)
stream = await ManagedStream.create(self.loop, self.config, self.blob_manager, file_path, key, iv_generator)
self.streams[stream.sd_hash] = stream
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
if self.config.reflect_streams and self.config.reflector_servers:
host, port = random.choice(self.config.reflector_servers)
@ -268,8 +237,8 @@ class StreamManager:
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
await self.stop_stream(stream)
if stream in self.streams:
self.streams.remove(stream)
if stream.sd_hash in self.streams:
del self.streams[stream.sd_hash]
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
await self.blob_manager.delete_blobs(blob_hashes, delete_from_db=False)
await self.storage.delete_stream(stream.descriptor)
@ -277,7 +246,7 @@ class StreamManager:
os.remove(stream.full_path)
def get_stream_by_stream_hash(self, stream_hash: str) -> typing.Optional[ManagedStream]:
streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams))
streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams.values()))
if streams:
return streams[0]
@ -304,53 +273,24 @@ class StreamManager:
if search_by:
comparison = comparison or 'eq'
streams = []
for stream in self.streams:
for stream in self.streams.values():
for search, val in search_by.items():
if comparison_operators[comparison](getattr(stream, search), val):
streams.append(stream)
break
else:
streams = list(self.streams)
streams = list(self.streams.values())
if sort_by:
streams.sort(key=lambda s: getattr(s, sort_by))
if reverse:
streams.reverse()
return streams
def wait_for_stream_finished(self, stream: ManagedStream):
async def _wait_for_stream_finished():
if stream.downloader and stream.running:
await stream.downloader.stream_finished_event.wait()
stream.update_status(ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
stream.download_id, stream.claim_name, stream.sd_hash
))
task = self.loop.create_task(_wait_for_stream_finished())
self.update_stream_finished_futs.append(task)
task.add_done_callback(
lambda _: None if task not in self.update_stream_finished_futs else
self.update_stream_finished_futs.remove(task)
)
async def _store_stream(self, downloader: StreamDownloader) -> int:
file_name = os.path.basename(downloader.output_path)
download_directory = os.path.dirname(downloader.output_path)
if not await self.storage.stream_exists(downloader.sd_hash):
await self.storage.store_stream(downloader.sd_blob, downloader.descriptor)
if not await self.storage.file_exists(downloader.sd_hash):
return await self.storage.save_downloaded_file(
downloader.descriptor.stream_hash, file_name, download_directory,
0.0
)
else:
return await self.storage.rowid_for_stream(downloader.descriptor.stream_hash)
async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint)
if existing:
if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0])
return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
@ -363,6 +303,7 @@ class StreamManager:
existing[0].stream_hash, outpoint
)
await self._update_content_claim(existing[0])
if not existing[0].running:
await self.start_stream(existing[0])
return existing[0], None
else:
@ -372,35 +313,32 @@ class StreamManager:
return None, existing_for_claim_id[0]
return None, None
async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: StreamDownloader,
download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict,
file_name: typing.Optional[str] = None) -> ManagedStream:
start_time = self.loop.time()
downloader.download(self.node)
await downloader.got_descriptor.wait()
got_descriptor_time.set_result(self.loop.time() - start_time)
rowid = await self._store_stream(downloader)
await self.storage.save_content_claim(
downloader.descriptor.stream_hash, outpoint
)
stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir,
file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id)
stream.set_claim(resolved, claim)
await stream.downloader.wrote_bytes_event.wait()
self.streams.add(stream)
return stream
async def _download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager',
file_name: typing.Optional[str] = None) -> ManagedStream:
@cache_concurrent
async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout
start_time = self.loop.time()
resolved_time = None
stream = None
error = None
outpoint = None
try:
# resolve the claim
parsed_uri = parse_lbry_uri(uri)
if parsed_uri.is_channel:
raise ResolveError("cannot download a channel claim, specify a /path")
# resolve the claim
resolved = (await self.wallet.ledger.resolve(0, 10, uri)).get(uri, {})
try:
resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout)
except asyncio.TimeoutError:
raise ResolveTimeout(uri)
await self.storage.save_claims_for_resolve([
value for value in resolved_result.values() if 'error' not in value
])
resolved = resolved_result.get(uri, {})
resolved = resolved if 'value' in resolved else resolved.get('claim')
if not resolved:
raise ResolveError(f"Failed to resolve stream at '{uri}'")
if 'error' in resolved:
@ -415,9 +353,10 @@ class StreamManager:
if updated_stream:
return updated_stream
content_fee = None
# check that the fee is payable
fee_amount, fee_address = None, None
if claim.stream.has_fee:
if not to_replace and claim.stream.has_fee:
fee_amount = round(exchange_rate_manager.convert_currency(
claim.stream.fee.currency, "LBC", claim.stream.fee.amount
), 5)
@ -435,79 +374,58 @@ class StreamManager:
raise InsufficientFundsError(msg)
fee_address = claim.stream.fee.address
# download the stream
download_id = binascii.hexlify(generate_id()).decode()
downloader = StreamDownloader(self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash,
self.config.download_dir, file_name)
stream = None
descriptor_time_fut = self.loop.create_future()
start_download_time = self.loop.time()
time_to_descriptor = None
time_to_first_bytes = None
error = None
try:
stream = await asyncio.wait_for(
asyncio.ensure_future(
self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved,
file_name)
), timeout
content_fee = await self.wallet.send_amount_to_address(
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
)
time_to_descriptor = await descriptor_time_fut
time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor
self.wait_for_stream_finished(stream)
if fee_address and fee_amount and not to_replace:
stream.tx = await self.wallet.send_amount_to_address(
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1'))
elif to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
log.info("paid fee of %s for %s", fee_amount, uri)
download_directory = download_directory or self.config.download_dir
if not file_name and (not self.config.save_files or not save_file):
download_dir, file_name = None, None
stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee,
analytics_manager=self.analytics_manager
)
log.info("starting download for %s", uri)
try:
await asyncio.wait_for(stream.setup(
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
), timeout, loop=self.loop)
except asyncio.TimeoutError:
if descriptor_time_fut.done():
time_to_descriptor = descriptor_time_fut.result()
error = DownloadDataTimeout(downloader.sd_hash)
self.blob_manager.delete_blob(downloader.sd_hash)
await self.storage.delete_stream(downloader.descriptor)
else:
descriptor_time_fut.cancel()
error = DownloadSDTimeout(downloader.sd_hash)
if stream:
await self.stop_stream(stream)
else:
downloader.stop()
if error:
log.warning(error)
if self.analytics_manager:
if not stream.descriptor:
raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
stream.set_claim(resolved, claim)
await self.storage.save_content_claim(stream.stream_hash, outpoint)
self.streams[stream.sd_hash] = stream
return stream
except Exception as err:
error = err
if stream and stream.descriptor:
await self.storage.delete_stream(stream.descriptor)
await self.blob_manager.delete_blob(stream.sd_hash)
finally:
if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or
stream.downloader.time_to_first_bytes))):
self.loop.create_task(
self.analytics_manager.send_time_to_first_bytes(
resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint,
resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id,
uri, outpoint,
None if not stream else len(stream.downloader.blob_downloader.active_connections),
None if not stream else len(stream.downloader.blob_downloader.scores),
False if not downloader else downloader.added_fixed_peers,
self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay,
claim.stream.source.sd_hash, time_to_descriptor,
False if not stream else stream.downloader.added_fixed_peers,
self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay,
None if not stream else stream.sd_hash,
None if not stream else stream.downloader.time_to_descriptor,
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
time_to_first_bytes, None if not error else error.__class__.__name__
None if not stream else stream.downloader.time_to_first_bytes,
None if not error else error.__class__.__name__
)
)
if error:
raise error
return stream
async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
file_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None) -> ManagedStream:
timeout = timeout or self.config.download_timeout
if uri in self.starting_streams:
return await self.starting_streams[uri]
fut = asyncio.Future(loop=self.loop)
self.starting_streams[uri] = fut
try:
stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name)
fut.set_result(stream)
except Exception as err:
fut.set_exception(err)
try:
return await fut
finally:
del self.starting_streams[uri]

View file

@ -18,7 +18,7 @@ from lbrynet.extras.daemon.Components import (
)
from lbrynet.extras.daemon.ComponentManager import ComponentManager
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.reflector.server import ReflectorServer
from lbrynet.blob_exchange.server import BlobServer
@ -107,9 +107,11 @@ class CommandTestCase(IntegrationTestCase):
server_tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, server_tmp_dir)
self.server_storage = SQLiteStorage(Config(), ':memory:')
self.server_config = Config()
self.server_storage = SQLiteStorage(self.server_config, ':memory:')
await self.server_storage.open()
self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage)
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_config)
self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP')
self.server.start_server(5567, '127.0.0.1')
await self.server.started_listening.wait()

View file

@ -14,6 +14,7 @@ import pkg_resources
import contextlib
import certifi
import aiohttp
import functools
from lbrynet.schema.claim import Claim
from lbrynet.cryptoutils import get_lbry_hash_obj
@ -146,6 +147,44 @@ def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
cancel_task(tasks.pop())
def async_timed_cache(duration: int):
def wrapper(fn):
cache: typing.Dict[typing.Tuple,
typing.Tuple[typing.Any, float]] = {}
@functools.wraps(fn)
async def _inner(*args, **kwargs) -> typing.Any:
loop = asyncio.get_running_loop()
now = loop.time()
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
if key in cache and (now - cache[key][1] < duration):
return cache[key][0]
to_cache = await fn(*args, **kwargs)
cache[key] = to_cache, now
return to_cache
return _inner
return wrapper
def cache_concurrent(async_fn):
"""
When the decorated function has concurrent calls made to it with the same arguments, only run it once
"""
cache: typing.Dict = {}
@functools.wraps(async_fn)
async def wrapper(*args, **kwargs):
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
try:
return await cache[key]
finally:
cache.pop(key, None)
return wrapper
@async_timed_cache(300)
async def resolve_host(url: str, port: int, proto: str) -> str:
if proto not in ['udp', 'tcp']:
raise Exception("invalid protocol")

View file

@ -5,7 +5,7 @@ import socket
import ipaddress
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
import logging
@ -32,7 +32,7 @@ async def main(blob_hash: str, url: str):
host = host_info[0][4][0]
storage = SQLiteStorage(conf, os.path.join(conf.data_dir, "lbrynet.sqlite"))
blob_manager = BlobFileManager(loop, os.path.join(conf.data_dir, "blobfiles"), storage)
blob_manager = BlobManager(loop, os.path.join(conf.data_dir, "blobfiles"), storage)
await storage.open()
await blob_manager.setup()

View file

@ -325,6 +325,11 @@ class Examples(CommandTestCase):
'get', file_uri
)
await r(
'Save a file to the downloads directory',
'file', 'save', f"--sd_hash=\"{file_list_result[0]['sd_hash']}\""
)
# blobs
bloblist = await r(

View file

@ -1,7 +1,7 @@
import sys
import os
import asyncio
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.server import BlobServer
from lbrynet.schema.address import decode_address
from lbrynet.extras.daemon.storage import SQLiteStorage
@ -17,7 +17,7 @@ async def main(address: str):
storage = SQLiteStorage(os.path.expanduser("~/.lbrynet/lbrynet.sqlite"))
await storage.open()
blob_manager = BlobFileManager(loop, os.path.expanduser("~/.lbrynet/blobfiles"), storage)
blob_manager = BlobManager(loop, os.path.expanduser("~/.lbrynet/blobfiles"), storage)
await blob_manager.setup()
server = await loop.create_server(

View file

@ -61,6 +61,7 @@ def mock_network_loop(loop: asyncio.BaseEventLoop):
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.is_closing = lambda: False
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto
transport.sendto = mock_sock.sendto

View file

@ -518,8 +518,12 @@ class StreamCommands(CommandTestCase):
file.flush()
tx1 = await self.publish('foo', bid='1.0', file_path=file.name)
self.assertEqual(1, len(self.daemon.jsonrpc_file_list()))
# doesn't error on missing arguments when doing an update stream
tx2 = await self.publish('foo', tags='updated')
self.assertEqual(1, len(self.daemon.jsonrpc_file_list()))
self.assertEqual(
tx1['outputs'][0]['claim_id'],
tx2['outputs'][0]['claim_id']
@ -530,12 +534,14 @@ class StreamCommands(CommandTestCase):
with self.assertRaisesRegex(Exception, "There are 2 claims for 'foo'"):
await self.daemon.jsonrpc_publish('foo')
self.assertEqual(2, len(self.daemon.jsonrpc_file_list()))
# abandon duplicate stream
await self.stream_abandon(tx3['outputs'][0]['claim_id'])
# publish to a channel
await self.channel_create('@abc')
tx3 = await self.publish('foo', channel_name='@abc')
self.assertEqual(2, len(self.daemon.jsonrpc_file_list()))
r = await self.resolve('lbry://@abc/foo')
self.assertEqual(
r['lbry://@abc/foo']['claim']['claim_id'],
@ -544,6 +550,7 @@ class StreamCommands(CommandTestCase):
# publishing again re-signs with the same channel
tx4 = await self.publish('foo', languages='uk-UA')
self.assertEqual(2, len(self.daemon.jsonrpc_file_list()))
r = await self.resolve('lbry://@abc/foo')
claim = r['lbry://@abc/foo']['claim']
self.assertEqual(claim['txid'], tx4['outputs'][0]['txid'])

View file

@ -36,12 +36,12 @@ class FileCommands(CommandTestCase):
await self.server.blob_manager.delete_blobs(all_except_sd)
resp = await self.daemon.jsonrpc_get('lbry://foo', timeout=2)
self.assertIn('error', resp)
self.assertEquals('Failed to download data blobs for sd hash %s within timeout' % sd_hash, resp['error'])
self.assertEqual('Failed to download data blobs for sd hash %s within timeout' % sd_hash, resp['error'])
await self.daemon.jsonrpc_file_delete(claim_name='foo')
await self.server.blob_manager.delete_blobs([sd_hash])
resp = await self.daemon.jsonrpc_get('lbry://foo', timeout=2)
self.assertIn('error', resp)
self.assertEquals('Failed to download sd blob %s within timeout' % sd_hash, resp['error'])
self.assertEqual('Failed to download sd blob %s within timeout' % sd_hash, resp['error'])
async def wait_files_to_complete(self):
while self.sout(self.daemon.jsonrpc_file_list(status='running')):
@ -59,17 +59,14 @@ class FileCommands(CommandTestCase):
await self.daemon.stream_manager.start()
await asyncio.wait_for(self.wait_files_to_complete(), timeout=5) # if this hangs, file didnt get set completed
# check that internal state got through up to the file list API
downloader = self.daemon.stream_manager.get_stream_by_stream_hash(file_info['stream_hash']).downloader
file_info = self.sout(self.daemon.jsonrpc_file_list())[0]
self.assertEqual(downloader.output_file_name, file_info['file_name'])
stream = self.daemon.stream_manager.get_stream_by_stream_hash(file_info['stream_hash'])
file_info = self.sout(self.daemon.jsonrpc_file_list()[0])
self.assertEqual(stream.file_name, file_info['file_name'])
# checks if what the API shows is what he have at the very internal level.
self.assertEqual(downloader.output_path, file_info['download_path'])
# if you got here refactoring just change above, but ensure what gets set internally gets reflected externally!
self.assertTrue(downloader.output_path.endswith(downloader.output_file_name))
# this used to be inconsistent, if it becomes again it would create weird bugs, so worth checking
self.assertEqual(stream.full_path, file_info['download_path'])
async def test_incomplete_downloads_erases_output_file_on_stop(self):
tx = await self.stream_create('foo', '0.01')
tx = await self.stream_create('foo', '0.01', data=b'deadbeef' * 1000000)
sd_hash = tx['outputs'][0]['value']['source']['sd_hash']
file_info = self.sout(self.daemon.jsonrpc_file_list())[0]
await self.daemon.jsonrpc_file_delete(claim_name='foo')
@ -77,25 +74,27 @@ class FileCommands(CommandTestCase):
await self.server_storage.get_stream_hash_for_sd_hash(sd_hash)
)
all_except_sd_and_head = [
blob.blob_hash for blob in blobs[1:] if blob.blob_hash
blob.blob_hash for blob in blobs[1:-1]
]
await self.server.blob_manager.delete_blobs(all_except_sd_and_head)
self.assertFalse(os.path.isfile(os.path.join(self.daemon.conf.download_dir, file_info['file_name'])))
path = os.path.join(self.daemon.conf.download_dir, file_info['file_name'])
self.assertFalse(os.path.isfile(path))
resp = await self.out(self.daemon.jsonrpc_get('lbry://foo', timeout=2))
self.assertNotIn('error', resp)
self.assertTrue(os.path.isfile(os.path.join(self.daemon.conf.download_dir, file_info['file_name'])))
self.assertTrue(os.path.isfile(path))
self.daemon.stream_manager.stop()
self.assertFalse(os.path.isfile(os.path.join(self.daemon.conf.download_dir, file_info['file_name'])))
await asyncio.sleep(0.01, loop=self.loop) # FIXME: this sleep should not be needed
self.assertFalse(os.path.isfile(path))
async def test_incomplete_downloads_retry(self):
tx = await self.stream_create('foo', '0.01')
tx = await self.stream_create('foo', '0.01', data=b'deadbeef' * 1000000)
sd_hash = tx['outputs'][0]['value']['source']['sd_hash']
await self.daemon.jsonrpc_file_delete(claim_name='foo')
blobs = await self.server_storage.get_blobs_for_stream(
await self.server_storage.get_stream_hash_for_sd_hash(sd_hash)
)
all_except_sd_and_head = [
blob.blob_hash for blob in blobs[1:] if blob.blob_hash
blob.blob_hash for blob in blobs[1:-1]
]
# backup server blobs
@ -143,7 +142,7 @@ class FileCommands(CommandTestCase):
os.rename(missing_blob.file_path + '__', missing_blob.file_path)
self.server_blob_manager.blobs.clear()
missing_blob = self.server_blob_manager.get_blob(missing_blob_hash)
await self.server_blob_manager.blob_completed(missing_blob)
self.server_blob_manager.blob_completed(missing_blob)
await asyncio.wait_for(self.wait_files_to_complete(), timeout=1)
async def test_paid_download(self):
@ -176,9 +175,8 @@ class FileCommands(CommandTestCase):
)
await self.daemon.jsonrpc_file_delete(claim_name='icanpay')
await self.assertBalance(self.account, '9.925679')
response = await self.out(self.daemon.jsonrpc_get('lbry://icanpay'))
self.assertNotIn('error', response)
await self.on_transaction_dict(response['tx'])
response = await self.daemon.jsonrpc_get('lbry://icanpay')
await self.ledger.wait(response.content_fee)
await self.assertBalance(self.account, '8.925555')
self.assertEqual(len(self.daemon.jsonrpc_file_list()), 1)

View file

@ -0,0 +1,358 @@
import os
import hashlib
import aiohttp
import aiohttp.web
from lbrynet.utils import aiohttp_request
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.testcase import CommandTestCase
def get_random_bytes(n: int) -> bytes:
result = b''.join(hashlib.sha256(os.urandom(4)).digest() for _ in range(n // 16))
if len(result) < n:
result += os.urandom(n - len(result))
elif len(result) > n:
result = result[:-(len(result) - n)]
assert len(result) == n, (n, len(result))
return result
class RangeRequests(CommandTestCase):
async def _restart_stream_manager(self):
self.daemon.stream_manager.stop()
await self.daemon.stream_manager.start()
return
async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False):
self.daemon.conf.save_blobs = save_blobs
self.daemon.conf.save_files = save_files
self.data = data
await self.stream_create('foo', '0.01', data=self.data)
if save_blobs:
self.assertTrue(len(os.listdir(self.daemon.blob_manager.blob_dir)) > 1)
await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait()
await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, claim_name='foo')
self.assertEqual(0, len(os.listdir(self.daemon.blob_manager.blob_dir)))
# await self._restart_stream_manager()
await self.daemon.runner.setup()
site = aiohttp.web.TCPSite(self.daemon.runner, self.daemon.conf.api_host, self.daemon.conf.api_port)
await site.start()
self.assertListEqual(self.daemon.jsonrpc_file_list(), [])
async def _test_range_requests(self):
name = 'foo'
url = f'http://{self.daemon.conf.api_host}:{self.daemon.conf.api_port}/get/{name}'
async with aiohttp_request('get', url) as req:
self.assertEqual(req.headers.get('Content-Type'), 'application/octet-stream')
content_range = req.headers.get('Content-Range')
content_length = int(req.headers.get('Content-Length'))
streamed_bytes = await req.content.read()
self.assertEqual(content_length, len(streamed_bytes))
return streamed_bytes, content_range, content_length
async def test_range_requests_2_byte(self):
self.data = b'hi'
await self._setup_stream(self.data)
streamed, content_range, content_length = await self._test_range_requests()
self.assertEqual(15, content_length)
self.assertEqual(b'hi\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', streamed)
self.assertEqual('bytes 0-14/15', content_range)
async def test_range_requests_15_byte(self):
self.data = b'123456789abcdef'
await self._setup_stream(self.data)
streamed, content_range, content_length = await self._test_range_requests()
self.assertEqual(15, content_length)
self.assertEqual(15, len(streamed))
self.assertEqual(self.data, streamed)
self.assertEqual('bytes 0-14/15', content_range)
async def test_range_requests_0_padded_bytes(self, size: int = (MAX_BLOB_SIZE - 1) * 4,
expected_range: str = 'bytes 0-8388603/8388604', padding=b''):
self.data = get_random_bytes(size)
await self._setup_stream(self.data)
streamed, content_range, content_length = await self._test_range_requests()
self.assertEqual(len(self.data + padding), content_length)
self.assertEqual(streamed, self.data + padding)
self.assertEqual(expected_range, content_range)
async def test_range_requests_1_padded_bytes(self):
await self.test_range_requests_0_padded_bytes(
((MAX_BLOB_SIZE - 1) * 4) - 1, padding=b'\x00'
)
async def test_range_requests_2_padded_bytes(self):
await self.test_range_requests_0_padded_bytes(
((MAX_BLOB_SIZE - 1) * 4) - 2, padding=b'\x00' * 2
)
async def test_range_requests_14_padded_bytes(self):
await self.test_range_requests_0_padded_bytes(
((MAX_BLOB_SIZE - 1) * 4) - 14, padding=b'\x00' * 14
)
async def test_range_requests_15_padded_bytes(self):
await self.test_range_requests_0_padded_bytes(
((MAX_BLOB_SIZE - 1) * 4) - 15, padding=b'\x00' * 15
)
async def test_range_requests_last_block_of_last_blob_padding(self):
self.data = get_random_bytes(((MAX_BLOB_SIZE - 1) * 4) - 16)
await self._setup_stream(self.data)
streamed, content_range, content_length = await self._test_range_requests()
self.assertEqual(len(self.data), content_length)
self.assertEqual(streamed, self.data)
self.assertEqual('bytes 0-8388587/8388588', content_range)
async def test_streaming_only_with_blobs(self):
self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4)
await self._setup_stream(self.data)
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
# test that repeated range requests do not create duplicate files
for _ in range(3):
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
# test that a range request after restart does not create a duplicate file
await self._restart_stream_manager()
current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
stream = self.daemon.jsonrpc_file_list()[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
async def test_streaming_only_without_blobs(self):
self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4)
await self._setup_stream(self.data, save_blobs=False)
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
# test that repeated range requests do not create duplicate files
for _ in range(3):
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
# test that a range request after restart does not create a duplicate file
await self._restart_stream_manager()
current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.download_directory)
self.assertIsNone(stream.full_path)
current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
async def test_stream_and_save_file_with_blobs(self):
self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4)
await self._setup_stream(self.data, save_files=True)
await self._test_range_requests()
streams = self.daemon.jsonrpc_file_list()
self.assertEqual(1, len(streams))
stream = streams[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
full_path = stream.full_path
files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
for _ in range(3):
await self._test_range_requests()
streams = self.daemon.jsonrpc_file_list()
self.assertEqual(1, len(streams))
stream = streams[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
await self._restart_stream_manager()
current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
streams = self.daemon.jsonrpc_file_list()
self.assertEqual(1, len(streams))
stream = streams[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
await self._test_range_requests()
streams = self.daemon.jsonrpc_file_list()
self.assertEqual(1, len(streams))
stream = streams[0]
self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path))
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
with open(stream.full_path, 'rb') as f:
self.assertEqual(self.data, f.read())
async def test_stream_and_save_file_without_blobs(self):
self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4)
await self._setup_stream(self.data, save_files=True)
self.daemon.conf.save_blobs = False
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
full_path = stream.full_path
files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
for _ in range(3):
await self._test_range_requests()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
await self._restart_stream_manager()
current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
streams = self.daemon.jsonrpc_file_list()
self.assertEqual(1, len(streams))
stream = streams[0]
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
await self._test_range_requests()
streams = self.daemon.jsonrpc_file_list()
self.assertEqual(1, len(streams))
stream = streams[0]
self.assertTrue(os.path.isdir(stream.download_directory))
self.assertTrue(os.path.isfile(stream.full_path))
current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path)))
self.assertEqual(
len(files_in_download_dir), len(current_files_in_download_dir)
)
with open(stream.full_path, 'rb') as f:
self.assertEqual(self.data, f.read())
async def test_switch_save_blobs_while_running(self):
await self.test_streaming_only_without_blobs()
self.daemon.conf.save_blobs = True
blobs_in_stream = self.daemon.jsonrpc_file_list()[0].blobs_in_stream
sd_hash = self.daemon.jsonrpc_file_list()[0].sd_hash
start_file_count = len(os.listdir(self.daemon.blob_manager.blob_dir))
await self._test_range_requests()
self.assertEqual(start_file_count + blobs_in_stream, len(os.listdir(self.daemon.blob_manager.blob_dir)))
self.assertEqual(0, self.daemon.jsonrpc_file_list()[0].blobs_remaining)
# switch back
self.daemon.conf.save_blobs = False
await self._test_range_requests()
self.assertEqual(start_file_count + blobs_in_stream, len(os.listdir(self.daemon.blob_manager.blob_dir)))
self.assertEqual(0, self.daemon.jsonrpc_file_list()[0].blobs_remaining)
await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, sd_hash=sd_hash)
self.assertEqual(start_file_count, len(os.listdir(self.daemon.blob_manager.blob_dir)))
await self._test_range_requests()
self.assertEqual(start_file_count, len(os.listdir(self.daemon.blob_manager.blob_dir)))
self.assertEqual(blobs_in_stream, self.daemon.jsonrpc_file_list()[0].blobs_remaining)
async def test_file_save_streaming_only_save_blobs(self):
await self.test_streaming_only_with_blobs()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.full_path)
self.server.stop_server()
await self.daemon.jsonrpc_file_save('test', self.daemon.conf.data_dir)
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNotNone(stream.full_path)
await stream.finished_writing.wait()
with open(stream.full_path, 'rb') as f:
self.assertEqual(self.data, f.read())
await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, sd_hash=stream.sd_hash)
async def test_file_save_stop_before_finished_streaming_only(self, wait_for_start_writing=False):
await self.test_streaming_only_with_blobs()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.full_path)
self.server.stop_server()
await self.daemon.jsonrpc_file_save('test', self.daemon.conf.data_dir)
stream = self.daemon.jsonrpc_file_list()[0]
path = stream.full_path
self.assertIsNotNone(path)
if wait_for_start_writing:
await stream.started_writing.wait()
self.assertTrue(os.path.isfile(path))
await self._restart_stream_manager()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.full_path)
self.assertFalse(os.path.isfile(path))
async def test_file_save_stop_before_finished_streaming_only_wait_for_start(self):
return await self.test_file_save_stop_before_finished_streaming_only(wait_for_start_writing=True)
async def test_file_save_streaming_only_dont_save_blobs(self):
await self.test_streaming_only_without_blobs()
stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNone(stream.full_path)
await self.daemon.jsonrpc_file_save('test', self.daemon.conf.data_dir)
stream = self.daemon.jsonrpc_file_list()[0]
await stream.finished_writing.wait()
with open(stream.full_path, 'rb') as f:
self.assertEqual(self.data, f.read())

View file

@ -3,34 +3,207 @@ import tempfile
import shutil
import os
from torba.testcase import AsyncioTestCase
from lbrynet.error import InvalidDataError, InvalidBlobHashError
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob
class TestBlobfile(AsyncioTestCase):
async def test_create_blob(self):
class TestBlob(AsyncioTestCase):
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
loop = asyncio.get_event_loop()
async def asyncSetUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(self.tmp_dir))
self.loop = asyncio.get_running_loop()
self.config = Config()
self.storage = SQLiteStorage(self.config, ":memory:", self.loop)
self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.config)
await self.storage.open()
def _get_blob(self, blob_class=AbstractBlob, blob_directory=None):
blob = blob_class(self.loop, self.blob_hash, len(self.blob_bytes), self.blob_manager.blob_completed,
blob_directory=blob_directory)
self.assertFalse(blob.get_is_verified())
self.addCleanup(blob.close)
return blob
async def _test_create_blob(self, blob_class=AbstractBlob, blob_directory=None):
blob = self._get_blob(blob_class, blob_directory)
writer = blob.get_blob_writer()
writer.write(self.blob_bytes)
await blob.verified.wait()
self.assertTrue(blob.get_is_verified())
await asyncio.sleep(0, loop=self.loop) # wait for the db save task
return blob
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
blob = self._get_blob(blob_class, blob_directory=blob_directory)
writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)]
self.assertEqual(5, len(blob.writers))
# test that writing too much causes the writer to fail with InvalidDataError and to be removed
with self.assertRaises(InvalidDataError):
writers[1].write(self.blob_bytes * 2)
await writers[1].finished
await asyncio.sleep(0, loop=self.loop)
self.assertEqual(4, len(blob.writers))
# write the blob
other = writers[2]
writers[3].write(self.blob_bytes)
await blob.verified.wait()
self.assertTrue(blob.get_is_verified())
self.assertEqual(0, len(blob.writers))
with self.assertRaises(IOError):
other.write(self.blob_bytes)
def _test_ioerror_if_length_not_set(self, blob_class=AbstractBlob, blob_directory=None):
blob = blob_class(
self.loop, self.blob_hash, blob_completed_callback=self.blob_manager.blob_completed,
blob_directory=blob_directory
)
self.addCleanup(blob.close)
writer = blob.get_blob_writer()
with self.assertRaises(IOError):
writer.write(b'')
async def _test_invalid_blob_bytes(self, blob_class=AbstractBlob, blob_directory=None):
blob = blob_class(
self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed,
blob_directory=blob_directory
)
self.addCleanup(blob.close)
writer = blob.get_blob_writer()
writer.write(self.blob_bytes[:-4] + b'fake')
with self.assertRaises(InvalidBlobHashError):
await writer.finished
async def test_add_blob_buffer_to_db(self):
blob = await self._test_create_blob(BlobBuffer)
db_status = await self.storage.get_blob_status(blob.blob_hash)
self.assertEqual(db_status, 'pending')
async def test_add_blob_file_to_db(self):
blob = await self._test_create_blob(BlobFile, self.tmp_dir)
db_status = await self.storage.get_blob_status(blob.blob_hash)
self.assertEqual(db_status, 'finished')
async def test_invalid_blob_bytes(self):
await self._test_invalid_blob_bytes(BlobBuffer)
await self._test_invalid_blob_bytes(BlobFile, self.tmp_dir)
def test_ioerror_if_length_not_set(self):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self._test_ioerror_if_length_not_set(BlobBuffer)
self._test_ioerror_if_length_not_set(BlobFile, tmp_dir)
async def test_create_blob_file(self):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
blob = await self._test_create_blob(BlobFile, tmp_dir)
self.assertIsInstance(blob, BlobFile)
self.assertTrue(os.path.isfile(blob.file_path))
for _ in range(2):
with blob.reader_context() as reader:
self.assertEqual(self.blob_bytes, reader.read())
async def test_create_blob_buffer(self):
blob = await self._test_create_blob(BlobBuffer)
self.assertIsInstance(blob, BlobBuffer)
self.assertIsNotNone(blob._verified_bytes)
# check we can only read the bytes once, and that the buffer is torn down
with blob.reader_context() as reader:
self.assertEqual(self.blob_bytes, reader.read())
self.assertIsNone(blob._verified_bytes)
with self.assertRaises(OSError):
with blob.reader_context() as reader:
self.assertEqual(self.blob_bytes, reader.read())
self.assertIsNone(blob._verified_bytes)
async def test_close_writers_on_finished(self):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
await self._test_close_writers_on_finished(BlobBuffer)
await self._test_close_writers_on_finished(BlobFile, tmp_dir)
async def test_delete(self):
blob_buffer = await self._test_create_blob(BlobBuffer)
self.assertIsInstance(blob_buffer, BlobBuffer)
self.assertIsNotNone(blob_buffer._verified_bytes)
self.assertTrue(blob_buffer.get_is_verified())
blob_buffer.delete()
self.assertIsNone(blob_buffer._verified_bytes)
self.assertFalse(blob_buffer.get_is_verified())
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite"))
blob_manager = BlobFileManager(loop, tmp_dir, storage)
blob_file = await self._test_create_blob(BlobFile, tmp_dir)
self.assertIsInstance(blob_file, BlobFile)
self.assertTrue(os.path.isfile(blob_file.file_path))
self.assertTrue(blob_file.get_is_verified())
blob_file.delete()
self.assertFalse(os.path.isfile(blob_file.file_path))
self.assertFalse(blob_file.get_is_verified())
await storage.open()
await blob_manager.setup()
async def test_delete_corrupt(self):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
blob = BlobFile(
self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed,
blob_directory=tmp_dir
)
writer = blob.get_blob_writer()
writer.write(self.blob_bytes)
await blob.verified.wait()
blob.close()
blob = BlobFile(
self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed,
blob_directory=tmp_dir
)
self.assertTrue(blob.get_is_verified())
# add the blob on the server
blob = blob_manager.get_blob(blob_hash, len(blob_bytes))
self.assertEqual(blob.get_is_verified(), False)
self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes)
with open(blob.file_path, 'wb+') as f:
f.write(b'\x00')
blob = BlobFile(
self.loop, self.blob_hash, len(self.blob_bytes), blob_completed_callback=self.blob_manager.blob_completed,
blob_directory=tmp_dir
)
self.assertFalse(blob.get_is_verified())
self.assertFalse(os.path.isfile(blob.file_path))
writer = blob.open_for_writing()
writer.write(blob_bytes)
await blob.finished_writing.wait()
self.assertTrue(os.path.isfile(blob.file_path), True)
self.assertEqual(blob.get_is_verified(), True)
self.assertIn(blob_hash, blob_manager.completed_blob_hashes)
def test_invalid_blob_hash(self):
self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, '', len(self.blob_bytes))
self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, 'x' * 96, len(self.blob_bytes))
self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, 'a' * 97, len(self.blob_bytes))
async def _test_close_reader(self, blob_class=AbstractBlob, blob_directory=None):
blob = await self._test_create_blob(blob_class, blob_directory)
reader = blob.reader_context()
self.assertEqual(0, len(blob.readers))
async def read_blob_buffer():
with reader as read_handle:
self.assertEqual(1, len(blob.readers))
await asyncio.sleep(2, loop=self.loop)
self.assertEqual(0, len(blob.readers))
return read_handle.read()
self.loop.call_later(1, blob.close)
with self.assertRaises(ValueError) as err:
read_task = self.loop.create_task(read_blob_buffer())
await read_task
self.assertEqual(err.exception, ValueError("I/O operation on closed file"))
async def test_close_reader(self):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
await self._test_close_reader(BlobBuffer)
await self._test_close_reader(BlobFile, tmp_dir)

View file

@ -5,50 +5,53 @@ import os
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
class TestBlobManager(AsyncioTestCase):
async def test_sync_blob_manager_on_startup(self):
loop = asyncio.get_event_loop()
async def setup_blob_manager(self, save_blobs=True):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.config = Config(save_blobs=save_blobs)
self.storage = SQLiteStorage(self.config, os.path.join(tmp_dir, "lbrynet.sqlite"))
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.config)
await self.storage.open()
storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite"))
blob_manager = BlobFileManager(loop, tmp_dir, storage)
async def test_sync_blob_file_manager_on_startup(self):
await self.setup_blob_manager(save_blobs=True)
# add a blob file
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
with open(os.path.join(blob_manager.blob_dir, blob_hash), 'wb') as f:
with open(os.path.join(self.blob_manager.blob_dir, blob_hash), 'wb') as f:
f.write(blob_bytes)
# it should not have been added automatically on startup
await storage.open()
await blob_manager.setup()
self.assertSetEqual(blob_manager.completed_blob_hashes, set())
await self.blob_manager.setup()
self.assertSetEqual(self.blob_manager.completed_blob_hashes, set())
# make sure we can add the blob
await blob_manager.blob_completed(blob_manager.get_blob(blob_hash, len(blob_bytes)))
self.assertSetEqual(blob_manager.completed_blob_hashes, {blob_hash})
await self.blob_manager.blob_completed(self.blob_manager.get_blob(blob_hash, len(blob_bytes)))
self.assertSetEqual(self.blob_manager.completed_blob_hashes, {blob_hash})
# stop the blob manager and restart it, make sure the blob is there
blob_manager.stop()
self.assertSetEqual(blob_manager.completed_blob_hashes, set())
await blob_manager.setup()
self.assertSetEqual(blob_manager.completed_blob_hashes, {blob_hash})
self.blob_manager.stop()
self.assertSetEqual(self.blob_manager.completed_blob_hashes, set())
await self.blob_manager.setup()
self.assertSetEqual(self.blob_manager.completed_blob_hashes, {blob_hash})
# test that the blob is removed upon the next startup after the file being manually deleted
blob_manager.stop()
self.blob_manager.stop()
# manually delete the blob file and restart the blob manager
os.remove(os.path.join(blob_manager.blob_dir, blob_hash))
await blob_manager.setup()
self.assertSetEqual(blob_manager.completed_blob_hashes, set())
os.remove(os.path.join(self.blob_manager.blob_dir, blob_hash))
await self.blob_manager.setup()
self.assertSetEqual(self.blob_manager.completed_blob_hashes, set())
# check that the deleted blob was updated in the database
self.assertEqual(
'pending', (
await storage.run_and_return_one_or_none('select status from blob where blob_hash=?', blob_hash)
await self.storage.run_and_return_one_or_none('select status from blob where blob_hash=?', blob_hash)
)
)

View file

@ -9,9 +9,9 @@ from lbrynet.blob_exchange.serialization import BlobRequest
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
from lbrynet.blob_exchange.client import request_blob
from lbrynet.dht.peer import KademliaPeer, PeerManager
# import logging
@ -35,13 +35,13 @@ class BlobExchangeTestBase(AsyncioTestCase):
self.server_config = Config(data_dir=self.server_dir, download_dir=self.server_dir, wallet=self.server_dir,
reflector_servers=[])
self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite"))
self.server_blob_manager = BlobFileManager(self.loop, self.server_dir, self.server_storage)
self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config)
self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP')
self.client_config = Config(data_dir=self.client_dir, download_dir=self.client_dir, wallet=self.client_dir,
reflector_servers=[])
self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite"))
self.client_blob_manager = BlobFileManager(self.loop, self.client_dir, self.client_storage)
self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config)
self.client_peer_manager = PeerManager(self.loop)
self.server_from_client = KademliaPeer(self.loop, "127.0.0.1", b'1' * 48, tcp_port=33333)
@ -51,6 +51,7 @@ class BlobExchangeTestBase(AsyncioTestCase):
await self.server_blob_manager.setup()
self.server.start_server(33333, '127.0.0.1')
self.addCleanup(self.server.stop_server)
await self.server.started_listening.wait()
@ -58,21 +59,25 @@ class TestBlobExchange(BlobExchangeTestBase):
async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes):
# add the blob on the server
server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes))
writer = server_blob.open_for_writing()
writer = server_blob.get_blob_writer()
writer.write(blob_bytes)
await server_blob.finished_writing.wait()
await server_blob.verified.wait()
self.assertTrue(os.path.isfile(server_blob.file_path))
self.assertEqual(server_blob.get_is_verified(), True)
self.assertTrue(writer.closed())
async def _test_transfer_blob(self, blob_hash: str):
client_blob = self.client_blob_manager.get_blob(blob_hash)
# download the blob
downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address,
downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address,
self.server_from_client.tcp_port, 2, 3)
await client_blob.finished_writing.wait()
self.assertIsNotNone(transport)
self.addCleanup(transport.close)
await client_blob.verified.wait()
self.assertEqual(client_blob.get_is_verified(), True)
self.assertTrue(downloaded)
client_blob.close()
async def test_transfer_sd_blob(self):
sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02"
@ -92,9 +97,11 @@ class TestBlobExchange(BlobExchangeTestBase):
second_client_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, second_client_dir)
second_client_storage = SQLiteStorage(Config(), os.path.join(second_client_dir, "lbrynet.sqlite"))
second_client_blob_manager = BlobFileManager(self.loop, second_client_dir, second_client_storage)
second_client_conf = Config()
second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite"))
second_client_blob_manager = BlobManager(
self.loop, second_client_dir, second_client_storage, second_client_conf
)
server_from_second_client = KademliaPeer(self.loop, "127.0.0.1", b'1' * 48, tcp_port=33333)
await second_client_storage.open()
@ -102,7 +109,7 @@ class TestBlobExchange(BlobExchangeTestBase):
await self._add_blob_to_server(blob_hash, mock_blob_bytes)
second_client_blob = self.client_blob_manager.get_blob(blob_hash)
second_client_blob = second_client_blob_manager.get_blob(blob_hash)
# download the blob
await asyncio.gather(
@ -112,9 +119,65 @@ class TestBlobExchange(BlobExchangeTestBase):
),
self._test_transfer_blob(blob_hash)
)
await second_client_blob.finished_writing.wait()
await second_client_blob.verified.wait()
self.assertEqual(second_client_blob.get_is_verified(), True)
async def test_blob_writers_concurrency(self):
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
blob = self.server_blob_manager.get_blob(blob_hash)
write_blob = blob._write_blob
write_called_count = 0
def wrap_write_blob(blob_bytes):
nonlocal write_called_count
write_called_count += 1
write_blob(blob_bytes)
blob._write_blob = wrap_write_blob
writer1 = blob.get_blob_writer(peer_port=1)
writer2 = blob.get_blob_writer(peer_port=2)
reader1_ctx_before_write = blob.reader_context()
with self.assertRaises(OSError):
blob.get_blob_writer(peer_port=2)
with self.assertRaises(OSError):
with blob.reader_context():
pass
blob.set_length(len(mock_blob_bytes))
results = {}
def check_finished_callback(writer, num):
def inner(writer_future: asyncio.Future):
results[num] = writer_future.result()
writer.finished.add_done_callback(inner)
check_finished_callback(writer1, 1)
check_finished_callback(writer2, 2)
def write_task(writer):
async def _inner():
writer.write(mock_blob_bytes)
return self.loop.create_task(_inner())
await asyncio.gather(write_task(writer1), write_task(writer2), loop=self.loop)
self.assertDictEqual({1: mock_blob_bytes, 2: mock_blob_bytes}, results)
self.assertEqual(1, write_called_count)
self.assertTrue(blob.get_is_verified())
self.assertDictEqual({}, blob.writers)
with reader1_ctx_before_write as f:
self.assertEqual(mock_blob_bytes, f.read())
with blob.reader_context() as f:
self.assertEqual(mock_blob_bytes, f.read())
with blob.reader_context() as f:
blob.close()
with self.assertRaises(ValueError):
f.read()
self.assertListEqual([], blob.readers)
async def test_host_different_blobs_to_multiple_peers_at_once(self):
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
@ -124,9 +187,12 @@ class TestBlobExchange(BlobExchangeTestBase):
second_client_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, second_client_dir)
second_client_conf = Config()
second_client_storage = SQLiteStorage(Config(), os.path.join(second_client_dir, "lbrynet.sqlite"))
second_client_blob_manager = BlobFileManager(self.loop, second_client_dir, second_client_storage)
second_client_storage = SQLiteStorage(second_client_conf, os.path.join(second_client_dir, "lbrynet.sqlite"))
second_client_blob_manager = BlobManager(
self.loop, second_client_dir, second_client_storage, second_client_conf
)
server_from_second_client = KademliaPeer(self.loop, "127.0.0.1", b'1' * 48, tcp_port=33333)
await second_client_storage.open()
@ -143,7 +209,7 @@ class TestBlobExchange(BlobExchangeTestBase):
server_from_second_client.tcp_port, 2, 3
),
self._test_transfer_blob(sd_hash),
second_client_blob.finished_writing.wait()
second_client_blob.verified.wait()
)
self.assertEqual(second_client_blob.get_is_verified(), True)

View file

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from lbrynet import utils
import unittest
import asyncio
from lbrynet import utils
from torba.testcase import AsyncioTestCase
class CompareVersionTest(unittest.TestCase):
@ -61,3 +62,76 @@ class SdHashTests(unittest.TestCase):
}
}
self.assertIsNone(utils.get_sd_hash(claim))
class CacheConcurrentDecoratorTests(AsyncioTestCase):
def setUp(self):
self.called = []
self.finished = []
self.counter = 0
@utils.cache_concurrent
async def foo(self, arg1, arg2=None, delay=1):
self.called.append((arg1, arg2, delay))
await asyncio.sleep(delay, loop=self.loop)
self.counter += 1
self.finished.append((arg1, arg2, delay))
return object()
async def test_gather_duplicates(self):
result = await asyncio.gather(
self.loop.create_task(self.foo(1)), self.loop.create_task(self.foo(1)), loop=self.loop
)
self.assertEqual(1, len(self.called))
self.assertEqual(1, len(self.finished))
self.assertEqual(1, self.counter)
self.assertIs(result[0], result[1])
self.assertEqual(2, len(result))
async def test_one_cancelled_all_cancel(self):
t1 = self.loop.create_task(self.foo(1))
self.loop.call_later(0.1, t1.cancel)
with self.assertRaises(asyncio.CancelledError):
await asyncio.gather(
t1, self.loop.create_task(self.foo(1)), loop=self.loop
)
self.assertEqual(1, len(self.called))
self.assertEqual(0, len(self.finished))
self.assertEqual(0, self.counter)
async def test_error_after_success(self):
def cause_type_error():
self.counter = ""
self.loop.call_later(0.1, cause_type_error)
t1 = self.loop.create_task(self.foo(1))
t2 = self.loop.create_task(self.foo(1))
with self.assertRaises(TypeError):
await t2
self.assertEqual(1, len(self.called))
self.assertEqual(0, len(self.finished))
self.assertTrue(t1.done())
self.assertEqual("", self.counter)
# test that the task is run fresh, it should not error
self.counter = 0
t3 = self.loop.create_task(self.foo(1))
self.assertTrue((await t3))
self.assertEqual(1, self.counter)
# the previously failed call should still raise if awaited
with self.assertRaises(TypeError):
await t1
self.assertEqual(1, self.counter)
async def test_break_it(self):
t1 = self.loop.create_task(self.foo(1))
t2 = self.loop.create_task(self.foo(1))
t3 = self.loop.create_task(self.foo(2, delay=0))
t3.add_done_callback(lambda _: t2.cancel())
with self.assertRaises(asyncio.CancelledError):
await asyncio.gather(t1, t2, t3)

View file

@ -7,7 +7,7 @@ from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_info import BlobInfo
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.descriptor import StreamDescriptor
from tests.test_utils import random_lbry_hash
@ -68,17 +68,18 @@ fake_claim_info = {
class StorageTest(AsyncioTestCase):
async def asyncSetUp(self):
self.storage = SQLiteStorage(Config(), ':memory:')
self.conf = Config()
self.storage = SQLiteStorage(self.conf, ':memory:')
self.blob_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.blob_dir)
self.blob_manager = BlobFileManager(asyncio.get_event_loop(), self.blob_dir, self.storage)
self.blob_manager = BlobManager(asyncio.get_event_loop(), self.blob_dir, self.storage, self.conf)
await self.storage.open()
async def asyncTearDown(self):
await self.storage.close()
async def store_fake_blob(self, blob_hash, length=100):
await self.storage.add_completed_blob(blob_hash, length)
await self.storage.add_blobs((blob_hash, length), finished=True)
async def store_fake_stream(self, stream_hash, blobs=None, file_name="fake_file", key="DEADBEEF"):
blobs = blobs or [BlobInfo(1, 100, "DEADBEEF", random_lbry_hash())]

View file

@ -1,12 +1,13 @@
import asyncio
from torba.testcase import AsyncioTestCase
from unittest import mock, TestCase
from lbrynet.dht.protocol.data_store import DictDataStore
from lbrynet.dht.peer import PeerManager
class DataStoreTests(AsyncioTestCase):
class DataStoreTests(TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.loop = mock.Mock(spec=asyncio.BaseEventLoop)
self.loop.time = lambda: 0.0
self.peer_manager = PeerManager(self.loop)
self.data_store = DictDataStore(self.loop, self.peer_manager)

View file

@ -78,8 +78,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
blob2 = binascii.hexlify(b'2' * 48).decode()
async with self._test_network_context():
await self.storage.add_completed_blob(blob1, 1024)
await self.storage.add_completed_blob(blob2, 1024)
await self.storage.add_blobs((blob1, 1024), (blob2, 1024), finished=True)
await self.storage.db.execute(
"update blob set next_announce_time=0, should_announce=1 where blob_hash in (?, ?)",
(blob1, blob2)

View file

@ -1,117 +0,0 @@
import os
import asyncio
import tempfile
import shutil
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.stream.assembler import StreamAssembler
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.stream_manager import StreamManager
class TestStreamAssembler(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4
self.cleartext = b'test'
async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), ":memory:")
await self.storage.open()
self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage)
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
# create the stream
file_path = os.path.join(tmp_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.cleartext)
sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key)
# copy blob files
sd_hash = sd.calculate_sd_hash()
shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash))
for blob_info in sd.blobs:
if blob_info.blob_hash:
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
await downloader_storage.open()
# add the blobs to the blob table (this would happen upon a blob download finishing)
downloader_blob_manager = BlobFileManager(self.loop, download_dir, downloader_storage)
descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash)
# assemble the decrypted file
assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash)
await assembler.assemble_decrypted_stream(download_dir)
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file")))
with open(os.path.join(download_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.cleartext)
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
# its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
self.assertEqual(
[descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
)
await downloader_storage.close()
await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
self.cleartext = b'test\n' * 20000000
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_padding(self):
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i)
await self.test_create_and_decrypt_one_blob_stream()
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_random(self):
self.cleartext = os.urandom(20000000)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_managed_stream_announces(self):
# setup a blob manager
storage = SQLiteStorage(Config(), ":memory:")
await storage.open()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
blob_manager = BlobFileManager(self.loop, tmp_dir, storage)
stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# create the stream
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
file_path = os.path.join(download_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(b'testtest')
stream = await stream_manager.create_stream(file_path)
self.assertEqual(
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
await storage.get_blobs_to_announce())
async def test_create_truncate_and_handle_stream(self):
self.cleartext = b'potato' * 1337 * 5279
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -1,102 +0,0 @@
import os
import time
import unittest
from unittest import mock
import asyncio
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.conf import Config
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
class TestStreamDownloader(BlobExchangeTestBase):
async def setup_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir,
reflector_servers=[])
self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(self.downloader.stream_handle.closed)
self.assertTrue(os.path.isfile(self.downloader.output_path))
self.downloader.stop()
self.assertIs(self.downloader.stream_handle, None)
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))

View file

@ -0,0 +1,176 @@
import os
import shutil
import unittest
from unittest import mock
import asyncio
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.stream.descriptor import StreamDescriptor
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
def get_mock_node(loop):
mock_node = mock.Mock(spec=Node)
mock_node.joined = asyncio.Event(loop=loop)
mock_node.joined.set()
return mock_node
class TestManagedStream(BlobExchangeTestBase):
async def create_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
return descriptor
async def setup_stream(self, blob_count: int = 10):
await self.create_stream(blob_count)
self.stream = ManagedStream(
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))
async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False):
descriptor = await self.create_stream(blobs)
# copy blob files
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash),
os.path.join(self.client_blob_manager.blob_dir, self.sd_hash))
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
for blob_info in descriptor.blobs[:-1]:
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash),
os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
await self.stream.setup()
await self.stream.finished_writing.wait()
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
with open(os.path.join(self.client_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.stream_bytes)
self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified())
self.assertEqual(
True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
#
# # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
# self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
# self.assertEqual(
# [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
# )
#
# await downloader_storage.close()
# await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
await self.test_create_and_decrypt_one_blob_stream(10)
# async def test_create_managed_stream_announces(self):
# # setup a blob manager
# storage = SQLiteStorage(Config(), ":memory:")
# await storage.open()
# tmp_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(tmp_dir))
# blob_manager = BlobManager(self.loop, tmp_dir, storage)
# stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# # create the stream
# download_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(download_dir))
# file_path = os.path.join(download_dir, "test_file")
# with open(file_path, 'wb') as f:
# f.write(b'testtest')
#
# stream = await stream_manager.create_stream(file_path)
# self.assertEqual(
# [stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
# await storage.get_blobs_to_announce())
# async def test_create_truncate_and_handle_stream(self):
# # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
# await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -5,7 +5,7 @@ import shutil
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.stream_manager import StreamManager
from lbrynet.stream.reflector.server import ReflectorServer
@ -18,16 +18,18 @@ class TestStreamAssembler(AsyncioTestCase):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite"))
self.conf = Config()
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
await self.storage.open()
self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage)
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
server_tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(server_tmp_dir))
self.server_storage = SQLiteStorage(Config(), os.path.join(server_tmp_dir, "lbrynet.sqlite"))
self.server_conf = Config()
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
await self.server_storage.open()
self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage)
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))

View file

@ -9,7 +9,7 @@ from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.error import InvalidStreamDescriptorError
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.descriptor import StreamDescriptor
@ -20,9 +20,10 @@ class TestStreamDescriptor(AsyncioTestCase):
self.cleartext = os.urandom(20000000)
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(self.tmp_dir))
self.storage = SQLiteStorage(Config(), ":memory:")
self.conf = Config()
self.storage = SQLiteStorage(self.conf, ":memory:")
await self.storage.open()
self.blob_manager = BlobFileManager(self.loop, self.tmp_dir, self.storage)
self.blob_manager = BlobManager(self.loop, self.tmp_dir, self.storage, self.conf)
self.file_path = os.path.join(self.tmp_dir, "test_file")
with open(self.file_path, 'wb') as f:
@ -83,9 +84,10 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
loop = asyncio.get_event_loop()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
storage = SQLiteStorage(Config(), ":memory:")
self.conf = Config()
storage = SQLiteStorage(self.conf, ":memory:")
await storage.open()
blob_manager = BlobFileManager(loop, tmp_dir, storage)
blob_manager = BlobManager(loop, tmp_dir, storage, self.conf)
sd_bytes = b'{"stream_name": "4f62616d6120446f6e6b65792d322e73746c", "blobs": [{"length": 1153488, "blob_num' \
b'": 0, "blob_hash": "9fa32a249ce3f2d4e46b78599800f368b72f2a7f22b81df443c7f6bdbef496bd61b4c0079c7' \
@ -99,7 +101,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
blob = blob_manager.get_blob(sd_hash)
blob.set_length(len(sd_bytes))
writer = blob.open_for_writing()
writer = blob.get_blob_writer()
writer.write(sd_bytes)
await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -116,7 +118,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle:
handle.write(b'doesnt work')
blob = BlobFile(loop, tmp_dir, sd_hash)
blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir)
self.assertTrue(blob.file_exists)
self.assertIsNotNone(blob.length)
with self.assertRaises(InvalidStreamDescriptorError):

View file

@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase):
def check_post(event):
self.assertEqual(event['event'], 'Time To First Bytes')
self.assertEqual(event['properties']['error'], 'DownloadSDTimeout')
self.assertEqual(event['properties']['tried_peers_count'], None)
self.assertEqual(event['properties']['active_peer_count'], None)
self.assertEqual(event['properties']['tried_peers_count'], 0)
self.assertEqual(event['properties']['active_peer_count'], 0)
self.assertEqual(event['properties']['use_fixed_peers'], False)
self.assertEqual(event['properties']['added_fixed_peers'], False)
self.assertEqual(event['properties']['fixed_peer_delay'], None)
@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase):
self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set())
self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
stream_hash = stream.stream_hash
self.assertSetEqual(self.stream_manager.streams, {stream})
self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream})
self.assertTrue(stream.running)
self.assertFalse(stream.finished)
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream)
await stream.downloader.stream_finished_event.wait()
await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished)
self.assertFalse(stream.running)
@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "finished")
await self.stream_manager.delete_stream(stream, True)
self.assertSetEqual(self.stream_manager.streams, set())
self.assertDictEqual(self.stream_manager.streams, {})
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash
@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase):
async def _test_download_error_on_start(self, expected_error, timeout=None):
with self.assertRaises(expected_error):
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout)
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
async def _test_download_error_analytics_on_start(self, expected_error, timeout=None):
received = []
@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase):
await self.setup_stream_manager(old_sort=old_sort)
self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set())
self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
await stream.downloader.stream_finished_event.wait()
await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop)
self.stream_manager.stop()
self.client_blob_manager.stop()
@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase):
await self.client_blob_manager.setup()
await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams))
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash)
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status)
self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys()))
for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]:
blob_status = await self.client_storage.get_blob_status(blob_hash)
self.assertEqual('pending', blob_status)
self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists)