From bac7d99b8ab9cd85e7912acb01ed584f6cb4b545 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Sun, 12 May 2019 03:39:11 -0300 Subject: [PATCH] add ability to re-join network on disconnect + tests --- lbrynet/blob_exchange/client.py | 4 +-- lbrynet/dht/blob_announcer.py | 2 +- lbrynet/dht/node.py | 42 +++++++++++++++++---------- lbrynet/dht/peer.py | 4 +++ lbrynet/dht/protocol/protocol.py | 1 + tests/integration/test_dht.py | 22 ++++++++++++-- tests/unit/dht/test_blob_announcer.py | 3 +- 7 files changed, 56 insertions(+), 22 deletions(-) diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index 1b4c46f01..4051d8dc3 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -31,8 +31,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol): self.closed = asyncio.Event(loop=self.loop) def data_received(self, data: bytes): - log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received", - self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received) + #log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received", + # self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received) if not self.transport or self.transport.is_closing(): log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port, binascii.hexlify(data)) diff --git a/lbrynet/dht/blob_announcer.py b/lbrynet/dht/blob_announcer.py index 9321b8536..1fc3bc069 100644 --- a/lbrynet/dht/blob_announcer.py +++ b/lbrynet/dht/blob_announcer.py @@ -33,7 +33,7 @@ class BlobAnnouncer: while batch_size: if not self.node.joined.is_set(): await self.node.joined.wait() - await asyncio.sleep(60) + await asyncio.sleep(60, loop=self.loop) if not self.node.protocol.routing_table.get_peers(): log.warning("No peers in DHT, announce round skipped") continue diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index 1efe78037..62ec7f3d9 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -144,9 +144,21 @@ class Node: KademliaPeer(self.loop, address, udp_port=port) for (address, port) in known_node_addresses ] - while not len(self.protocol.routing_table.get_peers()): - peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32)) - self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0) + while True: + if not self.protocol.routing_table.get_peers(): + if self.joined.is_set(): + self.joined.clear() + self.protocol.peer_manager.reset() + self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0) + peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32)) + if self.protocol.routing_table.get_peers(): + self.joined.set() + log.info( + "Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), + self.protocol.routing_table.buckets_with_contacts()) + else: + continue + await asyncio.sleep(1, loop=self.loop) log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), self.protocol.routing_table.buckets_with_contacts()) @@ -186,27 +198,25 @@ class Node: async def _accumulate_search_junction(self, search_queue: asyncio.Queue, result_queue: asyncio.Queue): - ongoing = {} + tasks = [] async def __start_producing_task(): while True: blob_hash = await search_queue.get() - ongoing[blob_hash] = asyncio.create_task(self._value_producer(blob_hash, result_queue)) - ongoing[''] = asyncio.create_task(__start_producing_task()) + tasks.append(asyncio.create_task(self._value_producer(blob_hash, result_queue))) + tasks.append(asyncio.create_task(__start_producing_task())) try: - while True: - await asyncio.wait(ongoing.values(), return_when='FIRST_COMPLETED') - for key in list(ongoing.keys())[:]: - if key and ongoing[key].done(): - ongoing[key] = asyncio.create_task(self._value_producer(key, result_queue)) + await asyncio.wait(tasks) finally: - for task in ongoing.values(): + for task in tasks: task.cancel() async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue): - log.info("Searching %s", blob_hash[:8]) - async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())): - result_queue.put_nowait(results) - log.info("Search expired %s", blob_hash[:8]) + for interval in range(1000): + log.info("Searching %s", blob_hash[:8]) + async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())): + result_queue.put_nowait(results) + log.info("Search expired %s", blob_hash[:8]) + await asyncio.sleep(interval ** 2) def accumulate_peers(self, search_queue: asyncio.Queue, peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[ diff --git a/lbrynet/dht/peer.py b/lbrynet/dht/peer.py index 8253d1332..c4fb5a9ba 100644 --- a/lbrynet/dht/peer.py +++ b/lbrynet/dht/peer.py @@ -32,6 +32,10 @@ class PeerManager: self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {} self._node_tokens: typing.Dict[bytes, (float, bytes)] = {} + def reset(self): + for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested): + statistic.clear() + def report_failure(self, address: str, udp_port: int): now = self._loop.time() _, previous = self._rpc_failures.pop((address, udp_port), (None, None)) diff --git a/lbrynet/dht/protocol/protocol.py b/lbrynet/dht/protocol/protocol.py index 346c7c8b4..03a8df7a8 100644 --- a/lbrynet/dht/protocol/protocol.py +++ b/lbrynet/dht/protocol/protocol.py @@ -275,6 +275,7 @@ class KademliaProtocol(DatagramProtocol): self._wakeup_routing_task = asyncio.Event(loop=self.loop) self.maintaing_routing_task: typing.Optional[asyncio.Task] = None + @functools.lru_cache(128) def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC: return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer) diff --git a/tests/integration/test_dht.py b/tests/integration/test_dht.py index afa2f1341..66f69d6a5 100644 --- a/tests/integration/test_dht.py +++ b/tests/integration/test_dht.py @@ -16,7 +16,7 @@ class DHTIntegrationTest(AsyncioTestCase): self.nodes = [] self.known_node_addresses = [] - async def setup_network(self, size: int, start_port=40000): + async def setup_network(self, size: int, start_port=40000, seed_nodes=1): for i in range(size): node_port = start_port + i node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i), @@ -29,7 +29,7 @@ class DHTIntegrationTest(AsyncioTestCase): for node in self.nodes: node.protocol.rpc_timeout = .2 node.protocol.ping_queue._default_delay = .5 - node.start('127.0.0.1', self.known_node_addresses[:1]) + node.start('127.0.0.1', self.known_node_addresses[:seed_nodes]) await asyncio.gather(*[node.joined.wait() for node in self.nodes]) async def test_replace_bad_nodes(self): @@ -52,6 +52,24 @@ class DHTIntegrationTest(AsyncioTestCase): for peer in node.protocol.routing_table.get_peers(): self.assertIn(peer.node_id, good_nodes) + async def test_re_join(self): + await self.setup_network(20, seed_nodes=10) + node = self.nodes[-1] + self.assertTrue(node.joined.is_set()) + self.assertTrue(node.protocol.routing_table.get_peers()) + for network_node in self.nodes[:-1]: + network_node.stop() + await node.refresh_node(True) + self.assertFalse(node.protocol.routing_table.get_peers()) + for network_node in self.nodes[:-1]: + await network_node.start_listening('127.0.0.1') + self.assertFalse(node.protocol.routing_table.get_peers()) + timeout = 20 + while not node.protocol.routing_table.get_peers(): + await asyncio.sleep(.1) + timeout -= 1 + if not timeout: + self.fail("node didnt join back after 2 seconds") async def test_announce_no_peers(self): await self.setup_network(1) diff --git a/tests/unit/dht/test_blob_announcer.py b/tests/unit/dht/test_blob_announcer.py index 8e75e0e76..1f4c535a8 100644 --- a/tests/unit/dht/test_blob_announcer.py +++ b/tests/unit/dht/test_blob_announcer.py @@ -86,7 +86,8 @@ class TestBlobAnnouncer(AsyncioTestCase): to_announce = await self.storage.get_blobs_to_announce() self.assertEqual(2, len(to_announce)) self.blob_announcer.start(batch_size=1) # so it covers batching logic - await self.advance(61.0) + # takes 60 seconds to start, but we advance 120 to ensure it processed all batches + await self.advance(60.0 * 2) to_announce = await self.storage.get_blobs_to_announce() self.assertEqual(0, len(to_announce)) self.blob_announcer.stop()