Merge pull request #2247 from lbryio/fix-unable-to-serialize-response

paginated deterministically shuffled find_value
This commit is contained in:
Jack Robison 2019-06-18 23:25:12 -04:00 committed by GitHub
commit bfeed40bfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 129 additions and 36 deletions

View file

@ -24,7 +24,7 @@ persistent=yes
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=1
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
@ -77,8 +77,6 @@ disable=
dangerous-default-value,
duplicate-code,
fixme,
global-statement,
inherit-non-class,
invalid-name,
len-as-condition,
locally-disabled,
@ -111,13 +109,11 @@ disable=
unnecessary-lambda,
unused-argument,
unused-variable,
wildcard-import,
wrong-import-order,
wrong-import-position,
deprecated-lambda,
simplifiable-if-statement,
unidiomatic-typecheck,
global-at-module-level,
inconsistent-return-statements,
keyword-arg-before-vararg,
assignment-from-no-return,

View file

@ -18,11 +18,10 @@ data_expiration = 86400 # 24 hours
token_secret_refresh_interval = 300 # 5 minutes
maybe_ping_delay = 300 # 5 minutes
check_refresh_interval = refresh_interval / 5
max_datagram_size = 8192 # 8 KB
rpc_id_length = 20
protocol_version = 1
bottom_out_limit = 3
msg_size_limit = max_datagram_size - 26
msg_size_limit = 1400
def digest(data: bytes) -> bytes:

View file

@ -5,12 +5,13 @@ import hashlib
import asyncio
import typing
import binascii
import random
from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport
from lbrynet.dht import constants
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY
from lbrynet.dht.error import RemoteException, TransportNotConnected
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
from lbrynet.dht.protocol.data_store import DictDataStore
@ -67,19 +68,23 @@ class KademliaRPC:
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
contact_triples = []
for contact in contacts:
for contact in contacts[:constants.k * 2]:
contact_triples.append((contact.node_id, contact.address, contact.udp_port))
return contact_triples
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes):
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes, page: int = 0):
page = page if page > 0 else 0
if len(key) != constants.hash_length:
raise ValueError("invalid blob_exchange hash length: %i" % len(key))
response = {
b'token': self.make_token(rpc_contact.compact_ip()),
b'contacts': self.find_node(rpc_contact, key)
}
if not page:
response[b'contacts'] = self.find_node(rpc_contact, key)[:constants.k]
if self.protocol.protocol_version:
response[b'protocolVersion'] = self.protocol.protocol_version
@ -92,8 +97,14 @@ class KademliaRPC:
# if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
peers.append(self.compact_address())
if peers:
response[key] = peers
if not peers:
response[PAGE_KEY] = 0
else:
response[PAGE_KEY] = (len(peers) // (constants.k + 1)) + 1 # how many pages of peers we have for the blob
if len(peers) > constants.k:
random.Random(self.protocol.node_id).shuffle(peers)
if page * constants.k < len(peers):
response[key] = peers[page * constants.k:page * constants.k + constants.k]
return response
def refresh_token(self): # TODO: this needs to be called periodically
@ -166,7 +177,7 @@ class RemoteKademliaRPC:
)
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
async def find_value(self, key: bytes) -> typing.Union[typing.Dict]:
async def find_value(self, key: bytes, page: int = 0) -> typing.Union[typing.Dict]:
"""
:return: {
b'token': <token bytes>,
@ -177,7 +188,7 @@ class RemoteKademliaRPC:
if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid length of find value key: {len(key)}")
response = await self.protocol.send_request(
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key)
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key, page=page)
)
self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
return response.response
@ -406,24 +417,25 @@ class KademliaProtocol(DatagramProtocol):
raise AttributeError('Invalid method: %s' % message.method.decode())
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
# args don't need reformatting
a, kw = tuple(message.args[:-1]), message.args[-1]
args, kw = tuple(message.args[:-1]), message.args[-1]
else:
a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
args, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
sender_contact.address, sender_contact.udp_port)
if method == b'ping':
result = self.node_rpc.ping()
elif method == b'store':
blob_hash, token, port, original_publisher_id, age = a
blob_hash, token, port, original_publisher_id, age = args[:5]
result = self.node_rpc.store(sender_contact, blob_hash, token, port)
elif method == b'findNode':
key, = a
else:
key = args[0]
page = kw.get(PAGE_KEY, 0)
if method == b'findNode':
result = self.node_rpc.find_node(sender_contact, key)
else:
assert method == b'findValue'
key, = a
result = self.node_rpc.find_value(sender_contact, key)
result = self.node_rpc.find_value(sender_contact, key, page)
self.send_response(
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
@ -569,15 +581,18 @@ class KademliaProtocol(DatagramProtocol):
def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
self._send(peer, error)
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
ErrorDatagram]):
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]):
if not self.transport or self.transport.is_closing():
raise TransportNotConnected()
data = message.bencode()
if len(data) > constants.msg_size_limit:
log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit)
raise ValueError()
log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)",
constants.msg_size_limit, len(data))
log.debug("Packet is too large to send: %s", binascii.hexlify(data[:3500]).decode())
raise ValueError(
f"cannot send datagram larger than {constants.msg_size_limit} bytes (packet is {len(data)} bytes)"
)
if isinstance(message, (RequestDatagram, ResponseDatagram)):
assert message.node_id == self.node_id, message
if isinstance(message, RequestDatagram):
@ -642,12 +657,12 @@ class KademliaProtocol(DatagramProtocol):
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
except ValueError as err:
log.error("Unexpected response: %s" % err)
except Exception as err:
except RemoteException as err:
if 'Invalid token' in str(err):
self.peer_manager.clear_token(peer.node_id)
try:
return await __store()
except:
except (ValueError, asyncio.TimeoutError, RemoteException):
return peer.node_id, False
else:
log.exception("Unexpected error while storing blob_hash")

View file

@ -63,7 +63,7 @@ def bencode(data: typing.Dict) -> bytes:
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}")
assert isinstance(data, bytes), DecodeError(f"invalid data type: {str(type(data))}")
if len(data) == 0:
raise DecodeError('Cannot decode empty string')

View file

@ -7,6 +7,9 @@ REQUEST_TYPE = 0
RESPONSE_TYPE = 1
ERROR_TYPE = 2
# bencode representation of argument keys
PAGE_KEY = b'p'
class KademliaDatagramBase:
"""
@ -91,11 +94,13 @@ class RequestDatagram(KademliaDatagramBase):
@classmethod
def make_find_value(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
rpc_id: typing.Optional[bytes] = None, page: int = 0) -> 'RequestDatagram':
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid key length: {len(key)}")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
if page < 0:
raise ValueError(f"cannot request a negative page ({page})")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key, {PAGE_KEY: page}])
class ResponseDatagram(KademliaDatagramBase):

