From 3234d70270c946a85293880951dbec41ddfc0027 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 23 Apr 2019 09:52:20 -0400 Subject: [PATCH] add peer address/port to hash blob writer --- lbrynet/blob/blob_file.py | 27 ++++++++++++++++----------- lbrynet/blob_exchange/client.py | 5 +++-- lbrynet/stream/reflector/server.py | 4 ++-- tests/unit/blob/test_blob_file.py | 2 +- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index fef7396fd..290ed071a 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -79,7 +79,7 @@ class AbstractBlob: self.length = length self.blob_completed_callback = blob_completed_callback self.blob_directory = blob_directory - self.writers: typing.List[HashBlobWriter] = [] + self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {} self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.writing: asyncio.Event = asyncio.Event(loop=self.loop) self.readers: typing.List[typing.BinaryIO] = [] @@ -99,7 +99,7 @@ class AbstractBlob: @contextlib.contextmanager def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: if not self.is_readable(): - raise OSError("not readable") + raise OSError(f"{str(type(self))} not readable, {len(self.readers)} readers {len(self.writers)} writers") with self._reader_context() as reader: try: self.readers.append(reader) @@ -142,7 +142,8 @@ class AbstractBlob: def close(self): while self.writers: - self.writers.pop().finished.cancel() + peer, writer = self.writers.popitem() + writer.finished.cancel() while self.readers: reader = self.readers.pop() if reader: @@ -198,10 +199,13 @@ class AbstractBlob: else: self.verified.set() - def get_blob_writer(self) -> HashBlobWriter: + def get_blob_writer(self, peer_address: typing.Optional[str] = None, + peer_port: typing.Optional[int] = None) -> HashBlobWriter: + if (peer_address, peer_port) in self.writers: + log.exception("attempted to download blob twice from %s:%s", peer_address, peer_port) fut = asyncio.Future(loop=self.loop) writer = HashBlobWriter(self.blob_hash, self.get_length, fut) - self.writers.append(writer) + self.writers[(peer_address, peer_port)] = writer def writer_finished_callback(finished: asyncio.Future): try: @@ -210,18 +214,18 @@ class AbstractBlob: raise err verified_bytes = finished.result() while self.writers: - other = self.writers.pop() + _, other = self.writers.popitem() if other is not writer: other.finished.cancel() self.save_verified_blob(verified_bytes) return except (InvalidBlobHashError, InvalidDataError) as error: - log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error)) + log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error)) except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError): pass finally: - if writer in self.writers: - self.writers.remove(writer) + if (peer_address, peer_port) in self.writers: + self.writers.pop((peer_address, peer_port)) fut.add_done_callback(writer_finished_callback) return writer @@ -292,10 +296,11 @@ class BlobFile(AbstractBlob): def is_writeable(self) -> bool: return super().is_writeable() and not os.path.isfile(self.file_path) - def get_blob_writer(self) -> HashBlobWriter: + def get_blob_writer(self, peer_address: typing.Optional[str] = None, + peer_port: typing.Optional[str] = None) -> HashBlobWriter: if self.file_exists: raise OSError(f"File already exists '{self.file_path}'") - return super().get_blob_writer() + return super().get_blob_writer(peer_address, peer_port) @contextlib.contextmanager def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index 10c5d10df..11f0d1804 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -148,8 +148,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol): if blob.get_is_verified() or not blob.is_writeable(): return 0, self.transport try: - blob.get_blob_writer() - self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0 + + self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(self.peer_address, + self.peer_port), 0 self._response_fut = asyncio.Future(loop=self.loop) return await self._download_blob() except OSError as e: diff --git a/lbrynet/stream/reflector/server.py b/lbrynet/stream/reflector/server.py index 8d4c0289e..fd43f0c3a 100644 --- a/lbrynet/stream/reflector/server.py +++ b/lbrynet/stream/reflector/server.py @@ -63,7 +63,7 @@ class ReflectorServerProtocol(asyncio.Protocol): return self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) if not self.sd_blob.get_is_verified(): - self.writer = self.sd_blob.get_blob_writer() + self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.incoming.set() self.send_response({"send_sd_blob": True}) try: @@ -102,7 +102,7 @@ class ReflectorServerProtocol(asyncio.Protocol): return blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) if not blob.get_is_verified(): - self.writer = blob.get_blob_writer() + self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.incoming.set() self.send_response({"send_blob": True}) try: diff --git a/tests/unit/blob/test_blob_file.py b/tests/unit/blob/test_blob_file.py index 90b5d9ce9..8b519e27e 100644 --- a/tests/unit/blob/test_blob_file.py +++ b/tests/unit/blob/test_blob_file.py @@ -41,7 +41,7 @@ class TestBlob(AsyncioTestCase): async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None): blob = self._get_blob(blob_class, blob_directory=blob_directory) - writers = [blob.get_blob_writer() for _ in range(5)] + writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)] self.assertEqual(5, len(blob.writers)) # test that writing too much causes the writer to fail with InvalidDataError and to be removed