Merge pull request #1866 from lbryio/non-async-close-blob

refactor blob.close to be non-async, speed up deleting blobs and streams
This commit is contained in:
Jack Robison 2019-02-06 13:52:37 -05:00 committed by GitHub
commit 3508da4993
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 256 additions and 84 deletions

View file

@ -132,16 +132,18 @@ class BlobFile:
with open(self.file_path, 'rb') as handle: with open(self.file_path, 'rb') as handle:
return await self.loop.sendfile(writer.transport, handle, count=self.get_length()) return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
async def close(self): def close(self):
while self.writers: while self.writers:
self.writers.pop().finished.cancel() self.writers.pop().finished.cancel()
async def delete(self): async def delete(self):
await self.close() self.close()
async with self.blob_write_lock: async with self.blob_write_lock:
self.saved_verified_blob = False self.saved_verified_blob = False
if os.path.isfile(self.file_path): if os.path.isfile(self.file_path):
os.remove(self.file_path) os.remove(self.file_path)
self.verified.clear()
self.finished_writing.clear()
def decrypt(self, key: bytes, iv: bytes) -> bytes: def decrypt(self, key: bytes, iv: bytes) -> bytes:
""" """

View file

@ -63,21 +63,21 @@ class BlobFileManager:
blob_hashes = await self.storage.get_all_blob_hashes() blob_hashes = await self.storage.get_all_blob_hashes()
return self.check_completed_blobs(blob_hashes) return self.check_completed_blobs(blob_hashes)
async def delete_blobs(self, blob_hashes: typing.List[str]): async def delete_blob(self, blob_hash: str):
bh_to_delete_from_db = []
for blob_hash in blob_hashes:
if not blob_hash:
continue
try: try:
blob = self.get_blob(blob_hash) blob = self.get_blob(blob_hash)
await blob.delete() await blob.delete()
bh_to_delete_from_db.append(blob_hash)
except Exception as e: except Exception as e:
log.warning("Failed to delete blob file. Reason: %s", e) log.warning("Failed to delete blob file. Reason: %s", e)
if blob_hash in self.completed_blob_hashes: if blob_hash in self.completed_blob_hashes:
self.completed_blob_hashes.remove(blob_hash) self.completed_blob_hashes.remove(blob_hash)
if blob_hash in self.blobs: if blob_hash in self.blobs:
del self.blobs[blob_hash] del self.blobs[blob_hash]
async def delete_blobs(self, blob_hashes: typing.List[str], delete_from_db: typing.Optional[bool] = True):
bh_to_delete_from_db = []
await asyncio.gather(*map(self.delete_blob, blob_hashes), loop=self.loop)
if delete_from_db:
try: try:
await self.storage.delete_blobs_from_db(bh_to_delete_from_db) await self.storage.delete_blobs_from_db(bh_to_delete_from_db)
except IntegrityError as err: except IntegrityError as err:

View file

@ -86,7 +86,7 @@ class BlobDownloader:
peer, task = self.active_connections.popitem() peer, task = self.active_connections.popitem()
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
await blob.close() blob.close()
log.debug("downloaded %s", blob_hash[:8]) log.debug("downloaded %s", blob_hash[:8])
return blob return blob
except asyncio.CancelledError: except asyncio.CancelledError:

View file

@ -474,7 +474,12 @@ class KademliaProtocol(DatagramProtocol):
remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})") remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})")
if error_datagram.rpc_id in self.sent_messages: if error_datagram.rpc_id in self.sent_messages:
peer, df, request = self.sent_messages.pop(error_datagram.rpc_id) peer, df, request = self.sent_messages.pop(error_datagram.rpc_id)
if (peer.address, peer.udp_port) != address:
df.set_exception(RemoteException(
f"response from {address[0]}:{address[1]}, "
f"expected {peer.address}:{peer.udp_port}")
)
return
error_msg = f"" \ error_msg = f"" \
f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \ f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \
f"Args: {request.args}\n" \ f"Args: {request.args}\n" \
@ -484,11 +489,6 @@ class KademliaProtocol(DatagramProtocol):
else: else:
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)", log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
peer.address, peer.udp_port, old_protocol_errors[error_datagram.response]) peer.address, peer.udp_port, old_protocol_errors[error_datagram.response])
# reject replies coming from a different address than what we sent our request to
if (peer.address, peer.udp_port) != address:
log.error("node id mismatch in reply")
remote_exception = TimeoutError(peer.node_id)
df.set_exception(remote_exception) df.set_exception(remote_exception)
return return
else: else:

