From 3a916a8e8e7acbba0d4e957b906abf0e6f494c2a Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Sun, 31 Mar 2019 13:42:27 -0400 Subject: [PATCH] tests --- lbrynet/error.py | 6 + lbrynet/extras/daemon/Daemon.py | 2 +- lbrynet/extras/daemon/storage.py | 11 + lbrynet/stream/downloader.py | 25 +- lbrynet/stream/managed_stream.py | 35 ++- lbrynet/stream/stream_manager.py | 290 +++++++----------- tests/unit/blob/test_blob_file.py | 4 +- .../unit/blob_exchange/test_transfer_blob.py | 19 +- tests/unit/dht/protocol/test_data_store.py | 7 +- tests/unit/stream/test_assembler.py | 117 ------- tests/unit/stream/test_downloader.py | 102 ------ tests/unit/stream/test_managed_stream.py | 175 +++++++++++ tests/unit/stream/test_stream_descriptor.py | 4 +- tests/unit/stream/test_stream_manager.py | 25 +- 14 files changed, 381 insertions(+), 441 deletions(-) delete mode 100644 tests/unit/stream/test_assembler.py delete mode 100644 tests/unit/stream/test_downloader.py create mode 100644 tests/unit/stream/test_managed_stream.py diff --git a/lbrynet/error.py b/lbrynet/error.py index 234bb9229..a57252da0 100644 --- a/lbrynet/error.py +++ b/lbrynet/error.py @@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception): self.download = download +class ResolveTimeout(Exception): + def __init__(self, uri): + super().__init__(f'Failed to resolve "{uri}" within the timeout') + self.uri = uri + + class RequestCanceledError(Exception): pass diff --git a/lbrynet/extras/daemon/Daemon.py b/lbrynet/extras/daemon/Daemon.py index 7af1ad21f..35492ecf9 100644 --- a/lbrynet/extras/daemon/Daemon.py +++ b/lbrynet/extras/daemon/Daemon.py @@ -897,7 +897,7 @@ class Daemon(metaclass=JSONRPCServerType): """ try: stream = await self.stream_manager.download_stream_from_uri( - uri, timeout, self.exchange_rate_manager, file_name + uri, self.exchange_rate_manager, timeout, file_name ) if not stream: raise DownloadSDTimeout(uri) diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index f829a0212..93830eb5b 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -423,6 +423,17 @@ class SQLiteStorage(SQLiteMixin): } return self.db.run(_sync_blobs) + def sync_files_to_blobs(self): + def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]: + transaction.executemany( + "update file set status='stopped' where stream_hash=?", + transaction.execute( + "select distinct sb.stream_hash from stream_blob sb " + "inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'" + ).fetchall() + ) + return self.db.run(_sync_blobs) + # # # # # # # # # stream functions # # # # # # # # # async def stream_exists(self, sd_hash: str) -> bool: diff --git a/lbrynet/stream/downloader.py b/lbrynet/stream/downloader.py index 9f8007677..529e49c80 100644 --- a/lbrynet/stream/downloader.py +++ b/lbrynet/stream/downloader.py @@ -2,6 +2,7 @@ import asyncio import typing import logging import binascii +from lbrynet.error import DownloadSDTimeout from lbrynet.utils import resolve_host from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.blob_exchange.downloader import BlobDownloader @@ -32,6 +33,8 @@ class StreamDownloader: self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None self.fixed_peers_delay: typing.Optional[float] = None self.added_fixed_peers = False + self.time_to_descriptor: typing.Optional[float] = None + self.time_to_first_bytes: typing.Optional[float] = None async def add_fixed_peers(self): def _delayed_add_fixed_peers(): @@ -59,8 +62,16 @@ class StreamDownloader: # download or get the sd blob sd_blob = self.blob_manager.get_blob(self.sd_hash) if not sd_blob.get_is_verified(): - sd_blob = await self.blob_downloader.download_blob(self.sd_hash) - log.info("downloaded sd blob %s", self.sd_hash) + try: + now = self.loop.time() + sd_blob = await asyncio.wait_for( + self.blob_downloader.download_blob(self.sd_hash), + self.config.blob_download_timeout, loop=self.loop + ) + log.info("downloaded sd blob %s", self.sd_hash) + self.time_to_descriptor = self.loop.time() - now + except asyncio.TimeoutError: + raise DownloadSDTimeout(self.sd_hash) # parse the descriptor self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( @@ -101,12 +112,18 @@ class StreamDownloader: binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) ) - async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob'): + async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob) async def read_blob(self, blob_info: 'BlobInfo') -> bytes: + start = None + if self.time_to_first_bytes is None: + start = self.loop.time() blob = await self.download_stream_blob(blob_info) - return await self.decrypt_blob(blob_info, blob) + decrypted = await self.decrypt_blob(blob_info, blob) + if start: + self.time_to_first_bytes = self.loop.time() - start + return decrypted def stop(self): if self.accumulate_task: diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index 4d5f03ccf..ca0a30532 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -9,13 +9,13 @@ from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.reflector.client import StreamReflectorClient from lbrynet.extras.daemon.storage import StoredStreamClaim -from lbrynet.blob import MAX_BLOB_SIZE if typing.TYPE_CHECKING: from lbrynet.conf import Config from lbrynet.schema.claim import Claim from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.blob_info import BlobInfo from lbrynet.dht.node import Node + from lbrynet.extras.daemon.analytics import AnalyticsManager log = logging.getLogger(__name__) @@ -43,7 +43,8 @@ class ManagedStream: sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None, - descriptor: typing.Optional[StreamDescriptor] = None): + descriptor: typing.Optional[StreamDescriptor] = None, + analytics_manager: typing.Optional['AnalyticsManager'] = None): self.loop = loop self.config = config self.blob_manager = blob_manager @@ -56,11 +57,13 @@ class ManagedStream: self.rowid = rowid self.written_bytes = 0 self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) + self.analytics_manager = analytics_manager self.fully_reflected = asyncio.Event(loop=self.loop) self.file_output_task: typing.Optional[asyncio.Task] = None self.delayed_stop: typing.Optional[asyncio.Handle] = None self.saving = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop) + self.started_writing = asyncio.Event(loop=self.loop) @property def descriptor(self) -> StreamDescriptor: @@ -217,16 +220,18 @@ class ManagedStream: return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) - async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True): + async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True, + file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): await self.downloader.start(node) - if not save_file: + if not save_file and not file_name: if not await self.blob_manager.storage.file_exists(self.sd_hash): self.rowid = self.blob_manager.storage.save_downloaded_file( self.stream_hash, None, None, 0.0 ) self.update_delayed_stop() else: - await self.save_file() + await self.save_file(file_name, download_directory) + await self.started_writing.wait() self.update_status(ManagedStream.STATUS_RUNNING) await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) @@ -235,7 +240,6 @@ class ManagedStream: log.info("Stopping inactive download for stream %s", self.sd_hash) self.stop_download() - log.info("update delayed stop") if self.delayed_stop: self.delayed_stop.cancel() self.delayed_stop = self.loop.call_later(60, _delayed_stop) @@ -259,6 +263,7 @@ class ManagedStream: async def _save_file(self, output_path: str): self.saving.set() self.finished_writing.clear() + self.started_writing.clear() try: with open(output_path, 'wb') as file_write_handle: async for blob_info, decrypted in self.aiter_read_stream(): @@ -266,14 +271,21 @@ class ManagedStream: file_write_handle.write(decrypted) file_write_handle.flush() self.written_bytes += len(decrypted) - + if not self.started_writing.is_set(): + self.started_writing.set() + self.update_status(ManagedStream.STATUS_FINISHED) + await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED) + if self.analytics_manager: + self.loop.create_task(self.analytics_manager.send_download_finished( + self.download_id, self.claim_name, self.sd_hash + )) self.finished_writing.set() except Exception as err: + if os.path.isfile(output_path): + log.info("removing incomplete download %s for %s", output_path, self.sd_hash) + os.remove(output_path) if not isinstance(err, asyncio.CancelledError): log.exception("unexpected error encountered writing file for stream %s", self.sd_hash) - if os.path.isfile(output_path): - log.info("removing incomplete download %s", output_path) - os.remove(output_path) raise err finally: self.saving.clear() @@ -282,10 +294,9 @@ class ManagedStream: if self.file_output_task and not self.file_output_task.done(): self.file_output_task.cancel() if self.delayed_stop: - log.info('cancel delayed stop') self.delayed_stop.cancel() self.delayed_stop = None - self.download_directory = download_directory or self.download_directory + self.download_directory = download_directory or self.download_directory or self.config.download_dir if not self.download_directory: raise ValueError("no directory to download to") if not (file_name or self._file_name or self.descriptor.suggested_file_name): diff --git a/lbrynet/stream/stream_manager.py b/lbrynet/stream/stream_manager.py index b99141266..3019d897f 100644 --- a/lbrynet/stream/stream_manager.py +++ b/lbrynet/stream/stream_manager.py @@ -6,8 +6,8 @@ import logging import random from decimal import Decimal from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError -# DownloadDataTimeout, DownloadSDTimeout -from lbrynet.utils import generate_id, cache_concurrent +from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout +from lbrynet.utils import cache_concurrent from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.managed_stream import ManagedStream from lbrynet.schema.claim import Claim @@ -96,11 +96,10 @@ class StreamManager: await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED) async def start_stream(self, stream: ManagedStream): - await stream.setup(self.node, save_file=not self.config.streaming_only) - self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) stream.update_status(ManagedStream.STATUS_RUNNING) await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING) - self.wait_for_stream_finished(stream) + await stream.setup(self.node, save_file=not self.config.streaming_only) + self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) async def recover_streams(self, file_infos: typing.List[typing.Dict]): to_restore = [] @@ -139,13 +138,14 @@ class StreamManager: return stream = ManagedStream( self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status, - claim, rowid=rowid, descriptor=descriptor + claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager ) self.streams[sd_hash] = stream async def load_streams_from_database(self): to_recover = [] to_start = [] + await self.storage.sync_files_to_blobs() for file_info in await self.storage.get_all_lbry_files(): if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): to_recover.append(file_info) @@ -181,10 +181,10 @@ class StreamManager: while True: if self.config.reflect_streams and self.config.reflector_servers: sd_hashes = await self.storage.get_streams_to_re_reflect() - streams = list(filter(lambda s: s in sd_hashes, self.streams.keys())) + sd_hashes = [sd for sd in sd_hashes if sd in self.streams] batch = [] - while streams: - stream = streams.pop() + while sd_hashes: + stream = self.streams[sd_hashes.pop()] if not stream.fully_reflected.is_set(): host, port = random.choice(self.config.reflector_servers) batch.append(stream.upload_to_reflector(host, port)) @@ -198,7 +198,7 @@ class StreamManager: async def start(self): await self.load_streams_from_database() self.resume_downloading_task = self.loop.create_task(self.resume()) - # self.re_reflect_task = self.loop.create_task(self.reflect_streams()) + self.re_reflect_task = self.loop.create_task(self.reflect_streams()) def stop(self): if self.resume_downloading_task and not self.resume_downloading_task.done(): @@ -279,28 +279,11 @@ class StreamManager: streams.reverse() return streams - def wait_for_stream_finished(self, stream: ManagedStream): - async def _wait_for_stream_finished(): - if stream.downloader and stream.running: - await stream.finished_writing.wait() - stream.update_status(ManagedStream.STATUS_FINISHED) - if self.analytics_manager: - self.loop.create_task(self.analytics_manager.send_download_finished( - stream.download_id, stream.claim_name, stream.sd_hash - )) - - task = self.loop.create_task(_wait_for_stream_finished()) - self.update_stream_finished_futs.append(task) - task.add_done_callback( - lambda _: None if task not in self.update_stream_finished_futs else - self.update_stream_finished_futs.remove(task) - ) - async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[ typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: existing = self.get_filtered_streams(outpoint=outpoint) if existing: - if not existing[0].running: + if existing[0].status == ManagedStream.STATUS_STOPPED: await self.start_stream(existing[0]) return existing[0], None existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) @@ -323,163 +306,112 @@ class StreamManager: return None, existing_for_claim_id[0] return None, None - # async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: EncryptedStreamDownloader, - # download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict, - # file_name: typing.Optional[str] = None) -> ManagedStream: - # start_time = self.loop.time() - # downloader.download(self.node) - # await downloader.got_descriptor.wait() - # got_descriptor_time.set_result(self.loop.time() - start_time) - # rowid = await self._store_stream(downloader) - # await self.storage.save_content_claim( - # downloader.descriptor.stream_hash, outpoint - # ) - # stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir, - # file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id) - # stream.set_claim(resolved, claim) - # await stream.downloader.wrote_bytes_event.wait() - # self.streams.add(stream) - # return stream - @cache_concurrent - async def download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager', - file_name: typing.Optional[str] = None) -> ManagedStream: + async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager', + timeout: typing.Optional[float] = None, + file_name: typing.Optional[str] = None, + download_directory: typing.Optional[str] = None, + save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream: + timeout = timeout or self.config.download_timeout start_time = self.loop.time() - parsed_uri = parse_lbry_uri(uri) - if parsed_uri.is_channel: - raise ResolveError("cannot download a channel claim, specify a /path") - - # resolve the claim - resolved_result = await self.wallet.ledger.resolve(0, 10, uri) - await self.storage.save_claims_for_resolve([ - value for value in resolved_result.values() if 'error' not in value - ]) - resolved = resolved_result.get(uri, {}) - resolved = resolved if 'value' in resolved else resolved.get('claim') - if not resolved: - raise ResolveError(f"Failed to resolve stream at '{uri}'") - if 'error' in resolved: - raise ResolveError(f"error resolving stream: {resolved['error']}") + resolved_time = None + stream = None + error = None + outpoint = None + try: + # resolve the claim + parsed_uri = parse_lbry_uri(uri) + if parsed_uri.is_channel: + raise ResolveError("cannot download a channel claim, specify a /path") + try: + resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout) + except asyncio.TimeoutError: + raise ResolveTimeout(uri) + await self.storage.save_claims_for_resolve([ + value for value in resolved_result.values() if 'error' not in value + ]) + resolved = resolved_result.get(uri, {}) + resolved = resolved if 'value' in resolved else resolved.get('claim') + if not resolved: + raise ResolveError(f"Failed to resolve stream at '{uri}'") + if 'error' in resolved: + raise ResolveError(f"error resolving stream: {resolved['error']}") claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf'])) outpoint = f"{resolved['txid']}:{resolved['nout']}" resolved_time = self.loop.time() - start_time - # resume or update an existing stream, if the stream changed download it and delete the old one after - updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) - if updated_stream: - return updated_stream + # resume or update an existing stream, if the stream changed download it and delete the old one after + updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) + if updated_stream: + return updated_stream - # check that the fee is payable - fee_amount, fee_address = None, None - if claim.stream.has_fee: - fee_amount = round(exchange_rate_manager.convert_currency( - claim.stream.fee.currency, "LBC", claim.stream.fee.amount - ), 5) - max_fee_amount = round(exchange_rate_manager.convert_currency( - self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount']) - ), 5) - if fee_amount > max_fee_amount: - msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}" - log.warning(msg) - raise KeyFeeAboveMaxAllowed(msg) - balance = await self.wallet.default_account.get_balance() - if lbc_to_dewies(str(fee_amount)) > balance: - msg = f"fee of {fee_amount} exceeds max available balance" - log.warning(msg) - raise InsufficientFundsError(msg) - fee_address = claim.stream.fee.address - # content_fee_tx = await self.wallet.send_amount_to_address( - # lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') - # ) + # check that the fee is payable + if not to_replace and claim.stream.has_fee: + fee_amount = round(exchange_rate_manager.convert_currency( + claim.stream.fee.currency, "LBC", claim.stream.fee.amount + ), 5) + max_fee_amount = round(exchange_rate_manager.convert_currency( + self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount']) + ), 5) + if fee_amount > max_fee_amount: + msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}" + log.warning(msg) + raise KeyFeeAboveMaxAllowed(msg) + balance = await self.wallet.default_account.get_balance() + if lbc_to_dewies(str(fee_amount)) > balance: + msg = f"fee of {fee_amount} exceeds max available balance" + log.warning(msg) + raise InsufficientFundsError(msg) + fee_address = claim.stream.fee.address + await self.wallet.send_amount_to_address( + lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') + ) + log.info("paid fee of %s for %s", fee_amount, uri) - handled_fee_time = self.loop.time() - resolved_time - start_time - - # download the stream - download_id = binascii.hexlify(generate_id()).decode() - - download_dir = self.config.download_dir - save_file = True - if not file_name and self.config.streaming_only: - download_dir, file_name = None, None - save_file = False - stream = ManagedStream( - self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_dir, - file_name, ManagedStream.STATUS_RUNNING, download_id=download_id - ) - - await stream.setup(self.node, save_file=save_file) - stream.set_claim(resolved, claim) - await self.storage.save_content_claim(stream.stream_hash, outpoint) - self.streams[stream.sd_hash] = stream - - # stream = None - # descriptor_time_fut = self.loop.create_future() - # start_download_time = self.loop.time() - # time_to_descriptor = None - # time_to_first_bytes = None - # error = None - # try: - # stream = await asyncio.wait_for( - # asyncio.ensure_future( - # self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved, - # file_name) - # ), timeout - # ) - # time_to_descriptor = await descriptor_time_fut - # time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor - # self.wait_for_stream_finished(stream) - # if fee_address and fee_amount and not to_replace: - # - # elif to_replace: # delete old stream now that the replacement has started downloading - # await self.delete_stream(to_replace) - # except asyncio.TimeoutError: - # if descriptor_time_fut.done(): - # time_to_descriptor = descriptor_time_fut.result() - # error = DownloadDataTimeout(downloader.sd_hash) - # self.blob_manager.delete_blob(downloader.sd_hash) - # await self.storage.delete_stream(downloader.descriptor) - # else: - # descriptor_time_fut.cancel() - # error = DownloadSDTimeout(downloader.sd_hash) - # if stream: - # await self.stop_stream(stream) - # else: - # downloader.stop() - # if error: - # log.warning(error) - # if self.analytics_manager: - # self.loop.create_task( - # self.analytics_manager.send_time_to_first_bytes( - # resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint, - # None if not stream else len(stream.downloader.blob_downloader.active_connections), - # None if not stream else len(stream.downloader.blob_downloader.scores), - # False if not downloader else downloader.added_fixed_peers, - # self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay, - # claim.source_hash.decode(), time_to_descriptor, - # None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash, - # None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length, - # time_to_first_bytes, None if not error else error.__class__.__name__ - # ) - # ) - # if error: - # raise error - return stream - - # async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager', - # file_name: typing.Optional[str] = None, - # timeout: typing.Optional[float] = None) -> ManagedStream: - # timeout = timeout or self.config.download_timeout - # if uri in self.starting_streams: - # return await self.starting_streams[uri] - # fut = asyncio.Future(loop=self.loop) - # self.starting_streams[uri] = fut - # try: - # stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name) - # fut.set_result(stream) - # except Exception as err: - # fut.set_exception(err) - # try: - # return await fut - # finally: - # del self.starting_streams[uri] + download_directory = download_directory or self.config.download_dir + if not file_name and (self.config.streaming_only or not save_file): + download_dir, file_name = None, None + stream = ManagedStream( + self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory, + file_name, ManagedStream.STATUS_RUNNING, analytics_manager=self.analytics_manager + ) + try: + await asyncio.wait_for(stream.setup( + self.node, save_file=save_file, file_name=file_name, download_directory=download_directory + ), timeout, loop=self.loop) + except asyncio.TimeoutError: + if not stream.descriptor: + raise DownloadSDTimeout(stream.sd_hash) + raise DownloadDataTimeout(stream.sd_hash) + if to_replace: # delete old stream now that the replacement has started downloading + await self.delete_stream(to_replace) + stream.set_claim(resolved, claim) + await self.storage.save_content_claim(stream.stream_hash, outpoint) + self.streams[stream.sd_hash] = stream + return stream + except Exception as err: + error = err + if stream and stream.descriptor: + await self.storage.delete_stream(stream.descriptor) + finally: + if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or + stream.downloader.time_to_first_bytes))): + self.loop.create_task( + self.analytics_manager.send_time_to_first_bytes( + resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id, + uri, outpoint, + None if not stream else len(stream.downloader.blob_downloader.active_connections), + None if not stream else len(stream.downloader.blob_downloader.scores), + False if not stream else stream.downloader.added_fixed_peers, + self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay, + None if not stream else stream.sd_hash, + None if not stream else stream.downloader.time_to_descriptor, + None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash, + None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length, + None if not stream else stream.downloader.time_to_first_bytes, + None if not error else error.__class__.__name__ + ) + ) + if error: + raise error diff --git a/tests/unit/blob/test_blob_file.py b/tests/unit/blob/test_blob_file.py index 173681082..3ac9ee9a6 100644 --- a/tests/unit/blob/test_blob_file.py +++ b/tests/unit/blob/test_blob_file.py @@ -28,9 +28,9 @@ class TestBlobfile(AsyncioTestCase): self.assertEqual(blob.get_is_verified(), False) self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes) - writer = blob.open_for_writing() + writer = blob.get_blob_writer() writer.write(blob_bytes) - await blob.finished_writing.wait() + await blob.verified.wait() self.assertTrue(os.path.isfile(blob.file_path), True) self.assertEqual(blob.get_is_verified(), True) self.assertIn(blob_hash, blob_manager.completed_blob_hashes) diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index e30717a43..5d9c1d9a2 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -11,7 +11,7 @@ from lbrynet.conf import Config from lbrynet.extras.daemon.storage import SQLiteStorage from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol -from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob +from lbrynet.blob_exchange.client import request_blob from lbrynet.dht.peer import KademliaPeer, PeerManager # import logging @@ -58,9 +58,9 @@ class TestBlobExchange(BlobExchangeTestBase): async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes): # add the blob on the server server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes)) - writer = server_blob.open_for_writing() + writer = server_blob.get_blob_writer() writer.write(blob_bytes) - await server_blob.finished_writing.wait() + await server_blob.verified.wait() self.assertTrue(os.path.isfile(server_blob.file_path)) self.assertEqual(server_blob.get_is_verified(), True) @@ -68,11 +68,14 @@ class TestBlobExchange(BlobExchangeTestBase): client_blob = self.client_blob_manager.get_blob(blob_hash) # download the blob - downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address, - self.server_from_client.tcp_port, 2, 3) - await client_blob.finished_writing.wait() + downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address, + self.server_from_client.tcp_port, 2, 3) + self.assertIsNotNone(transport) + self.addCleanup(transport.close) + await client_blob.verified.wait() self.assertEqual(client_blob.get_is_verified(), True) self.assertTrue(downloaded) + self.addCleanup(client_blob.close) async def test_transfer_sd_blob(self): sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02" @@ -112,7 +115,7 @@ class TestBlobExchange(BlobExchangeTestBase): ), self._test_transfer_blob(blob_hash) ) - await second_client_blob.finished_writing.wait() + await second_client_blob.verified.wait() self.assertEqual(second_client_blob.get_is_verified(), True) async def test_host_different_blobs_to_multiple_peers_at_once(self): @@ -143,7 +146,7 @@ class TestBlobExchange(BlobExchangeTestBase): server_from_second_client.tcp_port, 2, 3 ), self._test_transfer_blob(sd_hash), - second_client_blob.finished_writing.wait() + second_client_blob.verified.wait() ) self.assertEqual(second_client_blob.get_is_verified(), True) diff --git a/tests/unit/dht/protocol/test_data_store.py b/tests/unit/dht/protocol/test_data_store.py index 54c58cce9..f8d264ffd 100644 --- a/tests/unit/dht/protocol/test_data_store.py +++ b/tests/unit/dht/protocol/test_data_store.py @@ -1,12 +1,13 @@ import asyncio -from torba.testcase import AsyncioTestCase +from unittest import mock, TestCase from lbrynet.dht.protocol.data_store import DictDataStore from lbrynet.dht.peer import PeerManager -class DataStoreTests(AsyncioTestCase): +class DataStoreTests(TestCase): def setUp(self): - self.loop = asyncio.get_event_loop() + self.loop = mock.Mock(spec=asyncio.BaseEventLoop) + self.loop.time = lambda: 0.0 self.peer_manager = PeerManager(self.loop) self.data_store = DictDataStore(self.loop, self.peer_manager) diff --git a/tests/unit/stream/test_assembler.py b/tests/unit/stream/test_assembler.py deleted file mode 100644 index cc2e8ab3f..000000000 --- a/tests/unit/stream/test_assembler.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import asyncio -import tempfile -import shutil - -from torba.testcase import AsyncioTestCase -from lbrynet.conf import Config -from lbrynet.blob.blob_file import MAX_BLOB_SIZE -from lbrynet.extras.daemon.storage import SQLiteStorage -from lbrynet.blob.blob_manager import BlobManager -from lbrynet.stream.assembler import StreamAssembler -from lbrynet.stream.descriptor import StreamDescriptor -from lbrynet.stream.stream_manager import StreamManager - - -class TestStreamAssembler(AsyncioTestCase): - def setUp(self): - self.loop = asyncio.get_event_loop() - self.key = b'deadbeef' * 4 - self.cleartext = b'test' - - async def test_create_and_decrypt_one_blob_stream(self, corrupt=False): - tmp_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - self.storage = SQLiteStorage(Config(), ":memory:") - await self.storage.open() - self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage) - - download_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(download_dir)) - - # create the stream - file_path = os.path.join(tmp_dir, "test_file") - with open(file_path, 'wb') as f: - f.write(self.cleartext) - - sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key) - - # copy blob files - sd_hash = sd.calculate_sd_hash() - shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash)) - for blob_info in sd.blobs: - if blob_info.blob_hash: - shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash)) - if corrupt and blob_info.length == MAX_BLOB_SIZE: - with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle: - handle.truncate() - handle.flush() - - downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite")) - await downloader_storage.open() - - # add the blobs to the blob table (this would happen upon a blob download finishing) - downloader_blob_manager = BlobManager(self.loop, download_dir, downloader_storage) - descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash) - - # assemble the decrypted file - assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash) - await assembler.assemble_decrypted_stream(download_dir) - if corrupt: - return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file"))) - - with open(os.path.join(download_dir, "test_file"), "rb") as f: - decrypted = f.read() - self.assertEqual(decrypted, self.cleartext) - self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified()) - self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified()) - # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs - self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs())) - self.assertEqual( - [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce() - ) - - await downloader_storage.close() - await self.storage.close() - - async def test_create_and_decrypt_multi_blob_stream(self): - self.cleartext = b'test\n' * 20000000 - await self.test_create_and_decrypt_one_blob_stream() - - async def test_create_and_decrypt_padding(self): - for i in range(16): - self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i) - await self.test_create_and_decrypt_one_blob_stream() - - for i in range(16): - self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i) - await self.test_create_and_decrypt_one_blob_stream() - - async def test_create_and_decrypt_random(self): - self.cleartext = os.urandom(20000000) - await self.test_create_and_decrypt_one_blob_stream() - - async def test_create_managed_stream_announces(self): - # setup a blob manager - storage = SQLiteStorage(Config(), ":memory:") - await storage.open() - tmp_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(tmp_dir)) - blob_manager = BlobManager(self.loop, tmp_dir, storage) - stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None) - # create the stream - download_dir = tempfile.mkdtemp() - self.addCleanup(lambda: shutil.rmtree(download_dir)) - file_path = os.path.join(download_dir, "test_file") - with open(file_path, 'wb') as f: - f.write(b'testtest') - - stream = await stream_manager.create_stream(file_path) - self.assertEqual( - [stream.sd_hash, stream.descriptor.blobs[0].blob_hash], - await storage.get_blobs_to_announce()) - - async def test_create_truncate_and_handle_stream(self): - self.cleartext = b'potato' * 1337 * 5279 - # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated - await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5) diff --git a/tests/unit/stream/test_downloader.py b/tests/unit/stream/test_downloader.py deleted file mode 100644 index d97444c0c..000000000 --- a/tests/unit/stream/test_downloader.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import time -import unittest -from unittest import mock -import asyncio - -from lbrynet.blob_exchange.serialization import BlobResponse -from lbrynet.blob_exchange.server import BlobServerProtocol -from lbrynet.conf import Config -from lbrynet.stream.descriptor import StreamDescriptor -from lbrynet.stream.downloader import StreamDownloader -from lbrynet.dht.node import Node -from lbrynet.dht.peer import KademliaPeer -from lbrynet.blob.blob_file import MAX_BLOB_SIZE -from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase - - -class TestStreamDownloader(BlobExchangeTestBase): - async def setup_stream(self, blob_count: int = 10): - self.stream_bytes = b'' - for _ in range(blob_count): - self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) - # create the stream - file_path = os.path.join(self.server_dir, "test_file") - with open(file_path, 'wb') as f: - f.write(self.stream_bytes) - descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) - self.sd_hash = descriptor.calculate_sd_hash() - conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir, - reflector_servers=[]) - self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash) - - async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): - await self.setup_stream(blob_count) - mock_node = mock.Mock(spec=Node) - - def _mock_accumulate_peers(q1, q2): - async def _task(): - pass - q2.put_nowait([self.server_from_client]) - return q2, self.loop.create_task(_task()) - - mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers - self.downloader.download(mock_node) - await self.downloader.stream_finished_event.wait() - self.assertTrue(self.downloader.stream_handle.closed) - self.assertTrue(os.path.isfile(self.downloader.output_path)) - self.downloader.stop() - self.assertIs(self.downloader.stream_handle, None) - self.assertTrue(os.path.isfile(self.downloader.output_path)) - with open(self.downloader.output_path, 'rb') as f: - self.assertEqual(f.read(), self.stream_bytes) - await asyncio.sleep(0.01) - - async def test_transfer_stream(self): - await self._test_transfer_stream(10) - - @unittest.SkipTest - async def test_transfer_hundred_blob_stream(self): - await self._test_transfer_stream(100) - - async def test_transfer_stream_bad_first_peer_good_second(self): - await self.setup_stream(2) - - mock_node = mock.Mock(spec=Node) - q = asyncio.Queue() - - bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334) - - def _mock_accumulate_peers(q1, q2): - async def _task(): - pass - - q2.put_nowait([bad_peer]) - self.loop.call_later(1, q2.put_nowait, [self.server_from_client]) - return q2, self.loop.create_task(_task()) - - mock_node.accumulate_peers = _mock_accumulate_peers - - self.downloader.download(mock_node) - await self.downloader.stream_finished_event.wait() - self.assertTrue(os.path.isfile(self.downloader.output_path)) - with open(self.downloader.output_path, 'rb') as f: - self.assertEqual(f.read(), self.stream_bytes) - # self.assertIs(self.server_from_client.tcp_last_down, None) - # self.assertIsNot(bad_peer.tcp_last_down, None) - - async def test_client_chunked_response(self): - self.server.stop_server() - class ChunkedServerProtocol(BlobServerProtocol): - - def send_response(self, responses): - to_send = [] - while responses: - to_send.append(responses.pop()) - for byte in BlobResponse(to_send).serialize(): - self.transport.write(bytes([byte])) - self.server.server_protocol_class = ChunkedServerProtocol - self.server.start_server(33333, '127.0.0.1') - self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes)) - await asyncio.wait_for(self._test_transfer_stream(10), timeout=2) - self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes)) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py new file mode 100644 index 000000000..2039f1cbd --- /dev/null +++ b/tests/unit/stream/test_managed_stream.py @@ -0,0 +1,175 @@ +import os +import shutil +import unittest +from unittest import mock +import asyncio +from lbrynet.blob.blob_file import MAX_BLOB_SIZE +from lbrynet.blob_exchange.serialization import BlobResponse +from lbrynet.blob_exchange.server import BlobServerProtocol +from lbrynet.dht.node import Node +from lbrynet.dht.peer import KademliaPeer +from lbrynet.extras.daemon.storage import StoredStreamClaim +from lbrynet.stream.managed_stream import ManagedStream +from lbrynet.stream.descriptor import StreamDescriptor +from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase + + +def get_mock_node(loop): + mock_node = mock.Mock(spec=Node) + mock_node.joined = asyncio.Event(loop=loop) + mock_node.joined.set() + return mock_node + + +class TestManagedStream(BlobExchangeTestBase): + async def create_stream(self, blob_count: int = 10): + self.stream_bytes = b'' + for _ in range(blob_count): + self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) + # create the stream + file_path = os.path.join(self.server_dir, "test_file") + with open(file_path, 'wb') as f: + f.write(self.stream_bytes) + descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) + self.sd_hash = descriptor.calculate_sd_hash() + return descriptor + + async def setup_stream(self, blob_count: int = 10): + await self.create_stream(blob_count) + self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash, + self.client_dir) + + async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): + await self.setup_stream(blob_count) + mock_node = mock.Mock(spec=Node) + + def _mock_accumulate_peers(q1, q2): + async def _task(): + pass + q2.put_nowait([self.server_from_client]) + return q2, self.loop.create_task(_task()) + + mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers + await self.stream.setup(mock_node, save_file=True) + await self.stream.finished_writing.wait() + self.assertTrue(os.path.isfile(self.stream.full_path)) + self.stream.stop_download() + self.assertTrue(os.path.isfile(self.stream.full_path)) + with open(self.stream.full_path, 'rb') as f: + self.assertEqual(f.read(), self.stream_bytes) + await asyncio.sleep(0.01) + + async def test_transfer_stream(self): + await self._test_transfer_stream(10) + + @unittest.SkipTest + async def test_transfer_hundred_blob_stream(self): + await self._test_transfer_stream(100) + + async def test_transfer_stream_bad_first_peer_good_second(self): + await self.setup_stream(2) + + mock_node = mock.Mock(spec=Node) + q = asyncio.Queue() + + bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334) + + def _mock_accumulate_peers(q1, q2): + async def _task(): + pass + + q2.put_nowait([bad_peer]) + self.loop.call_later(1, q2.put_nowait, [self.server_from_client]) + return q2, self.loop.create_task(_task()) + + mock_node.accumulate_peers = _mock_accumulate_peers + + await self.stream.setup(mock_node, save_file=True) + await self.stream.finished_writing.wait() + self.assertTrue(os.path.isfile(self.stream.full_path)) + with open(self.stream.full_path, 'rb') as f: + self.assertEqual(f.read(), self.stream_bytes) + # self.assertIs(self.server_from_client.tcp_last_down, None) + # self.assertIsNot(bad_peer.tcp_last_down, None) + + async def test_client_chunked_response(self): + self.server.stop_server() + + class ChunkedServerProtocol(BlobServerProtocol): + def send_response(self, responses): + to_send = [] + while responses: + to_send.append(responses.pop()) + for byte in BlobResponse(to_send).serialize(): + self.transport.write(bytes([byte])) + self.server.server_protocol_class = ChunkedServerProtocol + self.server.start_server(33333, '127.0.0.1') + self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes)) + await asyncio.wait_for(self._test_transfer_stream(10), timeout=2) + self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes)) + + async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False): + descriptor = await self.create_stream(blobs) + + # copy blob files + shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash), + os.path.join(self.client_blob_manager.blob_dir, self.sd_hash)) + self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash, + self.client_dir) + + for blob_info in descriptor.blobs[:-1]: + shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash), + os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash)) + if corrupt and blob_info.length == MAX_BLOB_SIZE: + with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle: + handle.truncate() + handle.flush() + await self.stream.setup() + await self.stream.finished_writing.wait() + if corrupt: + return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) + + with open(os.path.join(self.client_dir, "test_file"), "rb") as f: + decrypted = f.read() + self.assertEqual(decrypted, self.stream_bytes) + + self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified()) + self.assertEqual( + True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified() + ) + # + # # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs + # self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs())) + # self.assertEqual( + # [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce() + # ) + # + # await downloader_storage.close() + # await self.storage.close() + + async def test_create_and_decrypt_multi_blob_stream(self): + await self.test_create_and_decrypt_one_blob_stream(10) + + # async def test_create_managed_stream_announces(self): + # # setup a blob manager + # storage = SQLiteStorage(Config(), ":memory:") + # await storage.open() + # tmp_dir = tempfile.mkdtemp() + # self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + # blob_manager = BlobManager(self.loop, tmp_dir, storage) + # stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None) + # # create the stream + # download_dir = tempfile.mkdtemp() + # self.addCleanup(lambda: shutil.rmtree(download_dir)) + # file_path = os.path.join(download_dir, "test_file") + # with open(file_path, 'wb') as f: + # f.write(b'testtest') + # + # stream = await stream_manager.create_stream(file_path) + # self.assertEqual( + # [stream.sd_hash, stream.descriptor.blobs[0].blob_hash], + # await storage.get_blobs_to_announce()) + + # async def test_create_truncate_and_handle_stream(self): + # # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated + # await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5) diff --git a/tests/unit/stream/test_stream_descriptor.py b/tests/unit/stream/test_stream_descriptor.py index 912fd3d21..7cadf3ffa 100644 --- a/tests/unit/stream/test_stream_descriptor.py +++ b/tests/unit/stream/test_stream_descriptor.py @@ -99,7 +99,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase): blob = blob_manager.get_blob(sd_hash) blob.set_length(len(sd_bytes)) - writer = blob.open_for_writing() + writer = blob.get_blob_writer() writer.write(sd_bytes) await blob.verified.wait() descriptor = await StreamDescriptor.from_stream_descriptor_blob( @@ -116,7 +116,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase): sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2' with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle: handle.write(b'doesnt work') - blob = BlobFile(loop, tmp_dir, sd_hash) + blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir) self.assertTrue(blob.file_exists) self.assertIsNotNone(blob.length) with self.assertRaises(InvalidStreamDescriptorError): diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index b2a57c652..9fd957df1 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase): def check_post(event): self.assertEqual(event['event'], 'Time To First Bytes') self.assertEqual(event['properties']['error'], 'DownloadSDTimeout') - self.assertEqual(event['properties']['tried_peers_count'], None) - self.assertEqual(event['properties']['active_peer_count'], None) + self.assertEqual(event['properties']['tried_peers_count'], 0) + self.assertEqual(event['properties']['active_peer_count'], 0) self.assertEqual(event['properties']['use_fixed_peers'], False) self.assertEqual(event['properties']['added_fixed_peers'], False) self.assertEqual(event['properties']['fixed_peer_delay'], None) @@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase): self.stream_manager.analytics_manager._post = check_post - self.assertSetEqual(self.stream_manager.streams, set()) + self.assertDictEqual(self.stream_manager.streams, {}) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) stream_hash = stream.stream_hash - self.assertSetEqual(self.stream_manager.streams, {stream}) + self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream}) self.assertTrue(stream.running) self.assertFalse(stream.finished) self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file"))) @@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase): self.assertEqual(stored_status, "stopped") await self.stream_manager.start_stream(stream) - await stream.downloader.stream_finished_event.wait() + await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.assertTrue(stream.finished) self.assertFalse(stream.running) @@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase): self.assertEqual(stored_status, "finished") await self.stream_manager.delete_stream(stream, True) - self.assertSetEqual(self.stream_manager.streams, set()) + self.assertDictEqual(self.stream_manager.streams, {}) self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) stored_status = await self.client_storage.run_and_return_one_or_none( "select status from file where stream_hash=?", stream_hash @@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase): async def _test_download_error_on_start(self, expected_error, timeout=None): with self.assertRaises(expected_error): - await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout) + await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout) async def _test_download_error_analytics_on_start(self, expected_error, timeout=None): received = [] @@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase): await self.setup_stream_manager(old_sort=old_sort) self.stream_manager.analytics_manager._post = check_post - self.assertSetEqual(self.stream_manager.streams, set()) + self.assertDictEqual(self.stream_manager.streams, {}) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) - await stream.downloader.stream_finished_event.wait() + await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.stream_manager.stop() self.client_blob_manager.stop() @@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase): await self.client_blob_manager.setup() await self.stream_manager.start() self.assertEqual(1, len(self.stream_manager.streams)) - self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash) - self.assertEqual('stopped', list(self.stream_manager.streams)[0].status) + self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys())) + for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]: + blob_status = await self.client_storage.get_blob_status(blob_hash) + self.assertEqual('pending', blob_status) + self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status) sd_blob = self.client_blob_manager.get_blob(stream.sd_hash) self.assertTrue(sd_blob.file_exists)