refactoring of DHT tests and fixed encoding bug when dealing with bytearray

This commit is contained in:
Lex Berezhny 2018-07-30 21:23:38 -04:00 committed by Jack Robison
parent 2d4bf73632
commit bc24dbea29
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 179 additions and 442 deletions

View file

@ -1,141 +1,75 @@
from __future__ import print_function
from .error import DecodeError from .error import DecodeError
import sys
if sys.version_info > (3,):
long = int
raw = ord
else:
raw = lambda x: x
class Encoding:
""" Interface for RPC message encoders/decoders
All encoding implementations used with this library should inherit and
implement this.
"""
def encode(self, data):
""" Encode the specified data
@param data: The data to encode
This method has to support encoding of the following
types: C{str}, C{int} and C{long}
Any additional data types may be supported as long as the
implementing class's C{decode()} method can successfully
decode them.
@return: The encoded data
@rtype: str
"""
def decode(self, data):
""" Decode the specified data string
@param data: The data (byte string) to decode.
@type data: str
@return: The decoded data (in its correct type)
"""
class Bencode(Encoding): def bencode(data):
""" Implementation of a Bencode-based algorithm (Bencode is the encoding """ Encoder implementation of the Bencode algorithm (Bittorrent). """
algorithm used by Bittorrent). if isinstance(data, int):
return b'i%de' % data
elif isinstance(data, (bytes, bytearray)):
return b'%d:%s' % (len(data), data)
elif isinstance(data, str):
return b'%d:%s' % (len(data), data.encode())
elif isinstance(data, (list, tuple)):
encoded_list_items = b''
for item in data:
encoded_list_items += bencode(item)
return b'l%se' % encoded_list_items
elif isinstance(data, dict):
encoded_dict_items = b''
keys = data.keys()
for key in sorted(keys):
encoded_dict_items += bencode(key)
encoded_dict_items += bencode(data[key])
return b'd%se' % encoded_dict_items
else:
raise TypeError("Cannot bencode '%s' object" % type(data))
@note: This algorithm differs from the "official" Bencode algorithm in
that it can encode/decode floating point values in addition to
integers.
"""
def encode(self, data): def bdecode(data):
""" Encoder implementation of the Bencode algorithm """ Decoder implementation of the Bencode algorithm. """
assert type(data) == bytes # fixme: _maybe_ remove this after porting
if len(data) == 0:
raise DecodeError('Cannot decode empty string')
try:
return _decode_recursive(data)[0]
except ValueError as e:
raise DecodeError(str(e))
@param data: The data to encode
@type data: int, long, tuple, list, dict or str
@return: The encoded data def _decode_recursive(data, start_index=0):
@rtype: str if data[start_index] == ord('i'):
""" end_pos = data[start_index:].find(b'e') + start_index
if isinstance(data, (int, long)): return int(data[start_index + 1:end_pos]), end_pos + 1
return b'i%de' % data elif data[start_index] == ord('l'):
elif isinstance(data, bytes): start_index += 1
return b'%d:%s' % (len(data), data) decoded_list = []
elif isinstance(data, str): while data[start_index] != ord('e'):
return b'%d:' % (len(data)) + data.encode() list_data, start_index = _decode_recursive(data, start_index)
elif isinstance(data, (list, tuple)): decoded_list.append(list_data)
encodedListItems = b'' return decoded_list, start_index + 1
for item in data: elif data[start_index] == ord('d'):
encodedListItems += self.encode(item) start_index += 1
return b'l%se' % encodedListItems decoded_dict = {}
elif isinstance(data, dict): while data[start_index] != ord('e'):
encodedDictItems = b'' key, start_index = _decode_recursive(data, start_index)
keys = data.keys() value, start_index = _decode_recursive(data, start_index)
for key in sorted(keys): decoded_dict[key] = value
encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings return decoded_dict, start_index
encodedDictItems += self.encode(data[key]) elif data[start_index] == ord('f'):
return b'd%se' % encodedDictItems # This (float data type) is a non-standard extension to the original Bencode algorithm
else: end_pos = data[start_index:].find(b'e') + start_index
raise TypeError("Cannot bencode '%s' object" % type(data)) return float(data[start_index + 1:end_pos]), end_pos + 1
elif data[start_index] == ord('n'):
def decode(self, data): # This (None/NULL data type) is a non-standard extension
""" Decoder implementation of the Bencode algorithm # to the original Bencode algorithm
return None, start_index + 1
@param data: The encoded data else:
@type data: str split_pos = data[start_index:].find(b':') + start_index
@note: This is a convenience wrapper for the recursive decoding
algorithm, C{_decodeRecursive}
@return: The decoded data, as a native Python type
@rtype: int, list, dict or str
"""
assert type(data) == bytes # fixme: _maybe_ remove this after porting
if len(data) == 0:
raise DecodeError('Cannot decode empty string')
try: try:
return self._decodeRecursive(data)[0] length = int(data[start_index:split_pos])
except ValueError as e: except ValueError:
raise DecodeError(e.message) raise DecodeError()
start_index = split_pos + 1
@staticmethod end_pos = start_index + length
def _decodeRecursive(data, startIndex=0): b = data[start_index:end_pos]
""" Actual implementation of the recursive Bencode algorithm return b, end_pos
Do not call this; use C{decode()} instead
"""
if data[startIndex] == raw('i'):
endPos = data[startIndex:].find(b'e') + startIndex
return int(data[startIndex + 1:endPos]), endPos + 1
elif data[startIndex] == raw('l'):
startIndex += 1
decodedList = []
while data[startIndex] != raw('e'):
listData, startIndex = Bencode._decodeRecursive(data, startIndex)
decodedList.append(listData)
return decodedList, startIndex + 1
elif data[startIndex] == raw('d'):
startIndex += 1
decodedDict = {}
while data[startIndex] != raw('e'):
key, startIndex = Bencode._decodeRecursive(data, startIndex)
value, startIndex = Bencode._decodeRecursive(data, startIndex)
decodedDict[key] = value
return decodedDict, startIndex
elif data[startIndex] == raw('f'):
# This (float data type) is a non-standard extension to the original Bencode algorithm
endPos = data[startIndex:].find(b'e') + startIndex
return float(data[startIndex + 1:endPos]), endPos + 1
elif data[startIndex] == raw('n'):
# This (None/NULL data type) is a non-standard extension
# to the original Bencode algorithm
return None, startIndex + 1
else:
splitPos = data[startIndex:].find(b':') + startIndex
try:
length = int(data[startIndex:splitPos])
except ValueError:
raise DecodeError()
startIndex = splitPos + 1
endPos = startIndex + length
bytes = data[startIndex:endPos]
return bytes, endPos

