This commit is contained in:
Jack Robison 2019-03-31 13:42:27 -04:00
parent f125468ebf
commit 3a916a8e8e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
14 changed files with 381 additions and 441 deletions

View file

@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception):
self.download = download self.download = download
class ResolveTimeout(Exception):
def __init__(self, uri):
super().__init__(f'Failed to resolve "{uri}" within the timeout')
self.uri = uri
class RequestCanceledError(Exception): class RequestCanceledError(Exception):
pass pass

View file

@ -897,7 +897,7 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
try: try:
stream = await self.stream_manager.download_stream_from_uri( stream = await self.stream_manager.download_stream_from_uri(
uri, timeout, self.exchange_rate_manager, file_name uri, self.exchange_rate_manager, timeout, file_name
) )
if not stream: if not stream:
raise DownloadSDTimeout(uri) raise DownloadSDTimeout(uri)

View file

@ -423,6 +423,17 @@ class SQLiteStorage(SQLiteMixin):
} }
return self.db.run(_sync_blobs) return self.db.run(_sync_blobs)
def sync_files_to_blobs(self):
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
transaction.execute(
"select distinct sb.stream_hash from stream_blob sb "
"inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'"
).fetchall()
)
return self.db.run(_sync_blobs)
# # # # # # # # # stream functions # # # # # # # # # # # # # # # # # # stream functions # # # # # # # # #
async def stream_exists(self, sd_hash: str) -> bool: async def stream_exists(self, sd_hash: str) -> bool:

View file

@ -2,6 +2,7 @@ import asyncio
import typing import typing
import logging import logging
import binascii import binascii
from lbrynet.error import DownloadSDTimeout
from lbrynet.utils import resolve_host from lbrynet.utils import resolve_host
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.blob_exchange.downloader import BlobDownloader from lbrynet.blob_exchange.downloader import BlobDownloader
@ -32,6 +33,8 @@ class StreamDownloader:
self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None
self.fixed_peers_delay: typing.Optional[float] = None self.fixed_peers_delay: typing.Optional[float] = None
self.added_fixed_peers = False self.added_fixed_peers = False
self.time_to_descriptor: typing.Optional[float] = None
self.time_to_first_bytes: typing.Optional[float] = None
async def add_fixed_peers(self): async def add_fixed_peers(self):
def _delayed_add_fixed_peers(): def _delayed_add_fixed_peers():
@ -59,8 +62,16 @@ class StreamDownloader:
# download or get the sd blob # download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash) sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
sd_blob = await self.blob_downloader.download_blob(self.sd_hash) try:
log.info("downloaded sd blob %s", self.sd_hash) now = self.loop.time()
sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash),
self.config.blob_download_timeout, loop=self.loop
)
log.info("downloaded sd blob %s", self.sd_hash)
self.time_to_descriptor = self.loop.time() - now
except asyncio.TimeoutError:
raise DownloadSDTimeout(self.sd_hash)
# parse the descriptor # parse the descriptor
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -101,12 +112,18 @@ class StreamDownloader:
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
) )
async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob'): async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob) return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob)
async def read_blob(self, blob_info: 'BlobInfo') -> bytes: async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
start = None
if self.time_to_first_bytes is None:
start = self.loop.time()
blob = await self.download_stream_blob(blob_info) blob = await self.download_stream_blob(blob_info)
return await self.decrypt_blob(blob_info, blob) decrypted = await self.decrypt_blob(blob_info, blob)
if start:
self.time_to_first_bytes = self.loop.time() - start
return decrypted
def stop(self): def stop(self):
if self.accumulate_task: if self.accumulate_task:

View file

