This commit is contained in:
Jack Robison 2019-03-31 13:42:27 -04:00
parent f125468ebf
commit 3a916a8e8e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
14 changed files with 381 additions and 441 deletions

View file

@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception):
self.download = download self.download = download
class ResolveTimeout(Exception):
def __init__(self, uri):
super().__init__(f'Failed to resolve "{uri}" within the timeout')
self.uri = uri
class RequestCanceledError(Exception): class RequestCanceledError(Exception):
pass pass

View file

@ -897,7 +897,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
try: try:
stream = await self.stream_manager.download_stream_from_uri( stream = await self.stream_manager.download_stream_from_uri(
uri, timeout, self.exchange_rate_manager, file_name uri, self.exchange_rate_manager, timeout, file_name
) )
if not stream: if not stream:
raise DownloadSDTimeout(uri) raise DownloadSDTimeout(uri)

View file

@ -423,6 +423,17 @@ class SQLiteStorage(SQLiteMixin):
} }
return self.db.run(_sync_blobs) return self.db.run(_sync_blobs)
def sync_files_to_blobs(self):
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
transaction.execute(
"select distinct sb.stream_hash from stream_blob sb "
"inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'"
).fetchall()
)
return self.db.run(_sync_blobs)
# # # # # # # # # stream functions # # # # # # # # # # # # # # # # # # stream functions # # # # # # # # #
async def stream_exists(self, sd_hash: str) -> bool: async def stream_exists(self, sd_hash: str) -> bool:

View file

@ -2,6 +2,7 @@ import asyncio
import typing import typing
import logging import logging
import binascii import binascii
from lbrynet.error import DownloadSDTimeout
from lbrynet.utils import resolve_host from lbrynet.utils import resolve_host
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.blob_exchange.downloader import BlobDownloader from lbrynet.blob_exchange.downloader import BlobDownloader
@ -32,6 +33,8 @@ class StreamDownloader:
self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None
self.fixed_peers_delay: typing.Optional[float] = None self.fixed_peers_delay: typing.Optional[float] = None
self.added_fixed_peers = False self.added_fixed_peers = False
self.time_to_descriptor: typing.Optional[float] = None
self.time_to_first_bytes: typing.Optional[float] = None
async def add_fixed_peers(self): async def add_fixed_peers(self):
def _delayed_add_fixed_peers(): def _delayed_add_fixed_peers():
@ -59,8 +62,16 @@ class StreamDownloader:
# download or get the sd blob # download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash) sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
sd_blob = await self.blob_downloader.download_blob(self.sd_hash) try:
now = self.loop.time()
sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash),
self.config.blob_download_timeout, loop=self.loop
)
log.info("downloaded sd blob %s", self.sd_hash) log.info("downloaded sd blob %s", self.sd_hash)
self.time_to_descriptor = self.loop.time() - now
except asyncio.TimeoutError:
raise DownloadSDTimeout(self.sd_hash)
# parse the descriptor # parse the descriptor
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -101,12 +112,18 @@ class StreamDownloader:
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
) )
async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob'): async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob) return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob)
async def read_blob(self, blob_info: 'BlobInfo') -> bytes: async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
start = None
if self.time_to_first_bytes is None:
start = self.loop.time()
blob = await self.download_stream_blob(blob_info) blob = await self.download_stream_blob(blob_info)
return await self.decrypt_blob(blob_info, blob) decrypted = await self.decrypt_blob(blob_info, blob)
if start:
self.time_to_first_bytes = self.loop.time() - start
return decrypted
def stop(self): def stop(self):
if self.accumulate_task: if self.accumulate_task:

View file

@ -9,13 +9,13 @@ from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.reflector.client import StreamReflectorClient from lbrynet.stream.reflector.client import StreamReflectorClient
from lbrynet.extras.daemon.storage import StoredStreamClaim from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.blob import MAX_BLOB_SIZE
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbrynet.conf import Config from lbrynet.conf import Config
from lbrynet.schema.claim import Claim from lbrynet.schema.claim import Claim
from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_info import BlobInfo from lbrynet.blob.blob_info import BlobInfo
from lbrynet.dht.node import Node from lbrynet.dht.node import Node
from lbrynet.extras.daemon.analytics import AnalyticsManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -43,7 +43,8 @@ class ManagedStream:
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None, download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None,
descriptor: typing.Optional[StreamDescriptor] = None): descriptor: typing.Optional[StreamDescriptor] = None,
analytics_manager: typing.Optional['AnalyticsManager'] = None):
self.loop = loop self.loop = loop
self.config = config self.config = config
self.blob_manager = blob_manager self.blob_manager = blob_manager
@ -56,11 +57,13 @@ class ManagedStream:
self.rowid = rowid self.rowid = rowid
self.written_bytes = 0 self.written_bytes = 0
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop) self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None self.delayed_stop: typing.Optional[asyncio.Handle] = None
self.saving = asyncio.Event(loop=self.loop) self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop)
@property @property
def descriptor(self) -> StreamDescriptor: def descriptor(self) -> StreamDescriptor:
@ -217,16 +220,18 @@ class ManagedStream:
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True): async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
await self.downloader.start(node) await self.downloader.start(node)
if not save_file: if not save_file and not file_name:
if not await self.blob_manager.storage.file_exists(self.sd_hash): if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = self.blob_manager.storage.save_downloaded_file( self.rowid = self.blob_manager.storage.save_downloaded_file(
self.stream_hash, None, None, 0.0 self.stream_hash, None, None, 0.0
) )
self.update_delayed_stop() self.update_delayed_stop()
else: else:
await self.save_file() await self.save_file(file_name, download_directory)
await self.started_writing.wait()
self.update_status(ManagedStream.STATUS_RUNNING) self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
@ -235,7 +240,6 @@ class ManagedStream:
log.info("Stopping inactive download for stream %s", self.sd_hash) log.info("Stopping inactive download for stream %s", self.sd_hash)
self.stop_download() self.stop_download()
log.info("update delayed stop")
if self.delayed_stop: if self.delayed_stop:
self.delayed_stop.cancel() self.delayed_stop.cancel()
self.delayed_stop = self.loop.call_later(60, _delayed_stop) self.delayed_stop = self.loop.call_later(60, _delayed_stop)
@ -259,6 +263,7 @@ class ManagedStream:
async def _save_file(self, output_path: str): async def _save_file(self, output_path: str):
self.saving.set() self.saving.set()
self.finished_writing.clear() self.finished_writing.clear()
self.started_writing.clear()
try: try:
with open(output_path, 'wb') as file_write_handle: with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream(): async for blob_info, decrypted in self.aiter_read_stream():
@ -266,14 +271,21 @@ class ManagedStream:
file_write_handle.write(decrypted) file_write_handle.write(decrypted)
file_write_handle.flush() file_write_handle.flush()
self.written_bytes += len(decrypted) self.written_bytes += len(decrypted)
if not self.started_writing.is_set():
self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash
))
self.finished_writing.set() self.finished_writing.set()
except Exception as err: except Exception as err:
if os.path.isfile(output_path):
log.info("removing incomplete download %s for %s", output_path, self.sd_hash)
os.remove(output_path)
if not isinstance(err, asyncio.CancelledError): if not isinstance(err, asyncio.CancelledError):
log.exception("unexpected error encountered writing file for stream %s", self.sd_hash) log.exception("unexpected error encountered writing file for stream %s", self.sd_hash)
if os.path.isfile(output_path):
log.info("removing incomplete download %s", output_path)
os.remove(output_path)
raise err raise err
finally: finally:
self.saving.clear() self.saving.clear()
@ -282,10 +294,9 @@ class ManagedStream:
if self.file_output_task and not self.file_output_task.done(): if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel() self.file_output_task.cancel()
if self.delayed_stop: if self.delayed_stop:
log.info('cancel delayed stop')
self.delayed_stop.cancel() self.delayed_stop.cancel()
self.delayed_stop = None self.delayed_stop = None
self.download_directory = download_directory or self.download_directory self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.descriptor.suggested_file_name): if not (file_name or self._file_name or self.descriptor.suggested_file_name):