View file

@ -97,7 +97,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
def __init__(self, node): def __init__(self, node):
self._node = node self._node = node
self._encoder = encoding.Bencode()
self._translator = msgformat.DefaultFormat() self._translator = msgformat.DefaultFormat()
self._sentMessages = {} self._sentMessages = {}
self._partialMessages = {} self._partialMessages = {}
@ -163,7 +162,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method, msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
*args)) *args))
msgPrimitive = self._translator.toPrimitive(msg) msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive) encodedMsg = encoding.bencode(msgPrimitive)
if args: if args:
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method, log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
@ -237,7 +236,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
else: else:
return return
try: try:
msgPrimitive = self._encoder.decode(datagram) msgPrimitive = encoding.bdecode(datagram)
message = self._translator.fromPrimitive(msgPrimitive) message = self._translator.fromPrimitive(msgPrimitive)
except (encoding.DecodeError, ValueError) as err: except (encoding.DecodeError, ValueError) as err:
# We received some rubbish here # We received some rubbish here
@ -394,7 +393,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
""" """
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response) msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
msgPrimitive = self._translator.toPrimitive(msg) msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive) encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port)) self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage): def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
@ -403,7 +402,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
exceptionMessage = exceptionMessage.encode() exceptionMessage = exceptionMessage.encode()
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage) msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
msgPrimitive = self._translator.toPrimitive(msg) msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive) encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port)) self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _handleRPC(self, senderContact, rpcID, method, args): def _handleRPC(self, senderContact, rpcID, method, args):

View file