View file

@ -323,7 +323,7 @@ class BlobComponent(Component):
async def stop(self): async def stop(self):
while self.blob_manager and self.blob_manager.blobs: while self.blob_manager and self.blob_manager.blobs:
_, blob = self.blob_manager.blobs.popitem() _, blob = self.blob_manager.blobs.popitem()
await blob.close() blob.close()
async def get_status(self): async def get_status(self):
count = 0 count = 0

View file

@ -1614,7 +1614,7 @@ class Daemon(metaclass=JSONRPCServerType):
await self.stream_manager.start_stream(stream) await self.stream_manager.start_stream(stream)
msg = "Resumed download" msg = "Resumed download"
elif status == 'stop' and stream.running: elif status == 'stop' and stream.running:
stream.stop_download() await self.stream_manager.stop_stream(stream)
msg = "Stopped download" msg = "Stopped download"
else: else:
msg = ( msg = (

View file

@ -43,7 +43,7 @@ class StreamAssembler:
self.written_bytes: int = 0 self.written_bytes: int = 0
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str): async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
if not blob or self.stream_handle.closed: if not blob or not self.stream_handle or self.stream_handle.closed:
return False return False
def _decrypt_and_write(): def _decrypt_and_write():
@ -56,7 +56,6 @@ class StreamAssembler:
self.stream_handle.flush() self.stream_handle.flush()
self.written_bytes += len(_decrypted) self.written_bytes += len(_decrypted)
log.debug("decrypted %s", blob.blob_hash[:8]) log.debug("decrypted %s", blob.blob_hash[:8])
self.wrote_bytes_event.set()
await self.loop.run_in_executor(None, _decrypt_and_write) await self.loop.run_in_executor(None, _decrypt_and_write)
return True return True
@ -86,17 +85,23 @@ class StreamAssembler:
self.sd_blob, self.descriptor self.sd_blob, self.descriptor
) )
await self.blob_manager.blob_completed(self.sd_blob) await self.blob_manager.blob_completed(self.sd_blob)
written_blobs = None
try:
with open(self.output_path, 'wb') as stream_handle: with open(self.output_path, 'wb') as stream_handle:
self.stream_handle = stream_handle self.stream_handle = stream_handle
for i, blob_info in enumerate(self.descriptor.blobs[:-1]): for i, blob_info in enumerate(self.descriptor.blobs[:-1]):
if blob_info.blob_num != i: if blob_info.blob_num != i:
log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash) log.error("sd blob %s is invalid, cannot assemble stream", self.descriptor.sd_hash)
return return
while not stream_handle.closed: while self.stream_handle and not self.stream_handle.closed:
try: try:
blob = await self.get_blob(blob_info.blob_hash, blob_info.length) blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
if await self._decrypt_blob(blob, blob_info, self.descriptor.key): if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
await self.blob_manager.blob_completed(blob) await self.blob_manager.blob_completed(blob)
written_blobs = i
if not self.wrote_bytes_event.is_set():
self.wrote_bytes_event.set()
log.debug("written %i/%i", written_blobs, len(self.descriptor.blobs) - 2)
break break
except FileNotFoundError: except FileNotFoundError:
log.debug("stream assembler stopped") log.debug("stream assembler stopped")
@ -105,9 +110,14 @@ class StreamAssembler:
log.warning("failed to decrypt blob %s for stream %s", blob_info.blob_hash, log.warning("failed to decrypt blob %s for stream %s", blob_info.blob_hash,
self.descriptor.sd_hash) self.descriptor.sd_hash)
continue continue
finally:
self.stream_finished_event.set() if written_blobs == len(self.descriptor.blobs) - 2:
log.debug("finished decrypting and assembling stream")
await self.after_finished() await self.after_finished()
self.stream_finished_event.set()
else:
log.debug("stream decryption and assembly did not finish (%i/%i blobs are done)", written_blobs or 0,
len(self.descriptor.blobs) - 2)
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
return self.blob_manager.get_blob(blob_hash, length) return self.blob_manager.get_blob(blob_hash, length)

View file

@ -86,7 +86,7 @@ class StreamDescriptor:
writer = sd_blob.open_for_writing() writer = sd_blob.open_for_writing()
writer.write(sd_data) writer.write(sd_data)
await sd_blob.verified.wait() await sd_blob.verified.wait()
await sd_blob.close() sd_blob.close()
return sd_blob return sd_blob
@classmethod @classmethod

View file

@ -63,6 +63,10 @@ class StreamDownloader(StreamAssembler):
self.fixed_peers_handle.cancel() self.fixed_peers_handle.cancel()
self.fixed_peers_handle = None self.fixed_peers_handle = None
self.blob_downloader = None self.blob_downloader = None
if self.stream_handle:
if not self.stream_handle.closed:
self.stream_handle.close()
self.stream_handle = None
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
return await self.blob_downloader.download_blob(blob_hash, length) return await self.blob_downloader.download_blob(blob_hash, length)

View file

@ -104,8 +104,12 @@ class ManagedStream:
def blobs_remaining(self) -> int: def blobs_remaining(self) -> int:
return self.blobs_in_stream - self.blobs_completed return self.blobs_in_stream - self.blobs_completed
@property
def full_path(self) -> str:
return os.path.join(self.download_directory, os.path.basename(self.file_name))
def as_dict(self) -> typing.Dict: def as_dict(self) -> typing.Dict:
full_path = os.path.join(self.download_directory, self.file_name) full_path = self.full_path
if not os.path.isfile(full_path): if not os.path.isfile(full_path):
full_path = None full_path = None
mime_type = guess_media_type(os.path.basename(self.file_name)) mime_type = guess_media_type(os.path.basename(self.file_name))
@ -170,12 +174,7 @@ class ManagedStream:
def stop_download(self): def stop_download(self):
if self.downloader: if self.downloader:
self.downloader.stop() self.downloader.stop()
if not self.downloader.stream_finished_event.is_set() and self.downloader.wrote_bytes_event.is_set(): self.downloader = None
path = os.path.join(self.download_directory, self.file_name)
if os.path.isfile(path):
os.remove(path)
if not self.finished:
self.update_status(self.STATUS_STOPPED)
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]: async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
sent = [] sent = []

View file

