Merge pull request #2065 from lbryio/streaming-bug-fixes

Streaming bug fixes
This commit is contained in:
Jack Robison 2019-05-07 10:10:39 -04:00 committed by GitHub
commit dd665a758d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 348 additions and 233 deletions

View file

@ -2,6 +2,13 @@ import typing
class BlobInfo: class BlobInfo:
__slots__ = [
'blob_hash',
'blob_num',
'length',
'iv',
]
def __init__(self, blob_num: int, length: int, iv: str, blob_hash: typing.Optional[str] = None): def __init__(self, blob_num: int, length: int, iv: str, blob_hash: typing.Optional[str] = None):
self.blob_hash = blob_hash self.blob_hash = blob_hash
self.blob_num = blob_num self.blob_num = blob_num

View file

@ -4,6 +4,7 @@ import typing
import binascii import binascii
from lbrynet.error import InvalidBlobHashError, InvalidDataError from lbrynet.error import InvalidBlobHashError, InvalidDataError
from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest
from lbrynet.utils import cache_concurrent
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbrynet.blob.blob_file import AbstractBlob from lbrynet.blob.blob_file import AbstractBlob
from lbrynet.blob.writer import HashBlobWriter from lbrynet.blob.writer import HashBlobWriter
@ -158,8 +159,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
return await self._download_blob() return await self._download_blob()
except OSError as e: except OSError as e:
# i'm not sure how to fix this race condition - jack # i'm not sure how to fix this race condition - jack
log.exception("race happened downloading %s from %s:%i", blob_hash, self.peer_address, self.peer_port) log.warning("race happened downloading %s from %s:%i", blob_hash, self.peer_address, self.peer_port)
return self._blob_bytes_received, self.transport # return self._blob_bytes_received, self.transport
raise
except asyncio.TimeoutError: except asyncio.TimeoutError:
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
self._response_fut.cancel() self._response_fut.cancel()
@ -184,9 +186,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.close() self.close()
@cache_concurrent
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int, async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float, peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\ connected_transport: asyncio.Transport = None, connection_id: int = 0)\
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]: -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
""" """
Returns [<downloaded blob>, <keep connection>] Returns [<downloaded blob>, <keep connection>]

View file

@ -27,13 +27,14 @@ class BlobDownloader:
self.scores: typing.Dict['KademliaPeer', int] = {} self.scores: typing.Dict['KademliaPeer', int] = {}
self.failures: typing.Dict['KademliaPeer', int] = {} self.failures: typing.Dict['KademliaPeer', int] = {}
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
self.is_running = asyncio.Event(loop=self.loop)
def should_race_continue(self, blob: 'AbstractBlob'): def should_race_continue(self, blob: 'AbstractBlob'):
if len(self.active_connections) >= self.config.max_connections_per_download: if len(self.active_connections) >= self.config.max_connections_per_download:
return False return False
return not (blob.get_is_verified() or not blob.is_writeable()) return not (blob.get_is_verified() or not blob.is_writeable())
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'): async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
if blob.get_is_verified(): if blob.get_is_verified():
return return
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
@ -41,7 +42,7 @@ class BlobDownloader:
start = self.loop.time() start = self.loop.time()
bytes_received, transport = await request_blob( bytes_received, transport = await request_blob(
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout, self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
self.config.blob_download_timeout, connected_transport=transport self.config.blob_download_timeout, connected_transport=transport, connection_id=connection_id
) )
if not transport and peer not in self.ignored: if not transport and peer not in self.ignored:
self.ignored[peer] = self.loop.time() self.ignored[peer] = self.loop.time()
@ -74,12 +75,14 @@ class BlobDownloader:
)) ))
@cache_concurrent @cache_concurrent
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob': async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None,
connection_id: int = 0) -> 'AbstractBlob':
blob = self.blob_manager.get_blob(blob_hash, length) blob = self.blob_manager.get_blob(blob_hash, length)
if blob.get_is_verified(): if blob.get_is_verified():
return blob return blob
self.is_running.set()
try: try:
while not blob.get_is_verified(): while not blob.get_is_verified() and self.is_running.is_set():
batch: typing.Set['KademliaPeer'] = set() batch: typing.Set['KademliaPeer'] = set()
while not self.peer_queue.empty(): while not self.peer_queue.empty():
batch.update(self.peer_queue.get_nowait()) batch.update(self.peer_queue.get_nowait())
@ -94,7 +97,7 @@ class BlobDownloader:
break break
if peer not in self.active_connections and peer not in self.ignored: if peer not in self.active_connections and peer not in self.ignored:
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port) log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
t = self.loop.create_task(self.request_blob_from_peer(blob, peer)) t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id))
self.active_connections[peer] = t self.active_connections[peer] = t
await self.new_peer_or_finished() await self.new_peer_or_finished()
self.cleanup_active() self.cleanup_active()
@ -106,6 +109,7 @@ class BlobDownloader:
def close(self): def close(self):
self.scores.clear() self.scores.clear()
self.ignored.clear() self.ignored.clear()
self.is_running.clear()
for transport in self.connections.values(): for transport in self.connections.values():
transport.close() transport.close()

View file

@ -135,6 +135,15 @@ class PeerManager:
class KademliaPeer: class KademliaPeer:
__slots__ = [
'loop',
'_node_id',
'address',
'udp_port',
'tcp_port',
'protocol_version',
]
def __init__(self, loop: asyncio.BaseEventLoop, address: str, node_id: typing.Optional[bytes] = None, def __init__(self, loop: asyncio.BaseEventLoop, address: str, node_id: typing.Optional[bytes] = None,
udp_port: typing.Optional[int] = None, tcp_port: typing.Optional[int] = None): udp_port: typing.Optional[int] = None, tcp_port: typing.Optional[int] = None):
if node_id is not None: if node_id is not None:

