From 49c51267dc0e26cda5d1916f49154dfd80dee936 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Mon, 12 Aug 2019 12:53:25 -0400 Subject: [PATCH] update get_blob_writer to accept an asyncio.Transport instead of an address/port --- lbry/lbry/blob/blob_file.py | 20 +++++++++---------- lbry/lbry/blob/writer.py | 5 +++-- lbry/lbry/blob_exchange/client.py | 3 ++- lbry/lbry/stream/reflector/server.py | 4 ++-- .../unit/blob_exchange/test_transfer_blob.py | 9 ++++++--- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/lbry/lbry/blob/blob_file.py b/lbry/lbry/blob/blob_file.py index 45ac2b26f..fc8f705c1 100644 --- a/lbry/lbry/blob/blob_file.py +++ b/lbry/lbry/blob/blob_file.py @@ -79,7 +79,7 @@ class AbstractBlob: self.length = length self.blob_completed_callback = blob_completed_callback self.blob_directory = blob_directory - self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {} + self.writers: typing.Dict[typing.Optional[asyncio.Transport], HashBlobWriter] = {} self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.writing: asyncio.Event = asyncio.Event(loop=self.loop) self.readers: typing.List[typing.BinaryIO] = [] @@ -199,17 +199,16 @@ class AbstractBlob: if self.blob_completed_callback: self.blob_completed_callback(self) - def get_blob_writer(self, peer_address: typing.Optional[str] = None, - peer_port: typing.Optional[int] = None) -> HashBlobWriter: - if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed(): - raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}") + def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter: + if transport and transport in self.writers and not self.writers[transport].closed(): + raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}") fut = asyncio.Future(loop=self.loop) writer = HashBlobWriter(self.blob_hash, self.get_length, fut) - self.writers[(peer_address, peer_port)] = writer + self.writers[transport] = writer def remove_writer(_): - if (peer_address, peer_port) in self.writers: - del self.writers[(peer_address, peer_port)] + if transport in self.writers: + del self.writers[transport] fut.add_done_callback(remove_writer) @@ -299,11 +298,10 @@ class BlobFile(AbstractBlob): def is_writeable(self) -> bool: return super().is_writeable() and not os.path.isfile(self.file_path) - def get_blob_writer(self, peer_address: typing.Optional[str] = None, - peer_port: typing.Optional[str] = None) -> HashBlobWriter: + def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter: if self.file_exists: raise OSError(f"File already exists '{self.file_path}'") - return super().get_blob_writer(peer_address, peer_port) + return super().get_blob_writer(transport) @contextlib.contextmanager def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: diff --git a/lbry/lbry/blob/writer.py b/lbry/lbry/blob/writer.py index 608abf83f..9323f7dc3 100644 --- a/lbry/lbry/blob/writer.py +++ b/lbry/lbry/blob/writer.py @@ -1,16 +1,17 @@ import typing import logging -import asyncio from io import BytesIO from lbry.error import InvalidBlobHashError, InvalidDataError from lbry.cryptoutils import get_lbry_hash_obj +if typing.TYPE_CHECKING: + import asyncio log = logging.getLogger(__name__) class HashBlobWriter: def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int], - finished: asyncio.Future): + finished: 'asyncio.Future[bytes]'): self.expected_blob_hash = expected_blob_hash self.get_length = get_length self.buffer = BytesIO() diff --git a/lbry/lbry/blob_exchange/client.py b/lbry/lbry/blob_exchange/client.py index e73c35ace..c8321beeb 100644 --- a/lbry/lbry/blob_exchange/client.py +++ b/lbry/lbry/blob_exchange/client.py @@ -85,6 +85,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self._blob_bytes_received += len(data) try: self.writer.write(data) + except IOError as err: log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err) if self._response_fut and not self._response_fut.done(): @@ -180,7 +181,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): return 0, self.transport try: self._blob_bytes_received = 0 - self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port) + self.blob, self.writer = blob, blob.get_blob_writer(self.transport) self._response_fut = asyncio.Future(loop=self.loop) return await self._download_blob() except OSError: diff --git a/lbry/lbry/stream/reflector/server.py b/lbry/lbry/stream/reflector/server.py index daadbbcd9..489a19003 100644 --- a/lbry/lbry/stream/reflector/server.py +++ b/lbry/lbry/stream/reflector/server.py @@ -72,7 +72,7 @@ class ReflectorServerProtocol(asyncio.Protocol): return self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) if not self.sd_blob.get_is_verified(): - self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) + self.writer = self.sd_blob.get_blob_writer(self.transport) self.incoming.set() self.send_response({"send_sd_blob": True}) try: @@ -111,7 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol): return blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) if not blob.get_is_verified(): - self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) + self.writer = blob.get_blob_writer(self.transport) self.incoming.set() self.send_response({"send_blob": True}) try: diff --git a/lbry/tests/unit/blob_exchange/test_transfer_blob.py b/lbry/tests/unit/blob_exchange/test_transfer_blob.py index ae9764939..3058cc158 100644 --- a/lbry/tests/unit/blob_exchange/test_transfer_blob.py +++ b/lbry/tests/unit/blob_exchange/test_transfer_blob.py @@ -135,12 +135,15 @@ class TestBlobExchange(BlobExchangeTestBase): write_blob(blob_bytes) blob._write_blob = wrap_write_blob - writer1 = blob.get_blob_writer(peer_port=1) - writer2 = blob.get_blob_writer(peer_port=2) + t1 = asyncio.Transport() + t2 = asyncio.Transport() + + writer1 = blob.get_blob_writer(t1) + writer2 = blob.get_blob_writer(t2) reader1_ctx_before_write = blob.reader_context() with self.assertRaises(OSError): - blob.get_blob_writer(peer_port=2) + blob.get_blob_writer(t2) with self.assertRaises(OSError): with blob.reader_context(): pass