Merge pull request #1987 from lbryio/more-dht-tests

improve dht unit tests
This commit is contained in:
Jack Robison 2019-03-15 15:38:51 -04:00 committed by GitHub
commit 58743ba19b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 323 additions and 38 deletions

View file

@ -13,11 +13,6 @@ def cancel_task(task: typing.Optional[asyncio.Task]):
task.cancel()
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
for task in tasks:
cancel_task(task)
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks:
cancel_task(tasks.pop())

View file

@ -10,11 +10,13 @@ class Distance:
def __init__(self, key: bytes):
if len(key) != constants.hash_length:
raise ValueError("invalid key length: %i" % len(key))
raise ValueError(f"invalid key length: {len(key)}")
self.key = key
self.val_key_one = int.from_bytes(key, 'big')
def __call__(self, key_two: bytes) -> int:
if len(key_two) != constants.hash_length:
raise ValueError(f"invalid length of key to compare: {len(key_two)}")
val_key_two = int.from_bytes(key_two, 'big')
return self.val_key_one ^ val_key_two

View file

@ -65,45 +65,36 @@ class RequestDatagram(KademliaDatagramBase):
@classmethod
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
@classmethod
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
if not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(blob_hash) != constants.hash_bits // 8:
raise ValueError(f"invalid blob hash length: {len(blob_hash)}")
if not 0 < port < 65536:
raise ValueError(f"invalid port: {port}")
if len(token) != constants.hash_bits // 8:
raise ValueError(f"invalid token length: {len(token)}")
store_args = [blob_hash, token, port, from_node_id, 0]
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
@classmethod
def make_find_node(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
if not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid key length: {len(key)}")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
@classmethod
def make_find_value(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
if not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid key length: {len(key)}")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
@ -147,8 +138,6 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
}
primitive: typing.Dict = bdecode(datagram)
if not isinstance(primitive, dict):
raise ValueError("invalid datagram type")
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
else:
@ -162,14 +151,19 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
)
def make_compact_ip(address: str):
return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
def make_compact_ip(address: str) -> bytearray:
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
if len(compact_ip) != 4:
raise ValueError(f"invalid IPv4 length")
return compact_ip
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
compact_ip = make_compact_ip(address)
if not 0 <= port <= 65536:
if not 0 < port < 65536:
raise ValueError(f'Invalid port: {port}')
if len(node_id) != constants.hash_bits // 8:
raise ValueError(f"invalid node node_id length")
return compact_ip + port.to_bytes(2, 'big') + node_id
@ -177,4 +171,8 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i
address = "{}.{}.{}.{}".format(*compact_address[:4])
port = int.from_bytes(compact_address[4:6], 'big')
node_id = compact_address[6:]
if not 0 < port < 65536:
raise ValueError(f'Invalid port: {port}')
if len(node_id) != constants.hash_bits // 8:
raise ValueError(f"invalid node node_id length")
return node_id, address, port

View file

@ -0,0 +1,93 @@
import asyncio
from torba.testcase import AsyncioTestCase
from lbrynet.dht.protocol.data_store import DictDataStore
from lbrynet.dht.peer import PeerManager
class DataStoreTests(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.peer_manager = PeerManager(self.loop)
self.data_store = DictDataStore(self.loop, self.peer_manager)
def _test_add_peer_to_blob(self, blob=b'2' * 48, node_id=b'1' * 48, address='1.2.3.4', tcp_port=3333,
udp_port=4444):
peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
peer.update_tcp_port(tcp_port)
before = self.data_store.get_peers_for_blob(blob)
self.data_store.add_peer_to_blob(peer, blob, peer.compact_address_tcp(), 0, 0, peer.node_id)
self.assertListEqual(before + [peer], self.data_store.get_peers_for_blob(blob))
return peer
def test_add_peer_to_blob(self, blob=b'f' * 48, peers=None):
peers = peers or [
(b'a' * 48, '1.2.3.4'),
(b'b' * 48, '1.2.3.5'),
(b'c' * 48, '1.2.3.6'),
]
self.assertListEqual([], self.data_store.get_peers_for_blob(blob))
peer_objects = []
for (node_id, address) in peers:
peer_objects.append(self._test_add_peer_to_blob(blob=blob, node_id=node_id, address=address))
self.assertTrue(self.data_store.has_peers_for_blob(blob))
self.assertEqual(len(self.data_store.get_peers_for_blob(blob)), len(peers))
return peer_objects
def test_get_storing_contacts(self, peers=None, blob1=b'd' * 48, blob2=b'e' * 48):
peers = peers or [
(b'a' * 48, '1.2.3.4'),
(b'b' * 48, '1.2.3.5'),
(b'c' * 48, '1.2.3.6'),
]
peer_objs1 = self.test_add_peer_to_blob(blob=blob1, peers=peers)
self.assertEqual(len(peers), len(peer_objs1))
self.assertEqual(len(peers), len(self.data_store.get_storing_contacts()))
peer_objs2 = self.test_add_peer_to_blob(blob=blob2, peers=peers)
self.assertEqual(len(peers), len(peer_objs2))
self.assertEqual(len(peers), len(self.data_store.get_storing_contacts()))
for o1, o2 in zip(peer_objs1, peer_objs2):
self.assertIs(o1, o2)
def test_remove_expired_peers(self):
peers = [
(b'a' * 48, '1.2.3.4'),
(b'b' * 48, '1.2.3.5'),
(b'c' * 48, '1.2.3.6'),
]
blob1 = b'd' * 48
blob2 = b'e' * 48
self.data_store.removed_expired_peers() # nothing should happen
self.test_get_storing_contacts(peers, blob1, blob2)
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers))
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers))
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
# expire the first peer from blob1
first = self.data_store._data_store[blob1][0]
self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4])
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
self.data_store.removed_expired_peers()
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers))
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
# expire the first peer from blob2
first = self.data_store._data_store[blob2][0]
self.data_store._data_store[blob2][0] = (first[0], first[1], first[2], -86401, first[4])
self.data_store.removed_expired_peers()
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1)
# expire the second and third peers from blob1
first = self.data_store._data_store[blob2][0]
self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4])
second = self.data_store._data_store[blob2][1]
self.data_store._data_store[blob1][1] = (second[0], second[1], second[2], -86401, second[4])
self.data_store.removed_expired_peers()
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), 0)
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1)