View file

@ -6,8 +6,8 @@ import logging
import random import random
from decimal import Decimal from decimal import Decimal
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
# DownloadDataTimeout, DownloadSDTimeout from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
from lbrynet.utils import generate_id, cache_concurrent from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import Claim from lbrynet.schema.claim import Claim
@ -96,11 +96,10 @@ class StreamManager:
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED) await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
async def start_stream(self, stream: ManagedStream): async def start_stream(self, stream: ManagedStream):
await stream.setup(self.node, save_file=not self.config.streaming_only)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
stream.update_status(ManagedStream.STATUS_RUNNING) stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING) await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
self.wait_for_stream_finished(stream) await stream.setup(self.node, save_file=not self.config.streaming_only)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]): async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = [] to_restore = []
@ -139,13 +138,14 @@ class StreamManager:
return return
stream = ManagedStream( stream = ManagedStream(
self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status, self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status,
claim, rowid=rowid, descriptor=descriptor claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager
) )
self.streams[sd_hash] = stream self.streams[sd_hash] = stream
async def load_streams_from_database(self): async def load_streams_from_database(self):
to_recover = [] to_recover = []
to_start = [] to_start = []
await self.storage.sync_files_to_blobs()
for file_info in await self.storage.get_all_lbry_files(): for file_info in await self.storage.get_all_lbry_files():
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_recover.append(file_info) to_recover.append(file_info)
@ -181,10 +181,10 @@ class StreamManager:
while True: while True:
if self.config.reflect_streams and self.config.reflector_servers: if self.config.reflect_streams and self.config.reflector_servers:
sd_hashes = await self.storage.get_streams_to_re_reflect() sd_hashes = await self.storage.get_streams_to_re_reflect()
streams = list(filter(lambda s: s in sd_hashes, self.streams.keys())) sd_hashes = [sd for sd in sd_hashes if sd in self.streams]
batch = [] batch = []
while streams: while sd_hashes:
stream = streams.pop() stream = self.streams[sd_hashes.pop()]
if not stream.fully_reflected.is_set(): if not stream.fully_reflected.is_set():
host, port = random.choice(self.config.reflector_servers) host, port = random.choice(self.config.reflector_servers)
batch.append(stream.upload_to_reflector(host, port)) batch.append(stream.upload_to_reflector(host, port))
@ -198,7 +198,7 @@ class StreamManager:
async def start(self): async def start(self):
await self.load_streams_from_database() await self.load_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume()) self.resume_downloading_task = self.loop.create_task(self.resume())
# self.re_reflect_task = self.loop.create_task(self.reflect_streams()) self.re_reflect_task = self.loop.create_task(self.reflect_streams())
def stop(self): def stop(self):
if self.resume_downloading_task and not self.resume_downloading_task.done(): if self.resume_downloading_task and not self.resume_downloading_task.done():
@ -279,28 +279,11 @@ class StreamManager:
streams.reverse() streams.reverse()
return streams return streams
def wait_for_stream_finished(self, stream: ManagedStream):
async def _wait_for_stream_finished():
if stream.downloader and stream.running:
await stream.finished_writing.wait()
stream.update_status(ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
stream.download_id, stream.claim_name, stream.sd_hash
))
task = self.loop.create_task(_wait_for_stream_finished())
self.update_stream_finished_futs.append(task)
task.add_done_callback(
lambda _: None if task not in self.update_stream_finished_futs else
self.update_stream_finished_futs.remove(task)
)
async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[ async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint) existing = self.get_filtered_streams(outpoint=outpoint)
if existing: if existing:
if not existing[0].running: if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0]) await self.start_stream(existing[0])
return existing[0], None return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
@ -323,34 +306,27 @@ class StreamManager:
return None, existing_for_claim_id[0] return None, existing_for_claim_id[0]
return None, None return None, None
# async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: EncryptedStreamDownloader,
# download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict,
# file_name: typing.Optional[str] = None) -> ManagedStream:
# start_time = self.loop.time()
# downloader.download(self.node)
# await downloader.got_descriptor.wait()
# got_descriptor_time.set_result(self.loop.time() - start_time)
# rowid = await self._store_stream(downloader)
# await self.storage.save_content_claim(
# downloader.descriptor.stream_hash, outpoint
# )
# stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir,
# file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id)
# stream.set_claim(resolved, claim)
# await stream.downloader.wrote_bytes_event.wait()
# self.streams.add(stream)
# return stream
@cache_concurrent @cache_concurrent
async def download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager', async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
file_name: typing.Optional[str] = None) -> ManagedStream: timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout
start_time = self.loop.time() start_time = self.loop.time()
resolved_time = None
stream = None
error = None
outpoint = None
try:
# resolve the claim
parsed_uri = parse_lbry_uri(uri) parsed_uri = parse_lbry_uri(uri)
if parsed_uri.is_channel: if parsed_uri.is_channel:
raise ResolveError("cannot download a channel claim, specify a /path") raise ResolveError("cannot download a channel claim, specify a /path")
try:
# resolve the claim resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout)
resolved_result = await self.wallet.ledger.resolve(0, 10, uri) except asyncio.TimeoutError:
raise ResolveTimeout(uri)
await self.storage.save_claims_for_resolve([ await self.storage.save_claims_for_resolve([
value for value in resolved_result.values() if 'error' not in value value for value in resolved_result.values() if 'error' not in value
]) ])
@ -371,8 +347,7 @@ class StreamManager:
return updated_stream return updated_stream
# check that the fee is payable # check that the fee is payable
fee_amount, fee_address = None, None if not to_replace and claim.stream.has_fee:
if claim.stream.has_fee:
fee_amount = round(exchange_rate_manager.convert_currency( fee_amount = round(exchange_rate_manager.convert_currency(
claim.stream.fee.currency, "LBC", claim.stream.fee.amount claim.stream.fee.currency, "LBC", claim.stream.fee.amount
), 5) ), 5)
@ -389,97 +364,54 @@ class StreamManager:
log.warning(msg) log.warning(msg)
raise InsufficientFundsError(msg) raise InsufficientFundsError(msg)
fee_address = claim.stream.fee.address fee_address = claim.stream.fee.address
# content_fee_tx = await self.wallet.send_amount_to_address( await self.wallet.send_amount_to_address(
# lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
# )
handled_fee_time = self.loop.time() - resolved_time - start_time
# download the stream
download_id = binascii.hexlify(generate_id()).decode()
download_dir = self.config.download_dir
save_file = True
if not file_name and self.config.streaming_only:
download_dir, file_name = None, None
save_file = False
stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_dir,
file_name, ManagedStream.STATUS_RUNNING, download_id=download_id
) )
log.info("paid fee of %s for %s", fee_amount, uri)
await stream.setup(self.node, save_file=save_file) download_directory = download_directory or self.config.download_dir
if not file_name and (self.config.streaming_only or not save_file):
download_dir, file_name = None, None
stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
file_name, ManagedStream.STATUS_RUNNING, analytics_manager=self.analytics_manager
)
try:
await asyncio.wait_for(stream.setup(
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
), timeout, loop=self.loop)
except asyncio.TimeoutError:
if not stream.descriptor:
raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
stream.set_claim(resolved, claim) stream.set_claim(resolved, claim)
await self.storage.save_content_claim(stream.stream_hash, outpoint) await self.storage.save_content_claim(stream.stream_hash, outpoint)
self.streams[stream.sd_hash] = stream self.streams[stream.sd_hash] = stream
# stream = None
# descriptor_time_fut = self.loop.create_future()
# start_download_time = self.loop.time()
# time_to_descriptor = None
# time_to_first_bytes = None
# error = None
# try:
# stream = await asyncio.wait_for(
# asyncio.ensure_future(
# self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved,
# file_name)
# ), timeout
# )
# time_to_descriptor = await descriptor_time_fut
# time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor
# self.wait_for_stream_finished(stream)
# if fee_address and fee_amount and not to_replace:
#
# elif to_replace: # delete old stream now that the replacement has started downloading
# await self.delete_stream(to_replace)
# except asyncio.TimeoutError:
# if descriptor_time_fut.done():
# time_to_descriptor = descriptor_time_fut.result()
# error = DownloadDataTimeout(downloader.sd_hash)
# self.blob_manager.delete_blob(downloader.sd_hash)
# await self.storage.delete_stream(downloader.descriptor)
# else:
# descriptor_time_fut.cancel()
# error = DownloadSDTimeout(downloader.sd_hash)
# if stream:
# await self.stop_stream(stream)
# else:
# downloader.stop()
# if error:
# log.warning(error)
# if self.analytics_manager:
# self.loop.create_task(
# self.analytics_manager.send_time_to_first_bytes(
# resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint,
# None if not stream else len(stream.downloader.blob_downloader.active_connections),
# None if not stream else len(stream.downloader.blob_downloader.scores),
# False if not downloader else downloader.added_fixed_peers,
# self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay,
# claim.source_hash.decode(), time_to_descriptor,
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
# time_to_first_bytes, None if not error else error.__class__.__name__
# )
# )
# if error:
# raise error
return stream return stream
except Exception as err:
# async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager', error = err
# file_name: typing.Optional[str] = None, if stream and stream.descriptor:
# timeout: typing.Optional[float] = None) -> ManagedStream: await self.storage.delete_stream(stream.descriptor)
# timeout = timeout or self.config.download_timeout finally:
# if uri in self.starting_streams: if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or
# return await self.starting_streams[uri] stream.downloader.time_to_first_bytes))):
# fut = asyncio.Future(loop=self.loop) self.loop.create_task(
# self.starting_streams[uri] = fut self.analytics_manager.send_time_to_first_bytes(
# try: resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id,
# stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name) uri, outpoint,
# fut.set_result(stream) None if not stream else len(stream.downloader.blob_downloader.active_connections),
# except Exception as err: None if not stream else len(stream.downloader.blob_downloader.scores),
# fut.set_exception(err) False if not stream else stream.downloader.added_fixed_peers,
# try: self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay,
# return await fut None if not stream else stream.sd_hash,
# finally: None if not stream else stream.downloader.time_to_descriptor,
# del self.starting_streams[uri] None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
None if not stream else stream.downloader.time_to_first_bytes,
None if not error else error.__class__.__name__
)
)
if error:
raise error

