DHT fixes from review and an attempt at removing hashing and equals (#1370)

* use int to_bytes/from_bytes instead of struct
* fix ping queue bug and dht functional tests
* run functional tests on travis
* re-add contact comparison unit test
* dont need __ne__ if its just inverting __eq__ result
This commit is contained in:
shyba 2018-08-22 01:12:46 -03:00 committed by Jack Robison
parent 593d0046bd
commit eab95a6246
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
10 changed files with 48 additions and 69 deletions

View file

@ -26,10 +26,20 @@ jobs:
script: HOME=/tmp coverage run --source=lbrynet -m twisted.trial --reactor=asyncio tests.unit
after_success:
- bash <(curl -s https://codecov.io/bash)
- <<: *tests
name: "Unit Tests w/ Python 3.6"
python: "3.6"
- <<: *tests
name: "DHT Tests w/ Python 3.7"
script: HOME=/tmp coverage run --source=lbrynet -m twisted.trial --reactor=asyncio tests.functional
- <<: *tests
name: "DHT Tests w/ Python 3.6"
python: "3.6"
script: HOME=/tmp coverage run --source=lbrynet -m twisted.trial --reactor=asyncio tests.functional
- name: "Integration Tests"
install:
- pip install tox-travis coverage

View file

@ -1,7 +1,6 @@
import ipaddress
from binascii import hexlify
from functools import reduce
from lbrynet.dht import constants
@ -98,23 +97,12 @@ class _Contact:
return None
def __eq__(self, other):
if isinstance(other, _Contact):
return self.id == other.id
elif isinstance(other, str):
return self.id == other
else:
return False
def __ne__(self, other):
if isinstance(other, _Contact):
return self.id != other.id
elif isinstance(other, str):
return self.id != other
else:
return True
if not isinstance(other, _Contact):
raise TypeError("invalid type to compare with Contact: %s" % str(type(other)))
return (self.id, self.address, self.port) == (other.id, other.address, other.port)
def __hash__(self):
return int(hexlify(self.id), 16) if self.id else int(sum(int(x) for x in self.address.split('.')) + self.port)
return hash((self.id, self.address, self.port))
def compact_ip(self):
compact_ip = reduce(

View file

@ -1,5 +1,4 @@
import logging
import struct
from twisted.internet import defer
from .distance import Distance
from .error import TimeoutError
@ -17,7 +16,7 @@ def get_contact(contact_list, node_id, address, port):
def expand_peer(compact_peer_info):
host = "{}.{}.{}.{}".format(*compact_peer_info[:4])
port, = struct.unpack('>H', compact_peer_info[4:6])
port = int.from_bytes(compact_peer_info[4:6], 'big')
peer_node_id = compact_peer_info[6:]
return (peer_node_id, host, port)

View file

@ -1,5 +1,4 @@
import logging
from binascii import hexlify
from . import constants
from .distance import Distance
@ -138,7 +137,7 @@ class KBucket:
@rtype: bool
"""
if isinstance(key, bytes):
key = int(hexlify(key), 16)
key = int.from_bytes(key, 'big')
return self.rangeMin <= key < self.rangeMax
def __len__(self):

View file

@ -1,6 +1,5 @@
import binascii
import hashlib
import struct
import logging
from functools import reduce
@ -150,9 +149,6 @@ class Node(MockKademliaHelper):
return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port)
def __hash__(self):
return int(binascii.hexlify(self.node_id), 16)
@defer.inlineCallbacks
def stop(self):
# stop LoopingCalls:
@ -512,7 +508,7 @@ class Node(MockKademliaHelper):
elif not self.verify_token(token, compact_ip):
raise ValueError("Invalid token")
if 0 <= port <= 65536:
compact_port = struct.pack('>H', port)
compact_port = port.to_bytes(2, 'big')
else:
raise TypeError('Invalid port: {}'.format(port))
compact_address = compact_ip + compact_port + rpc_contact.id
@ -577,7 +573,7 @@ class Node(MockKademliaHelper):
# if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and key in self._dataStore.completed_blobs:
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray())
compact_port = struct.pack('>H', self.peerPort)
compact_port = self.peerPort.to_bytes(2, 'big')
compact_address = compact_ip + compact_port + self.node_id
peers.append(compact_address)

View file

@ -30,7 +30,7 @@ class PingQueue:
self._process_lc = node.get_looping_call(self._semaphore.run, self._process)
def _add_contact(self, contact, delay=None):
if contact in self._enqueued_contacts:
if (contact.address, contact.port) in [(c.address, c.port) for c in self._enqueued_contacts]:
return defer.succeed(None)
delay = delay or constants.checkRefreshInterval
self._enqueued_contacts[contact] = self._get_time() + delay

View file

@ -1,4 +1,5 @@
import logging
import binascii
from twisted.trial import unittest
from twisted.internet import defer, task
@ -91,7 +92,7 @@ class TestKademliaBase(unittest.TestCase):
online.add(n.externalIP)
return online
def show_info(self):
def show_info(self, show_contacts=False):
known = set()
for n in self._seeds:
known.update([(c.id, c.address, c.port) for c in n.contacts])
@ -99,12 +100,13 @@ class TestKademliaBase(unittest.TestCase):
known.update([(c.id, c.address, c.port) for c in n.contacts])
log.info("Routable: %i/%i", len(known), len(self.nodes) + len(self._seeds))
for n in self._seeds:
log.info("seed %s has %i contacts in %i buckets", n.externalIP, len(n.contacts),
len([b for b in n._routingTable._buckets if b.getContacts()]))
for n in self.nodes:
log.info("node %s has %i contacts in %i buckets", n.externalIP, len(n.contacts),
len([b for b in n._routingTable._buckets if b.getContacts()]))
if show_contacts:
for n in self._seeds:
log.info("seed %s (%s) has %i contacts in %i buckets", n.externalIP, binascii.hexlify(n.node_id)[:8], len(n.contacts),
len([b for b in n._routingTable._buckets if b.getContacts()]))
for n in self.nodes:
log.info("node %s (%s) has %i contacts in %i buckets", n.externalIP, binascii.hexlify(n.node_id)[:8], len(n.contacts),
len([b for b in n._routingTable._buckets if b.getContacts()]))
@defer.inlineCallbacks
def setUp(self):
@ -128,13 +130,10 @@ class TestKademliaBase(unittest.TestCase):
yield self.run_reactor(constants.checkRefreshInterval+1, seed_dl)
while len(self.nodes + self._seeds) < self.network_size:
network_dl = []
# fixme: We are starting one by one to reduce flakiness on time advance.
# fixme: When that improves, get back to 10+!
for i in range(min(1, self.network_size - len(self._seeds) - len(self.nodes))):
for i in range(min(10, self.network_size - len(self._seeds) - len(self.nodes))):
network_dl.append(self.add_node(known_addresses))
yield self.run_reactor(constants.checkRefreshInterval*2+1, network_dl)
self.assertEqual(len(self.nodes + self._seeds), self.network_size)
self.pump_clock(3600)
self.verify_all_nodes_are_routable()
self.verify_all_nodes_are_pingable()

View file

@ -113,10 +113,9 @@ def address_generator(address=(10, 42, 42, 1)):
address = increment(address)
def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES):
def mock_node_generator(count=None, mock_node_ids=None):
if mock_node_ids is None:
mock_node_ids = MOCK_DHT_NODES
mock_node_ids = list(mock_node_ids)
for num, node_ip in enumerate(address_generator()):
if count and num >= count:

View file

@ -12,7 +12,6 @@ class TestKademliaBootstrap(TestKademliaBase):
pass
@unittest.SkipTest
class TestKademliaBootstrap40Nodes(TestKademliaBase):
network_size = 40

View file

@ -6,7 +6,7 @@ from lbrynet.dht.contact import ContactManager
from lbrynet.dht import constants
class ContactOperatorsTest(unittest.TestCase):
class ContactTest(unittest.TestCase):
""" Basic tests case for boolean operators on the Contact class """
def setUp(self):
self.contact_manager = ContactManager()
@ -14,7 +14,7 @@ class ContactOperatorsTest(unittest.TestCase):
make_contact = self.contact_manager.make_contact
self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1)
self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.second_contact_copy = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.second_contact_second_reference = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50)
def test_make_contact_error_cases(self):
@ -28,31 +28,27 @@ class ContactOperatorsTest(unittest.TestCase):
ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None)
def test_no_duplicate_contact_objects(self):
self.assertTrue(self.second_contact is self.second_contact_copy)
self.assertTrue(self.second_contact is self.second_contact_second_reference)
self.assertTrue(self.first_contact is not self.first_contact_different_values)
def test_boolean(self):
""" Test "equals" and "not equals" comparisons """
self.assertNotEqual(
self.first_contact, self.second_contact,
'Contacts with different IDs should not be equal.')
self.assertEqual(
self.first_contact, self.first_contact_different_values,
'Contacts with same IDs should be equal, even if their other values differ.')
self.assertEqual(
self.second_contact, self.second_contact_copy,
'Different copies of the same Contact instance should be equal')
def test_illogical_comparisons(self):
""" Test comparisons with non-Contact and non-str types """
msg = '"{}" operator: Contact object should not be equal to {} type'
for item in (123, [1, 2, 3], {'key': 'value'}):
self.assertNotEqual(
self.first_contact, item,
msg.format('eq', type(item).__name__))
self.assertTrue(
self.first_contact != item,
msg.format('ne', type(item).__name__))
self.first_contact, self.contact_manager.make_contact(
self.first_contact.id, self.first_contact.address, self.first_contact.port + 1, None, 32
)
)
self.assertNotEqual(
self.first_contact, self.contact_manager.make_contact(
self.first_contact.id, '193.168.1.1', self.first_contact.port, None, 32
)
)
self.assertNotEqual(
self.first_contact, self.contact_manager.make_contact(
generate_id(), self.first_contact.address, self.first_contact.port, None, 32
)
)
self.assertEqual(self.second_contact, self.second_contact_second_reference)
def test_compact_ip(self):
self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01')
@ -62,12 +58,6 @@ class ContactOperatorsTest(unittest.TestCase):
self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1]))
self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8])
def test_hash(self):
# fails with "TypeError: unhashable type: '_Contact'" if __hash__ is not implemented
self.assertEqual(
len({self.first_contact, self.second_contact, self.second_contact_copy}), 2
)
class TestContactLastReplied(unittest.TestCase):
def setUp(self):