refactoring of DHT tests and fixed encoding bug when dealing with bytearray

This commit is contained in:
Lex Berezhny 2018-07-30 21:23:38 -04:00 committed by Jack Robison
parent 2d4bf73632
commit bc24dbea29
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 179 additions and 442 deletions

View file

@ -1,141 +1,75 @@
from __future__ import print_function
from .error import DecodeError
import sys
if sys.version_info > (3,):
long = int
raw = ord
else:
raw = lambda x: x
class Encoding:
""" Interface for RPC message encoders/decoders
All encoding implementations used with this library should inherit and
implement this.
"""
def encode(self, data):
""" Encode the specified data
@param data: The data to encode
This method has to support encoding of the following
types: C{str}, C{int} and C{long}
Any additional data types may be supported as long as the
implementing class's C{decode()} method can successfully
decode them.
@return: The encoded data
@rtype: str
"""
def decode(self, data):
""" Decode the specified data string
@param data: The data (byte string) to decode.
@type data: str
@return: The decoded data (in its correct type)
"""
class Bencode(Encoding):
""" Implementation of a Bencode-based algorithm (Bencode is the encoding
algorithm used by Bittorrent).
def bencode(data):
""" Encoder implementation of the Bencode algorithm (Bittorrent). """
if isinstance(data, int):
return b'i%de' % data
elif isinstance(data, (bytes, bytearray)):
return b'%d:%s' % (len(data), data)
elif isinstance(data, str):
return b'%d:%s' % (len(data), data.encode())
elif isinstance(data, (list, tuple)):
encoded_list_items = b''
for item in data:
encoded_list_items += bencode(item)
return b'l%se' % encoded_list_items
elif isinstance(data, dict):
encoded_dict_items = b''
keys = data.keys()
for key in sorted(keys):
encoded_dict_items += bencode(key)
encoded_dict_items += bencode(data[key])
return b'd%se' % encoded_dict_items
else:
raise TypeError("Cannot bencode '%s' object" % type(data))
@note: This algorithm differs from the "official" Bencode algorithm in
that it can encode/decode floating point values in addition to
integers.
"""
def encode(self, data):
""" Encoder implementation of the Bencode algorithm
def bdecode(data):
""" Decoder implementation of the Bencode algorithm. """
assert type(data) == bytes # fixme: _maybe_ remove this after porting
if len(data) == 0:
raise DecodeError('Cannot decode empty string')
try:
return _decode_recursive(data)[0]
except ValueError as e:
raise DecodeError(str(e))
@param data: The data to encode
@type data: int, long, tuple, list, dict or str
@return: The encoded data
@rtype: str
"""
if isinstance(data, (int, long)):
return b'i%de' % data
elif isinstance(data, bytes):
return b'%d:%s' % (len(data), data)
elif isinstance(data, str):
return b'%d:' % (len(data)) + data.encode()
elif isinstance(data, (list, tuple)):
encodedListItems = b''
for item in data:
encodedListItems += self.encode(item)
return b'l%se' % encodedListItems
elif isinstance(data, dict):
encodedDictItems = b''
keys = data.keys()
for key in sorted(keys):
encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings
encodedDictItems += self.encode(data[key])
return b'd%se' % encodedDictItems
else:
raise TypeError("Cannot bencode '%s' object" % type(data))
def decode(self, data):
""" Decoder implementation of the Bencode algorithm
@param data: The encoded data
@type data: str
@note: This is a convenience wrapper for the recursive decoding
algorithm, C{_decodeRecursive}
@return: The decoded data, as a native Python type
@rtype: int, list, dict or str
"""
assert type(data) == bytes # fixme: _maybe_ remove this after porting
if len(data) == 0:
raise DecodeError('Cannot decode empty string')
def _decode_recursive(data, start_index=0):
if data[start_index] == ord('i'):
end_pos = data[start_index:].find(b'e') + start_index
return int(data[start_index + 1:end_pos]), end_pos + 1
elif data[start_index] == ord('l'):
start_index += 1
decoded_list = []
while data[start_index] != ord('e'):
list_data, start_index = _decode_recursive(data, start_index)
decoded_list.append(list_data)
return decoded_list, start_index + 1
elif data[start_index] == ord('d'):
start_index += 1
decoded_dict = {}
while data[start_index] != ord('e'):
key, start_index = _decode_recursive(data, start_index)
value, start_index = _decode_recursive(data, start_index)
decoded_dict[key] = value
return decoded_dict, start_index
elif data[start_index] == ord('f'):
# This (float data type) is a non-standard extension to the original Bencode algorithm
end_pos = data[start_index:].find(b'e') + start_index
return float(data[start_index + 1:end_pos]), end_pos + 1
elif data[start_index] == ord('n'):
# This (None/NULL data type) is a non-standard extension
# to the original Bencode algorithm
return None, start_index + 1
else:
split_pos = data[start_index:].find(b':') + start_index
try:
return self._decodeRecursive(data)[0]
except ValueError as e:
raise DecodeError(e.message)
@staticmethod
def _decodeRecursive(data, startIndex=0):
""" Actual implementation of the recursive Bencode algorithm
Do not call this; use C{decode()} instead
"""
if data[startIndex] == raw('i'):
endPos = data[startIndex:].find(b'e') + startIndex
return int(data[startIndex + 1:endPos]), endPos + 1
elif data[startIndex] == raw('l'):
startIndex += 1
decodedList = []
while data[startIndex] != raw('e'):
listData, startIndex = Bencode._decodeRecursive(data, startIndex)
decodedList.append(listData)
return decodedList, startIndex + 1
elif data[startIndex] == raw('d'):
startIndex += 1
decodedDict = {}
while data[startIndex] != raw('e'):
key, startIndex = Bencode._decodeRecursive(data, startIndex)
value, startIndex = Bencode._decodeRecursive(data, startIndex)
decodedDict[key] = value
return decodedDict, startIndex
elif data[startIndex] == raw('f'):
# This (float data type) is a non-standard extension to the original Bencode algorithm
endPos = data[startIndex:].find(b'e') + startIndex
return float(data[startIndex + 1:endPos]), endPos + 1
elif data[startIndex] == raw('n'):
# This (None/NULL data type) is a non-standard extension
# to the original Bencode algorithm
return None, startIndex + 1
else:
splitPos = data[startIndex:].find(b':') + startIndex
try:
length = int(data[startIndex:splitPos])
except ValueError:
raise DecodeError()
startIndex = splitPos + 1
endPos = startIndex + length
bytes = data[startIndex:endPos]
return bytes, endPos
length = int(data[start_index:split_pos])
except ValueError:
raise DecodeError()
start_index = split_pos + 1
end_pos = start_index + length
b = data[start_index:end_pos]
return b, end_pos

