forked from LBRYCommunity/lbry-sdk
Simplify by eliminating AsyncGenerator base and generator function. Remove any new places enforcing max_results.
This commit is contained in:
parent
530f9c72ea
commit
e5e9873f79
1 changed files with 30 additions and 56 deletions
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from itertools import chain
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
import typing
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -72,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
|
|||
return shortlist or routing_table.find_close_peers(key)
|
||||
|
||||
|
||||
class IterativeFinder(AsyncGenerator):
|
||||
class IterativeFinder(AsyncIterator):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
|
||||
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
||||
max_results: typing.Optional[int] = constants.K,
|
||||
|
@ -99,8 +99,6 @@ class IterativeFinder(AsyncGenerator):
|
|||
self.iteration_count = 0
|
||||
self.running = False
|
||||
self.tasks: typing.List[asyncio.Task] = []
|
||||
self.generator = None
|
||||
|
||||
for peer in get_shortlist(routing_table, key, shortlist):
|
||||
if peer.node_id:
|
||||
self._add_active(peer, force=True)
|
||||
|
@ -154,7 +152,7 @@ class IterativeFinder(AsyncGenerator):
|
|||
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address,
|
||||
peer.udp_port, address, udp_port)
|
||||
self.check_result_ready(response)
|
||||
self._log_state()
|
||||
self._log_state(reason="check result")
|
||||
|
||||
def _reset_closest(self, peer):
|
||||
if peer in self.active:
|
||||
|
@ -169,7 +167,7 @@ class IterativeFinder(AsyncGenerator):
|
|||
except asyncio.CancelledError:
|
||||
log.debug("%s[%x] cancelled probe",
|
||||
type(self).__name__, id(self))
|
||||
return
|
||||
raise
|
||||
except ValueError as err:
|
||||
log.warning(str(err))
|
||||
self._reset_closest(peer)
|
||||
|
@ -199,8 +197,6 @@ class IterativeFinder(AsyncGenerator):
|
|||
break
|
||||
if index > (constants.K + len(self.running_probes)):
|
||||
break
|
||||
if self.iteration_count + self.iteration_queue.qsize() >= self.max_results:
|
||||
break
|
||||
origin_address = (peer.address, peer.udp_port)
|
||||
if origin_address in self.exclude:
|
||||
continue
|
||||
|
@ -232,76 +228,54 @@ class IterativeFinder(AsyncGenerator):
|
|||
t.add_done_callback(callback)
|
||||
self.running_probes[peer] = t
|
||||
|
||||
def _log_state(self):
|
||||
log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued",
|
||||
def _log_state(self, reason="?"):
|
||||
log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued",
|
||||
type(self).__name__, id(self), self.key.hex()[:8],
|
||||
len(self.active), len(self.contacted),
|
||||
reason, len(self.active), len(self.contacted),
|
||||
self.iteration_count, self.iteration_queue.qsize())
|
||||
|
||||
async def _generator_func(self):
|
||||
try:
|
||||
while self.iteration_count < self.max_results:
|
||||
if self.iteration_count == 0:
|
||||
result = self.get_initial_result() or await self.iteration_queue.get()
|
||||
else:
|
||||
result = await self.iteration_queue.get()
|
||||
if not result:
|
||||
# no more results
|
||||
await self._aclose(reason="no more results")
|
||||
self.generator = None
|
||||
return
|
||||
self.iteration_count += 1
|
||||
yield result
|
||||
# reached max_results limit
|
||||
await self._aclose(reason="max_results reached")
|
||||
self.generator = None
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
await self._aclose(reason="cancelled")
|
||||
self.generator = None
|
||||
raise
|
||||
except GeneratorExit:
|
||||
await self._aclose(reason="generator exit")
|
||||
self.generator = None
|
||||
raise
|
||||
|
||||
def __aiter__(self):
|
||||
if self.running:
|
||||
raise Exception("already running")
|
||||
self.running = True
|
||||
self.generator = self._generator_func()
|
||||
self.loop.call_soon(self._search_round)
|
||||
return super().__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> typing.List['KademliaPeer']:
|
||||
return await super().__anext__()
|
||||
|
||||
async def asend(self, value):
|
||||
return await self.generator.asend(value)
|
||||
|
||||
async def athrow(self, typ, val=None, tb=None):
|
||||
return await self.generator.athrow(typ, val, tb)
|
||||
try:
|
||||
if self.iteration_count == 0:
|
||||
result = self.get_initial_result() or await self.iteration_queue.get()
|
||||
else:
|
||||
result = await self.iteration_queue.get()
|
||||
if not result:
|
||||
raise StopAsyncIteration
|
||||
self.iteration_count += 1
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
await self._aclose(reason="cancelled")
|
||||
raise
|
||||
except StopAsyncIteration:
|
||||
await self._aclose(reason="no more results")
|
||||
raise
|
||||
|
||||
async def _aclose(self, reason="?"):
|
||||
self.running = False
|
||||
running_tasks = list(chain(self.tasks, self.running_probes.values()))
|
||||
for task in running_tasks:
|
||||
task.cancel()
|
||||
log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued",
|
||||
log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued",
|
||||
type(self).__name__, id(self), self.key.hex()[:8],
|
||||
reason, len(self.active), len(self.contacted),
|
||||
self.iteration_count, self.iteration_queue.qsize())
|
||||
self.running = False
|
||||
self.iteration_queue.put_nowait(None)
|
||||
for task in chain(self.tasks, self.running_probes.values()):
|
||||
task.cancel()
|
||||
self.tasks.clear()
|
||||
self.running_probes.clear()
|
||||
|
||||
async def aclose(self):
|
||||
if self.generator:
|
||||
await super().aclose()
|
||||
self.generator = None
|
||||
if self.running:
|
||||
await self._aclose(reason="aclose")
|
||||
log.debug("%s[%x] [%s] async close completed",
|
||||
type(self).__name__, id(self), self.key.hex()[:8])
|
||||
|
||||
|
||||
class IterativeNodeFinder(IterativeFinder):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
|
||||
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
||||
|
|
Loading…
Reference in a new issue