View file

@ -28,9 +28,9 @@ class TestBlobfile(AsyncioTestCase):
self.assertEqual(blob.get_is_verified(), False) self.assertEqual(blob.get_is_verified(), False)
self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes) self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes)
writer = blob.open_for_writing() writer = blob.get_blob_writer()
writer.write(blob_bytes) writer.write(blob_bytes)
await blob.finished_writing.wait() await blob.verified.wait()
self.assertTrue(os.path.isfile(blob.file_path), True) self.assertTrue(os.path.isfile(blob.file_path), True)
self.assertEqual(blob.get_is_verified(), True) self.assertEqual(blob.get_is_verified(), True)
self.assertIn(blob_hash, blob_manager.completed_blob_hashes) self.assertIn(blob_hash, blob_manager.completed_blob_hashes)

View file

@ -11,7 +11,7 @@ from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob from lbrynet.blob_exchange.client import request_blob
from lbrynet.dht.peer import KademliaPeer, PeerManager from lbrynet.dht.peer import KademliaPeer, PeerManager
# import logging # import logging
@ -58,9 +58,9 @@ class TestBlobExchange(BlobExchangeTestBase):
async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes): async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes):
# add the blob on the server # add the blob on the server
server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes)) server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes))
writer = server_blob.open_for_writing() writer = server_blob.get_blob_writer()
writer.write(blob_bytes) writer.write(blob_bytes)
await server_blob.finished_writing.wait() await server_blob.verified.wait()
self.assertTrue(os.path.isfile(server_blob.file_path)) self.assertTrue(os.path.isfile(server_blob.file_path))
self.assertEqual(server_blob.get_is_verified(), True) self.assertEqual(server_blob.get_is_verified(), True)
@ -68,11 +68,14 @@ class TestBlobExchange(BlobExchangeTestBase):
client_blob = self.client_blob_manager.get_blob(blob_hash) client_blob = self.client_blob_manager.get_blob(blob_hash)
# download the blob # download the blob
downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address, downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address,
self.server_from_client.tcp_port, 2, 3) self.server_from_client.tcp_port, 2, 3)
await client_blob.finished_writing.wait() self.assertIsNotNone(transport)
self.addCleanup(transport.close)
await client_blob.verified.wait()
self.assertEqual(client_blob.get_is_verified(), True) self.assertEqual(client_blob.get_is_verified(), True)
self.assertTrue(downloaded) self.assertTrue(downloaded)
self.addCleanup(client_blob.close)
async def test_transfer_sd_blob(self): async def test_transfer_sd_blob(self):
sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02" sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02"
@ -112,7 +115,7 @@ class TestBlobExchange(BlobExchangeTestBase):
), ),
self._test_transfer_blob(blob_hash) self._test_transfer_blob(blob_hash)
) )
await second_client_blob.finished_writing.wait() await second_client_blob.verified.wait()
self.assertEqual(second_client_blob.get_is_verified(), True) self.assertEqual(second_client_blob.get_is_verified(), True)
async def test_host_different_blobs_to_multiple_peers_at_once(self): async def test_host_different_blobs_to_multiple_peers_at_once(self):
@ -143,7 +146,7 @@ class TestBlobExchange(BlobExchangeTestBase):
server_from_second_client.tcp_port, 2, 3 server_from_second_client.tcp_port, 2, 3
), ),
self._test_transfer_blob(sd_hash), self._test_transfer_blob(sd_hash),
second_client_blob.finished_writing.wait() second_client_blob.verified.wait()
) )
self.assertEqual(second_client_blob.get_is_verified(), True) self.assertEqual(second_client_blob.get_is_verified(), True)

