fix/refactor starting and stopping files

-move partial content handling into ManagedStream
-add delayed stop test
This commit is contained in:
Jack Robison 2019-05-01 17:09:50 -04:00
parent b134e0c9c9
commit 9099ee2e8e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 222 additions and 165 deletions

View file

@ -18,7 +18,7 @@ from torba.client.baseaccount import SingleKey, HierarchicalDeterministic
from lbrynet import utils from lbrynet import utils
from lbrynet.conf import Config, Setting from lbrynet.conf import Config, Setting
from lbrynet.blob.blob_file import is_valid_blobhash from lbrynet.blob.blob_file import is_valid_blobhash, BlobBuffer
from lbrynet.blob_exchange.downloader import download_blob from lbrynet.blob_exchange.downloader import download_blob
from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted
from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet
@ -477,59 +477,11 @@ class Daemon(metaclass=JSONRPCServerType):
raise web.HTTPServerError(text=stream['error']) raise web.HTTPServerError(text=stream['error'])
raise web.HTTPFound(f"/stream/{stream.sd_hash}") raise web.HTTPFound(f"/stream/{stream.sd_hash}")
@staticmethod
def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str],
int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in stream.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': stream.mime_type
}
return headers, size, skip_blobs
async def handle_stream_range_request(self, request: web.Request): async def handle_stream_range_request(self, request: web.Request):
sd_hash = request.path.split("/stream/")[1] sd_hash = request.path.split("/stream/")[1]
if sd_hash not in self.stream_manager.streams: if sd_hash not in self.stream_manager.streams:
return web.HTTPNotFound() return web.HTTPNotFound()
stream = self.stream_manager.streams[sd_hash] return await self.stream_manager.stream_partial_content(request, sd_hash)
if stream.status == 'stopped':
await self.stream_manager.start_stream(stream)
if stream.delayed_stop:
stream.delayed_stop.cancel()
headers, size, skip_blobs = self.prepare_range_response_headers(
request.headers.get('range', 'bytes=0-'), stream
)
response = web.StreamResponse(
status=206,
headers=headers
)
await response.prepare(request)
wrote = 0
async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs):
log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1)
if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += b'\x00' * (size - len(decrypted) - wrote)
await response.write_eof(decrypted)
break
else:
await response.write(decrypted)
wrote += len(decrypted)
response.force_close()
return response
async def _process_rpc_call(self, data): async def _process_rpc_call(self, data):
args = data.get('params', {}) args = data.get('params', {})
@ -924,7 +876,6 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: {File} Returns: {File}
""" """
save_file = save_file if save_file is not None else self.conf.save_files
try: try:
stream = await self.stream_manager.download_stream_from_uri( stream = await self.stream_manager.download_stream_from_uri(
uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file
@ -1554,7 +1505,7 @@ class Daemon(metaclass=JSONRPCServerType):
await self.stream_manager.start_stream(stream) await self.stream_manager.start_stream(stream)
msg = "Resumed download" msg = "Resumed download"
elif status == 'stop' and stream.running: elif status == 'stop' and stream.running:
await self.stream_manager.stop_stream(stream) await stream.stop()
msg = "Stopped download" msg = "Stopped download"
else: else:
msg = ( msg = (

View file

@ -3,7 +3,9 @@ import asyncio
import typing import typing
import logging import logging
import binascii import binascii
from aiohttp.web import Request, StreamResponse
from lbrynet.utils import generate_id from lbrynet.utils import generate_id
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout
from lbrynet.schema.mime_types import guess_media_type from lbrynet.schema.mime_types import guess_media_type
from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
@ -40,6 +42,33 @@ class ManagedStream:
STATUS_STOPPED = "stopped" STATUS_STOPPED = "stopped"
STATUS_FINISHED = "finished" STATUS_FINISHED = "finished"
__slots__ = [
'loop',
'config',
'blob_manager',
'sd_hash',
'download_directory',
'_file_name',
'_status',
'stream_claim_info',
'download_id',
'rowid',
'written_bytes',
'content_fee',
'downloader',
'analytics_manager',
'fully_reflected',
'file_output_task',
'delayed_stop_task',
'streaming_responses',
'streaming',
'_running',
'saving',
'finished_writing',
'started_writing',
]
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
@ -61,9 +90,13 @@ class ManagedStream:
self.content_fee = content_fee self.content_fee = content_fee
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop) self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None self.delayed_stop_task: typing.Optional[asyncio.Task] = None
self.streaming_responses: typing.List[StreamResponse] = []
self.streaming = asyncio.Event(loop=self.loop)
self._running = asyncio.Event(loop=self.loop)
self.saving = asyncio.Event(loop=self.loop) self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop) self.started_writing = asyncio.Event(loop=self.loop)
@ -84,9 +117,10 @@ class ManagedStream:
def status(self) -> str: def status(self) -> str:
return self._status return self._status
def update_status(self, status: str): async def update_status(self, status: str):
assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED] assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED]
self._status = status self._status = status
await self.blob_manager.storage.change_file_status(self.stream_hash, status)
@property @property
def finished(self) -> bool: def finished(self) -> bool:
@ -216,47 +250,85 @@ class ManagedStream:
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True, async def start(self, node: typing.Optional['Node'] = None, timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): save_now: bool = False):
await self.downloader.start(node) timeout = timeout or self.config.download_timeout
if not save_file and not file_name: if self._running.is_set():
if not await self.blob_manager.storage.file_exists(self.sd_hash): return
self.rowid = await self.blob_manager.storage.save_downloaded_file( self._running.set()
self.stream_hash, None, None, 0.0 start_time = self.loop.time()
) try:
self.download_directory = None await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
self._file_name = None if save_now:
self.update_status(ManagedStream.STATUS_RUNNING) await asyncio.wait_for(self.save_file(node=node), timeout - (self.loop.time() - start_time),
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) loop=self.loop)
self.update_delayed_stop() except asyncio.TimeoutError:
else: self._running.clear()
await self.save_file(file_name, download_directory) if not self.descriptor:
await self.started_writing.wait() raise DownloadSDTimeout(self.sd_hash)
raise DownloadDataTimeout(self.sd_hash)
def update_delayed_stop(self): if self.delayed_stop_task and not self.delayed_stop_task.done():
def _delayed_stop(): self.delayed_stop_task.cancel()
log.info("Stopping inactive download for stream %s", self.sd_hash) self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
self.stop_download() if not await self.blob_manager.storage.file_exists(self.sd_hash):
if save_now:
file_name, download_dir = self._file_name, self.download_directory
else:
file_name, download_dir = None, None
self.rowid = await self.blob_manager.storage.save_downloaded_file(
self.stream_hash, file_name, download_dir, 0.0
)
if self.status != self.STATUS_RUNNING:
await self.update_status(self.STATUS_RUNNING)
if self.delayed_stop: async def stop(self, finished: bool = False):
self.delayed_stop.cancel() """
self.delayed_stop = self.loop.call_later(60, _delayed_stop) Stop any running save/stream tasks as well as the downloader and update the status in the database
"""
async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[ self.stop_tasks()
typing.Tuple['BlobInfo', bytes]]: if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
if start_blob_num >= len(self.descriptor.blobs[:-1]): if start_blob_num >= len(self.descriptor.blobs[:-1]):
raise IndexError(start_blob_num) raise IndexError(start_blob_num)
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]): for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
assert i + start_blob_num == blob_info.blob_num assert i + start_blob_num == blob_info.blob_num
if self.delayed_stop: decrypted = await self.downloader.read_blob(blob_info)
self.delayed_stop.cancel() yield (blob_info, decrypted)
try:
decrypted = await self.downloader.read_blob(blob_info) async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
yield (blob_info, decrypted) await self.start(node)
except asyncio.CancelledError: headers, size, skip_blobs = self._prepare_range_response_headers(request.headers.get('range', 'bytes=0-'))
if not self.saving.is_set() and not self.finished_writing.is_set(): response = StreamResponse(
self.update_delayed_stop() status=206,
raise headers=headers
)
await response.prepare(request)
self.streaming_responses.append(response)
self.streaming.set()
try:
wrote = 0
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs):
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += b'\x00' * (size - len(decrypted) - wrote)
await response.write_eof(decrypted)
else:
await response.write(decrypted)
wrote += len(decrypted)
log.info("streamed %sblob %i/%i", "(closing stream) " if response._eof_sent else "",
blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
if response._eof_sent:
break
return response
finally:
response.force_close()
if response in self.streaming_responses:
self.streaming_responses.remove(response)
self.streaming.clear()
async def _save_file(self, output_path: str): async def _save_file(self, output_path: str):
log.debug("save file %s -> %s", self.sd_hash, output_path) log.debug("save file %s -> %s", self.sd_hash, output_path)
@ -265,15 +337,14 @@ class ManagedStream:
self.started_writing.clear() self.started_writing.clear()
try: try:
with open(output_path, 'wb') as file_write_handle: with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream(): async for blob_info, decrypted in self._aiter_read_stream():
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted) file_write_handle.write(decrypted)
file_write_handle.flush() file_write_handle.flush()
self.written_bytes += len(decrypted) self.written_bytes += len(decrypted)
if not self.started_writing.is_set(): if not self.started_writing.is_set():
self.started_writing.set() self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED) await self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
if self.analytics_manager: if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished( self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash self.download_id, self.claim_name, self.sd_hash
@ -289,12 +360,11 @@ class ManagedStream:
finally: finally:
self.saving.clear() self.saving.clear()
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None,
if self.file_output_task and not self.file_output_task.done(): node: typing.Optional['Node'] = None):
await self.start(node)
if self.file_output_task and not self.file_output_task.done(): # cancel an already running save task
self.file_output_task.cancel() self.file_output_task.cancel()
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = None
self.download_directory = download_directory or self.download_directory or self.config.download_dir self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
@ -303,28 +373,26 @@ class ManagedStream:
if not os.path.isdir(self.download_directory): if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory) os.mkdir(self.download_directory)
if not await self.blob_manager.storage.file_exists(self.sd_hash): self._file_name = await get_next_available_file_name(
self._file_name = await get_next_available_file_name( self.loop, self.download_directory,
self.loop, self.download_directory, file_name or self._file_name or self.descriptor.suggested_file_name
file_name or self._file_name or self.descriptor.suggested_file_name )
) await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.rowid = self.blob_manager.storage.save_downloaded_file( self.stream_hash, self.download_directory, self.file_name
self.stream_hash, self.file_name, self.download_directory, 0.0 )
) await self.update_status(ManagedStream.STATUS_RUNNING)
else:
await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name
)
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self.written_bytes = 0 self.written_bytes = 0
self.file_output_task = self.loop.create_task(self._save_file(self.full_path)) self.file_output_task = self.loop.create_task(self._save_file(self.full_path))
await self.started_writing.wait()
def stop_download(self): def stop_tasks(self):
if self.file_output_task and not self.file_output_task.done(): if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel() self.file_output_task.cancel()
self.file_output_task = None self.file_output_task = None
while self.streaming_responses:
self.streaming_responses.pop().force_close()
self.downloader.stop() self.downloader.stop()
self._running.clear()
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]: async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
sent = [] sent = []
@ -365,3 +433,43 @@ class ManagedStream:
binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'], binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'],
claim_info['claim_sequence'], claim_info.get('channel_name') claim_info['claim_sequence'], claim_info.get('channel_name')
) )
async def update_content_claim(self, claim_info: typing.Optional[typing.Dict] = None):
if not claim_info:
claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash)
self.set_claim(claim_info, claim_info['value'])
async def _delayed_stop(self):
stalled_count = 0
while self._running.is_set():
if self.saving.is_set() or self.streaming.is_set():
stalled_count = 0
else:
stalled_count += 1
if stalled_count > 1:
log.info("Stopping inactive download for stream %s", self.sd_hash)
await self.stop()
return
await asyncio.sleep(1, loop=self.loop)
def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in self.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': self.mime_type
}
return headers, size, skip_blobs