@ -9,13 +9,13 @@ from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.reflector.client import StreamReflectorClient from lbrynet.stream.reflector.client import StreamReflectorClient
from lbrynet.extras.daemon.storage import StoredStreamClaim from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.blob import MAX_BLOB_SIZE
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbrynet.conf import Config from lbrynet.conf import Config
from lbrynet.schema.claim import Claim from lbrynet.schema.claim import Claim
from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_info import BlobInfo from lbrynet.blob.blob_info import BlobInfo
from lbrynet.dht.node import Node from lbrynet.dht.node import Node
from lbrynet.extras.daemon.analytics import AnalyticsManager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -43,7 +43,8 @@ class ManagedStream:
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None, status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None, download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None,
descriptor: typing.Optional[StreamDescriptor] = None): descriptor: typing.Optional[StreamDescriptor] = None,
analytics_manager: typing.Optional['AnalyticsManager'] = None):
self.loop = loop self.loop = loop
self.config = config self.config = config
self.blob_manager = blob_manager self.blob_manager = blob_manager
@ -56,11 +57,13 @@ class ManagedStream:
self.rowid = rowid self.rowid = rowid
self.written_bytes = 0 self.written_bytes = 0
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop) self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None self.delayed_stop: typing.Optional[asyncio.Handle] = None
self.saving = asyncio.Event(loop=self.loop) self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop)
@property @property
def descriptor(self) -> StreamDescriptor: def descriptor(self) -> StreamDescriptor:
@ -217,16 +220,18 @@ class ManagedStream:
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True): async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
await self.downloader.start(node) await self.downloader.start(node)
if not save_file: if not save_file and not file_name:
if not await self.blob_manager.storage.file_exists(self.sd_hash): if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = self.blob_manager.storage.save_downloaded_file( self.rowid = self.blob_manager.storage.save_downloaded_file(
self.stream_hash, None, None, 0.0 self.stream_hash, None, None, 0.0
) )
self.update_delayed_stop() self.update_delayed_stop()
else: else:
await self.save_file() await self.save_file(file_name, download_directory)
await self.started_writing.wait()
self.update_status(ManagedStream.STATUS_RUNNING) self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
@ -235,7 +240,6 @@ class ManagedStream:
log.info("Stopping inactive download for stream %s", self.sd_hash) log.info("Stopping inactive download for stream %s", self.sd_hash)
self.stop_download() self.stop_download()
log.info("update delayed stop")
if self.delayed_stop: if self.delayed_stop:
self.delayed_stop.cancel() self.delayed_stop.cancel()
self.delayed_stop = self.loop.call_later(60, _delayed_stop) self.delayed_stop = self.loop.call_later(60, _delayed_stop)
@ -259,6 +263,7 @@ class ManagedStream:
async def _save_file(self, output_path: str): async def _save_file(self, output_path: str):
self.saving.set() self.saving.set()
self.finished_writing.clear() self.finished_writing.clear()
self.started_writing.clear()
try: try:
with open(output_path, 'wb') as file_write_handle: with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream(): async for blob_info, decrypted in self.aiter_read_stream():
@ -266,14 +271,21 @@ class ManagedStream:
file_write_handle.write(decrypted) file_write_handle.write(decrypted)
file_write_handle.flush() file_write_handle.flush()
self.written_bytes += len(decrypted) self.written_bytes += len(decrypted)
if not self.started_writing.is_set():
self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash
))
self.finished_writing.set() self.finished_writing.set()
except Exception as err: except Exception as err:
if os.path.isfile(output_path):
log.info("removing incomplete download %s for %s", output_path, self.sd_hash)
os.remove(output_path)
if not isinstance(err, asyncio.CancelledError): if not isinstance(err, asyncio.CancelledError):
log.exception("unexpected error encountered writing file for stream %s", self.sd_hash) log.exception("unexpected error encountered writing file for stream %s", self.sd_hash)
if os.path.isfile(output_path):
log.info("removing incomplete download %s", output_path)
os.remove(output_path)
raise err raise err
finally: finally:
self.saving.clear() self.saving.clear()
@ -282,10 +294,9 @@ class ManagedStream:
if self.file_output_task and not self.file_output_task.done(): if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel() self.file_output_task.cancel()
if self.delayed_stop: if self.delayed_stop:
log.info('cancel delayed stop')
self.delayed_stop.cancel() self.delayed_stop.cancel()
self.delayed_stop = None self.delayed_stop = None
self.download_directory = download_directory or self.download_directory self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.descriptor.suggested_file_name): if not (file_name or self._file_name or self.descriptor.suggested_file_name):

View file

