fix/refactor starting and stopping files

-move partial content handling into ManagedStream
-add delayed stop test
This commit is contained in:
Jack Robison 2019-05-01 17:09:50 -04:00
parent b134e0c9c9
commit 9099ee2e8e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 222 additions and 165 deletions

View file

@ -18,7 +18,7 @@ from torba.client.baseaccount import SingleKey, HierarchicalDeterministic
from lbrynet import utils
from lbrynet.conf import Config, Setting
from lbrynet.blob.blob_file import is_valid_blobhash
from lbrynet.blob.blob_file import is_valid_blobhash, BlobBuffer
from lbrynet.blob_exchange.downloader import download_blob
from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted
from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet
@ -477,59 +477,11 @@ class Daemon(metaclass=JSONRPCServerType):
raise web.HTTPServerError(text=stream['error'])
raise web.HTTPFound(f"/stream/{stream.sd_hash}")
@staticmethod
def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str],
int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in stream.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': stream.mime_type
}
return headers, size, skip_blobs
async def handle_stream_range_request(self, request: web.Request):
sd_hash = request.path.split("/stream/")[1]
if sd_hash not in self.stream_manager.streams:
return web.HTTPNotFound()
stream = self.stream_manager.streams[sd_hash]
if stream.status == 'stopped':
await self.stream_manager.start_stream(stream)
if stream.delayed_stop:
stream.delayed_stop.cancel()
headers, size, skip_blobs = self.prepare_range_response_headers(
request.headers.get('range', 'bytes=0-'), stream
)
response = web.StreamResponse(
status=206,
headers=headers
)
await response.prepare(request)
wrote = 0
async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs):
log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1)
if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += b'\x00' * (size - len(decrypted) - wrote)
await response.write_eof(decrypted)
break
else:
await response.write(decrypted)
wrote += len(decrypted)
response.force_close()
return response
return await self.stream_manager.stream_partial_content(request, sd_hash)
async def _process_rpc_call(self, data):
args = data.get('params', {})
@ -924,7 +876,6 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: {File}
"""
save_file = save_file if save_file is not None else self.conf.save_files
try:
stream = await self.stream_manager.download_stream_from_uri(
uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file
@ -1554,7 +1505,7 @@ class Daemon(metaclass=JSONRPCServerType):
await self.stream_manager.start_stream(stream)
msg = "Resumed download"
elif status == 'stop' and stream.running:
await self.stream_manager.stop_stream(stream)
await stream.stop()
msg = "Stopped download"
else:
msg = (

View file

@ -3,7 +3,9 @@ import asyncio
import typing
import logging
import binascii
from aiohttp.web import Request, StreamResponse
from lbrynet.utils import generate_id
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout
from lbrynet.schema.mime_types import guess_media_type
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor
@ -40,6 +42,33 @@ class ManagedStream:
STATUS_STOPPED = "stopped"
STATUS_FINISHED = "finished"
__slots__ = [
'loop',
'config',
'blob_manager',
'sd_hash',
'download_directory',
'_file_name',
'_status',
'stream_claim_info',
'download_id',
'rowid',
'written_bytes',
'content_fee',
'downloader',
'analytics_manager',
'fully_reflected',
'file_output_task',
'delayed_stop_task',
'streaming_responses',
'streaming',
'_running',
'saving',
'finished_writing',
'started_writing',
]
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
@ -61,9 +90,13 @@ class ManagedStream:
self.content_fee = content_fee
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None
self.delayed_stop_task: typing.Optional[asyncio.Task] = None
self.streaming_responses: typing.List[StreamResponse] = []
self.streaming = asyncio.Event(loop=self.loop)
self._running = asyncio.Event(loop=self.loop)
self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop)
@ -84,9 +117,10 @@ class ManagedStream:
def status(self) -> str:
return self._status
def update_status(self, status: str):
async def update_status(self, status: str):
assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED]
self._status = status
await self.blob_manager.storage.change_file_status(self.stream_hash, status)
@property
def finished(self) -> bool:
@ -216,47 +250,85 @@ class ManagedStream:
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
await self.downloader.start(node)
if not save_file and not file_name:
if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = await self.blob_manager.storage.save_downloaded_file(
self.stream_hash, None, None, 0.0
)
self.download_directory = None
self._file_name = None
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self.update_delayed_stop()
else:
await self.save_file(file_name, download_directory)
await self.started_writing.wait()
async def start(self, node: typing.Optional['Node'] = None, timeout: typing.Optional[float] = None,
save_now: bool = False):
timeout = timeout or self.config.download_timeout
if self._running.is_set():
return
self._running.set()
start_time = self.loop.time()
try:
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
if save_now:
await asyncio.wait_for(self.save_file(node=node), timeout - (self.loop.time() - start_time),
loop=self.loop)
except asyncio.TimeoutError:
self._running.clear()
if not self.descriptor:
raise DownloadSDTimeout(self.sd_hash)
raise DownloadDataTimeout(self.sd_hash)
def update_delayed_stop(self):
def _delayed_stop():
log.info("Stopping inactive download for stream %s", self.sd_hash)
self.stop_download()
if self.delayed_stop_task and not self.delayed_stop_task.done():
self.delayed_stop_task.cancel()
self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
if not await self.blob_manager.storage.file_exists(self.sd_hash):
if save_now:
file_name, download_dir = self._file_name, self.download_directory
else:
file_name, download_dir = None, None
self.rowid = await self.blob_manager.storage.save_downloaded_file(
self.stream_hash, file_name, download_dir, 0.0
)
if self.status != self.STATUS_RUNNING:
await self.update_status(self.STATUS_RUNNING)
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = self.loop.call_later(60, _delayed_stop)
async def stop(self, finished: bool = False):
"""
Stop any running save/stream tasks as well as the downloader and update the status in the database
"""
async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[
typing.Tuple['BlobInfo', bytes]]:
self.stop_tasks()
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
if start_blob_num >= len(self.descriptor.blobs[:-1]):
raise IndexError(start_blob_num)
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
assert i + start_blob_num == blob_info.blob_num
if self.delayed_stop:
self.delayed_stop.cancel()
try:
decrypted = await self.downloader.read_blob(blob_info)
yield (blob_info, decrypted)
except asyncio.CancelledError:
if not self.saving.is_set() and not self.finished_writing.is_set():
self.update_delayed_stop()
raise
decrypted = await self.downloader.read_blob(blob_info)
yield (blob_info, decrypted)
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
await self.start(node)
headers, size, skip_blobs = self._prepare_range_response_headers(request.headers.get('range', 'bytes=0-'))
response = StreamResponse(
status=206,
headers=headers
)
await response.prepare(request)
self.streaming_responses.append(response)
self.streaming.set()
try:
wrote = 0
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs):
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += b'\x00' * (size - len(decrypted) - wrote)
await response.write_eof(decrypted)
else:
await response.write(decrypted)
wrote += len(decrypted)
log.info("streamed %sblob %i/%i", "(closing stream) " if response._eof_sent else "",
blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
if response._eof_sent:
break
return response
finally:
response.force_close()
if response in self.streaming_responses:
self.streaming_responses.remove(response)
self.streaming.clear()
async def _save_file(self, output_path: str):
log.debug("save file %s -> %s", self.sd_hash, output_path)
@ -265,15 +337,14 @@ class ManagedStream:
self.started_writing.clear()
try:
with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream():
async for blob_info, decrypted in self._aiter_read_stream():
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted)
file_write_handle.flush()
self.written_bytes += len(decrypted)
if not self.started_writing.is_set():
self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
await self.update_status(ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash
@ -289,12 +360,11 @@ class ManagedStream:
finally:
self.saving.clear()
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
if self.file_output_task and not self.file_output_task.done():
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None,
node: typing.Optional['Node'] = None):
await self.start(node)
if self.file_output_task and not self.file_output_task.done(): # cancel an already running save task
self.file_output_task.cancel()
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = None
self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory:
raise ValueError("no directory to download to")
@ -303,28 +373,26 @@ class ManagedStream:
if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory)
if not await self.blob_manager.storage.file_exists(self.sd_hash):
self._file_name = await get_next_available_file_name(
self.loop, self.download_directory,
file_name or self._file_name or self.descriptor.suggested_file_name
)
self.rowid = self.blob_manager.storage.save_downloaded_file(
self.stream_hash, self.file_name, self.download_directory, 0.0
)
else:
await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name
)
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self._file_name = await get_next_available_file_name(
self.loop, self.download_directory,
file_name or self._file_name or self.descriptor.suggested_file_name
)
await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name
)
await self.update_status(ManagedStream.STATUS_RUNNING)
self.written_bytes = 0
self.file_output_task = self.loop.create_task(self._save_file(self.full_path))
await self.started_writing.wait()
def stop_download(self):
def stop_tasks(self):
if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel()
self.file_output_task = None
while self.streaming_responses:
self.streaming_responses.pop().force_close()
self.downloader.stop()
self._running.clear()
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
sent = []
@ -365,3 +433,43 @@ class ManagedStream:
binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'],
claim_info['claim_sequence'], claim_info.get('channel_name')
)
async def update_content_claim(self, claim_info: typing.Optional[typing.Dict] = None):
if not claim_info:
claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash)
self.set_claim(claim_info, claim_info['value'])
async def _delayed_stop(self):
stalled_count = 0
while self._running.is_set():
if self.saving.is_set() or self.streaming.is_set():
stalled_count = 0
else:
stalled_count += 1
if stalled_count > 1:
log.info("Stopping inactive download for stream %s", self.sd_hash)
await self.stop()
return
await asyncio.sleep(1, loop=self.loop)
def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in self.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': self.mime_type
}
return headers, size, skip_blobs

View file

@ -5,8 +5,9 @@ import binascii
import logging
import random
from decimal import Decimal
from aiohttp.web import Request
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
from lbrynet.error import ResolveTimeout, DownloadDataTimeout
from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.managed_stream import ManagedStream
@ -56,6 +57,7 @@ comparison_operators = {
def path_or_none(p) -> typing.Optional[str]:
return None if p == '{stream}' else binascii.unhexlify(p).decode()
class StreamManager:
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'],
@ -77,24 +79,6 @@ class StreamManager:
claim_info = await self.storage.get_content_claim(stream.stream_hash)
self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value'])
async def stop_stream(self, stream: ManagedStream):
stream.stop_download()
if not stream.finished and stream.output_file_exists:
try:
os.remove(stream.full_path)
except OSError as err:
log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path,
str(err))
if stream.running:
stream.update_status(ManagedStream.STATUS_STOPPED)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
async def start_stream(self, stream: ManagedStream):
stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
await stream.setup(self.node, save_file=self.config.save_files)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = []
@ -150,6 +134,7 @@ class StreamManager:
await self.recover_streams(to_recover)
if not self.config.save_files:
# set files that have been deleted manually to streaming mode
to_set_as_streaming = []
for file_info in to_start:
file_name = path_or_none(file_info['file_name'])
@ -176,7 +161,7 @@ class StreamManager:
if not self.node:
log.warning("no DHT node given, resuming downloads trusting that we can contact reflector")
t = [
self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values()
self.loop.create_task(stream.start(node=self.node)) for stream in self.streams.values()
if stream.running
]
if t:
@ -214,7 +199,7 @@ class StreamManager:
self.re_reflect_task.cancel()
while self.streams:
_, stream = self.streams.popitem()
stream.stop_download()
stream.stop_tasks()
while self.update_stream_finished_futs:
self.update_stream_finished_futs.pop().cancel()
while self.running_reflector_uploads:
@ -236,7 +221,7 @@ class StreamManager:
return stream
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
await self.stop_stream(stream)
stream.stop_tasks()
if stream.sd_hash in self.streams:
del self.streams[stream.sd_hash]
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
@ -290,21 +275,16 @@ class StreamManager:
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint)
if existing:
if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0])
return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
if existing and existing[0].claim_id != claim_id:
raise ResolveError(f"stream for {existing[0].claim_id} collides with existing "
f"download {claim_id}")
raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {claim_id}")
if existing:
log.info("claim contains a metadata only update to a stream we have")
await self.storage.save_content_claim(
existing[0].stream_hash, outpoint
)
await self._update_content_claim(existing[0])
if not existing[0].running:
await self.start_stream(existing[0])
return existing[0], None
else:
existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id)
@ -318,13 +298,23 @@ class StreamManager:
timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
save_file: typing.Optional[bool] = None,
resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout
start_time = self.loop.time()
resolved_time = None
stream = None
error = None
outpoint = None
if save_file is None:
save_file = self.config.save_files
if file_name and not save_file:
save_file = True
if save_file:
download_directory = download_directory or self.config.download_dir
else:
download_directory = None
try:
# resolve the claim
parsed_uri = parse_lbry_uri(uri)
@ -352,6 +342,9 @@ class StreamManager:
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
if updated_stream:
log.info("already have stream for %s", uri)
if save_file and updated_stream.output_file_exists:
save_file = False
await updated_stream.start(node=self.node, timeout=timeout, save_now=save_file)
return updated_stream
content_fee = None
@ -381,30 +374,18 @@ class StreamManager:
log.info("paid fee of %s for %s", fee_amount, uri)
download_directory = download_directory or self.config.download_dir
if not file_name and (not self.config.save_files or not save_file):
download_dir, file_name = None, None
stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee,
analytics_manager=self.analytics_manager
)
log.info("starting download for %s", uri)
try:
await asyncio.wait_for(stream.setup(
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
), timeout, loop=self.loop)
except asyncio.TimeoutError:
if not stream.descriptor:
raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
finally:
if stream.descriptor:
if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
stream.set_claim(resolved, claim)
await self.storage.save_content_claim(stream.stream_hash, outpoint)
self.streams[stream.sd_hash] = stream
await stream.start(self.node, timeout, save_now=save_file)
if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
self.streams[stream.sd_hash] = stream
stream.set_claim(resolved, claim)
await self.storage.save_content_claim(stream.stream_hash, outpoint)
return stream
except DownloadDataTimeout as err: # forgive data timeout, dont delete stream
error = err
@ -435,3 +416,6 @@ class StreamManager:
)
if error:
raise error
async def stream_partial_content(self, request: Request, sd_hash: str):
return await self.streams[sd_hash].stream_file(request, self.node)

View file

@ -40,7 +40,7 @@ class TestManagedStream(BlobExchangeTestBase):
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
@ -51,10 +51,11 @@ class TestManagedStream(BlobExchangeTestBase):
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.save_file(node=mock_node)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download()
if stop_when_done:
await self.stream.stop()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
@ -62,6 +63,18 @@ class TestManagedStream(BlobExchangeTestBase):
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
self.assertEqual(self.stream.status, "finished")
self.assertFalse(self.stream._running.is_set())
async def test_delayed_stop(self):
await self._test_transfer_stream(10, stop_when_done=False)
self.assertEqual(self.stream.status, "finished")
self.assertTrue(self.stream._running.is_set())
await asyncio.sleep(0.5, loop=self.loop)
self.assertTrue(self.stream._running.is_set())
await asyncio.sleep(0.6, loop=self.loop)
self.assertEqual(self.stream.status, "finished")
self.assertFalse(self.stream._running.is_set())
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
@ -85,11 +98,12 @@ class TestManagedStream(BlobExchangeTestBase):
mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.save_file(node=mock_node)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await self.stream.stop()
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
@ -125,7 +139,7 @@ class TestManagedStream(BlobExchangeTestBase):
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
await self.stream.setup()
await self.stream.save_file()
await self.stream.finished_writing.wait()
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))

View file

@ -225,7 +225,7 @@ class TestStreamManager(BlobExchangeTestBase):
)
self.assertEqual(stored_status, "running")
await self.stream_manager.stop_stream(stream)
await stream.stop()
self.assertFalse(stream.finished)
self.assertFalse(stream.running)
@ -235,7 +235,7 @@ class TestStreamManager(BlobExchangeTestBase):
)
self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream)
await stream.save_file(node=self.stream_manager.node)
await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished)