From 2ad22d7d19783ab2d3c7104bd8451a0073029a4d Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 18 Jun 2019 20:08:50 -0400 Subject: [PATCH] paginated deterministically shuffled find_value - fixes https://github.com/lbryio/lbry/issues/2244 - reduce the max DHT datagram size to 1400 bytes - truncate `contacts` field of find_value response datagrams to the k closest (8) - truncate peers in find_node response datagrams to the 2k closest (16) - remove `contacts` field from find_value responses beyond `page` 0 (the first/default) - deterministically shuffle the peers for a blob in a find_value response - add optional `page` argument to `find_value` and `p` field to find_value responses containing the number of pages of k peers for the blob - test one blob being announced by 150 different peers to one peer - speed up pylint and remove some disabled checks --- .pylintrc | 6 +- lbrynet/dht/constants.py | 3 +- lbrynet/dht/protocol/protocol.py | 61 ++++++++++++------- lbrynet/dht/serialization/bencoding.py | 2 +- lbrynet/dht/serialization/datagram.py | 9 ++- tests/unit/dht/protocol/test_protocol.py | 5 +- tests/unit/dht/serialization/test_datagram.py | 24 +++++++- tests/unit/dht/test_blob_announcer.py | 55 +++++++++++++++++ 8 files changed, 129 insertions(+), 36 deletions(-) diff --git a/.pylintrc b/.pylintrc index 6d7cb1ad9..83aaa46f6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -24,7 +24,7 @@ persistent=yes load-plugins= # Use multiple processes to speed up Pylint. -jobs=1 +jobs=4 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. @@ -77,8 +77,6 @@ disable= dangerous-default-value, duplicate-code, fixme, - global-statement, - inherit-non-class, invalid-name, len-as-condition, locally-disabled, @@ -111,13 +109,11 @@ disable= unnecessary-lambda, unused-argument, unused-variable, - wildcard-import, wrong-import-order, wrong-import-position, deprecated-lambda, simplifiable-if-statement, unidiomatic-typecheck, - global-at-module-level, inconsistent-return-statements, keyword-arg-before-vararg, assignment-from-no-return, diff --git a/lbrynet/dht/constants.py b/lbrynet/dht/constants.py index c64720706..a9754c2b6 100644 --- a/lbrynet/dht/constants.py +++ b/lbrynet/dht/constants.py @@ -18,11 +18,10 @@ data_expiration = 86400 # 24 hours token_secret_refresh_interval = 300 # 5 minutes maybe_ping_delay = 300 # 5 minutes check_refresh_interval = refresh_interval / 5 -max_datagram_size = 8192 # 8 KB rpc_id_length = 20 protocol_version = 1 bottom_out_limit = 3 -msg_size_limit = max_datagram_size - 26 +msg_size_limit = 1400 def digest(data: bytes) -> bytes: diff --git a/lbrynet/dht/protocol/protocol.py b/lbrynet/dht/protocol/protocol.py index fdd61be15..2de36c00a 100644 --- a/lbrynet/dht/protocol/protocol.py +++ b/lbrynet/dht/protocol/protocol.py @@ -5,12 +5,13 @@ import hashlib import asyncio import typing import binascii +import random from asyncio.protocols import DatagramProtocol from asyncio.transports import DatagramTransport from lbrynet.dht import constants from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram -from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE +from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY from lbrynet.dht.error import RemoteException, TransportNotConnected from lbrynet.dht.protocol.routing_table import TreeRoutingTable from lbrynet.dht.protocol.data_store import DictDataStore @@ -67,19 +68,23 @@ class KademliaRPC: contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id) contact_triples = [] - for contact in contacts: + for contact in contacts[:constants.k * 2]: contact_triples.append((contact.node_id, contact.address, contact.udp_port)) return contact_triples - def find_value(self, rpc_contact: 'KademliaPeer', key: bytes): + def find_value(self, rpc_contact: 'KademliaPeer', key: bytes, page: int = 0): + page = page if page > 0 else 0 + if len(key) != constants.hash_length: raise ValueError("invalid blob_exchange hash length: %i" % len(key)) response = { b'token': self.make_token(rpc_contact.compact_ip()), - b'contacts': self.find_node(rpc_contact, key) } + if not page: + response[b'contacts'] = self.find_node(rpc_contact, key)[:constants.k] + if self.protocol.protocol_version: response[b'protocolVersion'] = self.protocol.protocol_version @@ -92,8 +97,14 @@ class KademliaRPC: # if we don't have k storing peers to return and we have this hash locally, include our contact information if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs: peers.append(self.compact_address()) - if peers: - response[key] = peers + if not peers: + response[PAGE_KEY] = 0 + else: + response[PAGE_KEY] = (len(peers) // (constants.k + 1)) + 1 # how many pages of peers we have for the blob + if len(peers) > constants.k: + random.Random(self.protocol.node_id).shuffle(peers) + if page * constants.k < len(peers): + response[key] = peers[page * constants.k:page * constants.k + constants.k] return response def refresh_token(self): # TODO: this needs to be called periodically @@ -166,7 +177,7 @@ class RemoteKademliaRPC: ) return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response] - async def find_value(self, key: bytes) -> typing.Union[typing.Dict]: + async def find_value(self, key: bytes, page: int = 0) -> typing.Union[typing.Dict]: """ :return: { b'token': , @@ -177,7 +188,7 @@ class RemoteKademliaRPC: if len(key) != constants.hash_bits // 8: raise ValueError(f"invalid length of find value key: {len(key)}") response = await self.protocol.send_request( - self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key) + self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key, page=page) ) self.peer_tracker.update_token(self.peer.node_id, response.response[b'token']) return response.response @@ -406,24 +417,25 @@ class KademliaProtocol(DatagramProtocol): raise AttributeError('Invalid method: %s' % message.method.decode()) if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]: # args don't need reformatting - a, kw = tuple(message.args[:-1]), message.args[-1] + args, kw = tuple(message.args[:-1]), message.args[-1] else: - a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args) + args, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args) log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(), sender_contact.address, sender_contact.udp_port) if method == b'ping': result = self.node_rpc.ping() elif method == b'store': - blob_hash, token, port, original_publisher_id, age = a + blob_hash, token, port, original_publisher_id, age = args[:5] result = self.node_rpc.store(sender_contact, blob_hash, token, port) - elif method == b'findNode': - key, = a - result = self.node_rpc.find_node(sender_contact, key) else: - assert method == b'findValue' - key, = a - result = self.node_rpc.find_value(sender_contact, key) + key = args[0] + page = kw.get(PAGE_KEY, 0) + if method == b'findNode': + result = self.node_rpc.find_node(sender_contact, key) + else: + assert method == b'findValue' + result = self.node_rpc.find_value(sender_contact, key, page) self.send_response( sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result), @@ -569,15 +581,18 @@ class KademliaProtocol(DatagramProtocol): def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram): self._send(peer, error) - def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, - ErrorDatagram]): + def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]): if not self.transport or self.transport.is_closing(): raise TransportNotConnected() data = message.bencode() if len(data) > constants.msg_size_limit: - log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit) - raise ValueError() + log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)", + constants.msg_size_limit, len(data)) + log.debug("Packet is too large to send: %s", binascii.hexlify(data[:3500]).decode()) + raise ValueError( + f"cannot send datagram larger than {constants.msg_size_limit} bytes (packet is {len(data)} bytes)" + ) if isinstance(message, (RequestDatagram, ResponseDatagram)): assert message.node_id == self.node_id, message if isinstance(message, RequestDatagram): @@ -642,12 +657,12 @@ class KademliaProtocol(DatagramProtocol): log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer) except ValueError as err: log.error("Unexpected response: %s" % err) - except Exception as err: + except RemoteException as err: if 'Invalid token' in str(err): self.peer_manager.clear_token(peer.node_id) try: return await __store() - except: + except (ValueError, asyncio.TimeoutError, RemoteException): return peer.node_id, False else: log.exception("Unexpected error while storing blob_hash") diff --git a/lbrynet/dht/serialization/bencoding.py b/lbrynet/dht/serialization/bencoding.py index 5efdf969d..dbb6895ee 100644 --- a/lbrynet/dht/serialization/bencoding.py +++ b/lbrynet/dht/serialization/bencoding.py @@ -63,7 +63,7 @@ def bencode(data: typing.Dict) -> bytes: def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict: - assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}") + assert isinstance(data, bytes), DecodeError(f"invalid data type: {str(type(data))}") if len(data) == 0: raise DecodeError('Cannot decode empty string') diff --git a/lbrynet/dht/serialization/datagram.py b/lbrynet/dht/serialization/datagram.py index 231fbbb59..70c84fbd0 100644 --- a/lbrynet/dht/serialization/datagram.py +++ b/lbrynet/dht/serialization/datagram.py @@ -7,6 +7,9 @@ REQUEST_TYPE = 0 RESPONSE_TYPE = 1 ERROR_TYPE = 2 +# bencode representation of argument keys +PAGE_KEY = b'p' + class KademliaDatagramBase: """ @@ -91,11 +94,13 @@ class RequestDatagram(KademliaDatagramBase): @classmethod def make_find_value(cls, from_node_id: bytes, key: bytes, - rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': + rpc_id: typing.Optional[bytes] = None, page: int = 0) -> 'RequestDatagram': rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] if len(key) != constants.hash_bits // 8: raise ValueError(f"invalid key length: {len(key)}") - return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key]) + if page < 0: + raise ValueError(f"cannot request a negative page ({page})") + return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key, {PAGE_KEY: page}]) class ResponseDatagram(KademliaDatagramBase): diff --git a/tests/unit/dht/protocol/test_protocol.py b/tests/unit/dht/protocol/test_protocol.py index bfad0165f..6669d955a 100644 --- a/tests/unit/dht/protocol/test_protocol.py +++ b/tests/unit/dht/protocol/test_protocol.py @@ -83,12 +83,15 @@ class TestProtocol(AsyncioTestCase): find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48) self.assertEqual(len(find_value_response[b'contacts']), 0) self.assertSetEqual( - {b'2' * 48, b'token', b'protocolVersion', b'contacts'}, set(find_value_response.keys()) + {b'2' * 48, b'token', b'protocolVersion', b'contacts', b'p'}, set(find_value_response.keys()) ) self.assertEqual(2, len(find_value_response[b'2' * 48])) self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp()) self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response) + find_value_page_above_pages_response = peer1.node_rpc.find_value(peer3, b'2' * 48, page=10) + self.assertNotIn(b'2' * 48, find_value_page_above_pages_response) + peer1.stop() peer2.stop() peer1.disconnect() diff --git a/tests/unit/dht/serialization/test_datagram.py b/tests/unit/dht/serialization/test_datagram.py index 6f473d212..565f7f1a5 100644 --- a/tests/unit/dht/serialization/test_datagram.py +++ b/tests/unit/dht/serialization/test_datagram.py @@ -59,17 +59,28 @@ class TestDatagram(unittest.TestCase): self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21) + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 20, -1) self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id)) + # default page argument serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.node_id, b'1' * 48) self.assertEqual(decoded.method, b'findValue') - self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}]) + self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 0}]) - def test_find_value_response(self): + # nondefault page argument + serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20, 1).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, REQUEST_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.method, b'findValue') + self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 1}]) + + def test_find_value_response_without_pages_field(self): found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']} serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode() decoded = decode_datagram(serialized) @@ -78,6 +89,15 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.node_id, b'1' * 48) self.assertDictEqual(decoded.response, found_value_response) + def test_find_value_response_with_pages_field(self): + found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01'], b'p': 1} + serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, RESPONSE_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertDictEqual(decoded.response, found_value_response) + def test_store_request(self): self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20) diff --git a/tests/unit/dht/test_blob_announcer.py b/tests/unit/dht/test_blob_announcer.py index 1f4c535a8..49eb3c673 100644 --- a/tests/unit/dht/test_blob_announcer.py +++ b/tests/unit/dht/test_blob_announcer.py @@ -1,6 +1,7 @@ import contextlib import typing import binascii +import socket import asyncio from torba.testcase import AsyncioTestCase from tests import dht_mocks @@ -26,6 +27,7 @@ class TestBlobAnnouncer(AsyncioTestCase): for node_id, address in peer_addresses: await self.add_peer(node_id, address) self.node.joined.set() + self.node._refresh_task = self.loop.create_task(self.node.refresh_node()) async def add_peer(self, node_id, address, add_to_routing_table=True): n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address) @@ -113,3 +115,56 @@ class TestBlobAnnouncer(AsyncioTestCase): self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id) self.assertEqual(self.node.protocol.external_ip, found_peers[0].address) self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port) + + async def test_popular_blob(self): + peer_count = 150 + addresses = [ + (constants.generate_id(i + 1), socket.inet_ntoa(int(i + 1).to_bytes(length=4, byteorder='big'))) + for i in range(peer_count) + ] + blob_hash = b'1' * 48 + + async with self._test_network_context(peer_addresses=addresses): + total_seen = set() + announced_to = self.nodes[0] + for i in range(1, peer_count): + node = self.nodes[i] + kad_peer = announced_to.protocol.peer_manager.get_kademlia_peer( + node.protocol.node_id, node.protocol.external_ip, node.protocol.udp_port + ) + await announced_to.protocol._add_peer(kad_peer) + peer = node.protocol.get_rpc_peer( + node.protocol.peer_manager.get_kademlia_peer( + announced_to.protocol.node_id, + announced_to.protocol.external_ip, + announced_to.protocol.udp_port + ) + ) + response = await peer.store(blob_hash) + self.assertEqual(response, b'OK') + peers_for_blob = await peer.find_value(blob_hash, 0) + if i == 1: + self.assertTrue(blob_hash not in peers_for_blob) + self.assertEqual(peers_for_blob[b'p'], 0) + else: + self.assertEqual(len(peers_for_blob[blob_hash]), min(i - 1, constants.k)) + self.assertEqual(len(announced_to.protocol.data_store.get_peers_for_blob(blob_hash)), i) + if i - 1 > constants.k: + self.assertEqual(len(peers_for_blob[b'contacts']), constants.k) + self.assertEqual(peers_for_blob[b'p'], ((i - 1) // (constants.k + 1)) + 1) + seen = set(peers_for_blob[blob_hash]) + self.assertEqual(len(seen), constants.k) + self.assertEqual(len(peers_for_blob[blob_hash]), len(seen)) + + for pg in range(1, peers_for_blob[b'p']): + page_x = await peer.find_value(blob_hash, pg) + self.assertNotIn(b'contacts', page_x) + page_x_set = set(page_x[blob_hash]) + self.assertEqual(len(page_x[blob_hash]), len(page_x_set)) + self.assertTrue(len(page_x_set) > 0) + self.assertSetEqual(seen.intersection(page_x_set), set()) + seen.intersection_update(page_x_set) + total_seen.update(page_x_set) + else: + self.assertEqual(len(peers_for_blob[b'contacts']), i - 1) + self.assertEqual(len(total_seen), peer_count - 2)