Tighten up IterativeFinder logic to respect max_records better, and wait after task cancel().
Also make IterativeFinder a proper AsyncGenerator. This gives it an offically recognized aclose() method and could help with clean finalization.
This commit is contained in:
parent
5c708e1c6f
commit
b036961954
1 changed files with 76 additions and 23 deletions
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
@ -71,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
|
||||||
return shortlist or routing_table.find_close_peers(key)
|
return shortlist or routing_table.find_close_peers(key)
|
||||||
|
|
||||||
|
|
||||||
class IterativeFinder:
|
class IterativeFinder(AsyncGenerator):
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
|
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
|
||||||
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
||||||
max_results: typing.Optional[int] = constants.K,
|
max_results: typing.Optional[int] = constants.K,
|
||||||
|
@ -98,6 +99,8 @@ class IterativeFinder:
|
||||||
self.iteration_count = 0
|
self.iteration_count = 0
|
||||||
self.running = False
|
self.running = False
|
||||||
self.tasks: typing.List[asyncio.Task] = []
|
self.tasks: typing.List[asyncio.Task] = []
|
||||||
|
self.generator = None
|
||||||
|
|
||||||
for peer in get_shortlist(routing_table, key, shortlist):
|
for peer in get_shortlist(routing_table, key, shortlist):
|
||||||
if peer.node_id:
|
if peer.node_id:
|
||||||
self._add_active(peer, force=True)
|
self._add_active(peer, force=True)
|
||||||
|
@ -163,12 +166,16 @@ class IterativeFinder:
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self._reset_closest(peer)
|
self._reset_closest(peer)
|
||||||
return
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
log.debug("%s[%x] cancelled probe",
|
||||||
|
type(self).__name__, id(self))
|
||||||
|
return
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
log.warning(str(err))
|
log.warning(str(err))
|
||||||
self._reset_closest(peer)
|
self._reset_closest(peer)
|
||||||
return
|
return
|
||||||
except TransportNotConnected:
|
except TransportNotConnected:
|
||||||
return self.aclose()
|
return self._aclose()
|
||||||
except RemoteException:
|
except RemoteException:
|
||||||
self._reset_closest(peer)
|
self._reset_closest(peer)
|
||||||
return
|
return
|
||||||
|
@ -182,13 +189,17 @@ class IterativeFinder:
|
||||||
added = 0
|
added = 0
|
||||||
for index, peer in enumerate(self.active.keys()):
|
for index, peer in enumerate(self.active.keys()):
|
||||||
if index == 0:
|
if index == 0:
|
||||||
log.debug("closest to probe: %s", peer.node_id.hex()[:8])
|
log.debug("%s[%x] closest to probe: %s",
|
||||||
|
type(self).__name__, id(self),
|
||||||
|
peer.node_id.hex()[:8])
|
||||||
if peer in self.contacted:
|
if peer in self.contacted:
|
||||||
continue
|
continue
|
||||||
if len(self.running_probes) >= constants.ALPHA:
|
if len(self.running_probes) >= constants.ALPHA:
|
||||||
break
|
break
|
||||||
if index > (constants.K + len(self.running_probes)):
|
if index > (constants.K + len(self.running_probes)):
|
||||||
break
|
break
|
||||||
|
if self.iteration_count + self.iteration_queue.qsize() >= self.max_results:
|
||||||
|
break
|
||||||
origin_address = (peer.address, peer.udp_port)
|
origin_address = (peer.address, peer.udp_port)
|
||||||
if origin_address in self.exclude:
|
if origin_address in self.exclude:
|
||||||
continue
|
continue
|
||||||
|
@ -198,9 +209,13 @@ class IterativeFinder:
|
||||||
continue
|
continue
|
||||||
self._schedule_probe(peer)
|
self._schedule_probe(peer)
|
||||||
added += 1
|
added += 1
|
||||||
log.debug("running %d probes for key %s", len(self.running_probes), self.key.hex()[:8])
|
log.debug("%s[%x] running %d probes for key %s",
|
||||||
|
type(self).__name__, id(self),
|
||||||
|
len(self.running_probes), self.key.hex()[:8])
|
||||||
if not added and not self.running_probes:
|
if not added and not self.running_probes:
|
||||||
log.debug("search for %s exhausted", self.key.hex()[:8])
|
log.debug("%s[%x] search for %s exhausted",
|
||||||
|
type(self).__name__, id(self),
|
||||||
|
self.key.hex()[:8])
|
||||||
self.search_exhausted()
|
self.search_exhausted()
|
||||||
|
|
||||||
def _schedule_probe(self, peer: 'KademliaPeer'):
|
def _schedule_probe(self, peer: 'KademliaPeer'):
|
||||||
|
@ -217,38 +232,76 @@ class IterativeFinder:
|
||||||
self.running_probes[peer] = t
|
self.running_probes[peer] = t
|
||||||
|
|
||||||
def _log_state(self):
|
def _log_state(self):
|
||||||
log.debug("[%s] check result: %i active nodes %i contacted",
|
log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued",
|
||||||
self.key.hex()[:8], len(self.active), len(self.contacted))
|
type(self).__name__, id(self), self.key.hex()[:8],
|
||||||
|
len(self.active), len(self.contacted),
|
||||||
|
self.iteration_count, self.iteration_queue.qsize())
|
||||||
|
|
||||||
def __aiter__(self):
|
async def _generator_func(self):
|
||||||
if self.running:
|
|
||||||
raise Exception("already running")
|
|
||||||
self.running = True
|
|
||||||
self.loop.call_soon(self._search_round)
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self) -> typing.List['KademliaPeer']:
|
|
||||||
try:
|
try:
|
||||||
|
while self.iteration_count < self.max_results:
|
||||||
if self.iteration_count == 0:
|
if self.iteration_count == 0:
|
||||||
result = self.get_initial_result() or await self.iteration_queue.get()
|
result = self.get_initial_result() or await self.iteration_queue.get()
|
||||||
else:
|
else:
|
||||||
result = await self.iteration_queue.get()
|
result = await self.iteration_queue.get()
|
||||||
if not result:
|
if not result:
|
||||||
raise StopAsyncIteration
|
# no more results
|
||||||
|
await self._aclose(reason="no more results")
|
||||||
|
self.generator = None
|
||||||
|
return
|
||||||
self.iteration_count += 1
|
self.iteration_count += 1
|
||||||
return result
|
yield result
|
||||||
except (asyncio.CancelledError, StopAsyncIteration):
|
# reached max_results limit
|
||||||
self.loop.call_soon(self.aclose)
|
await self._aclose(reason="max_results reached")
|
||||||
|
self.generator = None
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await self._aclose(reason="cancelled")
|
||||||
|
self.generator = None
|
||||||
|
raise
|
||||||
|
except GeneratorExit:
|
||||||
|
await self._aclose(reason="generator exit")
|
||||||
|
self.generator = None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def aclose(self):
|
def __aiter__(self):
|
||||||
|
if self.running:
|
||||||
|
raise Exception("already running")
|
||||||
|
self.running = True
|
||||||
|
self.generator = self._generator_func()
|
||||||
|
self.loop.call_soon(self._search_round)
|
||||||
|
return super().__aiter__()
|
||||||
|
|
||||||
|
async def __anext__(self) -> typing.List['KademliaPeer']:
|
||||||
|
return await super().__anext__()
|
||||||
|
|
||||||
|
async def asend(self, val):
|
||||||
|
return await self.generator.asend(val)
|
||||||
|
|
||||||
|
async def athrow(self, typ, val=None, tb=None):
|
||||||
|
return await self.generator.athrow(typ, val, tb)
|
||||||
|
|
||||||
|
async def _aclose(self, reason="?"):
|
||||||
self.running = False
|
self.running = False
|
||||||
self.iteration_queue.put_nowait(None)
|
running_tasks = list(chain(self.tasks, self.running_probes.values()))
|
||||||
for task in chain(self.tasks, self.running_probes.values()):
|
for task in running_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
if len(running_tasks):
|
||||||
|
await asyncio.wait(running_tasks, loop=self.loop)
|
||||||
|
log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued",
|
||||||
|
type(self).__name__, id(self), self.key.hex()[:8],
|
||||||
|
reason, len(self.active), len(self.contacted),
|
||||||
|
self.iteration_count, self.iteration_queue.qsize())
|
||||||
self.tasks.clear()
|
self.tasks.clear()
|
||||||
self.running_probes.clear()
|
self.running_probes.clear()
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
if self.generator:
|
||||||
|
await super().aclose()
|
||||||
|
self.generator = None
|
||||||
|
log.debug("%s[%x] [%s] async close completed",
|
||||||
|
type(self).__name__, id(self), self.key.hex()[:8])
|
||||||
|
|
||||||
|
|
||||||
class IterativeNodeFinder(IterativeFinder):
|
class IterativeNodeFinder(IterativeFinder):
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
|
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
|
||||||
|
|
Loading…
Reference in a new issue