View file

@ -1,12 +1,13 @@
import asyncio import asyncio
from torba.testcase import AsyncioTestCase from unittest import mock, TestCase
from lbrynet.dht.protocol.data_store import DictDataStore from lbrynet.dht.protocol.data_store import DictDataStore
from lbrynet.dht.peer import PeerManager from lbrynet.dht.peer import PeerManager
class DataStoreTests(AsyncioTestCase): class DataStoreTests(TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.get_event_loop() self.loop = mock.Mock(spec=asyncio.BaseEventLoop)
self.loop.time = lambda: 0.0
self.peer_manager = PeerManager(self.loop) self.peer_manager = PeerManager(self.loop)
self.data_store = DictDataStore(self.loop, self.peer_manager) self.data_store = DictDataStore(self.loop, self.peer_manager)

View file

@ -1,117 +0,0 @@
import os
import asyncio
import tempfile
import shutil
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.assembler import StreamAssembler
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.stream_manager import StreamManager
class TestStreamAssembler(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4
self.cleartext = b'test'
async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), ":memory:")
await self.storage.open()
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage)
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
# create the stream
file_path = os.path.join(tmp_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.cleartext)
sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key)
# copy blob files
sd_hash = sd.calculate_sd_hash()
shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash))
for blob_info in sd.blobs:
if blob_info.blob_hash:
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
await downloader_storage.open()
# add the blobs to the blob table (this would happen upon a blob download finishing)
downloader_blob_manager = BlobManager(self.loop, download_dir, downloader_storage)
descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash)
# assemble the decrypted file
assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash)
await assembler.assemble_decrypted_stream(download_dir)
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file")))
with open(os.path.join(download_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.cleartext)
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
# its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
self.assertEqual(
[descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
)
await downloader_storage.close()
await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
self.cleartext = b'test\n' * 20000000
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_padding(self):
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i)
await self.test_create_and_decrypt_one_blob_stream()
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_random(self):
self.cleartext = os.urandom(20000000)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_managed_stream_announces(self):
# setup a blob manager
storage = SQLiteStorage(Config(), ":memory:")
await storage.open()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
blob_manager = BlobManager(self.loop, tmp_dir, storage)
stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# create the stream
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
file_path = os.path.join(download_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(b'testtest')
stream = await stream_manager.create_stream(file_path)
self.assertEqual(
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
await storage.get_blobs_to_announce())
async def test_create_truncate_and_handle_stream(self):
self.cleartext = b'potato' * 1337 * 5279
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -1,102 +0,0 @@
import os
import time
import unittest
from unittest import mock
import asyncio
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.conf import Config
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
class TestStreamDownloader(BlobExchangeTestBase):
async def setup_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir,
reflector_servers=[])
self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(self.downloader.stream_handle.closed)
self.assertTrue(os.path.isfile(self.downloader.output_path))
self.downloader.stop()
self.assertIs(self.downloader.stream_handle, None)
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))

