diff --git a/lbrynet/dht/protocol/iterative_find.py b/lbrynet/dht/protocol/iterative_find.py index 8b899013b..f0b905295 100644 --- a/lbrynet/dht/protocol/iterative_find.py +++ b/lbrynet/dht/protocol/iterative_find.py @@ -1,7 +1,9 @@ import asyncio +from binascii import hexlify +from itertools import chain + import typing import logging -from lbrynet.utils import drain_tasks from lbrynet.dht import constants from lbrynet.dht.error import RemoteException from lbrynet.dht.protocol.distance import Distance @@ -90,8 +92,8 @@ class IterativeFinder: self.exclude = exclude or [] self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist) - self.active: typing.List['KademliaPeer'] = [] - self.contacted: typing.List[typing.Tuple[str, int]] = [] + self.active: typing.Set['KademliaPeer'] = set() + self.contacted: typing.Set[typing.Tuple[str, int]] = set() self.distance = Distance(key) self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0] @@ -99,14 +101,12 @@ class IterativeFinder: self.iteration_queue = asyncio.Queue(loop=self.loop) - self.running_probes: typing.List[asyncio.Task] = [] - self.lock = asyncio.Lock(loop=self.loop) + self.running_probes: typing.Set[asyncio.Task] = set() self.iteration_count = 0 self.bottom_out_count = 0 self.running = False self.tasks: typing.List[asyncio.Task] = [] self.delayed_calls: typing.List[asyncio.Handle] = [] - self.finished = asyncio.Event(loop=self.loop) async def send_probe(self, peer: 'KademliaPeer') -> FindResponse: """ @@ -114,9 +114,16 @@ class IterativeFinder: """ raise NotImplementedError() + def search_exhausted(self): + """ + This method ends the iterator due no more peers to contact. + Override to provide last time results. + """ + self.iteration_queue.put_nowait(None) + def check_result_ready(self, response: FindResponse): """ - Called with a lock after adding peers from an rpc result to the shortlist. + Called after adding peers from an rpc result to the shortlist. This method is responsible for putting a result for the generator into the Queue """ raise NotImplementedError() @@ -129,9 +136,7 @@ class IterativeFinder: return [] def _is_closer(self, peer: 'KademliaPeer') -> bool: - if not self.closest_peer: - return True - return self.distance.is_closer(peer.node_id, self.closest_peer.node_id) + return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id) def _update_closest(self): self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True) @@ -141,21 +146,18 @@ class IterativeFinder: self.closest_peer = self.shortlist[-1] async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse): - async with self.lock: - if peer not in self.shortlist: - self.shortlist.append(peer) - if peer not in self.active: - self.active.append(peer) - for contact_triple in response.get_close_triples(): - addr_tuple = (contact_triple[1], contact_triple[2]) - if addr_tuple not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple) - found_peer = self.peer_manager.get_kademlia_peer( - contact_triple[0], contact_triple[1], contact_triple[2] - ) - if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False: - self.shortlist.append(found_peer) - self._update_closest() - self.check_result_ready(response) + if peer not in self.shortlist: + self.shortlist.append(peer) + if peer not in self.active: + self.active.add(peer) + for contact_triple in response.get_close_triples(): + node_id, address, udp_port = contact_triple + if (address, udp_port) not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple) + found_peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port) + if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False: + self.shortlist.append(found_peer) + self._update_closest() + self.check_result_ready(response) async def _send_probe(self, peer: 'KademliaPeer'): try: @@ -163,13 +165,11 @@ class IterativeFinder: except asyncio.CancelledError: return except asyncio.TimeoutError: - if peer in self.active: - self.active.remove(peer) + self.active.discard(peer) return except ValueError as err: log.warning(str(err)) - if peer in self.active: - self.active.remove(peer) + self.active.discard(peer) return except RemoteException: return @@ -181,31 +181,35 @@ class IterativeFinder: """ added = 0 - async with self.lock: - self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True) - while self.running and len(self.shortlist) and added < constants.alpha: - peer = self.shortlist.pop() - origin_address = (peer.address, peer.udp_port) - if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False: - continue - if peer.node_id == self.protocol.node_id: - continue - if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port): - continue - if (peer.address, peer.udp_port) not in self.contacted: - self.contacted.append((peer.address, peer.udp_port)) + self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True) + while self.running and len(self.shortlist) and added < constants.alpha: + peer = self.shortlist.pop() + origin_address = (peer.address, peer.udp_port) + if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False: + continue + if peer.node_id == self.protocol.node_id: + continue + if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port): + continue + if (peer.address, peer.udp_port) not in self.contacted: + self.contacted.add((peer.address, peer.udp_port)) - t = self.loop.create_task(self._send_probe(peer)) + t = self.loop.create_task(self._send_probe(peer)) - def callback(_): - if t and t in self.running_probes: - self.running_probes.remove(t) - if not self.running_probes and self.shortlist: - self.tasks.append(self.loop.create_task(self._search_task(0.0))) + def callback(_): + self.running_probes.difference_update({ + probe for probe in self.running_probes if probe.done() or probe == t + }) + if not self.running_probes and self.shortlist: + self.tasks.append(self.loop.create_task(self._search_task(0.0))) - t.add_done_callback(callback) - self.running_probes.append(t) - added += 1 + t.add_done_callback(callback) + self.running_probes.add(t) + added += 1 + log.debug("running %d probes", len(self.running_probes)) + if not added and not self.running_probes: + log.debug("search for %s exhausted", hexlify(self.key)[:8]) + self.search_exhausted() async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay): try: @@ -215,70 +219,41 @@ class IterativeFinder: self.delayed_calls.append(self.loop.call_later(delay, self._search)) except (asyncio.CancelledError, StopAsyncIteration): if self.running: - drain_tasks(self.running_probes) - self.running = False + self.loop.call_soon(self.aclose) def _search(self): self.tasks.append(self.loop.create_task(self._search_task())) - def search(self): + def __aiter__(self): if self.running: raise Exception("already running") self.running = True self._search() - - async def next_queue_or_finished(self) -> typing.List['KademliaPeer']: - peers = self.loop.create_task(self.iteration_queue.get()) - finished = self.loop.create_task(self.finished.wait()) - err = None - try: - await asyncio.wait([peers, finished], loop=self.loop, return_when='FIRST_COMPLETED') - if peers.done(): - return peers.result() - raise StopAsyncIteration() - except asyncio.CancelledError as error: - err = error - finally: - if not finished.done() and not finished.cancelled(): - finished.cancel() - if not peers.done() and not peers.cancelled(): - peers.cancel() - if err: - raise err - - def __aiter__(self): - self.search() return self async def __anext__(self) -> typing.List['KademliaPeer']: try: if self.iteration_count == 0: - initial_results = self.get_initial_result() - if initial_results: - self.iteration_queue.put_nowait(initial_results) - result = await self.next_queue_or_finished() + result = self.get_initial_result() or await self.iteration_queue.get() + else: + result = await self.iteration_queue.get() + if not result: + raise StopAsyncIteration self.iteration_count += 1 return result except (asyncio.CancelledError, StopAsyncIteration): - await self.aclose() + self.loop.call_soon(self.aclose) raise def aclose(self): self.running = False + self.iteration_queue.put_nowait(None) + for task in chain(self.tasks, self.running_probes, self.delayed_calls): + task.cancel() + self.tasks.clear() + self.running_probes.clear() + self.delayed_calls.clear() - async def _aclose(): - async with self.lock: - self.running = False - if not self.finished.is_set(): - self.finished.set() - drain_tasks(self.tasks) - drain_tasks(self.running_probes) - while self.delayed_calls: - timer = self.delayed_calls.pop() - if timer: - timer.cancel() - - return asyncio.ensure_future(_aclose(), loop=self.loop) class IterativeNodeFinder(IterativeFinder): @@ -295,24 +270,26 @@ class IterativeNodeFinder(IterativeFinder): response = await self.protocol.get_rpc_peer(peer).find_node(self.key) return FindNodeResponse(self.key, response) - def put_result(self, from_list: typing.List['KademliaPeer']): - not_yet_yielded = [peer for peer in from_list if peer not in self.yielded_peers] + def search_exhausted(self): + self.put_result(self.active, finish=True) + + def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False): + not_yet_yielded = [peer for peer in from_iter if peer not in self.yielded_peers] not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id)) to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))] if to_yield: for peer in to_yield: self.yielded_peers.add(peer) self.iteration_queue.put_nowait(to_yield) + if finish: + self.iteration_queue.put_nowait(None) def check_result_ready(self, response: FindNodeResponse): found = response.found and self.key != self.protocol.node_id if found: log.info("found") - self.put_result(self.shortlist) - if not self.finished.is_set(): - self.finished.set() - return + return self.put_result(self.shortlist, finish=True) if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer): # log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted), # self.bottom_out_count, self.iteration_count) @@ -323,16 +300,10 @@ class IterativeNodeFinder(IterativeFinder): self.bottom_out_count) if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit: log.info("limit hit") - self.put_result(self.active) - if not self.finished.is_set(): - self.finished.set() - return - if self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results: + self.put_result(self.active, True) + elif self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results: log.info("max results") - self.put_result(self.active) - if not self.finished.is_set(): - self.finished.set() - return + self.put_result(self.active, True) class IterativeValueFinder(IterativeFinder): @@ -366,14 +337,11 @@ class IterativeValueFinder(IterativeFinder): # log.info("enough blob peers found") # if not self.finished.is_set(): # self.finished.set() - return - if self.prev_closest_peer and self.closest_peer: + elif self.prev_closest_peer and self.closest_peer: self.bottom_out_count += 1 if self.bottom_out_count >= self.bottom_out_limit: log.info("blob peer search bottomed out") - if not self.finished.is_set(): - self.finished.set() - return + self.iteration_queue.put_nowait(None) def get_initial_result(self) -> typing.List['KademliaPeer']: if self.protocol.data_store.has_peers_for_blob(self.key):