diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index 255cdcc72..445a8f5da 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -68,7 +68,8 @@ class AbstractBlob: 'blob_directory', 'writers', 'verified', - 'writing' + 'writing', + 'readers' ] def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, @@ -82,18 +83,29 @@ class AbstractBlob: self.writers: typing.List[HashBlobWriter] = [] self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.writing: asyncio.Event = asyncio.Event(loop=self.loop) + self.readers: typing.List[typing.BinaryIO] = [] + if not is_valid_blobhash(blob_hash): raise InvalidBlobHashError(blob_hash) def __del__(self): - if self.writers or self.is_readable(): + if self.writers or self.readers: log.warning("%s not closed before being garbage collected", self.blob_hash) self.close() @contextlib.contextmanager - def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: raise NotImplementedError() + @contextlib.contextmanager + def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + try: + with self._reader_context() as reader: + self.readers.append(reader) + yield reader + finally: + self.readers = [reader for reader in self.readers if reader is not None] + def _write_blob(self, blob_bytes: bytes): raise NotImplementedError() @@ -129,6 +141,10 @@ class AbstractBlob: def close(self): while self.writers: self.writers.pop().finished.cancel() + while self.readers: + reader = self.readers.pop() + if reader: + reader.close() def delete(self): self.close() @@ -213,11 +229,11 @@ class BlobBuffer(AbstractBlob): def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None, blob_directory: typing.Optional[str] = None): - super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) self._verified_bytes: typing.Optional[BytesIO] = None + super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) @contextlib.contextmanager - def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: if not self.is_readable(): raise OSError("cannot open blob for reading") try: @@ -238,6 +254,11 @@ class BlobBuffer(AbstractBlob): self._verified_bytes = None return super().delete() + def __del__(self): + super().__del__() + if self._verified_bytes: + self.delete() + class BlobFile(AbstractBlob): """ @@ -272,7 +293,7 @@ class BlobFile(AbstractBlob): return super().get_blob_writer() @contextlib.contextmanager - def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: + def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: handle = open(self.file_path, 'rb') try: yield handle diff --git a/tests/unit/blob/test_blob_file.py b/tests/unit/blob/test_blob_file.py index 9a39440e2..5933a906d 100644 --- a/tests/unit/blob/test_blob_file.py +++ b/tests/unit/blob/test_blob_file.py @@ -150,3 +150,27 @@ class TestBlob(AsyncioTestCase): self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, '', len(self.blob_bytes)) self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, 'x' * 96, len(self.blob_bytes)) self.assertRaises(InvalidBlobHashError, BlobBuffer, self.loop, 'a' * 97, len(self.blob_bytes)) + + async def _test_close_reader(self, blob_class=AbstractBlob, blob_directory=None): + blob = await self._test_create_blob(blob_class, blob_directory) + reader = blob.reader_context() + self.assertEqual(0, len(blob.readers)) + + async def read_blob_buffer(): + with reader as read_handle: + self.assertEqual(1, len(blob.readers)) + await asyncio.sleep(2, loop=self.loop) + self.assertEqual(0, len(blob.readers)) + return read_handle.read() + + self.loop.call_later(1, blob.close) + with self.assertRaises(ValueError) as err: + read_task = self.loop.create_task(read_blob_buffer()) + await read_task + self.assertEqual(err.exception, ValueError("I/O operation on closed file")) + + async def test_close_reader(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + await self._test_close_reader(BlobBuffer) + await self._test_close_reader(BlobFile, tmp_dir)