@ -6,8 +6,8 @@ import logging
import random import random
from decimal import Decimal from decimal import Decimal
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
# DownloadDataTimeout, DownloadSDTimeout from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
from lbrynet.utils import generate_id, cache_concurrent from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import Claim from lbrynet.schema.claim import Claim
@ -96,11 +96,10 @@ class StreamManager:
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED) await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
async def start_stream(self, stream: ManagedStream): async def start_stream(self, stream: ManagedStream):
await stream.setup(self.node, save_file=not self.config.streaming_only)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
stream.update_status(ManagedStream.STATUS_RUNNING) stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING) await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
self.wait_for_stream_finished(stream) await stream.setup(self.node, save_file=not self.config.streaming_only)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]): async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = [] to_restore = []
@ -139,13 +138,14 @@ class StreamManager:
return return
stream = ManagedStream( stream = ManagedStream(
self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status, self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status,
claim, rowid=rowid, descriptor=descriptor claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager
) )
self.streams[sd_hash] = stream self.streams[sd_hash] = stream
async def load_streams_from_database(self): async def load_streams_from_database(self):
to_recover = [] to_recover = []
to_start = [] to_start = []
await self.storage.sync_files_to_blobs()
for file_info in await self.storage.get_all_lbry_files(): for file_info in await self.storage.get_all_lbry_files():
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified(): if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_recover.append(file_info) to_recover.append(file_info)
@ -181,10 +181,10 @@ class StreamManager:
while True: while True:
if self.config.reflect_streams and self.config.reflector_servers: if self.config.reflect_streams and self.config.reflector_servers:
sd_hashes = await self.storage.get_streams_to_re_reflect() sd_hashes = await self.storage.get_streams_to_re_reflect()
streams = list(filter(lambda s: s in sd_hashes, self.streams.keys())) sd_hashes = [sd for sd in sd_hashes if sd in self.streams]
batch = [] batch = []
while streams: while sd_hashes:
stream = streams.pop() stream = self.streams[sd_hashes.pop()]
if not stream.fully_reflected.is_set(): if not stream.fully_reflected.is_set():
host, port = random.choice(self.config.reflector_servers) host, port = random.choice(self.config.reflector_servers)
batch.append(stream.upload_to_reflector(host, port)) batch.append(stream.upload_to_reflector(host, port))
@ -198,7 +198,7 @@ class StreamManager:
async def start(self): async def start(self):
await self.load_streams_from_database() await self.load_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume()) self.resume_downloading_task = self.loop.create_task(self.resume())
# self.re_reflect_task = self.loop.create_task(self.reflect_streams()) self.re_reflect_task = self.loop.create_task(self.reflect_streams())
def stop(self): def stop(self):
if self.resume_downloading_task and not self.resume_downloading_task.done(): if self.resume_downloading_task and not self.resume_downloading_task.done():
@ -279,28 +279,11 @@ class StreamManager:
streams.reverse() streams.reverse()
return streams return streams
def wait_for_stream_finished(self, stream: ManagedStream):
async def _wait_for_stream_finished():
if stream.downloader and stream.running:
await stream.finished_writing.wait()
stream.update_status(ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
stream.download_id, stream.claim_name, stream.sd_hash
))
task = self.loop.create_task(_wait_for_stream_finished())
self.update_stream_finished_futs.append(task)
task.add_done_callback(
lambda _: None if task not in self.update_stream_finished_futs else
self.update_stream_finished_futs.remove(task)
)
async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[ async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]: typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint) existing = self.get_filtered_streams(outpoint=outpoint)
if existing: if existing:
if not existing[0].running: if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0]) await self.start_stream(existing[0])
return existing[0], None return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
@ -323,163 +306,112 @@ class StreamManager:
return None, existing_for_claim_id[0] return None, existing_for_claim_id[0]
return None, None return None, None
# async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: EncryptedStreamDownloader,
# download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict,
# file_name: typing.Optional[str] = None) -> ManagedStream:
# start_time = self.loop.time()
# downloader.download(self.node)
# await downloader.got_descriptor.wait()
# got_descriptor_time.set_result(self.loop.time() - start_time)
# rowid = await self._store_stream(downloader)
# await self.storage.save_content_claim(
# downloader.descriptor.stream_hash, outpoint
# )
# stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir,
# file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id)
# stream.set_claim(resolved, claim)
# await stream.downloader.wrote_bytes_event.wait()
# self.streams.add(stream)
# return stream
@cache_concurrent @cache_concurrent
async def download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager', async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
file_name: typing.Optional[str] = None) -> ManagedStream: timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout
start_time = self.loop.time() start_time = self.loop.time()
parsed_uri = parse_lbry_uri(uri) resolved_time = None
if parsed_uri.is_channel: stream = None
raise ResolveError("cannot download a channel claim, specify a /path") error = None
outpoint = None
# resolve the claim try:
resolved_result = await self.wallet.ledger.resolve(0, 10, uri) # resolve the claim
await self.storage.save_claims_for_resolve([ parsed_uri = parse_lbry_uri(uri)
value for value in resolved_result.values() if 'error' not in value if parsed_uri.is_channel:
]) raise ResolveError("cannot download a channel claim, specify a /path")
resolved = resolved_result.get(uri, {}) try:
resolved = resolved if 'value' in resolved else resolved.get('claim') resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout)
if not resolved: except asyncio.TimeoutError:
raise ResolveError(f"Failed to resolve stream at '{uri}'") raise ResolveTimeout(uri)
if 'error' in resolved: await self.storage.save_claims_for_resolve([
raise ResolveError(f"error resolving stream: {resolved['error']}") value for value in resolved_result.values() if 'error' not in value
])
resolved = resolved_result.get(uri, {})
resolved = resolved if 'value' in resolved else resolved.get('claim')
if not resolved:
raise ResolveError(f"Failed to resolve stream at '{uri}'")
if 'error' in resolved:
raise ResolveError(f"error resolving stream: {resolved['error']}")
claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf'])) claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf']))
outpoint = f"{resolved['txid']}:{resolved['nout']}" outpoint = f"{resolved['txid']}:{resolved['nout']}"
resolved_time = self.loop.time() - start_time resolved_time = self.loop.time() - start_time
# resume or update an existing stream, if the stream changed download it and delete the old one after # resume or update an existing stream, if the stream changed download it and delete the old one after
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
if updated_stream: if updated_stream:
return updated_stream return updated_stream
# check that the fee is payable # check that the fee is payable
fee_amount, fee_address = None, None if not to_replace and claim.stream.has_fee:
if claim.stream.has_fee: fee_amount = round(exchange_rate_manager.convert_currency(
fee_amount = round(exchange_rate_manager.convert_currency( claim.stream.fee.currency, "LBC", claim.stream.fee.amount
claim.stream.fee.currency, "LBC", claim.stream.fee.amount ), 5)
), 5) max_fee_amount = round(exchange_rate_manager.convert_currency(
max_fee_amount = round(exchange_rate_manager.convert_currency( self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount'])
self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount']) ), 5)
), 5) if fee_amount > max_fee_amount:
if fee_amount > max_fee_amount: msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}"
msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}" log.warning(msg)
log.warning(msg) raise KeyFeeAboveMaxAllowed(msg)
raise KeyFeeAboveMaxAllowed(msg) balance = await self.wallet.default_account.get_balance()
balance = await self.wallet.default_account.get_balance() if lbc_to_dewies(str(fee_amount)) > balance:
if lbc_to_dewies(str(fee_amount)) > balance: msg = f"fee of {fee_amount} exceeds max available balance"
msg = f"fee of {fee_amount} exceeds max available balance" log.warning(msg)
log.warning(msg) raise InsufficientFundsError(msg)
raise InsufficientFundsError(msg) fee_address = claim.stream.fee.address
fee_address = claim.stream.fee.address await self.wallet.send_amount_to_address(
# content_fee_tx = await self.wallet.send_amount_to_address( lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
# lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1') )
# ) log.info("paid fee of %s for %s", fee_amount, uri)
handled_fee_time = self.loop.time() - resolved_time - start_time download_directory = download_directory or self.config.download_dir
if not file_name and (self.config.streaming_only or not save_file):
# download the stream download_dir, file_name = None, None
download_id = binascii.hexlify(generate_id()).decode() stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
download_dir = self.config.download_dir file_name, ManagedStream.STATUS_RUNNING, analytics_manager=self.analytics_manager
save_file = True )
if not file_name and self.config.streaming_only: try:
download_dir, file_name = None, None await asyncio.wait_for(stream.setup(
save_file = False self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
stream = ManagedStream( ), timeout, loop=self.loop)
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_dir, except asyncio.TimeoutError:
file_name, ManagedStream.STATUS_RUNNING, download_id=download_id if not stream.descriptor:
) raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
await stream.setup(self.node, save_file=save_file) if to_replace: # delete old stream now that the replacement has started downloading
stream.set_claim(resolved, claim) await self.delete_stream(to_replace)
await self.storage.save_content_claim(stream.stream_hash, outpoint) stream.set_claim(resolved, claim)
self.streams[stream.sd_hash] = stream await self.storage.save_content_claim(stream.stream_hash, outpoint)
self.streams[stream.sd_hash] = stream
# stream = None return stream
# descriptor_time_fut = self.loop.create_future() except Exception as err:
# start_download_time = self.loop.time() error = err
# time_to_descriptor = None if stream and stream.descriptor:
# time_to_first_bytes = None await self.storage.delete_stream(stream.descriptor)
# error = None finally:
# try: if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or
# stream = await asyncio.wait_for( stream.downloader.time_to_first_bytes))):
# asyncio.ensure_future( self.loop.create_task(
# self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved, self.analytics_manager.send_time_to_first_bytes(
# file_name) resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id,
# ), timeout uri, outpoint,
# ) None if not stream else len(stream.downloader.blob_downloader.active_connections),
# time_to_descriptor = await descriptor_time_fut None if not stream else len(stream.downloader.blob_downloader.scores),
# time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor False if not stream else stream.downloader.added_fixed_peers,
# self.wait_for_stream_finished(stream) self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay,
# if fee_address and fee_amount and not to_replace: None if not stream else stream.sd_hash,
# None if not stream else stream.downloader.time_to_descriptor,
# elif to_replace: # delete old stream now that the replacement has started downloading None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
# await self.delete_stream(to_replace) None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
# except asyncio.TimeoutError: None if not stream else stream.downloader.time_to_first_bytes,
# if descriptor_time_fut.done(): None if not error else error.__class__.__name__
# time_to_descriptor = descriptor_time_fut.result() )
# error = DownloadDataTimeout(downloader.sd_hash) )
# self.blob_manager.delete_blob(downloader.sd_hash) if error:
# await self.storage.delete_stream(downloader.descriptor) raise error
# else:
# descriptor_time_fut.cancel()
# error = DownloadSDTimeout(downloader.sd_hash)
# if stream:
# await self.stop_stream(stream)
# else:
# downloader.stop()
# if error:
# log.warning(error)
# if self.analytics_manager:
# self.loop.create_task(
# self.analytics_manager.send_time_to_first_bytes(
# resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint,
# None if not stream else len(stream.downloader.blob_downloader.active_connections),
# None if not stream else len(stream.downloader.blob_downloader.scores),
# False if not downloader else downloader.added_fixed_peers,
# self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay,
# claim.source_hash.decode(), time_to_descriptor,
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
# time_to_first_bytes, None if not error else error.__class__.__name__
# )
# )
# if error:
# raise error
return stream
# async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
# file_name: typing.Optional[str] = None,
# timeout: typing.Optional[float] = None) -> ManagedStream:
# timeout = timeout or self.config.download_timeout
# if uri in self.starting_streams:
# return await self.starting_streams[uri]
# fut = asyncio.Future(loop=self.loop)
# self.starting_streams[uri] = fut
# try:
# stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name)
# fut.set_result(stream)
# except Exception as err:
# fut.set_exception(err)
# try:
# return await fut
# finally:
# del self.starting_streams[uri]