View file

@ -97,7 +97,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
def __init__(self, node):
self._node = node
self._encoder = encoding.Bencode()
self._translator = msgformat.DefaultFormat()
self._sentMessages = {}
self._partialMessages = {}
@ -163,7 +162,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
*args))
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
encodedMsg = encoding.bencode(msgPrimitive)
if args:
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
@ -237,7 +236,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
else:
return
try:
msgPrimitive = self._encoder.decode(datagram)
msgPrimitive = encoding.bdecode(datagram)
message = self._translator.fromPrimitive(msgPrimitive)
except (encoding.DecodeError, ValueError) as err:
# We received some rubbish here
@ -394,7 +393,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
"""
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
@ -403,7 +402,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
exceptionMessage = exceptionMessage.encode()
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _handleRPC(self, senderContact, rpcID, method, args):

View file

@ -1,18 +1,14 @@
import struct
import hashlib
import logging
from binascii import unhexlify, hexlify
from binascii import unhexlify
from twisted.internet import defer, error
from lbrynet.dht.encoding import Bencode
from lbrynet.dht import encoding
from lbrynet.dht.error import DecodeError
from lbrynet.dht.msgformat import DefaultFormat
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
import sys
if sys.version_info > (3,):
unicode = str
_encode = Bencode()
_datagram_formatter = DefaultFormat()
log = logging.getLogger()
@ -138,7 +134,7 @@ def debug_kademlia_packet(data, source, destination, node):
if log.level != logging.DEBUG:
return
try:
packet = _datagram_formatter.fromPrimitive(_encode.decode(data))
packet = _datagram_formatter.fromPrimitive(encoding.bdecode(data))
if isinstance(packet, RequestMessage):
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
node.clock.seconds())

View file

@ -1,47 +1,50 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
from twisted.trial import unittest
import lbrynet.dht.encoding
from lbrynet.dht.encoding import bencode, bdecode, DecodeError
class BencodeTest(unittest.TestCase):
""" Basic tests case for the Bencode implementation """
def setUp(self):
self.encoding = lbrynet.dht.encoding.Bencode()
# Thanks goes to wikipedia for the initial test cases ;-)
self.cases = ((42, b'i42e'),
(b'spam', b'4:spam'),
([b'spam', 42], b'l4:spami42ee'),
({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee'),
# ...and now the "real life" tests
([[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]],
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'))
# The following test cases are "bad"; i.e. sending rubbish into the decoder to test
# what exceptions get thrown
self.badDecoderCases = (b'abcdefghijklmnopqrstuvwxyz',
b'')
class EncodeDecodeTest(unittest.TestCase):
def testEncoder(self):
""" Tests the bencode encoder """
for value, encodedValue in self.cases:
result = self.encoding.encode(value)
self.assertEqual(
result, encodedValue,
'Value "%s" not correctly encoded! Expected "%s", got "%s"' %
(value, encodedValue, result))
def test_integer(self):
self.assertEqual(bencode(42), b'i42e')
def testDecoder(self):
""" Tests the bencode decoder """
for value, encodedValue in self.cases:
result = self.encoding.decode(encodedValue)
self.assertEqual(
result, value,
'Value "%s" not correctly decoded! Expected "%s", got "%s"' %
(encodedValue, value, result))
for encodedValue in self.badDecoderCases:
self.assertRaises(
lbrynet.dht.encoding.DecodeError, self.encoding.decode, encodedValue)
self.assertEqual(bdecode(b'i42e'), 42)
def test_bytes(self):
self.assertEqual(bencode(b''), b'0:')
self.assertEqual(bencode(b'spam'), b'4:spam')
self.assertEqual(bencode(b'4:spam'), b'6:4:spam')
self.assertEqual(bencode(bytearray(b'spam')), b'4:spam')
self.assertEqual(bdecode(b'0:'), b'')
self.assertEqual(bdecode(b'4:spam'), b'spam')
self.assertEqual(bdecode(b'6:4:spam'), b'4:spam')
def test_string(self):
self.assertEqual(bencode(''), b'0:')
self.assertEqual(bencode('spam'), b'4:spam')
self.assertEqual(bencode('4:spam'), b'6:4:spam')
def test_list(self):
self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee')
self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42])
def test_dict(self):
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
def test_mixed(self):
self.assertEqual(bencode(
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
)
self.assertEqual(bdecode(
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'),
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
)
def test_decode_error(self):
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz')
self.assertRaises(DecodeError, bdecode, b'')

View file

@ -1,56 +1,40 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
import hashlib
from twisted.trial import unittest
import struct
from twisted.trial import unittest
from twisted.internet import defer
from lbrynet.dht.node import Node
from lbrynet.dht import constants
from lbrynet.core.utils import generate_id
class NodeIDTest(unittest.TestCase):
""" Test case for the Node class's ID """
def setUp(self):
self.node = Node()
def testAutoCreatedID(self):
""" Tests if a new node has a valid node ID """
self.assertEqual(type(self.node.node_id), bytes, 'Node does not have a valid ID')
self.assertEqual(len(self.node.node_id), 48, 'Node ID length is incorrect! '
'Expected 384 bits, got %d bits.' %
(len(self.node.node_id) * 8))
def test_new_node_has_auto_created_id(self):
self.assertEqual(type(self.node.node_id), bytes)
self.assertEqual(len(self.node.node_id), 48)
def testUniqueness(self):
""" Tests the uniqueness of the values created by the NodeID generator """
generatedIDs = []
def test_uniqueness_and_length_of_generated_ids(self):
previous_ids = []
for i in range(100):
newID = self.node._generateID()
# ugly uniqueness test
self.assertFalse(newID in generatedIDs, 'Generated ID #%d not unique!' % (i+1))
generatedIDs.append(newID)
def testKeyLength(self):
""" Tests the key Node ID key length """
for i in range(20):
id = self.node._generateID()
# Key length: 20 bytes == 160 bits
self.assertEqual(len(id), 48,
'Length of generated ID is incorrect! Expected 384 bits, '
'got %d bits.' % (len(id)*8))
new_id = self.node._generateID()
self.assertNotIn(new_id, previous_ids, 'id at index {} not unique'.format(i))
self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id)))
previous_ids.append(new_id)
class NodeDataTest(unittest.TestCase):
""" Test case for the Node class's data-related functions """
def setUp(self):
h = hashlib.sha384()
h.update(b'test')
self.node = Node()
self.contact = self.node.contact_manager.make_contact(h.digest(), '127.0.0.1', 12345, self.node._protocol)
self.contact = self.node.contact_manager.make_contact(
h.digest(), '127.0.0.1', 12345, self.node._protocol)
self.token = self.node.make_token(self.contact.compact_ip())
self.cases = []
for i in range(5):
@ -59,19 +43,18 @@ class NodeDataTest(unittest.TestCase):
self.cases.append((h.digest(), 5001+2*i))
@defer.inlineCallbacks
def testStore(self):
def test_store(self):
""" Tests if the node can store (and privately retrieve) some data """
for key, port in self.cases:
yield self.node.store( # pylint: disable=too-many-function-args
yield self.node.store(
self.contact, key, self.token, port, self.contact.id, 0
)
for key, value in self.cases:
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + \
self.contact.id
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id
self.assertTrue(self.node._dataStore.hasPeersForBlob(key),
'Stored key not found in node\'s DataStore: "%s"' % key)
"Stored key not found in node's DataStore: '%s'" % key)
self.assertTrue(expected_result in self.node._dataStore.getPeersForBlob(key),
'Stored val not found in node\'s DataStore: key:"%s" port:"%s" %s'
"Stored val not found in node's DataStore: key:'%s' port:'%s' %s"
% (key, value, self.node._dataStore.getPeersForBlob(key)))
@ -81,182 +64,25 @@ class NodeContactTest(unittest.TestCase):
self.node = Node()
@defer.inlineCallbacks
def testAddContact(self):
def test_add_contact(self):
""" Tests if a contact can be added and retrieved correctly """
# Create the contact
h = hashlib.sha384()
h.update(b'node1')
contactID = h.digest()
contact = self.node.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.node._protocol)
contact_id = generate_id(b'node1')
contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol)
# Now add it...
yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE
closestNodes = self.node._routingTable.findCloseNodes(contactID, constants.k)
self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; '
'expected 1, got %d' % len(closestNodes))
self.assertTrue(contact in closestNodes, 'Added contact not found by issueing '
'_findCloseNodes()')
closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k)
self.assertEqual(len(closest_nodes), 1)
self.assertIn(contact, closest_nodes)
@defer.inlineCallbacks
def testAddSelfAsContact(self):
def test_add_self_as_contact(self):
""" Tests the node's behaviour when attempting to add itself as a contact """
# Create a contact with the same ID as the local node's ID
contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None)
# Now try to add it
yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE
closestNodes = self.node._routingTable.findCloseNodes(self.node.node_id,
constants.k)
self.assertFalse(contact in closestNodes, 'Node added itself as a contact')
# class FakeRPCProtocol(protocol.DatagramProtocol):
# def __init__(self):
# self.reactor = selectreactor.SelectReactor()
# self.testResponse = None
# self.network = None
#
# def createNetwork(self, contactNetwork):
# """
# set up a list of contacts together with their closest contacts
# @param contactNetwork: a sequence of tuples, each containing a contact together with its
# closest contacts: C{(<contact>, <closest contact 1, ...,closest contact n>)}
# """
# self.network = contactNetwork
#
# def sendRPC(self, contact, method, args, rawResponse=False):
# """ Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs"""
#
# h = hashlib.sha384()
# h.update('rpcId')
# rpc_id = h.digest()[:20]
#
# if method == "findNode":
# # get the specific contacts closest contacts
# closestContacts = []
# closestContactsList = []
# for contactTuple in self.network:
# if contact == contactTuple[0]:
# # get the list of closest contacts for this contact
# closestContactsList = contactTuple[1]
# # Pack the closest contacts into a ResponseMessage
# for closeContact in closestContactsList:
# closestContacts.append((closeContact.id, closeContact.address, closeContact.port))
#
# message = ResponseMessage(rpc_id, contact.id, closestContacts)
# df = defer.Deferred()
# df.callback((message, (contact.address, contact.port)))
# return df
# elif method == "findValue":
# for contactTuple in self.network:
# if contact == contactTuple[0]:
# # Get the data stored by this remote contact
# dataDict = contactTuple[2]
# dataKey = dataDict.keys()[0]
# data = dataDict.get(dataKey)
# # Check if this contact has the requested value
# if dataKey == args[0]:
# # Return the data value
# response = dataDict
# print "data found at contact: " + contact.id
# else:
# # Return the closest contact to the requested data key
# print "data not found at contact: " + contact.id
# closeContacts = contactTuple[1]
# closestContacts = []
# for closeContact in closeContacts:
# closestContacts.append((closeContact.id, closeContact.address,
# closeContact.port))
# response = closestContacts
#
# # Create the response message
# message = ResponseMessage(rpc_id, contact.id, response)
# df = defer.Deferred()
# df.callback((message, (contact.address, contact.port)))
# return df
#
# def _send(self, data, rpcID, address):
# """ fake sending data """
#
#
# class NodeLookupTest(unittest.TestCase):
# """ Test case for the Node class's iterativeFind node lookup algorithm """
#
# def setUp(self):
# # create a fake protocol to imitate communication with other nodes
# self._protocol = FakeRPCProtocol()
# # Note: The reactor is never started for this test. All deferred calls run sequentially,
# # since there is no asynchronous network communication
# # create the node to be tested in isolation
# h = hashlib.sha384()
# h.update('node1')
# node_id = str(h.digest())
# self.node = Node(node_id, 4000, None, None, self._protocol)
# self.updPort = 81173
# self.contactsAmount = 80
# # Reinitialise the routing table
# self.node._routingTable = TreeRoutingTable(self.node.node_id)
#
# # create 160 bit node ID's for test purposes
# self.testNodeIDs = []
# idNum = int(self.node.node_id.encode('hex'), 16)
# for i in range(self.contactsAmount):
# # create the testNodeIDs in ascending order, away from the actual node ID,
# # with regards to the distance metric
# self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex'))
#
# # generate contacts
# self.contacts = []
# for i in range(self.contactsAmount):
# contact = self.node.contact_manager.make_contact(self.testNodeIDs[i], "127.0.0.1",
# self.updPort + i + 1, self._protocol)
# self.contacts.append(contact)
#
# # create the network of contacts in format: (contact, closest contacts)
# contactNetwork = ((self.contacts[0], self.contacts[8:15]),
# (self.contacts[1], self.contacts[16:23]),
# (self.contacts[2], self.contacts[24:31]),
# (self.contacts[3], self.contacts[32:39]),
# (self.contacts[4], self.contacts[40:47]),
# (self.contacts[5], self.contacts[48:55]),
# (self.contacts[6], self.contacts[56:63]),
# (self.contacts[7], self.contacts[64:71]),
# (self.contacts[8], self.contacts[72:79]),
# (self.contacts[40], self.contacts[41:48]),
# (self.contacts[41], self.contacts[41:48]),
# (self.contacts[42], self.contacts[41:48]),
# (self.contacts[43], self.contacts[41:48]),
# (self.contacts[44], self.contacts[41:48]),
# (self.contacts[45], self.contacts[41:48]),
# (self.contacts[46], self.contacts[41:48]),
# (self.contacts[47], self.contacts[41:48]),
# (self.contacts[48], self.contacts[41:48]),
# (self.contacts[50], self.contacts[0:7]),
# (self.contacts[51], self.contacts[8:15]),
# (self.contacts[52], self.contacts[16:23]))
#
# contacts_with_datastores = []
#
# for contact_tuple in contactNetwork:
# contacts_with_datastores.append((contact_tuple[0], contact_tuple[1],
# DictDataStore()))
# self._protocol.createNetwork(contacts_with_datastores)
#
# # @defer.inlineCallbacks
# # def testNodeBootStrap(self):
# # """ Test bootstrap with the closest possible contacts """
# # # Set the expected result
# # expectedResult = {item.id for item in self.contacts[0:8]}
# #
# # activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8])
# #
# # # Check the length of the active contacts
# # self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(),
# # "More active contacts should exist, there should be %d "
# # "contacts but there are %d" % (len(expectedResult),
# # len(activeContacts)))
# #
# # # Check that the received active contacts are the same as the input contacts
# # self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult,
# # "Active should only contain the closest possible contacts"
# # " which were used as input for the boostrap")
closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k)
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.')

