From 6565ca85585e6221c7a21c1b4c88f534555a6fb0 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Thu, 14 Mar 2019 18:41:24 -0400 Subject: [PATCH 1/5] improve lbrynet.dht.serialization unit tests --- .../dht/protocol/async_generator_junction.py | 5 -- lbrynet/dht/serialization/datagram.py | 52 ++++++------ .../unit/dht/serialization/test_bencoding.py | 1 + tests/unit/dht/serialization/test_datagram.py | 79 +++++++++++++++++-- 4 files changed, 100 insertions(+), 37 deletions(-) diff --git a/lbrynet/dht/protocol/async_generator_junction.py b/lbrynet/dht/protocol/async_generator_junction.py index cbcb1ba84..79db6a55d 100644 --- a/lbrynet/dht/protocol/async_generator_junction.py +++ b/lbrynet/dht/protocol/async_generator_junction.py @@ -13,11 +13,6 @@ def cancel_task(task: typing.Optional[asyncio.Task]): task.cancel() -def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): - for task in tasks: - cancel_task(task) - - def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): while tasks: cancel_task(tasks.pop()) diff --git a/lbrynet/dht/serialization/datagram.py b/lbrynet/dht/serialization/datagram.py index 2d2104ee3..231fbbb59 100644 --- a/lbrynet/dht/serialization/datagram.py +++ b/lbrynet/dht/serialization/datagram.py @@ -65,45 +65,36 @@ class RequestDatagram(KademliaDatagramBase): @classmethod def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping') @classmethod def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") - if not rpc_id: - rpc_id = constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") + rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] + if len(blob_hash) != constants.hash_bits // 8: + raise ValueError(f"invalid blob hash length: {len(blob_hash)}") + if not 0 < port < 65536: + raise ValueError(f"invalid port: {port}") + if len(token) != constants.hash_bits // 8: + raise ValueError(f"invalid token length: {len(token)}") store_args = [blob_hash, token, port, from_node_id, 0] return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args) @classmethod def make_find_node(cls, from_node_id: bytes, key: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") - if not rpc_id: - rpc_id = constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") + rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] + if len(key) != constants.hash_bits // 8: + raise ValueError(f"invalid key length: {len(key)}") return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key]) @classmethod def make_find_value(cls, from_node_id: bytes, key: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") - if not rpc_id: - rpc_id = constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") + rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] + if len(key) != constants.hash_bits // 8: + raise ValueError(f"invalid key length: {len(key)}") return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key]) @@ -147,8 +138,6 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa } primitive: typing.Dict = bdecode(datagram) - if not isinstance(primitive, dict): - raise ValueError("invalid datagram type") if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object datagram_type = primitive[0] # pylint: disable=unsubscriptable-object else: @@ -162,14 +151,19 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa ) -def make_compact_ip(address: str): - return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) +def make_compact_ip(address: str) -> bytearray: + compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) + if len(compact_ip) != 4: + raise ValueError(f"invalid IPv4 length") + return compact_ip def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray: compact_ip = make_compact_ip(address) - if not 0 <= port <= 65536: + if not 0 < port < 65536: raise ValueError(f'Invalid port: {port}') + if len(node_id) != constants.hash_bits // 8: + raise ValueError(f"invalid node node_id length") return compact_ip + port.to_bytes(2, 'big') + node_id @@ -177,4 +171,8 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i address = "{}.{}.{}.{}".format(*compact_address[:4]) port = int.from_bytes(compact_address[4:6], 'big') node_id = compact_address[6:] + if not 0 < port < 65536: + raise ValueError(f'Invalid port: {port}') + if len(node_id) != constants.hash_bits // 8: + raise ValueError(f"invalid node node_id length") return node_id, address, port diff --git a/tests/unit/dht/serialization/test_bencoding.py b/tests/unit/dht/serialization/test_bencoding.py index b516d2e46..983ab224c 100644 --- a/tests/unit/dht/serialization/test_bencoding.py +++ b/tests/unit/dht/serialization/test_bencoding.py @@ -62,3 +62,4 @@ class EncodeDecodeTest(unittest.TestCase): def test_decode_error(self): self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True) self.assertRaises(DecodeError, bdecode, b'', True) + self.assertRaises(DecodeError, bdecode, b'l4:spami42ee') diff --git a/tests/unit/dht/serialization/test_datagram.py b/tests/unit/dht/serialization/test_datagram.py index 738ebaa6c..6f473d212 100644 --- a/tests/unit/dht/serialization/test_datagram.py +++ b/tests/unit/dht/serialization/test_datagram.py @@ -1,11 +1,17 @@ import unittest -from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram -from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE +from lbrynet.dht.error import DecodeError +from lbrynet.dht.serialization.bencoding import _bencode +from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram +from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE +from lbrynet.dht.serialization.datagram import make_compact_address, decode_compact_address class TestDatagram(unittest.TestCase): def test_ping_request_datagram(self): - serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode() + self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 48, b'1' * 21) + self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 47, b'1' * 20) + self.assertEqual(20, len(RequestDatagram.make_ping(b'1' * 48).rpc_id)) + serialized = RequestDatagram.make_ping(b'1' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) @@ -14,6 +20,9 @@ class TestDatagram(unittest.TestCase): self.assertListEqual(decoded.args, [{b'protocolVersion': 1}]) def test_ping_response(self): + self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 21, b'1' * 48, b'pong') + self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 20, b'1' * 49, b'pong') + self.assertRaises(ValueError, ResponseDatagram, 5, b'1' * 20, b'1' * 48, b'pong') serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, RESPONSE_TYPE) @@ -22,7 +31,12 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.response, b'pong') def test_find_node_request_datagram(self): - serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode() + self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 49, b'2' * 48, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 49, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 48, b'1' * 21) + self.assertEqual(20, len(RequestDatagram.make_find_node(b'1' * 48, b'2' * 48).rpc_id)) + + serialized = RequestDatagram.make_find_node(b'1' * 48, b'2' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) @@ -42,7 +56,12 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.response, expected) def test_find_value_request(self): - serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode() + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21) + self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id)) + + serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) @@ -58,3 +77,53 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.node_id, b'1' * 48) self.assertDictEqual(decoded.response, found_value_response) + + def test_store_request(self): + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 47, 3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, -3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 21) + + serialized = RequestDatagram.make_store(b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 20).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, REQUEST_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.method, b'store') + + def test_error_datagram(self): + serialized = ErrorDatagram(ERROR_TYPE, b'1' * 20, b'1' * 48, b'FakeErrorType', b'more info').bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, ERROR_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.exception_type, 'FakeErrorType') + self.assertEqual(decoded.response, 'more info') + + def test_invalid_datagram_type(self): + serialized = b'di0ei5ei1e20:11111111111111111111i2e48:11111111111111111111' \ + b'1111111111111111111111111111i3e13:FakeErrorTypei4e9:more infoe' + self.assertRaises(ValueError, decode_datagram, serialized) + self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4])) + + +class TestCompactAddress(unittest.TestCase): + def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48): + decoded = decode_compact_address(make_compact_address(node_id, address, port)) + self.assertEqual((node_id, address, port), decoded) + + def test_errors(self): + self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 0) + self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 65536) + self.assertRaises( + ValueError, decode_compact_address, + b'\x01\x02\x03\x04\x00\x00111111111111111111111111111111111111111111111111' + ) + + self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4.5', 4444) + self.assertRaises(ValueError, make_compact_address, b'1' * 47, '1.2.3.4', 4444) + self.assertRaises( + ValueError, decode_compact_address, + b'\x01\x02\x03\x04\x11\\11111111111111111111111111111111111111111111111' + ) From 44f92a271f844a34052c984fb7ad74bebaf8c241 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Fri, 15 Mar 2019 12:01:21 -0400 Subject: [PATCH 2/5] move dht tests to reflect the real directory structure --- tests/unit/dht/{ => protocol}/test_async_gen_junction.py | 0 tests/unit/dht/{routing => protocol}/test_kbucket.py | 0 tests/unit/dht/{routing => protocol}/test_routing_table.py | 0 tests/unit/dht/routing/__init__.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/unit/dht/{ => protocol}/test_async_gen_junction.py (100%) rename tests/unit/dht/{routing => protocol}/test_kbucket.py (100%) rename tests/unit/dht/{routing => protocol}/test_routing_table.py (100%) delete mode 100644 tests/unit/dht/routing/__init__.py diff --git a/tests/unit/dht/test_async_gen_junction.py b/tests/unit/dht/protocol/test_async_gen_junction.py similarity index 100% rename from tests/unit/dht/test_async_gen_junction.py rename to tests/unit/dht/protocol/test_async_gen_junction.py diff --git a/tests/unit/dht/routing/test_kbucket.py b/tests/unit/dht/protocol/test_kbucket.py similarity index 100% rename from tests/unit/dht/routing/test_kbucket.py rename to tests/unit/dht/protocol/test_kbucket.py diff --git a/tests/unit/dht/routing/test_routing_table.py b/tests/unit/dht/protocol/test_routing_table.py similarity index 100% rename from tests/unit/dht/routing/test_routing_table.py rename to tests/unit/dht/protocol/test_routing_table.py diff --git a/tests/unit/dht/routing/__init__.py b/tests/unit/dht/routing/__init__.py deleted file mode 100644 index e69de29bb..000000000 From 664f91bfab921a07d692b35e9d605510d0d8b891 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Fri, 15 Mar 2019 12:44:41 -0400 Subject: [PATCH 3/5] add lbrynet.dht.protocol.distance unit tests --- lbrynet/dht/protocol/distance.py | 4 +++- tests/unit/dht/protocol/test_distance.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 tests/unit/dht/protocol/test_distance.py diff --git a/lbrynet/dht/protocol/distance.py b/lbrynet/dht/protocol/distance.py index 516dfabc4..2b2577fa3 100644 --- a/lbrynet/dht/protocol/distance.py +++ b/lbrynet/dht/protocol/distance.py @@ -10,11 +10,13 @@ class Distance: def __init__(self, key: bytes): if len(key) != constants.hash_length: - raise ValueError("invalid key length: %i" % len(key)) + raise ValueError(f"invalid key length: {len(key)}") self.key = key self.val_key_one = int.from_bytes(key, 'big') def __call__(self, key_two: bytes) -> int: + if len(key_two) != constants.hash_length: + raise ValueError(f"invalid length of key to compare: {len(key_two)}") val_key_two = int.from_bytes(key_two, 'big') return self.val_key_one ^ val_key_two diff --git a/tests/unit/dht/protocol/test_distance.py b/tests/unit/dht/protocol/test_distance.py new file mode 100644 index 000000000..829bcf3f9 --- /dev/null +++ b/tests/unit/dht/protocol/test_distance.py @@ -0,0 +1,13 @@ +import unittest +from lbrynet.dht.protocol.distance import Distance + + +class DistanceTests(unittest.TestCase): + def test_invalid_key_length(self): + self.assertRaises(ValueError, Distance, b'1' * 47) + self.assertRaises(ValueError, Distance, b'1' * 49) + self.assertRaises(ValueError, Distance, b'') + + self.assertRaises(ValueError, Distance(b'0' * 48), b'1' * 47) + self.assertRaises(ValueError, Distance(b'0' * 48), b'1' * 49) + self.assertRaises(ValueError, Distance(b'0' * 48), b'') From 60a24f0e6e3e94b56d8c65d7f4dba511c46af3fc Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Fri, 15 Mar 2019 12:44:55 -0400 Subject: [PATCH 4/5] add lbrynet.dht.protocol.data_store unit tests --- tests/unit/dht/protocol/test_data_store.py | 93 ++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/unit/dht/protocol/test_data_store.py diff --git a/tests/unit/dht/protocol/test_data_store.py b/tests/unit/dht/protocol/test_data_store.py new file mode 100644 index 000000000..54c58cce9 --- /dev/null +++ b/tests/unit/dht/protocol/test_data_store.py @@ -0,0 +1,93 @@ +import asyncio +from torba.testcase import AsyncioTestCase +from lbrynet.dht.protocol.data_store import DictDataStore +from lbrynet.dht.peer import PeerManager + + +class DataStoreTests(AsyncioTestCase): + def setUp(self): + self.loop = asyncio.get_event_loop() + self.peer_manager = PeerManager(self.loop) + self.data_store = DictDataStore(self.loop, self.peer_manager) + + def _test_add_peer_to_blob(self, blob=b'2' * 48, node_id=b'1' * 48, address='1.2.3.4', tcp_port=3333, + udp_port=4444): + peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port) + peer.update_tcp_port(tcp_port) + before = self.data_store.get_peers_for_blob(blob) + self.data_store.add_peer_to_blob(peer, blob, peer.compact_address_tcp(), 0, 0, peer.node_id) + self.assertListEqual(before + [peer], self.data_store.get_peers_for_blob(blob)) + return peer + + def test_add_peer_to_blob(self, blob=b'f' * 48, peers=None): + peers = peers or [ + (b'a' * 48, '1.2.3.4'), + (b'b' * 48, '1.2.3.5'), + (b'c' * 48, '1.2.3.6'), + ] + self.assertListEqual([], self.data_store.get_peers_for_blob(blob)) + peer_objects = [] + for (node_id, address) in peers: + peer_objects.append(self._test_add_peer_to_blob(blob=blob, node_id=node_id, address=address)) + self.assertTrue(self.data_store.has_peers_for_blob(blob)) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob)), len(peers)) + return peer_objects + + def test_get_storing_contacts(self, peers=None, blob1=b'd' * 48, blob2=b'e' * 48): + peers = peers or [ + (b'a' * 48, '1.2.3.4'), + (b'b' * 48, '1.2.3.5'), + (b'c' * 48, '1.2.3.6'), + ] + peer_objs1 = self.test_add_peer_to_blob(blob=blob1, peers=peers) + self.assertEqual(len(peers), len(peer_objs1)) + self.assertEqual(len(peers), len(self.data_store.get_storing_contacts())) + + peer_objs2 = self.test_add_peer_to_blob(blob=blob2, peers=peers) + self.assertEqual(len(peers), len(peer_objs2)) + self.assertEqual(len(peers), len(self.data_store.get_storing_contacts())) + + for o1, o2 in zip(peer_objs1, peer_objs2): + self.assertIs(o1, o2) + + def test_remove_expired_peers(self): + peers = [ + (b'a' * 48, '1.2.3.4'), + (b'b' * 48, '1.2.3.5'), + (b'c' * 48, '1.2.3.6'), + ] + blob1 = b'd' * 48 + blob2 = b'e' * 48 + + self.data_store.removed_expired_peers() # nothing should happen + self.test_get_storing_contacts(peers, blob1, blob2) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers)) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers)) + self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers)) + + # expire the first peer from blob1 + first = self.data_store._data_store[blob1][0] + self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4]) + self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers)) + self.data_store.removed_expired_peers() + self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers)) + self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers)) + + # expire the first peer from blob2 + first = self.data_store._data_store[blob2][0] + self.data_store._data_store[blob2][0] = (first[0], first[1], first[2], -86401, first[4]) + self.data_store.removed_expired_peers() + self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1) + self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1) + + # expire the second and third peers from blob1 + first = self.data_store._data_store[blob2][0] + self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4]) + second = self.data_store._data_store[blob2][1] + self.data_store._data_store[blob1][1] = (second[0], second[1], second[2], -86401, second[4]) + self.data_store.removed_expired_peers() + self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), 0) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1) + self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1) From f2fefbe287298b8a6b7a55397f9c2e3abda1bdf7 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Fri, 15 Mar 2019 14:00:49 -0400 Subject: [PATCH 5/5] add lbrynet.dht.blob_announcer unit tests --- tests/unit/dht/test_blob_announcer.py | 114 ++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 tests/unit/dht/test_blob_announcer.py diff --git a/tests/unit/dht/test_blob_announcer.py b/tests/unit/dht/test_blob_announcer.py new file mode 100644 index 000000000..654484af7 --- /dev/null +++ b/tests/unit/dht/test_blob_announcer.py @@ -0,0 +1,114 @@ +import contextlib +import typing +import binascii +import asyncio +from torba.testcase import AsyncioTestCase +from tests import dht_mocks +from lbrynet.conf import Config +from lbrynet.dht import constants +from lbrynet.dht.node import Node +from lbrynet.dht.peer import PeerManager +from lbrynet.dht.blob_announcer import BlobAnnouncer +from lbrynet.extras.daemon.storage import SQLiteStorage + + +class TestBlobAnnouncer(AsyncioTestCase): + async def setup_node(self, peer_addresses, address, node_id): + self.nodes: typing.Dict[int, Node] = {} + self.advance = dht_mocks.get_time_accelerator(self.loop, self.loop.time()) + self.conf = Config() + self.storage = SQLiteStorage(self.conf, ":memory:", self.loop, self.loop.time) + await self.storage.open() + self.peer_manager = PeerManager(self.loop) + self.node = Node(self.loop, self.peer_manager, node_id, 4444, 4444, 3333, address) + await self.node.start_listening(address) + self.blob_announcer = BlobAnnouncer(self.loop, self.node, self.storage) + for node_id, address in peer_addresses: + await self.add_peer(node_id, address) + self.node.joined.set() + + async def add_peer(self, node_id, address, add_to_routing_table=True): + n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address) + await n.start_listening(address) + self.nodes.update({len(self.nodes): n}) + if add_to_routing_table: + await self.node.protocol.add_peer( + self.peer_manager.get_kademlia_peer( + n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port + ) + ) + + @contextlib.asynccontextmanager + async def _test_network_context(self, peer_addresses=None): + self.peer_addresses = peer_addresses or [ + (constants.generate_id(2), '1.2.3.2'), + (constants.generate_id(3), '1.2.3.3'), + (constants.generate_id(4), '1.2.3.4'), + (constants.generate_id(5), '1.2.3.5'), + (constants.generate_id(6), '1.2.3.6'), + (constants.generate_id(7), '1.2.3.7'), + (constants.generate_id(8), '1.2.3.8'), + (constants.generate_id(9), '1.2.3.9'), + ] + try: + with dht_mocks.mock_network_loop(self.loop): + await self.setup_node(self.peer_addresses, '1.2.3.1', constants.generate_id(1)) + yield + finally: + self.blob_announcer.stop() + self.node.stop() + for n in self.nodes.values(): + n.stop() + + async def chain_peer(self, node_id, address): + previous_last_node = self.nodes[len(self.nodes) - 1] + await self.add_peer(node_id, address, False) + last_node = self.nodes[len(self.nodes) - 1] + peer = last_node.protocol.get_rpc_peer( + last_node.protocol.peer_manager.get_kademlia_peer( + previous_last_node.protocol.node_id, previous_last_node.protocol.external_ip, + previous_last_node.protocol.udp_port + ) + ) + await peer.ping() + return peer + + async def test_announce_blobs(self): + blob1 = binascii.hexlify(b'1' * 48).decode() + blob2 = binascii.hexlify(b'2' * 48).decode() + + async with self._test_network_context(): + await self.storage.add_completed_blob(blob1, 1024) + await self.storage.add_completed_blob(blob2, 1024) + await self.storage.db.execute( + "update blob set next_announce_time=0, should_announce=1 where blob_hash in (?, ?)", + (blob1, blob2) + ) + to_announce = await self.storage.get_blobs_to_announce() + self.assertEqual(2, len(to_announce)) + self.blob_announcer.start() + await self.advance(61.0) + to_announce = await self.storage.get_blobs_to_announce() + self.assertEqual(0, len(to_announce)) + self.blob_announcer.stop() + + # test that we can route from a poorly connected peer all the way to the announced blob + + await self.chain_peer(constants.generate_id(10), '1.2.3.10') + await self.chain_peer(constants.generate_id(11), '1.2.3.11') + await self.chain_peer(constants.generate_id(12), '1.2.3.12') + await self.chain_peer(constants.generate_id(13), '1.2.3.13') + await self.chain_peer(constants.generate_id(14), '1.2.3.14') + + last = self.nodes[len(self.nodes) - 1] + search_q, peer_q = asyncio.Queue(loop=self.loop), asyncio.Queue(loop=self.loop) + search_q.put_nowait(blob1) + + _, task = last.accumulate_peers(search_q, peer_q) + found_peers = await peer_q.get() + task.cancel() + + self.assertEqual(1, len(found_peers)) + self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id) + self.assertEqual(self.node.protocol.external_ip, found_peers[0].address) + self.assertEqual(self.node.protocol.peer_port, found_peers[0].tcp_port)