View file

@ -28,9 +28,9 @@ class TestBlobfile(AsyncioTestCase):
self.assertEqual(blob.get_is_verified(), False) self.assertEqual(blob.get_is_verified(), False)
self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes) self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes)
writer = blob.open_for_writing() writer = blob.get_blob_writer()
writer.write(blob_bytes) writer.write(blob_bytes)
await blob.finished_writing.wait() await blob.verified.wait()
self.assertTrue(os.path.isfile(blob.file_path), True) self.assertTrue(os.path.isfile(blob.file_path), True)
self.assertEqual(blob.get_is_verified(), True) self.assertEqual(blob.get_is_verified(), True)
self.assertIn(blob_hash, blob_manager.completed_blob_hashes) self.assertIn(blob_hash, blob_manager.completed_blob_hashes)

View file

@ -11,7 +11,7 @@ from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob from lbrynet.blob_exchange.client import request_blob
from lbrynet.dht.peer import KademliaPeer, PeerManager from lbrynet.dht.peer import KademliaPeer, PeerManager
# import logging # import logging
@ -58,9 +58,9 @@ class TestBlobExchange(BlobExchangeTestBase):
async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes): async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes):
# add the blob on the server # add the blob on the server
server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes)) server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes))
writer = server_blob.open_for_writing() writer = server_blob.get_blob_writer()
writer.write(blob_bytes) writer.write(blob_bytes)
await server_blob.finished_writing.wait() await server_blob.verified.wait()
self.assertTrue(os.path.isfile(server_blob.file_path)) self.assertTrue(os.path.isfile(server_blob.file_path))
self.assertEqual(server_blob.get_is_verified(), True) self.assertEqual(server_blob.get_is_verified(), True)
@ -68,11 +68,14 @@ class TestBlobExchange(BlobExchangeTestBase):
client_blob = self.client_blob_manager.get_blob(blob_hash) client_blob = self.client_blob_manager.get_blob(blob_hash)
# download the blob # download the blob
downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address, downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address,
self.server_from_client.tcp_port, 2, 3) self.server_from_client.tcp_port, 2, 3)
await client_blob.finished_writing.wait() self.assertIsNotNone(transport)
self.addCleanup(transport.close)
await client_blob.verified.wait()
self.assertEqual(client_blob.get_is_verified(), True) self.assertEqual(client_blob.get_is_verified(), True)
self.assertTrue(downloaded) self.assertTrue(downloaded)
self.addCleanup(client_blob.close)
async def test_transfer_sd_blob(self): async def test_transfer_sd_blob(self):
sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02" sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02"
@ -112,7 +115,7 @@ class TestBlobExchange(BlobExchangeTestBase):
), ),
self._test_transfer_blob(blob_hash) self._test_transfer_blob(blob_hash)
) )
await second_client_blob.finished_writing.wait() await second_client_blob.verified.wait()
self.assertEqual(second_client_blob.get_is_verified(), True) self.assertEqual(second_client_blob.get_is_verified(), True)
async def test_host_different_blobs_to_multiple_peers_at_once(self): async def test_host_different_blobs_to_multiple_peers_at_once(self):
@ -143,7 +146,7 @@ class TestBlobExchange(BlobExchangeTestBase):
server_from_second_client.tcp_port, 2, 3 server_from_second_client.tcp_port, 2, 3
), ),
self._test_transfer_blob(sd_hash), self._test_transfer_blob(sd_hash),
second_client_blob.finished_writing.wait() second_client_blob.verified.wait()
) )
self.assertEqual(second_client_blob.get_is_verified(), True) self.assertEqual(second_client_blob.get_is_verified(), True)

