diff --git a/lbrynet/dht/protocol/routing_table.py b/lbrynet/dht/protocol/routing_table.py index c33f44e16..71577b7c7 100644 --- a/lbrynet/dht/protocol/routing_table.py +++ b/lbrynet/dht/protocol/routing_table.py @@ -29,6 +29,7 @@ class KBucket: self.range_max = range_max self.peers: typing.List['KademliaPeer'] = [] self._node_id = node_id + self._distance_to_self = Distance(node_id) def add_peer(self, peer: 'KademliaPeer') -> bool: """ Add contact to _contact list in the right order. This will move the @@ -130,7 +131,7 @@ class KBucket: if not. @rtype: bool """ - return self.range_min <= int.from_bytes(key, 'big') < self.range_max + return self.range_min <= self._distance_to_self(key) < self.range_max def __len__(self) -> int: return len(self.peers) @@ -170,7 +171,7 @@ class TreeRoutingTable: def should_split(self, bucket_index: int, to_add: bytes) -> bool: # https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456 - if self.buckets[bucket_index].key_in_range(self._parent_node_id): + if not bucket_index: return True contacts = self.get_peers() distance = Distance(self._parent_node_id) @@ -236,11 +237,15 @@ class TreeRoutingTable: def random_id_in_bucket_range(self, bucket_index: int) -> bytes: random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max)) - return random_id.to_bytes(constants.hash_length, 'big') + return Distance( + self._parent_node_id + )(random_id.to_bytes(constants.hash_length, 'big')).to_bytes(constants.hash_length, 'big') def midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes: half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2) - return int(self.buckets[bucket_index].range_min + half).to_bytes(constants.hash_length, 'big') + return Distance(self._parent_node_id)( + int(self.buckets[bucket_index].range_min + half).to_bytes(constants.hash_length, 'big') + ).to_bytes(constants.hash_length, 'big') def split_bucket(self, old_bucket_index: int) -> None: """ Splits the specified k-bucket into two new buckets which together diff --git a/tests/unit/dht/routing/test_routing_table.py b/tests/unit/dht/routing/test_routing_table.py index 183e04165..a27b316e2 100644 --- a/tests/unit/dht/routing/test_routing_table.py +++ b/tests/unit/dht/routing/test_routing_table.py @@ -6,6 +6,29 @@ from lbrynet.dht.node import Node from lbrynet.dht.peer import PeerManager +expected_ranges = [ + ( + 0, + 2462625387274654950767440006258975862817483704404090416746768337765357610718575663213391640930307227550414249394176 + ), + ( + 2462625387274654950767440006258975862817483704404090416746768337765357610718575663213391640930307227550414249394176, + 4925250774549309901534880012517951725634967408808180833493536675530715221437151326426783281860614455100828498788352 + ), + ( + 4925250774549309901534880012517951725634967408808180833493536675530715221437151326426783281860614455100828498788352, + 9850501549098619803069760025035903451269934817616361666987073351061430442874302652853566563721228910201656997576704 + ), + ( + 9850501549098619803069760025035903451269934817616361666987073351061430442874302652853566563721228910201656997576704, + 19701003098197239606139520050071806902539869635232723333974146702122860885748605305707133127442457820403313995153408 + ), + ( + 19701003098197239606139520050071806902539869635232723333974146702122860885748605305707133127442457820403313995153408, + 39402006196394479212279040100143613805079739270465446667948293404245721771497210611414266254884915640806627990306816 + ) +] + class TestRouting(AsyncioTestCase): async def test_fill_one_bucket(self): loop = asyncio.get_event_loop() @@ -43,6 +66,42 @@ class TestRouting(AsyncioTestCase): for node in nodes.values(): node.protocol.stop() + async def test_split_buckets(self): + loop = asyncio.get_event_loop() + peer_addresses = [ + (constants.generate_id(1), '1.2.3.1'), + ] + for i in range(2, 200): + peer_addresses.append((constants.generate_id(i), f'1.2.3.{i}')) + with dht_mocks.mock_network_loop(loop): + nodes = { + i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address) + for i, (node_id, address) in enumerate(peer_addresses) + } + node_1 = nodes[0] + for i in range(1, len(peer_addresses)): + node = nodes[i] + peer = node_1.protocol.peer_manager.get_kademlia_peer( + node.protocol.node_id, node.protocol.external_ip, + udp_port=node.protocol.udp_port + ) + # set all of the peers to good (as to not attempt pinging stale ones during split) + node_1.protocol.peer_manager.report_last_replied(peer.address, peer.udp_port) + node_1.protocol.peer_manager.report_last_replied(peer.address, peer.udp_port) + await node_1.protocol.add_peer(peer) + # check that bucket 0 is always the one covering the local node id + self.assertEqual(True, node_1.protocol.routing_table.buckets[0].key_in_range(node_1.protocol.node_id)) + self.assertEqual(40, len(node_1.protocol.routing_table.get_peers())) + self.assertEqual(len(expected_ranges), len(node_1.protocol.routing_table.buckets)) + covered = 0 + for (expected_min, expected_max), bucket in zip(expected_ranges, node_1.protocol.routing_table.buckets): + self.assertEqual(expected_min, bucket.range_min) + self.assertEqual(expected_max, bucket.range_max) + covered += bucket.range_max - bucket.range_min + self.assertEqual(2**384, covered) + for node in nodes.values(): + node.stop() + # from binascii import hexlify, unhexlify #