From 5ee4b2173a7cd76b0f6b2497a156f85ac63dad5c Mon Sep 17 00:00:00 2001
From: Victor Shyba <victor1984@riseup.net>
Date: Thu, 31 Jan 2019 01:13:01 -0300
Subject: [PATCH] fix probes bugs, partial DHT refactor

---
 lbrynet/dht/protocol/iterative_find.py | 196 +++++++++++--------------
 1 file changed, 82 insertions(+), 114 deletions(-)

diff --git a/lbrynet/dht/protocol/iterative_find.py b/lbrynet/dht/protocol/iterative_find.py
index 8b899013b..f0b905295 100644
--- a/lbrynet/dht/protocol/iterative_find.py
+++ b/lbrynet/dht/protocol/iterative_find.py
@@ -1,7 +1,9 @@
 import asyncio
+from binascii import hexlify
+from itertools import chain
+
 import typing
 import logging
-from lbrynet.utils import drain_tasks
 from lbrynet.dht import constants
 from lbrynet.dht.error import RemoteException
 from lbrynet.dht.protocol.distance import Distance
@@ -90,8 +92,8 @@ class IterativeFinder:
         self.exclude = exclude or []
 
         self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist)
-        self.active: typing.List['KademliaPeer'] = []
-        self.contacted: typing.List[typing.Tuple[str, int]] = []
+        self.active: typing.Set['KademliaPeer'] = set()
+        self.contacted: typing.Set[typing.Tuple[str, int]] = set()
         self.distance = Distance(key)
 
         self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0]
@@ -99,14 +101,12 @@ class IterativeFinder:
 
         self.iteration_queue = asyncio.Queue(loop=self.loop)
 
-        self.running_probes: typing.List[asyncio.Task] = []
-        self.lock = asyncio.Lock(loop=self.loop)
+        self.running_probes: typing.Set[asyncio.Task] = set()
         self.iteration_count = 0
         self.bottom_out_count = 0
         self.running = False
         self.tasks: typing.List[asyncio.Task] = []
         self.delayed_calls: typing.List[asyncio.Handle] = []
-        self.finished = asyncio.Event(loop=self.loop)
 
     async def send_probe(self, peer: 'KademliaPeer') -> FindResponse:
         """
@@ -114,9 +114,16 @@ class IterativeFinder:
         """
         raise NotImplementedError()
 
+    def search_exhausted(self):
+        """
+        This method ends the iterator due no more peers to contact.
+        Override to provide last time results.
+        """
+        self.iteration_queue.put_nowait(None)
+
     def check_result_ready(self, response: FindResponse):
         """
-        Called with a lock after adding peers from an rpc result to the shortlist.
+        Called after adding peers from an rpc result to the shortlist.
         This method is responsible for putting a result for the generator into the Queue
         """
         raise NotImplementedError()
@@ -129,9 +136,7 @@ class IterativeFinder:
         return []
 
     def _is_closer(self, peer: 'KademliaPeer') -> bool:
-        if not self.closest_peer:
-            return True
-        return self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
+        return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
 
     def _update_closest(self):
         self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True)
@@ -141,21 +146,18 @@ class IterativeFinder:
                 self.closest_peer = self.shortlist[-1]
 
     async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
-        async with self.lock:
-            if peer not in self.shortlist:
-                self.shortlist.append(peer)
-            if peer not in self.active:
-                self.active.append(peer)
-            for contact_triple in response.get_close_triples():
-                addr_tuple = (contact_triple[1], contact_triple[2])
-                if addr_tuple not in self.contacted:  # and not self.peer_manager.is_ignored(addr_tuple)
-                    found_peer = self.peer_manager.get_kademlia_peer(
-                        contact_triple[0], contact_triple[1], contact_triple[2]
-                    )
-                    if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False:
-                        self.shortlist.append(found_peer)
-            self._update_closest()
-            self.check_result_ready(response)
+        if peer not in self.shortlist:
+            self.shortlist.append(peer)
+        if peer not in self.active:
+            self.active.add(peer)
+        for contact_triple in response.get_close_triples():
+            node_id, address, udp_port = contact_triple
+            if (address, udp_port) not in self.contacted:  # and not self.peer_manager.is_ignored(addr_tuple)
+                found_peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
+                if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False:
+                    self.shortlist.append(found_peer)
+        self._update_closest()
+        self.check_result_ready(response)
 
     async def _send_probe(self, peer: 'KademliaPeer'):
         try:
@@ -163,13 +165,11 @@ class IterativeFinder:
         except asyncio.CancelledError:
             return
         except asyncio.TimeoutError:
-            if peer in self.active:
-                self.active.remove(peer)
+            self.active.discard(peer)
             return
         except ValueError as err:
             log.warning(str(err))
-            if peer in self.active:
-                self.active.remove(peer)
+            self.active.discard(peer)
             return
         except RemoteException:
             return
@@ -181,31 +181,35 @@ class IterativeFinder:
         """
 
         added = 0
-        async with self.lock:
-            self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True)
-            while self.running and len(self.shortlist) and added < constants.alpha:
-                peer = self.shortlist.pop()
-                origin_address = (peer.address, peer.udp_port)
-                if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False:
-                    continue
-                if peer.node_id == self.protocol.node_id:
-                    continue
-                if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port):
-                    continue
-                if (peer.address, peer.udp_port) not in self.contacted:
-                    self.contacted.append((peer.address, peer.udp_port))
+        self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True)
+        while self.running and len(self.shortlist) and added < constants.alpha:
+            peer = self.shortlist.pop()
+            origin_address = (peer.address, peer.udp_port)
+            if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False:
+                continue
+            if peer.node_id == self.protocol.node_id:
+                continue
+            if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port):
+                continue
+            if (peer.address, peer.udp_port) not in self.contacted:
+                self.contacted.add((peer.address, peer.udp_port))
 
-                    t = self.loop.create_task(self._send_probe(peer))
+                t = self.loop.create_task(self._send_probe(peer))
 
-                    def callback(_):
-                        if t and t in self.running_probes:
-                            self.running_probes.remove(t)
-                        if not self.running_probes and self.shortlist:
-                            self.tasks.append(self.loop.create_task(self._search_task(0.0)))
+                def callback(_):
+                    self.running_probes.difference_update({
+                        probe for probe in self.running_probes if probe.done() or probe == t
+                    })
+                    if not self.running_probes and self.shortlist:
+                        self.tasks.append(self.loop.create_task(self._search_task(0.0)))
 
-                    t.add_done_callback(callback)
-                    self.running_probes.append(t)
-                    added += 1
+                t.add_done_callback(callback)
+                self.running_probes.add(t)
+                added += 1
+        log.debug("running %d probes", len(self.running_probes))
+        if not added and not self.running_probes:
+            log.debug("search for %s exhausted", hexlify(self.key)[:8])
+            self.search_exhausted()
 
     async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay):
         try:
@@ -215,70 +219,41 @@ class IterativeFinder:
                 self.delayed_calls.append(self.loop.call_later(delay, self._search))
         except (asyncio.CancelledError, StopAsyncIteration):
             if self.running:
-                drain_tasks(self.running_probes)
-                self.running = False
+                self.loop.call_soon(self.aclose)
 
     def _search(self):
         self.tasks.append(self.loop.create_task(self._search_task()))
 
-    def search(self):
+    def __aiter__(self):
         if self.running:
             raise Exception("already running")
         self.running = True
         self._search()
-
-    async def next_queue_or_finished(self) -> typing.List['KademliaPeer']:
-        peers = self.loop.create_task(self.iteration_queue.get())
-        finished = self.loop.create_task(self.finished.wait())
-        err = None
-        try:
-            await asyncio.wait([peers, finished], loop=self.loop, return_when='FIRST_COMPLETED')
-            if peers.done():
-                return peers.result()
-            raise StopAsyncIteration()
-        except asyncio.CancelledError as error:
-            err = error
-        finally:
-            if not finished.done() and not finished.cancelled():
-                finished.cancel()
-            if not peers.done() and not peers.cancelled():
-                peers.cancel()
-            if err:
-                raise err
-
-    def __aiter__(self):
-        self.search()
         return self
 
     async def __anext__(self) -> typing.List['KademliaPeer']:
         try:
             if self.iteration_count == 0:
-                initial_results = self.get_initial_result()
-                if initial_results:
-                    self.iteration_queue.put_nowait(initial_results)
-            result = await self.next_queue_or_finished()
+                result = self.get_initial_result() or await self.iteration_queue.get()
+            else:
+                result = await self.iteration_queue.get()
+            if not result:
+                raise StopAsyncIteration
             self.iteration_count += 1
             return result
         except (asyncio.CancelledError, StopAsyncIteration):
-            await self.aclose()
+            self.loop.call_soon(self.aclose)
             raise
 
     def aclose(self):
         self.running = False
+        self.iteration_queue.put_nowait(None)
+        for task in chain(self.tasks, self.running_probes, self.delayed_calls):
+            task.cancel()
+        self.tasks.clear()
+        self.running_probes.clear()
+        self.delayed_calls.clear()
 
-        async def _aclose():
-            async with self.lock:
-                self.running = False
-                if not self.finished.is_set():
-                    self.finished.set()
-                drain_tasks(self.tasks)
-                drain_tasks(self.running_probes)
-                while self.delayed_calls:
-                    timer = self.delayed_calls.pop()
-                    if timer:
-                        timer.cancel()
-
-        return asyncio.ensure_future(_aclose(), loop=self.loop)
 
 
 class IterativeNodeFinder(IterativeFinder):
@@ -295,24 +270,26 @@ class IterativeNodeFinder(IterativeFinder):
         response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
         return FindNodeResponse(self.key, response)
 
-    def put_result(self, from_list: typing.List['KademliaPeer']):
-        not_yet_yielded = [peer for peer in from_list if peer not in self.yielded_peers]
+    def search_exhausted(self):
+        self.put_result(self.active, finish=True)
+
+    def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False):
+        not_yet_yielded = [peer for peer in from_iter if peer not in self.yielded_peers]
         not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
         to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))]
         if to_yield:
             for peer in to_yield:
                 self.yielded_peers.add(peer)
             self.iteration_queue.put_nowait(to_yield)
+        if finish:
+            self.iteration_queue.put_nowait(None)
 
     def check_result_ready(self, response: FindNodeResponse):
         found = response.found and self.key != self.protocol.node_id
 
         if found:
             log.info("found")
-            self.put_result(self.shortlist)
-            if not self.finished.is_set():
-                self.finished.set()
-            return
+            return self.put_result(self.shortlist, finish=True)
         if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
             # log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
             #          self.bottom_out_count, self.iteration_count)
@@ -323,16 +300,10 @@ class IterativeNodeFinder(IterativeFinder):
                      self.bottom_out_count)
         if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
             log.info("limit hit")
-            self.put_result(self.active)
-            if not self.finished.is_set():
-                self.finished.set()
-            return
-        if self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results:
+            self.put_result(self.active, True)
+        elif self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results:
             log.info("max results")
-            self.put_result(self.active)
-            if not self.finished.is_set():
-                self.finished.set()
-            return
+            self.put_result(self.active, True)
 
 
 class IterativeValueFinder(IterativeFinder):
@@ -366,14 +337,11 @@ class IterativeValueFinder(IterativeFinder):
                 #     log.info("enough blob peers found")
                 #     if not self.finished.is_set():
                 #         self.finished.set()
-            return
-        if self.prev_closest_peer and self.closest_peer:
+        elif self.prev_closest_peer and self.closest_peer:
             self.bottom_out_count += 1
             if self.bottom_out_count >= self.bottom_out_limit:
                 log.info("blob peer search bottomed out")
-                if not self.finished.is_set():
-                    self.finished.set()
-                return
+                self.iteration_queue.put_nowait(None)
 
     def get_initial_result(self) -> typing.List['KademliaPeer']:
         if self.protocol.data_store.has_peers_for_blob(self.key):