diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index b10884644..aa52a9067 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -74,24 +74,25 @@ class Node: await fut async def announce_blob(self, blob_hash: str) -> typing.List[bytes]: - announced_to_node_ids = [] - while not announced_to_node_ids: - hash_value = binascii.unhexlify(blob_hash.encode()) - assert len(hash_value) == constants.hash_length - peers = await self.peer_search(hash_value) + hash_value = binascii.unhexlify(blob_hash.encode()) + assert len(hash_value) == constants.hash_length + peers = await self.peer_search(hash_value) - if not self.protocol.external_ip: - raise Exception("Cannot determine external IP") - log.debug("Store to %i peers", len(peers)) - for peer in peers: - log.debug("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port) - stored_to_tup = await asyncio.gather( - *(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop - ) - announced_to_node_ids.extend([node_id for node_id, contacted in stored_to_tup if contacted]) + if not self.protocol.external_ip: + raise Exception("Cannot determine external IP") + log.debug("Store to %i peers", len(peers)) + for peer in peers: + log.debug("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port) + stored_to_tup = await asyncio.gather( + *(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop + ) + stored_to = [node_id for node_id, contacted in stored_to_tup if contacted] + if stored_to: log.info("Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8], - len(announced_to_node_ids), len(peers)) - return announced_to_node_ids + len(stored_to), len(peers)) + else: + log.warning("Failed announcing %s, stored to 0 peers") + return stored_to def stop(self) -> None: if self.joined.is_set(): diff --git a/lbrynet/dht/protocol/iterative_find.py b/lbrynet/dht/protocol/iterative_find.py index f60b3af6b..ad92342cf 100644 --- a/lbrynet/dht/protocol/iterative_find.py +++ b/lbrynet/dht/protocol/iterative_find.py @@ -276,8 +276,7 @@ class IterativeNodeFinder(IterativeFinder): not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id)) to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))] if to_yield: - for peer in to_yield: - self.yielded_peers.add(peer) + self.yielded_peers.update(to_yield) self.iteration_queue.put_nowait(to_yield) if finish: self.iteration_queue.put_nowait(None) diff --git a/tests/integration/test_dht.py b/tests/integration/test_dht.py index 8092116cb..afa2f1341 100644 --- a/tests/integration/test_dht.py +++ b/tests/integration/test_dht.py @@ -1,4 +1,5 @@ import asyncio +from binascii import hexlify from lbrynet.dht import constants from lbrynet.dht.node import Node @@ -6,7 +7,7 @@ from lbrynet.dht.peer import PeerManager, KademliaPeer from torba.testcase import AsyncioTestCase -class CLIIntegrationTest(AsyncioTestCase): +class DHTIntegrationTest(AsyncioTestCase): async def asyncSetUp(self): import logging @@ -24,16 +25,13 @@ class CLIIntegrationTest(AsyncioTestCase): self.nodes.append(node) self.known_node_addresses.append(('127.0.0.1', node_port)) await node.start_listening('127.0.0.1') + self.addCleanup(node.stop) for node in self.nodes: node.protocol.rpc_timeout = .2 node.protocol.ping_queue._default_delay = .5 node.start('127.0.0.1', self.known_node_addresses[:1]) await asyncio.gather(*[node.joined.wait() for node in self.nodes]) - async def asyncTearDown(self): - for node in self.nodes: - node.stop() - async def test_replace_bad_nodes(self): await self.setup_network(20) self.assertEquals(len(self.nodes), 20) @@ -55,3 +53,9 @@ class CLIIntegrationTest(AsyncioTestCase): self.assertIn(peer.node_id, good_nodes) + async def test_announce_no_peers(self): + await self.setup_network(1) + node = self.nodes[0] + blob_hash = hexlify(constants.generate_id(1337)).decode() + peers = await node.announce_blob(blob_hash) + self.assertEqual(len(peers), 0)