View file

@ -5,8 +5,9 @@ import binascii
import logging import logging
import random import random
from decimal import Decimal from decimal import Decimal
from aiohttp.web import Request
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout from lbrynet.error import ResolveTimeout, DownloadDataTimeout
from lbrynet.utils import cache_concurrent from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
@ -56,6 +57,7 @@ comparison_operators = {
def path_or_none(p) -> typing.Optional[str]: def path_or_none(p) -> typing.Optional[str]:
return None if p == '{stream}' else binascii.unhexlify(p).decode() return None if p == '{stream}' else binascii.unhexlify(p).decode()
class StreamManager: class StreamManager:
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'], wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'],
@ -77,24 +79,6 @@ class StreamManager:
claim_info = await self.storage.get_content_claim(stream.stream_hash) claim_info = await self.storage.get_content_claim(stream.stream_hash)
self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value']) self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value'])
async def stop_stream(self, stream: ManagedStream):
stream.stop_download()
if not stream.finished and stream.output_file_exists:
try:
os.remove(stream.full_path)
except OSError as err:
log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path,
str(err))
if stream.running:
stream.update_status(ManagedStream.STATUS_STOPPED)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
async def start_stream(self, stream: ManagedStream):
stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
await stream.setup(self.node, save_file=self.config.save_files)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]): async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = [] to_restore = []
@ -150,6 +134,7 @@ class StreamManager:
await self.recover_streams(to_recover) await self.recover_streams(to_recover)
if not self.config.save_files: if not self.config.save_files:
# set files that have been deleted manually to streaming mode
to_set_as_streaming = [] to_set_as_streaming = []
for file_info in to_start: for file_info in to_start:
file_name = path_or_none(file_info['file_name']) file_name = path_or_none(file_info['file_name'])
@ -176,7 +161,7 @@ class StreamManager:
if not self.node: if not self.node:
log.warning("no DHT node given, resuming downloads trusting that we can contact reflector") log.warning("no DHT node given, resuming downloads trusting that we can contact reflector")
t = [ t = [
self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values() self.loop.create_task(stream.start(node=self.node)) for stream in self.streams.values()
if stream.running if stream.running
] ]
if t: if t:
@ -214,7 +199,7 @@ class StreamManager:
self.re_reflect_task.cancel() self.re_reflect_task.cancel()
while self.streams: while self.streams:
_, stream = self.streams.popitem() _, stream = self.streams.popitem()
stream.stop_download() stream.stop_tasks()
while self.update_stream_finished_futs: while self.update_stream_finished_futs:
self.update_stream_finished_futs.pop().cancel() self.update_stream_finished_futs.pop().cancel()
while self.running_reflector_uploads: while self.running_reflector_uploads:
@ -236,7 +221,7 @@ class StreamManager:
return stream return stream
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
await self.stop_stream(stream) stream.stop_tasks()
if stream.sd_hash in self.streams: if stream.sd_hash in self.streams:
del self.streams[stream.sd_hash] del self.streams[stream.sd_hash]
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]] blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
@ -290,21 +275,16 @@ class StreamManager:
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint) existing = self.get_filtered_streams(outpoint=outpoint)
if existing: if existing:
if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0])
return existing[0], None return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
if existing and existing[0].claim_id != claim_id: if existing and existing[0].claim_id != claim_id:
raise ResolveError(f"stream for {existing[0].claim_id} collides with existing " raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {claim_id}")
f"download {claim_id}")
if existing: if existing:
log.info("claim contains a metadata only update to a stream we have") log.info("claim contains a metadata only update to a stream we have")
await self.storage.save_content_claim( await self.storage.save_content_claim(
existing[0].stream_hash, outpoint existing[0].stream_hash, outpoint
) )
await self._update_content_claim(existing[0]) await self._update_content_claim(existing[0])
if not existing[0].running:
await self.start_stream(existing[0])
return existing[0], None return existing[0], None
else: else:
existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id) existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id)
@ -318,13 +298,23 @@ class StreamManager:
timeout: typing.Optional[float] = None, timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None, download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream: save_file: typing.Optional[bool] = None,
resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout timeout = timeout or self.config.download_timeout
start_time = self.loop.time() start_time = self.loop.time()
resolved_time = None resolved_time = None
stream = None stream = None
error = None error = None
outpoint = None outpoint = None
if save_file is None:
save_file = self.config.save_files
if file_name and not save_file:
save_file = True
if save_file:
download_directory = download_directory or self.config.download_dir
else:
download_directory = None
try: try:
# resolve the claim # resolve the claim
parsed_uri = parse_lbry_uri(uri) parsed_uri = parse_lbry_uri(uri)
@ -352,6 +342,9 @@ class StreamManager:
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
if updated_stream: if updated_stream:
log.info("already have stream for %s", uri) log.info("already have stream for %s", uri)
if save_file and updated_stream.output_file_exists:
save_file = False
await updated_stream.start(node=self.node, timeout=timeout, save_now=save_file)
return updated_stream return updated_stream
content_fee = None content_fee = None
@ -381,30 +374,18 @@ class StreamManager:
log.info("paid fee of %s for %s", fee_amount, uri) log.info("paid fee of %s for %s", fee_amount, uri)
download_directory = download_directory or self.config.download_dir
if not file_name and (not self.config.save_files or not save_file):
download_dir, file_name = None, None
stream = ManagedStream( stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory, self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee, file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee,
analytics_manager=self.analytics_manager analytics_manager=self.analytics_manager
) )
log.info("starting download for %s", uri) log.info("starting download for %s", uri)
try: await stream.start(self.node, timeout, save_now=save_file)
await asyncio.wait_for(stream.setup( if to_replace: # delete old stream now that the replacement has started downloading
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory await self.delete_stream(to_replace)
), timeout, loop=self.loop) self.streams[stream.sd_hash] = stream
except asyncio.TimeoutError: stream.set_claim(resolved, claim)
if not stream.descriptor: await self.storage.save_content_claim(stream.stream_hash, outpoint)
raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
finally:
if stream.descriptor:
if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
stream.set_claim(resolved, claim)
await self.storage.save_content_claim(stream.stream_hash, outpoint)
self.streams[stream.sd_hash] = stream
return stream return stream
except DownloadDataTimeout as err: # forgive data timeout, dont delete stream except DownloadDataTimeout as err: # forgive data timeout, dont delete stream
error = err error = err
@ -435,3 +416,6 @@ class StreamManager:
) )
if error: if error:
raise error raise error
async def stream_partial_content(self, request: Request, sd_hash: str):
return await self.streams[sd_hash].stream_file(request, self.node)

