Merge pull request #1866 from lbryio/non-async-close-blob
refactor blob.close to be non-async, speed up deleting blobs and streams
This commit is contained in:
commit
3508da4993
14 changed files with 256 additions and 84 deletions
|
@ -132,16 +132,18 @@ class BlobFile:
|
||||||
with open(self.file_path, 'rb') as handle:
|
with open(self.file_path, 'rb') as handle:
|
||||||
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
|
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
|
||||||
|
|
||||||
async def close(self):
|
def close(self):
|
||||||
while self.writers:
|
while self.writers:
|
||||||
self.writers.pop().finished.cancel()
|
self.writers.pop().finished.cancel()
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await self.close()
|
self.close()
|
||||||
async with self.blob_write_lock:
|
async with self.blob_write_lock:
|
||||||
self.saved_verified_blob = False
|
self.saved_verified_blob = False
|
||||||
if os.path.isfile(self.file_path):
|
if os.path.isfile(self.file_path):
|
||||||
os.remove(self.file_path)
|
os.remove(self.file_path)
|
||||||
|
self.verified.clear()
|
||||||
|
self.finished_writing.clear()
|
||||||
|
|
||||||
def decrypt(self, key: bytes, iv: bytes) -> bytes:
|
def decrypt(self, key: bytes, iv: bytes) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -63,21 +63,21 @@ class BlobFileManager:
|
||||||
blob_hashes = await self.storage.get_all_blob_hashes()
|
blob_hashes = await self.storage.get_all_blob_hashes()
|
||||||
return self.check_completed_blobs(blob_hashes)
|
return self.check_completed_blobs(blob_hashes)
|
||||||
|
|
||||||
async def delete_blobs(self, blob_hashes: typing.List[str]):
|
async def delete_blob(self, blob_hash: str):
|
||||||
bh_to_delete_from_db = []
|
|
||||||
for blob_hash in blob_hashes:
|
|
||||||
if not blob_hash:
|
|
||||||
continue
|
|
||||||
try:
|
try:
|
||||||
blob = self.get_blob(blob_hash)
|
blob = self.get_blob(blob_hash)
|
||||||
await blob.delete()
|
await blob.delete()
|
||||||
bh_to_delete_from_db.append(blob_hash)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning("Failed to delete blob file. Reason: %s", e)
|
log.warning("Failed to delete blob file. Reason: %s", e)
|
||||||
if blob_hash in self.completed_blob_hashes:
|
if blob_hash in self.completed_blob_hashes:
|
||||||
self.completed_blob_hashes.remove(blob_hash)
|
self.completed_blob_hashes.remove(blob_hash)
|
||||||
if blob_hash in self.blobs:
|
if blob_hash in self.blobs:
|
||||||
del self.blobs[blob_hash]
|
del self.blobs[blob_hash]
|
||||||
|
|
||||||
|
async def delete_blobs(self, blob_hashes: typing.List[str], delete_from_db: typing.Optional[bool] = True):
|
||||||
|
bh_to_delete_from_db = []
|
||||||
|
await asyncio.gather(*map(self.delete_blob, blob_hashes), loop=self.loop)
|
||||||
|
if delete_from_db:
|
||||||
try:
|
try:
|
||||||
await self.storage.delete_blobs_from_db(bh_to_delete_from_db)
|
await self.storage.delete_blobs_from_db(bh_to_delete_from_db)
|
||||||
except IntegrityError as err:
|
except IntegrityError as err:
|
||||||
|
|
|
@ -86,7 +86,7 @@ class BlobDownloader:
|
||||||
peer, task = self.active_connections.popitem()
|
peer, task = self.active_connections.popitem()
|
||||||
if task and not task.done():
|
if task and not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
await blob.close()
|
blob.close()
|
||||||
log.debug("downloaded %s", blob_hash[:8])
|
log.debug("downloaded %s", blob_hash[:8])
|
||||||
return blob
|
return blob
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|
|
@ -474,7 +474,12 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})")
|
remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})")
|
||||||
if error_datagram.rpc_id in self.sent_messages:
|
if error_datagram.rpc_id in self.sent_messages:
|
||||||
peer, df, request = self.sent_messages.pop(error_datagram.rpc_id)
|
peer, df, request = self.sent_messages.pop(error_datagram.rpc_id)
|
||||||
|
if (peer.address, peer.udp_port) != address:
|
||||||
|
df.set_exception(RemoteException(
|
||||||
|
f"response from {address[0]}:{address[1]}, "
|
||||||
|
f"expected {peer.address}:{peer.udp_port}")
|
||||||
|
)
|
||||||
|
return
|
||||||
error_msg = f"" \
|
error_msg = f"" \
|
||||||
f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \
|
f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \
|
||||||
f"Args: {request.args}\n" \
|
f"Args: {request.args}\n" \
|
||||||
|
@ -484,11 +489,6 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
else:
|
else:
|
||||||
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
|
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
|
||||||
peer.address, peer.udp_port, old_protocol_errors[error_datagram.response])
|
peer.address, peer.udp_port, old_protocol_errors[error_datagram.response])
|
||||||
|
|
||||||
# reject replies coming from a different address than what we sent our request to
|
|
||||||
if (peer.address, peer.udp_port) != address:
|
|
||||||
log.error("node id mismatch in reply")
|
|
||||||
remote_exception = TimeoutError(peer.node_id)
|
|
||||||
df.set_exception(remote_exception)
|
df.set_exception(remote_exception)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -323,7 +323,7 @@ class BlobComponent(Component):
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
while self.blob_manager and self.blob_manager.blobs:
|
while self.blob_manager and self.blob_manager.blobs:
|
||||||
_, blob = self.blob_manager.blobs.popitem()
|
_, blob = self.blob_manager.blobs.popitem()
|
||||||
await blob.close()
|
blob.close()
|
||||||
|
|
||||||
async def get_status(self):
|
async def get_status(self):
|
||||||
count = 0
|
count = 0
|
||||||
|
|
|
@ -1614,7 +1614,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
await self.stream_manager.start_stream(stream)
|
await self.stream_manager.start_stream(stream)
|
||||||
msg = "Resumed download"
|
msg = "Resumed download"
|
||||||
elif status == 'stop' and stream.running:
|
elif status == 'stop' and stream.running:
|
||||||
stream.stop_download()
|
await self.stream_manager.stop_stream(stream)
|
||||||
msg = "Stopped download"
|
msg = "Stopped download"
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
|
|
|
@ -43,7 +43,7 @@ class StreamAssembler:
|
||||||
self.written_bytes: int = 0
|
self.written_bytes: int = 0
|
||||||
|
|
||||||
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
|
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
|
||||||
if not blob or self.stream_handle.closed:
|
if not blob or not self.stream_handle or self.stream_handle.closed:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _decrypt_and_write():
|
def _decrypt_and_write():
|
||||||
|
@ -56,7 +56,6 @@ class StreamAssembler:
|
||||||
self.stream_handle.flush()
|
self.stream_handle.flush()
|
||||||
self.written_bytes += len(_decrypted)
|
self.written_bytes += len(_decrypted)
|
||||||
log.debug("decrypted %s", blob.blob_hash[:8])
|
log.debug("decrypted %s", blob.blob_hash[:8])
|
||||||
self.wrote_bytes_event.set()
|
|
||||||
|
|
||||||
await self.loop.run_in_executor(None, _decrypt_and_write)
|
await self.loop.run_in_executor(None, _decrypt_and_write)
|
||||||
return True
|
return True
|
||||||
|
@ -86,17 +85,23 @@ class StreamAssembler:
|
||||||
self.sd_blob, self.descriptor
|
self.sd_blob, self.descriptor
|
||||||
)
|
)
|
||||||
await self.blob_manager.blob_completed(self.sd_blob)
|
await self.blob_manager.blob_completed(self.sd_blob)
|
||||||
|
written_blobs = None
|
||||||
|
try:
|
||||||
with open(self.output_path, 'wb') as stream_handle:
|
with open(self.output_path, 'wb') as stream_handle:
|
||||||
self.stream_handle = stream_handle
|
self.stream_handle = stream_handle
|
||||||
for i, blob_info in enumerate(self.descriptor.blobs[:-1]):
|
for i, blob_info in enumerate(self.descriptor.blobs[:-1]):
|
||||||
if blob_info.blob_num != i:
|
if blob_info.blob_num != i:
|
||||||
log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash)
|
log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash)
|
||||||
return
|
return
|
||||||
while not stream_handle.closed:
|
while self.stream_handle and not self.stream_handle.closed:
|
||||||
try:
|
try:
|
||||||
blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
|
blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
|
||||||
if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
|
if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
|
||||||
await self.blob_manager.blob_completed(blob)
|
await self.blob_manager.blob_completed(blob)
|
||||||
|
written_blobs = i
|
||||||
|
if not self.wrote_bytes_event.is_set():
|
||||||
|
self.wrote_bytes_event.set()
|
||||||
|
log.debug("written %i/%i", written_blobs, len(self.descriptor.blobs) - 2)
|
||||||
break
|
break
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
log.debug("stream assembler stopped")
|
log.debug("stream assembler stopped")
|
||||||
|
@ -105,9 +110,14 @@ class StreamAssembler:
|
||||||
log.warning("failed to decrypt blob %s for stream %s", blob_info.blob_hash,
|
log.warning("failed to decrypt blob %s for stream %s", blob_info.blob_hash,
|
||||||
self.descriptor.sd_hash)
|
self.descriptor.sd_hash)
|
||||||
continue
|
continue
|
||||||
|
finally:
|
||||||
self.stream_finished_event.set()
|
if written_blobs == len(self.descriptor.blobs) - 2:
|
||||||
|
log.debug("finished decrypting and assembling stream")
|
||||||
await self.after_finished()
|
await self.after_finished()
|
||||||
|
self.stream_finished_event.set()
|
||||||
|
else:
|
||||||
|
log.debug("stream decryption and assembly did not finish (%i/%i blobs are done)", written_blobs or 0,
|
||||||
|
len(self.descriptor.blobs) - 2)
|
||||||
|
|
||||||
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||||
return self.blob_manager.get_blob(blob_hash, length)
|
return self.blob_manager.get_blob(blob_hash, length)
|
||||||
|
|
|
@ -86,7 +86,7 @@ class StreamDescriptor:
|
||||||
writer = sd_blob.open_for_writing()
|
writer = sd_blob.open_for_writing()
|
||||||
writer.write(sd_data)
|
writer.write(sd_data)
|
||||||
await sd_blob.verified.wait()
|
await sd_blob.verified.wait()
|
||||||
await sd_blob.close()
|
sd_blob.close()
|
||||||
return sd_blob
|
return sd_blob
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -63,6 +63,10 @@ class StreamDownloader(StreamAssembler):
|
||||||
self.fixed_peers_handle.cancel()
|
self.fixed_peers_handle.cancel()
|
||||||
self.fixed_peers_handle = None
|
self.fixed_peers_handle = None
|
||||||
self.blob_downloader = None
|
self.blob_downloader = None
|
||||||
|
if self.stream_handle:
|
||||||
|
if not self.stream_handle.closed:
|
||||||
|
self.stream_handle.close()
|
||||||
|
self.stream_handle = None
|
||||||
|
|
||||||
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||||
return await self.blob_downloader.download_blob(blob_hash, length)
|
return await self.blob_downloader.download_blob(blob_hash, length)
|
||||||
|
|
|
@ -104,8 +104,12 @@ class ManagedStream:
|
||||||
def blobs_remaining(self) -> int:
|
def blobs_remaining(self) -> int:
|
||||||
return self.blobs_in_stream - self.blobs_completed
|
return self.blobs_in_stream - self.blobs_completed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_path(self) -> str:
|
||||||
|
return os.path.join(self.download_directory, os.path.basename(self.file_name))
|
||||||
|
|
||||||
def as_dict(self) -> typing.Dict:
|
def as_dict(self) -> typing.Dict:
|
||||||
full_path = os.path.join(self.download_directory, self.file_name)
|
full_path = self.full_path
|
||||||
if not os.path.isfile(full_path):
|
if not os.path.isfile(full_path):
|
||||||
full_path = None
|
full_path = None
|
||||||
mime_type = guess_media_type(os.path.basename(self.file_name))
|
mime_type = guess_media_type(os.path.basename(self.file_name))
|
||||||
|
@ -170,12 +174,7 @@ class ManagedStream:
|
||||||
def stop_download(self):
|
def stop_download(self):
|
||||||
if self.downloader:
|
if self.downloader:
|
||||||
self.downloader.stop()
|
self.downloader.stop()
|
||||||
if not self.downloader.stream_finished_event.is_set() and self.downloader.wrote_bytes_event.is_set():
|
self.downloader = None
|
||||||
path = os.path.join(self.download_directory, self.file_name)
|
|
||||||
if os.path.isfile(path):
|
|
||||||
os.remove(path)
|
|
||||||
if not self.finished:
|
|
||||||
self.update_status(self.STATUS_STOPPED)
|
|
||||||
|
|
||||||
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
|
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
|
||||||
sent = []
|
sent = []
|
||||||
|
|
|
@ -4,7 +4,7 @@ import typing
|
||||||
import binascii
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from lbrynet.error import ResolveError
|
from lbrynet.error import ResolveError, InvalidStreamDescriptorError
|
||||||
from lbrynet.stream.downloader import StreamDownloader
|
from lbrynet.stream.downloader import StreamDownloader
|
||||||
from lbrynet.stream.managed_stream import ManagedStream
|
from lbrynet.stream.managed_stream import ManagedStream
|
||||||
from lbrynet.schema.claim import ClaimDict
|
from lbrynet.schema.claim import ClaimDict
|
||||||
|
@ -97,8 +97,9 @@ class StreamManager:
|
||||||
await asyncio.wait_for(self.loop.create_task(stream.downloader.got_descriptor.wait()),
|
await asyncio.wait_for(self.loop.create_task(stream.downloader.got_descriptor.wait()),
|
||||||
self.config.download_timeout)
|
self.config.download_timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
stream.stop_download()
|
await self.stop_stream(stream)
|
||||||
stream.downloader = None
|
if stream in self.streams:
|
||||||
|
self.streams.remove(stream)
|
||||||
return False
|
return False
|
||||||
file_name = os.path.basename(stream.downloader.output_path)
|
file_name = os.path.basename(stream.downloader.output_path)
|
||||||
await self.storage.change_file_download_dir_and_file_name(
|
await self.storage.change_file_download_dir_and_file_name(
|
||||||
|
@ -108,6 +109,18 @@ class StreamManager:
|
||||||
return True
|
return True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def stop_stream(self, stream: ManagedStream):
|
||||||
|
stream.stop_download()
|
||||||
|
if not stream.finished and os.path.isfile(stream.full_path):
|
||||||
|
try:
|
||||||
|
os.remove(stream.full_path)
|
||||||
|
except OSError as err:
|
||||||
|
log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path,
|
||||||
|
str(err))
|
||||||
|
if stream.running:
|
||||||
|
stream.update_status(ManagedStream.STATUS_STOPPED)
|
||||||
|
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
|
||||||
|
|
||||||
def make_downloader(self, sd_hash: str, download_directory: str, file_name: str):
|
def make_downloader(self, sd_hash: str, download_directory: str, file_name: str):
|
||||||
return StreamDownloader(
|
return StreamDownloader(
|
||||||
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
|
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
|
||||||
|
@ -116,13 +129,15 @@ class StreamManager:
|
||||||
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim):
|
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim):
|
||||||
sd_blob = self.blob_manager.get_blob(sd_hash)
|
sd_blob = self.blob_manager.get_blob(sd_hash)
|
||||||
if sd_blob.get_is_verified():
|
if sd_blob.get_is_verified():
|
||||||
|
try:
|
||||||
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
|
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
|
||||||
|
except InvalidStreamDescriptorError as err:
|
||||||
|
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
|
||||||
|
return
|
||||||
|
|
||||||
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
|
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
|
||||||
stream = ManagedStream(
|
stream = ManagedStream(
|
||||||
self.loop, self.blob_manager, descriptor,
|
self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim
|
||||||
download_directory,
|
|
||||||
file_name,
|
|
||||||
downloader, status, claim
|
|
||||||
)
|
)
|
||||||
self.streams.add(stream)
|
self.streams.add(stream)
|
||||||
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
|
||||||
|
@ -194,18 +209,14 @@ class StreamManager:
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
||||||
stream.stop_download()
|
await self.stop_stream(stream)
|
||||||
|
if stream in self.streams:
|
||||||
self.streams.remove(stream)
|
self.streams.remove(stream)
|
||||||
|
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
|
||||||
|
await self.blob_manager.delete_blobs(blob_hashes, delete_from_db=False)
|
||||||
await self.storage.delete_stream(stream.descriptor)
|
await self.storage.delete_stream(stream.descriptor)
|
||||||
|
if delete_file and os.path.isfile(stream.full_path):
|
||||||
blob_hashes = [stream.sd_hash]
|
os.remove(stream.full_path)
|
||||||
for blob_info in stream.descriptor.blobs[:-1]:
|
|
||||||
blob_hashes.append(blob_info.blob_hash)
|
|
||||||
await self.blob_manager.delete_blobs(blob_hashes)
|
|
||||||
if delete_file:
|
|
||||||
path = os.path.join(stream.download_directory, stream.file_name)
|
|
||||||
if os.path.isfile(path):
|
|
||||||
os.remove(path)
|
|
||||||
|
|
||||||
def wait_for_stream_finished(self, stream: ManagedStream):
|
def wait_for_stream_finished(self, stream: ManagedStream):
|
||||||
async def _wait_for_stream_finished():
|
async def _wait_for_stream_finished():
|
||||||
|
@ -267,7 +278,6 @@ class StreamManager:
|
||||||
fee_amount: typing.Optional[float] = 0.0,
|
fee_amount: typing.Optional[float] = 0.0,
|
||||||
fee_address: typing.Optional[str] = None,
|
fee_address: typing.Optional[str] = None,
|
||||||
should_pay: typing.Optional[bool] = True) -> typing.Optional[ManagedStream]:
|
should_pay: typing.Optional[bool] = True) -> typing.Optional[ManagedStream]:
|
||||||
log.info("get lbry://%s#%s", claim_info['name'], claim_info['claim_id'])
|
|
||||||
claim = ClaimDict.load_dict(claim_info['value'])
|
claim = ClaimDict.load_dict(claim_info['value'])
|
||||||
sd_hash = claim.source_hash.decode()
|
sd_hash = claim.source_hash.decode()
|
||||||
if sd_hash in self.starting_streams:
|
if sd_hash in self.starting_streams:
|
||||||
|
@ -294,7 +304,6 @@ class StreamManager:
|
||||||
finally:
|
finally:
|
||||||
if sd_hash in self.starting_streams:
|
if sd_hash in self.starting_streams:
|
||||||
del self.starting_streams[sd_hash]
|
del self.starting_streams[sd_hash]
|
||||||
log.info("returned from get lbry://%s#%s", claim_info['name'], claim_info['claim_id'])
|
|
||||||
|
|
||||||
def get_stream_by_stream_hash(self, stream_hash: str) -> typing.Optional[ManagedStream]:
|
def get_stream_by_stream_hash(self, stream_hash: str) -> typing.Optional[ManagedStream]:
|
||||||
streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams))
|
streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams))
|
||||||
|
|
|
@ -35,6 +35,14 @@ class DummyExchangeRateManager(exchange_rate_manager.ExchangeRateManager):
|
||||||
feed.market, rates[feed.market]['spot'], rates[feed.market]['ts'])
|
feed.market, rates[feed.market]['spot'], rates[feed.market]['ts'])
|
||||||
|
|
||||||
|
|
||||||
|
def get_dummy_exchange_rate_manager(time):
|
||||||
|
rates = {
|
||||||
|
'BTCLBC': {'spot': 3.0, 'ts': time.time() + 1},
|
||||||
|
'USDBTC': {'spot': 2.0, 'ts': time.time() + 2}
|
||||||
|
}
|
||||||
|
return DummyExchangeRateManager([BTCLBCFeed()], rates)
|
||||||
|
|
||||||
|
|
||||||
class FeeFormatTest(unittest.TestCase):
|
class FeeFormatTest(unittest.TestCase):
|
||||||
def test_fee_created_with_correct_inputs(self):
|
def test_fee_created_with_correct_inputs(self):
|
||||||
fee_dict = {
|
fee_dict = {
|
||||||
|
|
|
@ -37,15 +37,16 @@ class TestStreamDownloader(BlobExchangeTestBase):
|
||||||
return q2, self.loop.create_task(_task())
|
return q2, self.loop.create_task(_task())
|
||||||
|
|
||||||
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
||||||
|
|
||||||
self.downloader.download(mock_node)
|
self.downloader.download(mock_node)
|
||||||
await self.downloader.stream_finished_event.wait()
|
await self.downloader.stream_finished_event.wait()
|
||||||
|
self.assertTrue(self.downloader.stream_handle.closed)
|
||||||
|
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||||
self.downloader.stop()
|
self.downloader.stop()
|
||||||
|
self.assertIs(self.downloader.stream_handle, None)
|
||||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||||
with open(self.downloader.output_path, 'rb') as f:
|
with open(self.downloader.output_path, 'rb') as f:
|
||||||
self.assertEqual(f.read(), self.stream_bytes)
|
self.assertEqual(f.read(), self.stream_bytes)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
self.assertTrue(self.downloader.stream_handle.closed)
|
|
||||||
|
|
||||||
async def test_transfer_stream(self):
|
async def test_transfer_stream(self):
|
||||||
await self._test_transfer_stream(10)
|
await self._test_transfer_stream(10)
|
||||||
|
|
139
tests/unit/stream/test_stream_manager.py
Normal file
139
tests/unit/stream/test_stream_manager.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
import os
|
||||||
|
import binascii
|
||||||
|
from unittest import mock
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
|
||||||
|
from tests.unit.lbrynet_daemon.test_ExchangeRateManager import get_dummy_exchange_rate_manager
|
||||||
|
|
||||||
|
from lbrynet.extras.wallet.manager import LbryWalletManager
|
||||||
|
from lbrynet.stream.stream_manager import StreamManager
|
||||||
|
from lbrynet.stream.descriptor import StreamDescriptor
|
||||||
|
from lbrynet.dht.node import Node
|
||||||
|
from lbrynet.schema.claim import ClaimDict
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_node(peer):
|
||||||
|
def mock_accumulate_peers(q1: asyncio.Queue, q2: asyncio.Queue):
|
||||||
|
async def _task():
|
||||||
|
pass
|
||||||
|
|
||||||
|
q2.put_nowait([peer])
|
||||||
|
return q2, asyncio.create_task(_task())
|
||||||
|
|
||||||
|
mock_node = mock.Mock(spec=Node)
|
||||||
|
mock_node.accumulate_peers = mock_accumulate_peers
|
||||||
|
return mock_node
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_wallet(sd_hash, storage):
|
||||||
|
claim = {
|
||||||
|
"address": "bYFeMtSL7ARuG1iMpjFyrnTe4oJHSAVNXF",
|
||||||
|
"amount": "0.1",
|
||||||
|
"claim_id": "c49566d631226492317d06ad7fdbe1ed32925124",
|
||||||
|
"claim_sequence": 1,
|
||||||
|
"decoded_claim": True,
|
||||||
|
"depth": 1057,
|
||||||
|
"effective_amount": "0.1",
|
||||||
|
"has_signature": False,
|
||||||
|
"height": 514081,
|
||||||
|
"hex": "",
|
||||||
|
"name": "33rpm",
|
||||||
|
"nout": 0,
|
||||||
|
"permanent_url": "33rpm#c49566d631226492317d06ad7fdbe1ed32925124",
|
||||||
|
"supports": [],
|
||||||
|
"txid": "81ac52662af926fdf639d56920069e0f63449d4cde074c61717cb99ddde40e3c",
|
||||||
|
"value": {
|
||||||
|
"claimType": "streamType",
|
||||||
|
"stream": {
|
||||||
|
"metadata": {
|
||||||
|
"author": "",
|
||||||
|
"description": "",
|
||||||
|
"language": "en",
|
||||||
|
"license": "None",
|
||||||
|
"licenseUrl": "",
|
||||||
|
"nsfw": False,
|
||||||
|
"preview": "",
|
||||||
|
"thumbnail": "",
|
||||||
|
"title": "33rpm",
|
||||||
|
"version": "_0_1_0"
|
||||||
|
},
|
||||||
|
"source": {
|
||||||
|
"contentType": "image/png",
|
||||||
|
"source": sd_hash,
|
||||||
|
"sourceType": "lbry_sd_hash",
|
||||||
|
"version": "_0_0_1"
|
||||||
|
},
|
||||||
|
"version": "_0_0_1"
|
||||||
|
},
|
||||||
|
"version": "_0_0_1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claim_dict = ClaimDict.load_dict(claim['value'])
|
||||||
|
claim['hex'] = binascii.hexlify(claim_dict.serialized).decode()
|
||||||
|
|
||||||
|
async def mock_resolve(*args):
|
||||||
|
await storage.save_claims([claim])
|
||||||
|
return {
|
||||||
|
claim['permanent_url']: claim
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wallet = mock.Mock(spec=LbryWalletManager)
|
||||||
|
mock_wallet.resolve = mock_resolve
|
||||||
|
return mock_wallet, claim['permanent_url']
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamManager(BlobExchangeTestBase):
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
|
file_path = os.path.join(self.server_dir, "test_file")
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
f.write(os.urandom(20000000))
|
||||||
|
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
|
||||||
|
self.sd_hash = descriptor.calculate_sd_hash()
|
||||||
|
self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage)
|
||||||
|
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet,
|
||||||
|
self.client_storage, get_mock_node(self.server_from_client))
|
||||||
|
self.exchange_rate_manager = get_dummy_exchange_rate_manager(time)
|
||||||
|
|
||||||
|
async def test_download_stop_resume_delete(self):
|
||||||
|
self.assertSetEqual(self.stream_manager.streams, set())
|
||||||
|
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
|
||||||
|
stream_hash = stream.stream_hash
|
||||||
|
self.assertSetEqual(self.stream_manager.streams, {stream})
|
||||||
|
self.assertTrue(stream.running)
|
||||||
|
self.assertFalse(stream.finished)
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||||
|
stored_status = await self.client_storage.run_and_return_one_or_none(
|
||||||
|
"select status from file where stream_hash=?", stream_hash
|
||||||
|
)
|
||||||
|
self.assertEqual(stored_status, "running")
|
||||||
|
|
||||||
|
await self.stream_manager.stop_stream(stream)
|
||||||
|
|
||||||
|
self.assertFalse(stream.finished)
|
||||||
|
self.assertFalse(stream.running)
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||||
|
stored_status = await self.client_storage.run_and_return_one_or_none(
|
||||||
|
"select status from file where stream_hash=?", stream_hash
|
||||||
|
)
|
||||||
|
self.assertEqual(stored_status, "stopped")
|
||||||
|
|
||||||
|
await self.stream_manager.start_stream(stream)
|
||||||
|
await stream.downloader.stream_finished_event.wait()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
self.assertTrue(stream.finished)
|
||||||
|
self.assertFalse(stream.running)
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||||
|
stored_status = await self.client_storage.run_and_return_one_or_none(
|
||||||
|
"select status from file where stream_hash=?", stream_hash
|
||||||
|
)
|
||||||
|
self.assertEqual(stored_status, "finished")
|
||||||
|
|
||||||
|
await self.stream_manager.delete_stream(stream, True)
|
||||||
|
self.assertSetEqual(self.stream_manager.streams, set())
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
|
||||||
|
stored_status = await self.client_storage.run_and_return_one_or_none(
|
||||||
|
"select status from file where stream_hash=?", stream_hash
|
||||||
|
)
|
||||||
|
self.assertEqual(stored_status, None)
|
Loading…
Reference in a new issue