diff --git a/lbry/dht/protocol/iterative_find.py b/lbry/dht/protocol/iterative_find.py index ab89edddc..2b1e70e7a 100644 --- a/lbry/dht/protocol/iterative_find.py +++ b/lbry/dht/protocol/iterative_find.py @@ -1,6 +1,7 @@ import asyncio from itertools import chain from collections import defaultdict, OrderedDict +from collections.abc import AsyncGenerator import typing import logging from typing import TYPE_CHECKING @@ -71,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes, return shortlist or routing_table.find_close_peers(key) -class IterativeFinder: +class IterativeFinder(AsyncGenerator): def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, max_results: typing.Optional[int] = constants.K, @@ -98,6 +99,8 @@ class IterativeFinder: self.iteration_count = 0 self.running = False self.tasks: typing.List[asyncio.Task] = [] + self.generator = None + for peer in get_shortlist(routing_table, key, shortlist): if peer.node_id: self._add_active(peer, force=True) @@ -163,12 +166,16 @@ class IterativeFinder: except asyncio.TimeoutError: self._reset_closest(peer) return + except asyncio.CancelledError: + log.debug("%s[%x] cancelled probe", + type(self).__name__, id(self)) + return except ValueError as err: log.warning(str(err)) self._reset_closest(peer) return except TransportNotConnected: - return self.aclose() + return self._aclose() except RemoteException: self._reset_closest(peer) return @@ -182,13 +189,17 @@ class IterativeFinder: added = 0 for index, peer in enumerate(self.active.keys()): if index == 0: - log.debug("closest to probe: %s", peer.node_id.hex()[:8]) + log.debug("%s[%x] closest to probe: %s", + type(self).__name__, id(self), + peer.node_id.hex()[:8]) if peer in self.contacted: continue if len(self.running_probes) >= constants.ALPHA: break if index > (constants.K + len(self.running_probes)): break + if self.iteration_count + self.iteration_queue.qsize() >= self.max_results: + break origin_address = (peer.address, peer.udp_port) if origin_address in self.exclude: continue @@ -198,9 +209,13 @@ class IterativeFinder: continue self._schedule_probe(peer) added += 1 - log.debug("running %d probes for key %s", len(self.running_probes), self.key.hex()[:8]) + log.debug("%s[%x] running %d probes for key %s", + type(self).__name__, id(self), + len(self.running_probes), self.key.hex()[:8]) if not added and not self.running_probes: - log.debug("search for %s exhausted", self.key.hex()[:8]) + log.debug("%s[%x] search for %s exhausted", + type(self).__name__, id(self), + self.key.hex()[:8]) self.search_exhausted() def _schedule_probe(self, peer: 'KademliaPeer'): @@ -217,38 +232,76 @@ class IterativeFinder: self.running_probes[peer] = t def _log_state(self): - log.debug("[%s] check result: %i active nodes %i contacted", - self.key.hex()[:8], len(self.active), len(self.contacted)) + log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued", + type(self).__name__, id(self), self.key.hex()[:8], + len(self.active), len(self.contacted), + self.iteration_count, self.iteration_queue.qsize()) + + async def _generator_func(self): + try: + while self.iteration_count < self.max_results: + if self.iteration_count == 0: + result = self.get_initial_result() or await self.iteration_queue.get() + else: + result = await self.iteration_queue.get() + if not result: + # no more results + await self._aclose(reason="no more results") + self.generator = None + return + self.iteration_count += 1 + yield result + # reached max_results limit + await self._aclose(reason="max_results reached") + self.generator = None + return + except asyncio.CancelledError: + await self._aclose(reason="cancelled") + self.generator = None + raise + except GeneratorExit: + await self._aclose(reason="generator exit") + self.generator = None + raise def __aiter__(self): if self.running: raise Exception("already running") self.running = True + self.generator = self._generator_func() self.loop.call_soon(self._search_round) - return self + return super().__aiter__() async def __anext__(self) -> typing.List['KademliaPeer']: - try: - if self.iteration_count == 0: - result = self.get_initial_result() or await self.iteration_queue.get() - else: - result = await self.iteration_queue.get() - if not result: - raise StopAsyncIteration - self.iteration_count += 1 - return result - except (asyncio.CancelledError, StopAsyncIteration): - self.loop.call_soon(self.aclose) - raise + return await super().__anext__() - def aclose(self): + async def asend(self, val): + return await self.generator.asend(val) + + async def athrow(self, typ, val=None, tb=None): + return await self.generator.athrow(typ, val, tb) + + async def _aclose(self, reason="?"): self.running = False - self.iteration_queue.put_nowait(None) - for task in chain(self.tasks, self.running_probes.values()): + running_tasks = list(chain(self.tasks, self.running_probes.values())) + for task in running_tasks: task.cancel() + if len(running_tasks): + await asyncio.wait(running_tasks, loop=self.loop) + log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued", + type(self).__name__, id(self), self.key.hex()[:8], + reason, len(self.active), len(self.contacted), + self.iteration_count, self.iteration_queue.qsize()) self.tasks.clear() self.running_probes.clear() + async def aclose(self): + if self.generator: + await super().aclose() + self.generator = None + log.debug("%s[%x] [%s] async close completed", + type(self).__name__, id(self), self.key.hex()[:8]) + class IterativeNodeFinder(IterativeFinder): def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',