@ -4,7 +4,7 @@ import typing
import binascii import binascii
import logging import logging
import random import random
from lbrynet.error import ResolveError from lbrynet.error import ResolveError, InvalidStreamDescriptorError
from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import ClaimDict from lbrynet.schema.claim import ClaimDict
@ -97,8 +97,9 @@ class StreamManager:
await asyncio.wait_for(self.loop.create_task(stream.downloader.got_descriptor.wait()), await asyncio.wait_for(self.loop.create_task(stream.downloader.got_descriptor.wait()),
self.config.download_timeout) self.config.download_timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
stream.stop_download() await self.stop_stream(stream)
stream.downloader = None if stream in self.streams:
self.streams.remove(stream)
return False return False
file_name = os.path.basename(stream.downloader.output_path) file_name = os.path.basename(stream.downloader.output_path)
await self.storage.change_file_download_dir_and_file_name( await self.storage.change_file_download_dir_and_file_name(
@ -108,6 +109,18 @@ class StreamManager:
return True return True
return True return True
async def stop_stream(self, stream: ManagedStream):
stream.stop_download()
if not stream.finished and os.path.isfile(stream.full_path):
try:
os.remove(stream.full_path)
except OSError as err:
log.warning("Failed to delete partial download %s from downloads directory: %s", stream.full_path,
str(err))
if stream.running:
stream.update_status(ManagedStream.STATUS_STOPPED)
await self.storage.change_file_status(stream.stream_hash, ManagedStream.STATUS_STOPPED)
def make_downloader(self, sd_hash: str, download_directory: str, file_name: str): def make_downloader(self, sd_hash: str, download_directory: str, file_name: str):
return StreamDownloader( return StreamDownloader(
self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name self.loop, self.config, self.blob_manager, sd_hash, download_directory, file_name
@ -116,13 +129,15 @@ class StreamManager:
async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim): async def add_stream(self, sd_hash: str, file_name: str, download_directory: str, status: str, claim):
sd_blob = self.blob_manager.get_blob(sd_hash) sd_blob = self.blob_manager.get_blob(sd_hash)
if sd_blob.get_is_verified(): if sd_blob.get_is_verified():
try:
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash) descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
except InvalidStreamDescriptorError as err:
log.warning("Failed to start stream for sd %s - %s", sd_hash, str(err))
return
downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name) downloader = self.make_downloader(descriptor.sd_hash, download_directory, file_name)
stream = ManagedStream( stream = ManagedStream(
self.loop, self.blob_manager, descriptor, self.loop, self.blob_manager, descriptor, download_directory, file_name, downloader, status, claim
download_directory,
file_name,
downloader, status, claim
) )
self.streams.add(stream) self.streams.add(stream)
self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream)
@ -194,18 +209,14 @@ class StreamManager:
return stream return stream
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
stream.stop_download() await self.stop_stream(stream)
if stream in self.streams:
self.streams.remove(stream) self.streams.remove(stream)
blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]]
await self.blob_manager.delete_blobs(blob_hashes, delete_from_db=False)
await self.storage.delete_stream(stream.descriptor) await self.storage.delete_stream(stream.descriptor)
if delete_file and os.path.isfile(stream.full_path):
blob_hashes = [stream.sd_hash] os.remove(stream.full_path)
for blob_info in stream.descriptor.blobs[:-1]:
blob_hashes.append(blob_info.blob_hash)
await self.blob_manager.delete_blobs(blob_hashes)
if delete_file:
path = os.path.join(stream.download_directory, stream.file_name)
if os.path.isfile(path):
os.remove(path)
def wait_for_stream_finished(self, stream: ManagedStream): def wait_for_stream_finished(self, stream: ManagedStream):
async def _wait_for_stream_finished(): async def _wait_for_stream_finished():
@ -267,7 +278,6 @@ class StreamManager:
fee_amount: typing.Optional[float] = 0.0, fee_amount: typing.Optional[float] = 0.0,
fee_address: typing.Optional[str] = None, fee_address: typing.Optional[str] = None,
should_pay: typing.Optional[bool] = True) -> typing.Optional[ManagedStream]: should_pay: typing.Optional[bool] = True) -> typing.Optional[ManagedStream]:
log.info("get lbry://%s#%s", claim_info['name'], claim_info['claim_id'])
claim = ClaimDict.load_dict(claim_info['value']) claim = ClaimDict.load_dict(claim_info['value'])
sd_hash = claim.source_hash.decode() sd_hash = claim.source_hash.decode()
if sd_hash in self.starting_streams: if sd_hash in self.starting_streams:
@ -294,7 +304,6 @@ class StreamManager:
finally: finally:
if sd_hash in self.starting_streams: if sd_hash in self.starting_streams:
del self.starting_streams[sd_hash] del self.starting_streams[sd_hash]
log.info("returned from get lbry://%s#%s", claim_info['name'], claim_info['claim_id'])
def get_stream_by_stream_hash(self, stream_hash: str) -> typing.Optional[ManagedStream]: def get_stream_by_stream_hash(self, stream_hash: str) -> typing.Optional[ManagedStream]:
streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams)) streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams))

View file

