lbry-sdk/tests/unit/dht/test_routingtable.py

214 lines
11 KiB
Python
Raw Normal View History

2018-07-18 03:26:08 +02:00
from binascii import hexlify, unhexlify
from twisted.trial import unittest
from twisted.internet import defer
from lbrynet.dht import constants
from lbrynet.dht.routingtable import TreeRoutingTable
from lbrynet.dht.contact import ContactManager
from lbrynet.dht.distance import Distance
from lbrynet.core.utils import generate_id
2018-02-15 00:28:28 +01:00
2017-11-16 20:34:48 +01:00
class FakeRPCProtocol(object):
""" Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
def sendRPC(self, *args, **kwargs):
return defer.succeed(None)
2017-11-16 20:34:48 +01:00
class TreeRoutingTableTest(unittest.TestCase):
""" Test case for the RoutingTable class """
def setUp(self):
self.contact_manager = ContactManager()
self.nodeID = generate_id(b'node1')
2017-11-16 20:34:48 +01:00
self.protocol = FakeRPCProtocol()
self.routingTable = TreeRoutingTable(self.nodeID)
2017-11-16 20:34:48 +01:00
def test_distance(self):
2017-11-16 20:34:48 +01:00
""" Test to see if distance method returns correct result"""
d = Distance(bytes((170,) * 48))
result = d(bytes((85,) * 48))
expected = int(hexlify(bytes((255,) * 48)), 16)
self.assertEqual(result, expected)
2017-11-16 20:34:48 +01:00
@defer.inlineCallbacks
def test_add_contact(self):
2017-11-16 20:34:48 +01:00
""" Tests if a contact can be added and retrieved correctly """
# Create the contact
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
2017-11-16 20:34:48 +01:00
# Now add it...
yield self.routingTable.addContact(contact)
2017-11-16 20:34:48 +01:00
# ...and request the closest nodes to it (will retrieve it)
closest_nodes = self.routingTable.findCloseNodes(contact_id)
self.assertEqual(len(closest_nodes), 1)
self.assertIn(contact, closest_nodes)
2017-11-16 20:34:48 +01:00
@defer.inlineCallbacks
def test_get_contact(self):
2017-11-16 20:34:48 +01:00
""" Tests if a specific existing contact can be retrieved correctly """
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
2017-11-16 20:34:48 +01:00
# Now add it...
yield self.routingTable.addContact(contact)
2017-11-16 20:34:48 +01:00
# ...and get it again
same_contact = self.routingTable.getContact(contact_id)
self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
2017-11-16 20:34:48 +01:00
@defer.inlineCallbacks
def test_add_parent_node_as_contact(self):
2017-11-16 20:34:48 +01:00
"""
Tests the routing table's behaviour when attempting to add its parent node as a contact
"""
# Create a contact with the same ID as the local node's ID
2018-05-29 22:22:30 +02:00
contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
2017-11-16 20:34:48 +01:00
# Now try to add it
yield self.routingTable.addContact(contact)
2017-11-16 20:34:48 +01:00
# ...and request the closest nodes to it using FIND_NODE
closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
2017-11-16 20:34:48 +01:00
@defer.inlineCallbacks
def test_remove_contact(self):
2017-11-16 20:34:48 +01:00
""" Tests contact removal """
# Create the contact
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
2017-11-16 20:34:48 +01:00
# Now add it...
yield self.routingTable.addContact(contact)
2017-11-16 20:34:48 +01:00
# Verify addition
2018-07-21 22:55:43 +02:00
self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly')
2017-11-16 20:34:48 +01:00
# Now remove it
self.routingTable.removeContact(contact)
2018-07-21 22:55:43 +02:00
self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
2017-11-16 20:34:48 +01:00
@defer.inlineCallbacks
def test_split_bucket(self):
2017-11-16 20:34:48 +01:00
""" Tests if the the routing table correctly dynamically splits k-buckets """
2018-07-21 22:55:43 +02:00
self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
2017-11-16 20:34:48 +01:00
'Initial k-bucket range should be 0 <= range < 2**384')
# Add k contacts
for i in range(constants.k):
node_id = generate_id(b'remote node %d' % i)
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
2018-07-21 22:55:43 +02:00
self.assertEqual(len(self.routingTable._buckets), 1,
2017-11-16 20:34:48 +01:00
'Only k nodes have been added; the first k-bucket should now '
'be full, but should not yet be split')
# Now add 1 more contact
node_id = generate_id(b'yet another remote node')
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
2018-07-21 22:55:43 +02:00
self.assertEqual(len(self.routingTable._buckets), 2,
2017-11-16 20:34:48 +01:00
'k+1 nodes have been added; the first k-bucket should have been '
'split into two new buckets')
2018-07-21 22:55:43 +02:00
self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384,
2017-11-16 20:34:48 +01:00
'K-bucket was split, but its range was not properly adjusted')
2018-07-21 22:55:43 +02:00
self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384,
2017-11-16 20:34:48 +01:00
'K-bucket was split, but the second (new) bucket\'s '
'max range was not set properly')
2018-07-21 22:55:43 +02:00
self.assertEqual(self.routingTable._buckets[0].rangeMax,
2017-11-16 20:34:48 +01:00
self.routingTable._buckets[1].rangeMin,
'K-bucket was split, but the min/max ranges were '
'not divided properly')
@defer.inlineCallbacks
def test_full_split(self):
2017-11-16 20:34:48 +01:00
"""
2018-05-29 22:22:30 +02:00
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
2017-11-16 20:34:48 +01:00
"""
self.routingTable._parentNodeID = bytes(48 * b'\xff')
node_ids = [
2018-07-18 03:26:08 +02:00
b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
]
2017-11-16 20:34:48 +01:00
# Add k contacts
2018-05-29 22:22:30 +02:00
for nodeID in node_ids:
# self.assertEquals(nodeID, node_ids[i].decode('hex'))
2018-07-18 03:26:08 +02:00
contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
2018-07-21 22:55:43 +02:00
self.assertEqual(len(self.routingTable._buckets), 2)
self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
# try adding a contact who is further from us than the k'th known contact
2018-07-18 03:26:08 +02:00
nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
nodeID = unhexlify(nodeID)
2018-05-29 22:22:30 +02:00
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id))
yield self.routingTable.addContact(contact)
2018-07-21 22:55:43 +02:00
self.assertEqual(len(self.routingTable._buckets), 2)
self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
self.assertFalse(contact in self.routingTable._buckets[0]._contacts)
self.assertFalse(contact in self.routingTable._buckets[1]._contacts)
# class KeyErrorFixedTest(unittest.TestCase):
# """ Basic tests case for boolean operators on the Contact class """
#
# def setUp(self):
# own_id = (2 ** constants.key_bits) - 1
# # carefully chosen own_id. here's the logic
# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id
# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the
# # end
#
# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id)
#
# def fill_bucket(self, bucket_min):
# bucket_size = lbrynet.dht.constants.k
# for i in range(bucket_min, bucket_min + bucket_size):
# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None))
#
# def overflow_bucket(self, bucket_min):
# bucket_size = lbrynet.dht.constants.k
# self.fill_bucket(bucket_min)
# self.table.addContact(
# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1),
# '127.0.0.1', 9999, None))
#
# def testKeyError(self):
#
# # find middle, so we know where bucket will split
# bucket_middle = self.table._buckets[0].rangeMax / 2
#
# # fill last bucket
# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1)
# # -1 in previous line because own_id is in last bucket
#
# # fill/overflow 7 more buckets
# bucket_start = 0
# for i in range(0, lbrynet.dht.constants.k):
# self.overflow_bucket(bucket_start)
# bucket_start += bucket_middle / (2 ** i)
#
# # replacement cache now has k-1 entries.
# # adding one more contact to bucket 0 used to cause a KeyError, but it should work
# self.table.addContact(
# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None))
#
# # import math
# # print ""
# # for i, bucket in enumerate(self.table._buckets):
# # print "Bucket " + str(i) + " (2 ** " + str(
# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str(
# # math.log(bucket.rangeMax, 2)) + ")"
# # for c in bucket.getContacts():
# # print " contact " + str(c.id)
# # for key, bucket in self.table._replacementCache.items():
# # print "Replacement Cache for Bucket " + str(key)
# # for c in bucket:
# # print " contact " + str(c.id)