View file

@ -0,0 +1,13 @@
import unittest
from lbrynet.dht.protocol.distance import Distance
class DistanceTests(unittest.TestCase):
def test_invalid_key_length(self):
self.assertRaises(ValueError, Distance, b'1' * 47)
self.assertRaises(ValueError, Distance, b'1' * 49)
self.assertRaises(ValueError, Distance, b'')
self.assertRaises(ValueError, Distance(b'0' * 48), b'1' * 47)
self.assertRaises(ValueError, Distance(b'0' * 48), b'1' * 49)
self.assertRaises(ValueError, Distance(b'0' * 48), b'')

View file

@ -62,3 +62,4 @@ class EncodeDecodeTest(unittest.TestCase):
def test_decode_error(self):
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
self.assertRaises(DecodeError, bdecode, b'', True)
self.assertRaises(DecodeError, bdecode, b'l4:spami42ee')

View file

@ -1,11 +1,17 @@
import unittest
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE
from lbrynet.dht.error import DecodeError
from lbrynet.dht.serialization.bencoding import _bencode
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
from lbrynet.dht.serialization.datagram import make_compact_address, decode_compact_address
class TestDatagram(unittest.TestCase):
def test_ping_request_datagram(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode()
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 48, b'1' * 21)
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 47, b'1' * 20)
self.assertEqual(20, len(RequestDatagram.make_ping(b'1' * 48).rpc_id))
serialized = RequestDatagram.make_ping(b'1' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
@ -14,6 +20,9 @@ class TestDatagram(unittest.TestCase):
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
def test_ping_response(self):
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 21, b'1' * 48, b'pong')
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 20, b'1' * 49, b'pong')
self.assertRaises(ValueError, ResponseDatagram, 5, b'1' * 20, b'1' * 48, b'pong')
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
@ -22,7 +31,12 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.response, b'pong')
def test_find_node_request_datagram(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode()
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 49, b'2' * 48, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 49, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 48, b'1' * 21)
self.assertEqual(20, len(RequestDatagram.make_find_node(b'1' * 48, b'2' * 48).rpc_id))
serialized = RequestDatagram.make_find_node(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
@ -42,7 +56,12 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.response, expected)
def test_find_value_request(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode()
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
@ -58,3 +77,53 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response)
def test_store_request(self):
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 47, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, -3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 21)
serialized = RequestDatagram.make_store(b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 20).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'store')
def test_error_datagram(self):
serialized = ErrorDatagram(ERROR_TYPE, b'1' * 20, b'1' * 48, b'FakeErrorType', b'more info').bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, ERROR_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.exception_type, 'FakeErrorType')
self.assertEqual(decoded.response, 'more info')
def test_invalid_datagram_type(self):
serialized = b'di0ei5ei1e20:11111111111111111111i2e48:11111111111111111111' \
b'1111111111111111111111111111i3e13:FakeErrorTypei4e9:more infoe'
self.assertRaises(ValueError, decode_datagram, serialized)
self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4]))
class TestCompactAddress(unittest.TestCase):
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
decoded = decode_compact_address(make_compact_address(node_id, address, port))
self.assertEqual((node_id, address, port), decoded)
def test_errors(self):
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 0)
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 65536)
self.assertRaises(
ValueError, decode_compact_address,
b'\x01\x02\x03\x04\x00\x00111111111111111111111111111111111111111111111111'
)
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4.5', 4444)
self.assertRaises(ValueError, make_compact_address, b'1' * 47, '1.2.3.4', 4444)
self.assertRaises(
ValueError, decode_compact_address,
b'\x01\x02\x03\x04\x11\\11111111111111111111111111111111111111111111111'
)

View file

@ -0,0 +1,114 @@
import contextlib
import typing
import binascii
import asyncio
from torba.testcase import AsyncioTestCase
from tests import dht_mocks
from lbrynet.conf import Config
from lbrynet.dht import constants
from lbrynet.dht.node import Node
from lbrynet.dht.peer import PeerManager
from lbrynet.dht.blob_announcer import BlobAnnouncer
from lbrynet.extras.daemon.storage import SQLiteStorage
class TestBlobAnnouncer(AsyncioTestCase):
async def setup_node(self, peer_addresses, address, node_id):
self.nodes: typing.Dict[int, Node] = {}
self.advance = dht_mocks.get_time_accelerator(self.loop, self.loop.time())
self.conf = Config()
self.storage = SQLiteStorage(self.conf, ":memory:", self.loop, self.loop.time)
await self.storage.open()
self.peer_manager = PeerManager(self.loop)
self.node = Node(self.loop, self.peer_manager, node_id, 4444, 4444, 3333, address)
await self.node.start_listening(address)
self.blob_announcer = BlobAnnouncer(self.loop, self.node, self.storage)
for node_id, address in peer_addresses:
await self.add_peer(node_id, address)
self.node.joined.set()
async def add_peer(self, node_id, address, add_to_routing_table=True):
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
await n.start_listening(address)
self.nodes.update({len(self.nodes): n})
if add_to_routing_table:
await self.node.protocol.add_peer(
self.peer_manager.get_kademlia_peer(
n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port
)
)
@contextlib.asynccontextmanager
async def _test_network_context(self, peer_addresses=None):
self.peer_addresses = peer_addresses or [
(constants.generate_id(2), '1.2.3.2'),
(constants.generate_id(3), '1.2.3.3'),
(constants.generate_id(4), '1.2.3.4'),
(constants.generate_id(5), '1.2.3.5'),
(constants.generate_id(6), '1.2.3.6'),
(constants.generate_id(7), '1.2.3.7'),
(constants.generate_id(8), '1.2.3.8'),
(constants.generate_id(9), '1.2.3.9'),
]
try:
with dht_mocks.mock_network_loop(self.loop):
await self.setup_node(self.peer_addresses, '1.2.3.1', constants.generate_id(1))
yield
finally:
self.blob_announcer.stop()
self.node.stop()
for n in self.nodes.values():
n.stop()
async def chain_peer(self, node_id, address):
previous_last_node = self.nodes[len(self.nodes) - 1]
await self.add_peer(node_id, address, False)
last_node = self.nodes[len(self.nodes) - 1]
peer = last_node.protocol.get_rpc_peer(
last_node.protocol.peer_manager.get_kademlia_peer(
previous_last_node.protocol.node_id, previous_last_node.protocol.external_ip,
previous_last_node.protocol.udp_port
)
)
await peer.ping()
return peer
async def test_announce_blobs(self):
blob1 = binascii.hexlify(b'1' * 48).decode()
blob2 = binascii.hexlify(b'2' * 48).decode()
async with self._test_network_context():
await self.storage.add_completed_blob(blob1, 1024)
await self.storage.add_completed_blob(blob2, 1024)
await self.storage.db.execute(
"update blob set next_announce_time=0, should_announce=1 where blob_hash in (?, ?)",
(blob1, blob2)
)
to_announce = await self.storage.get_blobs_to_announce()
self.assertEqual(2, len(to_announce))
self.blob_announcer.start()
await self.advance(61.0)
to_announce = await self.storage.get_blobs_to_announce()
self.assertEqual(0, len(to_announce))
self.blob_announcer.stop()
# test that we can route from a poorly connected peer all the way to the announced blob
await self.chain_peer(constants.generate_id(10), '1.2.3.10')
await self.chain_peer(constants.generate_id(11), '1.2.3.11')
await self.chain_peer(constants.generate_id(12), '1.2.3.12')
await self.chain_peer(constants.generate_id(13), '1.2.3.13')
await self.chain_peer(constants.generate_id(14), '1.2.3.14')
last = self.nodes[len(self.nodes) - 1]
search_q, peer_q = asyncio.Queue(loop=self.loop), asyncio.Queue(loop=self.loop)
search_q.put_nowait(blob1)
_, task = last.accumulate_peers(search_q, peer_q)
found_peers = await peer_q.get()
task.cancel()
self.assertEqual(1, len(found_peers))
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)