Merge pull request #2247 from lbryio/fix-unable-to-serialize-response

paginated deterministically shuffled find_value
This commit is contained in:
Jack Robison 2019-06-18 23:25:12 -04:00 committed by GitHub
commit bfeed40bfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 129 additions and 36 deletions

View file

@ -24,7 +24,7 @@ persistent=yes
load-plugins= load-plugins=
# Use multiple processes to speed up Pylint. # Use multiple processes to speed up Pylint.
jobs=1 jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the # Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code. # active Python interpreter and may run arbitrary code.
@ -77,8 +77,6 @@ disable=
dangerous-default-value, dangerous-default-value,
duplicate-code, duplicate-code,
fixme, fixme,
global-statement,
inherit-non-class,
invalid-name, invalid-name,
len-as-condition, len-as-condition,
locally-disabled, locally-disabled,
@ -111,13 +109,11 @@ disable=
unnecessary-lambda, unnecessary-lambda,
unused-argument, unused-argument,
unused-variable, unused-variable,
wildcard-import,
wrong-import-order, wrong-import-order,
wrong-import-position, wrong-import-position,
deprecated-lambda, deprecated-lambda,
simplifiable-if-statement, simplifiable-if-statement,
unidiomatic-typecheck, unidiomatic-typecheck,
global-at-module-level,
inconsistent-return-statements, inconsistent-return-statements,
keyword-arg-before-vararg, keyword-arg-before-vararg,
assignment-from-no-return, assignment-from-no-return,

View file

@ -18,11 +18,10 @@ data_expiration = 86400 # 24 hours
token_secret_refresh_interval = 300 # 5 minutes token_secret_refresh_interval = 300 # 5 minutes
maybe_ping_delay = 300 # 5 minutes maybe_ping_delay = 300 # 5 minutes
check_refresh_interval = refresh_interval / 5 check_refresh_interval = refresh_interval / 5
max_datagram_size = 8192 # 8 KB
rpc_id_length = 20 rpc_id_length = 20
protocol_version = 1 protocol_version = 1
bottom_out_limit = 3 bottom_out_limit = 3
msg_size_limit = max_datagram_size - 26 msg_size_limit = 1400
def digest(data: bytes) -> bytes: def digest(data: bytes) -> bytes:

View file

@ -5,12 +5,13 @@ import hashlib
import asyncio import asyncio
import typing import typing
import binascii import binascii
import random
from asyncio.protocols import DatagramProtocol from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from lbrynet.dht import constants from lbrynet.dht import constants
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY
from lbrynet.dht.error import RemoteException, TransportNotConnected from lbrynet.dht.error import RemoteException, TransportNotConnected
from lbrynet.dht.protocol.routing_table import TreeRoutingTable from lbrynet.dht.protocol.routing_table import TreeRoutingTable
from lbrynet.dht.protocol.data_store import DictDataStore from lbrynet.dht.protocol.data_store import DictDataStore
@ -67,19 +68,23 @@ class KademliaRPC:
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id) contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
contact_triples = [] contact_triples = []
for contact in contacts: for contact in contacts[:constants.k * 2]:
contact_triples.append((contact.node_id, contact.address, contact.udp_port)) contact_triples.append((contact.node_id, contact.address, contact.udp_port))
return contact_triples return contact_triples
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes): def find_value(self, rpc_contact: 'KademliaPeer', key: bytes, page: int = 0):
page = page if page > 0 else 0
if len(key) != constants.hash_length: if len(key) != constants.hash_length:
raise ValueError("invalid blob_exchange hash length: %i" % len(key)) raise ValueError("invalid blob_exchange hash length: %i" % len(key))
response = { response = {
b'token': self.make_token(rpc_contact.compact_ip()), b'token': self.make_token(rpc_contact.compact_ip()),
b'contacts': self.find_node(rpc_contact, key)
} }
if not page:
response[b'contacts'] = self.find_node(rpc_contact, key)[:constants.k]
if self.protocol.protocol_version: if self.protocol.protocol_version:
response[b'protocolVersion'] = self.protocol.protocol_version response[b'protocolVersion'] = self.protocol.protocol_version
@ -92,8 +97,14 @@ class KademliaRPC:
# if we don't have k storing peers to return and we have this hash locally, include our contact information # if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs: if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
peers.append(self.compact_address()) peers.append(self.compact_address())
if peers: if not peers:
response[key] = peers response[PAGE_KEY] = 0
else:
response[PAGE_KEY] = (len(peers) // (constants.k + 1)) + 1 # how many pages of peers we have for the blob
if len(peers) > constants.k:
random.Random(self.protocol.node_id).shuffle(peers)
if page * constants.k < len(peers):
response[key] = peers[page * constants.k:page * constants.k + constants.k]
return response return response
def refresh_token(self): # TODO: this needs to be called periodically def refresh_token(self): # TODO: this needs to be called periodically
@ -166,7 +177,7 @@ class RemoteKademliaRPC:
) )
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response] return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
async def find_value(self, key: bytes) -> typing.Union[typing.Dict]: async def find_value(self, key: bytes, page: int = 0) -> typing.Union[typing.Dict]:
""" """
:return: { :return: {
b'token': <token bytes>, b'token': <token bytes>,
@ -177,7 +188,7 @@ class RemoteKademliaRPC:
if len(key) != constants.hash_bits // 8: if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid length of find value key: {len(key)}") raise ValueError(f"invalid length of find value key: {len(key)}")
response = await self.protocol.send_request( response = await self.protocol.send_request(
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key) self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key, page=page)
) )
self.peer_tracker.update_token(self.peer.node_id, response.response[b'token']) self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
return response.response return response.response
@ -406,24 +417,25 @@ class KademliaProtocol(DatagramProtocol):
raise AttributeError('Invalid method: %s' % message.method.decode()) raise AttributeError('Invalid method: %s' % message.method.decode())
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]: if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
# args don't need reformatting # args don't need reformatting
a, kw = tuple(message.args[:-1]), message.args[-1] args, kw = tuple(message.args[:-1]), message.args[-1]
else: else:
a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args) args, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(), log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
sender_contact.address, sender_contact.udp_port) sender_contact.address, sender_contact.udp_port)
if method == b'ping': if method == b'ping':
result = self.node_rpc.ping() result = self.node_rpc.ping()
elif method == b'store': elif method == b'store':
blob_hash, token, port, original_publisher_id, age = a blob_hash, token, port, original_publisher_id, age = args[:5]
result = self.node_rpc.store(sender_contact, blob_hash, token, port) result = self.node_rpc.store(sender_contact, blob_hash, token, port)
elif method == b'findNode':
key, = a
result = self.node_rpc.find_node(sender_contact, key)
else: else:
assert method == b'findValue' key = args[0]
key, = a page = kw.get(PAGE_KEY, 0)
result = self.node_rpc.find_value(sender_contact, key) if method == b'findNode':
result = self.node_rpc.find_node(sender_contact, key)
else:
assert method == b'findValue'
result = self.node_rpc.find_value(sender_contact, key, page)
self.send_response( self.send_response(
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result), sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
@ -569,15 +581,18 @@ class KademliaProtocol(DatagramProtocol):
def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram): def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
self._send(peer, error) self._send(peer, error)
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]):
ErrorDatagram]):
if not self.transport or self.transport.is_closing(): if not self.transport or self.transport.is_closing():
raise TransportNotConnected() raise TransportNotConnected()
data = message.bencode() data = message.bencode()
if len(data) > constants.msg_size_limit: if len(data) > constants.msg_size_limit:
log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit) log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)",
raise ValueError() constants.msg_size_limit, len(data))
log.debug("Packet is too large to send: %s", binascii.hexlify(data[:3500]).decode())
raise ValueError(
f"cannot send datagram larger than {constants.msg_size_limit} bytes (packet is {len(data)} bytes)"
)
if isinstance(message, (RequestDatagram, ResponseDatagram)): if isinstance(message, (RequestDatagram, ResponseDatagram)):
assert message.node_id == self.node_id, message assert message.node_id == self.node_id, message
if isinstance(message, RequestDatagram): if isinstance(message, RequestDatagram):
@ -642,12 +657,12 @@ class KademliaProtocol(DatagramProtocol):
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer) log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
except ValueError as err: except ValueError as err:
log.error("Unexpected response: %s" % err) log.error("Unexpected response: %s" % err)
except Exception as err: except RemoteException as err:
if 'Invalid token' in str(err): if 'Invalid token' in str(err):
self.peer_manager.clear_token(peer.node_id) self.peer_manager.clear_token(peer.node_id)
try: try:
return await __store() return await __store()
except: except (ValueError, asyncio.TimeoutError, RemoteException):
return peer.node_id, False return peer.node_id, False
else: else:
log.exception("Unexpected error while storing blob_hash") log.exception("Unexpected error while storing blob_hash")

View file

@ -63,7 +63,7 @@ def bencode(data: typing.Dict) -> bytes:
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict: def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}") assert isinstance(data, bytes), DecodeError(f"invalid data type: {str(type(data))}")
if len(data) == 0: if len(data) == 0:
raise DecodeError('Cannot decode empty string') raise DecodeError('Cannot decode empty string')

View file

@ -7,6 +7,9 @@ REQUEST_TYPE = 0
RESPONSE_TYPE = 1 RESPONSE_TYPE = 1
ERROR_TYPE = 2 ERROR_TYPE = 2
# bencode representation of argument keys
PAGE_KEY = b'p'
class KademliaDatagramBase: class KademliaDatagramBase:
""" """
@ -91,11 +94,13 @@ class RequestDatagram(KademliaDatagramBase):
@classmethod @classmethod
def make_find_value(cls, from_node_id: bytes, key: bytes, def make_find_value(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': rpc_id: typing.Optional[bytes] = None, page: int = 0) -> 'RequestDatagram':
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(key) != constants.hash_bits // 8: if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid key length: {len(key)}") raise ValueError(f"invalid key length: {len(key)}")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key]) if page < 0:
raise ValueError(f"cannot request a negative page ({page})")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key, {PAGE_KEY: page}])
class ResponseDatagram(KademliaDatagramBase): class ResponseDatagram(KademliaDatagramBase):