View file

@ -1,12 +1,13 @@
import asyncio import asyncio
from torba.testcase import AsyncioTestCase from unittest import mock, TestCase
from lbrynet.dht.protocol.data_store import DictDataStore from lbrynet.dht.protocol.data_store import DictDataStore
from lbrynet.dht.peer import PeerManager from lbrynet.dht.peer import PeerManager
class DataStoreTests(AsyncioTestCase): class DataStoreTests(TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.get_event_loop() self.loop = mock.Mock(spec=asyncio.BaseEventLoop)
self.loop.time = lambda: 0.0
self.peer_manager = PeerManager(self.loop) self.peer_manager = PeerManager(self.loop)
self.data_store = DictDataStore(self.loop, self.peer_manager) self.data_store = DictDataStore(self.loop, self.peer_manager)

View file

@ -1,117 +0,0 @@
import os
import asyncio
import tempfile
import shutil
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.assembler import StreamAssembler
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.stream_manager import StreamManager
class TestStreamAssembler(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4
self.cleartext = b'test'
async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), ":memory:")
await self.storage.open()
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage)
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
# create the stream
file_path = os.path.join(tmp_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.cleartext)
sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key)
# copy blob files
sd_hash = sd.calculate_sd_hash()
shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash))
for blob_info in sd.blobs:
if blob_info.blob_hash:
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
await downloader_storage.open()
# add the blobs to the blob table (this would happen upon a blob download finishing)
downloader_blob_manager = BlobManager(self.loop, download_dir, downloader_storage)
descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash)
# assemble the decrypted file
assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash)
await assembler.assemble_decrypted_stream(download_dir)
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file")))
with open(os.path.join(download_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.cleartext)
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
# its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
self.assertEqual(
[descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
)
await downloader_storage.close()
await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
self.cleartext = b'test\n' * 20000000
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_padding(self):
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i)
await self.test_create_and_decrypt_one_blob_stream()
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_random(self):
self.cleartext = os.urandom(20000000)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_managed_stream_announces(self):
# setup a blob manager
storage = SQLiteStorage(Config(), ":memory:")
await storage.open()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
blob_manager = BlobManager(self.loop, tmp_dir, storage)
stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# create the stream
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
file_path = os.path.join(download_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(b'testtest')
stream = await stream_manager.create_stream(file_path)
self.assertEqual(
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
await storage.get_blobs_to_announce())
async def test_create_truncate_and_handle_stream(self):
self.cleartext = b'potato' * 1337 * 5279
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -1,102 +0,0 @@
import os
import time
import unittest
from unittest import mock
import asyncio
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.conf import Config
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
class TestStreamDownloader(BlobExchangeTestBase):
async def setup_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir,
reflector_servers=[])
self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(self.downloader.stream_handle.closed)
self.assertTrue(os.path.isfile(self.downloader.output_path))
self.downloader.stop()
self.assertIs(self.downloader.stream_handle, None)
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))