View file

@ -18,7 +18,7 @@ from torba.client.baseaccount import SingleKey, HierarchicalDeterministic
from lbrynet import utils from lbrynet import utils
from lbrynet.conf import Config, Setting from lbrynet.conf import Config, Setting
from lbrynet.blob.blob_file import is_valid_blobhash from lbrynet.blob.blob_file import is_valid_blobhash, BlobBuffer
from lbrynet.blob_exchange.downloader import download_blob from lbrynet.blob_exchange.downloader import download_blob
from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted from lbrynet.error import DownloadSDTimeout, ComponentsNotStarted
from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet from lbrynet.error import NullFundsError, NegativeFundsError, ComponentStartConditionNotMet
@ -433,15 +433,20 @@ class Daemon(metaclass=JSONRPCServerType):
self.component_startup_task = asyncio.create_task(self.component_manager.start()) self.component_startup_task = asyncio.create_task(self.component_manager.start())
await self.component_startup_task await self.component_startup_task
async def stop(self): async def stop(self, shutdown_runner=True):
if self.component_startup_task is not None: if self.component_startup_task is not None:
if self.component_startup_task.done(): if self.component_startup_task.done():
await self.component_manager.stop() await self.component_manager.stop()
else: else:
self.component_startup_task.cancel() self.component_startup_task.cancel()
log.info("stopped api components")
if shutdown_runner:
await self.runner.shutdown()
await self.runner.cleanup() await self.runner.cleanup()
log.info("stopped api server")
if self.analytics_manager.is_started: if self.analytics_manager.is_started:
self.analytics_manager.stop() self.analytics_manager.stop()
log.info("finished shutting down")
async def handle_old_jsonrpc(self, request): async def handle_old_jsonrpc(self, request):
data = await request.json() data = await request.json()
@ -472,64 +477,20 @@ class Daemon(metaclass=JSONRPCServerType):
else: else:
name, claim_id = name_and_claim_id.split("/") name, claim_id = name_and_claim_id.split("/")
uri = f"lbry://{name}#{claim_id}" uri = f"lbry://{name}#{claim_id}"
if not self.stream_manager.started.is_set():
await self.stream_manager.started.wait()
stream = await self.jsonrpc_get(uri) stream = await self.jsonrpc_get(uri)
if isinstance(stream, dict): if isinstance(stream, dict):
raise web.HTTPServerError(text=stream['error']) raise web.HTTPServerError(text=stream['error'])
raise web.HTTPFound(f"/stream/{stream.sd_hash}") raise web.HTTPFound(f"/stream/{stream.sd_hash}")
@staticmethod
def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str],
int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in stream.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': stream.mime_type
}
return headers, size, skip_blobs
async def handle_stream_range_request(self, request: web.Request): async def handle_stream_range_request(self, request: web.Request):
sd_hash = request.path.split("/stream/")[1] sd_hash = request.path.split("/stream/")[1]
if not self.stream_manager.started.is_set():
await self.stream_manager.started.wait()
if sd_hash not in self.stream_manager.streams: if sd_hash not in self.stream_manager.streams:
return web.HTTPNotFound() return web.HTTPNotFound()
stream = self.stream_manager.streams[sd_hash] return await self.stream_manager.stream_partial_content(request, sd_hash)
if stream.status == 'stopped':
await self.stream_manager.start_stream(stream)
if stream.delayed_stop:
stream.delayed_stop.cancel()
headers, size, skip_blobs = self.prepare_range_response_headers(
request.headers.get('range', 'bytes=0-'), stream
)
response = web.StreamResponse(
status=206,
headers=headers
)
await response.prepare(request)
wrote = 0
async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs):
log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1)
if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += b'\x00' * (size - len(decrypted) - wrote)
await response.write_eof(decrypted)
break
else:
await response.write(decrypted)
wrote += len(decrypted)
response.force_close()
return response
async def _process_rpc_call(self, data): async def _process_rpc_call(self, data):
args = data.get('params', {}) args = data.get('params', {})
@ -907,27 +868,30 @@ class Daemon(metaclass=JSONRPCServerType):
@requires(WALLET_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT, @requires(WALLET_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT,
STREAM_MANAGER_COMPONENT, STREAM_MANAGER_COMPONENT,
conditions=[WALLET_IS_UNLOCKED]) conditions=[WALLET_IS_UNLOCKED])
async def jsonrpc_get(self, uri, file_name=None, timeout=None, save_file=None): async def jsonrpc_get(self, uri, file_name=None, download_directory=None, timeout=None, save_file=None):
""" """
Download stream from a LBRY name. Download stream from a LBRY name.
Usage: Usage:
get <uri> [<file_name> | --file_name=<file_name>] [<timeout> | --timeout=<timeout>] get <uri> [<file_name> | --file_name=<file_name>]
[<download_directory> | --download_directory=<download_directory>] [<timeout> | --timeout=<timeout>]
[--save_file=<save_file>] [--save_file=<save_file>]
Options: Options:
--uri=<uri> : (str) uri of the content to download --uri=<uri> : (str) uri of the content to download
--file_name=<file_name> : (str) specified name for the downloaded file, overrides the stream file name --file_name=<file_name> : (str) specified name for the downloaded file, overrides the stream file name
--download_directory=<download_directory> : (str) full path to the directory to download into
--timeout=<timeout> : (int) download timeout in number of seconds --timeout=<timeout> : (int) download timeout in number of seconds
--save_file=<save_file> : (bool) save the file to the downloads directory --save_file=<save_file> : (bool) save the file to the downloads directory
Returns: {File} Returns: {File}
""" """
save_file = save_file if save_file is not None else self.conf.save_files if download_directory and not os.path.isdir(download_directory):
return {"error": f"specified download directory \"{download_directory}\" does not exist"}
try: try:
stream = await self.stream_manager.download_stream_from_uri( stream = await self.stream_manager.download_stream_from_uri(
uri, self.exchange_rate_manager, timeout, file_name, save_file=save_file uri, self.exchange_rate_manager, timeout, file_name, download_directory, save_file=save_file
) )
if not stream: if not stream:
raise DownloadSDTimeout(uri) raise DownloadSDTimeout(uri)
@ -1551,10 +1515,10 @@ class Daemon(metaclass=JSONRPCServerType):
raise Exception(f'Unable to find a file for {kwargs}') raise Exception(f'Unable to find a file for {kwargs}')
stream = streams[0] stream = streams[0]
if status == 'start' and not stream.running: if status == 'start' and not stream.running:
await self.stream_manager.start_stream(stream) await stream.save_file(node=self.stream_manager.node)
msg = "Resumed download" msg = "Resumed download"
elif status == 'stop' and stream.running: elif status == 'stop' and stream.running:
await self.stream_manager.stop_stream(stream) await stream.stop()
msg = "Stopped download" msg = "Stopped download"
else: else:
msg = ( msg = (
@ -2927,9 +2891,11 @@ class Daemon(metaclass=JSONRPCServerType):
blob = await download_blob(asyncio.get_event_loop(), self.conf, self.blob_manager, self.dht_node, blob_hash) blob = await download_blob(asyncio.get_event_loop(), self.conf, self.blob_manager, self.dht_node, blob_hash)
if read: if read:
with open(blob.file_path, 'rb') as handle: with blob.reader_context() as handle:
return handle.read().decode() return handle.read().decode()
else: elif isinstance(blob, BlobBuffer):
log.warning("manually downloaded blob buffer could have missed garbage collection, clearing it")
blob.delete()
return "Downloaded blob %s" % blob_hash return "Downloaded blob %s" % blob_hash
@requires(BLOB_COMPONENT, DATABASE_COMPONENT) @requires(BLOB_COMPONENT, DATABASE_COMPONENT)

View file

@ -47,6 +47,17 @@ def file_reader(file_path: str):
class StreamDescriptor: class StreamDescriptor:
__slots__ = [
'loop',
'blob_dir',
'stream_name',
'key',
'suggested_file_name',
'blobs',
'stream_hash',
'sd_hash'
]
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, stream_name: str, key: str, def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, stream_name: str, key: str,
suggested_file_name: str, blobs: typing.List[BlobInfo], stream_hash: typing.Optional[str] = None, suggested_file_name: str, blobs: typing.List[BlobInfo], stream_hash: typing.Optional[str] = None,
sd_hash: typing.Optional[str] = None): sd_hash: typing.Optional[str] = None):

