forked from LBRYCommunity/lbry-sdk
Merge pull request #2247 from lbryio/fix-unable-to-serialize-response
paginated deterministically shuffled find_value
This commit is contained in:
commit
bfeed40bfe
8 changed files with 129 additions and 36 deletions
|
@ -24,7 +24,7 @@ persistent=yes
|
|||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=1
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
|
@ -77,8 +77,6 @@ disable=
|
|||
dangerous-default-value,
|
||||
duplicate-code,
|
||||
fixme,
|
||||
global-statement,
|
||||
inherit-non-class,
|
||||
invalid-name,
|
||||
len-as-condition,
|
||||
locally-disabled,
|
||||
|
@ -111,13 +109,11 @@ disable=
|
|||
unnecessary-lambda,
|
||||
unused-argument,
|
||||
unused-variable,
|
||||
wildcard-import,
|
||||
wrong-import-order,
|
||||
wrong-import-position,
|
||||
deprecated-lambda,
|
||||
simplifiable-if-statement,
|
||||
unidiomatic-typecheck,
|
||||
global-at-module-level,
|
||||
inconsistent-return-statements,
|
||||
keyword-arg-before-vararg,
|
||||
assignment-from-no-return,
|
||||
|
|
|
@ -18,11 +18,10 @@ data_expiration = 86400 # 24 hours
|
|||
token_secret_refresh_interval = 300 # 5 minutes
|
||||
maybe_ping_delay = 300 # 5 minutes
|
||||
check_refresh_interval = refresh_interval / 5
|
||||
max_datagram_size = 8192 # 8 KB
|
||||
rpc_id_length = 20
|
||||
protocol_version = 1
|
||||
bottom_out_limit = 3
|
||||
msg_size_limit = max_datagram_size - 26
|
||||
msg_size_limit = 1400
|
||||
|
||||
|
||||
def digest(data: bytes) -> bytes:
|
||||
|
|
|
@ -5,12 +5,13 @@ import hashlib
|
|||
import asyncio
|
||||
import typing
|
||||
import binascii
|
||||
import random
|
||||
from asyncio.protocols import DatagramProtocol
|
||||
from asyncio.transports import DatagramTransport
|
||||
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
|
||||
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE
|
||||
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY
|
||||
from lbrynet.dht.error import RemoteException, TransportNotConnected
|
||||
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
|
||||
from lbrynet.dht.protocol.data_store import DictDataStore
|
||||
|
@ -67,19 +68,23 @@ class KademliaRPC:
|
|||
|
||||
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
|
||||
contact_triples = []
|
||||
for contact in contacts:
|
||||
for contact in contacts[:constants.k * 2]:
|
||||
contact_triples.append((contact.node_id, contact.address, contact.udp_port))
|
||||
return contact_triples
|
||||
|
||||
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes):
|
||||
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes, page: int = 0):
|
||||
page = page if page > 0 else 0
|
||||
|
||||
if len(key) != constants.hash_length:
|
||||
raise ValueError("invalid blob_exchange hash length: %i" % len(key))
|
||||
|
||||
response = {
|
||||
b'token': self.make_token(rpc_contact.compact_ip()),
|
||||
b'contacts': self.find_node(rpc_contact, key)
|
||||
}
|
||||
|
||||
if not page:
|
||||
response[b'contacts'] = self.find_node(rpc_contact, key)[:constants.k]
|
||||
|
||||
if self.protocol.protocol_version:
|
||||
response[b'protocolVersion'] = self.protocol.protocol_version
|
||||
|
||||
|
@ -92,8 +97,14 @@ class KademliaRPC:
|
|||
# if we don't have k storing peers to return and we have this hash locally, include our contact information
|
||||
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
|
||||
peers.append(self.compact_address())
|
||||
if peers:
|
||||
response[key] = peers
|
||||
if not peers:
|
||||
response[PAGE_KEY] = 0
|
||||
else:
|
||||
response[PAGE_KEY] = (len(peers) // (constants.k + 1)) + 1 # how many pages of peers we have for the blob
|
||||
if len(peers) > constants.k:
|
||||
random.Random(self.protocol.node_id).shuffle(peers)
|
||||
if page * constants.k < len(peers):
|
||||
response[key] = peers[page * constants.k:page * constants.k + constants.k]
|
||||
return response
|
||||
|
||||
def refresh_token(self): # TODO: this needs to be called periodically
|
||||
|
@ -166,7 +177,7 @@ class RemoteKademliaRPC:
|
|||
)
|
||||
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
|
||||
|
||||
async def find_value(self, key: bytes) -> typing.Union[typing.Dict]:
|
||||
async def find_value(self, key: bytes, page: int = 0) -> typing.Union[typing.Dict]:
|
||||
"""
|
||||
:return: {
|
||||
b'token': <token bytes>,
|
||||
|
@ -177,7 +188,7 @@ class RemoteKademliaRPC:
|
|||
if len(key) != constants.hash_bits // 8:
|
||||
raise ValueError(f"invalid length of find value key: {len(key)}")
|
||||
response = await self.protocol.send_request(
|
||||
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key)
|
||||
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key, page=page)
|
||||
)
|
||||
self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
|
||||
return response.response
|
||||
|
@ -406,24 +417,25 @@ class KademliaProtocol(DatagramProtocol):
|
|||
raise AttributeError('Invalid method: %s' % message.method.decode())
|
||||
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
|
||||
# args don't need reformatting
|
||||
a, kw = tuple(message.args[:-1]), message.args[-1]
|
||||
args, kw = tuple(message.args[:-1]), message.args[-1]
|
||||
else:
|
||||
a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
|
||||
args, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
|
||||
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
|
||||
sender_contact.address, sender_contact.udp_port)
|
||||
|
||||
if method == b'ping':
|
||||
result = self.node_rpc.ping()
|
||||
elif method == b'store':
|
||||
blob_hash, token, port, original_publisher_id, age = a
|
||||
blob_hash, token, port, original_publisher_id, age = args[:5]
|
||||
result = self.node_rpc.store(sender_contact, blob_hash, token, port)
|
||||
elif method == b'findNode':
|
||||
key, = a
|
||||
else:
|
||||
key = args[0]
|
||||
page = kw.get(PAGE_KEY, 0)
|
||||
if method == b'findNode':
|
||||
result = self.node_rpc.find_node(sender_contact, key)
|
||||
else:
|
||||
assert method == b'findValue'
|
||||
key, = a
|
||||
result = self.node_rpc.find_value(sender_contact, key)
|
||||
result = self.node_rpc.find_value(sender_contact, key, page)
|
||||
|
||||
self.send_response(
|
||||
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
|
||||
|
@ -569,15 +581,18 @@ class KademliaProtocol(DatagramProtocol):
|
|||
def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
|
||||
self._send(peer, error)
|
||||
|
||||
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
|
||||
ErrorDatagram]):
|
||||
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]):
|
||||
if not self.transport or self.transport.is_closing():
|
||||
raise TransportNotConnected()
|
||||
|
||||
data = message.bencode()
|
||||
if len(data) > constants.msg_size_limit:
|
||||
log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit)
|
||||
raise ValueError()
|
||||
log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)",
|
||||
constants.msg_size_limit, len(data))
|
||||
log.debug("Packet is too large to send: %s", binascii.hexlify(data[:3500]).decode())
|
||||
raise ValueError(
|
||||
f"cannot send datagram larger than {constants.msg_size_limit} bytes (packet is {len(data)} bytes)"
|
||||
)
|
||||
if isinstance(message, (RequestDatagram, ResponseDatagram)):
|
||||
assert message.node_id == self.node_id, message
|
||||
if isinstance(message, RequestDatagram):
|
||||
|
@ -642,12 +657,12 @@ class KademliaProtocol(DatagramProtocol):
|
|||
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
||||
except ValueError as err:
|
||||
log.error("Unexpected response: %s" % err)
|
||||
except Exception as err:
|
||||
except RemoteException as err:
|
||||
if 'Invalid token' in str(err):
|
||||
self.peer_manager.clear_token(peer.node_id)
|
||||
try:
|
||||
return await __store()
|
||||
except:
|
||||
except (ValueError, asyncio.TimeoutError, RemoteException):
|
||||
return peer.node_id, False
|
||||
else:
|
||||
log.exception("Unexpected error while storing blob_hash")
|
||||
|
|
|
@ -63,7 +63,7 @@ def bencode(data: typing.Dict) -> bytes:
|
|||
|
||||
|
||||
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
|
||||
assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}")
|
||||
assert isinstance(data, bytes), DecodeError(f"invalid data type: {str(type(data))}")
|
||||
|
||||
if len(data) == 0:
|
||||
raise DecodeError('Cannot decode empty string')
|
||||
|
|
|
@ -7,6 +7,9 @@ REQUEST_TYPE = 0
|
|||
RESPONSE_TYPE = 1
|
||||
ERROR_TYPE = 2
|
||||
|
||||
# bencode representation of argument keys
|
||||
PAGE_KEY = b'p'
|
||||
|
||||
|
||||
class KademliaDatagramBase:
|
||||
"""
|
||||
|
@ -91,11 +94,13 @@ class RequestDatagram(KademliaDatagramBase):
|
|||
|
||||
@classmethod
|
||||
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||
rpc_id: typing.Optional[bytes] = None, page: int = 0) -> 'RequestDatagram':
|
||||
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||
if len(key) != constants.hash_bits // 8:
|
||||
raise ValueError(f"invalid key length: {len(key)}")
|
||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
||||
if page < 0:
|
||||
raise ValueError(f"cannot request a negative page ({page})")
|
||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key, {PAGE_KEY: page}])
|
||||
|
||||
|
||||
class ResponseDatagram(KademliaDatagramBase):
|
||||
|
|
|
@ -83,12 +83,15 @@ class TestProtocol(AsyncioTestCase):
|
|||
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
|
||||
self.assertEqual(len(find_value_response[b'contacts']), 0)
|
||||
self.assertSetEqual(
|
||||
{b'2' * 48, b'token', b'protocolVersion', b'contacts'}, set(find_value_response.keys())
|
||||
{b'2' * 48, b'token', b'protocolVersion', b'contacts', b'p'}, set(find_value_response.keys())
|
||||
)
|
||||
self.assertEqual(2, len(find_value_response[b'2' * 48]))
|
||||
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp())
|
||||
self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response)
|
||||
|
||||
find_value_page_above_pages_response = peer1.node_rpc.find_value(peer3, b'2' * 48, page=10)
|
||||
self.assertNotIn(b'2' * 48, find_value_page_above_pages_response)
|
||||
|
||||
peer1.stop()
|
||||
peer2.stop()
|
||||
peer1.disconnect()
|
||||
|
|
|
@ -59,17 +59,28 @@ class TestDatagram(unittest.TestCase):
|
|||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
|
||||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
|
||||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
|
||||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 20, -1)
|
||||
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
|
||||
|
||||
# default page argument
|
||||
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.method, b'findValue')
|
||||
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
|
||||
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 0}])
|
||||
|
||||
def test_find_value_response(self):
|
||||
# nondefault page argument
|
||||
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20, 1).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.method, b'findValue')
|
||||
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 1}])
|
||||
|
||||
def test_find_value_response_without_pages_field(self):
|
||||
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
|
||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
|
@ -78,6 +89,15 @@ class TestDatagram(unittest.TestCase):
|
|||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertDictEqual(decoded.response, found_value_response)
|
||||
|
||||
def test_find_value_response_with_pages_field(self):
|
||||
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01'], b'p': 1}
|
||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertDictEqual(decoded.response, found_value_response)
|
||||
|
||||
def test_store_request(self):
|
||||
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
|
||||
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import contextlib
|
||||
import typing
|
||||
import binascii
|
||||
import socket
|
||||
import asyncio
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from tests import dht_mocks
|
||||
|
@ -26,6 +27,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
|||
for node_id, address in peer_addresses:
|
||||
await self.add_peer(node_id, address)
|
||||
self.node.joined.set()
|
||||
self.node._refresh_task = self.loop.create_task(self.node.refresh_node())
|
||||
|
||||
async def add_peer(self, node_id, address, add_to_routing_table=True):
|
||||
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
|
||||
|
@ -113,3 +115,56 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
|||
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
|
||||
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
|
||||
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)
|
||||
|
||||
async def test_popular_blob(self):
|
||||
peer_count = 150
|
||||
addresses = [
|
||||
(constants.generate_id(i + 1), socket.inet_ntoa(int(i + 1).to_bytes(length=4, byteorder='big')))
|
||||
for i in range(peer_count)
|
||||
]
|
||||
blob_hash = b'1' * 48
|
||||
|
||||
async with self._test_network_context(peer_addresses=addresses):
|
||||
total_seen = set()
|
||||
announced_to = self.nodes[0]
|
||||
for i in range(1, peer_count):
|
||||
node = self.nodes[i]
|
||||
kad_peer = announced_to.protocol.peer_manager.get_kademlia_peer(
|
||||
node.protocol.node_id, node.protocol.external_ip, node.protocol.udp_port
|
||||
)
|
||||
await announced_to.protocol._add_peer(kad_peer)
|
||||
peer = node.protocol.get_rpc_peer(
|
||||
node.protocol.peer_manager.get_kademlia_peer(
|
||||
announced_to.protocol.node_id,
|
||||
announced_to.protocol.external_ip,
|
||||
announced_to.protocol.udp_port
|
||||
)
|
||||
)
|
||||
response = await peer.store(blob_hash)
|
||||
self.assertEqual(response, b'OK')
|
||||
peers_for_blob = await peer.find_value(blob_hash, 0)
|
||||
if i == 1:
|
||||
self.assertTrue(blob_hash not in peers_for_blob)
|
||||
self.assertEqual(peers_for_blob[b'p'], 0)
|
||||
else:
|
||||
self.assertEqual(len(peers_for_blob[blob_hash]), min(i - 1, constants.k))
|
||||
self.assertEqual(len(announced_to.protocol.data_store.get_peers_for_blob(blob_hash)), i)
|
||||
if i - 1 > constants.k:
|
||||
self.assertEqual(len(peers_for_blob[b'contacts']), constants.k)
|
||||
self.assertEqual(peers_for_blob[b'p'], ((i - 1) // (constants.k + 1)) + 1)
|
||||
seen = set(peers_for_blob[blob_hash])
|
||||
self.assertEqual(len(seen), constants.k)
|
||||
self.assertEqual(len(peers_for_blob[blob_hash]), len(seen))
|
||||
|
||||
for pg in range(1, peers_for_blob[b'p']):
|
||||
page_x = await peer.find_value(blob_hash, pg)
|
||||
self.assertNotIn(b'contacts', page_x)
|
||||
page_x_set = set(page_x[blob_hash])
|
||||
self.assertEqual(len(page_x[blob_hash]), len(page_x_set))
|
||||
self.assertTrue(len(page_x_set) > 0)
|
||||
self.assertSetEqual(seen.intersection(page_x_set), set())
|
||||
seen.intersection_update(page_x_set)
|
||||
total_seen.update(page_x_set)
|
||||
else:
|
||||
self.assertEqual(len(peers_for_blob[b'contacts']), i - 1)
|
||||
self.assertEqual(len(total_seen), peer_count - 2)
|
||||
|
|
Loading…
Reference in a new issue