fix/refactor starting and stopping files
-move partial content handling into ManagedStream -add delayed stop test
This commit is contained in:
parent
b134e0c9c9
commit
9099ee2e8e
5 changed files with 222 additions and 165 deletions
|
@ -18,7 +18,7 @@ from torba.client.baseaccount import SingleKey, HierarchicalDeterministic
|
|||
|
||||
from lbrynet import utils
|
||||
from lbrynet.conf import Config, Setting
|
||||
from lbrynet.blob.blob_file import is_valid_blobhash
|
||||
from lbrynet.blob.blob_file import is_valid_blobhash, BlobBuffer
|
||||
from lbrynet.blob_exchange.downloader import download_blob
|
||||
from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted
|
||||
from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet
|
||||
|
@ -477,59 +477,11 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
raise web.HTTPServerError(text=stream['error'])
|
||||
raise web.HTTPFound(f"/stream/{stream.sd_hash}")
|
||||
|
||||
@staticmethod
|
||||
def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str],
|
||||
int, int]:
|
||||
if '=' in get_range:
|
||||
get_range = get_range.split('=')[1]
|
||||
start, end = get_range.split('-')
|
||||
size = 0
|
||||
for blob in stream.descriptor.blobs[:-1]:
|
||||
size += blob.length - 1
|
||||
start = int(start)
|
||||
end = int(end) if end else size - 1
|
||||
skip_blobs = start // 2097150
|
||||
skip = skip_blobs * 2097151
|
||||
start = skip
|
||||
final_size = end - start + 1
|
||||
|
||||
headers = {
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Range': f'bytes {start}-{end}/{size}',
|
||||
'Content-Length': str(final_size),
|
||||
'Content-Type': stream.mime_type
|
||||
}
|
||||
return headers, size, skip_blobs
|
||||
|
||||
async def handle_stream_range_request(self, request: web.Request):
|
||||
sd_hash = request.path.split("/stream/")[1]
|
||||
if sd_hash not in self.stream_manager.streams:
|
||||
return web.HTTPNotFound()
|
||||
stream = self.stream_manager.streams[sd_hash]
|
||||
if stream.status == 'stopped':
|
||||
await self.stream_manager.start_stream(stream)
|
||||
if stream.delayed_stop:
|
||||
stream.delayed_stop.cancel()
|
||||
headers, size, skip_blobs = self.prepare_range_response_headers(
|
||||
request.headers.get('range', 'bytes=0-'), stream
|
||||
)
|
||||
response = web.StreamResponse(
|
||||
status=206,
|
||||
headers=headers
|
||||
)
|
||||
await response.prepare(request)
|
||||
wrote = 0
|
||||
async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs):
|
||||
log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1)
|
||||
if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
|
||||
decrypted += b'\x00' * (size - len(decrypted) - wrote)
|
||||
await response.write_eof(decrypted)
|
||||
break
|
||||
else:
|
||||
await response.write(decrypted)
|
||||
wrote += len(decrypted)
|
||||
response.force_close()
|
||||
return response
|
||||
return await self.stream_manager.stream_partial_content(request, sd_hash)
|
||||
|
||||
async def _process_rpc_call(self, data):
|
||||
args = data.get('params', {})
|
||||
|
@ -924,7 +876,6 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
|
||||
Returns: {File}
|
||||
"""
|
||||
save_file = save_file if save_file is not None else self.conf.save_files
|
||||
try:
|
||||
stream = await self.stream_manager.download_stream_from_uri(
|
||||
uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file
|
||||
|
@ -1554,7 +1505,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
await self.stream_manager.start_stream(stream)
|
||||
msg = "Resumed download"
|
||||
elif status == 'stop' and stream.running:
|
||||
await self.stream_manager.stop_stream(stream)
|
||||
await stream.stop()
|
||||
msg = "Stopped download"
|
||||
else:
|
||||
msg = (
|
||||
|
|
|
@ -3,7 +3,9 @@ import asyncio
|
|||
import typing
|
||||
import logging
|
||||
import binascii
|
||||
from aiohttp.web import Request, StreamResponse
|
||||
from lbrynet.utils import generate_id
|
||||
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout
|
||||
from lbrynet.schema.mime_types import guess_media_type
|
||||
from lbrynet.stream.downloader import StreamDownloader
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
|
@ -40,6 +42,33 @@ class ManagedStream:
|
|||
STATUS_STOPPED = "stopped"
|
||||
STATUS_FINISHED = "finished"
|
||||
|
||||
__slots__ = [
|
||||
'loop',
|
||||
'config',
|
||||
'blob_manager',
|
||||
'sd_hash',
|
||||
'download_directory',
|
||||
'_file_name',
|
||||
'_status',
|
||||
'stream_claim_info',
|
||||
'download_id',
|
||||
'rowid',
|
||||
'written_bytes',
|
||||
'content_fee',
|
||||
'downloader',
|
||||
'analytics_manager',
|
||||
'fully_reflected',
|
||||
'file_output_task',
|
||||
'delayed_stop_task',
|
||||
'streaming_responses',
|
||||
'streaming',
|
||||
'_running',
|
||||
'saving',
|
||||
'finished_writing',
|
||||
'started_writing',
|
||||
|
||||
]
|
||||
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
|
||||
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
|
||||
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
|
||||
|
@ -61,9 +90,13 @@ class ManagedStream:
|
|||
self.content_fee = content_fee
|
||||
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
|
||||
self.analytics_manager = analytics_manager
|
||||
|
||||
self.fully_reflected = asyncio.Event(loop=self.loop)
|
||||
self.file_output_task: typing.Optional[asyncio.Task] = None
|
||||
self.delayed_stop: typing.Optional[asyncio.Handle] = None
|
||||
self.delayed_stop_task: typing.Optional[asyncio.Task] = None
|
||||
self.streaming_responses: typing.List[StreamResponse] = []
|
||||
self.streaming = asyncio.Event(loop=self.loop)
|
||||
self._running = asyncio.Event(loop=self.loop)
|
||||
self.saving = asyncio.Event(loop=self.loop)
|
||||
self.finished_writing = asyncio.Event(loop=self.loop)
|
||||
self.started_writing = asyncio.Event(loop=self.loop)
|
||||
|
@ -84,9 +117,10 @@ class ManagedStream:
|
|||
def status(self) -> str:
|
||||
return self._status
|
||||
|
||||
def update_status(self, status: str):
|
||||
async def update_status(self, status: str):
|
||||
assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED]
|
||||
self._status = status
|
||||
await self.blob_manager.storage.change_file_status(self.stream_hash, status)
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
|
@ -216,47 +250,85 @@ class ManagedStream:
|
|||
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
|
||||
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
|
||||
|
||||
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
|
||||
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
|
||||
await self.downloader.start(node)
|
||||
if not save_file and not file_name:
|
||||
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
||||
self.rowid = await self.blob_manager.storage.save_downloaded_file(
|
||||
self.stream_hash, None, None, 0.0
|
||||
)
|
||||
self.download_directory = None
|
||||
self._file_name = None
|
||||
self.update_status(ManagedStream.STATUS_RUNNING)
|
||||
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
|
||||
self.update_delayed_stop()
|
||||
else:
|
||||
await self.save_file(file_name, download_directory)
|
||||
await self.started_writing.wait()
|
||||
async def start(self, node: typing.Optional['Node'] = None, timeout: typing.Optional[float] = None,
|
||||
save_now: bool = False):
|
||||
timeout = timeout or self.config.download_timeout
|
||||
if self._running.is_set():
|
||||
return
|
||||
self._running.set()
|
||||
start_time = self.loop.time()
|
||||
try:
|
||||
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
|
||||
if save_now:
|
||||
await asyncio.wait_for(self.save_file(node=node), timeout - (self.loop.time() - start_time),
|
||||
loop=self.loop)
|
||||
except asyncio.TimeoutError:
|
||||
self._running.clear()
|
||||
if not self.descriptor:
|
||||
raise DownloadSDTimeout(self.sd_hash)
|
||||
raise DownloadDataTimeout(self.sd_hash)
|
||||
|
||||
def update_delayed_stop(self):
|
||||
def _delayed_stop():
|
||||
log.info("Stopping inactive download for stream %s", self.sd_hash)
|
||||
self.stop_download()
|
||||
if self.delayed_stop_task and not self.delayed_stop_task.done():
|
||||
self.delayed_stop_task.cancel()
|
||||
self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
|
||||
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
||||
if save_now:
|
||||
file_name, download_dir = self._file_name, self.download_directory
|
||||
else:
|
||||
file_name, download_dir = None, None
|
||||
self.rowid = await self.blob_manager.storage.save_downloaded_file(
|
||||
self.stream_hash, file_name, download_dir, 0.0
|
||||
)
|
||||
if self.status != self.STATUS_RUNNING:
|
||||
await self.update_status(self.STATUS_RUNNING)
|
||||
|
||||
if self.delayed_stop:
|
||||
self.delayed_stop.cancel()
|
||||
self.delayed_stop = self.loop.call_later(60, _delayed_stop)
|
||||
async def stop(self, finished: bool = False):
|
||||
"""
|
||||
Stop any running save/stream tasks as well as the downloader and update the status in the database
|
||||
"""
|
||||
|
||||
async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[
|
||||
typing.Tuple['BlobInfo', bytes]]:
|
||||
self.stop_tasks()
|
||||
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
||||
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
||||
|
||||
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\
|
||||
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
|
||||
if start_blob_num >= len(self.descriptor.blobs[:-1]):
|
||||
raise IndexError(start_blob_num)
|
||||
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
|
||||
assert i + start_blob_num == blob_info.blob_num
|
||||
if self.delayed_stop:
|
||||
self.delayed_stop.cancel()
|
||||
try:
|
||||
decrypted = await self.downloader.read_blob(blob_info)
|
||||
yield (blob_info, decrypted)
|
||||
except asyncio.CancelledError:
|
||||
if not self.saving.is_set() and not self.finished_writing.is_set():
|
||||
self.update_delayed_stop()
|
||||
raise
|
||||
decrypted = await self.downloader.read_blob(blob_info)
|
||||
yield (blob_info, decrypted)
|
||||
|
||||
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
|
||||
await self.start(node)
|
||||
headers, size, skip_blobs = self._prepare_range_response_headers(request.headers.get('range', 'bytes=0-'))
|
||||
response = StreamResponse(
|
||||
status=206,
|
||||
headers=headers
|
||||
)
|
||||
await response.prepare(request)
|
||||
self.streaming_responses.append(response)
|
||||
self.streaming.set()
|
||||
try:
|
||||
wrote = 0
|
||||
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs):
|
||||
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
|
||||
decrypted += b'\x00' * (size - len(decrypted) - wrote)
|
||||
await response.write_eof(decrypted)
|
||||
else:
|
||||
await response.write(decrypted)
|
||||
wrote += len(decrypted)
|
||||
log.info("streamed %sblob %i/%i", "(closing stream) " if response._eof_sent else "",
|
||||
blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
|
||||
if response._eof_sent:
|
||||
break
|
||||
return response
|
||||
finally:
|
||||
response.force_close()
|
||||
if response in self.streaming_responses:
|
||||
self.streaming_responses.remove(response)
|
||||
self.streaming.clear()
|
||||
|
||||
async def _save_file(self, output_path: str):
|
||||
log.debug("save file %s -> %s", self.sd_hash, output_path)
|
||||
|
@ -265,15 +337,14 @@ class ManagedStream:
|
|||
self.started_writing.clear()
|
||||
try:
|
||||
with open(output_path, 'wb') as file_write_handle:
|
||||
async for blob_info, decrypted in self.aiter_read_stream():
|
||||
async for blob_info, decrypted in self._aiter_read_stream():
|
||||
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
|
||||
file_write_handle.write(decrypted)
|
||||
file_write_handle.flush()
|
||||
self.written_bytes += len(decrypted)
|
||||
if not self.started_writing.is_set():
|
||||
self.started_writing.set()
|
||||
self.update_status(ManagedStream.STATUS_FINISHED)
|
||||
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
|
||||
await self.update_status(ManagedStream.STATUS_FINISHED)
|
||||
if self.analytics_manager:
|
||||
self.loop.create_task(self.analytics_manager.send_download_finished(
|
||||
self.download_id, self.claim_name, self.sd_hash
|
||||
|
@ -289,12 +360,11 @@ class ManagedStream:
|
|||
finally:
|
||||
self.saving.clear()
|
||||
|
||||
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
|
||||
if self.file_output_task and not self.file_output_task.done():
|
||||
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None,
|
||||
node: typing.Optional['Node'] = None):
|
||||
await self.start(node)
|
||||
if self.file_output_task and not self.file_output_task.done(): # cancel an already running save task
|
||||
self.file_output_task.cancel()
|
||||
if self.delayed_stop:
|
||||
self.delayed_stop.cancel()
|
||||
self.delayed_stop = None
|
||||
self.download_directory = download_directory or self.download_directory or self.config.download_dir
|
||||
if not self.download_directory:
|
||||
raise ValueError("no directory to download to")
|
||||
|
@ -303,28 +373,26 @@ class ManagedStream:
|
|||
if not os.path.isdir(self.download_directory):
|
||||
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
|
||||
os.mkdir(self.download_directory)
|
||||
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
||||
self._file_name = await get_next_available_file_name(
|
||||
self.loop, self.download_directory,
|
||||
file_name or self._file_name or self.descriptor.suggested_file_name
|
||||
)
|
||||
self.rowid = self.blob_manager.storage.save_downloaded_file(
|
||||
self.stream_hash, self.file_name, self.download_directory, 0.0
|
||||
)
|
||||
else:
|
||||
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
||||
self.stream_hash, self.download_directory, self.file_name
|
||||
)
|
||||
self.update_status(ManagedStream.STATUS_RUNNING)
|
||||
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
|
||||
self._file_name = await get_next_available_file_name(
|
||||
self.loop, self.download_directory,
|
||||
file_name or self._file_name or self.descriptor.suggested_file_name
|
||||
)
|
||||
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
||||
self.stream_hash, self.download_directory, self.file_name
|
||||
)
|
||||
await self.update_status(ManagedStream.STATUS_RUNNING)
|
||||
self.written_bytes = 0
|
||||
self.file_output_task = self.loop.create_task(self._save_file(self.full_path))
|
||||
await self.started_writing.wait()
|
||||
|
||||
def stop_download(self):
|
||||
def stop_tasks(self):
|
||||
if self.file_output_task and not self.file_output_task.done():
|
||||
self.file_output_task.cancel()
|
||||
self.file_output_task = None
|
||||
while self.streaming_responses:
|
||||
self.streaming_responses.pop().force_close()
|
||||
self.downloader.stop()
|
||||
self._running.clear()
|
||||
|
||||
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
|
||||
sent = []
|
||||
|
@ -365,3 +433,43 @@ class ManagedStream:
|
|||
binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'],
|
||||
claim_info['claim_sequence'], claim_info.get('channel_name')
|
||||
)
|
||||
|
||||
async def update_content_claim(self, claim_info: typing.Optional[typing.Dict] = None):
|
||||
if not claim_info:
|
||||
claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash)
|
||||
self.set_claim(claim_info, claim_info['value'])
|
||||
|
||||
async def _delayed_stop(self):
|
||||
stalled_count = 0
|
||||
while self._running.is_set():
|
||||
if self.saving.is_set() or self.streaming.is_set():
|
||||
stalled_count = 0
|
||||
else:
|
||||
stalled_count += 1
|
||||
if stalled_count > 1:
|
||||
log.info("Stopping inactive download for stream %s", self.sd_hash)
|
||||
await self.stop()
|
||||
return
|
||||
await asyncio.sleep(1, loop=self.loop)
|
||||
|
||||
def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int]:
|
||||
if '=' in get_range:
|
||||
get_range = get_range.split('=')[1]
|
||||
start, end = get_range.split('-')
|
||||
size = 0
|
||||
for blob in self.descriptor.blobs[:-1]:
|
||||
size += blob.length - 1
|
||||
start = int(start)
|
||||
end = int(end) if end else size - 1
|
||||
skip_blobs = start // 2097150
|
||||
skip = skip_blobs * 2097151
|
||||
start = skip
|
||||
final_size = end - start + 1
|
||||
|
||||
headers = {
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Range': f'bytes {start}-{end}/{size}',
|
||||
'Content-Length': str(final_size),
|
||||
'Content-Type': self.mime_type
|
||||
}
|
||||
return headers, size, skip_blobs
|
||||
|
|
|
@ -5,8 +5,9 @@ import binascii
|
|||
import logging
|
||||
import random
|
||||
from decimal import Decimal
|
||||
from aiohttp.web import Request
|
||||
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
|
||||
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
|
||||
from lbrynet.error import ResolveTimeout, DownloadDataTimeout
|
||||
from lbrynet.utils import cache_concurrent
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.managed_stream import ManagedStream
|
||||
|
@ -56,6 +57,7 @@ comparison_operators = {
|
|||
def path_or_none(p) -> typing.Optional[str]:
|
||||
return None if p == '{stream}' else binascii.unhexlify(p).decode()
|
||||
|
||||
|
||||
class StreamManager:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
|
||||
wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'],
|
||||
|
@ -77,24 +79,6 @@ class StreamManager:
|
|||
claim_info = await self.storage.get_content_claim(stream.stream_hash)
|
||||
self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value'])
|
||||
|
||||
async def stop_stream(self, stream: ManagedStream):
|
||||
stream.stop_download()
|
||||
if not stream.finished and stream.output_file_exists:
|
||||
try:
|
||||
os.remove(stream.full_path)
|
||||
except OSError as err:
|
||||
log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path,
|
||||
str(err))
|
||||
if stream.running:
|
||||
stream.update_status(ManagedStream.STATUS_STOPPED)
|
||||
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
|
||||
|
||||
async def start_stream(self, stream: ManagedStream):
|
||||
stream.update_status(ManagedStream.STATUS_RUNNING)
|
||||
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
|
||||
await stream.setup(self.node, save_file=self.config.save_files)
|
||||
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
||||
|
||||
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
|
||||
to_restore = []
|
||||
|
||||
|
@ -150,6 +134,7 @@ class StreamManager:
|
|||
await self.recover_streams(to_recover)
|
||||
|
||||
if not self.config.save_files:
|
||||
# set files that have been deleted manually to streaming mode
|
||||
to_set_as_streaming = []
|
||||
for file_info in to_start:
|
||||
file_name = path_or_none(file_info['file_name'])
|
||||
|
@ -176,7 +161,7 @@ class StreamManager:
|
|||
if not self.node:
|
||||
log.warning("no DHT node given, resuming downloads trusting that we can contact reflector")
|
||||
t = [
|
||||
self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values()
|
||||
self.loop.create_task(stream.start(node=self.node)) for stream in self.streams.values()
|
||||
if stream.running
|
||||
]
|
||||
if t:
|
||||
|
@ -214,7 +199,7 @@ class StreamManager:
|
|||
self.re_reflect_task.cancel()
|
||||
while self.streams:
|
||||
_, stream = self.streams.popitem()
|
||||
stream.stop_download()
|
||||
stream.stop_tasks()
|
||||
while self.update_stream_finished_futs:
|
||||
self.update_stream_finished_futs.pop().cancel()
|
||||
while self.running_reflector_uploads:
|
||||
|
@ -236,7 +221,7 @@ class StreamManager:
|
|||
return stream
|
||||
|
||||
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
||||
await self.stop_stream(stream)
|
||||
stream.stop_tasks()
|
||||
if stream.sd_hash in self.streams:
|
||||
del self.streams[stream.sd_hash]
|
||||
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
|
||||
|
@ -290,21 +275,16 @@ class StreamManager:
|
|||
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
|
||||
existing = self.get_filtered_streams(outpoint=outpoint)
|
||||
if existing:
|
||||
if existing[0].status == ManagedStream.STATUS_STOPPED:
|
||||
await self.start_stream(existing[0])
|
||||
return existing[0], None
|
||||
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
|
||||
if existing and existing[0].claim_id != claim_id:
|
||||
raise ResolveError(f"stream for {existing[0].claim_id} collides with existing "
|
||||
f"download {claim_id}")
|
||||
raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {claim_id}")
|
||||
if existing:
|
||||
log.info("claim contains a metadata only update to a stream we have")
|
||||
await self.storage.save_content_claim(
|
||||
existing[0].stream_hash, outpoint
|
||||
)
|
||||
await self._update_content_claim(existing[0])
|
||||
if not existing[0].running:
|
||||
await self.start_stream(existing[0])
|
||||
return existing[0], None
|
||||
else:
|
||||
existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id)
|
||||
|
@ -318,13 +298,23 @@ class StreamManager:
|
|||
timeout: typing.Optional[float] = None,
|
||||
file_name: typing.Optional[str] = None,
|
||||
download_directory: typing.Optional[str] = None,
|
||||
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
|
||||
save_file: typing.Optional[bool] = None,
|
||||
resolve_timeout: float = 3.0) -> ManagedStream:
|
||||
timeout = timeout or self.config.download_timeout
|
||||
start_time = self.loop.time()
|
||||
resolved_time = None
|
||||
stream = None
|
||||
error = None
|
||||
outpoint = None
|
||||
if save_file is None:
|
||||
save_file = self.config.save_files
|
||||
if file_name and not save_file:
|
||||
save_file = True
|
||||
if save_file:
|
||||
download_directory = download_directory or self.config.download_dir
|
||||
else:
|
||||
download_directory = None
|
||||
|
||||
try:
|
||||
# resolve the claim
|
||||
parsed_uri = parse_lbry_uri(uri)
|
||||
|
@ -352,6 +342,9 @@ class StreamManager:
|
|||
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
|
||||
if updated_stream:
|
||||
log.info("already have stream for %s", uri)
|
||||
if save_file and updated_stream.output_file_exists:
|
||||
save_file = False
|
||||
await updated_stream.start(node=self.node, timeout=timeout, save_now=save_file)
|
||||
return updated_stream
|
||||
|
||||
content_fee = None
|
||||
|
@ -381,30 +374,18 @@ class StreamManager:
|
|||
|
||||
log.info("paid fee of %s for %s", fee_amount, uri)
|
||||
|
||||
download_directory = download_directory or self.config.download_dir
|
||||
if not file_name and (not self.config.save_files or not save_file):
|
||||
download_dir, file_name = None, None
|
||||
stream = ManagedStream(
|
||||
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
|
||||
file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee,
|
||||
analytics_manager=self.analytics_manager
|
||||
)
|
||||
log.info("starting download for %s", uri)
|
||||
try:
|
||||
await asyncio.wait_for(stream.setup(
|
||||
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
|
||||
), timeout, loop=self.loop)
|
||||
except asyncio.TimeoutError:
|
||||
if not stream.descriptor:
|
||||
raise DownloadSDTimeout(stream.sd_hash)
|
||||
raise DownloadDataTimeout(stream.sd_hash)
|
||||
finally:
|
||||
if stream.descriptor:
|
||||
if to_replace: # delete old stream now that the replacement has started downloading
|
||||
await self.delete_stream(to_replace)
|
||||
stream.set_claim(resolved, claim)
|
||||
await self.storage.save_content_claim(stream.stream_hash, outpoint)
|
||||
self.streams[stream.sd_hash] = stream
|
||||
await stream.start(self.node, timeout, save_now=save_file)
|
||||
if to_replace: # delete old stream now that the replacement has started downloading
|
||||
await self.delete_stream(to_replace)
|
||||
self.streams[stream.sd_hash] = stream
|
||||
stream.set_claim(resolved, claim)
|
||||
await self.storage.save_content_claim(stream.stream_hash, outpoint)
|
||||
return stream
|
||||
except DownloadDataTimeout as err: # forgive data timeout, dont delete stream
|
||||
error = err
|
||||
|
@ -435,3 +416,6 @@ class StreamManager:
|
|||
)
|
||||
if error:
|
||||
raise error
|
||||
|
||||
async def stream_partial_content(self, request: Request, sd_hash: str):
|
||||
return await self.streams[sd_hash].stream_file(request, self.node)
|
||||
|
|
|
@ -40,7 +40,7 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
|
||||
)
|
||||
|
||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
|
||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True):
|
||||
await self.setup_stream(blob_count)
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
|
||||
|
@ -51,10 +51,11 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
return q2, self.loop.create_task(_task())
|
||||
|
||||
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
||||
await self.stream.setup(mock_node, save_file=True)
|
||||
await self.stream.save_file(node=mock_node)
|
||||
await self.stream.finished_writing.wait()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
self.stream.stop_download()
|
||||
if stop_when_done:
|
||||
await self.stream.stop()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
with open(self.stream.full_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
|
@ -62,6 +63,18 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
|
||||
async def test_transfer_stream(self):
|
||||
await self._test_transfer_stream(10)
|
||||
self.assertEqual(self.stream.status, "finished")
|
||||
self.assertFalse(self.stream._running.is_set())
|
||||
|
||||
async def test_delayed_stop(self):
|
||||
await self._test_transfer_stream(10, stop_when_done=False)
|
||||
self.assertEqual(self.stream.status, "finished")
|
||||
self.assertTrue(self.stream._running.is_set())
|
||||
await asyncio.sleep(0.5, loop=self.loop)
|
||||
self.assertTrue(self.stream._running.is_set())
|
||||
await asyncio.sleep(0.6, loop=self.loop)
|
||||
self.assertEqual(self.stream.status, "finished")
|
||||
self.assertFalse(self.stream._running.is_set())
|
||||
|
||||
@unittest.SkipTest
|
||||
async def test_transfer_hundred_blob_stream(self):
|
||||
|
@ -85,11 +98,12 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
|
||||
mock_node.accumulate_peers = _mock_accumulate_peers
|
||||
|
||||
await self.stream.setup(mock_node, save_file=True)
|
||||
await self.stream.save_file(node=mock_node)
|
||||
await self.stream.finished_writing.wait()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
with open(self.stream.full_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
await self.stream.stop()
|
||||
# self.assertIs(self.server_from_client.tcp_last_down, None)
|
||||
# self.assertIsNot(bad_peer.tcp_last_down, None)
|
||||
|
||||
|
@ -125,7 +139,7 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
|
||||
handle.truncate()
|
||||
handle.flush()
|
||||
await self.stream.setup()
|
||||
await self.stream.save_file()
|
||||
await self.stream.finished_writing.wait()
|
||||
if corrupt:
|
||||
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||
|
|
|
@ -225,7 +225,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
)
|
||||
self.assertEqual(stored_status, "running")
|
||||
|
||||
await self.stream_manager.stop_stream(stream)
|
||||
await stream.stop()
|
||||
|
||||
self.assertFalse(stream.finished)
|
||||
self.assertFalse(stream.running)
|
||||
|
@ -235,7 +235,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
)
|
||||
self.assertEqual(stored_status, "stopped")
|
||||
|
||||
await self.stream_manager.start_stream(stream)
|
||||
await stream.save_file(node=self.stream_manager.node)
|
||||
await stream.finished_writing.wait()
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
self.assertTrue(stream.finished)
|
||||
|
|
Loading…
Reference in a new issue