Merge pull request #1987 from lbryio/more-dht-tests
improve dht unit tests
This commit is contained in:
commit
58743ba19b
12 changed files with 323 additions and 38 deletions
|
@ -13,11 +13,6 @@ def cancel_task(task: typing.Optional[asyncio.Task]):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
|
||||||
for task in tasks:
|
|
||||||
cancel_task(task)
|
|
||||||
|
|
||||||
|
|
||||||
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
||||||
while tasks:
|
while tasks:
|
||||||
cancel_task(tasks.pop())
|
cancel_task(tasks.pop())
|
||||||
|
|
|
@ -10,11 +10,13 @@ class Distance:
|
||||||
|
|
||||||
def __init__(self, key: bytes):
|
def __init__(self, key: bytes):
|
||||||
if len(key) != constants.hash_length:
|
if len(key) != constants.hash_length:
|
||||||
raise ValueError("invalid key length: %i" % len(key))
|
raise ValueError(f"invalid key length: {len(key)}")
|
||||||
self.key = key
|
self.key = key
|
||||||
self.val_key_one = int.from_bytes(key, 'big')
|
self.val_key_one = int.from_bytes(key, 'big')
|
||||||
|
|
||||||
def __call__(self, key_two: bytes) -> int:
|
def __call__(self, key_two: bytes) -> int:
|
||||||
|
if len(key_two) != constants.hash_length:
|
||||||
|
raise ValueError(f"invalid length of key to compare: {len(key_two)}")
|
||||||
val_key_two = int.from_bytes(key_two, 'big')
|
val_key_two = int.from_bytes(key_two, 'big')
|
||||||
return self.val_key_one ^ val_key_two
|
return self.val_key_one ^ val_key_two
|
||||||
|
|
||||||
|
|
|
@ -65,45 +65,36 @@ class RequestDatagram(KademliaDatagramBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
|
||||||
raise ValueError("invalid rpc id length")
|
|
||||||
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
|
||||||
raise ValueError("invalid node id")
|
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
|
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
raise ValueError("invalid rpc id length")
|
if len(blob_hash) != constants.hash_bits // 8:
|
||||||
if not rpc_id:
|
raise ValueError(f"invalid blob hash length: {len(blob_hash)}")
|
||||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
if not 0 < port < 65536:
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
raise ValueError(f"invalid port: {port}")
|
||||||
raise ValueError("invalid node id")
|
if len(token) != constants.hash_bits // 8:
|
||||||
|
raise ValueError(f"invalid token length: {len(token)}")
|
||||||
store_args = [blob_hash, token, port, from_node_id, 0]
|
store_args = [blob_hash, token, port, from_node_id, 0]
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_find_node(cls, from_node_id: bytes, key: bytes,
|
def make_find_node(cls, from_node_id: bytes, key: bytes,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
raise ValueError("invalid rpc id length")
|
if len(key) != constants.hash_bits // 8:
|
||||||
if not rpc_id:
|
raise ValueError(f"invalid key length: {len(key)}")
|
||||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
|
||||||
raise ValueError("invalid node id")
|
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
raise ValueError("invalid rpc id length")
|
if len(key) != constants.hash_bits // 8:
|
||||||
if not rpc_id:
|
raise ValueError(f"invalid key length: {len(key)}")
|
||||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
|
||||||
raise ValueError("invalid node id")
|
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,8 +138,6 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
||||||
}
|
}
|
||||||
|
|
||||||
primitive: typing.Dict = bdecode(datagram)
|
primitive: typing.Dict = bdecode(datagram)
|
||||||
if not isinstance(primitive, dict):
|
|
||||||
raise ValueError("invalid datagram type")
|
|
||||||
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
||||||
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
||||||
else:
|
else:
|
||||||
|
@ -162,14 +151,19 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_compact_ip(address: str):
|
def make_compact_ip(address: str) -> bytearray:
|
||||||
return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
|
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
|
||||||
|
if len(compact_ip) != 4:
|
||||||
|
raise ValueError(f"invalid IPv4 length")
|
||||||
|
return compact_ip
|
||||||
|
|
||||||
|
|
||||||
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
|
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
|
||||||
compact_ip = make_compact_ip(address)
|
compact_ip = make_compact_ip(address)
|
||||||
if not 0 <= port <= 65536:
|
if not 0 < port < 65536:
|
||||||
raise ValueError(f'Invalid port: {port}')
|
raise ValueError(f'Invalid port: {port}')
|
||||||
|
if len(node_id) != constants.hash_bits // 8:
|
||||||
|
raise ValueError(f"invalid node node_id length")
|
||||||
return compact_ip + port.to_bytes(2, 'big') + node_id
|
return compact_ip + port.to_bytes(2, 'big') + node_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,4 +171,8 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i
|
||||||
address = "{}.{}.{}.{}".format(*compact_address[:4])
|
address = "{}.{}.{}.{}".format(*compact_address[:4])
|
||||||
port = int.from_bytes(compact_address[4:6], 'big')
|
port = int.from_bytes(compact_address[4:6], 'big')
|
||||||
node_id = compact_address[6:]
|
node_id = compact_address[6:]
|
||||||
|
if not 0 < port < 65536:
|
||||||
|
raise ValueError(f'Invalid port: {port}')
|
||||||
|
if len(node_id) != constants.hash_bits // 8:
|
||||||
|
raise ValueError(f"invalid node node_id length")
|
||||||
return node_id, address, port
|
return node_id, address, port
|
||||||
|
|
93
tests/unit/dht/protocol/test_data_store.py
Normal file
93
tests/unit/dht/protocol/test_data_store.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import asyncio
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
from lbrynet.dht.protocol.data_store import DictDataStore
|
||||||
|
from lbrynet.dht.peer import PeerManager
|
||||||
|
|
||||||
|
|
||||||
|
class DataStoreTests(AsyncioTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.peer_manager = PeerManager(self.loop)
|
||||||
|
self.data_store = DictDataStore(self.loop, self.peer_manager)
|
||||||
|
|
||||||
|
def _test_add_peer_to_blob(self, blob=b'2' * 48, node_id=b'1' * 48, address='1.2.3.4', tcp_port=3333,
|
||||||
|
udp_port=4444):
|
||||||
|
peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
|
||||||
|
peer.update_tcp_port(tcp_port)
|
||||||
|
before = self.data_store.get_peers_for_blob(blob)
|
||||||
|
self.data_store.add_peer_to_blob(peer, blob, peer.compact_address_tcp(), 0, 0, peer.node_id)
|
||||||
|
self.assertListEqual(before + [peer], self.data_store.get_peers_for_blob(blob))
|
||||||
|
return peer
|
||||||
|
|
||||||
|
def test_add_peer_to_blob(self, blob=b'f' * 48, peers=None):
|
||||||
|
peers = peers or [
|
||||||
|
(b'a' * 48, '1.2.3.4'),
|
||||||
|
(b'b' * 48, '1.2.3.5'),
|
||||||
|
(b'c' * 48, '1.2.3.6'),
|
||||||
|
]
|
||||||
|
self.assertListEqual([], self.data_store.get_peers_for_blob(blob))
|
||||||
|
peer_objects = []
|
||||||
|
for (node_id, address) in peers:
|
||||||
|
peer_objects.append(self._test_add_peer_to_blob(blob=blob, node_id=node_id, address=address))
|
||||||
|
self.assertTrue(self.data_store.has_peers_for_blob(blob))
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob)), len(peers))
|
||||||
|
return peer_objects
|
||||||
|
|
||||||
|
def test_get_storing_contacts(self, peers=None, blob1=b'd' * 48, blob2=b'e' * 48):
|
||||||
|
peers = peers or [
|
||||||
|
(b'a' * 48, '1.2.3.4'),
|
||||||
|
(b'b' * 48, '1.2.3.5'),
|
||||||
|
(b'c' * 48, '1.2.3.6'),
|
||||||
|
]
|
||||||
|
peer_objs1 = self.test_add_peer_to_blob(blob=blob1, peers=peers)
|
||||||
|
self.assertEqual(len(peers), len(peer_objs1))
|
||||||
|
self.assertEqual(len(peers), len(self.data_store.get_storing_contacts()))
|
||||||
|
|
||||||
|
peer_objs2 = self.test_add_peer_to_blob(blob=blob2, peers=peers)
|
||||||
|
self.assertEqual(len(peers), len(peer_objs2))
|
||||||
|
self.assertEqual(len(peers), len(self.data_store.get_storing_contacts()))
|
||||||
|
|
||||||
|
for o1, o2 in zip(peer_objs1, peer_objs2):
|
||||||
|
self.assertIs(o1, o2)
|
||||||
|
|
||||||
|
def test_remove_expired_peers(self):
|
||||||
|
peers = [
|
||||||
|
(b'a' * 48, '1.2.3.4'),
|
||||||
|
(b'b' * 48, '1.2.3.5'),
|
||||||
|
(b'c' * 48, '1.2.3.6'),
|
||||||
|
]
|
||||||
|
blob1 = b'd' * 48
|
||||||
|
blob2 = b'e' * 48
|
||||||
|
|
||||||
|
self.data_store.removed_expired_peers() # nothing should happen
|
||||||
|
self.test_get_storing_contacts(peers, blob1, blob2)
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers))
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers))
|
||||||
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
||||||
|
|
||||||
|
# expire the first peer from blob1
|
||||||
|
first = self.data_store._data_store[blob1][0]
|
||||||
|
self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4])
|
||||||
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
||||||
|
self.data_store.removed_expired_peers()
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers))
|
||||||
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
||||||
|
|
||||||
|
# expire the first peer from blob2
|
||||||
|
first = self.data_store._data_store[blob2][0]
|
||||||
|
self.data_store._data_store[blob2][0] = (first[0], first[1], first[2], -86401, first[4])
|
||||||
|
self.data_store.removed_expired_peers()
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
|
||||||
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1)
|
||||||
|
|
||||||
|
# expire the second and third peers from blob1
|
||||||
|
first = self.data_store._data_store[blob2][0]
|
||||||
|
self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4])
|
||||||
|
second = self.data_store._data_store[blob2][1]
|
||||||
|
self.data_store._data_store[blob1][1] = (second[0], second[1], second[2], -86401, second[4])
|
||||||
|
self.data_store.removed_expired_peers()
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), 0)
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
|
||||||
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1)
|
13
tests/unit/dht/protocol/test_distance.py
Normal file
13
tests/unit/dht/protocol/test_distance.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import unittest
|
||||||
|
from lbrynet.dht.protocol.distance import Distance
|
||||||
|
|
||||||
|
|
||||||
|
class DistanceTests(unittest.TestCase):
|
||||||
|
def test_invalid_key_length(self):
|
||||||
|
self.assertRaises(ValueError, Distance, b'1' * 47)
|
||||||
|
self.assertRaises(ValueError, Distance, b'1' * 49)
|
||||||
|
self.assertRaises(ValueError, Distance, b'')
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, Distance(b'0' * 48), b'1' * 47)
|
||||||
|
self.assertRaises(ValueError, Distance(b'0' * 48), b'1' * 49)
|
||||||
|
self.assertRaises(ValueError, Distance(b'0' * 48), b'')
|
|
@ -62,3 +62,4 @@ class EncodeDecodeTest(unittest.TestCase):
|
||||||
def test_decode_error(self):
|
def test_decode_error(self):
|
||||||
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
|
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
|
||||||
self.assertRaises(DecodeError, bdecode, b'', True)
|
self.assertRaises(DecodeError, bdecode, b'', True)
|
||||||
|
self.assertRaises(DecodeError, bdecode, b'l4:spami42ee')
|
||||||
|
|
|
@ -1,11 +1,17 @@
|
||||||
import unittest
|
import unittest
|
||||||
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram
|
from lbrynet.dht.error import DecodeError
|
||||||
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE
|
from lbrynet.dht.serialization.bencoding import _bencode
|
||||||
|
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
|
||||||
|
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
|
||||||
|
from lbrynet.dht.serialization.datagram import make_compact_address, decode_compact_address
|
||||||
|
|
||||||
|
|
||||||
class TestDatagram(unittest.TestCase):
|
class TestDatagram(unittest.TestCase):
|
||||||
def test_ping_request_datagram(self):
|
def test_ping_request_datagram(self):
|
||||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode()
|
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 48, b'1' * 21)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 47, b'1' * 20)
|
||||||
|
self.assertEqual(20, len(RequestDatagram.make_ping(b'1' * 48).rpc_id))
|
||||||
|
serialized = RequestDatagram.make_ping(b'1' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
@ -14,6 +20,9 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
|
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
|
||||||
|
|
||||||
def test_ping_response(self):
|
def test_ping_response(self):
|
||||||
|
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 21, b'1' * 48, b'pong')
|
||||||
|
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 20, b'1' * 49, b'pong')
|
||||||
|
self.assertRaises(ValueError, ResponseDatagram, 5, b'1' * 20, b'1' * 48, b'pong')
|
||||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
|
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||||
|
@ -22,7 +31,12 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.response, b'pong')
|
self.assertEqual(decoded.response, b'pong')
|
||||||
|
|
||||||
def test_find_node_request_datagram(self):
|
def test_find_node_request_datagram(self):
|
||||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode()
|
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 49, b'2' * 48, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 49, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 48, b'1' * 21)
|
||||||
|
self.assertEqual(20, len(RequestDatagram.make_find_node(b'1' * 48, b'2' * 48).rpc_id))
|
||||||
|
|
||||||
|
serialized = RequestDatagram.make_find_node(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
@ -42,7 +56,12 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.response, expected)
|
self.assertEqual(decoded.response, expected)
|
||||||
|
|
||||||
def test_find_value_request(self):
|
def test_find_value_request(self):
|
||||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode()
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
|
||||||
|
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
|
||||||
|
|
||||||
|
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
@ -58,3 +77,53 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
self.assertDictEqual(decoded.response, found_value_response)
|
self.assertDictEqual(decoded.response, found_value_response)
|
||||||
|
|
||||||
|
def test_store_request(self):
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 47, 3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, -3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 21)
|
||||||
|
|
||||||
|
serialized = RequestDatagram.make_store(b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 20).bencode()
|
||||||
|
decoded = decode_datagram(serialized)
|
||||||
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
|
self.assertEqual(decoded.method, b'store')
|
||||||
|
|
||||||
|
def test_error_datagram(self):
|
||||||
|
serialized = ErrorDatagram(ERROR_TYPE, b'1' * 20, b'1' * 48, b'FakeErrorType', b'more info').bencode()
|
||||||
|
decoded = decode_datagram(serialized)
|
||||||
|
self.assertEqual(decoded.packet_type, ERROR_TYPE)
|
||||||
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
|
self.assertEqual(decoded.exception_type, 'FakeErrorType')
|
||||||
|
self.assertEqual(decoded.response, 'more info')
|
||||||
|
|
||||||
|
def test_invalid_datagram_type(self):
|
||||||
|
serialized = b'di0ei5ei1e20:11111111111111111111i2e48:11111111111111111111' \
|
||||||
|
b'1111111111111111111111111111i3e13:FakeErrorTypei4e9:more infoe'
|
||||||
|
self.assertRaises(ValueError, decode_datagram, serialized)
|
||||||
|
self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactAddress(unittest.TestCase):
|
||||||
|
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
||||||
|
decoded = decode_compact_address(make_compact_address(node_id, address, port))
|
||||||
|
self.assertEqual((node_id, address, port), decoded)
|
||||||
|
|
||||||
|
def test_errors(self):
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 0)
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 65536)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError, decode_compact_address,
|
||||||
|
b'\x01\x02\x03\x04\x00\x00111111111111111111111111111111111111111111111111'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4.5', 4444)
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 47, '1.2.3.4', 4444)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError, decode_compact_address,
|
||||||
|
b'\x01\x02\x03\x04\x11\\11111111111111111111111111111111111111111111111'
|
||||||
|
)
|
||||||
|
|
114
tests/unit/dht/test_blob_announcer.py
Normal file
114
tests/unit/dht/test_blob_announcer.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import contextlib
|
||||||
|
import typing
|
||||||
|
import binascii
|
||||||
|
import asyncio
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
from tests import dht_mocks
|
||||||
|
from lbrynet.conf import Config
|
||||||
|
from lbrynet.dht import constants
|
||||||
|
from lbrynet.dht.node import Node
|
||||||
|
from lbrynet.dht.peer import PeerManager
|
||||||
|
from lbrynet.dht.blob_announcer import BlobAnnouncer
|
||||||
|
from lbrynet.extras.daemon.storage import SQLiteStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
|
async def setup_node(self, peer_addresses, address, node_id):
|
||||||
|
self.nodes: typing.Dict[int, Node] = {}
|
||||||
|
self.advance = dht_mocks.get_time_accelerator(self.loop, self.loop.time())
|
||||||
|
self.conf = Config()
|
||||||
|
self.storage = SQLiteStorage(self.conf, ":memory:", self.loop, self.loop.time)
|
||||||
|
await self.storage.open()
|
||||||
|
self.peer_manager = PeerManager(self.loop)
|
||||||
|
self.node = Node(self.loop, self.peer_manager, node_id, 4444, 4444, 3333, address)
|
||||||
|
await self.node.start_listening(address)
|
||||||
|
self.blob_announcer = BlobAnnouncer(self.loop, self.node, self.storage)
|
||||||
|
for node_id, address in peer_addresses:
|
||||||
|
await self.add_peer(node_id, address)
|
||||||
|
self.node.joined.set()
|
||||||
|
|
||||||
|
async def add_peer(self, node_id, address, add_to_routing_table=True):
|
||||||
|
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
|
||||||
|
await n.start_listening(address)
|
||||||
|
self.nodes.update({len(self.nodes): n})
|
||||||
|
if add_to_routing_table:
|
||||||
|
await self.node.protocol.add_peer(
|
||||||
|
self.peer_manager.get_kademlia_peer(
|
||||||
|
n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _test_network_context(self, peer_addresses=None):
|
||||||
|
self.peer_addresses = peer_addresses or [
|
||||||
|
(constants.generate_id(2), '1.2.3.2'),
|
||||||
|
(constants.generate_id(3), '1.2.3.3'),
|
||||||
|
(constants.generate_id(4), '1.2.3.4'),
|
||||||
|
(constants.generate_id(5), '1.2.3.5'),
|
||||||
|
(constants.generate_id(6), '1.2.3.6'),
|
||||||
|
(constants.generate_id(7), '1.2.3.7'),
|
||||||
|
(constants.generate_id(8), '1.2.3.8'),
|
||||||
|
(constants.generate_id(9), '1.2.3.9'),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
with dht_mocks.mock_network_loop(self.loop):
|
||||||
|
await self.setup_node(self.peer_addresses, '1.2.3.1', constants.generate_id(1))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.blob_announcer.stop()
|
||||||
|
self.node.stop()
|
||||||
|
for n in self.nodes.values():
|
||||||
|
n.stop()
|
||||||
|
|
||||||
|
async def chain_peer(self, node_id, address):
|
||||||
|
previous_last_node = self.nodes[len(self.nodes) - 1]
|
||||||
|
await self.add_peer(node_id, address, False)
|
||||||
|
last_node = self.nodes[len(self.nodes) - 1]
|
||||||
|
peer = last_node.protocol.get_rpc_peer(
|
||||||
|
last_node.protocol.peer_manager.get_kademlia_peer(
|
||||||
|
previous_last_node.protocol.node_id, previous_last_node.protocol.external_ip,
|
||||||
|
previous_last_node.protocol.udp_port
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await peer.ping()
|
||||||
|
return peer
|
||||||
|
|
||||||
|
async def test_announce_blobs(self):
|
||||||
|
blob1 = binascii.hexlify(b'1' * 48).decode()
|
||||||
|
blob2 = binascii.hexlify(b'2' * 48).decode()
|
||||||
|
|
||||||
|
async with self._test_network_context():
|
||||||
|
await self.storage.add_completed_blob(blob1, 1024)
|
||||||
|
await self.storage.add_completed_blob(blob2, 1024)
|
||||||
|
await self.storage.db.execute(
|
||||||
|
"update blob set next_announce_time=0, should_announce=1 where blob_hash in (?, ?)",
|
||||||
|
(blob1, blob2)
|
||||||
|
)
|
||||||
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
|
self.assertEqual(2, len(to_announce))
|
||||||
|
self.blob_announcer.start()
|
||||||
|
await self.advance(61.0)
|
||||||
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
|
self.assertEqual(0, len(to_announce))
|
||||||
|
self.blob_announcer.stop()
|
||||||
|
|
||||||
|
# test that we can route from a poorly connected peer all the way to the announced blob
|
||||||
|
|
||||||
|
await self.chain_peer(constants.generate_id(10), '1.2.3.10')
|
||||||
|
await self.chain_peer(constants.generate_id(11), '1.2.3.11')
|
||||||
|
await self.chain_peer(constants.generate_id(12), '1.2.3.12')
|
||||||
|
await self.chain_peer(constants.generate_id(13), '1.2.3.13')
|
||||||
|
await self.chain_peer(constants.generate_id(14), '1.2.3.14')
|
||||||
|
|
||||||
|
last = self.nodes[len(self.nodes) - 1]
|
||||||
|
search_q, peer_q = asyncio.Queue(loop=self.loop), asyncio.Queue(loop=self.loop)
|
||||||
|
search_q.put_nowait(blob1)
|
||||||
|
|
||||||
|
_, task = last.accumulate_peers(search_q, peer_q)
|
||||||
|
found_peers = await peer_q.get()
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
self.assertEqual(1, len(found_peers))
|
||||||
|
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
|
||||||
|
self.assertEqual(self.node.protocol.external_ip, found_peers[0].address)
|
||||||
|
self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)
|
Loading…
Reference in a new issue