Simplify by eliminating AsyncGenerator base and generator function. Remove any new places enforcing max_results.

This commit is contained in:
Jonathan Moody 2022-05-03 16:41:32 -04:00
parent 530f9c72ea
commit e5e9873f79

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
from itertools import chain from itertools import chain
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from collections.abc import AsyncGenerator from collections.abc import AsyncIterator
import typing import typing
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -72,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
return shortlist or routing_table.find_close_peers(key) return shortlist or routing_table.find_close_peers(key)
class IterativeFinder(AsyncGenerator): class IterativeFinder(AsyncIterator):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K, max_results: typing.Optional[int] = constants.K,
@ -99,8 +99,6 @@ class IterativeFinder(AsyncGenerator):
self.iteration_count = 0 self.iteration_count = 0
self.running = False self.running = False
self.tasks: typing.List[asyncio.Task] = [] self.tasks: typing.List[asyncio.Task] = []
self.generator = None
for peer in get_shortlist(routing_table, key, shortlist): for peer in get_shortlist(routing_table, key, shortlist):
if peer.node_id: if peer.node_id:
self._add_active(peer, force=True) self._add_active(peer, force=True)
@ -154,7 +152,7 @@ class IterativeFinder(AsyncGenerator):
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address, log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address,
peer.udp_port, address, udp_port) peer.udp_port, address, udp_port)
self.check_result_ready(response) self.check_result_ready(response)
self._log_state() self._log_state(reason="check result")
def _reset_closest(self, peer): def _reset_closest(self, peer):
if peer in self.active: if peer in self.active:
@ -169,7 +167,7 @@ class IterativeFinder(AsyncGenerator):
except asyncio.CancelledError: except asyncio.CancelledError:
log.debug("%s[%x] cancelled probe", log.debug("%s[%x] cancelled probe",
type(self).__name__, id(self)) type(self).__name__, id(self))
return raise
except ValueError as err: except ValueError as err:
log.warning(str(err)) log.warning(str(err))
self._reset_closest(peer) self._reset_closest(peer)
@ -199,8 +197,6 @@ class IterativeFinder(AsyncGenerator):
break break
if index > (constants.K + len(self.running_probes)): if index > (constants.K + len(self.running_probes)):
break break
if self.iteration_count + self.iteration_queue.qsize() >= self.max_results:
break
origin_address = (peer.address, peer.udp_port) origin_address = (peer.address, peer.udp_port)
if origin_address in self.exclude: if origin_address in self.exclude:
continue continue
@ -232,76 +228,54 @@ class IterativeFinder(AsyncGenerator):
t.add_done_callback(callback) t.add_done_callback(callback)
self.running_probes[peer] = t self.running_probes[peer] = t
def _log_state(self): def _log_state(self, reason="?"):
log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued", log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8], type(self).__name__, id(self), self.key.hex()[:8],
len(self.active), len(self.contacted), reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize()) self.iteration_count, self.iteration_queue.qsize())
async def _generator_func(self):
try:
while self.iteration_count < self.max_results:
if self.iteration_count == 0:
result = self.get_initial_result() or await self.iteration_queue.get()
else:
result = await self.iteration_queue.get()
if not result:
# no more results
await self._aclose(reason="no more results")
self.generator = None
return
self.iteration_count += 1
yield result
# reached max_results limit
await self._aclose(reason="max_results reached")
self.generator = None
return
except asyncio.CancelledError:
await self._aclose(reason="cancelled")
self.generator = None
raise
except GeneratorExit:
await self._aclose(reason="generator exit")
self.generator = None
raise
def __aiter__(self): def __aiter__(self):
if self.running: if self.running:
raise Exception("already running") raise Exception("already running")
self.running = True self.running = True
self.generator = self._generator_func()
self.loop.call_soon(self._search_round) self.loop.call_soon(self._search_round)
return super().__aiter__() return self
async def __anext__(self) -> typing.List['KademliaPeer']: async def __anext__(self) -> typing.List['KademliaPeer']:
return await super().__anext__() try:
if self.iteration_count == 0:
async def asend(self, value): result = self.get_initial_result() or await self.iteration_queue.get()
return await self.generator.asend(value) else:
result = await self.iteration_queue.get()
async def athrow(self, typ, val=None, tb=None): if not result:
return await self.generator.athrow(typ, val, tb) raise StopAsyncIteration
self.iteration_count += 1
return result
except asyncio.CancelledError:
await self._aclose(reason="cancelled")
raise
except StopAsyncIteration:
await self._aclose(reason="no more results")
raise
async def _aclose(self, reason="?"): async def _aclose(self, reason="?"):
self.running = False log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued",
running_tasks = list(chain(self.tasks, self.running_probes.values()))
for task in running_tasks:
task.cancel()
log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8], type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted), reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize()) self.iteration_count, self.iteration_queue.qsize())
self.running = False
self.iteration_queue.put_nowait(None)
for task in chain(self.tasks, self.running_probes.values()):
task.cancel()
self.tasks.clear() self.tasks.clear()
self.running_probes.clear() self.running_probes.clear()
async def aclose(self): async def aclose(self):
if self.generator: if self.running:
await super().aclose() await self._aclose(reason="aclose")
self.generator = None
log.debug("%s[%x] [%s] async close completed", log.debug("%s[%x] [%s] async close completed",
type(self).__name__, id(self), self.key.hex()[:8]) type(self).__name__, id(self), self.key.hex()[:8])
class IterativeNodeFinder(IterativeFinder): class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,