View file

@ -58,14 +58,14 @@ class StreamDownloader:
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers) self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
async def load_descriptor(self): async def load_descriptor(self, connection_id: int = 0):
# download or get the sd blob # download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash) sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
try: try:
now = self.loop.time() now = self.loop.time()
sd_blob = await asyncio.wait_for( sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash), self.blob_downloader.download_blob(self.sd_hash, connection_id),
self.config.blob_download_timeout, loop=self.loop self.config.blob_download_timeout, loop=self.loop
) )
log.info("downloaded sd blob %s", self.sd_hash) log.info("downloaded sd blob %s", self.sd_hash)
@ -79,7 +79,7 @@ class StreamDownloader:
) )
log.info("loaded stream manifest %s", self.sd_hash) log.info("loaded stream manifest %s", self.sd_hash)
async def start(self, node: typing.Optional['Node'] = None): async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0):
# set up peer accumulation # set up peer accumulation
if node: if node:
self.node = node self.node = node
@ -90,7 +90,7 @@ class StreamDownloader:
log.info("searching for peers for stream %s", self.sd_hash) log.info("searching for peers for stream %s", self.sd_hash)
if not self.descriptor: if not self.descriptor:
await self.load_descriptor() await self.load_descriptor(connection_id)
# add the head blob to the peer search # add the head blob to the peer search
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash) self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
@ -101,10 +101,10 @@ class StreamDownloader:
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
) )
async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob': async def download_stream_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> 'AbstractBlob':
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]): if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}") raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length) blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id)
return blob return blob
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
@ -112,11 +112,11 @@ class StreamDownloader:
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
) )
async def read_blob(self, blob_info: 'BlobInfo') -> bytes: async def read_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> bytes:
start = None start = None
if self.time_to_first_bytes is None: if self.time_to_first_bytes is None:
start = self.loop.time() start = self.loop.time()
blob = await self.download_stream_blob(blob_info) blob = await self.download_stream_blob(blob_info, connection_id)
decrypted = self.decrypt_blob(blob_info, blob) decrypted = self.decrypt_blob(blob_info, blob)
if start: if start:
self.time_to_first_bytes = self.loop.time() - start self.time_to_first_bytes = self.loop.time() - start

View file

