Tighten up IterativeFinder logic to respect max_records better, and wait after task cancel().

Also make IterativeFinder a proper AsyncGenerator. This gives it an offically recognized aclose() method and could help with clean finalization.
This commit is contained in:
Jonathan Moody 2022-04-11 18:17:16 -04:00
parent 5c708e1c6f
commit b036961954

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
from itertools import chain from itertools import chain
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from collections.abc import AsyncGenerator
import typing import typing
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -71,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
return shortlist or routing_table.find_close_peers(key) return shortlist or routing_table.find_close_peers(key)
class IterativeFinder: class IterativeFinder(AsyncGenerator):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K, max_results: typing.Optional[int] = constants.K,
@ -98,6 +99,8 @@ class IterativeFinder:
self.iteration_count = 0 self.iteration_count = 0
self.running = False self.running = False
self.tasks: typing.List[asyncio.Task] = [] self.tasks: typing.List[asyncio.Task] = []
self.generator = None
for peer in get_shortlist(routing_table, key, shortlist): for peer in get_shortlist(routing_table, key, shortlist):
if peer.node_id: if peer.node_id:
self._add_active(peer, force=True) self._add_active(peer, force=True)
@ -163,12 +166,16 @@ class IterativeFinder:
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._reset_closest(peer) self._reset_closest(peer)
return return
except asyncio.CancelledError:
log.debug("%s[%x] cancelled probe",
type(self).__name__, id(self))
return
except ValueError as err: except ValueError as err:
log.warning(str(err)) log.warning(str(err))
self._reset_closest(peer) self._reset_closest(peer)
return return
except TransportNotConnected: except TransportNotConnected:
return self.aclose() return self._aclose()
except RemoteException: except RemoteException:
self._reset_closest(peer) self._reset_closest(peer)
return return
@ -182,13 +189,17 @@ class IterativeFinder:
added = 0 added = 0
for index, peer in enumerate(self.active.keys()): for index, peer in enumerate(self.active.keys()):
if index == 0: if index == 0:
log.debug("closest to probe: %s", peer.node_id.hex()[:8]) log.debug("%s[%x] closest to probe: %s",
type(self).__name__, id(self),
peer.node_id.hex()[:8])
if peer in self.contacted: if peer in self.contacted:
continue continue
if len(self.running_probes) >= constants.ALPHA: if len(self.running_probes) >= constants.ALPHA:
break break
if index > (constants.K + len(self.running_probes)): if index > (constants.K + len(self.running_probes)):
break break
if self.iteration_count + self.iteration_queue.qsize() >= self.max_results:
break
origin_address = (peer.address, peer.udp_port) origin_address = (peer.address, peer.udp_port)
if origin_address in self.exclude: if origin_address in self.exclude:
continue continue
@ -198,9 +209,13 @@ class IterativeFinder:
continue continue
self._schedule_probe(peer) self._schedule_probe(peer)
added += 1 added += 1
log.debug("running %d probes for key %s", len(self.running_probes), self.key.hex()[:8]) log.debug("%s[%x] running %d probes for key %s",
type(self).__name__, id(self),
len(self.running_probes), self.key.hex()[:8])
if not added and not self.running_probes: if not added and not self.running_probes:
log.debug("search for %s exhausted", self.key.hex()[:8]) log.debug("%s[%x] search for %s exhausted",
type(self).__name__, id(self),
self.key.hex()[:8])
self.search_exhausted() self.search_exhausted()
def _schedule_probe(self, peer: 'KademliaPeer'): def _schedule_probe(self, peer: 'KademliaPeer'):
@ -217,38 +232,76 @@ class IterativeFinder:
self.running_probes[peer] = t self.running_probes[peer] = t
def _log_state(self): def _log_state(self):
log.debug("[%s] check result: %i active nodes %i contacted", log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued",
self.key.hex()[:8], len(self.active), len(self.contacted)) type(self).__name__, id(self), self.key.hex()[:8],
len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
async def _generator_func(self):
try:
while self.iteration_count < self.max_results:
if self.iteration_count == 0:
result = self.get_initial_result() or await self.iteration_queue.get()
else:
result = await self.iteration_queue.get()
if not result:
# no more results
await self._aclose(reason="no more results")
self.generator = None
return
self.iteration_count += 1
yield result
# reached max_results limit
await self._aclose(reason="max_results reached")
self.generator = None
return
except asyncio.CancelledError:
await self._aclose(reason="cancelled")
self.generator = None
raise
except GeneratorExit:
await self._aclose(reason="generator exit")
self.generator = None
raise
def __aiter__(self): def __aiter__(self):
if self.running: if self.running:
raise Exception("already running") raise Exception("already running")
self.running = True self.running = True
self.generator = self._generator_func()
self.loop.call_soon(self._search_round) self.loop.call_soon(self._search_round)
return self return super().__aiter__()
async def __anext__(self) -> typing.List['KademliaPeer']: async def __anext__(self) -> typing.List['KademliaPeer']:
try: return await super().__anext__()
if self.iteration_count == 0:
result = self.get_initial_result() or await self.iteration_queue.get()
else:
result = await self.iteration_queue.get()
if not result:
raise StopAsyncIteration
self.iteration_count += 1
return result
except (asyncio.CancelledError, StopAsyncIteration):
self.loop.call_soon(self.aclose)
raise
def aclose(self): async def asend(self, val):
return await self.generator.asend(val)
async def athrow(self, typ, val=None, tb=None):
return await self.generator.athrow(typ, val, tb)
async def _aclose(self, reason="?"):
self.running = False self.running = False
self.iteration_queue.put_nowait(None) running_tasks = list(chain(self.tasks, self.running_probes.values()))
for task in chain(self.tasks, self.running_probes.values()): for task in running_tasks:
task.cancel() task.cancel()
if len(running_tasks):
await asyncio.wait(running_tasks, loop=self.loop)
log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
self.tasks.clear() self.tasks.clear()
self.running_probes.clear() self.running_probes.clear()
async def aclose(self):
if self.generator:
await super().aclose()
self.generator = None
log.debug("%s[%x] [%s] async close completed",
type(self).__name__, id(self), self.key.hex()[:8])
class IterativeNodeFinder(IterativeFinder): class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',