From eab95a6246720b7e10630291d192d27b7bf81b20 Mon Sep 17 00:00:00 2001 From: shyba <shyba@users.noreply.github.com> Date: Wed, 22 Aug 2018 01:12:46 -0300 Subject: [PATCH] DHT fixes from review and an attempt at removing hashing and equals (#1370) * use int to_bytes/from_bytes instead of struct * fix ping queue bug and dht functional tests * run functional tests on travis * re-add contact comparison unit test * dont need __ne__ if its just inverting __eq__ result --- .travis.yml | 10 ++++ lbrynet/dht/contact.py | 20 ++------ lbrynet/dht/iterativefind.py | 3 +- lbrynet/dht/kbucket.py | 3 +- lbrynet/dht/node.py | 8 +--- lbrynet/dht/protocol.py | 2 +- tests/functional/dht/dht_test_environment.py | 21 ++++----- tests/functional/dht/mock_transport.py | 3 +- .../functional/dht/test_bootstrap_network.py | 1 - tests/unit/dht/test_contact.py | 46 ++++++++----------- 10 files changed, 48 insertions(+), 69 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3b293e0fa..2c0935b0a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,10 +26,20 @@ jobs: script: HOME=/tmp coverage run --source=lbrynet -m twisted.trial --reactor=asyncio tests.unit after_success: - bash <(curl -s https://codecov.io/bash) + - <<: *tests name: "Unit Tests w/ Python 3.6" python: "3.6" + - <<: *tests + name: "DHT Tests w/ Python 3.7" + script: HOME=/tmp coverage run --source=lbrynet -m twisted.trial --reactor=asyncio tests.functional + + - <<: *tests + name: "DHT Tests w/ Python 3.6" + python: "3.6" + script: HOME=/tmp coverage run --source=lbrynet -m twisted.trial --reactor=asyncio tests.functional + - name: "Integration Tests" install: - pip install tox-travis coverage diff --git a/lbrynet/dht/contact.py b/lbrynet/dht/contact.py index ebbda0fba..101492ef3 100644 --- a/lbrynet/dht/contact.py +++ b/lbrynet/dht/contact.py @@ -1,7 +1,6 @@ import ipaddress from binascii import hexlify from functools import reduce - from lbrynet.dht import constants @@ -98,23 +97,12 @@ class _Contact: return None def __eq__(self, other): - if isinstance(other, _Contact): - return self.id == other.id - elif isinstance(other, str): - return self.id == other - else: - return False - - def __ne__(self, other): - if isinstance(other, _Contact): - return self.id != other.id - elif isinstance(other, str): - return self.id != other - else: - return True + if not isinstance(other, _Contact): + raise TypeError("invalid type to compare with Contact: %s" % str(type(other))) + return (self.id, self.address, self.port) == (other.id, other.address, other.port) def __hash__(self): - return int(hexlify(self.id), 16) if self.id else int(sum(int(x) for x in self.address.split('.')) + self.port) + return hash((self.id, self.address, self.port)) def compact_ip(self): compact_ip = reduce( diff --git a/lbrynet/dht/iterativefind.py b/lbrynet/dht/iterativefind.py index 26608ead6..765c548dc 100644 --- a/lbrynet/dht/iterativefind.py +++ b/lbrynet/dht/iterativefind.py @@ -1,5 +1,4 @@ import logging -import struct from twisted.internet import defer from .distance import Distance from .error import TimeoutError @@ -17,7 +16,7 @@ def get_contact(contact_list, node_id, address, port): def expand_peer(compact_peer_info): host = "{}.{}.{}.{}".format(*compact_peer_info[:4]) - port, = struct.unpack('>H', compact_peer_info[4:6]) + port = int.from_bytes(compact_peer_info[4:6], 'big') peer_node_id = compact_peer_info[6:] return (peer_node_id, host, port) diff --git a/lbrynet/dht/kbucket.py b/lbrynet/dht/kbucket.py index a4756bed8..7fffb4ce7 100644 --- a/lbrynet/dht/kbucket.py +++ b/lbrynet/dht/kbucket.py @@ -1,5 +1,4 @@ import logging -from binascii import hexlify from . import constants from .distance import Distance @@ -138,7 +137,7 @@ class KBucket: @rtype: bool """ if isinstance(key, bytes): - key = int(hexlify(key), 16) + key = int.from_bytes(key, 'big') return self.rangeMin <= key < self.rangeMax def __len__(self): diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index 2e22f5439..ce9d2da81 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -1,6 +1,5 @@ import binascii import hashlib -import struct import logging from functools import reduce @@ -150,9 +149,6 @@ class Node(MockKademliaHelper): return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % ( self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port) - def __hash__(self): - return int(binascii.hexlify(self.node_id), 16) - @defer.inlineCallbacks def stop(self): # stop LoopingCalls: @@ -512,7 +508,7 @@ class Node(MockKademliaHelper): elif not self.verify_token(token, compact_ip): raise ValueError("Invalid token") if 0 <= port <= 65536: - compact_port = struct.pack('>H', port) + compact_port = port.to_bytes(2, 'big') else: raise TypeError('Invalid port: {}'.format(port)) compact_address = compact_ip + compact_port + rpc_contact.id @@ -577,7 +573,7 @@ class Node(MockKademliaHelper): # if we don't have k storing peers to return and we have this hash locally, include our contact information if len(peers) < constants.k and key in self._dataStore.completed_blobs: compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray()) - compact_port = struct.pack('>H', self.peerPort) + compact_port = self.peerPort.to_bytes(2, 'big') compact_address = compact_ip + compact_port + self.node_id peers.append(compact_address) diff --git a/lbrynet/dht/protocol.py b/lbrynet/dht/protocol.py index 01c3bddb5..e3130468c 100644 --- a/lbrynet/dht/protocol.py +++ b/lbrynet/dht/protocol.py @@ -30,7 +30,7 @@ class PingQueue: self._process_lc = node.get_looping_call(self._semaphore.run, self._process) def _add_contact(self, contact, delay=None): - if contact in self._enqueued_contacts: + if (contact.address, contact.port) in [(c.address, c.port) for c in self._enqueued_contacts]: return defer.succeed(None) delay = delay or constants.checkRefreshInterval self._enqueued_contacts[contact] = self._get_time() + delay diff --git a/tests/functional/dht/dht_test_environment.py b/tests/functional/dht/dht_test_environment.py index 4451ae770..8c431c5ec 100644 --- a/tests/functional/dht/dht_test_environment.py +++ b/tests/functional/dht/dht_test_environment.py @@ -1,4 +1,5 @@ import logging +import binascii from twisted.trial import unittest from twisted.internet import defer, task @@ -91,7 +92,7 @@ class TestKademliaBase(unittest.TestCase): online.add(n.externalIP) return online - def show_info(self): + def show_info(self, show_contacts=False): known = set() for n in self._seeds: known.update([(c.id, c.address, c.port) for c in n.contacts]) @@ -99,12 +100,13 @@ class TestKademliaBase(unittest.TestCase): known.update([(c.id, c.address, c.port) for c in n.contacts]) log.info("Routable: %i/%i", len(known), len(self.nodes) + len(self._seeds)) - for n in self._seeds: - log.info("seed %s has %i contacts in %i buckets", n.externalIP, len(n.contacts), - len([b for b in n._routingTable._buckets if b.getContacts()])) - for n in self.nodes: - log.info("node %s has %i contacts in %i buckets", n.externalIP, len(n.contacts), - len([b for b in n._routingTable._buckets if b.getContacts()])) + if show_contacts: + for n in self._seeds: + log.info("seed %s (%s) has %i contacts in %i buckets", n.externalIP, binascii.hexlify(n.node_id)[:8], len(n.contacts), + len([b for b in n._routingTable._buckets if b.getContacts()])) + for n in self.nodes: + log.info("node %s (%s) has %i contacts in %i buckets", n.externalIP, binascii.hexlify(n.node_id)[:8], len(n.contacts), + len([b for b in n._routingTable._buckets if b.getContacts()])) @defer.inlineCallbacks def setUp(self): @@ -128,13 +130,10 @@ class TestKademliaBase(unittest.TestCase): yield self.run_reactor(constants.checkRefreshInterval+1, seed_dl) while len(self.nodes + self._seeds) < self.network_size: network_dl = [] - # fixme: We are starting one by one to reduce flakiness on time advance. - # fixme: When that improves, get back to 10+! - for i in range(min(1, self.network_size - len(self._seeds) - len(self.nodes))): + for i in range(min(10, self.network_size - len(self._seeds) - len(self.nodes))): network_dl.append(self.add_node(known_addresses)) yield self.run_reactor(constants.checkRefreshInterval*2+1, network_dl) self.assertEqual(len(self.nodes + self._seeds), self.network_size) - self.pump_clock(3600) self.verify_all_nodes_are_routable() self.verify_all_nodes_are_pingable() diff --git a/tests/functional/dht/mock_transport.py b/tests/functional/dht/mock_transport.py index ac006f6f4..cbeaf66c7 100644 --- a/tests/functional/dht/mock_transport.py +++ b/tests/functional/dht/mock_transport.py @@ -113,10 +113,9 @@ def address_generator(address=(10, 42, 42, 1)): address = increment(address) -def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES): +def mock_node_generator(count=None, mock_node_ids=None): if mock_node_ids is None: mock_node_ids = MOCK_DHT_NODES - mock_node_ids = list(mock_node_ids) for num, node_ip in enumerate(address_generator()): if count and num >= count: diff --git a/tests/functional/dht/test_bootstrap_network.py b/tests/functional/dht/test_bootstrap_network.py index a9276c45f..82b2fc410 100644 --- a/tests/functional/dht/test_bootstrap_network.py +++ b/tests/functional/dht/test_bootstrap_network.py @@ -12,7 +12,6 @@ class TestKademliaBootstrap(TestKademliaBase): pass -@unittest.SkipTest class TestKademliaBootstrap40Nodes(TestKademliaBase): network_size = 40 diff --git a/tests/unit/dht/test_contact.py b/tests/unit/dht/test_contact.py index 1ab72a487..abbd99de0 100644 --- a/tests/unit/dht/test_contact.py +++ b/tests/unit/dht/test_contact.py @@ -6,7 +6,7 @@ from lbrynet.dht.contact import ContactManager from lbrynet.dht import constants -class ContactOperatorsTest(unittest.TestCase): +class ContactTest(unittest.TestCase): """ Basic tests case for boolean operators on the Contact class """ def setUp(self): self.contact_manager = ContactManager() @@ -14,7 +14,7 @@ class ContactOperatorsTest(unittest.TestCase): make_contact = self.contact_manager.make_contact self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1) self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) - self.second_contact_copy = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) + self.second_contact_second_reference = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50) def test_make_contact_error_cases(self): @@ -28,31 +28,27 @@ class ContactOperatorsTest(unittest.TestCase): ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None) def test_no_duplicate_contact_objects(self): - self.assertTrue(self.second_contact is self.second_contact_copy) + self.assertTrue(self.second_contact is self.second_contact_second_reference) self.assertTrue(self.first_contact is not self.first_contact_different_values) def test_boolean(self): """ Test "equals" and "not equals" comparisons """ self.assertNotEqual( - self.first_contact, self.second_contact, - 'Contacts with different IDs should not be equal.') - self.assertEqual( - self.first_contact, self.first_contact_different_values, - 'Contacts with same IDs should be equal, even if their other values differ.') - self.assertEqual( - self.second_contact, self.second_contact_copy, - 'Different copies of the same Contact instance should be equal') - - def test_illogical_comparisons(self): - """ Test comparisons with non-Contact and non-str types """ - msg = '"{}" operator: Contact object should not be equal to {} type' - for item in (123, [1, 2, 3], {'key': 'value'}): - self.assertNotEqual( - self.first_contact, item, - msg.format('eq', type(item).__name__)) - self.assertTrue( - self.first_contact != item, - msg.format('ne', type(item).__name__)) + self.first_contact, self.contact_manager.make_contact( + self.first_contact.id, self.first_contact.address, self.first_contact.port + 1, None, 32 + ) + ) + self.assertNotEqual( + self.first_contact, self.contact_manager.make_contact( + self.first_contact.id, '193.168.1.1', self.first_contact.port, None, 32 + ) + ) + self.assertNotEqual( + self.first_contact, self.contact_manager.make_contact( + generate_id(), self.first_contact.address, self.first_contact.port, None, 32 + ) + ) + self.assertEqual(self.second_contact, self.second_contact_second_reference) def test_compact_ip(self): self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01') @@ -62,12 +58,6 @@ class ContactOperatorsTest(unittest.TestCase): self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1])) self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8]) - def test_hash(self): - # fails with "TypeError: unhashable type: '_Contact'" if __hash__ is not implemented - self.assertEqual( - len({self.first_contact, self.second_contact, self.second_contact_copy}), 2 - ) - class TestContactLastReplied(unittest.TestCase): def setUp(self):