@ -35,6 +35,14 @@ class DummyExchangeRateManager(exchange_rate_manager.ExchangeRateManager):
feed.market, rates[feed.market]['spot'], rates[feed.market]['ts']) feed.market, rates[feed.market]['spot'], rates[feed.market]['ts'])
def get_dummy_exchange_rate_manager(time):
rates = {
'BTCLBC': {'spot': 3.0, 'ts': time.time() + 1},
'USDBTC': {'spot': 2.0, 'ts': time.time() + 2}
}
return DummyExchangeRateManager([BTCLBCFeed()], rates)
class FeeFormatTest(unittest.TestCase): class FeeFormatTest(unittest.TestCase):
def test_fee_created_with_correct_inputs(self): def test_fee_created_with_correct_inputs(self):
fee_dict = { fee_dict = {

View file

@ -37,15 +37,16 @@ class TestStreamDownloader(BlobExchangeTestBase):
return q2, self.loop.create_task(_task()) return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
self.downloader.download(mock_node) self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait() await self.downloader.stream_finished_event.wait()
self.assertTrue(self.downloader.stream_handle.closed)
self.assertTrue(os.path.isfile(self.downloader.output_path))
self.downloader.stop() self.downloader.stop()
self.assertIs(self.downloader.stream_handle, None)
self.assertTrue(os.path.isfile(self.downloader.output_path)) self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f: with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes) self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
self.assertTrue(self.downloader.stream_handle.closed)
async def test_transfer_stream(self): async def test_transfer_stream(self):
await self._test_transfer_stream(10) await self._test_transfer_stream(10)

View file

@ -0,0 +1,139 @@
import os
import binascii
from unittest import mock
import asyncio
import time
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
from tests.unit.lbrynet_daemon.test_ExchangeRateManager import get_dummy_exchange_rate_manager
from lbrynet.extras.wallet.manager import LbryWalletManager
from lbrynet.stream.stream_manager import StreamManager
from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.dht.node import Node
from lbrynet.schema.claim import ClaimDict
def get_mock_node(peer):
def mock_accumulate_peers(q1: asyncio.Queue, q2: asyncio.Queue):
async def _task():
pass
q2.put_nowait([peer])
return q2, asyncio.create_task(_task())
mock_node = mock.Mock(spec=Node)
mock_node.accumulate_peers = mock_accumulate_peers
return mock_node
def get_mock_wallet(sd_hash, storage):
claim = {
"address": "bYFeMtSL7ARuG1iMpjFyrnTe4oJHSAVNXF",
"amount": "0.1",
"claim_id": "c49566d631226492317d06ad7fdbe1ed32925124",
"claim_sequence": 1,
"decoded_claim": True,
"depth": 1057,
"effective_amount": "0.1",
"has_signature": False,
"height": 514081,
"hex": "",
"name": "33rpm",
"nout": 0,
"permanent_url": "33rpm#c49566d631226492317d06ad7fdbe1ed32925124",
"supports": [],
"txid": "81ac52662af926fdf639d56920069e0f63449d4cde074c61717cb99ddde40e3c",
"value": {
"claimType": "streamType",
"stream": {
"metadata": {
"author": "",
"description": "",
"language": "en",
"license": "None",
"licenseUrl": "",
"nsfw": False,
"preview": "",
"thumbnail": "",
"title": "33rpm",
"version": "_0_1_0"
},
"source": {
"contentType": "image/png",
"source": sd_hash,
"sourceType": "lbry_sd_hash",
"version": "_0_0_1"
},
"version": "_0_0_1"
},
"version": "_0_0_1"
}
}
claim_dict = ClaimDict.load_dict(claim['value'])
claim['hex'] = binascii.hexlify(claim_dict.serialized).decode()
async def mock_resolve(*args):
await storage.save_claims([claim])
return {
claim['permanent_url']: claim
}
mock_wallet = mock.Mock(spec=LbryWalletManager)
mock_wallet.resolve = mock_resolve
return mock_wallet, claim['permanent_url']
class TestStreamManager(BlobExchangeTestBase):
async def asyncSetUp(self):
await super().asyncSetUp()
file_path = os.path.join(self.server_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(os.urandom(20000000))
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
self.mock_wallet, self.uri = get_mock_wallet(self.sd_hash, self.client_storage)
self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet,
self.client_storage, get_mock_node(self.server_from_client))
self.exchange_rate_manager = get_dummy_exchange_rate_manager(time)
async def test_download_stop_resume_delete(self):
self.assertSetEqual(self.stream_manager.streams, set())
stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager)
stream_hash = stream.stream_hash
self.assertSetEqual(self.stream_manager.streams, {stream})
self.assertTrue(stream.running)
self.assertFalse(stream.finished)
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash
)
self.assertEqual(stored_status, "running")
await self.stream_manager.stop_stream(stream)
self.assertFalse(stream.finished)
self.assertFalse(stream.running)
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash
)
self.assertEqual(stored_status, "stopped")
await self.stream_manager.start_stream(stream)
await stream.downloader.stream_finished_event.wait()
await asyncio.sleep(0.01)
self.assertTrue(stream.finished)
self.assertFalse(stream.running)
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash
)
self.assertEqual(stored_status, "finished")
await self.stream_manager.delete_stream(stream, True)
self.assertSetEqual(self.stream_manager.streams, set())
self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file")))
stored_status = await self.client_storage.run_and_return_one_or_none(
"select status from file where stream_hash=?", stream_hash
)
self.assertEqual(stored_status, None)