From c06ec6cd6999fb13d422c10186a0c51c6de94b98 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Fri, 8 Feb 2019 02:27:58 -0300 Subject: [PATCH] manage connections, reusing them when possible --- lbrynet/blob_exchange/client.py | 61 +++++++++++++++-------------- lbrynet/blob_exchange/downloader.py | 18 +++++++-- lbrynet/stream/downloader.py | 1 + 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index 09e14f507..59ed7b56f 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -74,7 +74,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): if self._response_fut and not self._response_fut.done(): self._response_fut.set_exception(err) - async def _download_blob(self) -> typing.Tuple[bool, bool]: + async def _download_blob(self) -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]: """ :return: download success (bool), keep connection (bool) """ @@ -92,24 +92,24 @@ class BlobExchangeClientProtocol(asyncio.Protocol): log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address, self.peer_port) log.warning(response.to_dict()) - return False, False + return False, self.close() elif availability_response.available_blobs and \ availability_response.available_blobs != [self.blob.blob_hash]: log.warning("blob availability response doesn't match our request from %s:%i", self.peer_address, self.peer_port) - return False, False + return False, self.close() if not price_response or price_response.blob_data_payment_rate != 'RATE_ACCEPTED': log.warning("data rate rejected by %s:%i", self.peer_address, self.peer_port) - return False, False + return False, self.close() if not blob_response or blob_response.error: log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port) - return False, True + return False, self.transport if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash: log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port) - return False, False + return False, self.close() if self.blob.length is not None and self.blob.length != blob_response.length: log.warning("incoming blob unexpected length from %s:%i", self.peer_address, self.peer_port) - return False, False + return False, self.close() msg = f"downloading {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}," \ f" timeout in {self.peer_timeout}" log.debug(msg) @@ -117,16 +117,14 @@ class BlobExchangeClientProtocol(asyncio.Protocol): await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop) log.info(msg) await self.blob.finished_writing.wait() - return True, True + return True, self.transport except asyncio.TimeoutError: - return False, False + return False, self.close() except (InvalidBlobHashError, InvalidDataError): log.warning("invalid blob from %s:%i", self.peer_address, self.peer_port) - return False, False - finally: - await self.close() + return False, self.close() - async def close(self): + def close(self): if self._response_fut and not self._response_fut.done(): self._response_fut.cancel() if self.writer and not self.writer.closed(): @@ -138,9 +136,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.transport.close() self.transport = None - async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, bool]: + async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]: if blob.get_is_verified(): - return False, True + return False, self.transport try: self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0 self._response_fut = asyncio.Future(loop=self.loop) @@ -148,14 +146,14 @@ class BlobExchangeClientProtocol(asyncio.Protocol): except OSError: log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port) # i'm not sure how to fix this race condition - jack - return False, True + return False, self.transport except asyncio.TimeoutError: if self._response_fut and not self._response_fut.done(): self._response_fut.cancel() - return False, False + self.close() + return False, None except asyncio.CancelledError: - if self._response_fut and not self._response_fut.done(): - self._response_fut.cancel() + self.close() raise def connection_made(self, transport: asyncio.Transport): @@ -166,24 +164,29 @@ class BlobExchangeClientProtocol(asyncio.Protocol): def connection_lost(self, reason): log.debug("connection lost to %s:%i (reason: %s, %s)", self.peer_address, self.peer_port, str(reason), str(type(reason))) - self.transport = None - self.loop.create_task(self.close()) + self.close() async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int, - peer_connect_timeout: float, blob_download_timeout: float) -> typing.Tuple[bool, bool]: + peer_connect_timeout: float, blob_download_timeout: float, + connected_transport: asyncio.Transport = None)\ + -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]: """ Returns [, ] """ - protocol = BlobExchangeClientProtocol(loop, blob_download_timeout) if blob.get_is_verified(): - return False, True + return False, connected_transport + protocol = BlobExchangeClientProtocol(loop, blob_download_timeout) + if connected_transport and not connected_transport.is_closing(): + connected_transport.set_protocol(protocol) + protocol.connection_made(connected_transport) + else: + connected_transport = None try: - await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), - peer_connect_timeout, loop=loop) + if not connected_transport: + await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), + peer_connect_timeout, loop=loop) return await protocol.download_blob(blob) except (asyncio.TimeoutError, ConnectionRefusedError, ConnectionAbortedError, OSError): - return False, False - finally: - await protocol.close() + return False, None diff --git a/lbrynet/blob_exchange/downloader.py b/lbrynet/blob_exchange/downloader.py index 86e50798d..7c62b1989 100644 --- a/lbrynet/blob_exchange/downloader.py +++ b/lbrynet/blob_exchange/downloader.py @@ -23,6 +23,7 @@ class BlobDownloader: self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls self.ignored: typing.Set['KademliaPeer'] = set() self.scores: typing.Dict['KademliaPeer', int] = {} + self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} def should_race_continue(self): if len(self.active_connections) >= self.config.max_connections_per_download: @@ -38,15 +39,19 @@ class BlobDownloader: if blob.get_is_verified(): return self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones - success, keep_connection = await request_blob( + transport = self.connections.get(peer) + success, transport = await request_blob( self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout, - self.config.blob_download_timeout + self.config.blob_download_timeout, connected_transport=transport ) - if not keep_connection and peer not in self.ignored: + if not transport and peer not in self.ignored: self.ignored.add(peer) log.debug("drop peer %s:%i", peer.address, peer.tcp_port) - elif keep_connection: + if peer in self.connections: + del self.connections[peer] + elif transport: log.debug("keep peer %s:%i", peer.address, peer.tcp_port) + self.connections[peer] = transport self.scores[peer] = (self.scores.get(peer, 0) + 2) if success else 0 async def new_peer_or_finished(self, blob: 'BlobFile'): @@ -107,6 +112,10 @@ class BlobDownloader: log.exception(e) raise e + def close(self): + for transport in self.connections.values(): + transport.close() + async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node', blob_hash: str) -> 'BlobFile': @@ -119,3 +128,4 @@ async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', finally: if accumulate_task and not accumulate_task.done(): accumulate_task.cancel() + downloader.close() diff --git a/lbrynet/stream/downloader.py b/lbrynet/stream/downloader.py index e9142c2c4..c0408f0ce 100644 --- a/lbrynet/stream/downloader.py +++ b/lbrynet/stream/downloader.py @@ -51,6 +51,7 @@ class StreamDownloader(StreamAssembler): async def after_finished(self): log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path) await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished') + self.blob_downloader.close() def stop(self): if self.accumulate_task: