diff --git a/lbry/dht/protocol/iterative_find.py b/lbry/dht/protocol/iterative_find.py index ab4a66e3d..9a557ef51 100644 --- a/lbry/dht/protocol/iterative_find.py +++ b/lbry/dht/protocol/iterative_find.py @@ -1,7 +1,7 @@ import asyncio from itertools import chain from collections import defaultdict, OrderedDict -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator import typing import logging from typing import TYPE_CHECKING @@ -72,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes, return shortlist or routing_table.find_close_peers(key) -class IterativeFinder(AsyncGenerator): +class IterativeFinder(AsyncIterator): def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, max_results: typing.Optional[int] = constants.K, @@ -99,8 +99,6 @@ class IterativeFinder(AsyncGenerator): self.iteration_count = 0 self.running = False self.tasks: typing.List[asyncio.Task] = [] - self.generator = None - for peer in get_shortlist(routing_table, key, shortlist): if peer.node_id: self._add_active(peer, force=True) @@ -154,7 +152,7 @@ class IterativeFinder(AsyncGenerator): log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address, peer.udp_port, address, udp_port) self.check_result_ready(response) - self._log_state() + self._log_state(reason="check result") def _reset_closest(self, peer): if peer in self.active: @@ -169,7 +167,7 @@ class IterativeFinder(AsyncGenerator): except asyncio.CancelledError: log.debug("%s[%x] cancelled probe", type(self).__name__, id(self)) - return + raise except ValueError as err: log.warning(str(err)) self._reset_closest(peer) @@ -199,8 +197,6 @@ class IterativeFinder(AsyncGenerator): break if index > (constants.K + len(self.running_probes)): break - if self.iteration_count + self.iteration_queue.qsize() >= self.max_results: - break origin_address = (peer.address, peer.udp_port) if origin_address in self.exclude: continue @@ -232,76 +228,54 @@ class IterativeFinder(AsyncGenerator): t.add_done_callback(callback) self.running_probes[peer] = t - def _log_state(self): - log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued", + def _log_state(self, reason="?"): + log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued", type(self).__name__, id(self), self.key.hex()[:8], - len(self.active), len(self.contacted), + reason, len(self.active), len(self.contacted), self.iteration_count, self.iteration_queue.qsize()) - async def _generator_func(self): - try: - while self.iteration_count < self.max_results: - if self.iteration_count == 0: - result = self.get_initial_result() or await self.iteration_queue.get() - else: - result = await self.iteration_queue.get() - if not result: - # no more results - await self._aclose(reason="no more results") - self.generator = None - return - self.iteration_count += 1 - yield result - # reached max_results limit - await self._aclose(reason="max_results reached") - self.generator = None - return - except asyncio.CancelledError: - await self._aclose(reason="cancelled") - self.generator = None - raise - except GeneratorExit: - await self._aclose(reason="generator exit") - self.generator = None - raise - def __aiter__(self): if self.running: raise Exception("already running") self.running = True - self.generator = self._generator_func() self.loop.call_soon(self._search_round) - return super().__aiter__() + return self async def __anext__(self) -> typing.List['KademliaPeer']: - return await super().__anext__() - - async def asend(self, value): - return await self.generator.asend(value) - - async def athrow(self, typ, val=None, tb=None): - return await self.generator.athrow(typ, val, tb) + try: + if self.iteration_count == 0: + result = self.get_initial_result() or await self.iteration_queue.get() + else: + result = await self.iteration_queue.get() + if not result: + raise StopAsyncIteration + self.iteration_count += 1 + return result + except asyncio.CancelledError: + await self._aclose(reason="cancelled") + raise + except StopAsyncIteration: + await self._aclose(reason="no more results") + raise async def _aclose(self, reason="?"): - self.running = False - running_tasks = list(chain(self.tasks, self.running_probes.values())) - for task in running_tasks: - task.cancel() - log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued", + log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued", type(self).__name__, id(self), self.key.hex()[:8], reason, len(self.active), len(self.contacted), self.iteration_count, self.iteration_queue.qsize()) + self.running = False + self.iteration_queue.put_nowait(None) + for task in chain(self.tasks, self.running_probes.values()): + task.cancel() self.tasks.clear() self.running_probes.clear() async def aclose(self): - if self.generator: - await super().aclose() - self.generator = None + if self.running: + await self._aclose(reason="aclose") log.debug("%s[%x] [%s] async close completed", type(self).__name__, id(self), self.key.hex()[:8]) - class IterativeNodeFinder(IterativeFinder): def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,