View file

@ -40,7 +40,7 @@ class TestManagedStream(BlobExchangeTestBase):
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
) )
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True):
await self.setup_stream(blob_count) await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node) mock_node = mock.Mock(spec=Node)
@ -51,10 +51,11 @@ class TestManagedStream(BlobExchangeTestBase):
return q2, self.loop.create_task(_task()) return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True) await self.stream.save_file(node=mock_node)
await self.stream.finished_writing.wait() await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download() if stop_when_done:
await self.stream.stop()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f: with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes) self.assertEqual(f.read(), self.stream_bytes)
@ -62,6 +63,18 @@ class TestManagedStream(BlobExchangeTestBase):
async def test_transfer_stream(self): async def test_transfer_stream(self):
await self._test_transfer_stream(10) await self._test_transfer_stream(10)
self.assertEqual(self.stream.status, "finished")
self.assertFalse(self.stream._running.is_set())
async def test_delayed_stop(self):
await self._test_transfer_stream(10, stop_when_done=False)
self.assertEqual(self.stream.status, "finished")
self.assertTrue(self.stream._running.is_set())
await asyncio.sleep(0.5, loop=self.loop)
self.assertTrue(self.stream._running.is_set())
await asyncio.sleep(0.6, loop=self.loop)
self.assertEqual(self.stream.status, "finished")
self.assertFalse(self.stream._running.is_set())
@unittest.SkipTest @unittest.SkipTest
async def test_transfer_hundred_blob_stream(self): async def test_transfer_hundred_blob_stream(self):
@ -85,11 +98,12 @@ class TestManagedStream(BlobExchangeTestBase):
mock_node.accumulate_peers = _mock_accumulate_peers mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True) await self.stream.save_file(node=mock_node)
await self.stream.finished_writing.wait() await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f: with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes) self.assertEqual(f.read(), self.stream_bytes)
await self.stream.stop()
# self.assertIs(self.server_from_client.tcp_last_down, None) # self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None) # self.assertIsNot(bad_peer.tcp_last_down, None)
@ -125,7 +139,7 @@ class TestManagedStream(BlobExchangeTestBase):
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle: with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate() handle.truncate()
handle.flush() handle.flush()
await self.stream.setup() await self.stream.save_file()
await self.stream.finished_writing.wait() await self.stream.finished_writing.wait()
if corrupt: if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))

View file

@ -225,7 +225,7 @@ class TestStreamManager(BlobExchangeTestBase):
) )
self.assertEqual(stored_status, "running") self.assertEqual(stored_status, "running")
await self.stream_manager.stop_stream(stream) await stream.stop()
self.assertFalse(stream.finished) self.assertFalse(stream.finished)
self.assertFalse(stream.running) self.assertFalse(stream.running)
@ -235,7 +235,7 @@ class TestStreamManager(BlobExchangeTestBase):
) )
self.assertEqual(stored_status, "stopped") self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream) await stream.save_file(node=self.stream_manager.node)
await stream.finished_writing.wait() await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished) self.assertTrue(stream.finished)