diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index 882902e9b..484e93122 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -143,7 +143,8 @@ class AbstractBlob: def close(self): while self.writers: peer, writer = self.writers.popitem() - writer.finished.cancel() + if writer and writer.finished and not writer.finished.done() and not self.loop.is_closed(): + writer.finished.cancel() while self.readers: reader = self.readers.pop() if reader: @@ -206,6 +207,12 @@ class AbstractBlob: writer = HashBlobWriter(self.blob_hash, self.get_length, fut) self.writers[(peer_address, peer_port)] = writer + def remove_writer(_): + if (peer_address, peer_port) in self.writers: + del self.writers[(peer_address, peer_port)] + + fut.add_done_callback(remove_writer) + def writer_finished_callback(finished: asyncio.Future): try: err = finished.exception() @@ -215,16 +222,13 @@ class AbstractBlob: while self.writers: _, other = self.writers.popitem() if other is not writer: - other.finished.cancel() + other.close_handle() self.save_verified_blob(verified_bytes) - return except (InvalidBlobHashError, InvalidDataError) as error: log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error)) except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError): pass - finally: - if (peer_address, peer_port) in self.writers: - self.writers.pop((peer_address, peer_port)) + fut.add_done_callback(writer_finished_callback) return writer diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index 58d124951..2cc2c18c8 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -109,7 +109,7 @@ class TestBlobExchange(BlobExchangeTestBase): await self._add_blob_to_server(blob_hash, mock_blob_bytes) - second_client_blob = self.client_blob_manager.get_blob(blob_hash) + second_client_blob = second_client_blob_manager.get_blob(blob_hash) # download the blob await asyncio.gather( @@ -122,6 +122,62 @@ class TestBlobExchange(BlobExchangeTestBase): await second_client_blob.verified.wait() self.assertEqual(second_client_blob.get_is_verified(), True) + async def test_blob_writers_concurrency(self): + blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" + mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1) + blob = self.server_blob_manager.get_blob(blob_hash) + write_blob = blob._write_blob + write_called_count = 0 + + def wrap_write_blob(blob_bytes): + nonlocal write_called_count + write_called_count += 1 + write_blob(blob_bytes) + blob._write_blob = wrap_write_blob + + writer1 = blob.get_blob_writer(peer_port=1) + writer2 = blob.get_blob_writer(peer_port=2) + reader1_ctx_before_write = blob.reader_context() + + with self.assertRaises(OSError): + blob.get_blob_writer(peer_port=2) + with self.assertRaises(OSError): + with blob.reader_context(): + pass + + blob.set_length(len(mock_blob_bytes)) + results = {} + + def check_finished_callback(writer, num): + def inner(writer_future: asyncio.Future): + results[num] = writer_future.result() + writer.finished.add_done_callback(inner) + + check_finished_callback(writer1, 1) + check_finished_callback(writer2, 2) + + def write_task(writer): + async def _inner(): + writer.write(mock_blob_bytes) + return self.loop.create_task(_inner()) + + await asyncio.gather(write_task(writer1), write_task(writer2), loop=self.loop) + + self.assertDictEqual({1: mock_blob_bytes, 2: mock_blob_bytes}, results) + self.assertEqual(1, write_called_count) + self.assertTrue(blob.get_is_verified()) + self.assertDictEqual({}, blob.writers) + + with reader1_ctx_before_write as f: + self.assertEqual(mock_blob_bytes, f.read()) + with blob.reader_context() as f: + self.assertEqual(mock_blob_bytes, f.read()) + with blob.reader_context() as f: + blob.close() + with self.assertRaises(ValueError): + f.read() + self.assertListEqual([], blob.readers) + async def test_host_different_blobs_to_multiple_peers_at_once(self): blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1)