@ -3,7 +3,9 @@ import asyncio
import typing import typing
import logging import logging
import binascii import binascii
from aiohttp.web import Request, StreamResponse
from lbrynet.utils import generate_id from lbrynet.utils import generate_id
from lbrynet.error import DownloadSDTimeout
from lbrynet.schema.mime_types import guess_media_type from lbrynet.schema.mime_types import guess_media_type
from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
@ -40,6 +42,33 @@ class ManagedStream:
STATUS_STOPPED = "stopped" STATUS_STOPPED = "stopped"
STATUS_FINISHED = "finished" STATUS_FINISHED = "finished"
__slots__ = [
'loop',
'config',
'blob_manager',
'sd_hash',
'download_directory',
'_file_name',
'_status',
'stream_claim_info',
'download_id',
'rowid',
'written_bytes',
'content_fee',
'downloader',
'analytics_manager',
'fully_reflected',
'file_output_task',
'delayed_stop_task',
'streaming_responses',
'streaming',
'_running',
'saving',
'finished_writing',
'started_writing',
]
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
@ -61,9 +90,13 @@ class ManagedStream:
self.content_fee = content_fee self.content_fee = content_fee
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop) self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None self.delayed_stop_task: typing.Optional[asyncio.Task] = None
self.streaming_responses: typing.List[typing.Tuple[Request, StreamResponse]] = []
self.streaming = asyncio.Event(loop=self.loop)
self._running = asyncio.Event(loop=self.loop)
self.saving = asyncio.Event(loop=self.loop) self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop) self.started_writing = asyncio.Event(loop=self.loop)
@ -84,9 +117,10 @@ class ManagedStream:
def status(self) -> str: def status(self) -> str:
return self._status return self._status
def update_status(self, status: str): async def update_status(self, status: str):
assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED] assert status in [self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED]
self._status = status self._status = status
await self.blob_manager.storage.change_file_status(self.stream_hash, status)
@property @property
def finished(self) -> bool: def finished(self) -> bool:
@ -164,14 +198,12 @@ class ManagedStream:
return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0] return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0]
def as_dict(self) -> typing.Dict: def as_dict(self) -> typing.Dict:
if self.written_bytes: if not self.written_bytes and self.output_file_exists:
written_bytes = self.written_bytes
elif self.output_file_exists:
written_bytes = os.stat(self.full_path).st_size written_bytes = os.stat(self.full_path).st_size
else: else:
written_bytes = None written_bytes = self.written_bytes
return { return {
'completed': self.finished, 'completed': self.output_file_exists and self.status in ('stopped', 'finished'),
'file_name': self.file_name, 'file_name': self.file_name,
'download_directory': self.download_directory, 'download_directory': self.download_directory,
'points_paid': 0.0, 'points_paid': 0.0,
@ -218,72 +250,110 @@ class ManagedStream:
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True, async def start(self, node: typing.Optional['Node'] = None, timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): save_now: bool = False):
await self.downloader.start(node) timeout = timeout or self.config.download_timeout
if not save_file and not file_name: if self._running.is_set():
return
log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
self._running.set()
try:
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
except asyncio.TimeoutError:
self._running.clear()
raise DownloadSDTimeout(self.sd_hash)
if self.delayed_stop_task and not self.delayed_stop_task.done():
self.delayed_stop_task.cancel()
self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
if not await self.blob_manager.storage.file_exists(self.sd_hash): if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = await self.blob_manager.storage.save_downloaded_file( if save_now:
self.stream_hash, None, None, 0.0 file_name, download_dir = self._file_name, self.download_directory
)
self.download_directory = None
self._file_name = None
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self.update_delayed_stop()
else: else:
await self.save_file(file_name, download_directory) file_name, download_dir = None, None
await self.started_writing.wait() self.rowid = await self.blob_manager.storage.save_downloaded_file(
self.stream_hash, file_name, download_dir, 0.0
)
if self.status != self.STATUS_RUNNING:
await self.update_status(self.STATUS_RUNNING)
def update_delayed_stop(self): async def stop(self, finished: bool = False):
def _delayed_stop(): """
log.info("Stopping inactive download for stream %s", self.sd_hash) Stop any running save/stream tasks as well as the downloader and update the status in the database
self.stop_download() """
if self.delayed_stop: self.stop_tasks()
self.delayed_stop.cancel() if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
self.delayed_stop = self.loop.call_later(60, _delayed_stop) await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
async def aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0) -> typing.AsyncIterator[ async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\
typing.Tuple['BlobInfo', bytes]]: -> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
if start_blob_num >= len(self.descriptor.blobs[:-1]): if start_blob_num >= len(self.descriptor.blobs[:-1]):
raise IndexError(start_blob_num) raise IndexError(start_blob_num)
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]): for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
assert i + start_blob_num == blob_info.blob_num assert i + start_blob_num == blob_info.blob_num
if self.delayed_stop: decrypted = await self.downloader.read_blob(blob_info, connection_id)
self.delayed_stop.cancel()
try:
decrypted = await self.downloader.read_blob(blob_info)
yield (blob_info, decrypted) yield (blob_info, decrypted)
except asyncio.CancelledError:
if not self.saving.is_set() and not self.finished_writing.is_set(): async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
self.update_delayed_stop() log.info("stream file to browser for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id,
raise self.sd_hash[:6])
await self.start(node)
headers, size, skip_blobs = self._prepare_range_response_headers(request.headers.get('range', 'bytes=0-'))
response = StreamResponse(
status=206,
headers=headers
)
await response.prepare(request)
self.streaming_responses.append((request, response))
self.streaming.set()
try:
wrote = 0
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs, connection_id=2):
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151)))
await response.write_eof(decrypted)
else:
await response.write(decrypted)
wrote += len(decrypted)
log.info("sent browser %sblob %i/%i", "(final) " if response._eof_sent else "",
blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
if response._eof_sent:
break
return response
finally:
response.force_close()
if (request, response) in self.streaming_responses:
self.streaming_responses.remove((request, response))
if not self.streaming_responses:
self.streaming.clear()
async def _save_file(self, output_path: str): async def _save_file(self, output_path: str):
log.debug("save file %s -> %s", self.sd_hash, output_path) log.info("save file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6],
output_path)
self.saving.set() self.saving.set()
self.finished_writing.clear() self.finished_writing.clear()
self.started_writing.clear() self.started_writing.clear()
try: try:
with open(output_path, 'wb') as file_write_handle: with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream(): async for blob_info, decrypted in self._aiter_read_stream(connection_id=1):
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted) file_write_handle.write(decrypted)
file_write_handle.flush() file_write_handle.flush()
self.written_bytes += len(decrypted) self.written_bytes += len(decrypted)
if not self.started_writing.is_set(): if not self.started_writing.is_set():
self.started_writing.set() self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED) await self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
if self.analytics_manager: if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished( self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash self.download_id, self.claim_name, self.sd_hash
)) ))
self.finished_writing.set() self.finished_writing.set()
log.info("finished saving file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id,
self.sd_hash[:6], self.full_path)
except Exception as err: except Exception as err:
if os.path.isfile(output_path): if os.path.isfile(output_path):
log.info("removing incomplete download %s for %s", output_path, self.sd_hash) log.warning("removing incomplete download %s for %s", output_path, self.sd_hash)
os.remove(output_path) os.remove(output_path)
if not isinstance(err, asyncio.CancelledError): if not isinstance(err, asyncio.CancelledError):
log.exception("unexpected error encountered writing file for stream %s", self.sd_hash) log.exception("unexpected error encountered writing file for stream %s", self.sd_hash)
@ -291,12 +361,11 @@ class ManagedStream:
finally: finally:
self.saving.clear() self.saving.clear()
async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None): async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None,
if self.file_output_task and not self.file_output_task.done(): node: typing.Optional['Node'] = None):
await self.start(node)
if self.file_output_task and not self.file_output_task.done(): # cancel an already running save task
self.file_output_task.cancel() self.file_output_task.cancel()
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = None
self.download_directory = download_directory or self.download_directory or self.config.download_dir self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
@ -305,28 +374,28 @@ class ManagedStream:
if not os.path.isdir(self.download_directory): if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory) os.mkdir(self.download_directory)
if not await self.blob_manager.storage.file_exists(self.sd_hash):
self._file_name = await get_next_available_file_name( self._file_name = await get_next_available_file_name(
self.loop, self.download_directory, self.loop, self.download_directory,
file_name or self._file_name or self.descriptor.suggested_file_name file_name or self._file_name or self.descriptor.suggested_file_name
) )
self.rowid = self.blob_manager.storage.save_downloaded_file(
self.stream_hash, self.file_name, self.download_directory, 0.0
)
else:
await self.blob_manager.storage.change_file_download_dir_and_file_name( await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name self.stream_hash, self.download_directory, self.file_name
) )
self.update_status(ManagedStream.STATUS_RUNNING) await self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
self.written_bytes = 0 self.written_bytes = 0
self.file_output_task = self.loop.create_task(self._save_file(self.full_path)) self.file_output_task = self.loop.create_task(self._save_file(self.full_path))
await self.started_writing.wait()
def stop_download(self): def stop_tasks(self):
if self.file_output_task and not self.file_output_task.done(): if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel() self.file_output_task.cancel()
self.file_output_task = None self.file_output_task = None
while self.streaming_responses:
req, response = self.streaming_responses.pop()
response.force_close()
req.transport.close()
self.downloader.stop() self.downloader.stop()
self._running.clear()
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]: async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
sent = [] sent = []
@ -367,3 +436,44 @@ class ManagedStream:
binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'], binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'],
claim_info['claim_sequence'], claim_info.get('channel_name') claim_info['claim_sequence'], claim_info.get('channel_name')
) )
async def update_content_claim(self, claim_info: typing.Optional[typing.Dict] = None):
if not claim_info:
claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash)
self.set_claim(claim_info, claim_info['value'])
async def _delayed_stop(self):
stalled_count = 0
while self._running.is_set():
if self.saving.is_set() or self.streaming.is_set():
stalled_count = 0
else:
stalled_count += 1
if stalled_count > 1:
log.info("stopping inactive download for lbry://%s#%s (%s...)", self.claim_name, self.claim_id,
self.sd_hash[:6])
await self.stop()
return
await asyncio.sleep(1, loop=self.loop)
def _prepare_range_response_headers(self, get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int]:
if '=' in get_range:
get_range = get_range.split('=')[1]
start, end = get_range.split('-')
size = 0
for blob in self.descriptor.blobs[:-1]:
size += blob.length - 1
start = int(start)
end = int(end) if end else size - 1
skip_blobs = start // 2097150
skip = skip_blobs * 2097151
start = skip
final_size = end - start + 1
headers = {
'Accept-Ranges': 'bytes',
'Content-Range': f'bytes {start}-{end}/{size}',
'Content-Length': str(final_size),
'Content-Type': self.mime_type
}
return headers, size, skip_blobs

