refactoring of DHT tests and fixed encoding bug when dealing with bytearray
This commit is contained in:
parent
2d4bf73632
commit
bc24dbea29
6 changed files with 179 additions and 442 deletions
|
@ -1,141 +1,75 @@
|
|||
from __future__ import print_function
|
||||
from .error import DecodeError
|
||||
import sys
|
||||
if sys.version_info > (3,):
|
||||
long = int
|
||||
raw = ord
|
||||
else:
|
||||
raw = lambda x: x
|
||||
|
||||
class Encoding:
|
||||
""" Interface for RPC message encoders/decoders
|
||||
|
||||
All encoding implementations used with this library should inherit and
|
||||
implement this.
|
||||
"""
|
||||
|
||||
def encode(self, data):
|
||||
""" Encode the specified data
|
||||
|
||||
@param data: The data to encode
|
||||
This method has to support encoding of the following
|
||||
types: C{str}, C{int} and C{long}
|
||||
Any additional data types may be supported as long as the
|
||||
implementing class's C{decode()} method can successfully
|
||||
decode them.
|
||||
|
||||
@return: The encoded data
|
||||
@rtype: str
|
||||
"""
|
||||
|
||||
def decode(self, data):
|
||||
""" Decode the specified data string
|
||||
|
||||
@param data: The data (byte string) to decode.
|
||||
@type data: str
|
||||
|
||||
@return: The decoded data (in its correct type)
|
||||
"""
|
||||
|
||||
|
||||
class Bencode(Encoding):
|
||||
""" Implementation of a Bencode-based algorithm (Bencode is the encoding
|
||||
algorithm used by Bittorrent).
|
||||
|
||||
@note: This algorithm differs from the "official" Bencode algorithm in
|
||||
that it can encode/decode floating point values in addition to
|
||||
integers.
|
||||
"""
|
||||
|
||||
def encode(self, data):
|
||||
""" Encoder implementation of the Bencode algorithm
|
||||
|
||||
@param data: The data to encode
|
||||
@type data: int, long, tuple, list, dict or str
|
||||
|
||||
@return: The encoded data
|
||||
@rtype: str
|
||||
"""
|
||||
if isinstance(data, (int, long)):
|
||||
def bencode(data):
|
||||
""" Encoder implementation of the Bencode algorithm (Bittorrent). """
|
||||
if isinstance(data, int):
|
||||
return b'i%de' % data
|
||||
elif isinstance(data, bytes):
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
return b'%d:%s' % (len(data), data)
|
||||
elif isinstance(data, str):
|
||||
return b'%d:' % (len(data)) + data.encode()
|
||||
return b'%d:%s' % (len(data), data.encode())
|
||||
elif isinstance(data, (list, tuple)):
|
||||
encodedListItems = b''
|
||||
encoded_list_items = b''
|
||||
for item in data:
|
||||
encodedListItems += self.encode(item)
|
||||
return b'l%se' % encodedListItems
|
||||
encoded_list_items += bencode(item)
|
||||
return b'l%se' % encoded_list_items
|
||||
elif isinstance(data, dict):
|
||||
encodedDictItems = b''
|
||||
encoded_dict_items = b''
|
||||
keys = data.keys()
|
||||
for key in sorted(keys):
|
||||
encodedDictItems += self.encode(key) # TODO: keys should always be bytestrings
|
||||
encodedDictItems += self.encode(data[key])
|
||||
return b'd%se' % encodedDictItems
|
||||
encoded_dict_items += bencode(key)
|
||||
encoded_dict_items += bencode(data[key])
|
||||
return b'd%se' % encoded_dict_items
|
||||
else:
|
||||
raise TypeError("Cannot bencode '%s' object" % type(data))
|
||||
|
||||
def decode(self, data):
|
||||
""" Decoder implementation of the Bencode algorithm
|
||||
|
||||
@param data: The encoded data
|
||||
@type data: str
|
||||
|
||||
@note: This is a convenience wrapper for the recursive decoding
|
||||
algorithm, C{_decodeRecursive}
|
||||
|
||||
@return: The decoded data, as a native Python type
|
||||
@rtype: int, list, dict or str
|
||||
"""
|
||||
def bdecode(data):
|
||||
""" Decoder implementation of the Bencode algorithm. """
|
||||
assert type(data) == bytes # fixme: _maybe_ remove this after porting
|
||||
if len(data) == 0:
|
||||
raise DecodeError('Cannot decode empty string')
|
||||
try:
|
||||
return self._decodeRecursive(data)[0]
|
||||
return _decode_recursive(data)[0]
|
||||
except ValueError as e:
|
||||
raise DecodeError(e.message)
|
||||
raise DecodeError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def _decodeRecursive(data, startIndex=0):
|
||||
""" Actual implementation of the recursive Bencode algorithm
|
||||
|
||||
Do not call this; use C{decode()} instead
|
||||
"""
|
||||
if data[startIndex] == raw('i'):
|
||||
endPos = data[startIndex:].find(b'e') + startIndex
|
||||
return int(data[startIndex + 1:endPos]), endPos + 1
|
||||
elif data[startIndex] == raw('l'):
|
||||
startIndex += 1
|
||||
decodedList = []
|
||||
while data[startIndex] != raw('e'):
|
||||
listData, startIndex = Bencode._decodeRecursive(data, startIndex)
|
||||
decodedList.append(listData)
|
||||
return decodedList, startIndex + 1
|
||||
elif data[startIndex] == raw('d'):
|
||||
startIndex += 1
|
||||
decodedDict = {}
|
||||
while data[startIndex] != raw('e'):
|
||||
key, startIndex = Bencode._decodeRecursive(data, startIndex)
|
||||
value, startIndex = Bencode._decodeRecursive(data, startIndex)
|
||||
decodedDict[key] = value
|
||||
return decodedDict, startIndex
|
||||
elif data[startIndex] == raw('f'):
|
||||
def _decode_recursive(data, start_index=0):
|
||||
if data[start_index] == ord('i'):
|
||||
end_pos = data[start_index:].find(b'e') + start_index
|
||||
return int(data[start_index + 1:end_pos]), end_pos + 1
|
||||
elif data[start_index] == ord('l'):
|
||||
start_index += 1
|
||||
decoded_list = []
|
||||
while data[start_index] != ord('e'):
|
||||
list_data, start_index = _decode_recursive(data, start_index)
|
||||
decoded_list.append(list_data)
|
||||
return decoded_list, start_index + 1
|
||||
elif data[start_index] == ord('d'):
|
||||
start_index += 1
|
||||
decoded_dict = {}
|
||||
while data[start_index] != ord('e'):
|
||||
key, start_index = _decode_recursive(data, start_index)
|
||||
value, start_index = _decode_recursive(data, start_index)
|
||||
decoded_dict[key] = value
|
||||
return decoded_dict, start_index
|
||||
elif data[start_index] == ord('f'):
|
||||
# This (float data type) is a non-standard extension to the original Bencode algorithm
|
||||
endPos = data[startIndex:].find(b'e') + startIndex
|
||||
return float(data[startIndex + 1:endPos]), endPos + 1
|
||||
elif data[startIndex] == raw('n'):
|
||||
end_pos = data[start_index:].find(b'e') + start_index
|
||||
return float(data[start_index + 1:end_pos]), end_pos + 1
|
||||
elif data[start_index] == ord('n'):
|
||||
# This (None/NULL data type) is a non-standard extension
|
||||
# to the original Bencode algorithm
|
||||
return None, startIndex + 1
|
||||
return None, start_index + 1
|
||||
else:
|
||||
splitPos = data[startIndex:].find(b':') + startIndex
|
||||
split_pos = data[start_index:].find(b':') + start_index
|
||||
try:
|
||||
length = int(data[startIndex:splitPos])
|
||||
length = int(data[start_index:split_pos])
|
||||
except ValueError:
|
||||
raise DecodeError()
|
||||
startIndex = splitPos + 1
|
||||
endPos = startIndex + length
|
||||
bytes = data[startIndex:endPos]
|
||||
return bytes, endPos
|
||||
start_index = split_pos + 1
|
||||
end_pos = start_index + length
|
||||
b = data[start_index:end_pos]
|
||||
return b, end_pos
|
||||
|
|
|
@ -97,7 +97,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
|
|||
|
||||
def __init__(self, node):
|
||||
self._node = node
|
||||
self._encoder = encoding.Bencode()
|
||||
self._translator = msgformat.DefaultFormat()
|
||||
self._sentMessages = {}
|
||||
self._partialMessages = {}
|
||||
|
@ -163,7 +162,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
|
|||
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
|
||||
*args))
|
||||
msgPrimitive = self._translator.toPrimitive(msg)
|
||||
encodedMsg = self._encoder.encode(msgPrimitive)
|
||||
encodedMsg = encoding.bencode(msgPrimitive)
|
||||
|
||||
if args:
|
||||
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
|
||||
|
@ -237,7 +236,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
|
|||
else:
|
||||
return
|
||||
try:
|
||||
msgPrimitive = self._encoder.decode(datagram)
|
||||
msgPrimitive = encoding.bdecode(datagram)
|
||||
message = self._translator.fromPrimitive(msgPrimitive)
|
||||
except (encoding.DecodeError, ValueError) as err:
|
||||
# We received some rubbish here
|
||||
|
@ -394,7 +393,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
|
|||
"""
|
||||
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
|
||||
msgPrimitive = self._translator.toPrimitive(msg)
|
||||
encodedMsg = self._encoder.encode(msgPrimitive)
|
||||
encodedMsg = encoding.bencode(msgPrimitive)
|
||||
self._send(encodedMsg, rpcID, (contact.address, contact.port))
|
||||
|
||||
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
|
||||
|
@ -403,7 +402,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
|
|||
exceptionMessage = exceptionMessage.encode()
|
||||
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
|
||||
msgPrimitive = self._translator.toPrimitive(msg)
|
||||
encodedMsg = self._encoder.encode(msgPrimitive)
|
||||
encodedMsg = encoding.bencode(msgPrimitive)
|
||||
self._send(encodedMsg, rpcID, (contact.address, contact.port))
|
||||
|
||||
def _handleRPC(self, senderContact, rpcID, method, args):
|
||||
|
|
|
@ -1,18 +1,14 @@
|
|||
import struct
|
||||
import hashlib
|
||||
import logging
|
||||
from binascii import unhexlify, hexlify
|
||||
from binascii import unhexlify
|
||||
|
||||
from twisted.internet import defer, error
|
||||
from lbrynet.dht.encoding import Bencode
|
||||
from lbrynet.dht import encoding
|
||||
from lbrynet.dht.error import DecodeError
|
||||
from lbrynet.dht.msgformat import DefaultFormat
|
||||
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
|
||||
import sys
|
||||
if sys.version_info > (3,):
|
||||
unicode = str
|
||||
|
||||
_encode = Bencode()
|
||||
_datagram_formatter = DefaultFormat()
|
||||
|
||||
log = logging.getLogger()
|
||||
|
@ -138,7 +134,7 @@ def debug_kademlia_packet(data, source, destination, node):
|
|||
if log.level != logging.DEBUG:
|
||||
return
|
||||
try:
|
||||
packet = _datagram_formatter.fromPrimitive(_encode.decode(data))
|
||||
packet = _datagram_formatter.fromPrimitive(encoding.bdecode(data))
|
||||
if isinstance(packet, RequestMessage):
|
||||
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
|
||||
node.clock.seconds())
|
||||
|
|
|
@ -1,47 +1,50 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
|
||||
from twisted.trial import unittest
|
||||
import lbrynet.dht.encoding
|
||||
from lbrynet.dht.encoding import bencode, bdecode, DecodeError
|
||||
|
||||
|
||||
class BencodeTest(unittest.TestCase):
|
||||
""" Basic tests case for the Bencode implementation """
|
||||
def setUp(self):
|
||||
self.encoding = lbrynet.dht.encoding.Bencode()
|
||||
# Thanks goes to wikipedia for the initial test cases ;-)
|
||||
self.cases = ((42, b'i42e'),
|
||||
(b'spam', b'4:spam'),
|
||||
([b'spam', 42], b'l4:spami42ee'),
|
||||
({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee'),
|
||||
# ...and now the "real life" tests
|
||||
([[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]],
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'))
|
||||
# The following test cases are "bad"; i.e. sending rubbish into the decoder to test
|
||||
# what exceptions get thrown
|
||||
self.badDecoderCases = (b'abcdefghijklmnopqrstuvwxyz',
|
||||
b'')
|
||||
class EncodeDecodeTest(unittest.TestCase):
|
||||
|
||||
def testEncoder(self):
|
||||
""" Tests the bencode encoder """
|
||||
for value, encodedValue in self.cases:
|
||||
result = self.encoding.encode(value)
|
||||
self.assertEqual(
|
||||
result, encodedValue,
|
||||
'Value "%s" not correctly encoded! Expected "%s", got "%s"' %
|
||||
(value, encodedValue, result))
|
||||
def test_integer(self):
|
||||
self.assertEqual(bencode(42), b'i42e')
|
||||
|
||||
def testDecoder(self):
|
||||
""" Tests the bencode decoder """
|
||||
for value, encodedValue in self.cases:
|
||||
result = self.encoding.decode(encodedValue)
|
||||
self.assertEqual(
|
||||
result, value,
|
||||
'Value "%s" not correctly decoded! Expected "%s", got "%s"' %
|
||||
(encodedValue, value, result))
|
||||
for encodedValue in self.badDecoderCases:
|
||||
self.assertRaises(
|
||||
lbrynet.dht.encoding.DecodeError, self.encoding.decode, encodedValue)
|
||||
self.assertEqual(bdecode(b'i42e'), 42)
|
||||
|
||||
def test_bytes(self):
|
||||
self.assertEqual(bencode(b''), b'0:')
|
||||
self.assertEqual(bencode(b'spam'), b'4:spam')
|
||||
self.assertEqual(bencode(b'4:spam'), b'6:4:spam')
|
||||
self.assertEqual(bencode(bytearray(b'spam')), b'4:spam')
|
||||
|
||||
self.assertEqual(bdecode(b'0:'), b'')
|
||||
self.assertEqual(bdecode(b'4:spam'), b'spam')
|
||||
self.assertEqual(bdecode(b'6:4:spam'), b'4:spam')
|
||||
|
||||
def test_string(self):
|
||||
self.assertEqual(bencode(''), b'0:')
|
||||
self.assertEqual(bencode('spam'), b'4:spam')
|
||||
self.assertEqual(bencode('4:spam'), b'6:4:spam')
|
||||
|
||||
def test_list(self):
|
||||
self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee')
|
||||
|
||||
self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42])
|
||||
|
||||
def test_dict(self):
|
||||
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
|
||||
|
||||
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
|
||||
|
||||
def test_mixed(self):
|
||||
self.assertEqual(bencode(
|
||||
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
|
||||
)
|
||||
|
||||
self.assertEqual(bdecode(
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'),
|
||||
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
|
||||
)
|
||||
|
||||
def test_decode_error(self):
|
||||
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz')
|
||||
self.assertRaises(DecodeError, bdecode, b'')
|
||||
|
|
|
@ -1,56 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
|
||||
import hashlib
|
||||
from twisted.trial import unittest
|
||||
import struct
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.core.utils import generate_id
|
||||
|
||||
|
||||
class NodeIDTest(unittest.TestCase):
|
||||
""" Test case for the Node class's ID """
|
||||
|
||||
def setUp(self):
|
||||
self.node = Node()
|
||||
|
||||
def testAutoCreatedID(self):
|
||||
""" Tests if a new node has a valid node ID """
|
||||
self.assertEqual(type(self.node.node_id), bytes, 'Node does not have a valid ID')
|
||||
self.assertEqual(len(self.node.node_id), 48, 'Node ID length is incorrect! '
|
||||
'Expected 384 bits, got %d bits.' %
|
||||
(len(self.node.node_id) * 8))
|
||||
def test_new_node_has_auto_created_id(self):
|
||||
self.assertEqual(type(self.node.node_id), bytes)
|
||||
self.assertEqual(len(self.node.node_id), 48)
|
||||
|
||||
def testUniqueness(self):
|
||||
""" Tests the uniqueness of the values created by the NodeID generator """
|
||||
generatedIDs = []
|
||||
def test_uniqueness_and_length_of_generated_ids(self):
|
||||
previous_ids = []
|
||||
for i in range(100):
|
||||
newID = self.node._generateID()
|
||||
# ugly uniqueness test
|
||||
self.assertFalse(newID in generatedIDs, 'Generated ID #%d not unique!' % (i+1))
|
||||
generatedIDs.append(newID)
|
||||
|
||||
def testKeyLength(self):
|
||||
""" Tests the key Node ID key length """
|
||||
for i in range(20):
|
||||
id = self.node._generateID()
|
||||
# Key length: 20 bytes == 160 bits
|
||||
self.assertEqual(len(id), 48,
|
||||
'Length of generated ID is incorrect! Expected 384 bits, '
|
||||
'got %d bits.' % (len(id)*8))
|
||||
new_id = self.node._generateID()
|
||||
self.assertNotIn(new_id, previous_ids, 'id at index {} not unique'.format(i))
|
||||
self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id)))
|
||||
previous_ids.append(new_id)
|
||||
|
||||
|
||||
class NodeDataTest(unittest.TestCase):
|
||||
""" Test case for the Node class's data-related functions """
|
||||
|
||||
def setUp(self):
|
||||
h = hashlib.sha384()
|
||||
h.update(b'test')
|
||||
self.node = Node()
|
||||
self.contact = self.node.contact_manager.make_contact(h.digest(), '127.0.0.1', 12345, self.node._protocol)
|
||||
self.contact = self.node.contact_manager.make_contact(
|
||||
h.digest(), '127.0.0.1', 12345, self.node._protocol)
|
||||
self.token = self.node.make_token(self.contact.compact_ip())
|
||||
self.cases = []
|
||||
for i in range(5):
|
||||
|
@ -59,19 +43,18 @@ class NodeDataTest(unittest.TestCase):
|
|||
self.cases.append((h.digest(), 5001+2*i))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testStore(self):
|
||||
def test_store(self):
|
||||
""" Tests if the node can store (and privately retrieve) some data """
|
||||
for key, port in self.cases:
|
||||
yield self.node.store( # pylint: disable=too-many-function-args
|
||||
yield self.node.store(
|
||||
self.contact, key, self.token, port, self.contact.id, 0
|
||||
)
|
||||
for key, value in self.cases:
|
||||
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + \
|
||||
self.contact.id
|
||||
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id
|
||||
self.assertTrue(self.node._dataStore.hasPeersForBlob(key),
|
||||
'Stored key not found in node\'s DataStore: "%s"' % key)
|
||||
"Stored key not found in node's DataStore: '%s'" % key)
|
||||
self.assertTrue(expected_result in self.node._dataStore.getPeersForBlob(key),
|
||||
'Stored val not found in node\'s DataStore: key:"%s" port:"%s" %s'
|
||||
"Stored val not found in node's DataStore: key:'%s' port:'%s' %s"
|
||||
% (key, value, self.node._dataStore.getPeersForBlob(key)))
|
||||
|
||||
|
||||
|
@ -81,182 +64,25 @@ class NodeContactTest(unittest.TestCase):
|
|||
self.node = Node()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testAddContact(self):
|
||||
def test_add_contact(self):
|
||||
""" Tests if a contact can be added and retrieved correctly """
|
||||
# Create the contact
|
||||
h = hashlib.sha384()
|
||||
h.update(b'node1')
|
||||
contactID = h.digest()
|
||||
contact = self.node.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.node._protocol)
|
||||
contact_id = generate_id(b'node1')
|
||||
contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol)
|
||||
# Now add it...
|
||||
yield self.node.addContact(contact)
|
||||
# ...and request the closest nodes to it using FIND_NODE
|
||||
closestNodes = self.node._routingTable.findCloseNodes(contactID, constants.k)
|
||||
self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; '
|
||||
'expected 1, got %d' % len(closestNodes))
|
||||
self.assertTrue(contact in closestNodes, 'Added contact not found by issueing '
|
||||
'_findCloseNodes()')
|
||||
closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k)
|
||||
self.assertEqual(len(closest_nodes), 1)
|
||||
self.assertIn(contact, closest_nodes)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testAddSelfAsContact(self):
|
||||
def test_add_self_as_contact(self):
|
||||
""" Tests the node's behaviour when attempting to add itself as a contact """
|
||||
# Create a contact with the same ID as the local node's ID
|
||||
contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None)
|
||||
# Now try to add it
|
||||
yield self.node.addContact(contact)
|
||||
# ...and request the closest nodes to it using FIND_NODE
|
||||
closestNodes = self.node._routingTable.findCloseNodes(self.node.node_id,
|
||||
constants.k)
|
||||
self.assertFalse(contact in closestNodes, 'Node added itself as a contact')
|
||||
|
||||
|
||||
# class FakeRPCProtocol(protocol.DatagramProtocol):
|
||||
# def __init__(self):
|
||||
# self.reactor = selectreactor.SelectReactor()
|
||||
# self.testResponse = None
|
||||
# self.network = None
|
||||
#
|
||||
# def createNetwork(self, contactNetwork):
|
||||
# """
|
||||
# set up a list of contacts together with their closest contacts
|
||||
# @param contactNetwork: a sequence of tuples, each containing a contact together with its
|
||||
# closest contacts: C{(<contact>, <closest contact 1, ...,closest contact n>)}
|
||||
# """
|
||||
# self.network = contactNetwork
|
||||
#
|
||||
# def sendRPC(self, contact, method, args, rawResponse=False):
|
||||
# """ Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs"""
|
||||
#
|
||||
# h = hashlib.sha384()
|
||||
# h.update('rpcId')
|
||||
# rpc_id = h.digest()[:20]
|
||||
#
|
||||
# if method == "findNode":
|
||||
# # get the specific contacts closest contacts
|
||||
# closestContacts = []
|
||||
# closestContactsList = []
|
||||
# for contactTuple in self.network:
|
||||
# if contact == contactTuple[0]:
|
||||
# # get the list of closest contacts for this contact
|
||||
# closestContactsList = contactTuple[1]
|
||||
# # Pack the closest contacts into a ResponseMessage
|
||||
# for closeContact in closestContactsList:
|
||||
# closestContacts.append((closeContact.id, closeContact.address, closeContact.port))
|
||||
#
|
||||
# message = ResponseMessage(rpc_id, contact.id, closestContacts)
|
||||
# df = defer.Deferred()
|
||||
# df.callback((message, (contact.address, contact.port)))
|
||||
# return df
|
||||
# elif method == "findValue":
|
||||
# for contactTuple in self.network:
|
||||
# if contact == contactTuple[0]:
|
||||
# # Get the data stored by this remote contact
|
||||
# dataDict = contactTuple[2]
|
||||
# dataKey = dataDict.keys()[0]
|
||||
# data = dataDict.get(dataKey)
|
||||
# # Check if this contact has the requested value
|
||||
# if dataKey == args[0]:
|
||||
# # Return the data value
|
||||
# response = dataDict
|
||||
# print "data found at contact: " + contact.id
|
||||
# else:
|
||||
# # Return the closest contact to the requested data key
|
||||
# print "data not found at contact: " + contact.id
|
||||
# closeContacts = contactTuple[1]
|
||||
# closestContacts = []
|
||||
# for closeContact in closeContacts:
|
||||
# closestContacts.append((closeContact.id, closeContact.address,
|
||||
# closeContact.port))
|
||||
# response = closestContacts
|
||||
#
|
||||
# # Create the response message
|
||||
# message = ResponseMessage(rpc_id, contact.id, response)
|
||||
# df = defer.Deferred()
|
||||
# df.callback((message, (contact.address, contact.port)))
|
||||
# return df
|
||||
#
|
||||
# def _send(self, data, rpcID, address):
|
||||
# """ fake sending data """
|
||||
#
|
||||
#
|
||||
# class NodeLookupTest(unittest.TestCase):
|
||||
# """ Test case for the Node class's iterativeFind node lookup algorithm """
|
||||
#
|
||||
# def setUp(self):
|
||||
# # create a fake protocol to imitate communication with other nodes
|
||||
# self._protocol = FakeRPCProtocol()
|
||||
# # Note: The reactor is never started for this test. All deferred calls run sequentially,
|
||||
# # since there is no asynchronous network communication
|
||||
# # create the node to be tested in isolation
|
||||
# h = hashlib.sha384()
|
||||
# h.update('node1')
|
||||
# node_id = str(h.digest())
|
||||
# self.node = Node(node_id, 4000, None, None, self._protocol)
|
||||
# self.updPort = 81173
|
||||
# self.contactsAmount = 80
|
||||
# # Reinitialise the routing table
|
||||
# self.node._routingTable = TreeRoutingTable(self.node.node_id)
|
||||
#
|
||||
# # create 160 bit node ID's for test purposes
|
||||
# self.testNodeIDs = []
|
||||
# idNum = int(self.node.node_id.encode('hex'), 16)
|
||||
# for i in range(self.contactsAmount):
|
||||
# # create the testNodeIDs in ascending order, away from the actual node ID,
|
||||
# # with regards to the distance metric
|
||||
# self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex'))
|
||||
#
|
||||
# # generate contacts
|
||||
# self.contacts = []
|
||||
# for i in range(self.contactsAmount):
|
||||
# contact = self.node.contact_manager.make_contact(self.testNodeIDs[i], "127.0.0.1",
|
||||
# self.updPort + i + 1, self._protocol)
|
||||
# self.contacts.append(contact)
|
||||
#
|
||||
# # create the network of contacts in format: (contact, closest contacts)
|
||||
# contactNetwork = ((self.contacts[0], self.contacts[8:15]),
|
||||
# (self.contacts[1], self.contacts[16:23]),
|
||||
# (self.contacts[2], self.contacts[24:31]),
|
||||
# (self.contacts[3], self.contacts[32:39]),
|
||||
# (self.contacts[4], self.contacts[40:47]),
|
||||
# (self.contacts[5], self.contacts[48:55]),
|
||||
# (self.contacts[6], self.contacts[56:63]),
|
||||
# (self.contacts[7], self.contacts[64:71]),
|
||||
# (self.contacts[8], self.contacts[72:79]),
|
||||
# (self.contacts[40], self.contacts[41:48]),
|
||||
# (self.contacts[41], self.contacts[41:48]),
|
||||
# (self.contacts[42], self.contacts[41:48]),
|
||||
# (self.contacts[43], self.contacts[41:48]),
|
||||
# (self.contacts[44], self.contacts[41:48]),
|
||||
# (self.contacts[45], self.contacts[41:48]),
|
||||
# (self.contacts[46], self.contacts[41:48]),
|
||||
# (self.contacts[47], self.contacts[41:48]),
|
||||
# (self.contacts[48], self.contacts[41:48]),
|
||||
# (self.contacts[50], self.contacts[0:7]),
|
||||
# (self.contacts[51], self.contacts[8:15]),
|
||||
# (self.contacts[52], self.contacts[16:23]))
|
||||
#
|
||||
# contacts_with_datastores = []
|
||||
#
|
||||
# for contact_tuple in contactNetwork:
|
||||
# contacts_with_datastores.append((contact_tuple[0], contact_tuple[1],
|
||||
# DictDataStore()))
|
||||
# self._protocol.createNetwork(contacts_with_datastores)
|
||||
#
|
||||
# # @defer.inlineCallbacks
|
||||
# # def testNodeBootStrap(self):
|
||||
# # """ Test bootstrap with the closest possible contacts """
|
||||
# # # Set the expected result
|
||||
# # expectedResult = {item.id for item in self.contacts[0:8]}
|
||||
# #
|
||||
# # activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8])
|
||||
# #
|
||||
# # # Check the length of the active contacts
|
||||
# # self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(),
|
||||
# # "More active contacts should exist, there should be %d "
|
||||
# # "contacts but there are %d" % (len(expectedResult),
|
||||
# # len(activeContacts)))
|
||||
# #
|
||||
# # # Check that the received active contacts are the same as the input contacts
|
||||
# # self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult,
|
||||
# # "Active should only contain the closest possible contacts"
|
||||
# # " which were used as input for the boostrap")
|
||||
closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k)
|
||||
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.')
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import hashlib
|
||||
from binascii import hexlify, unhexlify
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
@ -7,9 +6,7 @@ from lbrynet.dht import constants
|
|||
from lbrynet.dht.routingtable import TreeRoutingTable
|
||||
from lbrynet.dht.contact import ContactManager
|
||||
from lbrynet.dht.distance import Distance
|
||||
import sys
|
||||
if sys.version_info > (3,):
|
||||
long = int
|
||||
from lbrynet.core.utils import generate_id
|
||||
|
||||
|
||||
class FakeRPCProtocol(object):
|
||||
|
@ -21,76 +18,61 @@ class FakeRPCProtocol(object):
|
|||
class TreeRoutingTableTest(unittest.TestCase):
|
||||
""" Test case for the RoutingTable class """
|
||||
def setUp(self):
|
||||
h = hashlib.sha384()
|
||||
h.update(b'node1')
|
||||
self.contact_manager = ContactManager()
|
||||
self.nodeID = h.digest()
|
||||
self.nodeID = generate_id(b'node1')
|
||||
self.protocol = FakeRPCProtocol()
|
||||
self.routingTable = TreeRoutingTable(self.nodeID)
|
||||
|
||||
def testDistance(self):
|
||||
def test_distance(self):
|
||||
""" Test to see if distance method returns correct result"""
|
||||
|
||||
# testList holds a couple 3-tuple (variable1, variable2, result)
|
||||
basicTestList = [(bytes(b'\xaa' * 48), bytes(b'\x55' * 48), long(hexlify(bytes(b'\xff' * 48)), 16))]
|
||||
|
||||
for test in basicTestList:
|
||||
result = Distance(test[0])(test[1])
|
||||
self.assertFalse(result != test[2], 'Result of _distance() should be %s but %s returned' %
|
||||
(test[2], result))
|
||||
d = Distance(bytes((170,) * 48))
|
||||
result = d(bytes((85,) * 48))
|
||||
expected = int(hexlify(bytes((255,) * 48)), 16)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testAddContact(self):
|
||||
def test_add_contact(self):
|
||||
""" Tests if a contact can be added and retrieved correctly """
|
||||
# Create the contact
|
||||
h = hashlib.sha384()
|
||||
h.update(b'node2')
|
||||
contactID = h.digest()
|
||||
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
|
||||
contact_id = generate_id(b'node2')
|
||||
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# Now add it...
|
||||
yield self.routingTable.addContact(contact)
|
||||
# ...and request the closest nodes to it (will retrieve it)
|
||||
closestNodes = self.routingTable.findCloseNodes(contactID)
|
||||
self.assertEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,'
|
||||
' got %d' % len(closestNodes))
|
||||
self.assertTrue(contact in closestNodes, 'Added contact not found by issueing '
|
||||
'_findCloseNodes()')
|
||||
closest_nodes = self.routingTable.findCloseNodes(contact_id)
|
||||
self.assertEqual(len(closest_nodes), 1)
|
||||
self.assertIn(contact, closest_nodes)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testGetContact(self):
|
||||
def test_get_contact(self):
|
||||
""" Tests if a specific existing contact can be retrieved correctly """
|
||||
h = hashlib.sha384()
|
||||
h.update(b'node2')
|
||||
contactID = h.digest()
|
||||
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
|
||||
contact_id = generate_id(b'node2')
|
||||
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# Now add it...
|
||||
yield self.routingTable.addContact(contact)
|
||||
# ...and get it again
|
||||
sameContact = self.routingTable.getContact(contactID)
|
||||
self.assertEqual(contact, sameContact, 'getContact() should return the same contact')
|
||||
same_contact = self.routingTable.getContact(contact_id)
|
||||
self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testAddParentNodeAsContact(self):
|
||||
def test_add_parent_node_as_contact(self):
|
||||
"""
|
||||
Tests the routing table's behaviour when attempting to add its parent node as a contact
|
||||
"""
|
||||
|
||||
# Create a contact with the same ID as the local node's ID
|
||||
contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
# Now try to add it
|
||||
yield self.routingTable.addContact(contact)
|
||||
# ...and request the closest nodes to it using FIND_NODE
|
||||
closestNodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
|
||||
self.assertFalse(contact in closestNodes, 'Node added itself as a contact')
|
||||
closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
|
||||
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testRemoveContact(self):
|
||||
def test_remove_contact(self):
|
||||
""" Tests contact removal """
|
||||
# Create the contact
|
||||
h = hashlib.sha384()
|
||||
h.update(b'node2')
|
||||
contactID = h.digest()
|
||||
contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
|
||||
contact_id = generate_id(b'node2')
|
||||
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# Now add it...
|
||||
yield self.routingTable.addContact(contact)
|
||||
# Verify addition
|
||||
|
@ -100,25 +82,22 @@ class TreeRoutingTableTest(unittest.TestCase):
|
|||
self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testSplitBucket(self):
|
||||
def test_split_bucket(self):
|
||||
""" Tests if the the routing table correctly dynamically splits k-buckets """
|
||||
self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
|
||||
'Initial k-bucket range should be 0 <= range < 2**384')
|
||||
# Add k contacts
|
||||
for i in range(constants.k):
|
||||
h = hashlib.sha384()
|
||||
h.update(b'remote node %d' % i)
|
||||
nodeID = h.digest()
|
||||
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
node_id = generate_id(b'remote node %d' % i)
|
||||
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
|
||||
yield self.routingTable.addContact(contact)
|
||||
|
||||
self.assertEqual(len(self.routingTable._buckets), 1,
|
||||
'Only k nodes have been added; the first k-bucket should now '
|
||||
'be full, but should not yet be split')
|
||||
# Now add 1 more contact
|
||||
h = hashlib.sha384()
|
||||
h.update(b'yet another remote node')
|
||||
nodeID = h.digest()
|
||||
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
node_id = generate_id(b'yet another remote node')
|
||||
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
|
||||
yield self.routingTable.addContact(contact)
|
||||
self.assertEqual(len(self.routingTable._buckets), 2,
|
||||
'k+1 nodes have been added; the first k-bucket should have been '
|
||||
|
@ -134,7 +113,7 @@ class TreeRoutingTableTest(unittest.TestCase):
|
|||
'not divided properly')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def testFullSplit(self):
|
||||
def test_full_split(self):
|
||||
"""
|
||||
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
|
||||
"""
|
||||
|
|
Loading…
Add table
Reference in a new issue