View file

@ -0,0 +1,175 @@
import os
import shutil
import unittest
from unittest import mock
import asyncio
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.stream.descriptor import StreamDescriptor
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
def get_mock_node(loop):
mock_node = mock.Mock(spec=Node)
mock_node.joined = asyncio.Event(loop=loop)
mock_node.joined.set()
return mock_node
class TestManagedStream(BlobExchangeTestBase):
async def create_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
return descriptor
async def setup_stream(self, blob_count: int = 10):
await self.create_stream(blob_count)
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))
async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False):
descriptor = await self.create_stream(blobs)
# copy blob files
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash),
os.path.join(self.client_blob_manager.blob_dir, self.sd_hash))
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
for blob_info in descriptor.blobs[:-1]:
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash),
os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
await self.stream.setup()
await self.stream.finished_writing.wait()
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
with open(os.path.join(self.client_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.stream_bytes)
self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified())
self.assertEqual(
True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
#
# # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
# self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
# self.assertEqual(
# [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
# )
#
# await downloader_storage.close()
# await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
await self.test_create_and_decrypt_one_blob_stream(10)
# async def test_create_managed_stream_announces(self):
# # setup a blob manager
# storage = SQLiteStorage(Config(), ":memory:")
# await storage.open()
# tmp_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(tmp_dir))
# blob_manager = BlobManager(self.loop, tmp_dir, storage)
# stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# # create the stream
# download_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(download_dir))
# file_path = os.path.join(download_dir, "test_file")
# with open(file_path, 'wb') as f:
# f.write(b'testtest')
#
# stream = await stream_manager.create_stream(file_path)
# self.assertEqual(
# [stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
# await storage.get_blobs_to_announce())
# async def test_create_truncate_and_handle_stream(self):
# # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
# await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -99,7 +99,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
blob = blob_manager.get_blob(sd_hash) blob = blob_manager.get_blob(sd_hash)
blob.set_length(len(sd_bytes)) blob.set_length(len(sd_bytes))
writer = blob.open_for_writing() writer = blob.get_blob_writer()
writer.write(sd_bytes) writer.write(sd_bytes)
await blob.verified.wait() await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob( descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -116,7 +116,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2' sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle: with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle:
handle.write(b'doesnt work') handle.write(b'doesnt work')
blob = BlobFile(loop, tmp_dir, sd_hash) blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir)
self.assertTrue(blob.file_exists) self.assertTrue(blob.file_exists)
self.assertIsNotNone(blob.length) self.assertIsNotNone(blob.length)
with self.assertRaises(InvalidStreamDescriptorError): with self.assertRaises(InvalidStreamDescriptorError):

View file

@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase):
def check_post(event): def check_post(event):
self.assertEqual(event['event'], 'Time To First Bytes') self.assertEqual(event['event'], 'Time To First Bytes')
self.assertEqual(event['properties']['error'], 'DownloadSDTimeout') self.assertEqual(event['properties']['error'], 'DownloadSDTimeout')
self.assertEqual(event['properties']['tried_peers_count'], None) self.assertEqual(event['properties']['tried_peers_count'], 0)
self.assertEqual(event['properties']['active_peer_count'], None) self.assertEqual(event['properties']['active_peer_count'], 0)
self.assertEqual(event['properties']['use_fixed_peers'], False) self.assertEqual(event['properties']['use_fixed_peers'], False)
self.assertEqual(event['properties']['added_fixed_peers'], False) self.assertEqual(event['properties']['added_fixed_peers'], False)
self.assertEqual(event['properties']['fixed_peer_delay'], None) self.assertEqual(event['properties']['fixed_peer_delay'], None)
@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase):
self.stream_manager.analytics_manager._post = check_post self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set()) self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
stream_hash = stream.stream_hash stream_hash = stream.stream_hash
self.assertSetEqual(self.stream_manager.streams, {stream}) self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream})
self.assertTrue(stream.running) self.assertTrue(stream.running)
self.assertFalse(stream.finished) self.assertFalse(stream.finished)
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file"))) self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "stopped") self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream) await self.stream_manager.start_stream(stream)
await stream.downloader.stream_finished_event.wait() await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished) self.assertTrue(stream.finished)
self.assertFalse(stream.running) self.assertFalse(stream.running)
@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "finished") self.assertEqual(stored_status, "finished")
await self.stream_manager.delete_stream(stream, True) await self.stream_manager.delete_stream(stream, True)
self.assertSetEqual(self.stream_manager.streams, set()) self.assertDictEqual(self.stream_manager.streams, {})
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none( stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash "select status from file where stream_hash=?", stream_hash
@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase):
async def _test_download_error_on_start(self, expected_error, timeout=None): async def _test_download_error_on_start(self, expected_error, timeout=None):
with self.assertRaises(expected_error): with self.assertRaises(expected_error):
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout) await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
async def _test_download_error_analytics_on_start(self, expected_error, timeout=None): async def _test_download_error_analytics_on_start(self, expected_error, timeout=None):
received = [] received = []
@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase):
await self.setup_stream_manager(old_sort=old_sort) await self.setup_stream_manager(old_sort=old_sort)
self.stream_manager.analytics_manager._post = check_post self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set()) self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
await stream.downloader.stream_finished_event.wait() await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.stream_manager.stop() self.stream_manager.stop()
self.client_blob_manager.stop() self.client_blob_manager.stop()
@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase):
await self.client_blob_manager.setup() await self.client_blob_manager.setup()
await self.stream_manager.start() await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams)) self.assertEqual(1, len(self.stream_manager.streams))
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash) self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys()))
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status) for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]:
blob_status = await self.client_storage.get_blob_status(blob_hash)
self.assertEqual('pending', blob_status)
self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash) sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists) self.assertTrue(sd_blob.file_exists)