Merge pull request #3599 from moodyjon/async-for-pr3504

Tighten up IterativeFinder async close behavior (DHT iterator continues after consumer breaks out of it)
This commit is contained in:
Victor Shyba 2022-05-23 11:12:40 -03:00 committed by GitHub
commit 1d95eb1549
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 45 deletions

View file

@ -5,7 +5,7 @@ import socket
from prometheus_client import Gauge from prometheus_client import Gauge
from lbry.utils import resolve_host from lbry.utils import aclosing, resolve_host
from lbry.dht import constants from lbry.dht import constants
from lbry.dht.peer import make_kademlia_peer from lbry.dht.peer import make_kademlia_peer
from lbry.dht.protocol.distance import Distance from lbry.dht.protocol.distance import Distance
@ -217,9 +217,10 @@ class Node:
shortlist: typing.Optional[typing.List['KademliaPeer']] = None shortlist: typing.Optional[typing.List['KademliaPeer']] = None
) -> typing.List['KademliaPeer']: ) -> typing.List['KademliaPeer']:
peers = [] peers = []
async for iteration_peers in self.get_iterative_node_finder( async with aclosing(self.get_iterative_node_finder(
node_id, shortlist=shortlist, max_results=max_results): node_id, shortlist=shortlist, max_results=max_results)) as node_finder:
peers.extend(iteration_peers) async for iteration_peers in node_finder:
peers.extend(iteration_peers)
distance = Distance(node_id) distance = Distance(node_id)
peers.sort(key=lambda peer: distance(peer.node_id)) peers.sort(key=lambda peer: distance(peer.node_id))
return peers[:count] return peers[:count]
@ -245,36 +246,36 @@ class Node:
# prioritize peers who reply to a dht ping first # prioritize peers who reply to a dht ping first
# this minimizes attempting to make tcp connections that won't work later to dead or unreachable peers # this minimizes attempting to make tcp connections that won't work later to dead or unreachable peers
async with aclosing(self.get_iterative_value_finder(bytes.fromhex(blob_hash))) as value_finder:
async for results in self.get_iterative_value_finder(bytes.fromhex(blob_hash)): async for results in value_finder:
to_put = [] to_put = []
for peer in results: for peer in results:
if peer.address == self.protocol.external_ip and self.protocol.peer_port == peer.tcp_port: if peer.address == self.protocol.external_ip and self.protocol.peer_port == peer.tcp_port:
continue continue
is_good = self.protocol.peer_manager.peer_is_good(peer) is_good = self.protocol.peer_manager.peer_is_good(peer)
if is_good: if is_good:
# the peer has replied recently over UDP, it can probably be reached on the TCP port # the peer has replied recently over UDP, it can probably be reached on the TCP port
to_put.append(peer) to_put.append(peer)
elif is_good is None: elif is_good is None:
if not peer.udp_port: if not peer.udp_port:
# TODO: use the same port for TCP and UDP # TODO: use the same port for TCP and UDP
# the udp port must be guessed # the udp port must be guessed
# default to the ports being the same. if the TCP port appears to be <=0.48.0 default, # default to the ports being the same. if the TCP port appears to be <=0.48.0 default,
# including on a network with several nodes, then assume the udp port is proportionately # including on a network with several nodes, then assume the udp port is proportionately
# based on a starting port of 4444 # based on a starting port of 4444
udp_port_to_try = peer.tcp_port udp_port_to_try = peer.tcp_port
if 3400 > peer.tcp_port > 3332: if 3400 > peer.tcp_port > 3332:
udp_port_to_try = (peer.tcp_port - 3333) + 4444 udp_port_to_try = (peer.tcp_port - 3333) + 4444
self.loop.create_task(put_into_result_queue_after_pong( self.loop.create_task(put_into_result_queue_after_pong(
make_kademlia_peer(peer.node_id, peer.address, udp_port_to_try, peer.tcp_port) make_kademlia_peer(peer.node_id, peer.address, udp_port_to_try, peer.tcp_port)
)) ))
else:
self.loop.create_task(put_into_result_queue_after_pong(peer))
else: else:
self.loop.create_task(put_into_result_queue_after_pong(peer)) # the peer is known to be bad/unreachable, skip trying to connect to it over TCP
else: log.debug("skip bad peer %s:%i for %s", peer.address, peer.tcp_port, blob_hash)
# the peer is known to be bad/unreachable, skip trying to connect to it over TCP if to_put:
log.debug("skip bad peer %s:%i for %s", peer.address, peer.tcp_port, blob_hash) result_queue.put_nowait(to_put)
if to_put:
result_queue.put_nowait(to_put)
def accumulate_peers(self, search_queue: asyncio.Queue, def accumulate_peers(self, search_queue: asyncio.Queue,
peer_queue: typing.Optional[asyncio.Queue] = None peer_queue: typing.Optional[asyncio.Queue] = None

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
from itertools import chain from itertools import chain
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from collections.abc import AsyncIterator
import typing import typing
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -71,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
return shortlist or routing_table.find_close_peers(key) return shortlist or routing_table.find_close_peers(key)
class IterativeFinder: class IterativeFinder(AsyncIterator):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K, max_results: typing.Optional[int] = constants.K,
@ -151,7 +152,7 @@ class IterativeFinder:
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address, log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address,
peer.udp_port, address, udp_port) peer.udp_port, address, udp_port)
self.check_result_ready(response) self.check_result_ready(response)
self._log_state() self._log_state(reason="check result")
def _reset_closest(self, peer): def _reset_closest(self, peer):
if peer in self.active: if peer in self.active:
@ -163,12 +164,17 @@ class IterativeFinder:
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._reset_closest(peer) self._reset_closest(peer)
return return
except asyncio.CancelledError:
log.debug("%s[%x] cancelled probe",
type(self).__name__, id(self))
raise
except ValueError as err: except ValueError as err:
log.warning(str(err)) log.warning(str(err))
self._reset_closest(peer) self._reset_closest(peer)
return return
except TransportNotConnected: except TransportNotConnected:
return self.aclose() await self._aclose(reason="not connected")
return
except RemoteException: except RemoteException:
self._reset_closest(peer) self._reset_closest(peer)
return return
@ -182,7 +188,9 @@ class IterativeFinder:
added = 0 added = 0
for index, peer in enumerate(self.active.keys()): for index, peer in enumerate(self.active.keys()):
if index == 0: if index == 0:
log.debug("closest to probe: %s", peer.node_id.hex()[:8]) log.debug("%s[%x] closest to probe: %s",
type(self).__name__, id(self),
peer.node_id.hex()[:8])
if peer in self.contacted: if peer in self.contacted:
continue continue
if len(self.running_probes) >= constants.ALPHA: if len(self.running_probes) >= constants.ALPHA:
@ -198,9 +206,13 @@ class IterativeFinder:
continue continue
self._schedule_probe(peer) self._schedule_probe(peer)
added += 1 added += 1
log.debug("running %d probes for key %s", len(self.running_probes), self.key.hex()[:8]) log.debug("%s[%x] running %d probes for key %s",
type(self).__name__, id(self),
len(self.running_probes), self.key.hex()[:8])
if not added and not self.running_probes: if not added and not self.running_probes:
log.debug("search for %s exhausted", self.key.hex()[:8]) log.debug("%s[%x] search for %s exhausted",
type(self).__name__, id(self),
self.key.hex()[:8])
self.search_exhausted() self.search_exhausted()
def _schedule_probe(self, peer: 'KademliaPeer'): def _schedule_probe(self, peer: 'KademliaPeer'):
@ -216,9 +228,11 @@ class IterativeFinder:
t.add_done_callback(callback) t.add_done_callback(callback)
self.running_probes[peer] = t self.running_probes[peer] = t
def _log_state(self): def _log_state(self, reason="?"):
log.debug("[%s] check result: %i active nodes %i contacted", log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued",
self.key.hex()[:8], len(self.active), len(self.contacted)) type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
def __aiter__(self): def __aiter__(self):
if self.running: if self.running:
@ -237,11 +251,18 @@ class IterativeFinder:
raise StopAsyncIteration raise StopAsyncIteration
self.iteration_count += 1 self.iteration_count += 1
return result return result
except (asyncio.CancelledError, StopAsyncIteration): except asyncio.CancelledError:
self.loop.call_soon(self.aclose) await self._aclose(reason="cancelled")
raise
except StopAsyncIteration:
await self._aclose(reason="no more results")
raise raise
def aclose(self): async def _aclose(self, reason="?"):
log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
self.running = False self.running = False
self.iteration_queue.put_nowait(None) self.iteration_queue.put_nowait(None)
for task in chain(self.tasks, self.running_probes.values()): for task in chain(self.tasks, self.running_probes.values()):
@ -249,6 +270,11 @@ class IterativeFinder:
self.tasks.clear() self.tasks.clear()
self.running_probes.clear() self.running_probes.clear()
async def aclose(self):
if self.running:
await self._aclose(reason="aclose")
log.debug("%s[%x] [%s] async close completed",
type(self).__name__, id(self), self.key.hex()[:8])
class IterativeNodeFinder(IterativeFinder): class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',

View file

@ -130,6 +130,16 @@ def get_sd_hash(stream_info):
def json_dumps_pretty(obj, **kwargs): def json_dumps_pretty(obj, **kwargs):
return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs) return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)
try:
# the standard contextlib.aclosing() is available in 3.10+
from contextlib import aclosing # pylint: disable=unused-import
except ImportError:
@contextlib.asynccontextmanager
async def aclosing(thing):
try:
yield thing
finally:
await thing.aclose()
def async_timed_cache(duration: int): def async_timed_cache(duration: int):
def wrapper(func): def wrapper(func):