View file

@ -5,8 +5,9 @@ import binascii
import logging import logging
import random import random
from decimal import Decimal from decimal import Decimal
from aiohttp.web import Request
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout from lbrynet.error import ResolveTimeout, DownloadDataTimeout
from lbrynet.utils import cache_concurrent from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
@ -56,6 +57,7 @@ comparison_operators = {
def path_or_none(p) -> typing.Optional[str]: def path_or_none(p) -> typing.Optional[str]:
return None if p == '{stream}' else binascii.unhexlify(p).decode() return None if p == '{stream}' else binascii.unhexlify(p).decode()
class StreamManager: class StreamManager:
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'], wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'],
@ -72,29 +74,12 @@ class StreamManager:
self.re_reflect_task: asyncio.Task = None self.re_reflect_task: asyncio.Task = None
self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.update_stream_finished_futs: typing.List[asyncio.Future] = []
self.running_reflector_uploads: typing.List[asyncio.Task] = [] self.running_reflector_uploads: typing.List[asyncio.Task] = []
self.started = asyncio.Event(loop=self.loop)
async def _update_content_claim(self, stream: ManagedStream): async def _update_content_claim(self, stream: ManagedStream):
claim_info = await self.storage.get_content_claim(stream.stream_hash) claim_info = await self.storage.get_content_claim(stream.stream_hash)
self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value']) self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value'])
async def stop_stream(self, stream: ManagedStream):
stream.stop_download()
if not stream.finished and stream.output_file_exists:
try:
os.remove(stream.full_path)
except OSError as err:
log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path,
str(err))
if stream.running:
stream.update_status(ManagedStream.STATUS_STOPPED)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
async def start_stream(self, stream: ManagedStream):
stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
await stream.setup(self.node, save_file=self.config.save_files)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]): async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = [] to_restore = []
@ -139,28 +124,19 @@ class StreamManager:
async def load_streams_from_database(self): async def load_streams_from_database(self):
to_recover = [] to_recover = []
to_start = [] to_start = []
await self.storage.sync_files_to_blobs()
# this will set streams marked as finished and are missing blobs as being stopped
# await self.storage.sync_files_to_blobs()
for file_info in await self.storage.get_all_lbry_files(): for file_info in await self.storage.get_all_lbry_files():
# if the sd blob is not verified, try to reconstruct it from the database
# this could either be because the blob files were deleted manually or save_blobs was not true when
# the stream was downloaded
if not self.blob_manager.is_blob_verified(file_info['sd_hash']): if not self.blob_manager.is_blob_verified(file_info['sd_hash']):
to_recover.append(file_info) to_recover.append(file_info)
to_start.append(file_info) to_start.append(file_info)
if to_recover: if to_recover:
# if self.blob_manager._save_blobs:
# log.info("Attempting to recover %i streams", len(to_recover))
await self.recover_streams(to_recover) await self.recover_streams(to_recover)
if not self.config.save_files:
to_set_as_streaming = []
for file_info in to_start:
file_name = path_or_none(file_info['file_name'])
download_dir = path_or_none(file_info['download_directory'])
if file_name and download_dir and not os.path.isfile(os.path.join(file_name, download_dir)):
file_info['file_name'], file_info['download_directory'] = '{stream}', '{stream}'
to_set_as_streaming.append(file_info['stream_hash'])
if to_set_as_streaming:
await self.storage.set_files_as_streaming(to_set_as_streaming)
log.info("Initializing %i files", len(to_start)) log.info("Initializing %i files", len(to_start))
if to_start: if to_start:
await asyncio.gather(*[ await asyncio.gather(*[
@ -176,8 +152,11 @@ class StreamManager:
if not self.node: if not self.node:
log.warning("no DHT node given, resuming downloads trusting that we can contact reflector") log.warning("no DHT node given, resuming downloads trusting that we can contact reflector")
t = [ t = [
self.loop.create_task(self.start_stream(stream)) for stream in self.streams.values() self.loop.create_task(
if stream.running stream.start(node=self.node, save_now=(stream.full_path is not None))
if not stream.full_path else
stream.save_file(node=self.node)
) for stream in self.streams.values() if stream.running
] ]
if t: if t:
log.info("resuming %i downloads", len(t)) log.info("resuming %i downloads", len(t))
@ -206,6 +185,7 @@ class StreamManager:
await self.load_streams_from_database() await self.load_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume()) self.resume_downloading_task = self.loop.create_task(self.resume())
self.re_reflect_task = self.loop.create_task(self.reflect_streams()) self.re_reflect_task = self.loop.create_task(self.reflect_streams())
self.started.set()
def stop(self): def stop(self):
if self.resume_downloading_task and not self.resume_downloading_task.done(): if self.resume_downloading_task and not self.resume_downloading_task.done():
@ -214,11 +194,13 @@ class StreamManager:
self.re_reflect_task.cancel() self.re_reflect_task.cancel()
while self.streams: while self.streams:
_, stream = self.streams.popitem() _, stream = self.streams.popitem()
stream.stop_download() stream.stop_tasks()
while self.update_stream_finished_futs: while self.update_stream_finished_futs:
self.update_stream_finished_futs.pop().cancel() self.update_stream_finished_futs.pop().cancel()
while self.running_reflector_uploads: while self.running_reflector_uploads:
self.running_reflector_uploads.pop().cancel() self.running_reflector_uploads.pop().cancel()
self.started.clear()
log.info("finished stopping the stream manager")
async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None, async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream: iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream:
@ -236,7 +218,7 @@ class StreamManager:
return stream return stream
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
await self.stop_stream(stream) stream.stop_tasks()
if stream.sd_hash in self.streams: if stream.sd_hash in self.streams:
del self.streams[stream.sd_hash] del self.streams[stream.sd_hash]
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]] blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
@ -290,21 +272,16 @@ class StreamManager:
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint) existing = self.get_filtered_streams(outpoint=outpoint)
if existing: if existing:
if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0])
return existing[0], None return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
if existing and existing[0].claim_id != claim_id: if existing and existing[0].claim_id != claim_id:
raise ResolveError(f"stream for {existing[0].claim_id} collides with existing " raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {claim_id}")
f"download {claim_id}")
if existing: if existing:
log.info("claim contains a metadata only update to a stream we have") log.info("claim contains a metadata only update to a stream we have")
await self.storage.save_content_claim( await self.storage.save_content_claim(
existing[0].stream_hash, outpoint existing[0].stream_hash, outpoint
) )
await self._update_content_claim(existing[0]) await self._update_content_claim(existing[0])
if not existing[0].running:
await self.start_stream(existing[0])
return existing[0], None return existing[0], None
else: else:
existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id) existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id)
@ -318,13 +295,23 @@ class StreamManager:
timeout: typing.Optional[float] = None, timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None, download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream: save_file: typing.Optional[bool] = None,
resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout timeout = timeout or self.config.download_timeout
start_time = self.loop.time() start_time = self.loop.time()
resolved_time = None resolved_time = None
stream = None stream = None
error = None error = None
outpoint = None outpoint = None
if save_file is None:
save_file = self.config.save_files
if file_name and not save_file:
save_file = True
if save_file:
download_directory = download_directory or self.config.download_dir
else:
download_directory = None
try: try:
# resolve the claim # resolve the claim
parsed_uri = parse_lbry_uri(uri) parsed_uri = parse_lbry_uri(uri)
@ -351,9 +338,18 @@ class StreamManager:
# resume or update an existing stream, if the stream changed download it and delete the old one after # resume or update an existing stream, if the stream changed download it and delete the old one after
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
if updated_stream: if updated_stream:
log.info("already have stream for %s", uri)
if save_file and updated_stream.output_file_exists:
save_file = False
await updated_stream.start(node=self.node, timeout=timeout, save_now=save_file)
if not updated_stream.output_file_exists and (save_file or file_name or download_directory):
await updated_stream.save_file(
file_name=file_name, download_directory=download_directory, node=self.node
)
return updated_stream return updated_stream
content_fee = None content_fee = None
fee_amount, fee_address = None, None
# check that the fee is payable # check that the fee is payable
if not to_replace and claim.stream.has_fee: if not to_replace and claim.stream.has_fee:
@ -374,45 +370,37 @@ class StreamManager:
raise InsufficientFundsError(msg) raise InsufficientFundsError(msg)
fee_address = claim.stream.fee.address fee_address = claim.stream.fee.address
content_fee = await self.wallet.send_amount_to_address(
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
)
log.info("paid fee of %s for %s", fee_amount, uri)
download_directory = download_directory or self.config.download_dir
if not file_name and (not self.config.save_files or not save_file):
download_dir, file_name = None, None
stream = ManagedStream( stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory, self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee, file_name, ManagedStream.STATUS_RUNNING, content_fee=content_fee,
analytics_manager=self.analytics_manager analytics_manager=self.analytics_manager
) )
log.info("starting download for %s", uri) log.info("starting download for %s", uri)
try:
await asyncio.wait_for(stream.setup( before_download = self.loop.time()
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory await stream.start(self.node, timeout)
), timeout, loop=self.loop) stream.set_claim(resolved, claim)
except asyncio.TimeoutError:
if not stream.descriptor:
raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
finally:
if stream.descriptor:
if to_replace: # delete old stream now that the replacement has started downloading if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace) await self.delete_stream(to_replace)
stream.set_claim(resolved, claim) elif fee_address:
await self.storage.save_content_claim(stream.stream_hash, outpoint) stream.content_fee = await self.wallet.send_amount_to_address(
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
)
log.info("paid fee of %s for %s", fee_amount, uri)
self.streams[stream.sd_hash] = stream self.streams[stream.sd_hash] = stream
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
await self.storage.save_content_claim(stream.stream_hash, outpoint)
if save_file:
await asyncio.wait_for(stream.save_file(node=self.node), timeout - (self.loop.time() - before_download),
loop=self.loop)
return stream return stream
except DownloadDataTimeout as err: # forgive data timeout, dont delete stream except asyncio.TimeoutError:
error = DownloadDataTimeout(stream.sd_hash)
raise error
except Exception as err: # forgive data timeout, dont delete stream
error = err error = err
raise raise
except Exception as err:
error = err
if stream and stream.descriptor:
await self.storage.delete_stream(stream.descriptor)
await self.blob_manager.delete_blob(stream.sd_hash)
finally: finally:
if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or
stream.downloader.time_to_first_bytes))): stream.downloader.time_to_first_bytes))):
@ -432,5 +420,6 @@ class StreamManager:
None if not error else error.__class__.__name__ None if not error else error.__class__.__name__
) )
) )
if error:
raise error async def stream_partial_content(self, request: Request, sd_hash: str):
return await self.streams[sd_hash].stream_file(request, self.node)