@ -1,18 +1,14 @@
import struct import struct
import hashlib import hashlib
import logging import logging
from binascii import unhexlify, hexlify from binascii import unhexlify
from twisted.internet import defer, error from twisted.internet import defer, error
from lbrynet.dht.encoding import Bencode from lbrynet.dht import encoding
from lbrynet.dht.error import DecodeError from lbrynet.dht.error import DecodeError
from lbrynet.dht.msgformat import DefaultFormat from lbrynet.dht.msgformat import DefaultFormat
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
import sys
if sys.version_info > (3,):
unicode = str
_encode = Bencode()
_datagram_formatter = DefaultFormat() _datagram_formatter = DefaultFormat()
log = logging.getLogger() log = logging.getLogger()
@ -138,7 +134,7 @@ def debug_kademlia_packet(data, source, destination, node):
if log.level != logging.DEBUG: if log.level != logging.DEBUG:
return return
try: try:
packet = _datagram_formatter.fromPrimitive(_encode.decode(data)) packet = _datagram_formatter.fromPrimitive(encoding.bdecode(data))
if isinstance(packet, RequestMessage): if isinstance(packet, RequestMessage):
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request, log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
node.clock.seconds()) node.clock.seconds())

View file

@ -1,47 +1,50 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
from twisted.trial import unittest from twisted.trial import unittest
import lbrynet.dht.encoding from lbrynet.dht.encoding import bencode, bdecode, DecodeError
class BencodeTest(unittest.TestCase): class EncodeDecodeTest(unittest.TestCase):
""" Basic tests case for the Bencode implementation """
def setUp(self):
self.encoding = lbrynet.dht.encoding.Bencode()
# Thanks goes to wikipedia for the initial test cases ;-)
self.cases = ((42, b'i42e'),
(b'spam', b'4:spam'),
([b'spam', 42], b'l4:spami42ee'),
({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee'),
# ...and now the "real life" tests
([[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]],
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'))
# The following test cases are "bad"; i.e. sending rubbish into the decoder to test
# what exceptions get thrown
self.badDecoderCases = (b'abcdefghijklmnopqrstuvwxyz',
b'')
def testEncoder(self): def test_integer(self):
""" Tests the bencode encoder """ self.assertEqual(bencode(42), b'i42e')
for value, encodedValue in self.cases:
result = self.encoding.encode(value)
self.assertEqual(
result, encodedValue,
'Value "%s" not correctly encoded! Expected "%s", got "%s"' %
(value, encodedValue, result))
def testDecoder(self): self.assertEqual(bdecode(b'i42e'), 42)
""" Tests the bencode decoder """
for value, encodedValue in self.cases: def test_bytes(self):
result = self.encoding.decode(encodedValue) self.assertEqual(bencode(b''), b'0:')
self.assertEqual( self.assertEqual(bencode(b'spam'), b'4:spam')
result, value, self.assertEqual(bencode(b'4:spam'), b'6:4:spam')
'Value "%s" not correctly decoded! Expected "%s", got "%s"' % self.assertEqual(bencode(bytearray(b'spam')), b'4:spam')
(encodedValue, value, result))
for encodedValue in self.badDecoderCases: self.assertEqual(bdecode(b'0:'), b'')
self.assertRaises( self.assertEqual(bdecode(b'4:spam'), b'spam')
lbrynet.dht.encoding.DecodeError, self.encoding.decode, encodedValue) self.assertEqual(bdecode(b'6:4:spam'), b'4:spam')
def test_string(self):
self.assertEqual(bencode(''), b'0:')
self.assertEqual(bencode('spam'), b'4:spam')
self.assertEqual(bencode('4:spam'), b'6:4:spam')
def test_list(self):
self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee')
self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42])
def test_dict(self):
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
def test_mixed(self):
self.assertEqual(bencode(
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
)
self.assertEqual(bdecode(
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'),
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
)
def test_decode_error(self):
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz')
self.assertRaises(DecodeError, bdecode, b'')

View file

@ -1,56 +1,40 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
import hashlib import hashlib
from twisted.trial import unittest
import struct import struct
from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
from lbrynet.dht.node import Node from lbrynet.dht.node import Node
from lbrynet.dht import constants from lbrynet.dht import constants
from lbrynet.core.utils import generate_id
class NodeIDTest(unittest.TestCase): class NodeIDTest(unittest.TestCase):
""" Test case for the Node class's ID """
def setUp(self): def setUp(self):
self.node = Node() self.node = Node()
def testAutoCreatedID(self): def test_new_node_has_auto_created_id(self):
""" Tests if a new node has a valid node ID """ self.assertEqual(type(self.node.node_id), bytes)
self.assertEqual(type(self.node.node_id), bytes, 'Node does not have a valid ID') self.assertEqual(len(self.node.node_id), 48)
self.assertEqual(len(self.node.node_id), 48, 'Node ID length is incorrect! '
'Expected 384 bits, got %d bits.' %
(len(self.node.node_id) * 8))
def testUniqueness(self): def test_uniqueness_and_length_of_generated_ids(self):
""" Tests the uniqueness of the values created by the NodeID generator """ previous_ids = []
generatedIDs = []
for i in range(100): for i in range(100):
newID = self.node._generateID() new_id = self.node._generateID()
# ugly uniqueness test self.assertNotIn(new_id, previous_ids, 'id at index {} not unique'.format(i))
self.assertFalse(newID in generatedIDs, 'Generated ID #%d not unique!' % (i+1)) self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id)))
generatedIDs.append(newID) previous_ids.append(new_id)
def testKeyLength(self):
""" Tests the key Node ID key length """
for i in range(20):
id = self.node._generateID()
# Key length: 20 bytes == 160 bits
self.assertEqual(len(id), 48,
'Length of generated ID is incorrect! Expected 384 bits, '
'got %d bits.' % (len(id)*8))
class NodeDataTest(unittest.TestCase): class NodeDataTest(unittest.TestCase):
""" Test case for the Node class's data-related functions """ """ Test case for the Node class's data-related functions """
def setUp(self): def setUp(self):
h = hashlib.sha384() h = hashlib.sha384()
h.update(b'test') h.update(b'test')
self.node = Node() self.node = Node()
self.contact = self.node.contact_manager.make_contact(h.digest(), '127.0.0.1', 12345, self.node._protocol) self.contact = self.node.contact_manager.make_contact(
h.digest(), '127.0.0.1', 12345, self.node._protocol)
self.token = self.node.make_token(self.contact.compact_ip()) self.token = self.node.make_token(self.contact.compact_ip())
self.cases = [] self.cases = []
for i in range(5): for i in range(5):
@ -59,19 +43,18 @@ class NodeDataTest(unittest.TestCase):
self.cases.append((h.digest(), 5001+2*i)) self.cases.append((h.digest(), 5001+2*i))
@defer.inlineCallbacks @defer.inlineCallbacks
def testStore(self): def test_store(self):
""" Tests if the node can store (and privately retrieve) some data """ """ Tests if the node can store (and privately retrieve) some data """
for key, port in self.cases: for key, port in self.cases:
yield self.node.store( # pylint: disable=too-many-function-args yield self.node.store(
self.contact, key, self.token, port, self.contact.id, 0 self.contact, key, self.token, port, self.contact.id, 0
) )
for key, value in self.cases: for key, value in self.cases:
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + \ expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id
self.contact.id
self.assertTrue(self.node._dataStore.hasPeersForBlob(key), self.assertTrue(self.node._dataStore.hasPeersForBlob(key),
'Stored key not found in node\'s DataStore: "%s"' % key) "Stored key not found in node's DataStore: '%s'" % key)
self.assertTrue(expected_result in self.node._dataStore.getPeersForBlob(key), self.assertTrue(expected_result in self.node._dataStore.getPeersForBlob(key),
'Stored val not found in node\'s DataStore: key:"%s" port:"%s" %s' "Stored val not found in node's DataStore: key:'%s' port:'%s' %s"
% (key, value, self.node._dataStore.getPeersForBlob(key))) % (key, value, self.node._dataStore.getPeersForBlob(key)))
@ -81,182 +64,25 @@ class NodeContactTest(unittest.TestCase):
self.node = Node() self.node = Node()
@defer.inlineCallbacks @defer.inlineCallbacks
def testAddContact(self): def test_add_contact(self):
""" Tests if a contact can be added and retrieved correctly """ """ Tests if a contact can be added and retrieved correctly """
# Create the contact # Create the contact
h = hashlib.sha384() contact_id = generate_id(b'node1')
h.update(b'node1') contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol)
contactID = h.digest()
contact = self.node.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.node._protocol)
# Now add it... # Now add it...
yield self.node.addContact(contact) yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE # ...and request the closest nodes to it using FIND_NODE
closestNodes = self.node._routingTable.findCloseNodes(contactID, constants.k) closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k)
self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; ' self.assertEqual(len(closest_nodes), 1)
'expected 1, got %d' % len(closestNodes)) self.assertIn(contact, closest_nodes)
self.assertTrue(contact in closestNodes, 'Added contact not found by issueing '
'_findCloseNodes()')
@defer.inlineCallbacks @defer.inlineCallbacks
def testAddSelfAsContact(self): def test_add_self_as_contact(self):
""" Tests the node's behaviour when attempting to add itself as a contact """ """ Tests the node's behaviour when attempting to add itself as a contact """
# Create a contact with the same ID as the local node's ID # Create a contact with the same ID as the local node's ID
contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None) contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None)
# Now try to add it # Now try to add it
yield self.node.addContact(contact) yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE # ...and request the closest nodes to it using FIND_NODE
closestNodes = self.node._routingTable.findCloseNodes(self.node.node_id, closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k)
constants.k) self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.')
self.assertFalse(contact in closestNodes, 'Node added itself as a contact')
# class FakeRPCProtocol(protocol.DatagramProtocol):
# def __init__(self):
# self.reactor = selectreactor.SelectReactor()
# self.testResponse = None
# self.network = None
#
# def createNetwork(self, contactNetwork):
# """
# set up a list of contacts together with their closest contacts
# @param contactNetwork: a sequence of tuples, each containing a contact together with its
# closest contacts: C{(<contact>, <closest contact 1, ...,closest contact n>)}
# """
# self.network = contactNetwork
#
# def sendRPC(self, contact, method, args, rawResponse=False):
# """ Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs"""
#
# h = hashlib.sha384()
# h.update('rpcId')
# rpc_id = h.digest()[:20]
#
# if method == "findNode":
# # get the specific contacts closest contacts
# closestContacts = []
# closestContactsList = []
# for contactTuple in self.network:
# if contact == contactTuple[0]:
# # get the list of closest contacts for this contact
# closestContactsList = contactTuple[1]
# # Pack the closest contacts into a ResponseMessage
# for closeContact in closestContactsList:
# closestContacts.append((closeContact.id, closeContact.address, closeContact.port))
#
# message = ResponseMessage(rpc_id, contact.id, closestContacts)
# df = defer.Deferred()
# df.callback((message, (contact.address, contact.port)))
# return df
# elif method == "findValue":
# for contactTuple in self.network:
# if contact == contactTuple[0]:
# # Get the data stored by this remote contact
# dataDict = contactTuple[2]
# dataKey = dataDict.keys()[0]
# data = dataDict.get(dataKey)
# # Check if this contact has the requested value
# if dataKey == args[0]:
# # Return the data value
# response = dataDict
# print "data found at contact: " + contact.id
# else:
# # Return the closest contact to the requested data key
# print "data not found at contact: " + contact.id
# closeContacts = contactTuple[1]
# closestContacts = []
# for closeContact in closeContacts:
# closestContacts.append((closeContact.id, closeContact.address,
# closeContact.port))
# response = closestContacts
#
# # Create the response message
# message = ResponseMessage(rpc_id, contact.id, response)
# df = defer.Deferred()
# df.callback((message, (contact.address, contact.port)))
# return df
#
# def _send(self, data, rpcID, address):
# """ fake sending data """
#
#
# class NodeLookupTest(unittest.TestCase):
# """ Test case for the Node class's iterativeFind node lookup algorithm """
#
# def setUp(self):
# # create a fake protocol to imitate communication with other nodes
# self._protocol = FakeRPCProtocol()
# # Note: The reactor is never started for this test. All deferred calls run sequentially,
# # since there is no asynchronous network communication
# # create the node to be tested in isolation
# h = hashlib.sha384()
# h.update('node1')
# node_id = str(h.digest())
# self.node = Node(node_id, 4000, None, None, self._protocol)
# self.updPort = 81173
# self.contactsAmount = 80
# # Reinitialise the routing table
# self.node._routingTable = TreeRoutingTable(self.node.node_id)
#
# # create 160 bit node ID's for test purposes
# self.testNodeIDs = []
# idNum = int(self.node.node_id.encode('hex'), 16)
# for i in range(self.contactsAmount):
# # create the testNodeIDs in ascending order, away from the actual node ID,
# # with regards to the distance metric
# self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex'))
#
# # generate contacts
# self.contacts = []
# for i in range(self.contactsAmount):
# contact = self.node.contact_manager.make_contact(self.testNodeIDs[i], "127.0.0.1",
# self.updPort + i + 1, self._protocol)
# self.contacts.append(contact)
#
# # create the network of contacts in format: (contact, closest contacts)
# contactNetwork = ((self.contacts[0], self.contacts[8:15]),
# (self.contacts[1], self.contacts[16:23]),
# (self.contacts[2], self.contacts[24:31]),
# (self.contacts[3], self.contacts[32:39]),
# (self.contacts[4], self.contacts[40:47]),
# (self.contacts[5], self.contacts[48:55]),
# (self.contacts[6], self.contacts[56:63]),
# (self.contacts[7], self.contacts[64:71]),
# (self.contacts[8], self.contacts[72:79]),
# (self.contacts[40], self.contacts[41:48]),
# (self.contacts[41], self.contacts[41:48]),
# (self.contacts[42], self.contacts[41:48]),
# (self.contacts[43], self.contacts[41:48]),
# (self.contacts[44], self.contacts[41:48]),
# (self.contacts[45], self.contacts[41:48]),
# (self.contacts[46], self.contacts[41:48]),
# (self.contacts[47], self.contacts[41:48]),
# (self.contacts[48], self.contacts[41:48]),
# (self.contacts[50], self.contacts[0:7]),
# (self.contacts[51], self.contacts[8:15]),
# (self.contacts[52], self.contacts[16:23]))
#
# contacts_with_datastores = []
#
# for contact_tuple in contactNetwork:
# contacts_with_datastores.append((contact_tuple[0], contact_tuple[1],
# DictDataStore()))
# self._protocol.createNetwork(contacts_with_datastores)
#
# # @defer.inlineCallbacks
# # def testNodeBootStrap(self):
# # """ Test bootstrap with the closest possible contacts """
# # # Set the expected result
# # expectedResult = {item.id for item in self.contacts[0:8]}
# #
# # activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8])
# #
# # # Check the length of the active contacts
# # self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(),
# # "More active contacts should exist, there should be %d "
# # "contacts but there are %d" % (len(expectedResult),
# # len(activeContacts)))
# #
# # # Check that the received active contacts are the same as the input contacts
# # self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult,
# # "Active should only contain the closest possible contacts"
# # " which were used as input for the boostrap")

