diff --git a/lbry/tests/unit/blob/test_blob_file.py b/lbry/tests/unit/blob/test_blob_file.py index daadca847..28c6e061c 100644 --- a/lbry/tests/unit/blob/test_blob_file.py +++ b/lbry/tests/unit/blob/test_blob_file.py @@ -10,6 +10,22 @@ from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob +class MockTransport(asyncio.Transport): + def __init__(self): + self.closed = asyncio.Event() + self.paused = asyncio.Event() + self._extra = {} + + def close(self): + self.closed.set() + + def pause_reading(self) -> None: + self.paused.set() + + def resume_reading(self) -> None: + self.paused.clear() + + class TestBlob(AsyncioTestCase): blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" blob_bytes = b'1' * ((2 * 2 ** 20) - 1) @@ -41,7 +57,7 @@ class TestBlob(AsyncioTestCase): async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None): blob = self._get_blob(blob_class, blob_directory=blob_directory) - writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)] + writers = [blob.get_blob_writer(MockTransport()) for _ in range(5)] self.assertEqual(5, len(blob.writers)) # test that writing too much causes the writer to fail with InvalidDataError and to be removed @@ -137,14 +153,15 @@ class TestBlob(AsyncioTestCase): blob_directory = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(blob_directory)) blob = self._get_blob(BlobBuffer, blob_directory=blob_directory) - writer = blob.get_blob_writer('1.1.1.1', 1337) + transport = MockTransport() + writer = blob.get_blob_writer(transport) self.assertEqual(1, len(blob.writers)) with self.assertRaises(OSError): - blob.get_blob_writer('1.1.1.1', 1337) + blob.get_blob_writer(transport) writer.close_handle() - self.assertTrue(blob.writers[('1.1.1.1', 1337)].closed()) - writer = blob.get_blob_writer('1.1.1.1', 1337) - self.assertEqual(blob.writers[('1.1.1.1', 1337)], writer) + self.assertTrue(blob.writers[(transport)].closed()) + writer = blob.get_blob_writer(transport) + self.assertEqual(blob.writers[transport], writer) writer.close_handle() await asyncio.sleep(0.000000001) # flush callbacks self.assertEqual(0, len(blob.writers))