View file

@ -124,7 +124,7 @@ class CommandTestCase(IntegrationTestCase):
async def asyncTearDown(self): async def asyncTearDown(self):
await super().asyncTearDown() await super().asyncTearDown()
self.wallet_component._running = False self.wallet_component._running = False
await self.daemon.stop() await self.daemon.stop(shutdown_runner=False)
async def confirm_tx(self, txid): async def confirm_tx(self, txid):
""" Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """ """ Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """

View file

@ -29,7 +29,7 @@ class CLIIntegrationTest(AsyncioTestCase):
await self.daemon.start() await self.daemon.start()
async def asyncTearDown(self): async def asyncTearDown(self):
await self.daemon.stop() await self.daemon.stop(shutdown_runner=False)
def test_cli_status_command_with_auth(self): def test_cli_status_command_with_auth(self):
actual_output = StringIO() actual_output = StringIO()

View file

@ -340,9 +340,11 @@ class RangeRequests(CommandTestCase):
self.assertTrue(os.path.isfile(path)) self.assertTrue(os.path.isfile(path))
await self._restart_stream_manager() await self._restart_stream_manager()
stream = self.daemon.jsonrpc_file_list()[0] stream = self.daemon.jsonrpc_file_list()[0]
self.assertIsNotNone(stream.full_path)
self.assertIsNone(stream.full_path)
self.assertFalse(os.path.isfile(path)) self.assertFalse(os.path.isfile(path))
if wait_for_start_writing:
await stream.started_writing.wait()
self.assertTrue(os.path.isfile(path))
async def test_file_save_stop_before_finished_streaming_only_wait_for_start(self): async def test_file_save_stop_before_finished_streaming_only_wait_for_start(self):
return await self.test_file_save_stop_before_finished_streaming_only(wait_for_start_writing=True) return await self.test_file_save_stop_before_finished_streaming_only(wait_for_start_writing=True)

