diff --git a/lbry/dht/node.py b/lbry/dht/node.py index 3f3c73776..2bb6edfcd 100644 --- a/lbry/dht/node.py +++ b/lbry/dht/node.py @@ -5,7 +5,7 @@ import socket from prometheus_client import Gauge -from lbry.utils import resolve_host +from lbry.utils import aclosing, resolve_host from lbry.dht import constants from lbry.dht.peer import make_kademlia_peer from lbry.dht.protocol.distance import Distance @@ -217,13 +217,10 @@ class Node: shortlist: typing.Optional[typing.List['KademliaPeer']] = None ) -> typing.List['KademliaPeer']: peers = [] - node_finder = self.get_iterative_node_finder( - node_id, shortlist=shortlist, max_results=max_results) - try: + async with aclosing(self.get_iterative_node_finder( + node_id, shortlist=shortlist, max_results=max_results)) as node_finder: async for iteration_peers in node_finder: peers.extend(iteration_peers) - finally: - await node_finder.aclose() distance = Distance(node_id) peers.sort(key=lambda peer: distance(peer.node_id)) return peers[:count] @@ -249,8 +246,7 @@ class Node: # prioritize peers who reply to a dht ping first # this minimizes attempting to make tcp connections that won't work later to dead or unreachable peers - value_finder = self.get_iterative_value_finder(bytes.fromhex(blob_hash)) - try: + async with aclosing(self.get_iterative_value_finder(bytes.fromhex(blob_hash))) as value_finder: async for results in value_finder: to_put = [] for peer in results: @@ -280,8 +276,6 @@ class Node: log.debug("skip bad peer %s:%i for %s", peer.address, peer.tcp_port, blob_hash) if to_put: result_queue.put_nowait(to_put) - finally: - await value_finder.aclose() def accumulate_peers(self, search_queue: asyncio.Queue, peer_queue: typing.Optional[asyncio.Queue] = None diff --git a/lbry/utils.py b/lbry/utils.py index 7a92ccc6a..cebba675e 100644 --- a/lbry/utils.py +++ b/lbry/utils.py @@ -130,6 +130,12 @@ def get_sd_hash(stream_info): def json_dumps_pretty(obj, **kwargs): return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs) +@contextlib.asynccontextmanager +async def aclosing(thing): + try: + yield thing + finally: + await thing.aclose() def async_timed_cache(duration: int): def wrapper(func):