This commit is contained in:
Jack Robison 2019-03-31 13:42:27 -04:00
parent f125468ebf
commit 3a916a8e8e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
14 changed files with 381 additions and 441 deletions

View file

@ -32,6 +32,12 @@ class DownloadDataTimeout(Exception):
self.download = download
class ResolveTimeout(Exception):
def __init__(self, uri):
super().__init__(f'Failed to resolve "{uri}" within the timeout')
self.uri = uri
class RequestCanceledError(Exception):
pass

View file

@ -897,7 +897,7 @@ class Daemon(metaclass=JSONRPCServerType):
"""
try:
stream = await self.stream_manager.download_stream_from_uri(
uri, timeout, self.exchange_rate_manager, file_name
uri, self.exchange_rate_manager, timeout, file_name
)
if not stream:
raise DownloadSDTimeout(uri)

View file

@ -423,6 +423,17 @@ class SQLiteStorage(SQLiteMixin):
}
return self.db.run(_sync_blobs)
def sync_files_to_blobs(self):
def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]:
transaction.executemany(
"update file set status='stopped' where stream_hash=?",
transaction.execute(
"select distinct sb.stream_hash from stream_blob sb "
"inner join blob b on b.blob_hash=sb.blob_hash and b.status=='pending'"
).fetchall()
)
return self.db.run(_sync_blobs)
# # # # # # # # # stream functions # # # # # # # # #
async def stream_exists(self, sd_hash: str) -> bool:

View file

@ -2,6 +2,7 @@ import asyncio
import typing
import logging
import binascii
from lbrynet.error import DownloadSDTimeout
from lbrynet.utils import resolve_host
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.blob_exchange.downloader import BlobDownloader
@ -32,6 +33,8 @@ class StreamDownloader:
self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None
self.fixed_peers_delay: typing.Optional[float] = None
self.added_fixed_peers = False
self.time_to_descriptor: typing.Optional[float] = None
self.time_to_first_bytes: typing.Optional[float] = None
async def add_fixed_peers(self):
def _delayed_add_fixed_peers():
@ -59,8 +62,16 @@ class StreamDownloader:
# download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified():
sd_blob = await self.blob_downloader.download_blob(self.sd_hash)
try:
now = self.loop.time()
sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash),
self.config.blob_download_timeout, loop=self.loop
)
log.info("downloaded sd blob %s", self.sd_hash)
self.time_to_descriptor = self.loop.time() - now
except asyncio.TimeoutError:
raise DownloadSDTimeout(self.sd_hash)
# parse the descriptor
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -101,12 +112,18 @@ class StreamDownloader:
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
)
async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob'):
async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob)
async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
start = None
if self.time_to_first_bytes is None:
start = self.loop.time()
blob = await self.download_stream_blob(blob_info)
return await self.decrypt_blob(blob_info, blob)
decrypted = await self.decrypt_blob(blob_info, blob)
if start:
self.time_to_first_bytes = self.loop.time() - start
return decrypted
def stop(self):
if self.accumulate_task:

View file

@ -9,13 +9,13 @@ from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.reflector.client import StreamReflectorClient
from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.blob import MAX_BLOB_SIZE
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
from lbrynet.schema.claim import Claim
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_info import BlobInfo
from lbrynet.dht.node import Node
from lbrynet.extras.daemon.analytics import AnalyticsManager
log = logging.getLogger(__name__)
@ -43,7 +43,8 @@ class ManagedStream:
sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None,
descriptor: typing.Optional[StreamDescriptor] = None):
descriptor: typing.Optional[StreamDescriptor] = None,
analytics_manager: typing.Optional['AnalyticsManager'] = None):
self.loop = loop
self.config = config
self.blob_manager = blob_manager
@ -56,11 +57,13 @@ class ManagedStream:
self.rowid = rowid
self.written_bytes = 0
self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
self.analytics_manager = analytics_manager
self.fully_reflected = asyncio.Event(loop=self.loop)
self.file_output_task: typing.Optional[asyncio.Task] = None
self.delayed_stop: typing.Optional[asyncio.Handle] = None
self.saving = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=self.loop)
self.started_writing = asyncio.Event(loop=self.loop)
@property
def descriptor(self) -> StreamDescriptor:
@ -217,16 +220,18 @@ class ManagedStream:
return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path),
os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor)
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True):
async def setup(self, node: typing.Optional['Node'] = None, save_file: typing.Optional[bool] = True,
file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None):
await self.downloader.start(node)
if not save_file:
if not save_file and not file_name:
if not await self.blob_manager.storage.file_exists(self.sd_hash):
self.rowid = self.blob_manager.storage.save_downloaded_file(
self.stream_hash, None, None, 0.0
)
self.update_delayed_stop()
else:
await self.save_file()
await self.save_file(file_name, download_directory)
await self.started_writing.wait()
self.update_status(ManagedStream.STATUS_RUNNING)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING)
@ -235,7 +240,6 @@ class ManagedStream:
log.info("Stopping inactive download for stream %s", self.sd_hash)
self.stop_download()
log.info("update delayed stop")
if self.delayed_stop:
self.delayed_stop.cancel()
self.delayed_stop = self.loop.call_later(60, _delayed_stop)
@ -259,6 +263,7 @@ class ManagedStream:
async def _save_file(self, output_path: str):
self.saving.set()
self.finished_writing.clear()
self.started_writing.clear()
try:
with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self.aiter_read_stream():
@ -266,14 +271,21 @@ class ManagedStream:
file_write_handle.write(decrypted)
file_write_handle.flush()
self.written_bytes += len(decrypted)
if not self.started_writing.is_set():
self.started_writing.set()
self.update_status(ManagedStream.STATUS_FINISHED)
await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
self.download_id, self.claim_name, self.sd_hash
))
self.finished_writing.set()
except Exception as err:
if os.path.isfile(output_path):
log.info("removing incomplete download %s for %s", output_path, self.sd_hash)
os.remove(output_path)
if not isinstance(err, asyncio.CancelledError):
log.exception("unexpected error encountered writing file for stream %s", self.sd_hash)
if os.path.isfile(output_path):
log.info("removing incomplete download %s", output_path)
os.remove(output_path)
raise err
finally:
self.saving.clear()
@ -282,10 +294,9 @@ class ManagedStream:
if self.file_output_task and not self.file_output_task.done():
self.file_output_task.cancel()
if self.delayed_stop:
log.info('cancel delayed stop')
self.delayed_stop.cancel()
self.delayed_stop = None
self.download_directory = download_directory or self.download_directory
self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory:
raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.descriptor.suggested_file_name):

View file

@ -6,8 +6,8 @@ import logging
import random
from decimal import Decimal
from lbrynet.error import ResolveError, InvalidStreamDescriptorError, KeyFeeAboveMaxAllowed, InsufficientFundsError
# DownloadDataTimeout, DownloadSDTimeout
from lbrynet.utils import generate_id, cache_concurrent
from lbrynet.error import DownloadSDTimeout, DownloadDataTimeout, ResolveTimeout
from lbrynet.utils import cache_concurrent
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import Claim
@ -96,11 +96,10 @@ class StreamManager:
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
async def start_stream(self, stream: ManagedStream):
await stream.setup(self.node, save_file=not self.config.streaming_only)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
stream.update_status(ManagedStream.STATUS_RUNNING)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_RUNNING)
self.wait_for_stream_finished(stream)
await stream.setup(self.node, save_file=not self.config.streaming_only)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
async def recover_streams(self, file_infos: typing.List[typing.Dict]):
to_restore = []
@ -139,13 +138,14 @@ class StreamManager:
return
stream = ManagedStream(
self.loop, self.config, self.blob_manager, descriptor.sd_hash, download_directory, file_name, status,
claim, rowid=rowid, descriptor=descriptor
claim, rowid=rowid, descriptor=descriptor, analytics_manager=self.analytics_manager
)
self.streams[sd_hash] = stream
async def load_streams_from_database(self):
to_recover = []
to_start = []
await self.storage.sync_files_to_blobs()
for file_info in await self.storage.get_all_lbry_files():
if not self.blob_manager.get_blob(file_info['sd_hash']).get_is_verified():
to_recover.append(file_info)
@ -181,10 +181,10 @@ class StreamManager:
while True:
if self.config.reflect_streams and self.config.reflector_servers:
sd_hashes = await self.storage.get_streams_to_re_reflect()
streams = list(filter(lambda s: s in sd_hashes, self.streams.keys()))
sd_hashes = [sd for sd in sd_hashes if sd in self.streams]
batch = []
while streams:
stream = streams.pop()
while sd_hashes:
stream = self.streams[sd_hashes.pop()]
if not stream.fully_reflected.is_set():
host, port = random.choice(self.config.reflector_servers)
batch.append(stream.upload_to_reflector(host, port))
@ -198,7 +198,7 @@ class StreamManager:
async def start(self):
await self.load_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume())
# self.re_reflect_task = self.loop.create_task(self.reflect_streams())
self.re_reflect_task = self.loop.create_task(self.reflect_streams())
def stop(self):
if self.resume_downloading_task and not self.resume_downloading_task.done():
@ -279,28 +279,11 @@ class StreamManager:
streams.reverse()
return streams
def wait_for_stream_finished(self, stream: ManagedStream):
async def _wait_for_stream_finished():
if stream.downloader and stream.running:
await stream.finished_writing.wait()
stream.update_status(ManagedStream.STATUS_FINISHED)
if self.analytics_manager:
self.loop.create_task(self.analytics_manager.send_download_finished(
stream.download_id, stream.claim_name, stream.sd_hash
))
task = self.loop.create_task(_wait_for_stream_finished())
self.update_stream_finished_futs.append(task)
task.add_done_callback(
lambda _: None if task not in self.update_stream_finished_futs else
self.update_stream_finished_futs.remove(task)
)
async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim) -> typing.Tuple[
typing.Optional[ManagedStream], typing.Optional[ManagedStream]]:
existing = self.get_filtered_streams(outpoint=outpoint)
if existing:
if not existing[0].running:
if existing[0].status == ManagedStream.STATUS_STOPPED:
await self.start_stream(existing[0])
return existing[0], None
existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash)
@ -323,34 +306,27 @@ class StreamManager:
return None, existing_for_claim_id[0]
return None, None
# async def start_downloader(self, got_descriptor_time: asyncio.Future, downloader: EncryptedStreamDownloader,
# download_id: str, outpoint: str, claim: Claim, resolved: typing.Dict,
# file_name: typing.Optional[str] = None) -> ManagedStream:
# start_time = self.loop.time()
# downloader.download(self.node)
# await downloader.got_descriptor.wait()
# got_descriptor_time.set_result(self.loop.time() - start_time)
# rowid = await self._store_stream(downloader)
# await self.storage.save_content_claim(
# downloader.descriptor.stream_hash, outpoint
# )
# stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, self.config.download_dir,
# file_name, downloader, ManagedStream.STATUS_RUNNING, download_id=download_id)
# stream.set_claim(resolved, claim)
# await stream.downloader.wrote_bytes_event.wait()
# self.streams.add(stream)
# return stream
@cache_concurrent
async def download_stream_from_uri(self, uri, timeout: float, exchange_rate_manager: 'ExchangeRateManager',
file_name: typing.Optional[str] = None) -> ManagedStream:
async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
timeout: typing.Optional[float] = None,
file_name: typing.Optional[str] = None,
download_directory: typing.Optional[str] = None,
save_file: bool = True, resolve_timeout: float = 3.0) -> ManagedStream:
timeout = timeout or self.config.download_timeout
start_time = self.loop.time()
resolved_time = None
stream = None
error = None
outpoint = None
try:
# resolve the claim
parsed_uri = parse_lbry_uri(uri)
if parsed_uri.is_channel:
raise ResolveError("cannot download a channel claim, specify a /path")
# resolve the claim
resolved_result = await self.wallet.ledger.resolve(0, 10, uri)
try:
resolved_result = await asyncio.wait_for(self.wallet.ledger.resolve(0, 1, uri), resolve_timeout)
except asyncio.TimeoutError:
raise ResolveTimeout(uri)
await self.storage.save_claims_for_resolve([
value for value in resolved_result.values() if 'error' not in value
])
@ -371,8 +347,7 @@ class StreamManager:
return updated_stream
# check that the fee is payable
fee_amount, fee_address = None, None
if claim.stream.has_fee:
if not to_replace and claim.stream.has_fee:
fee_amount = round(exchange_rate_manager.convert_currency(
claim.stream.fee.currency, "LBC", claim.stream.fee.amount
), 5)
@ -389,97 +364,54 @@ class StreamManager:
log.warning(msg)
raise InsufficientFundsError(msg)
fee_address = claim.stream.fee.address
# content_fee_tx = await self.wallet.send_amount_to_address(
# lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
# )
handled_fee_time = self.loop.time() - resolved_time - start_time
# download the stream
download_id = binascii.hexlify(generate_id()).decode()
download_dir = self.config.download_dir
save_file = True
if not file_name and self.config.streaming_only:
download_dir, file_name = None, None
save_file = False
stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_dir,
file_name, ManagedStream.STATUS_RUNNING, download_id=download_id
await self.wallet.send_amount_to_address(
lbc_to_dewies(str(fee_amount)), fee_address.encode('latin1')
)
log.info("paid fee of %s for %s", fee_amount, uri)
await stream.setup(self.node, save_file=save_file)
download_directory = download_directory or self.config.download_dir
if not file_name and (self.config.streaming_only or not save_file):
download_dir, file_name = None, None
stream = ManagedStream(
self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory,
file_name, ManagedStream.STATUS_RUNNING, analytics_manager=self.analytics_manager
)
try:
await asyncio.wait_for(stream.setup(
self.node, save_file=save_file, file_name=file_name, download_directory=download_directory
), timeout, loop=self.loop)
except asyncio.TimeoutError:
if not stream.descriptor:
raise DownloadSDTimeout(stream.sd_hash)
raise DownloadDataTimeout(stream.sd_hash)
if to_replace: # delete old stream now that the replacement has started downloading
await self.delete_stream(to_replace)
stream.set_claim(resolved, claim)
await self.storage.save_content_claim(stream.stream_hash, outpoint)
self.streams[stream.sd_hash] = stream
# stream = None
# descriptor_time_fut = self.loop.create_future()
# start_download_time = self.loop.time()
# time_to_descriptor = None
# time_to_first_bytes = None
# error = None
# try:
# stream = await asyncio.wait_for(
# asyncio.ensure_future(
# self.start_downloader(descriptor_time_fut, downloader, download_id, outpoint, claim, resolved,
# file_name)
# ), timeout
# )
# time_to_descriptor = await descriptor_time_fut
# time_to_first_bytes = self.loop.time() - start_download_time - time_to_descriptor
# self.wait_for_stream_finished(stream)
# if fee_address and fee_amount and not to_replace:
#
# elif to_replace: # delete old stream now that the replacement has started downloading
# await self.delete_stream(to_replace)
# except asyncio.TimeoutError:
# if descriptor_time_fut.done():
# time_to_descriptor = descriptor_time_fut.result()
# error = DownloadDataTimeout(downloader.sd_hash)
# self.blob_manager.delete_blob(downloader.sd_hash)
# await self.storage.delete_stream(downloader.descriptor)
# else:
# descriptor_time_fut.cancel()
# error = DownloadSDTimeout(downloader.sd_hash)
# if stream:
# await self.stop_stream(stream)
# else:
# downloader.stop()
# if error:
# log.warning(error)
# if self.analytics_manager:
# self.loop.create_task(
# self.analytics_manager.send_time_to_first_bytes(
# resolved_time, self.loop.time() - start_time, download_id, parse_lbry_uri(uri).name, outpoint,
# None if not stream else len(stream.downloader.blob_downloader.active_connections),
# None if not stream else len(stream.downloader.blob_downloader.scores),
# False if not downloader else downloader.added_fixed_peers,
# self.config.fixed_peer_delay if not downloader else downloader.fixed_peers_delay,
# claim.source_hash.decode(), time_to_descriptor,
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
# None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
# time_to_first_bytes, None if not error else error.__class__.__name__
# )
# )
# if error:
# raise error
return stream
# async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager',
# file_name: typing.Optional[str] = None,
# timeout: typing.Optional[float] = None) -> ManagedStream:
# timeout = timeout or self.config.download_timeout
# if uri in self.starting_streams:
# return await self.starting_streams[uri]
# fut = asyncio.Future(loop=self.loop)
# self.starting_streams[uri] = fut
# try:
# stream = await self._download_stream_from_uri(uri, timeout, exchange_rate_manager, file_name)
# fut.set_result(stream)
# except Exception as err:
# fut.set_exception(err)
# try:
# return await fut
# finally:
# del self.starting_streams[uri]
except Exception as err:
error = err
if stream and stream.descriptor:
await self.storage.delete_stream(stream.descriptor)
finally:
if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or
stream.downloader.time_to_first_bytes))):
self.loop.create_task(
self.analytics_manager.send_time_to_first_bytes(
resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id,
uri, outpoint,
None if not stream else len(stream.downloader.blob_downloader.active_connections),
None if not stream else len(stream.downloader.blob_downloader.scores),
False if not stream else stream.downloader.added_fixed_peers,
self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay,
None if not stream else stream.sd_hash,
None if not stream else stream.downloader.time_to_descriptor,
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash,
None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length,
None if not stream else stream.downloader.time_to_first_bytes,
None if not error else error.__class__.__name__
)
)
if error:
raise error

View file

@ -28,9 +28,9 @@ class TestBlobfile(AsyncioTestCase):
self.assertEqual(blob.get_is_verified(), False)
self.assertNotIn(blob_hash, blob_manager.completed_blob_hashes)
writer = blob.open_for_writing()
writer = blob.get_blob_writer()
writer.write(blob_bytes)
await blob.finished_writing.wait()
await blob.verified.wait()
self.assertTrue(os.path.isfile(blob.file_path), True)
self.assertEqual(blob.get_is_verified(), True)
self.assertIn(blob_hash, blob_manager.completed_blob_hashes)

View file

@ -11,7 +11,7 @@ from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob_exchange.server import BlobServer, BlobServerProtocol
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
from lbrynet.blob_exchange.client import request_blob
from lbrynet.dht.peer import KademliaPeer, PeerManager
# import logging
@ -58,9 +58,9 @@ class TestBlobExchange(BlobExchangeTestBase):
async def _add_blob_to_server(self, blob_hash: str, blob_bytes: bytes):
# add the blob on the server
server_blob = self.server_blob_manager.get_blob(blob_hash, len(blob_bytes))
writer = server_blob.open_for_writing()
writer = server_blob.get_blob_writer()
writer.write(blob_bytes)
await server_blob.finished_writing.wait()
await server_blob.verified.wait()
self.assertTrue(os.path.isfile(server_blob.file_path))
self.assertEqual(server_blob.get_is_verified(), True)
@ -68,11 +68,14 @@ class TestBlobExchange(BlobExchangeTestBase):
client_blob = self.client_blob_manager.get_blob(blob_hash)
# download the blob
downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address,
downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address,
self.server_from_client.tcp_port, 2, 3)
await client_blob.finished_writing.wait()
self.assertIsNotNone(transport)
self.addCleanup(transport.close)
await client_blob.verified.wait()
self.assertEqual(client_blob.get_is_verified(), True)
self.assertTrue(downloaded)
self.addCleanup(client_blob.close)
async def test_transfer_sd_blob(self):
sd_hash = "3e2706157a59aaa47ef52bc264fce488078b4026c0b9bab649a8f2fe1ecc5e5cad7182a2bb7722460f856831a1ac0f02"
@ -112,7 +115,7 @@ class TestBlobExchange(BlobExchangeTestBase):
),
self._test_transfer_blob(blob_hash)
)
await second_client_blob.finished_writing.wait()
await second_client_blob.verified.wait()
self.assertEqual(second_client_blob.get_is_verified(), True)
async def test_host_different_blobs_to_multiple_peers_at_once(self):
@ -143,7 +146,7 @@ class TestBlobExchange(BlobExchangeTestBase):
server_from_second_client.tcp_port, 2, 3
),
self._test_transfer_blob(sd_hash),
second_client_blob.finished_writing.wait()
second_client_blob.verified.wait()
)
self.assertEqual(second_client_blob.get_is_verified(), True)

View file

@ -1,12 +1,13 @@
import asyncio
from torba.testcase import AsyncioTestCase
from unittest import mock, TestCase
from lbrynet.dht.protocol.data_store import DictDataStore
from lbrynet.dht.peer import PeerManager
class DataStoreTests(AsyncioTestCase):
class DataStoreTests(TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.loop = mock.Mock(spec=asyncio.BaseEventLoop)
self.loop.time = lambda: 0.0
self.peer_manager = PeerManager(self.loop)
self.data_store = DictDataStore(self.loop, self.peer_manager)

View file

@ -1,117 +0,0 @@
import os
import asyncio
import tempfile
import shutil
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobManager
from lbrynet.stream.assembler import StreamAssembler
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.stream_manager import StreamManager
class TestStreamAssembler(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4
self.cleartext = b'test'
async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), ":memory:")
await self.storage.open()
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage)
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
# create the stream
file_path = os.path.join(tmp_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.cleartext)
sd = await StreamDescriptor.create_stream(self.loop, tmp_dir, file_path, key=self.key)
# copy blob files
sd_hash = sd.calculate_sd_hash()
shutil.copy(os.path.join(tmp_dir, sd_hash), os.path.join(download_dir, sd_hash))
for blob_info in sd.blobs:
if blob_info.blob_hash:
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
await downloader_storage.open()
# add the blobs to the blob table (this would happen upon a blob download finishing)
downloader_blob_manager = BlobManager(self.loop, download_dir, downloader_storage)
descriptor = await downloader_blob_manager.get_stream_descriptor(sd_hash)
# assemble the decrypted file
assembler = StreamAssembler(self.loop, downloader_blob_manager, descriptor.sd_hash)
await assembler.assemble_decrypted_stream(download_dir)
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(download_dir, "test_file")))
with open(os.path.join(download_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.cleartext)
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
# its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
self.assertEqual(
[descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
)
await downloader_storage.close()
await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
self.cleartext = b'test\n' * 20000000
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_padding(self):
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) + i)
await self.test_create_and_decrypt_one_blob_stream()
for i in range(16):
self.cleartext = os.urandom((MAX_BLOB_SIZE*2) - i)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_and_decrypt_random(self):
self.cleartext = os.urandom(20000000)
await self.test_create_and_decrypt_one_blob_stream()
async def test_create_managed_stream_announces(self):
# setup a blob manager
storage = SQLiteStorage(Config(), ":memory:")
await storage.open()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
blob_manager = BlobManager(self.loop, tmp_dir, storage)
stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# create the stream
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
file_path = os.path.join(download_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(b'testtest')
stream = await stream_manager.create_stream(file_path)
self.assertEqual(
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
await storage.get_blobs_to_announce())
async def test_create_truncate_and_handle_stream(self):
self.cleartext = b'potato' * 1337 * 5279
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -1,102 +0,0 @@
import os
import time
import unittest
from unittest import mock
import asyncio
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.conf import Config
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.downloader import StreamDownloader
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
class TestStreamDownloader(BlobExchangeTestBase):
async def setup_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir,
reflector_servers=[])
self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(self.downloader.stream_handle.closed)
self.assertTrue(os.path.isfile(self.downloader.output_path))
self.downloader.stop()
self.assertIs(self.downloader.stream_handle, None)
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait()
self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))

View file

@ -0,0 +1,175 @@
import os
import shutil
import unittest
from unittest import mock
import asyncio
from lbrynet.blob.blob_file import MAX_BLOB_SIZE
from lbrynet.blob_exchange.serialization import BlobResponse
from lbrynet.blob_exchange.server import BlobServerProtocol
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.stream.descriptor import StreamDescriptor
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
def get_mock_node(loop):
mock_node = mock.Mock(spec=Node)
mock_node.joined = asyncio.Event(loop=loop)
mock_node.joined.set()
return mock_node
class TestManagedStream(BlobExchangeTestBase):
async def create_stream(self, blob_count: int = 10):
self.stream_bytes = b''
for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
return descriptor
async def setup_stream(self, blob_count: int = 10):
await self.create_stream(blob_count)
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
self.stream.stop_download()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
async def test_transfer_stream(self):
await self._test_transfer_stream(10)
@unittest.SkipTest
async def test_transfer_hundred_blob_stream(self):
await self._test_transfer_stream(100)
async def test_transfer_stream_bad_first_peer_good_second(self):
await self.setup_stream(2)
mock_node = mock.Mock(spec=Node)
q = asyncio.Queue()
bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)
def _mock_accumulate_peers(q1, q2):
async def _task():
pass
q2.put_nowait([bad_peer])
self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = _mock_accumulate_peers
await self.stream.setup(mock_node, save_file=True)
await self.stream.finished_writing.wait()
self.assertTrue(os.path.isfile(self.stream.full_path))
with open(self.stream.full_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes)
# self.assertIs(self.server_from_client.tcp_last_down, None)
# self.assertIsNot(bad_peer.tcp_last_down, None)
async def test_client_chunked_response(self):
self.server.stop_server()
class ChunkedServerProtocol(BlobServerProtocol):
def send_response(self, responses):
to_send = []
while responses:
to_send.append(responses.pop())
for byte in BlobResponse(to_send).serialize():
self.transport.write(bytes([byte]))
self.server.server_protocol_class = ChunkedServerProtocol
self.server.start_server(33333, '127.0.0.1')
self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))
async def test_create_and_decrypt_one_blob_stream(self, blobs=1, corrupt=False):
descriptor = await self.create_stream(blobs)
# copy blob files
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash),
os.path.join(self.client_blob_manager.blob_dir, self.sd_hash))
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
for blob_info in descriptor.blobs[:-1]:
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash),
os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
await self.stream.setup()
await self.stream.finished_writing.wait()
if corrupt:
return self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
with open(os.path.join(self.client_dir, "test_file"), "rb") as f:
decrypted = f.read()
self.assertEqual(decrypted, self.stream_bytes)
self.assertEqual(True, self.client_blob_manager.get_blob(self.sd_hash).get_is_verified())
self.assertEqual(
True, self.client_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
#
# # its all blobs + sd blob - last blob, which is the same size as descriptor.blobs
# self.assertEqual(len(descriptor.blobs), len(await downloader_storage.get_all_finished_blobs()))
# self.assertEqual(
# [descriptor.sd_hash, descriptor.blobs[0].blob_hash], await downloader_storage.get_blobs_to_announce()
# )
#
# await downloader_storage.close()
# await self.storage.close()
async def test_create_and_decrypt_multi_blob_stream(self):
await self.test_create_and_decrypt_one_blob_stream(10)
# async def test_create_managed_stream_announces(self):
# # setup a blob manager
# storage = SQLiteStorage(Config(), ":memory:")
# await storage.open()
# tmp_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(tmp_dir))
# blob_manager = BlobManager(self.loop, tmp_dir, storage)
# stream_manager = StreamManager(self.loop, Config(), blob_manager, None, storage, None)
# # create the stream
# download_dir = tempfile.mkdtemp()
# self.addCleanup(lambda: shutil.rmtree(download_dir))
# file_path = os.path.join(download_dir, "test_file")
# with open(file_path, 'wb') as f:
# f.write(b'testtest')
#
# stream = await stream_manager.create_stream(file_path)
# self.assertEqual(
# [stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
# await storage.get_blobs_to_announce())
# async def test_create_truncate_and_handle_stream(self):
# # The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
# await asyncio.wait_for(self.test_create_and_decrypt_one_blob_stream(corrupt=True), timeout=5)

View file

@ -99,7 +99,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
blob = blob_manager.get_blob(sd_hash)
blob.set_length(len(sd_bytes))
writer = blob.open_for_writing()
writer = blob.get_blob_writer()
writer.write(sd_bytes)
await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
@ -116,7 +116,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle:
handle.write(b'doesnt work')
blob = BlobFile(loop, tmp_dir, sd_hash)
blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir)
self.assertTrue(blob.file_exists)
self.assertIsNotNone(blob.length)
with self.assertRaises(InvalidStreamDescriptorError):

View file

@ -192,8 +192,8 @@ class TestStreamManager(BlobExchangeTestBase):
def check_post(event):
self.assertEqual(event['event'], 'Time To First Bytes')
self.assertEqual(event['properties']['error'], 'DownloadSDTimeout')
self.assertEqual(event['properties']['tried_peers_count'], None)
self.assertEqual(event['properties']['active_peer_count'], None)
self.assertEqual(event['properties']['tried_peers_count'], 0)
self.assertEqual(event['properties']['active_peer_count'], 0)
self.assertEqual(event['properties']['use_fixed_peers'], False)
self.assertEqual(event['properties']['added_fixed_peers'], False)
self.assertEqual(event['properties']['fixed_peer_delay'], None)
@ -213,10 +213,10 @@ class TestStreamManager(BlobExchangeTestBase):
self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set())
self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
stream_hash = stream.stream_hash
self.assertSetEqual(self.stream_manager.streams, {stream})
self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream})
self.assertTrue(stream.running)
self.assertFalse(stream.finished)
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
@ -236,7 +236,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream)
await stream.downloader.stream_finished_event.wait()
await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop)
self.assertTrue(stream.finished)
self.assertFalse(stream.running)
@ -247,7 +247,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.assertEqual(stored_status, "finished")
await self.stream_manager.delete_stream(stream, True)
self.assertSetEqual(self.stream_manager.streams, set())
self.assertDictEqual(self.stream_manager.streams, {})
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash
@ -257,7 +257,7 @@ class TestStreamManager(BlobExchangeTestBase):
async def _test_download_error_on_start(self, expected_error, timeout=None):
with self.assertRaises(expected_error):
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout=timeout)
await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout)
async def _test_download_error_analytics_on_start(self, expected_error, timeout=None):
received = []
@ -321,9 +321,9 @@ class TestStreamManager(BlobExchangeTestBase):
await self.setup_stream_manager(old_sort=old_sort)
self.stream_manager.analytics_manager._post = check_post
self.assertSetEqual(self.stream_manager.streams, set())
self.assertDictEqual(self.stream_manager.streams, {})
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
await stream.downloader.stream_finished_event.wait()
await stream.finished_writing.wait()
await asyncio.sleep(0, loop=self.loop)
self.stream_manager.stop()
self.client_blob_manager.stop()
@ -333,8 +333,11 @@ class TestStreamManager(BlobExchangeTestBase):
await self.client_blob_manager.setup()
await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams))
self.assertEqual(stream.sd_hash, list(self.stream_manager.streams)[0].sd_hash)
self.assertEqual('stopped', list(self.stream_manager.streams)[0].status)
self.assertListEqual([self.sd_hash], list(self.stream_manager.streams.keys()))
for blob_hash in [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]:
blob_status = await self.client_storage.get_blob_status(blob_hash)
self.assertEqual('pending', blob_status)
self.assertEqual('stopped', self.stream_manager.streams[self.sd_hash].status)
sd_blob = self.client_blob_manager.get_blob(stream.sd_hash)
self.assertTrue(sd_blob.file_exists)