View file

@ -0,0 +1,175 @@
import os
import shutil
import unittest
from unittest import mock
import asyncio
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.stream.descriptor import StreamDescriptor
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
def get_mock_node(loop):
mock_node = mock.Mock(spec=Node)
mock_node.joined = asyncio.Event(loop=loop)
mock_node.joined.set()
return mock_node
class TestManagedStream(BlobExchangeTestBase):
async def create_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
return descriptor
async def setup_stream(self, blob_count: int = 10):
await self.create_stream(blob_count)
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))
async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False):
descriptor = await self.create_stream(blobs)
# copy blob files
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash),
os.path.join(self.client_blob_manager.blob_dir, self.sd_hash))
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
for blob_info in descriptor.blobs[:-1]:
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash),
os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
await self.stream.setup()
await self.stream.finished_writing.wait()
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
with open(os.path.join(self.client_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.stream_bytes)
self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified())
self.assertEqual(
True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
#
# # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
# self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
# self.assertEqual(
# [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
# )
#
# await downloader_storage.close()
# await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
await self.test_create_and_decrypt_one_blob_stream(10)
# async def test_create_managed_stream_announces(self):
# # setup a blob manager
# storage = SQLiteStorage(Config(), ":memory:")
# await storage.open()
# tmp_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(tmp_dir))
# blob_manager = BlobManager(self.loop, tmp_dir, storage)
# stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# # create the stream
# download_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(download_dir))
# file_path = os.path.join(download_dir, "test_file")
# with open(file_path, 'wb') as f:
# f.write(b'testtest')
#
# stream = await stream_manager.create_stream(file_path)
# self.assertEqual(
# [stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
# await storage.get_blobs_to_announce())
# async def test_create_truncate_and_handle_stream(self):
# # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
# await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -99,7 +99,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
blob = blob_manager.get_blob(sd_hash) blob = blob_manager.get_blob(sd_hash)
blob.set_length(len(sd_bytes)) blob.set_length(len(sd_bytes))
writer = blob.open_for_writing() writer = blob.get_blob_writer()
writer.write(sd_bytes) writer.write(sd_bytes)
await blob.verified.wait() await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob( descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -116,7 +116,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2' sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle: with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle:
handle.write(b'doesnt work') handle.write(b'doesnt work')
blob = BlobFile(loop, tmp_dir, sd_hash) blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir)
self.assertTrue(blob.file_exists) self.assertTrue(blob.file_exists)
self.assertIsNotNone(blob.length) self.assertIsNotNone(blob.length)
with self.assertRaises(InvalidStreamDescriptorError): with self.assertRaises(InvalidStreamDescriptorError):

View file

@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase):
def check_post(event): def check_post(event):
self.assertEqual(event['event'], 'Time To First Bytes') self.assertEqual(event['event'], 'Time To First Bytes')
self.assertEqual(event['properties']['error'], 'DownloadSDTimeout') self.assertEqual(event['properties']['error'], 'DownloadSDTimeout')
self.assertEqual(event['properties']['tried_peers_count'], None) self.assertEqual(event['properties']['tried_peers_count'], 0)
self.assertEqual(event['properties']['active_peer_count'], None) self.assertEqual(event['properties']['active_peer_count'], 0)
self.assertEqual(event['properties']['use_fixed_peers'], False) self.assertEqual(event['properties']['use_fixed_peers'], False)
self.assertEqual(event['properties']['added_fixed_peers'], False) self.assertEqual(event['properties']['added_fixed_peers'], False)
self.assertEqual(event['properties']['fixed_peer_delay'], None) self.assertEqual(event['properties']['fixed_peer_delay'], None)
@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase):
self.stream_manager.analytics_manager._post = check_post self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set()) self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
stream_hash = stream.stream_hash stream_hash = stream.stream_hash
self.assertSetEqual(self.stream_manager.streams, {stream}) self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream})
self.assertTrue(stream.running) self.assertTrue(stream.running)
self.assertFalse(stream.finished) self.assertFalse(stream.finished)
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file"))) self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "stopped") self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream) await self.stream_manager.start_stream(stream)
await stream.downloader.stream_finished_event.wait() await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished) self.assertTrue(stream.finished)
self.assertFalse(stream.running) self.assertFalse(stream.running)
@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "finished") self.assertEqual(stored_status, "finished")
await self.stream_manager.delete_stream(stream, True) await self.stream_manager.delete_stream(stream, True)
self.assertSetEqual(self.stream_manager.streams, set()) self.assertDictEqual(self.stream_manager.streams, {})
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none( stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash "select status from file where stream_hash=?", stream_hash
@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase):
async def _test_download_error_on_start(self, expected_error, timeout=None): async def _test_download_error_on_start(self, expected_error, timeout=None):
with self.assertRaises(expected_error): with self.assertRaises(expected_error):
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout) await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
async def _test_download_error_analytics_on_start(self, expected_error, timeout=None): async def _test_download_error_analytics_on_start(self, expected_error, timeout=None):
received = [] received = []
@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase):
await self.setup_stream_manager(old_sort=old_sort) await self.setup_stream_manager(old_sort=old_sort)
self.stream_manager.analytics_manager._post = check_post self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set()) self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
await stream.downloader.stream_finished_event.wait() await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.stream_manager.stop() self.stream_manager.stop()
self.client_blob_manager.stop() self.client_blob_manager.stop()
@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase):
await self.client_blob_manager.setup() await self.client_blob_manager.setup()
await self.stream_manager.start() await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams)) self.assertEqual(1, len(self.stream_manager.streams))
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash) self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys()))
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status) for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]:
blob_status = await self.client_storage.get_blob_status(blob_hash)
self.assertEqual('pending', blob_status)
self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash) sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists) self.assertTrue(sd_blob.file_exists)