View file

@ -83,12 +83,15 @@ class TestProtocol(AsyncioTestCase):
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48) find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
self.assertEqual(len(find_value_response[b'contacts']), 0) self.assertEqual(len(find_value_response[b'contacts']), 0)
self.assertSetEqual( self.assertSetEqual(
{b'2' * 48, b'token', b'protocolVersion', b'contacts'}, set(find_value_response.keys()) {b'2' * 48, b'token', b'protocolVersion', b'contacts', b'p'}, set(find_value_response.keys())
) )
self.assertEqual(2, len(find_value_response[b'2' * 48])) self.assertEqual(2, len(find_value_response[b'2' * 48]))
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp()) self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp())
self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response) self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response)
find_value_page_above_pages_response = peer1.node_rpc.find_value(peer3, b'2' * 48, page=10)
self.assertNotIn(b'2' * 48, find_value_page_above_pages_response)
peer1.stop() peer1.stop()
peer2.stop() peer2.stop()
peer1.disconnect() peer1.disconnect()

View file

@ -59,17 +59,28 @@ class TestDatagram(unittest.TestCase):
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21) self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 20, -1)
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id)) self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
# default page argument
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode() serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized) decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48) self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'findValue') self.assertEqual(decoded.method, b'findValue')
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}]) self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 0}])
def test_find_value_response(self): # nondefault page argument
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20, 1).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'findValue')
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 1}])
def test_find_value_response_without_pages_field(self):
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']} found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode() serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
decoded = decode_datagram(serialized) decoded = decode_datagram(serialized)
@ -78,6 +89,15 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.node_id, b'1' * 48) self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response) self.assertDictEqual(decoded.response, found_value_response)
def test_find_value_response_with_pages_field(self):
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01'], b'p': 1}
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response)
def test_store_request(self): def test_store_request(self):
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20) self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)

View file

@ -1,6 +1,7 @@
import contextlib import contextlib
import typing import typing
import binascii import binascii
import socket
import asyncio import asyncio
from torba.testcase import AsyncioTestCase from torba.testcase import AsyncioTestCase
from tests import dht_mocks from tests import dht_mocks
@ -26,6 +27,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
for node_id, address in peer_addresses: for node_id, address in peer_addresses:
await self.add_peer(node_id, address) await self.add_peer(node_id, address)
self.node.joined.set() self.node.joined.set()
self.node._refresh_task = self.loop.create_task(self.node.refresh_node())
async def add_peer(self, node_id, address, add_to_routing_table=True): async def add_peer(self, node_id, address, add_to_routing_table=True):
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address) n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
@ -113,3 +115,56 @@ class TestBlobAnnouncer(AsyncioTestCase):
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id) self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address) self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port) self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)
async def test_popular_blob(self):
peer_count = 150
addresses = [
(constants.generate_id(i + 1), socket.inet_ntoa(int(i + 1).to_bytes(length=4, byteorder='big')))
for i in range(peer_count)
]
blob_hash = b'1' * 48
async with self._test_network_context(peer_addresses=addresses):
total_seen = set()
announced_to = self.nodes[0]
for i in range(1, peer_count):
node = self.nodes[i]
kad_peer = announced_to.protocol.peer_manager.get_kademlia_peer(
node.protocol.node_id, node.protocol.external_ip, node.protocol.udp_port
)
await announced_to.protocol._add_peer(kad_peer)
peer = node.protocol.get_rpc_peer(
node.protocol.peer_manager.get_kademlia_peer(
announced_to.protocol.node_id,
announced_to.protocol.external_ip,
announced_to.protocol.udp_port
)
)
response = await peer.store(blob_hash)
self.assertEqual(response, b'OK')
peers_for_blob = await peer.find_value(blob_hash, 0)
if i == 1:
self.assertTrue(blob_hash not in peers_for_blob)
self.assertEqual(peers_for_blob[b'p'], 0)
else:
self.assertEqual(len(peers_for_blob[blob_hash]), min(i - 1, constants.k))
self.assertEqual(len(announced_to.protocol.data_store.get_peers_for_blob(blob_hash)), i)
if i - 1 > constants.k:
self.assertEqual(len(peers_for_blob[b'contacts']), constants.k)
self.assertEqual(peers_for_blob[b'p'], ((i - 1) // (constants.k + 1)) + 1)
seen = set(peers_for_blob[blob_hash])
self.assertEqual(len(seen), constants.k)
self.assertEqual(len(peers_for_blob[blob_hash]), len(seen))
for pg in range(1, peers_for_blob[b'p']):
page_x = await peer.find_value(blob_hash, pg)
self.assertNotIn(b'contacts', page_x)
page_x_set = set(page_x[blob_hash])
self.assertEqual(len(page_x[blob_hash]), len(page_x_set))
self.assertTrue(len(page_x_set) > 0)
self.assertSetEqual(seen.intersection(page_x_set), set())
seen.intersection_update(page_x_set)
total_seen.update(page_x_set)
else:
self.assertEqual(len(peers_for_blob[b'contacts']), i - 1)
self.assertEqual(len(total_seen), peer_count - 2)