diff --git a/lbrynet/extras/daemon/Daemon.py b/lbrynet/extras/daemon/Daemon.py index fb2d5b79d..36c39b5ca 100644 --- a/lbrynet/extras/daemon/Daemon.py +++ b/lbrynet/extras/daemon/Daemon.py @@ -18,7 +18,7 @@ from torba.client.baseaccount import SingleKey, HierarchicalDeterministic from lbrynet import utils from lbrynet.conf import Config, Setting -from lbrynet.blob.blob_file import is_valid_blobhash +from lbrynet.blob.blob_file import is_valid_blobhash, BlobBuffer from lbrynet.blob_exchange.downloader import download_blob from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet @@ -477,59 +477,11 @@ class Daemon(metaclass=JSONRPCServerType): raise web.HTTPServerError(text=stream['error']) raise web.HTTPFound(f"/stream/{stream.sd_hash}") - @staticmethod - def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str], - int, int]: - if '=' in get_range: - get_range = get_range.split('=')[1] - start, end = get_range.split('-') - size = 0 - for blob in stream.descriptor.blobs[:-1]: - size += blob.length - 1 - start = int(start) - end = int(end) if end else size - 1 - skip_blobs = start // 2097150 - skip = skip_blobs * 2097151 - start = skip - final_size = end - start + 1 - - headers = { - 'Accept-Ranges': 'bytes', - 'Content-Range': f'bytes {start}-{end}/{size}', - 'Content-Length': str(final_size), - 'Content-Type': stream.mime_type - } - return headers, size, skip_blobs - async def handle_stream_range_request(self, request: web.Request): sd_hash = request.path.split("/stream/")[1] if sd_hash not in self.stream_manager.streams: return web.HTTPNotFound() - stream = self.stream_manager.streams[sd_hash] - if stream.status == 'stopped': - await self.stream_manager.start_stream(stream) - if stream.delayed_stop: - stream.delayed_stop.cancel() - headers, size, skip_blobs = self.prepare_range_response_headers( - request.headers.get('range', 'bytes=0-'), stream - ) - response = web.StreamResponse( - status=206, - headers=headers - ) - await response.prepare(request) - wrote = 0 - async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs): - log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1) - if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size): - decrypted += b'\x00' * (size - len(decrypted) - wrote) - await response.write_eof(decrypted) - break - else: - await response.write(decrypted) - wrote += len(decrypted) - response.force_close() - return response + return await self.stream_manager.stream_partial_content(request, sd_hash) async def _process_rpc_call(self, data): args = data.get('params', {}) @@ -924,7 +876,6 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {File} """ - save_file = save_file if save_file is not None else self.conf.save_files try: stream = await self.stream_manager.download_stream_from_uri( uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file @@ -1554,7 +1505,7 @@ class Daemon(metaclass=JSONRPCServerType): await self.stream_manager.start_stream(stream) msg = "Resumed download" elif status == 'stop' and stream.running: - await self.stream_manager.stop_stream(stream) + await stream.stop() msg = "Stopped download" else: msg = ( diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index a0fdf4250..6535c665c 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -3,7 +3,9 @@ import asyncio import typing import logging import binascii +from aiohttp.web import Request, StreamResponse from lbrynet.utils import generate_id +from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout from lbrynet.schema.mime_types import guess_media_type from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.descriptor import StreamDescriptor @@ -40,6 +42,33 @@ class ManagedStream: STATUS_STOPPED = "stopped" STATUS_FINISHED = "finished" + __slots__ = [ + 'loop', + 'config', + 'blob_manager', + 'sd_hash', + 'download_directory', + '_file_name', + '_status', + 'stream_claim_info', + 'download_id', + 'rowid', + 'written_bytes', + 'content_fee', + 'downloader', + 'analytics_manager', + 'fully_reflected', + 'file_output_task', + 'delayed_stop_task', + 'streaming_responses', + 'streaming', + '_running', + 'saving', + 'finished_writing', + 'started_writing', + + ] + def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, @@ -61,9 +90,13 @@ class ManagedStream: self.content_fee = content_fee self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.analytics_manager = analytics_manager + self.fully_reflected = asyncio.Event(loop=self.loop) self.file_output_task: typing.Optional[asyncio.Task] = None - self.delayed_stop: typing.Optional[asyncio.Handle] = None + self.delayed_stop_task: typing.Optional[asyncio.Task] = None + self.streaming_responses: typing.List[StreamResponse] = [] + self.streaming = asyncio.Event(loop=self.loop) + self._running = asyncio.Event(loop=self.loop) self.saving = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop) self.started_writing = asyncio.Event(loop=self.loop) @@ -84,9 +117,10 @@ class ManagedStream: def status(self) -> str: return self._status - def update_status(self, status: str): + async def update_status(self, status: str): assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED] self._status = status + await self.blob_manager.storage.change_file_status(self.stream_hash, status) @property def finished(self) -> bool: @@ -216,47 +250,85 @@ class ManagedStream: return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) - async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True, - file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): - await self.downloader.start(node) - if not save_file and not file_name: - if not await self.blob_manager.storage.file_exists(self.sd_hash): - self.rowid = await self.blob_manager.storage.save_downloaded_file( - self.stream_hash, None, None, 0.0 - ) - self.download_directory = None - self._file_name = None - self.update_status(ManagedStream.STATUS_RUNNING) - await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) - self.update_delayed_stop() - else: - await self.save_file(file_name, download_directory) - await self.started_writing.wait() + async def start(self, node: typing.Optional['Node'] = None, timeout: typing.Optional[float] = None, + save_now: bool = False): + timeout = timeout or self.config.download_timeout + if self._running.is_set(): + return + self._running.set() + start_time = self.loop.time() + try: + await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop) + if save_now: + await asyncio.wait_for(self.save_file(node=node), timeout - (self.loop.time() - start_time), + loop=self.loop) + except asyncio.TimeoutError: + self._running.clear() + if not self.descriptor: + raise DownloadSDTimeout(self.sd_hash) + raise DownloadDataTimeout(self.sd_hash) - def update_delayed_stop(self): - def _delayed_stop(): - log.info("Stopping inactive download for stream %s", self.sd_hash) - self.stop_download() + if self.delayed_stop_task and not self.delayed_stop_task.done(): + self.delayed_stop_task.cancel() + self.delayed_stop_task = self.loop.create_task(self._delayed_stop()) + if not await self.blob_manager.storage.file_exists(self.sd_hash): + if save_now: + file_name, download_dir = self._file_name, self.download_directory + else: + file_name, download_dir = None, None + self.rowid = await self.blob_manager.storage.save_downloaded_file( + self.stream_hash, file_name, download_dir, 0.0 + ) + if self.status != self.STATUS_RUNNING: + await self.update_status(self.STATUS_RUNNING) - if self.delayed_stop: - self.delayed_stop.cancel() - self.delayed_stop = self.loop.call_later(60, _delayed_stop) + async def stop(self, finished: bool = False): + """ + Stop any running save/stream tasks as well as the downloader and update the status in the database + """ - async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[ - typing.Tuple['BlobInfo', bytes]]: + self.stop_tasks() + if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING: + await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED) + + async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\ + -> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]: if start_blob_num >= len(self.descriptor.blobs[:-1]): raise IndexError(start_blob_num) for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]): assert i + start_blob_num == blob_info.blob_num - if self.delayed_stop: - self.delayed_stop.cancel() - try: - decrypted = await self.downloader.read_blob(blob_info) - yield (blob_info, decrypted) - except asyncio.CancelledError: - if not self.saving.is_set() and not self.finished_writing.is_set(): - self.update_delayed_stop() - raise + decrypted = await self.downloader.read_blob(blob_info) + yield (blob_info, decrypted) + + async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse: + await self.start(node) + headers, size, skip_blobs = self._prepare_range_response_headers(request.headers.get('range', 'bytes=0-')) + response = StreamResponse( + status=206, + headers=headers + ) + await response.prepare(request) + self.streaming_responses.append(response) + self.streaming.set() + try: + wrote = 0 + async for blob_info, decrypted in self._aiter_read_stream(skip_blobs): + if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size): + decrypted += b'\x00' * (size - len(decrypted) - wrote) + await response.write_eof(decrypted) + else: + await response.write(decrypted) + wrote += len(decrypted) + log.info("streamed %sblob %i/%i", "(closing stream) " if response._eof_sent else "", + blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) + if response._eof_sent: + break + return response + finally: + response.force_close() + if response in self.streaming_responses: + self.streaming_responses.remove(response) + self.streaming.clear() async def _save_file(self, output_path: str): log.debug("save file %s -> %s", self.sd_hash, output_path) @@ -265,15 +337,14 @@ class ManagedStream: self.started_writing.clear() try: with open(output_path, 'wb') as file_write_handle: - async for blob_info, decrypted in self.aiter_read_stream(): + async for blob_info, decrypted in self._aiter_read_stream(): log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) file_write_handle.write(decrypted) file_write_handle.flush() self.written_bytes += len(decrypted) if not self.started_writing.is_set(): self.started_writing.set() - self.update_status(ManagedStream.STATUS_FINISHED) - await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED) + await self.update_status(ManagedStream.STATUS_FINISHED) if self.analytics_manager: self.loop.create_task(self.analytics_manager.send_download_finished( self.download_id, self.claim_name, self.sd_hash @@ -289,12 +360,11 @@ class ManagedStream: finally: self.saving.clear() - async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): - if self.file_output_task and not self.file_output_task.done(): + async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None, + node: typing.Optional['Node'] = None): + await self.start(node) + if self.file_output_task and not self.file_output_task.done(): # cancel an already running save task self.file_output_task.cancel() - if self.delayed_stop: - self.delayed_stop.cancel() - self.delayed_stop = None self.download_directory = download_directory or self.download_directory or self.config.download_dir if not self.download_directory: raise ValueError("no directory to download to") @@ -303,28 +373,26 @@ class ManagedStream: if not os.path.isdir(self.download_directory): log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) os.mkdir(self.download_directory) - if not await self.blob_manager.storage.file_exists(self.sd_hash): - self._file_name = await get_next_available_file_name( - self.loop, self.download_directory, - file_name or self._file_name or self.descriptor.suggested_file_name - ) - self.rowid = self.blob_manager.storage.save_downloaded_file( - self.stream_hash, self.file_name, self.download_directory, 0.0 - ) - else: - await self.blob_manager.storage.change_file_download_dir_and_file_name( - self.stream_hash, self.download_directory, self.file_name - ) - self.update_status(ManagedStream.STATUS_RUNNING) - await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) + self._file_name = await get_next_available_file_name( + self.loop, self.download_directory, + file_name or self._file_name or self.descriptor.suggested_file_name + ) + await self.blob_manager.storage.change_file_download_dir_and_file_name( + self.stream_hash, self.download_directory, self.file_name + ) + await self.update_status(ManagedStream.STATUS_RUNNING) self.written_bytes = 0 self.file_output_task = self.loop.create_task(self._save_file(self.full_path)) + await self.started_writing.wait() - def stop_download(self): + def stop_tasks(self): if self.file_output_task and not self.file_output_task.done(): self.file_output_task.cancel() self.file_output_task = None + while self.streaming_responses: + self.streaming_responses.pop().force_close() self.downloader.stop() + self._running.clear() async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]: sent = [] @@ -365,3 +433,43 @@ class ManagedStream: binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'], claim_info['claim_sequence'], claim_info.get('channel_name') ) + + async def update_content_claim(self, claim_info: typing.Optional[typing.Dict] = None): + if not claim_info: + claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash) + self.set_claim(claim_info, claim_info['value']) + + async def _delayed_stop(self): + stalled_count = 0 + while self._running.is_set(): + if self.saving.is_set() or self.streaming.is_set(): + stalled_count = 0 + else: + stalled_count += 1 + if stalled_count > 1: + log.info("Stopping inactive download for stream %s", self.sd_hash) + await self.stop() + return + await asyncio.sleep(1, loop=self.loop) + + def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int]: + if '=' in get_range: + get_range = get_range.split('=')[1] + start, end = get_range.split('-') + size = 0 + for blob in self.descriptor.blobs[:-1]: + size += blob.length - 1 + start = int(start) + end = int(end) if end else size - 1 + skip_blobs = start // 2097150 + skip = skip_blobs * 2097151 + start = skip + final_size = end - start + 1 + + headers = { + 'Accept-Ranges': 'bytes', + 'Content-Range': f'bytes {start}-{end}/{size}', + 'Content-Length': str(final_size), + 'Content-Type': self.mime_type + } + return headers, size, skip_blobs diff --git a/lbrynet/stream/stream_manager.py b/lbrynet/stream/stream_manager.py index 9a4c47581..2f5470c63 100644 --- a/lbrynet/stream/stream_manager.py +++ b/lbrynet/stream/stream_manager.py @@ -5,8 +5,9 @@ import binascii import logging import random from decimal import Decimal +from aiohttp.web import Request from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError -from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout +from lbrynet.error import ResolveTimeout, DownloadDataTimeout from lbrynet.utils import cache_concurrent from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.managed_stream import ManagedStream @@ -56,6 +57,7 @@ comparison_operators = { def path_or_none(p) -> typing.Optional[str]: return None if p == '{stream}' else binascii.unhexlify(p).decode() + class StreamManager: def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'], @@ -77,24 +79,6 @@ class StreamManager: claim_info = await self.storage.get_content_claim(stream.stream_hash) self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value']) - async def stop_stream(self, stream: ManagedStream): - stream.stop_download() - if not stream.finished and stream.output_file_exists: - try: - os.remove(stream.full_path) - except OSError as err: - log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path, - str(err)) - if stream.running: - stream.update_status(ManagedStream.STATUS_STOPPED) - await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED) - - async def start_stream(self, stream: ManagedStream): - stream.update_status(ManagedStream.STATUS_RUNNING) - await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING) - await stream.setup(self.node, save_file=self.config.save_files) - self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) - async def recover_streams(self, file_infos: typing.List[typing.Dict]): to_restore = [] @@ -150,6 +134,7 @@ class StreamManager: await self.recover_streams(to_recover) if not self.config.save_files: + # set files that have been deleted manually to streaming mode to_set_as_streaming = [] for file_info in to_start: file_name = path_or_none(file_info['file_name']) @@ -176,7 +161,7 @@ class StreamManager: if not self.node: log.warning("no DHT node given, resuming downloads trusting that we can contact reflector") t = [ - self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values() + self.loop.create_task(stream.start(node=self.node)) for stream in self.streams.values() if stream.running ] if t: @@ -214,7 +199,7 @@ class StreamManager: self.re_reflect_task.cancel() while self.streams: _, stream = self.streams.popitem() - stream.stop_download() + stream.stop_tasks() while self.update_stream_finished_futs: self.update_stream_finished_futs.pop().cancel() while self.running_reflector_uploads: @@ -236,7 +221,7 @@ class StreamManager: return stream async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): - await self.stop_stream(stream) + stream.stop_tasks() if stream.sd_hash in self.streams: del self.streams[stream.sd_hash] blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]] @@ -290,21 +275,16 @@ class StreamManager: typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: existing = self.get_filtered_streams(outpoint=outpoint) if existing: - if existing[0].status == ManagedStream.STATUS_STOPPED: - await self.start_stream(existing[0]) return existing[0], None existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) if existing and existing[0].claim_id != claim_id: - raise ResolveError(f"stream for {existing[0].claim_id} collides with existing " - f"download {claim_id}") + raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {claim_id}") if existing: log.info("claim contains a metadata only update to a stream we have") await self.storage.save_content_claim( existing[0].stream_hash, outpoint ) await self._update_content_claim(existing[0]) - if not existing[0].running: - await self.start_stream(existing[0]) return existing[0], None else: existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id) @@ -318,13 +298,23 @@ class StreamManager: timeout: typing.Optional[float] = None, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None, - save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream: + save_file: typing.Optional[bool] = None, + resolve_timeout: float = 3.0) -> ManagedStream: timeout = timeout or self.config.download_timeout start_time = self.loop.time() resolved_time = None stream = None error = None outpoint = None + if save_file is None: + save_file = self.config.save_files + if file_name and not save_file: + save_file = True + if save_file: + download_directory = download_directory or self.config.download_dir + else: + download_directory = None + try: # resolve the claim parsed_uri = parse_lbry_uri(uri) @@ -352,6 +342,9 @@ class StreamManager: updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) if updated_stream: log.info("already have stream for %s", uri) + if save_file and updated_stream.output_file_exists: + save_file = False + await updated_stream.start(node=self.node, timeout=timeout, save_now=save_file) return updated_stream content_fee = None @@ -381,30 +374,18 @@ class StreamManager: log.info("paid fee of %s for %s", fee_amount, uri) - download_directory = download_directory or self.config.download_dir - if not file_name and (not self.config.save_files or not save_file): - download_dir, file_name = None, None stream = ManagedStream( self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory, file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee, analytics_manager=self.analytics_manager ) log.info("starting download for %s", uri) - try: - await asyncio.wait_for(stream.setup( - self.node, save_file=save_file, file_name=file_name, download_directory=download_directory - ), timeout, loop=self.loop) - except asyncio.TimeoutError: - if not stream.descriptor: - raise DownloadSDTimeout(stream.sd_hash) - raise DownloadDataTimeout(stream.sd_hash) - finally: - if stream.descriptor: - if to_replace: # delete old stream now that the replacement has started downloading - await self.delete_stream(to_replace) - stream.set_claim(resolved, claim) - await self.storage.save_content_claim(stream.stream_hash, outpoint) - self.streams[stream.sd_hash] = stream + await stream.start(self.node, timeout, save_now=save_file) + if to_replace: # delete old stream now that the replacement has started downloading + await self.delete_stream(to_replace) + self.streams[stream.sd_hash] = stream + stream.set_claim(resolved, claim) + await self.storage.save_content_claim(stream.stream_hash, outpoint) return stream except DownloadDataTimeout as err: # forgive data timeout, dont delete stream error = err @@ -435,3 +416,6 @@ class StreamManager: ) if error: raise error + + async def stream_partial_content(self, request: Request, sd_hash: str): + return await self.streams[sd_hash].stream_file(request, self.node) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py index 372e5766d..0db4af1fe 100644 --- a/tests/unit/stream/test_managed_stream.py +++ b/tests/unit/stream/test_managed_stream.py @@ -40,7 +40,7 @@ class TestManagedStream(BlobExchangeTestBase): self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir ) - async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): + async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True): await self.setup_stream(blob_count) mock_node = mock.Mock(spec=Node) @@ -51,10 +51,11 @@ class TestManagedStream(BlobExchangeTestBase): return q2, self.loop.create_task(_task()) mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers - await self.stream.setup(mock_node, save_file=True) + await self.stream.save_file(node=mock_node) await self.stream.finished_writing.wait() self.assertTrue(os.path.isfile(self.stream.full_path)) - self.stream.stop_download() + if stop_when_done: + await self.stream.stop() self.assertTrue(os.path.isfile(self.stream.full_path)) with open(self.stream.full_path, 'rb') as f: self.assertEqual(f.read(), self.stream_bytes) @@ -62,6 +63,18 @@ class TestManagedStream(BlobExchangeTestBase): async def test_transfer_stream(self): await self._test_transfer_stream(10) + self.assertEqual(self.stream.status, "finished") + self.assertFalse(self.stream._running.is_set()) + + async def test_delayed_stop(self): + await self._test_transfer_stream(10, stop_when_done=False) + self.assertEqual(self.stream.status, "finished") + self.assertTrue(self.stream._running.is_set()) + await asyncio.sleep(0.5, loop=self.loop) + self.assertTrue(self.stream._running.is_set()) + await asyncio.sleep(0.6, loop=self.loop) + self.assertEqual(self.stream.status, "finished") + self.assertFalse(self.stream._running.is_set()) @unittest.SkipTest async def test_transfer_hundred_blob_stream(self): @@ -85,11 +98,12 @@ class TestManagedStream(BlobExchangeTestBase): mock_node.accumulate_peers = _mock_accumulate_peers - await self.stream.setup(mock_node, save_file=True) + await self.stream.save_file(node=mock_node) await self.stream.finished_writing.wait() self.assertTrue(os.path.isfile(self.stream.full_path)) with open(self.stream.full_path, 'rb') as f: self.assertEqual(f.read(), self.stream_bytes) + await self.stream.stop() # self.assertIs(self.server_from_client.tcp_last_down, None) # self.assertIsNot(bad_peer.tcp_last_down, None) @@ -125,7 +139,7 @@ class TestManagedStream(BlobExchangeTestBase): with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle: handle.truncate() handle.flush() - await self.stream.setup() + await self.stream.save_file() await self.stream.finished_writing.wait() if corrupt: return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index 761154b42..6ca453501 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -225,7 +225,7 @@ class TestStreamManager(BlobExchangeTestBase): ) self.assertEqual(stored_status, "running") - await self.stream_manager.stop_stream(stream) + await stream.stop() self.assertFalse(stream.finished) self.assertFalse(stream.running) @@ -235,7 +235,7 @@ class TestStreamManager(BlobExchangeTestBase): ) self.assertEqual(stored_status, "stopped") - await self.stream_manager.start_stream(stream) + await stream.save_file(node=self.stream_manager.node) await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.assertTrue(stream.finished)