From bc24dbea2926a875999a791da86769cb2c24ee7a Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Mon, 30 Jul 2018 21:23:38 -0400 Subject: [PATCH] refactoring of DHT tests and fixed encoding bug when dealing with bytearray --- lbrynet/dht/encoding.py | 200 +++++++-------------- lbrynet/dht/protocol.py | 9 +- tests/functional/dht/mock_transport.py | 10 +- tests/unit/dht/test_encoding.py | 87 +++++----- tests/unit/dht/test_node.py | 232 ++++--------------------- tests/unit/dht/test_routingtable.py | 83 ++++----- 6 files changed, 179 insertions(+), 442 deletions(-) diff --git a/lbrynet/dht/encoding.py b/lbrynet/dht/encoding.py index 631c65e36..f31bd119f 100644 --- a/lbrynet/dht/encoding.py +++ b/lbrynet/dht/encoding.py @@ -1,141 +1,75 @@ -from __future__ import print_function from .error import DecodeError -import sys -if sys.version_info > (3,): - long = int - raw = ord -else: - raw = lambda x: x - -class Encoding: - """ Interface for RPC message encoders/decoders - - All encoding implementations used with this library should inherit and - implement this. - """ - - def encode(self, data): - """ Encode the specified data - - @param data: The data to encode - This method has to support encoding of the following - types: C{str}, C{int} and C{long} - Any additional data types may be supported as long as the - implementing class's C{decode()} method can successfully - decode them. - - @return: The encoded data - @rtype: str - """ - - def decode(self, data): - """ Decode the specified data string - - @param data: The data (byte string) to decode. - @type data: str - - @return: The decoded data (in its correct type) - """ -class Bencode(Encoding): - """ Implementation of a Bencode-based algorithm (Bencode is the encoding - algorithm used by Bittorrent). +def bencode(data): + """ Encoder implementation of the Bencode algorithm (Bittorrent). """ + if isinstance(data, int): + return b'i%de' % data + elif isinstance(data, (bytes, bytearray)): + return b'%d:%s' % (len(data), data) + elif isinstance(data, str): + return b'%d:%s' % (len(data), data.encode()) + elif isinstance(data, (list, tuple)): + encoded_list_items = b'' + for item in data: + encoded_list_items += bencode(item) + return b'l%se' % encoded_list_items + elif isinstance(data, dict): + encoded_dict_items = b'' + keys = data.keys() + for key in sorted(keys): + encoded_dict_items += bencode(key) + encoded_dict_items += bencode(data[key]) + return b'd%se' % encoded_dict_items + else: + raise TypeError("Cannot bencode '%s' object" % type(data)) - @note: This algorithm differs from the "official" Bencode algorithm in - that it can encode/decode floating point values in addition to - integers. - """ - def encode(self, data): - """ Encoder implementation of the Bencode algorithm +def bdecode(data): + """ Decoder implementation of the Bencode algorithm. """ + assert type(data) == bytes # fixme: _maybe_ remove this after porting + if len(data) == 0: + raise DecodeError('Cannot decode empty string') + try: + return _decode_recursive(data)[0] + except ValueError as e: + raise DecodeError(str(e)) - @param data: The data to encode - @type data: int, long, tuple, list, dict or str - @return: The encoded data - @rtype: str - """ - if isinstance(data, (int, long)): - return b'i%de' % data - elif isinstance(data, bytes): - return b'%d:%s' % (len(data), data) - elif isinstance(data, str): - return b'%d:' % (len(data)) + data.encode() - elif isinstance(data, (list, tuple)): - encodedListItems = b'' - for item in data: - encodedListItems += self.encode(item) - return b'l%se' % encodedListItems - elif isinstance(data, dict): - encodedDictItems = b'' - keys = data.keys() - for key in sorted(keys): - encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings - encodedDictItems += self.encode(data[key]) - return b'd%se' % encodedDictItems - else: - raise TypeError("Cannot bencode '%s' object" % type(data)) - - def decode(self, data): - """ Decoder implementation of the Bencode algorithm - - @param data: The encoded data - @type data: str - - @note: This is a convenience wrapper for the recursive decoding - algorithm, C{_decodeRecursive} - - @return: The decoded data, as a native Python type - @rtype: int, list, dict or str - """ - assert type(data) == bytes # fixme: _maybe_ remove this after porting - if len(data) == 0: - raise DecodeError('Cannot decode empty string') +def _decode_recursive(data, start_index=0): + if data[start_index] == ord('i'): + end_pos = data[start_index:].find(b'e') + start_index + return int(data[start_index + 1:end_pos]), end_pos + 1 + elif data[start_index] == ord('l'): + start_index += 1 + decoded_list = [] + while data[start_index] != ord('e'): + list_data, start_index = _decode_recursive(data, start_index) + decoded_list.append(list_data) + return decoded_list, start_index + 1 + elif data[start_index] == ord('d'): + start_index += 1 + decoded_dict = {} + while data[start_index] != ord('e'): + key, start_index = _decode_recursive(data, start_index) + value, start_index = _decode_recursive(data, start_index) + decoded_dict[key] = value + return decoded_dict, start_index + elif data[start_index] == ord('f'): + # This (float data type) is a non-standard extension to the original Bencode algorithm + end_pos = data[start_index:].find(b'e') + start_index + return float(data[start_index + 1:end_pos]), end_pos + 1 + elif data[start_index] == ord('n'): + # This (None/NULL data type) is a non-standard extension + # to the original Bencode algorithm + return None, start_index + 1 + else: + split_pos = data[start_index:].find(b':') + start_index try: - return self._decodeRecursive(data)[0] - except ValueError as e: - raise DecodeError(e.message) - - @staticmethod - def _decodeRecursive(data, startIndex=0): - """ Actual implementation of the recursive Bencode algorithm - - Do not call this; use C{decode()} instead - """ - if data[startIndex] == raw('i'): - endPos = data[startIndex:].find(b'e') + startIndex - return int(data[startIndex + 1:endPos]), endPos + 1 - elif data[startIndex] == raw('l'): - startIndex += 1 - decodedList = [] - while data[startIndex] != raw('e'): - listData, startIndex = Bencode._decodeRecursive(data, startIndex) - decodedList.append(listData) - return decodedList, startIndex + 1 - elif data[startIndex] == raw('d'): - startIndex += 1 - decodedDict = {} - while data[startIndex] != raw('e'): - key, startIndex = Bencode._decodeRecursive(data, startIndex) - value, startIndex = Bencode._decodeRecursive(data, startIndex) - decodedDict[key] = value - return decodedDict, startIndex - elif data[startIndex] == raw('f'): - # This (float data type) is a non-standard extension to the original Bencode algorithm - endPos = data[startIndex:].find(b'e') + startIndex - return float(data[startIndex + 1:endPos]), endPos + 1 - elif data[startIndex] == raw('n'): - # This (None/NULL data type) is a non-standard extension - # to the original Bencode algorithm - return None, startIndex + 1 - else: - splitPos = data[startIndex:].find(b':') + startIndex - try: - length = int(data[startIndex:splitPos]) - except ValueError: - raise DecodeError() - startIndex = splitPos + 1 - endPos = startIndex + length - bytes = data[startIndex:endPos] - return bytes, endPos + length = int(data[start_index:split_pos]) + except ValueError: + raise DecodeError() + start_index = split_pos + 1 + end_pos = start_index + length + b = data[start_index:end_pos] + return b, end_pos diff --git a/lbrynet/dht/protocol.py b/lbrynet/dht/protocol.py index 341be18b9..01c3bddb5 100644 --- a/lbrynet/dht/protocol.py +++ b/lbrynet/dht/protocol.py @@ -97,7 +97,6 @@ class KademliaProtocol(protocol.DatagramProtocol): def __init__(self, node): self._node = node - self._encoder = encoding.Bencode() self._translator = msgformat.DefaultFormat() self._sentMessages = {} self._partialMessages = {} @@ -163,7 +162,7 @@ class KademliaProtocol(protocol.DatagramProtocol): msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method, *args)) msgPrimitive = self._translator.toPrimitive(msg) - encodedMsg = self._encoder.encode(msgPrimitive) + encodedMsg = encoding.bencode(msgPrimitive) if args: log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method, @@ -237,7 +236,7 @@ class KademliaProtocol(protocol.DatagramProtocol): else: return try: - msgPrimitive = self._encoder.decode(datagram) + msgPrimitive = encoding.bdecode(datagram) message = self._translator.fromPrimitive(msgPrimitive) except (encoding.DecodeError, ValueError) as err: # We received some rubbish here @@ -394,7 +393,7 @@ class KademliaProtocol(protocol.DatagramProtocol): """ msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response) msgPrimitive = self._translator.toPrimitive(msg) - encodedMsg = self._encoder.encode(msgPrimitive) + encodedMsg = encoding.bencode(msgPrimitive) self._send(encodedMsg, rpcID, (contact.address, contact.port)) def _sendError(self, contact, rpcID, exceptionType, exceptionMessage): @@ -403,7 +402,7 @@ class KademliaProtocol(protocol.DatagramProtocol): exceptionMessage = exceptionMessage.encode() msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage) msgPrimitive = self._translator.toPrimitive(msg) - encodedMsg = self._encoder.encode(msgPrimitive) + encodedMsg = encoding.bencode(msgPrimitive) self._send(encodedMsg, rpcID, (contact.address, contact.port)) def _handleRPC(self, senderContact, rpcID, method, args): diff --git a/tests/functional/dht/mock_transport.py b/tests/functional/dht/mock_transport.py index a000b1773..ac006f6f4 100644 --- a/tests/functional/dht/mock_transport.py +++ b/tests/functional/dht/mock_transport.py @@ -1,18 +1,14 @@ import struct import hashlib import logging -from binascii import unhexlify, hexlify +from binascii import unhexlify from twisted.internet import defer, error -from lbrynet.dht.encoding import Bencode +from lbrynet.dht import encoding from lbrynet.dht.error import DecodeError from lbrynet.dht.msgformat import DefaultFormat from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage -import sys -if sys.version_info > (3,): - unicode = str -_encode = Bencode() _datagram_formatter = DefaultFormat() log = logging.getLogger() @@ -138,7 +134,7 @@ def debug_kademlia_packet(data, source, destination, node): if log.level != logging.DEBUG: return try: - packet = _datagram_formatter.fromPrimitive(_encode.decode(data)) + packet = _datagram_formatter.fromPrimitive(encoding.bdecode(data)) if isinstance(packet, RequestMessage): log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request, node.clock.seconds()) diff --git a/tests/unit/dht/test_encoding.py b/tests/unit/dht/test_encoding.py index bf968c04d..da29c67b1 100644 --- a/tests/unit/dht/test_encoding.py +++ b/tests/unit/dht/test_encoding.py @@ -1,47 +1,50 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive - from twisted.trial import unittest -import lbrynet.dht.encoding +from lbrynet.dht.encoding import bencode, bdecode, DecodeError -class BencodeTest(unittest.TestCase): - """ Basic tests case for the Bencode implementation """ - def setUp(self): - self.encoding = lbrynet.dht.encoding.Bencode() - # Thanks goes to wikipedia for the initial test cases ;-) - self.cases = ((42, b'i42e'), - (b'spam', b'4:spam'), - ([b'spam', 42], b'l4:spami42ee'), - ({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee'), - # ...and now the "real life" tests - ([[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]], - b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee')) - # The following test cases are "bad"; i.e. sending rubbish into the decoder to test - # what exceptions get thrown - self.badDecoderCases = (b'abcdefghijklmnopqrstuvwxyz', - b'') +class EncodeDecodeTest(unittest.TestCase): - def testEncoder(self): - """ Tests the bencode encoder """ - for value, encodedValue in self.cases: - result = self.encoding.encode(value) - self.assertEqual( - result, encodedValue, - 'Value "%s" not correctly encoded! Expected "%s", got "%s"' % - (value, encodedValue, result)) + def test_integer(self): + self.assertEqual(bencode(42), b'i42e') - def testDecoder(self): - """ Tests the bencode decoder """ - for value, encodedValue in self.cases: - result = self.encoding.decode(encodedValue) - self.assertEqual( - result, value, - 'Value "%s" not correctly decoded! Expected "%s", got "%s"' % - (encodedValue, value, result)) - for encodedValue in self.badDecoderCases: - self.assertRaises( - lbrynet.dht.encoding.DecodeError, self.encoding.decode, encodedValue) + self.assertEqual(bdecode(b'i42e'), 42) + + def test_bytes(self): + self.assertEqual(bencode(b''), b'0:') + self.assertEqual(bencode(b'spam'), b'4:spam') + self.assertEqual(bencode(b'4:spam'), b'6:4:spam') + self.assertEqual(bencode(bytearray(b'spam')), b'4:spam') + + self.assertEqual(bdecode(b'0:'), b'') + self.assertEqual(bdecode(b'4:spam'), b'spam') + self.assertEqual(bdecode(b'6:4:spam'), b'4:spam') + + def test_string(self): + self.assertEqual(bencode(''), b'0:') + self.assertEqual(bencode('spam'), b'4:spam') + self.assertEqual(bencode('4:spam'), b'6:4:spam') + + def test_list(self): + self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee') + + self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42]) + + def test_dict(self): + self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee') + + self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'}) + + def test_mixed(self): + self.assertEqual(bencode( + [[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]), + b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee' + ) + + self.assertEqual(bdecode( + b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'), + [[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]] + ) + + def test_decode_error(self): + self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz') + self.assertRaises(DecodeError, bdecode, b'') diff --git a/tests/unit/dht/test_node.py b/tests/unit/dht/test_node.py index 8fe1b3378..0d6e2e232 100644 --- a/tests/unit/dht/test_node.py +++ b/tests/unit/dht/test_node.py @@ -1,56 +1,40 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive - import hashlib -from twisted.trial import unittest import struct +from twisted.trial import unittest from twisted.internet import defer from lbrynet.dht.node import Node from lbrynet.dht import constants +from lbrynet.core.utils import generate_id class NodeIDTest(unittest.TestCase): - """ Test case for the Node class's ID """ + def setUp(self): self.node = Node() - def testAutoCreatedID(self): - """ Tests if a new node has a valid node ID """ - self.assertEqual(type(self.node.node_id), bytes, 'Node does not have a valid ID') - self.assertEqual(len(self.node.node_id), 48, 'Node ID length is incorrect! ' - 'Expected 384 bits, got %d bits.' % - (len(self.node.node_id) * 8)) + def test_new_node_has_auto_created_id(self): + self.assertEqual(type(self.node.node_id), bytes) + self.assertEqual(len(self.node.node_id), 48) - def testUniqueness(self): - """ Tests the uniqueness of the values created by the NodeID generator """ - generatedIDs = [] + def test_uniqueness_and_length_of_generated_ids(self): + previous_ids = [] for i in range(100): - newID = self.node._generateID() - # ugly uniqueness test - self.assertFalse(newID in generatedIDs, 'Generated ID #%d not unique!' % (i+1)) - generatedIDs.append(newID) - - def testKeyLength(self): - """ Tests the key Node ID key length """ - for i in range(20): - id = self.node._generateID() - # Key length: 20 bytes == 160 bits - self.assertEqual(len(id), 48, - 'Length of generated ID is incorrect! Expected 384 bits, ' - 'got %d bits.' % (len(id)*8)) + new_id = self.node._generateID() + self.assertNotIn(new_id, previous_ids, 'id at index {} not unique'.format(i)) + self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id))) + previous_ids.append(new_id) class NodeDataTest(unittest.TestCase): """ Test case for the Node class's data-related functions """ + def setUp(self): h = hashlib.sha384() h.update(b'test') self.node = Node() - self.contact = self.node.contact_manager.make_contact(h.digest(), '127.0.0.1', 12345, self.node._protocol) + self.contact = self.node.contact_manager.make_contact( + h.digest(), '127.0.0.1', 12345, self.node._protocol) self.token = self.node.make_token(self.contact.compact_ip()) self.cases = [] for i in range(5): @@ -59,19 +43,18 @@ class NodeDataTest(unittest.TestCase): self.cases.append((h.digest(), 5001+2*i)) @defer.inlineCallbacks - def testStore(self): + def test_store(self): """ Tests if the node can store (and privately retrieve) some data """ for key, port in self.cases: - yield self.node.store( # pylint: disable=too-many-function-args + yield self.node.store( self.contact, key, self.token, port, self.contact.id, 0 ) for key, value in self.cases: - expected_result = self.contact.compact_ip() + struct.pack('>H', value) + \ - self.contact.id + expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id self.assertTrue(self.node._dataStore.hasPeersForBlob(key), - 'Stored key not found in node\'s DataStore: "%s"' % key) + "Stored key not found in node's DataStore: '%s'" % key) self.assertTrue(expected_result in self.node._dataStore.getPeersForBlob(key), - 'Stored val not found in node\'s DataStore: key:"%s" port:"%s" %s' + "Stored val not found in node's DataStore: key:'%s' port:'%s' %s" % (key, value, self.node._dataStore.getPeersForBlob(key))) @@ -81,182 +64,25 @@ class NodeContactTest(unittest.TestCase): self.node = Node() @defer.inlineCallbacks - def testAddContact(self): + def test_add_contact(self): """ Tests if a contact can be added and retrieved correctly """ # Create the contact - h = hashlib.sha384() - h.update(b'node1') - contactID = h.digest() - contact = self.node.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.node._protocol) + contact_id = generate_id(b'node1') + contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol) # Now add it... yield self.node.addContact(contact) # ...and request the closest nodes to it using FIND_NODE - closestNodes = self.node._routingTable.findCloseNodes(contactID, constants.k) - self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; ' - 'expected 1, got %d' % len(closestNodes)) - self.assertTrue(contact in closestNodes, 'Added contact not found by issueing ' - '_findCloseNodes()') + closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k) + self.assertEqual(len(closest_nodes), 1) + self.assertIn(contact, closest_nodes) @defer.inlineCallbacks - def testAddSelfAsContact(self): + def test_add_self_as_contact(self): """ Tests the node's behaviour when attempting to add itself as a contact """ # Create a contact with the same ID as the local node's ID contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None) # Now try to add it yield self.node.addContact(contact) # ...and request the closest nodes to it using FIND_NODE - closestNodes = self.node._routingTable.findCloseNodes(self.node.node_id, - constants.k) - self.assertFalse(contact in closestNodes, 'Node added itself as a contact') - - -# class FakeRPCProtocol(protocol.DatagramProtocol): -# def __init__(self): -# self.reactor = selectreactor.SelectReactor() -# self.testResponse = None -# self.network = None -# -# def createNetwork(self, contactNetwork): -# """ -# set up a list of contacts together with their closest contacts -# @param contactNetwork: a sequence of tuples, each containing a contact together with its -# closest contacts: C{(, )} -# """ -# self.network = contactNetwork -# -# def sendRPC(self, contact, method, args, rawResponse=False): -# """ Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs""" -# -# h = hashlib.sha384() -# h.update('rpcId') -# rpc_id = h.digest()[:20] -# -# if method == "findNode": -# # get the specific contacts closest contacts -# closestContacts = [] -# closestContactsList = [] -# for contactTuple in self.network: -# if contact == contactTuple[0]: -# # get the list of closest contacts for this contact -# closestContactsList = contactTuple[1] -# # Pack the closest contacts into a ResponseMessage -# for closeContact in closestContactsList: -# closestContacts.append((closeContact.id, closeContact.address, closeContact.port)) -# -# message = ResponseMessage(rpc_id, contact.id, closestContacts) -# df = defer.Deferred() -# df.callback((message, (contact.address, contact.port))) -# return df -# elif method == "findValue": -# for contactTuple in self.network: -# if contact == contactTuple[0]: -# # Get the data stored by this remote contact -# dataDict = contactTuple[2] -# dataKey = dataDict.keys()[0] -# data = dataDict.get(dataKey) -# # Check if this contact has the requested value -# if dataKey == args[0]: -# # Return the data value -# response = dataDict -# print "data found at contact: " + contact.id -# else: -# # Return the closest contact to the requested data key -# print "data not found at contact: " + contact.id -# closeContacts = contactTuple[1] -# closestContacts = [] -# for closeContact in closeContacts: -# closestContacts.append((closeContact.id, closeContact.address, -# closeContact.port)) -# response = closestContacts -# -# # Create the response message -# message = ResponseMessage(rpc_id, contact.id, response) -# df = defer.Deferred() -# df.callback((message, (contact.address, contact.port))) -# return df -# -# def _send(self, data, rpcID, address): -# """ fake sending data """ -# -# -# class NodeLookupTest(unittest.TestCase): -# """ Test case for the Node class's iterativeFind node lookup algorithm """ -# -# def setUp(self): -# # create a fake protocol to imitate communication with other nodes -# self._protocol = FakeRPCProtocol() -# # Note: The reactor is never started for this test. All deferred calls run sequentially, -# # since there is no asynchronous network communication -# # create the node to be tested in isolation -# h = hashlib.sha384() -# h.update('node1') -# node_id = str(h.digest()) -# self.node = Node(node_id, 4000, None, None, self._protocol) -# self.updPort = 81173 -# self.contactsAmount = 80 -# # Reinitialise the routing table -# self.node._routingTable = TreeRoutingTable(self.node.node_id) -# -# # create 160 bit node ID's for test purposes -# self.testNodeIDs = [] -# idNum = int(self.node.node_id.encode('hex'), 16) -# for i in range(self.contactsAmount): -# # create the testNodeIDs in ascending order, away from the actual node ID, -# # with regards to the distance metric -# self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex')) -# -# # generate contacts -# self.contacts = [] -# for i in range(self.contactsAmount): -# contact = self.node.contact_manager.make_contact(self.testNodeIDs[i], "127.0.0.1", -# self.updPort + i + 1, self._protocol) -# self.contacts.append(contact) -# -# # create the network of contacts in format: (contact, closest contacts) -# contactNetwork = ((self.contacts[0], self.contacts[8:15]), -# (self.contacts[1], self.contacts[16:23]), -# (self.contacts[2], self.contacts[24:31]), -# (self.contacts[3], self.contacts[32:39]), -# (self.contacts[4], self.contacts[40:47]), -# (self.contacts[5], self.contacts[48:55]), -# (self.contacts[6], self.contacts[56:63]), -# (self.contacts[7], self.contacts[64:71]), -# (self.contacts[8], self.contacts[72:79]), -# (self.contacts[40], self.contacts[41:48]), -# (self.contacts[41], self.contacts[41:48]), -# (self.contacts[42], self.contacts[41:48]), -# (self.contacts[43], self.contacts[41:48]), -# (self.contacts[44], self.contacts[41:48]), -# (self.contacts[45], self.contacts[41:48]), -# (self.contacts[46], self.contacts[41:48]), -# (self.contacts[47], self.contacts[41:48]), -# (self.contacts[48], self.contacts[41:48]), -# (self.contacts[50], self.contacts[0:7]), -# (self.contacts[51], self.contacts[8:15]), -# (self.contacts[52], self.contacts[16:23])) -# -# contacts_with_datastores = [] -# -# for contact_tuple in contactNetwork: -# contacts_with_datastores.append((contact_tuple[0], contact_tuple[1], -# DictDataStore())) -# self._protocol.createNetwork(contacts_with_datastores) -# -# # @defer.inlineCallbacks -# # def testNodeBootStrap(self): -# # """ Test bootstrap with the closest possible contacts """ -# # # Set the expected result -# # expectedResult = {item.id for item in self.contacts[0:8]} -# # -# # activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8]) -# # -# # # Check the length of the active contacts -# # self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(), -# # "More active contacts should exist, there should be %d " -# # "contacts but there are %d" % (len(expectedResult), -# # len(activeContacts))) -# # -# # # Check that the received active contacts are the same as the input contacts -# # self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult, -# # "Active should only contain the closest possible contacts" -# # " which were used as input for the boostrap") + closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k) + self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.') diff --git a/tests/unit/dht/test_routingtable.py b/tests/unit/dht/test_routingtable.py index fadd3ddef..e29477ebc 100644 --- a/tests/unit/dht/test_routingtable.py +++ b/tests/unit/dht/test_routingtable.py @@ -1,4 +1,3 @@ -import hashlib from binascii import hexlify, unhexlify from twisted.trial import unittest @@ -7,9 +6,7 @@ from lbrynet.dht import constants from lbrynet.dht.routingtable import TreeRoutingTable from lbrynet.dht.contact import ContactManager from lbrynet.dht.distance import Distance -import sys -if sys.version_info > (3,): - long = int +from lbrynet.core.utils import generate_id class FakeRPCProtocol(object): @@ -21,76 +18,61 @@ class FakeRPCProtocol(object): class TreeRoutingTableTest(unittest.TestCase): """ Test case for the RoutingTable class """ def setUp(self): - h = hashlib.sha384() - h.update(b'node1') self.contact_manager = ContactManager() - self.nodeID = h.digest() + self.nodeID = generate_id(b'node1') self.protocol = FakeRPCProtocol() self.routingTable = TreeRoutingTable(self.nodeID) - def testDistance(self): + def test_distance(self): """ Test to see if distance method returns correct result""" - - # testList holds a couple 3-tuple (variable1, variable2, result) - basicTestList = [(bytes(b'\xaa' * 48), bytes(b'\x55' * 48), long(hexlify(bytes(b'\xff' * 48)), 16))] - - for test in basicTestList: - result = Distance(test[0])(test[1]) - self.assertFalse(result != test[2], 'Result of _distance() should be %s but %s returned' % - (test[2], result)) + d = Distance(bytes((170,) * 48)) + result = d(bytes((85,) * 48)) + expected = int(hexlify(bytes((255,) * 48)), 16) + self.assertEqual(result, expected) @defer.inlineCallbacks - def testAddContact(self): + def test_add_contact(self): """ Tests if a contact can be added and retrieved correctly """ # Create the contact - h = hashlib.sha384() - h.update(b'node2') - contactID = h.digest() - contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol) + contact_id = generate_id(b'node2') + contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) # Now add it... yield self.routingTable.addContact(contact) # ...and request the closest nodes to it (will retrieve it) - closestNodes = self.routingTable.findCloseNodes(contactID) - self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,' - ' got %d' % len(closestNodes)) - self.assertTrue(contact in closestNodes, 'Added contact not found by issueing ' - '_findCloseNodes()') + closest_nodes = self.routingTable.findCloseNodes(contact_id) + self.assertEqual(len(closest_nodes), 1) + self.assertIn(contact, closest_nodes) @defer.inlineCallbacks - def testGetContact(self): + def test_get_contact(self): """ Tests if a specific existing contact can be retrieved correctly """ - h = hashlib.sha384() - h.update(b'node2') - contactID = h.digest() - contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol) + contact_id = generate_id(b'node2') + contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) # Now add it... yield self.routingTable.addContact(contact) # ...and get it again - sameContact = self.routingTable.getContact(contactID) - self.assertEqual(contact, sameContact, 'getContact() should return the same contact') + same_contact = self.routingTable.getContact(contact_id) + self.assertEqual(contact, same_contact, 'getContact() should return the same contact') @defer.inlineCallbacks - def testAddParentNodeAsContact(self): + def test_add_parent_node_as_contact(self): """ Tests the routing table's behaviour when attempting to add its parent node as a contact """ - # Create a contact with the same ID as the local node's ID contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol) # Now try to add it yield self.routingTable.addContact(contact) # ...and request the closest nodes to it using FIND_NODE - closestNodes = self.routingTable.findCloseNodes(self.nodeID, constants.k) - self.assertFalse(contact in closestNodes, 'Node added itself as a contact') + closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k) + self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact') @defer.inlineCallbacks - def testRemoveContact(self): + def test_remove_contact(self): """ Tests contact removal """ # Create the contact - h = hashlib.sha384() - h.update(b'node2') - contactID = h.digest() - contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol) + contact_id = generate_id(b'node2') + contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) # Now add it... yield self.routingTable.addContact(contact) # Verify addition @@ -100,25 +82,22 @@ class TreeRoutingTableTest(unittest.TestCase): self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly') @defer.inlineCallbacks - def testSplitBucket(self): + def test_split_bucket(self): """ Tests if the the routing table correctly dynamically splits k-buckets """ self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384, 'Initial k-bucket range should be 0 <= range < 2**384') # Add k contacts for i in range(constants.k): - h = hashlib.sha384() - h.update(b'remote node %d' % i) - nodeID = h.digest() - contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol) + node_id = generate_id(b'remote node %d' % i) + contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol) yield self.routingTable.addContact(contact) + self.assertEqual(len(self.routingTable._buckets), 1, 'Only k nodes have been added; the first k-bucket should now ' 'be full, but should not yet be split') # Now add 1 more contact - h = hashlib.sha384() - h.update(b'yet another remote node') - nodeID = h.digest() - contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol) + node_id = generate_id(b'yet another remote node') + contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol) yield self.routingTable.addContact(contact) self.assertEqual(len(self.routingTable._buckets), 2, 'k+1 nodes have been added; the first k-bucket should have been ' @@ -134,7 +113,7 @@ class TreeRoutingTableTest(unittest.TestCase): 'not divided properly') @defer.inlineCallbacks - def testFullSplit(self): + def test_full_split(self): """ Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact """