forked from LBRYCommunity/lbry-sdk
tests
This commit is contained in:
parent
f125468ebf
commit
3a916a8e8e
14 changed files with 381 additions and 441 deletions
|
@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception):
|
|||
self.download = download
|
||||
|
||||
|
||||
class ResolveTimeout(Exception):
|
||||
def __init__(self, uri):
|
||||
super().__init__(f'Failed to resolve "{uri}" within the timeout')
|
||||
self.uri = uri
|
||||
|
||||
|
||||
class RequestCanceledError(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
@ -897,7 +897,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
"""
|
||||
try:
|
||||
stream = await self.stream_manager.download_stream_from_uri(
|
||||
uri, timeout, self.exchange_rate_manager, file_name
|
||||
uri, self.exchange_rate_manager, timeout, file_name
|
||||
)
|
||||
if not stream:
|
||||
raise DownloadSDTimeout(uri)
|
||||
|
|
|
@ -423,6 +423,17 @@ class SQLiteStorage(SQLiteMixin):
|
|||
}
|
||||
return self.db.run(_sync_blobs)
|
||||
|
||||
def sync_files_to_blobs(self):
|
||||
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
|
||||
transaction.executemany(
|
||||
"update file set status='stopped' where stream_hash=?",
|
||||
transaction.execute(
|
||||
"select distinct sb.stream_hash from stream_blob sb "
|
||||
"inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'"
|
||||
).fetchall()
|
||||
)
|
||||
return self.db.run(_sync_blobs)
|
||||
|
||||
# # # # # # # # # stream functions # # # # # # # # #
|
||||
|
||||
async def stream_exists(self, sd_hash: str) -> bool:
|
||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import typing
|
||||
import logging
|
||||
import binascii
|
||||
from lbrynet.error import DownloadSDTimeout
|
||||
from lbrynet.utils import resolve_host
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.blob_exchange.downloader import BlobDownloader
|
||||
|
@ -32,6 +33,8 @@ class StreamDownloader:
|
|||
self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None
|
||||
self.fixed_peers_delay: typing.Optional[float] = None
|
||||
self.added_fixed_peers = False
|
||||
self.time_to_descriptor: typing.Optional[float] = None
|
||||
self.time_to_first_bytes: typing.Optional[float] = None
|
||||
|
||||
async def add_fixed_peers(self):
|
||||
def _delayed_add_fixed_peers():
|
||||
|
@ -59,8 +62,16 @@ class StreamDownloader:
|
|||
# download or get the sd blob
|
||||
sd_blob = self.blob_manager.get_blob(self.sd_hash)
|
||||
if not sd_blob.get_is_verified():
|
||||
sd_blob = await self.blob_downloader.download_blob(self.sd_hash)
|
||||
log.info("downloaded sd blob %s", self.sd_hash)
|
||||
try:
|
||||
now = self.loop.time()
|
||||
sd_blob = await asyncio.wait_for(
|
||||
self.blob_downloader.download_blob(self.sd_hash),
|
||||
self.config.blob_download_timeout, loop=self.loop
|
||||
)
|
||||
log.info("downloaded sd blob %s", self.sd_hash)
|
||||
self.time_to_descriptor = self.loop.time() - now
|
||||
except asyncio.TimeoutError:
|
||||
raise DownloadSDTimeout(self.sd_hash)
|
||||
|
||||
# parse the descriptor
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
|
@ -101,12 +112,18 @@ class StreamDownloader:
|
|||
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
|
||||
)
|
||||
|
||||
async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob'):
|
||||
async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
|
||||
return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob)
|
||||
|
||||
async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
|
||||
start = None
|
||||
if self.time_to_first_bytes is None:
|
||||
start = self.loop.time()
|
||||
blob = await self.download_stream_blob(blob_info)
|
||||
return await self.decrypt_blob(blob_info, blob)
|
||||
decrypted = await self.decrypt_blob(blob_info, blob)
|
||||
if start:
|
||||
self.time_to_first_bytes = self.loop.time() - start
|
||||
return decrypted
|
||||
|
||||
def stop(self):
|
||||
if self.accumulate_task:
|
||||
|
|
|
@ -9,13 +9,13 @@ from lbrynet.stream.downloader import StreamDownloader
|
|||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.reflector.client import StreamReflectorClient
|
||||
from lbrynet.extras.daemon.storage import StoredStreamClaim
|
||||
from lbrynet.blob import MAX_BLOB_SIZE
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.schema.claim import Claim
|
||||
from lbrynet.blob.blob_manager import BlobManager
|
||||
from lbrynet.blob.blob_info import BlobInfo
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.extras.daemon.analytics import AnalyticsManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,7 +43,8 @@ class ManagedStream:
|
|||
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
|
||||
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
|
||||
download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None,
|
||||
descriptor: typing.Optional[StreamDescriptor] = None):
|
||||
descriptor: typing.Optional[StreamDescriptor] = None,
|
||||
analytics_manager: typing.Optional['AnalyticsManager'] = None):
|
||||
self.loop = loop
|
||||
self.config = config
|
||||
self.blob_manager = blob_manager
|
||||
|
@ -56,11 +57,13 @@ class ManagedStream:
|
|||
self.rowid = rowid
|
||||
self.written_bytes = 0
|
||||
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
|
||||
self.analytics_manager = analytics_manager
|
||||
self.fully_reflected = asyncio.Event(loop=self.loop)
|
||||
self.file_output_task: typing.Optional[asyncio.Task] = None
|
||||
self.delayed_stop: typing.Optional[asyncio.Handle] = None
|
||||
self.saving = asyncio.Event(loop=self.loop)
|
||||
self.finished_writing = asyncio.Event(loop=self.loop)
|
||||
self.started_writing = asyncio.Event(loop=self.loop)
|
||||
|
||||
@property
|
||||
def descriptor(self) -> StreamDescriptor:
|
||||
|
@ -217,16 +220,18 @@ class ManagedStream:
|
|||
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
|
||||
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
|
||||
|
||||
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True):
|
||||
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
|
||||
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
|
||||
await self.downloader.start(node)
|
||||
if not save_file:
|
||||
if not save_file and not file_name:
|
||||
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
||||
self.rowid = self.blob_manager.storage.save_downloaded_file(
|
||||
self.stream_hash, None, None, 0.0
|
||||
)
|
||||
self.update_delayed_stop()
|
||||
else:
|
||||
await self.save_file()
|
||||
await self.save_file(file_name, download_directory)
|
||||
await self.started_writing.wait()
|
||||
self.update_status(ManagedStream.STATUS_RUNNING)
|
||||
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
|
||||
|
||||
|
@ -235,7 +240,6 @@ class ManagedStream:
|
|||
log.info("Stopping inactive download for stream %s", self.sd_hash)
|
||||
self.stop_download()
|
||||
|
||||
log.info("update delayed stop")
|
||||
if self.delayed_stop:
|
||||
self.delayed_stop.cancel()
|
||||
self.delayed_stop = self.loop.call_later(60, _delayed_stop)
|
||||
|
@ -259,6 +263,7 @@ class ManagedStream:
|
|||
async def _save_file(self, output_path: str):
|
||||
self.saving.set()
|
||||
self.finished_writing.clear()
|
||||
self.started_writing.clear()
|
||||
try:
|
||||
with open(output_path, 'wb') as file_write_handle:
|
||||
async for blob_info, decrypted in self.aiter_read_stream():
|
||||
|
@ -266,14 +271,21 @@ class ManagedStream:
|
|||
file_write_handle.write(decrypted)
|
||||
file_write_handle.flush()
|
||||
self.written_bytes += len(decrypted)
|
||||
|
||||
if not self.started_writing.is_set():
|
||||
self.started_writing.set()
|
||||
self.update_status(ManagedStream.STATUS_FINISHED)
|
||||
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
|
||||
if self.analytics_manager:
|
||||
self.loop.create_task(self.analytics_manager.send_download_finished(
|
||||
self.download_id, self.claim_name, self.sd_hash
|
||||
))
|
||||
self.finished_writing.set()
|
||||
except Exception as err:
|
||||
if os.path.isfile(output_path):
|
||||
log.info("removing incomplete download %s for %s", output_path, self.sd_hash)
|
||||
os.remove(output_path)
|
||||
if not isinstance(err, asyncio.CancelledError):
|
||||
log.exception("unexpected error encountered writing file for stream %s", self.sd_hash)
|
||||
if os.path.isfile(output_path):
|
||||
log.info("removing incomplete download %s", output_path)
|
||||
os.remove(output_path)
|
||||
raise err
|
||||
finally:
|
||||
self.saving.clear()
|
||||
|
@ -282,10 +294,9 @@ class ManagedStream:
|
|||
if self.file_output_task and not self.file_output_task.done():
|
||||
self.file_output_task.cancel()
|
||||
if self.delayed_stop:
|
||||
log.info('cancel delayed stop')
|
||||
self.delayed_stop.cancel()
|
||||
self.delayed_stop = None
|
||||
self.download_directory = download_directory or self.download_directory
|
||||
self.download_directory = download_directory or self.download_directory or self.config.download_dir
|
||||
if not self.download_directory:
|
||||
raise ValueError("no directory to download to")
|
||||
if not (file_name or self._file_name or self.descriptor.suggested_file_name):
|
||||
|
|
|
@ -6,8 +6,8 @@ import logging
|
|||
import random
|
||||
from decimal import Decimal
|
||||
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
|
||||
# DownloadDataTimeout, DownloadSDTimeout
|
||||
from lbrynet.utils import generate_id, cache_concurrent
|
||||
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
|
||||
from lbrynet.utils import cache_concurrent
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.managed_stream import ManagedStream
|
||||
from lbrynet.schema.claim import Claim
|
||||
|
@ -96,11 +96,10 @@ class StreamManager:
|
|||
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
|
||||
|
||||
async def start_stream(self, stream: ManagedStream):
|
||||
await stream.setup(self.node, save_file=not self.config.streaming_only)
|
||||
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
||||
stream.update_status(ManagedStream.STATUS_RUNNING)
|
||||
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
|
||||
self.wait_for_stream_finished(stream)
|
||||
await stream.setup(self.node, save_file=not self.config.streaming_only)
|
||||
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
||||
|
||||
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
|
||||
to_restore = []
|
||||
|
@ -139,13 +138,14 @@ class StreamManager:
|
|||
return
|
||||
stream = ManagedStream(
|
||||
self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status,
|
||||
claim, rowid=rowid, descriptor=descriptor
|
||||
claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager
|
||||
)
|
||||
self.streams[sd_hash] = stream
|
||||
|
||||
async def load_streams_from_database(self):
|
||||
to_recover = []
|
||||
to_start = []
|
||||
await self.storage.sync_files_to_blobs()
|
||||
for file_info in await self.storage.get_all_lbry_files():
|
||||
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
|
||||
to_recover.append(file_info)
|
||||
|
@ -181,10 +181,10 @@ class StreamManager:
|
|||
while True:
|
||||
if self.config.reflect_streams and self.config.reflector_servers:
|
||||
sd_hashes = await self.storage.get_streams_to_re_reflect()
|
||||
streams = list(filter(lambda s: s in sd_hashes, self.streams.keys()))
|
||||
sd_hashes = [sd for sd in sd_hashes if sd in self.streams]
|
||||
batch = []
|
||||
while streams:
|
||||
stream = streams.pop()
|
||||
while sd_hashes:
|
||||
stream = self.streams[sd_hashes.pop()]
|
||||
if not stream.fully_reflected.is_set():
|
||||
host, port = random.choice(self.config.reflector_servers)
|
||||
batch.append(stream.upload_to_reflector(host, port))
|
||||
|
@ -198,7 +198,7 @@ class StreamManager:
|
|||
async def start(self):
|
||||
await self.load_streams_from_database()
|
||||
self.resume_downloading_task = self.loop.create_task(self.resume())
|
||||
# self.re_reflect_task = self.loop.create_task(self.reflect_streams())
|
||||
self.re_reflect_task = self.loop.create_task(self.reflect_streams())
|
||||
|
||||
def stop(self):
|
||||
if self.resume_downloading_task and not self.resume_downloading_task.done():
|
||||
|
@ -279,28 +279,11 @@ class StreamManager:
|
|||
streams.reverse()
|
||||
return streams
|
||||
|
||||
def wait_for_stream_finished(self, stream: ManagedStream):
|
||||
async def _wait_for_stream_finished():
|
||||
if stream.downloader and stream.running:
|
||||
await stream.finished_writing.wait()
|
||||
stream.update_status(ManagedStream.STATUS_FINISHED)
|
||||
if self.analytics_manager:
|
||||
self.loop.create_task(self.analytics_manager.send_download_finished(
|
||||
stream.download_id, stream.claim_name, stream.sd_hash
|
||||
))
|
||||
|
||||
task = self.loop.create_task(_wait_for_stream_finished())
|
||||
self.update_stream_finished_futs.append(task)
|
||||
task.add_done_callback(
|
||||
lambda _: None if task not in self.update_stream_finished_futs else
|
||||
self.update_stream_finished_futs.remove(task)
|
||||
)
|
||||
|
||||
async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[
|
||||
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
|
||||
existing = self.get_filtered_streams(outpoint=outpoint)
|
||||
if existing:
|
||||
if not existing[0].running:
|
||||
if existing[0].status == ManagedStream.STATUS_STOPPED:
|
||||
await self.start_stream(existing[0])
|
||||
return existing[0], None
|
||||
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
|
||||
|
@ -323,163 +306,112 @@ class StreamManager:
|
|||
return None, existing_for_claim_id[0]
|
||||
return None, None
|
||||
|
||||
# async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: EncryptedStreamDownloader,
|
||||
# download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict,
|
||||
# file_name: typing.Optional[str] = None) -> ManagedStream:
|
||||
# start_time = self.loop.time()
|
||||
# downloader.download(self.node)
|
||||
# await downloader.got_descriptor.wait()
|
||||
# got_descriptor_time.set_result(self.loop.time() - start_time)
|
||||
# rowid = await self._store_stream(downloader)
|
||||
# await self.storage.save_content_claim(
|
||||
# downloader.descriptor.stream_hash, outpoint
|
||||
# )
|
||||
# stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir,
|
||||
# file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id)
|
||||
# stream.set_claim(resolved, claim)
|
||||
# await stream.downloader.wrote_bytes_event.wait()
|
||||
# self.streams.add(stream)
|
||||
# return stream
|
||||
|
||||
@cache_concurrent
|
||||
async def download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager',
|
||||
file_name: typing.Optional[str] = None) -> ManagedStream:
|
||||
async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
|
||||
timeout: typing.Optional[float] = None,
|
||||
file_name: typing.Optional[str] = None,
|
||||
download_directory: typing.Optional[str] = None,
|
||||
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
|
||||
timeout = timeout or self.config.download_timeout
|
||||
start_time = self.loop.time()
|
||||
parsed_uri = parse_lbry_uri(uri)
|
||||
if parsed_uri.is_channel:
|
||||
raise ResolveError("cannot download a channel claim, specify a /path")
|
||||
|
||||
# resolve the claim
|
||||
resolved_result = await self.wallet.ledger.resolve(0, 10, uri)
|
||||
await self.storage.save_claims_for_resolve([
|
||||
value for value in resolved_result.values() if 'error' not in value
|
||||
])
|
||||
resolved = resolved_result.get(uri, {})
|
||||
resolved = resolved if 'value' in resolved else resolved.get('claim')
|
||||
if not resolved:
|
||||
raise ResolveError(f"Failed to resolve stream at '{uri}'")
|
||||
if 'error' in resolved:
|
||||
raise ResolveError(f"error resolving stream: {resolved['error']}")
|
||||
resolved_time = None
|
||||
stream = None
|
||||
error = None
|
||||
outpoint = None
|
||||
try:
|
||||
# resolve the claim
|
||||
parsed_uri = parse_lbry_uri(uri)
|
||||
if parsed_uri.is_channel:
|
||||
raise ResolveError("cannot download a channel claim, specify a /path")
|
||||
try:
|
||||
resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise ResolveTimeout(uri)
|
||||
await self.storage.save_claims_for_resolve([
|
||||
value for value in resolved_result.values() if 'error' not in value
|
||||
])
|
||||
resolved = resolved_result.get(uri, {})
|
||||
resolved = resolved if 'value' in resolved else resolved.get('claim')
|
||||
if not resolved:
|
||||
raise ResolveError(f"Failed to resolve stream at '{uri}'")
|
||||
if 'error' in resolved:
|
||||
raise ResolveError(f"error resolving stream: {resolved['error']}")
|
||||
|
||||
claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf']))
|
||||
outpoint = f"{resolved['txid']}:{resolved['nout']}"
|
||||
resolved_time = self.loop.time() - start_time
|
||||
|
||||
# resume or update an existing stream, if the stream changed download it and delete the old one after
|
||||
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
|
||||
if updated_stream:
|
||||
return updated_stream
|
||||
# resume or update an existing stream, if the stream changed download it and delete the old one after
|
||||
updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim)
|
||||
if updated_stream:
|
||||
return updated_stream
|
||||
|
||||
# check that the fee is payable
|
||||
fee_amount, fee_address = None, None
|
||||
if claim.stream.has_fee:
|
||||
fee_amount = round(exchange_rate_manager.convert_currency(
|
||||
claim.stream.fee.currency, "LBC", claim.stream.fee.amount
|
||||
), 5)
|
||||
max_fee_amount = round(exchange_rate_manager.convert_currency(
|
||||
self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount'])
|
||||
), 5)
|
||||
if fee_amount > max_fee_amount:
|
||||
msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}"
|
||||
log.warning(msg)
|
||||
raise KeyFeeAboveMaxAllowed(msg)
|
||||
balance = await self.wallet.default_account.get_balance()
|
||||
if lbc_to_dewies(str(fee_amount)) > balance:
|
||||
msg = f"fee of {fee_amount} exceeds max available balance"
|
||||
log.warning(msg)
|
||||
raise InsufficientFundsError(msg)
|
||||
fee_address = claim.stream.fee.address
|
||||
# content_fee_tx = await self.wallet.send_amount_to_address(
|
||||
# lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
|
||||
# )
|
||||
# check that the fee is payable
|
||||
if not to_replace and claim.stream.has_fee:
|
||||
fee_amount = round(exchange_rate_manager.convert_currency(
|
||||
claim.stream.fee.currency, "LBC", claim.stream.fee.amount
|
||||
), 5)
|
||||
max_fee_amount = round(exchange_rate_manager.convert_currency(
|
||||
self.config.max_key_fee['currency'], "LBC", Decimal(self.config.max_key_fee['amount'])
|
||||
), 5)
|
||||
if fee_amount > max_fee_amount:
|
||||
msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}"
|
||||
log.warning(msg)
|
||||
raise KeyFeeAboveMaxAllowed(msg)
|
||||
balance = await self.wallet.default_account.get_balance()
|
||||
if lbc_to_dewies(str(fee_amount)) > balance:
|
||||
msg = f"fee of {fee_amount} exceeds max available balance"
|
||||
log.warning(msg)
|
||||
raise InsufficientFundsError(msg)
|
||||
fee_address = claim.stream.fee.address
|
||||
await self.wallet.send_amount_to_address(
|
||||
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
|
||||
)
|
||||
log.info("paid fee of %s for %s", fee_amount, uri)
|
||||
|
||||
handled_fee_time = self.loop.time() - resolved_time - start_time
|
||||
|
||||
# download the stream
|
||||
download_id = binascii.hexlify(generate_id()).decode()
|
||||
|
||||
download_dir = self.config.download_dir
|
||||
save_file = True
|
||||
if not file_name and self.config.streaming_only:
|
||||
download_dir, file_name = None, None
|
||||
save_file = False
|
||||
stream = ManagedStream(
|
||||
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_dir,
|
||||
file_name, ManagedStream.STATUS_RUNNING, download_id=download_id
|
||||
)
|
||||
|
||||
await stream.setup(self.node, save_file=save_file)
|
||||
stream.set_claim(resolved, claim)
|
||||
await self.storage.save_content_claim(stream.stream_hash, outpoint)
|
||||
self.streams[stream.sd_hash] = stream
|
||||
|
||||
# stream = None
|
||||
# descriptor_time_fut = self.loop.create_future()
|
||||
# start_download_time = self.loop.time()
|
||||
# time_to_descriptor = None
|
||||
# time_to_first_bytes = None
|
||||
# error = None
|
||||
# try:
|
||||
# stream = await asyncio.wait_for(
|
||||
# asyncio.ensure_future(
|
||||
# self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved,
|
||||
# file_name)
|
||||
# ), timeout
|
||||
# )
|
||||
# time_to_descriptor = await descriptor_time_fut
|
||||
# time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor
|
||||
# self.wait_for_stream_finished(stream)
|
||||
# if fee_address and fee_amount and not to_replace:
|
||||
#
|
||||
# elif to_replace: # delete old stream now that the replacement has started downloading
|
||||
# await self.delete_stream(to_replace)
|
||||
# except asyncio.TimeoutError:
|
||||
# if descriptor_time_fut.done():
|
||||
# time_to_descriptor = descriptor_time_fut.result()
|
||||
# error = DownloadDataTimeout(downloader.sd_hash)
|
||||
# self.blob_manager.delete_blob(downloader.sd_hash)
|
||||
# await self.storage.delete_stream(downloader.descriptor)
|
||||
# else:
|
||||
# descriptor_time_fut.cancel()
|
||||
# error = DownloadSDTimeout(downloader.sd_hash)
|
||||
# if stream:
|
||||
# await self.stop_stream(stream)
|
||||
# else:
|
||||
# downloader.stop()
|
||||
# if error:
|
||||
# log.warning(error)
|
||||
# if self.analytics_manager:
|
||||
# self.loop.create_task(
|
||||
# self.analytics_manager.send_time_to_first_bytes(
|
||||
# resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint,
|
||||
# None if not stream else len(stream.downloader.blob_downloader.active_connections),
|
||||
# None if not stream else len(stream.downloader.blob_downloader.scores),
|
||||
# False if not downloader else downloader.added_fixed_peers,
|
||||
# self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay,
|
||||
# claim.source_hash.decode(), time_to_descriptor,
|
||||
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
|
||||
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
|
||||
# time_to_first_bytes, None if not error else error.__class__.__name__
|
||||
# )
|
||||
# )
|
||||
# if error:
|
||||
# raise error
|
||||
return stream
|
||||
|
||||
# async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
|
||||
# file_name: typing.Optional[str] = None,
|
||||
# timeout: typing.Optional[float] = None) -> ManagedStream:
|
||||
# timeout = timeout or self.config.download_timeout
|
||||
# if uri in self.starting_streams:
|
||||
# return await self.starting_streams[uri]
|
||||
# fut = asyncio.Future(loop=self.loop)
|
||||
# self.starting_streams[uri] = fut
|
||||
# try:
|
||||
# stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name)
|
||||
# fut.set_result(stream)
|
||||
# except Exception as err:
|
||||
# fut.set_exception(err)
|
||||
# try:
|
||||
# return await fut
|
||||
# finally:
|
||||
# del self.starting_streams[uri]
|
||||
download_directory = download_directory or self.config.download_dir
|
||||
if not file_name and (self.config.streaming_only or not save_file):
|
||||
download_dir, file_name = None, None
|
||||
stream = ManagedStream(
|
||||
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
|
||||
file_name, ManagedStream.STATUS_RUNNING, analytics_manager=self.analytics_manager
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(stream.setup(
|
||||
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
|
||||
), timeout, loop=self.loop)
|
||||
except asyncio.TimeoutError:
|
||||
if not stream.descriptor:
|
||||
raise DownloadSDTimeout(stream.sd_hash)
|
||||
raise DownloadDataTimeout(stream.sd_hash)
|
||||
if to_replace: # delete old stream now that the replacement has started downloading
|
||||
await self.delete_stream(to_replace)
|
||||
stream.set_claim(resolved, claim)
|
||||
await self.storage.save_content_claim(stream.stream_hash, outpoint)
|
||||
self.streams[stream.sd_hash] = stream
|
||||
return stream
|
||||
except Exception as err:
|
||||
error = err
|
||||
if stream and stream.descriptor:
|
||||
await self.storage.delete_stream(stream.descriptor)
|
||||
finally:
|
||||
if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or
|
||||
stream.downloader.time_to_first_bytes))):
|
||||
self.loop.create_task(
|
||||
self.analytics_manager.send_time_to_first_bytes(
|
||||
resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id,
|
||||
uri, outpoint,
|
||||
None if not stream else len(stream.downloader.blob_downloader.active_connections),
|
||||
None if not stream else len(stream.downloader.blob_downloader.scores),
|
||||
False if not stream else stream.downloader.added_fixed_peers,
|
||||
self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay,
|
||||
None if not stream else stream.sd_hash,
|
||||
None if not stream else stream.downloader.time_to_descriptor,
|
||||
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
|
||||
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
|
||||
None if not stream else stream.downloader.time_to_first_bytes,
|
||||
None if not error else error.__class__.__name__
|
||||
)
|
||||
)
|
||||
if error:
|
||||
raise error
|
||||
|
|
|
@ -28,9 +28,9 @@ class TestBlobfile(AsyncioTestCase):
|
|||
self.assertEqual(blob.get_is_verified(), False)
|
||||
self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes)
|
||||
|
||||
writer = blob.open_for_writing()
|
||||
writer = blob.get_blob_writer()
|
||||
writer.write(blob_bytes)
|
||||
await blob.finished_writing.wait()
|
||||
await blob.verified.wait()
|
||||
self.assertTrue(os.path.isfile(blob.file_path), True)
|
||||
self.assertEqual(blob.get_is_verified(), True)
|
||||
self.assertIn(blob_hash, blob_manager.completed_blob_hashes)
|
||||
|
|
|
@ -11,7 +11,7 @@ from lbrynet.conf import Config
|
|||
from lbrynet.extras.daemon.storage import SQLiteStorage
|
||||
from lbrynet.blob.blob_manager import BlobManager
|
||||
from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol
|
||||
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
|
||||
from lbrynet.blob_exchange.client import request_blob
|
||||
from lbrynet.dht.peer import KademliaPeer, PeerManager
|
||||
|
||||
# import logging
|
||||
|
@ -58,9 +58,9 @@ class TestBlobExchange(BlobExchangeTestBase):
|
|||
async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes):
|
||||
# add the blob on the server
|
||||
server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes))
|
||||
writer = server_blob.open_for_writing()
|
||||
writer = server_blob.get_blob_writer()
|
||||
writer.write(blob_bytes)
|
||||
await server_blob.finished_writing.wait()
|
||||
await server_blob.verified.wait()
|
||||
self.assertTrue(os.path.isfile(server_blob.file_path))
|
||||
self.assertEqual(server_blob.get_is_verified(), True)
|
||||
|
||||
|
@ -68,11 +68,14 @@ class TestBlobExchange(BlobExchangeTestBase):
|
|||
client_blob = self.client_blob_manager.get_blob(blob_hash)
|
||||
|
||||
# download the blob
|
||||
downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address,
|
||||
self.server_from_client.tcp_port, 2, 3)
|
||||
await client_blob.finished_writing.wait()
|
||||
downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address,
|
||||
self.server_from_client.tcp_port, 2, 3)
|
||||
self.assertIsNotNone(transport)
|
||||
self.addCleanup(transport.close)
|
||||
await client_blob.verified.wait()
|
||||
self.assertEqual(client_blob.get_is_verified(), True)
|
||||
self.assertTrue(downloaded)
|
||||
self.addCleanup(client_blob.close)
|
||||
|
||||
async def test_transfer_sd_blob(self):
|
||||
sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02"
|
||||
|
@ -112,7 +115,7 @@ class TestBlobExchange(BlobExchangeTestBase):
|
|||
),
|
||||
self._test_transfer_blob(blob_hash)
|
||||
)
|
||||
await second_client_blob.finished_writing.wait()
|
||||
await second_client_blob.verified.wait()
|
||||
self.assertEqual(second_client_blob.get_is_verified(), True)
|
||||
|
||||
async def test_host_different_blobs_to_multiple_peers_at_once(self):
|
||||
|
@ -143,7 +146,7 @@ class TestBlobExchange(BlobExchangeTestBase):
|
|||
server_from_second_client.tcp_port, 2, 3
|
||||
),
|
||||
self._test_transfer_blob(sd_hash),
|
||||
second_client_blob.finished_writing.wait()
|
||||
second_client_blob.verified.wait()
|
||||
)
|
||||
self.assertEqual(second_client_blob.get_is_verified(), True)
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import asyncio
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from unittest import mock, TestCase
|
||||
from lbrynet.dht.protocol.data_store import DictDataStore
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
|
||||
|
||||
class DataStoreTests(AsyncioTestCase):
|
||||
class DataStoreTests(TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.loop = mock.Mock(spec=asyncio.BaseEventLoop)
|
||||
self.loop.time = lambda: 0.0
|
||||
self.peer_manager = PeerManager(self.loop)
|
||||
self.data_store = DictDataStore(self.loop, self.peer_manager)
|
||||
|
||||
|
|
|
@ -1,117 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
|
||||
from lbrynet.extras.daemon.storage import SQLiteStorage
|
||||
from lbrynet.blob.blob_manager import BlobManager
|
||||
from lbrynet.stream.assembler import StreamAssembler
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.stream_manager import StreamManager
|
||||
|
||||
|
||||
class TestStreamAssembler(AsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.key = b'deadbeef' * 4
|
||||
self.cleartext = b'test'
|
||||
|
||||
async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
||||
self.storage = SQLiteStorage(Config(), ":memory:")
|
||||
await self.storage.open()
|
||||
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage)
|
||||
|
||||
download_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||
|
||||
# create the stream
|
||||
file_path = os.path.join(tmp_dir, "test_file")
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(self.cleartext)
|
||||
|
||||
sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key)
|
||||
|
||||
# copy blob files
|
||||
sd_hash = sd.calculate_sd_hash()
|
||||
shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash))
|
||||
for blob_info in sd.blobs:
|
||||
if blob_info.blob_hash:
|
||||
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
|
||||
if corrupt and blob_info.length == MAX_BLOB_SIZE:
|
||||
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
|
||||
handle.truncate()
|
||||
handle.flush()
|
||||
|
||||
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
|
||||
await downloader_storage.open()
|
||||
|
||||
# add the blobs to the blob table (this would happen upon a blob download finishing)
|
||||
downloader_blob_manager = BlobManager(self.loop, download_dir, downloader_storage)
|
||||
descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash)
|
||||
|
||||
# assemble the decrypted file
|
||||
assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash)
|
||||
await assembler.assemble_decrypted_stream(download_dir)
|
||||
if corrupt:
|
||||
return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file")))
|
||||
|
||||
with open(os.path.join(download_dir, "test_file"), "rb") as f:
|
||||
decrypted = f.read()
|
||||
self.assertEqual(decrypted, self.cleartext)
|
||||
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
|
||||
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
|
||||
# its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
|
||||
self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
|
||||
self.assertEqual(
|
||||
[descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
|
||||
)
|
||||
|
||||
await downloader_storage.close()
|
||||
await self.storage.close()
|
||||
|
||||
async def test_create_and_decrypt_multi_blob_stream(self):
|
||||
self.cleartext = b'test\n' * 20000000
|
||||
await self.test_create_and_decrypt_one_blob_stream()
|
||||
|
||||
async def test_create_and_decrypt_padding(self):
|
||||
for i in range(16):
|
||||
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i)
|
||||
await self.test_create_and_decrypt_one_blob_stream()
|
||||
|
||||
for i in range(16):
|
||||
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i)
|
||||
await self.test_create_and_decrypt_one_blob_stream()
|
||||
|
||||
async def test_create_and_decrypt_random(self):
|
||||
self.cleartext = os.urandom(20000000)
|
||||
await self.test_create_and_decrypt_one_blob_stream()
|
||||
|
||||
async def test_create_managed_stream_announces(self):
|
||||
# setup a blob manager
|
||||
storage = SQLiteStorage(Config(), ":memory:")
|
||||
await storage.open()
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
||||
blob_manager = BlobManager(self.loop, tmp_dir, storage)
|
||||
stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
|
||||
# create the stream
|
||||
download_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||
file_path = os.path.join(download_dir, "test_file")
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(b'testtest')
|
||||
|
||||
stream = await stream_manager.create_stream(file_path)
|
||||
self.assertEqual(
|
||||
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
|
||||
await storage.get_blobs_to_announce())
|
||||
|
||||
async def test_create_truncate_and_handle_stream(self):
|
||||
self.cleartext = b'potato' * 1337 * 5279
|
||||
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
|
||||
await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
import unittest
|
||||
from unittest import mock
|
||||
import asyncio
|
||||
|
||||
from lbrynet.blob_exchange.serialization import BlobResponse
|
||||
from lbrynet.blob_exchange.server import BlobServerProtocol
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.downloader import StreamDownloader
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
|
||||
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
|
||||
|
||||
|
||||
class TestStreamDownloader(BlobExchangeTestBase):
|
||||
async def setup_stream(self, blob_count: int = 10):
|
||||
self.stream_bytes = b''
|
||||
for _ in range(blob_count):
|
||||
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
|
||||
# create the stream
|
||||
file_path = os.path.join(self.server_dir, "test_file")
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(self.stream_bytes)
|
||||
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
|
||||
self.sd_hash = descriptor.calculate_sd_hash()
|
||||
conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir,
|
||||
reflector_servers=[])
|
||||
self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash)
|
||||
|
||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
|
||||
await self.setup_stream(blob_count)
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
|
||||
def _mock_accumulate_peers(q1, q2):
|
||||
async def _task():
|
||||
pass
|
||||
q2.put_nowait([self.server_from_client])
|
||||
return q2, self.loop.create_task(_task())
|
||||
|
||||
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
||||
self.downloader.download(mock_node)
|
||||
await self.downloader.stream_finished_event.wait()
|
||||
self.assertTrue(self.downloader.stream_handle.closed)
|
||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||
self.downloader.stop()
|
||||
self.assertIs(self.downloader.stream_handle, None)
|
||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||
with open(self.downloader.output_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def test_transfer_stream(self):
|
||||
await self._test_transfer_stream(10)
|
||||
|
||||
@unittest.SkipTest
|
||||
async def test_transfer_hundred_blob_stream(self):
|
||||
await self._test_transfer_stream(100)
|
||||
|
||||
async def test_transfer_stream_bad_first_peer_good_second(self):
|
||||
await self.setup_stream(2)
|
||||
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
q = asyncio.Queue()
|
||||
|
||||
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
|
||||
|
||||
def _mock_accumulate_peers(q1, q2):
|
||||
async def _task():
|
||||
pass
|
||||
|
||||
q2.put_nowait([bad_peer])
|
||||
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
|
||||
return q2, self.loop.create_task(_task())
|
||||
|
||||
mock_node.accumulate_peers = _mock_accumulate_peers
|
||||
|
||||
self.downloader.download(mock_node)
|
||||
await self.downloader.stream_finished_event.wait()
|
||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||
with open(self.downloader.output_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
# self.assertIs(self.server_from_client.tcp_last_down, None)
|
||||
# self.assertIsNot(bad_peer.tcp_last_down, None)
|
||||
|
||||
async def test_client_chunked_response(self):
|
||||
self.server.stop_server()
|
||||
class ChunkedServerProtocol(BlobServerProtocol):
|
||||
|
||||
def send_response(self, responses):
|
||||
to_send = []
|
||||
while responses:
|
||||
to_send.append(responses.pop())
|
||||
for byte in BlobResponse(to_send).serialize():
|
||||
self.transport.write(bytes([byte]))
|
||||
self.server.server_protocol_class = ChunkedServerProtocol
|
||||
self.server.start_server(33333, '127.0.0.1')
|
||||
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
|
||||
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
|
||||
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))
|
175
tests/unit/stream/test_managed_stream.py
Normal file
175
tests/unit/stream/test_managed_stream.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
from unittest import mock
|
||||
import asyncio
|
||||
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
|
||||
from lbrynet.blob_exchange.serialization import BlobResponse
|
||||
from lbrynet.blob_exchange.server import BlobServerProtocol
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
from lbrynet.extras.daemon.storage import StoredStreamClaim
|
||||
from lbrynet.stream.managed_stream import ManagedStream
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
|
||||
|
||||
|
||||
def get_mock_node(loop):
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
mock_node.joined = asyncio.Event(loop=loop)
|
||||
mock_node.joined.set()
|
||||
return mock_node
|
||||
|
||||
|
||||
class TestManagedStream(BlobExchangeTestBase):
|
||||
async def create_stream(self, blob_count: int = 10):
|
||||
self.stream_bytes = b''
|
||||
for _ in range(blob_count):
|
||||
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
|
||||
# create the stream
|
||||
file_path = os.path.join(self.server_dir, "test_file")
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(self.stream_bytes)
|
||||
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
|
||||
self.sd_hash = descriptor.calculate_sd_hash()
|
||||
return descriptor
|
||||
|
||||
async def setup_stream(self, blob_count: int = 10):
|
||||
await self.create_stream(blob_count)
|
||||
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
|
||||
self.client_dir)
|
||||
|
||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
|
||||
await self.setup_stream(blob_count)
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
|
||||
def _mock_accumulate_peers(q1, q2):
|
||||
async def _task():
|
||||
pass
|
||||
q2.put_nowait([self.server_from_client])
|
||||
return q2, self.loop.create_task(_task())
|
||||
|
||||
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
||||
await self.stream.setup(mock_node, save_file=True)
|
||||
await self.stream.finished_writing.wait()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
self.stream.stop_download()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
with open(self.stream.full_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def test_transfer_stream(self):
|
||||
await self._test_transfer_stream(10)
|
||||
|
||||
@unittest.SkipTest
|
||||
async def test_transfer_hundred_blob_stream(self):
|
||||
await self._test_transfer_stream(100)
|
||||
|
||||
async def test_transfer_stream_bad_first_peer_good_second(self):
|
||||
await self.setup_stream(2)
|
||||
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
q = asyncio.Queue()
|
||||
|
||||
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
|
||||
|
||||
def _mock_accumulate_peers(q1, q2):
|
||||
async def _task():
|
||||
pass
|
||||
|
||||
q2.put_nowait([bad_peer])
|
||||
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
|
||||
return q2, self.loop.create_task(_task())
|
||||
|
||||
mock_node.accumulate_peers = _mock_accumulate_peers
|
||||
|
||||
await self.stream.setup(mock_node, save_file=True)
|
||||
await self.stream.finished_writing.wait()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
with open(self.stream.full_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
# self.assertIs(self.server_from_client.tcp_last_down, None)
|
||||
# self.assertIsNot(bad_peer.tcp_last_down, None)
|
||||
|
||||
async def test_client_chunked_response(self):
|
||||
self.server.stop_server()
|
||||
|
||||
class ChunkedServerProtocol(BlobServerProtocol):
|
||||
def send_response(self, responses):
|
||||
to_send = []
|
||||
while responses:
|
||||
to_send.append(responses.pop())
|
||||
for byte in BlobResponse(to_send).serialize():
|
||||
self.transport.write(bytes([byte]))
|
||||
self.server.server_protocol_class = ChunkedServerProtocol
|
||||
self.server.start_server(33333, '127.0.0.1')
|
||||
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
|
||||
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
|
||||
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))
|
||||
|
||||
async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False):
|
||||
descriptor = await self.create_stream(blobs)
|
||||
|
||||
# copy blob files
|
||||
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash),
|
||||
os.path.join(self.client_blob_manager.blob_dir, self.sd_hash))
|
||||
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
|
||||
self.client_dir)
|
||||
|
||||
for blob_info in descriptor.blobs[:-1]:
|
||||
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash),
|
||||
os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash))
|
||||
if corrupt and blob_info.length == MAX_BLOB_SIZE:
|
||||
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
|
||||
handle.truncate()
|
||||
handle.flush()
|
||||
await self.stream.setup()
|
||||
await self.stream.finished_writing.wait()
|
||||
if corrupt:
|
||||
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||
|
||||
with open(os.path.join(self.client_dir, "test_file"), "rb") as f:
|
||||
decrypted = f.read()
|
||||
self.assertEqual(decrypted, self.stream_bytes)
|
||||
|
||||
self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified())
|
||||
self.assertEqual(
|
||||
True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
|
||||
)
|
||||
#
|
||||
# # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
|
||||
# self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
|
||||
# self.assertEqual(
|
||||
# [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
|
||||
# )
|
||||
#
|
||||
# await downloader_storage.close()
|
||||
# await self.storage.close()
|
||||
|
||||
async def test_create_and_decrypt_multi_blob_stream(self):
|
||||
await self.test_create_and_decrypt_one_blob_stream(10)
|
||||
|
||||
# async def test_create_managed_stream_announces(self):
|
||||
# # setup a blob manager
|
||||
# storage = SQLiteStorage(Config(), ":memory:")
|
||||
# await storage.open()
|
||||
# tmp_dir = tempfile.mkdtemp()
|
||||
# self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
||||
# blob_manager = BlobManager(self.loop, tmp_dir, storage)
|
||||
# stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
|
||||
# # create the stream
|
||||
# download_dir = tempfile.mkdtemp()
|
||||
# self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||
# file_path = os.path.join(download_dir, "test_file")
|
||||
# with open(file_path, 'wb') as f:
|
||||
# f.write(b'testtest')
|
||||
#
|
||||
# stream = await stream_manager.create_stream(file_path)
|
||||
# self.assertEqual(
|
||||
# [stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
|
||||
# await storage.get_blobs_to_announce())
|
||||
|
||||
# async def test_create_truncate_and_handle_stream(self):
|
||||
# # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
|
||||
# await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)
|
|
@ -99,7 +99,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
|
|||
|
||||
blob = blob_manager.get_blob(sd_hash)
|
||||
blob.set_length(len(sd_bytes))
|
||||
writer = blob.open_for_writing()
|
||||
writer = blob.get_blob_writer()
|
||||
writer.write(sd_bytes)
|
||||
await blob.verified.wait()
|
||||
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
|
@ -116,7 +116,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
|
|||
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
|
||||
with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle:
|
||||
handle.write(b'doesnt work')
|
||||
blob = BlobFile(loop, tmp_dir, sd_hash)
|
||||
blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir)
|
||||
self.assertTrue(blob.file_exists)
|
||||
self.assertIsNotNone(blob.length)
|
||||
with self.assertRaises(InvalidStreamDescriptorError):
|
||||
|
|
|
@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
def check_post(event):
|
||||
self.assertEqual(event['event'], 'Time To First Bytes')
|
||||
self.assertEqual(event['properties']['error'], 'DownloadSDTimeout')
|
||||
self.assertEqual(event['properties']['tried_peers_count'], None)
|
||||
self.assertEqual(event['properties']['active_peer_count'], None)
|
||||
self.assertEqual(event['properties']['tried_peers_count'], 0)
|
||||
self.assertEqual(event['properties']['active_peer_count'], 0)
|
||||
self.assertEqual(event['properties']['use_fixed_peers'], False)
|
||||
self.assertEqual(event['properties']['added_fixed_peers'], False)
|
||||
self.assertEqual(event['properties']['fixed_peer_delay'], None)
|
||||
|
@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
|
||||
self.stream_manager.analytics_manager._post = check_post
|
||||
|
||||
self.assertSetEqual(self.stream_manager.streams, set())
|
||||
self.assertDictEqual(self.stream_manager.streams, {})
|
||||
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
|
||||
stream_hash = stream.stream_hash
|
||||
self.assertSetEqual(self.stream_manager.streams, {stream})
|
||||
self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream})
|
||||
self.assertTrue(stream.running)
|
||||
self.assertFalse(stream.finished)
|
||||
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||
|
@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
self.assertEqual(stored_status, "stopped")
|
||||
|
||||
await self.stream_manager.start_stream(stream)
|
||||
await stream.downloader.stream_finished_event.wait()
|
||||
await stream.finished_writing.wait()
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
self.assertTrue(stream.finished)
|
||||
self.assertFalse(stream.running)
|
||||
|
@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
self.assertEqual(stored_status, "finished")
|
||||
|
||||
await self.stream_manager.delete_stream(stream, True)
|
||||
self.assertSetEqual(self.stream_manager.streams, set())
|
||||
self.assertDictEqual(self.stream_manager.streams, {})
|
||||
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||
stored_status = await self.client_storage.run_and_return_one_or_none(
|
||||
"select status from file where stream_hash=?", stream_hash
|
||||
|
@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
|
||||
async def _test_download_error_on_start(self, expected_error, timeout=None):
|
||||
with self.assertRaises(expected_error):
|
||||
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout)
|
||||
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
|
||||
|
||||
async def _test_download_error_analytics_on_start(self, expected_error, timeout=None):
|
||||
received = []
|
||||
|
@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
await self.setup_stream_manager(old_sort=old_sort)
|
||||
self.stream_manager.analytics_manager._post = check_post
|
||||
|
||||
self.assertSetEqual(self.stream_manager.streams, set())
|
||||
self.assertDictEqual(self.stream_manager.streams, {})
|
||||
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
|
||||
await stream.downloader.stream_finished_event.wait()
|
||||
await stream.finished_writing.wait()
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
self.stream_manager.stop()
|
||||
self.client_blob_manager.stop()
|
||||
|
@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
await self.client_blob_manager.setup()
|
||||
await self.stream_manager.start()
|
||||
self.assertEqual(1, len(self.stream_manager.streams))
|
||||
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash)
|
||||
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status)
|
||||
self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys()))
|
||||
for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]:
|
||||
blob_status = await self.client_storage.get_blob_status(blob_hash)
|
||||
self.assertEqual('pending', blob_status)
|
||||
self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status)
|
||||
|
||||
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
|
||||
self.assertTrue(sd_blob.file_exists)
|
||||
|
|
Loading…
Add table
Reference in a new issue