View file

@ -83,12 +83,15 @@ class TestProtocol(AsyncioTestCase):
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
self.assertEqual(len(find_value_response[b'contacts']), 0)
self.assertSetEqual(
{b'2' * 48, b'token', b'protocolVersion', b'contacts'}, set(find_value_response.keys())
{b'2' * 48, b'token', b'protocolVersion', b'contacts', b'p'}, set(find_value_response.keys())
)
self.assertEqual(2, len(find_value_response[b'2' * 48]))
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp())
self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response)
find_value_page_above_pages_response = peer1.node_rpc.find_value(peer3, b'2' * 48, page=10)
self.assertNotIn(b'2' * 48, find_value_page_above_pages_response)
peer1.stop()
peer2.stop()
peer1.disconnect()

View file

@ -59,17 +59,28 @@ class TestDatagram(unittest.TestCase):
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 20, -1)
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
# default page argument
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'findValue')
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 0}])
def test_find_value_response(self):
# nondefault page argument
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20, 1).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'findValue')
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 1}])
def test_find_value_response_without_pages_field(self):
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
decoded = decode_datagram(serialized)
@ -78,6 +89,15 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response)
def test_find_value_response_with_pages_field(self):
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01'], b'p': 1}
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response)
def test_store_request(self):
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)

View file

@ -1,6 +1,7 @@
import contextlib
import typing
import binascii
import socket
import asyncio
from torba.testcase import AsyncioTestCase
from tests import dht_mocks
@ -26,6 +27,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
for node_id, address in peer_addresses:
await self.add_peer(node_id, address)
self.node.joined.set()
self.node._refresh_task = self.loop.create_task(self.node.refresh_node())
async def add_peer(self, node_id, address, add_to_routing_table=True):
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
@ -113,3 +115,56 @@ class TestBlobAnnouncer(AsyncioTestCase):
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)
async def test_popular_blob(self):
peer_count = 150
addresses = [
(constants.generate_id(i + 1), socket.inet_ntoa(int(i + 1).to_bytes(length=4, byteorder='big')))
for i in range(peer_count)
]
blob_hash = b'1' * 48
async with self._test_network_context(peer_addresses=addresses):
total_seen = set()
announced_to = self.nodes[0]
for i in range(1, peer_count):
node = self.nodes[i]
kad_peer = announced_to.protocol.peer_manager.get_kademlia_peer(
node.protocol.node_id, node.protocol.external_ip, node.protocol.udp_port
)
await announced_to.protocol._add_peer(kad_peer)
peer = node.protocol.get_rpc_peer(
node.protocol.peer_manager.get_kademlia_peer(
announced_to.protocol.node_id,
announced_to.protocol.external_ip,
announced_to.protocol.udp_port
)
)
response = await peer.store(blob_hash)
self.assertEqual(response, b'OK')
peers_for_blob = await peer.find_value(blob_hash, 0)
if i == 1:
self.assertTrue(blob_hash not in peers_for_blob)
self.assertEqual(peers_for_blob[b'p'], 0)
else:
self.assertEqual(len(peers_for_blob[blob_hash]), min(i - 1, constants.k))
self.assertEqual(len(announced_to.protocol.data_store.get_peers_for_blob(blob_hash)), i)
if i - 1 > constants.k:
self.assertEqual(len(peers_for_blob[b'contacts']), constants.k)
self.assertEqual(peers_for_blob[b'p'], ((i - 1) // (constants.k + 1)) + 1)
seen = set(peers_for_blob[blob_hash])
self.assertEqual(len(seen), constants.k)
self.assertEqual(len(peers_for_blob[blob_hash]), len(seen))
for pg in range(1, peers_for_blob[b'p']):
page_x = await peer.find_value(blob_hash, pg)
self.assertNotIn(b'contacts', page_x)
page_x_set = set(page_x[blob_hash])
self.assertEqual(len(page_x[blob_hash]), len(page_x_set))
self.assertTrue(len(page_x_set) > 0)
self.assertSetEqual(seen.intersection(page_x_set), set())
seen.intersection_update(page_x_set)
total_seen.update(page_x_set)
else:
self.assertEqual(len(peers_for_blob[b'contacts']), i - 1)
self.assertEqual(len(total_seen), peer_count - 2)