View file

@ -1,4 +1,3 @@
import hashlib
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest from twisted.trial import unittest
@ -7,9 +6,7 @@ from lbrynet.dht import constants
from lbrynet.dht.routingtable import TreeRoutingTable from lbrynet.dht.routingtable import TreeRoutingTable
from lbrynet.dht.contact import ContactManager from lbrynet.dht.contact import ContactManager
from lbrynet.dht.distance import Distance from lbrynet.dht.distance import Distance
import sys from lbrynet.core.utils import generate_id
if sys.version_info > (3,):
long = int
class FakeRPCProtocol(object): class FakeRPCProtocol(object):
@ -21,76 +18,61 @@ class FakeRPCProtocol(object):
class TreeRoutingTableTest(unittest.TestCase): class TreeRoutingTableTest(unittest.TestCase):
""" Test case for the RoutingTable class """ """ Test case for the RoutingTable class """
def setUp(self): def setUp(self):
h = hashlib.sha384()
h.update(b'node1')
self.contact_manager = ContactManager() self.contact_manager = ContactManager()
self.nodeID = h.digest() self.nodeID = generate_id(b'node1')
self.protocol = FakeRPCProtocol() self.protocol = FakeRPCProtocol()
self.routingTable = TreeRoutingTable(self.nodeID) self.routingTable = TreeRoutingTable(self.nodeID)
def testDistance(self): def test_distance(self):
""" Test to see if distance method returns correct result""" """ Test to see if distance method returns correct result"""
d = Distance(bytes((170,) * 48))
# testList holds a couple 3-tuple (variable1, variable2, result) result = d(bytes((85,) * 48))
basicTestList = [(bytes(b'\xaa' * 48), bytes(b'\x55' * 48), long(hexlify(bytes(b'\xff' * 48)), 16))] expected = int(hexlify(bytes((255,) * 48)), 16)
self.assertEqual(result, expected)
for test in basicTestList:
result = Distance(test[0])(test[1])
self.assertFalse(result != test[2], 'Result of _distance() should be %s but %s returned' %
(test[2], result))
@defer.inlineCallbacks @defer.inlineCallbacks
def testAddContact(self): def test_add_contact(self):
""" Tests if a contact can be added and retrieved correctly """ """ Tests if a contact can be added and retrieved correctly """
# Create the contact # Create the contact
h = hashlib.sha384() contact_id = generate_id(b'node2')
h.update(b'node2') contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
contactID = h.digest()
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
# Now add it... # Now add it...
yield self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it (will retrieve it) # ...and request the closest nodes to it (will retrieve it)
closestNodes = self.routingTable.findCloseNodes(contactID) closest_nodes = self.routingTable.findCloseNodes(contact_id)
self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,' self.assertEqual(len(closest_nodes), 1)
' got %d' % len(closestNodes)) self.assertIn(contact, closest_nodes)
self.assertTrue(contact in closestNodes, 'Added contact not found by issueing '
'_findCloseNodes()')
@defer.inlineCallbacks @defer.inlineCallbacks
def testGetContact(self): def test_get_contact(self):
""" Tests if a specific existing contact can be retrieved correctly """ """ Tests if a specific existing contact can be retrieved correctly """
h = hashlib.sha384() contact_id = generate_id(b'node2')
h.update(b'node2') contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
contactID = h.digest()
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
# Now add it... # Now add it...
yield self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# ...and get it again # ...and get it again
sameContact = self.routingTable.getContact(contactID) same_contact = self.routingTable.getContact(contact_id)
self.assertEqual(contact, sameContact, 'getContact() should return the same contact') self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
@defer.inlineCallbacks @defer.inlineCallbacks
def testAddParentNodeAsContact(self): def test_add_parent_node_as_contact(self):
""" """
Tests the routing table's behaviour when attempting to add its parent node as a contact Tests the routing table's behaviour when attempting to add its parent node as a contact
""" """
# Create a contact with the same ID as the local node's ID # Create a contact with the same ID as the local node's ID
contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol) contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
# Now try to add it # Now try to add it
yield self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE # ...and request the closest nodes to it using FIND_NODE
closestNodes = self.routingTable.findCloseNodes(self.nodeID, constants.k) closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
self.assertFalse(contact in closestNodes, 'Node added itself as a contact') self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
@defer.inlineCallbacks @defer.inlineCallbacks
def testRemoveContact(self): def test_remove_contact(self):
""" Tests contact removal """ """ Tests contact removal """
# Create the contact # Create the contact
h = hashlib.sha384() contact_id = generate_id(b'node2')
h.update(b'node2') contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
contactID = h.digest()
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
# Now add it... # Now add it...
yield self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# Verify addition # Verify addition
@ -100,25 +82,22 @@ class TreeRoutingTableTest(unittest.TestCase):
self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly') self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
@defer.inlineCallbacks @defer.inlineCallbacks
def testSplitBucket(self): def test_split_bucket(self):
""" Tests if the the routing table correctly dynamically splits k-buckets """ """ Tests if the the routing table correctly dynamically splits k-buckets """
self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384, self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
'Initial k-bucket range should be 0 <= range < 2**384') 'Initial k-bucket range should be 0 <= range < 2**384')
# Add k contacts # Add k contacts
for i in range(constants.k): for i in range(constants.k):
h = hashlib.sha384() node_id = generate_id(b'remote node %d' % i)
h.update(b'remote node %d' % i) contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
nodeID = h.digest()
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 1, self.assertEqual(len(self.routingTable._buckets), 1,
'Only k nodes have been added; the first k-bucket should now ' 'Only k nodes have been added; the first k-bucket should now '
'be full, but should not yet be split') 'be full, but should not yet be split')
# Now add 1 more contact # Now add 1 more contact
h = hashlib.sha384() node_id = generate_id(b'yet another remote node')
h.update(b'yet another remote node') contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
nodeID = h.digest()
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 2, self.assertEqual(len(self.routingTable._buckets), 2,
'k+1 nodes have been added; the first k-bucket should have been ' 'k+1 nodes have been added; the first k-bucket should have been '
@ -134,7 +113,7 @@ class TreeRoutingTableTest(unittest.TestCase):
'not divided properly') 'not divided properly')
@defer.inlineCallbacks @defer.inlineCallbacks
def testFullSplit(self): def test_full_split(self):
""" """
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
""" """