Merge pull request #2247 from lbryio/fix-unable-to-serialize-response
paginated deterministically shuffled find_value
This commit is contained in:
commit
bfeed40bfe
8 changed files with 129 additions and 36 deletions
|
@ -24,7 +24,7 @@ persistent=yes
|
||||||
load-plugins=
|
load-plugins=
|
||||||
|
|
||||||
# Use multiple processes to speed up Pylint.
|
# Use multiple processes to speed up Pylint.
|
||||||
jobs=1
|
jobs=4
|
||||||
|
|
||||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||||
# active Python interpreter and may run arbitrary code.
|
# active Python interpreter and may run arbitrary code.
|
||||||
|
@ -77,8 +77,6 @@ disable=
|
||||||
dangerous-default-value,
|
dangerous-default-value,
|
||||||
duplicate-code,
|
duplicate-code,
|
||||||
fixme,
|
fixme,
|
||||||
global-statement,
|
|
||||||
inherit-non-class,
|
|
||||||
invalid-name,
|
invalid-name,
|
||||||
len-as-condition,
|
len-as-condition,
|
||||||
locally-disabled,
|
locally-disabled,
|
||||||
|
@ -111,13 +109,11 @@ disable=
|
||||||
unnecessary-lambda,
|
unnecessary-lambda,
|
||||||
unused-argument,
|
unused-argument,
|
||||||
unused-variable,
|
unused-variable,
|
||||||
wildcard-import,
|
|
||||||
wrong-import-order,
|
wrong-import-order,
|
||||||
wrong-import-position,
|
wrong-import-position,
|
||||||
deprecated-lambda,
|
deprecated-lambda,
|
||||||
simplifiable-if-statement,
|
simplifiable-if-statement,
|
||||||
unidiomatic-typecheck,
|
unidiomatic-typecheck,
|
||||||
global-at-module-level,
|
|
||||||
inconsistent-return-statements,
|
inconsistent-return-statements,
|
||||||
keyword-arg-before-vararg,
|
keyword-arg-before-vararg,
|
||||||
assignment-from-no-return,
|
assignment-from-no-return,
|
||||||
|
|
|
@ -18,11 +18,10 @@ data_expiration = 86400 # 24 hours
|
||||||
token_secret_refresh_interval = 300 # 5 minutes
|
token_secret_refresh_interval = 300 # 5 minutes
|
||||||
maybe_ping_delay = 300 # 5 minutes
|
maybe_ping_delay = 300 # 5 minutes
|
||||||
check_refresh_interval = refresh_interval / 5
|
check_refresh_interval = refresh_interval / 5
|
||||||
max_datagram_size = 8192 # 8 KB
|
|
||||||
rpc_id_length = 20
|
rpc_id_length = 20
|
||||||
protocol_version = 1
|
protocol_version = 1
|
||||||
bottom_out_limit = 3
|
bottom_out_limit = 3
|
||||||
msg_size_limit = max_datagram_size - 26
|
msg_size_limit = 1400
|
||||||
|
|
||||||
|
|
||||||
def digest(data: bytes) -> bytes:
|
def digest(data: bytes) -> bytes:
|
||||||
|
|
|
@ -5,12 +5,13 @@ import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
import binascii
|
import binascii
|
||||||
|
import random
|
||||||
from asyncio.protocols import DatagramProtocol
|
from asyncio.protocols import DatagramProtocol
|
||||||
from asyncio.transports import DatagramTransport
|
from asyncio.transports import DatagramTransport
|
||||||
|
|
||||||
from lbrynet.dht import constants
|
from lbrynet.dht import constants
|
||||||
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
|
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
|
||||||
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE
|
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY
|
||||||
from lbrynet.dht.error import RemoteException, TransportNotConnected
|
from lbrynet.dht.error import RemoteException, TransportNotConnected
|
||||||
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
|
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
|
||||||
from lbrynet.dht.protocol.data_store import DictDataStore
|
from lbrynet.dht.protocol.data_store import DictDataStore
|
||||||
|
@ -67,19 +68,23 @@ class KademliaRPC:
|
||||||
|
|
||||||
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
|
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
|
||||||
contact_triples = []
|
contact_triples = []
|
||||||
for contact in contacts:
|
for contact in contacts[:constants.k * 2]:
|
||||||
contact_triples.append((contact.node_id, contact.address, contact.udp_port))
|
contact_triples.append((contact.node_id, contact.address, contact.udp_port))
|
||||||
return contact_triples
|
return contact_triples
|
||||||
|
|
||||||
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes):
|
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes, page: int = 0):
|
||||||
|
page = page if page > 0 else 0
|
||||||
|
|
||||||
if len(key) != constants.hash_length:
|
if len(key) != constants.hash_length:
|
||||||
raise ValueError("invalid blob_exchange hash length: %i" % len(key))
|
raise ValueError("invalid blob_exchange hash length: %i" % len(key))
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
b'token': self.make_token(rpc_contact.compact_ip()),
|
b'token': self.make_token(rpc_contact.compact_ip()),
|
||||||
b'contacts': self.find_node(rpc_contact, key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not page:
|
||||||
|
response[b'contacts'] = self.find_node(rpc_contact, key)[:constants.k]
|
||||||
|
|
||||||
if self.protocol.protocol_version:
|
if self.protocol.protocol_version:
|
||||||
response[b'protocolVersion'] = self.protocol.protocol_version
|
response[b'protocolVersion'] = self.protocol.protocol_version
|
||||||
|
|
||||||
|
@ -92,8 +97,14 @@ class KademliaRPC:
|
||||||
# if we don't have k storing peers to return and we have this hash locally, include our contact information
|
# if we don't have k storing peers to return and we have this hash locally, include our contact information
|
||||||
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
|
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
|
||||||
peers.append(self.compact_address())
|
peers.append(self.compact_address())
|
||||||
if peers:
|
if not peers:
|
||||||
response[key] = peers
|
response[PAGE_KEY] = 0
|
||||||
|
else:
|
||||||
|
response[PAGE_KEY] = (len(peers) // (constants.k + 1)) + 1 # how many pages of peers we have for the blob
|
||||||
|
if len(peers) > constants.k:
|
||||||
|
random.Random(self.protocol.node_id).shuffle(peers)
|
||||||
|
if page * constants.k < len(peers):
|
||||||
|
response[key] = peers[page * constants.k:page * constants.k + constants.k]
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def refresh_token(self): # TODO: this needs to be called periodically
|
def refresh_token(self): # TODO: this needs to be called periodically
|
||||||
|
@ -166,7 +177,7 @@ class RemoteKademliaRPC:
|
||||||
)
|
)
|
||||||
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
|
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
|
||||||
|
|
||||||
async def find_value(self, key: bytes) -> typing.Union[typing.Dict]:
|
async def find_value(self, key: bytes, page: int = 0) -> typing.Union[typing.Dict]:
|
||||||
"""
|
"""
|
||||||
:return: {
|
:return: {
|
||||||
b'token': <token bytes>,
|
b'token': <token bytes>,
|
||||||
|
@ -177,7 +188,7 @@ class RemoteKademliaRPC:
|
||||||
if len(key) != constants.hash_bits // 8:
|
if len(key) != constants.hash_bits // 8:
|
||||||
raise ValueError(f"invalid length of find value key: {len(key)}")
|
raise ValueError(f"invalid length of find value key: {len(key)}")
|
||||||
response = await self.protocol.send_request(
|
response = await self.protocol.send_request(
|
||||||
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key)
|
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key, page=page)
|
||||||
)
|
)
|
||||||
self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
|
self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
|
||||||
return response.response
|
return response.response
|
||||||
|
@ -406,24 +417,25 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
raise AttributeError('Invalid method: %s' % message.method.decode())
|
raise AttributeError('Invalid method: %s' % message.method.decode())
|
||||||
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
|
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
|
||||||
# args don't need reformatting
|
# args don't need reformatting
|
||||||
a, kw = tuple(message.args[:-1]), message.args[-1]
|
args, kw = tuple(message.args[:-1]), message.args[-1]
|
||||||
else:
|
else:
|
||||||
a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
|
args, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
|
||||||
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
|
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
|
||||||
sender_contact.address, sender_contact.udp_port)
|
sender_contact.address, sender_contact.udp_port)
|
||||||
|
|
||||||
if method == b'ping':
|
if method == b'ping':
|
||||||
result = self.node_rpc.ping()
|
result = self.node_rpc.ping()
|
||||||
elif method == b'store':
|
elif method == b'store':
|
||||||
blob_hash, token, port, original_publisher_id, age = a
|
blob_hash, token, port, original_publisher_id, age = args[:5]
|
||||||
result = self.node_rpc.store(sender_contact, blob_hash, token, port)
|
result = self.node_rpc.store(sender_contact, blob_hash, token, port)
|
||||||
elif method == b'findNode':
|
else:
|
||||||
key, = a
|
key = args[0]
|
||||||
|
page = kw.get(PAGE_KEY, 0)
|
||||||
|
if method == b'findNode':
|
||||||
result = self.node_rpc.find_node(sender_contact, key)
|
result = self.node_rpc.find_node(sender_contact, key)
|
||||||
else:
|
else:
|
||||||
assert method == b'findValue'
|
assert method == b'findValue'
|
||||||
key, = a
|
result = self.node_rpc.find_value(sender_contact, key, page)
|
||||||
result = self.node_rpc.find_value(sender_contact, key)
|
|
||||||
|
|
||||||
self.send_response(
|
self.send_response(
|
||||||
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
|
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
|
||||||
|
@ -569,15 +581,18 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
|
def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
|
||||||
self._send(peer, error)
|
self._send(peer, error)
|
||||||
|
|
||||||
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
|
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]):
|
||||||
ErrorDatagram]):
|
|
||||||
if not self.transport or self.transport.is_closing():
|
if not self.transport or self.transport.is_closing():
|
||||||
raise TransportNotConnected()
|
raise TransportNotConnected()
|
||||||
|
|
||||||
data = message.bencode()
|
data = message.bencode()
|
||||||
if len(data) > constants.msg_size_limit:
|
if len(data) > constants.msg_size_limit:
|
||||||
log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit)
|
log.warning("cannot send datagram larger than %i bytes (packet is %i bytes)",
|
||||||
raise ValueError()
|
constants.msg_size_limit, len(data))
|
||||||
|
log.debug("Packet is too large to send: %s", binascii.hexlify(data[:3500]).decode())
|
||||||
|
raise ValueError(
|
||||||
|
f"cannot send datagram larger than {constants.msg_size_limit} bytes (packet is {len(data)} bytes)"
|
||||||
|
)
|
||||||
if isinstance(message, (RequestDatagram, ResponseDatagram)):
|
if isinstance(message, (RequestDatagram, ResponseDatagram)):
|
||||||
assert message.node_id == self.node_id, message
|
assert message.node_id == self.node_id, message
|
||||||
if isinstance(message, RequestDatagram):
|
if isinstance(message, RequestDatagram):
|
||||||
|
@ -642,12 +657,12 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
log.error("Unexpected response: %s" % err)
|
log.error("Unexpected response: %s" % err)
|
||||||
except Exception as err:
|
except RemoteException as err:
|
||||||
if 'Invalid token' in str(err):
|
if 'Invalid token' in str(err):
|
||||||
self.peer_manager.clear_token(peer.node_id)
|
self.peer_manager.clear_token(peer.node_id)
|
||||||
try:
|
try:
|
||||||
return await __store()
|
return await __store()
|
||||||
except:
|
except (ValueError, asyncio.TimeoutError, RemoteException):
|
||||||
return peer.node_id, False
|
return peer.node_id, False
|
||||||
else:
|
else:
|
||||||
log.exception("Unexpected error while storing blob_hash")
|
log.exception("Unexpected error while storing blob_hash")
|
||||||
|
|
|
@ -63,7 +63,7 @@ def bencode(data: typing.Dict) -> bytes:
|
||||||
|
|
||||||
|
|
||||||
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
|
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
|
||||||
assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}")
|
assert isinstance(data, bytes), DecodeError(f"invalid data type: {str(type(data))}")
|
||||||
|
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
raise DecodeError('Cannot decode empty string')
|
raise DecodeError('Cannot decode empty string')
|
||||||
|
|
|
@ -7,6 +7,9 @@ REQUEST_TYPE = 0
|
||||||
RESPONSE_TYPE = 1
|
RESPONSE_TYPE = 1
|
||||||
ERROR_TYPE = 2
|
ERROR_TYPE = 2
|
||||||
|
|
||||||
|
# bencode representation of argument keys
|
||||||
|
PAGE_KEY = b'p'
|
||||||
|
|
||||||
|
|
||||||
class KademliaDatagramBase:
|
class KademliaDatagramBase:
|
||||||
"""
|
"""
|
||||||
|
@ -91,11 +94,13 @@ class RequestDatagram(KademliaDatagramBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None, page: int = 0) -> 'RequestDatagram':
|
||||||
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
if len(key) != constants.hash_bits // 8:
|
if len(key) != constants.hash_bits // 8:
|
||||||
raise ValueError(f"invalid key length: {len(key)}")
|
raise ValueError(f"invalid key length: {len(key)}")
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
if page < 0:
|
||||||
|
raise ValueError(f"cannot request a negative page ({page})")
|
||||||
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key, {PAGE_KEY: page}])
|
||||||
|
|
||||||
|
|
||||||
class ResponseDatagram(KademliaDatagramBase):
|
class ResponseDatagram(KademliaDatagramBase):
|
||||||
|
|
|
@ -83,12 +83,15 @@ class TestProtocol(AsyncioTestCase):
|
||||||
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
|
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
|
||||||
self.assertEqual(len(find_value_response[b'contacts']), 0)
|
self.assertEqual(len(find_value_response[b'contacts']), 0)
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{b'2' * 48, b'token', b'protocolVersion', b'contacts'}, set(find_value_response.keys())
|
{b'2' * 48, b'token', b'protocolVersion', b'contacts', b'p'}, set(find_value_response.keys())
|
||||||
)
|
)
|
||||||
self.assertEqual(2, len(find_value_response[b'2' * 48]))
|
self.assertEqual(2, len(find_value_response[b'2' * 48]))
|
||||||
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp())
|
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1.compact_address_tcp())
|
||||||
self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response)
|
self.assertDictEqual(bdecode(bencode(find_value_response)), find_value_response)
|
||||||
|
|
||||||
|
find_value_page_above_pages_response = peer1.node_rpc.find_value(peer3, b'2' * 48, page=10)
|
||||||
|
self.assertNotIn(b'2' * 48, find_value_page_above_pages_response)
|
||||||
|
|
||||||
peer1.stop()
|
peer1.stop()
|
||||||
peer2.stop()
|
peer2.stop()
|
||||||
peer1.disconnect()
|
peer1.disconnect()
|
||||||
|
|
|
@ -59,17 +59,28 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
|
||||||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
|
||||||
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 20, -1)
|
||||||
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
|
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
|
||||||
|
|
||||||
|
# default page argument
|
||||||
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
self.assertEqual(decoded.method, b'findValue')
|
self.assertEqual(decoded.method, b'findValue')
|
||||||
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
|
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 0}])
|
||||||
|
|
||||||
def test_find_value_response(self):
|
# nondefault page argument
|
||||||
|
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20, 1).bencode()
|
||||||
|
decoded = decode_datagram(serialized)
|
||||||
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
|
self.assertEqual(decoded.method, b'findValue')
|
||||||
|
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1, b'p': 1}])
|
||||||
|
|
||||||
|
def test_find_value_response_without_pages_field(self):
|
||||||
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
|
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
|
||||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
|
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
|
@ -78,6 +89,15 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
self.assertDictEqual(decoded.response, found_value_response)
|
self.assertDictEqual(decoded.response, found_value_response)
|
||||||
|
|
||||||
|
def test_find_value_response_with_pages_field(self):
|
||||||
|
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01'], b'p': 1}
|
||||||
|
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
|
||||||
|
decoded = decode_datagram(serialized)
|
||||||
|
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||||
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
|
self.assertDictEqual(decoded.response, found_value_response)
|
||||||
|
|
||||||
def test_store_request(self):
|
def test_store_request(self):
|
||||||
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
|
||||||
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import typing
|
import typing
|
||||||
import binascii
|
import binascii
|
||||||
|
import socket
|
||||||
import asyncio
|
import asyncio
|
||||||
from torba.testcase import AsyncioTestCase
|
from torba.testcase import AsyncioTestCase
|
||||||
from tests import dht_mocks
|
from tests import dht_mocks
|
||||||
|
@ -26,6 +27,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
for node_id, address in peer_addresses:
|
for node_id, address in peer_addresses:
|
||||||
await self.add_peer(node_id, address)
|
await self.add_peer(node_id, address)
|
||||||
self.node.joined.set()
|
self.node.joined.set()
|
||||||
|
self.node._refresh_task = self.loop.create_task(self.node.refresh_node())
|
||||||
|
|
||||||
async def add_peer(self, node_id, address, add_to_routing_table=True):
|
async def add_peer(self, node_id, address, add_to_routing_table=True):
|
||||||
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
|
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
|
||||||
|
@ -113,3 +115,56 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
|
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
|
||||||
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
|
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
|
||||||
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)
|
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)
|
||||||
|
|
||||||
|
async def test_popular_blob(self):
|
||||||
|
peer_count = 150
|
||||||
|
addresses = [
|
||||||
|
(constants.generate_id(i + 1), socket.inet_ntoa(int(i + 1).to_bytes(length=4, byteorder='big')))
|
||||||
|
for i in range(peer_count)
|
||||||
|
]
|
||||||
|
blob_hash = b'1' * 48
|
||||||
|
|
||||||
|
async with self._test_network_context(peer_addresses=addresses):
|
||||||
|
total_seen = set()
|
||||||
|
announced_to = self.nodes[0]
|
||||||
|
for i in range(1, peer_count):
|
||||||
|
node = self.nodes[i]
|
||||||
|
kad_peer = announced_to.protocol.peer_manager.get_kademlia_peer(
|
||||||
|
node.protocol.node_id, node.protocol.external_ip, node.protocol.udp_port
|
||||||
|
)
|
||||||
|
await announced_to.protocol._add_peer(kad_peer)
|
||||||
|
peer = node.protocol.get_rpc_peer(
|
||||||
|
node.protocol.peer_manager.get_kademlia_peer(
|
||||||
|
announced_to.protocol.node_id,
|
||||||
|
announced_to.protocol.external_ip,
|
||||||
|
announced_to.protocol.udp_port
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = await peer.store(blob_hash)
|
||||||
|
self.assertEqual(response, b'OK')
|
||||||
|
peers_for_blob = await peer.find_value(blob_hash, 0)
|
||||||
|
if i == 1:
|
||||||
|
self.assertTrue(blob_hash not in peers_for_blob)
|
||||||
|
self.assertEqual(peers_for_blob[b'p'], 0)
|
||||||
|
else:
|
||||||
|
self.assertEqual(len(peers_for_blob[blob_hash]), min(i - 1, constants.k))
|
||||||
|
self.assertEqual(len(announced_to.protocol.data_store.get_peers_for_blob(blob_hash)), i)
|
||||||
|
if i - 1 > constants.k:
|
||||||
|
self.assertEqual(len(peers_for_blob[b'contacts']), constants.k)
|
||||||
|
self.assertEqual(peers_for_blob[b'p'], ((i - 1) // (constants.k + 1)) + 1)
|
||||||
|
seen = set(peers_for_blob[blob_hash])
|
||||||
|
self.assertEqual(len(seen), constants.k)
|
||||||
|
self.assertEqual(len(peers_for_blob[blob_hash]), len(seen))
|
||||||
|
|
||||||
|
for pg in range(1, peers_for_blob[b'p']):
|
||||||
|
page_x = await peer.find_value(blob_hash, pg)
|
||||||
|
self.assertNotIn(b'contacts', page_x)
|
||||||
|
page_x_set = set(page_x[blob_hash])
|
||||||
|
self.assertEqual(len(page_x[blob_hash]), len(page_x_set))
|
||||||
|
self.assertTrue(len(page_x_set) > 0)
|
||||||
|
self.assertSetEqual(seen.intersection(page_x_set), set())
|
||||||
|
seen.intersection_update(page_x_set)
|
||||||
|
total_seen.update(page_x_set)
|
||||||
|
else:
|
||||||
|
self.assertEqual(len(peers_for_blob[b'contacts']), i - 1)
|
||||||
|
self.assertEqual(len(total_seen), peer_count - 2)
|
||||||
|
|
Loading…
Reference in a new issue