View file

@ -1,4 +1,3 @@
import hashlib
from binascii import hexlify, unhexlify
from twisted.trial import unittest
@ -7,9 +6,7 @@ from lbrynet.dht import constants
from lbrynet.dht.routingtable import TreeRoutingTable
from lbrynet.dht.contact import ContactManager
from lbrynet.dht.distance import Distance
import sys
if sys.version_info > (3,):
long = int
from lbrynet.core.utils import generate_id
class FakeRPCProtocol(object):
@ -21,76 +18,61 @@ class FakeRPCProtocol(object):
class TreeRoutingTableTest(unittest.TestCase):
""" Test case for the RoutingTable class """
def setUp(self):
h = hashlib.sha384()
h.update(b'node1')
self.contact_manager = ContactManager()
self.nodeID = h.digest()
self.nodeID = generate_id(b'node1')
self.protocol = FakeRPCProtocol()
self.routingTable = TreeRoutingTable(self.nodeID)
def testDistance(self):
def test_distance(self):
""" Test to see if distance method returns correct result"""
# testList holds a couple 3-tuple (variable1, variable2, result)
basicTestList = [(bytes(b'\xaa' * 48), bytes(b'\x55' * 48), long(hexlify(bytes(b'\xff' * 48)), 16))]
for test in basicTestList:
result = Distance(test[0])(test[1])
self.assertFalse(result != test[2], 'Result of _distance() should be %s but %s returned' %
(test[2], result))
d = Distance(bytes((170,) * 48))
result = d(bytes((85,) * 48))
expected = int(hexlify(bytes((255,) * 48)), 16)
self.assertEqual(result, expected)
@defer.inlineCallbacks
def testAddContact(self):
def test_add_contact(self):
""" Tests if a contact can be added and retrieved correctly """
# Create the contact
h = hashlib.sha384()
h.update(b'node2')
contactID = h.digest()
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# Now add it...
yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it (will retrieve it)
closestNodes = self.routingTable.findCloseNodes(contactID)
self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,'
' got %d' % len(closestNodes))
self.assertTrue(contact in closestNodes, 'Added contact not found by issueing '
'_findCloseNodes()')
closest_nodes = self.routingTable.findCloseNodes(contact_id)
self.assertEqual(len(closest_nodes), 1)
self.assertIn(contact, closest_nodes)
@defer.inlineCallbacks
def testGetContact(self):
def test_get_contact(self):
""" Tests if a specific existing contact can be retrieved correctly """
h = hashlib.sha384()
h.update(b'node2')
contactID = h.digest()
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# Now add it...
yield self.routingTable.addContact(contact)
# ...and get it again
sameContact = self.routingTable.getContact(contactID)
self.assertEqual(contact, sameContact, 'getContact() should return the same contact')
same_contact = self.routingTable.getContact(contact_id)
self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
@defer.inlineCallbacks
def testAddParentNodeAsContact(self):
def test_add_parent_node_as_contact(self):
"""
Tests the routing table's behaviour when attempting to add its parent node as a contact
"""
# Create a contact with the same ID as the local node's ID
contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
# Now try to add it
yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE
closestNodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
self.assertFalse(contact in closestNodes, 'Node added itself as a contact')
closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
@defer.inlineCallbacks
def testRemoveContact(self):
def test_remove_contact(self):
""" Tests contact removal """
# Create the contact
h = hashlib.sha384()
h.update(b'node2')
contactID = h.digest()
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# Now add it...
yield self.routingTable.addContact(contact)
# Verify addition
@ -100,25 +82,22 @@ class TreeRoutingTableTest(unittest.TestCase):
self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
@defer.inlineCallbacks
def testSplitBucket(self):
def test_split_bucket(self):
""" Tests if the the routing table correctly dynamically splits k-buckets """
self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
'Initial k-bucket range should be 0 <= range < 2**384')
# Add k contacts
for i in range(constants.k):
h = hashlib.sha384()
h.update(b'remote node %d' % i)
nodeID = h.digest()
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
node_id = generate_id(b'remote node %d' % i)
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 1,
'Only k nodes have been added; the first k-bucket should now '
'be full, but should not yet be split')
# Now add 1 more contact
h = hashlib.sha384()
h.update(b'yet another remote node')
nodeID = h.digest()
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
node_id = generate_id(b'yet another remote node')
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 2,
'k+1 nodes have been added; the first k-bucket should have been '
@ -134,7 +113,7 @@ class TreeRoutingTableTest(unittest.TestCase):
'not divided properly')
@defer.inlineCallbacks
def testFullSplit(self):
def test_full_split(self):
"""
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
"""