View file

@ -54,7 +54,7 @@ class AccountSynchronization(AsyncioTestCase):
async def asyncTearDown(self): async def asyncTearDown(self):
self.wallet_component._running = False self.wallet_component._running = False
await self.daemon.stop() await self.daemon.stop(shutdown_runner=False)
@mock.patch('time.time', mock.Mock(return_value=12345)) @mock.patch('time.time', mock.Mock(return_value=12345))
def test_sync(self): def test_sync(self):

View file

@ -40,7 +40,7 @@ class TestManagedStream(BlobExchangeTestBase):
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
) )
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None): async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True):
await self.setup_stream(blob_count) await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node) mock_node = mock.Mock(spec=Node)
@ -51,10 +51,11 @@ class TestManagedStream(BlobExchangeTestBase):
return q2, self.loop.create_task(_task()) return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True) await self.stream.save_file(node=mock_node)
await self.stream.finished_writing.wait() await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download() if stop_when_done:
await self.stream.stop()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f: with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes) self.assertEqual(f.read(), self.stream_bytes)
@ -62,6 +63,18 @@ class TestManagedStream(BlobExchangeTestBase):
async def test_transfer_stream(self): async def test_transfer_stream(self):
await self._test_transfer_stream(10) await self._test_transfer_stream(10)
self.assertEqual(self.stream.status, "finished")
self.assertFalse(self.stream._running.is_set())
async def test_delayed_stop(self):
await self._test_transfer_stream(10, stop_when_done=False)
self.assertEqual(self.stream.status, "finished")
self.assertTrue(self.stream._running.is_set())
await asyncio.sleep(0.5, loop=self.loop)
self.assertTrue(self.stream._running.is_set())
await asyncio.sleep(0.6, loop=self.loop)
self.assertEqual(self.stream.status, "finished")
self.assertFalse(self.stream._running.is_set())
@unittest.SkipTest @unittest.SkipTest
async def test_transfer_hundred_blob_stream(self): async def test_transfer_hundred_blob_stream(self):
@ -85,11 +98,12 @@ class TestManagedStream(BlobExchangeTestBase):
mock_node.accumulate_peers = _mock_accumulate_peers mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True) await self.stream.save_file(node=mock_node)
await self.stream.finished_writing.wait() await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f: with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes) self.assertEqual(f.read(), self.stream_bytes)
await self.stream.stop()
# self.assertIs(self.server_from_client.tcp_last_down, None) # self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None) # self.assertIsNot(bad_peer.tcp_last_down, None)
@ -125,7 +139,7 @@ class TestManagedStream(BlobExchangeTestBase):
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle: with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate() handle.truncate()
handle.flush() handle.flush()
await self.stream.setup() await self.stream.save_file()
await self.stream.finished_writing.wait() await self.stream.finished_writing.wait()
if corrupt: if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))

View file

@ -225,7 +225,7 @@ class TestStreamManager(BlobExchangeTestBase):
) )
self.assertEqual(stored_status, "running") self.assertEqual(stored_status, "running")
await self.stream_manager.stop_stream(stream) await stream.stop()
self.assertFalse(stream.finished) self.assertFalse(stream.finished)
self.assertFalse(stream.running) self.assertFalse(stream.running)
@ -235,7 +235,7 @@ class TestStreamManager(BlobExchangeTestBase):
) )
self.assertEqual(stored_status, "stopped") self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream) await stream.save_file(node=self.stream_manager.node)
await stream.finished_writing.wait() await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished) self.assertTrue(stream.finished)
@ -337,7 +337,7 @@ class TestStreamManager(BlobExchangeTestBase):
for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]: for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]:
blob_status = await self.client_storage.get_blob_status(blob_hash) blob_status = await self.client_storage.get_blob_status(blob_hash)
self.assertEqual('pending', blob_status) self.assertEqual('pending', blob_status)
self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status) self.assertEqual('finished', self.stream_manager.streams[self.sd_hash].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash) sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists) self.assertTrue(sd_blob.file_exists)