diff --git a/lbrynet/dht/protocol/async_generator_junction.py b/lbrynet/dht/protocol/async_generator_junction.py index cbcb1ba84..79db6a55d 100644 --- a/lbrynet/dht/protocol/async_generator_junction.py +++ b/lbrynet/dht/protocol/async_generator_junction.py @@ -13,11 +13,6 @@ def cancel_task(task: typing.Optional[asyncio.Task]): task.cancel() -def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): - for task in tasks: - cancel_task(task) - - def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): while tasks: cancel_task(tasks.pop()) diff --git a/lbrynet/dht/serialization/datagram.py b/lbrynet/dht/serialization/datagram.py index 2d2104ee3..231fbbb59 100644 --- a/lbrynet/dht/serialization/datagram.py +++ b/lbrynet/dht/serialization/datagram.py @@ -65,45 +65,36 @@ class RequestDatagram(KademliaDatagramBase): @classmethod def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping') @classmethod def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") - if not rpc_id: - rpc_id = constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") + rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] + if len(blob_hash) != constants.hash_bits // 8: + raise ValueError(f"invalid blob hash length: {len(blob_hash)}") + if not 0 < port < 65536: + raise ValueError(f"invalid port: {port}") + if len(token) != constants.hash_bits // 8: + raise ValueError(f"invalid token length: {len(token)}") store_args = [blob_hash, token, port, from_node_id, 0] return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args) @classmethod def make_find_node(cls, from_node_id: bytes, key: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") - if not rpc_id: - rpc_id = constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") + rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] + if len(key) != constants.hash_bits // 8: + raise ValueError(f"invalid key length: {len(key)}") return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key]) @classmethod def make_find_value(cls, from_node_id: bytes, key: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': - if rpc_id and len(rpc_id) != constants.rpc_id_length: - raise ValueError("invalid rpc id length") - if not rpc_id: - rpc_id = constants.generate_id()[:constants.rpc_id_length] - if len(from_node_id) != constants.hash_bits // 8: - raise ValueError("invalid node id") + rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] + if len(key) != constants.hash_bits // 8: + raise ValueError(f"invalid key length: {len(key)}") return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key]) @@ -147,8 +138,6 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa } primitive: typing.Dict = bdecode(datagram) - if not isinstance(primitive, dict): - raise ValueError("invalid datagram type") if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object datagram_type = primitive[0] # pylint: disable=unsubscriptable-object else: @@ -162,14 +151,19 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa ) -def make_compact_ip(address: str): - return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) +def make_compact_ip(address: str) -> bytearray: + compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) + if len(compact_ip) != 4: + raise ValueError(f"invalid IPv4 length") + return compact_ip def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray: compact_ip = make_compact_ip(address) - if not 0 <= port <= 65536: + if not 0 < port < 65536: raise ValueError(f'Invalid port: {port}') + if len(node_id) != constants.hash_bits // 8: + raise ValueError(f"invalid node node_id length") return compact_ip + port.to_bytes(2, 'big') + node_id @@ -177,4 +171,8 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i address = "{}.{}.{}.{}".format(*compact_address[:4]) port = int.from_bytes(compact_address[4:6], 'big') node_id = compact_address[6:] + if not 0 < port < 65536: + raise ValueError(f'Invalid port: {port}') + if len(node_id) != constants.hash_bits // 8: + raise ValueError(f"invalid node node_id length") return node_id, address, port diff --git a/tests/unit/dht/serialization/test_bencoding.py b/tests/unit/dht/serialization/test_bencoding.py index b516d2e46..983ab224c 100644 --- a/tests/unit/dht/serialization/test_bencoding.py +++ b/tests/unit/dht/serialization/test_bencoding.py @@ -62,3 +62,4 @@ class EncodeDecodeTest(unittest.TestCase): def test_decode_error(self): self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True) self.assertRaises(DecodeError, bdecode, b'', True) + self.assertRaises(DecodeError, bdecode, b'l4:spami42ee') diff --git a/tests/unit/dht/serialization/test_datagram.py b/tests/unit/dht/serialization/test_datagram.py index 738ebaa6c..6f473d212 100644 --- a/tests/unit/dht/serialization/test_datagram.py +++ b/tests/unit/dht/serialization/test_datagram.py @@ -1,11 +1,17 @@ import unittest -from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram -from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE +from lbrynet.dht.error import DecodeError +from lbrynet.dht.serialization.bencoding import _bencode +from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram +from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE +from lbrynet.dht.serialization.datagram import make_compact_address, decode_compact_address class TestDatagram(unittest.TestCase): def test_ping_request_datagram(self): - serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode() + self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 48, b'1' * 21) + self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 47, b'1' * 20) + self.assertEqual(20, len(RequestDatagram.make_ping(b'1' * 48).rpc_id)) + serialized = RequestDatagram.make_ping(b'1' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) @@ -14,6 +20,9 @@ class TestDatagram(unittest.TestCase): self.assertListEqual(decoded.args, [{b'protocolVersion': 1}]) def test_ping_response(self): + self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 21, b'1' * 48, b'pong') + self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 20, b'1' * 49, b'pong') + self.assertRaises(ValueError, ResponseDatagram, 5, b'1' * 20, b'1' * 48, b'pong') serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, RESPONSE_TYPE) @@ -22,7 +31,12 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.response, b'pong') def test_find_node_request_datagram(self): - serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode() + self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 49, b'2' * 48, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 49, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 48, b'1' * 21) + self.assertEqual(20, len(RequestDatagram.make_find_node(b'1' * 48, b'2' * 48).rpc_id)) + + serialized = RequestDatagram.make_find_node(b'1' * 48, b'2' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) @@ -42,7 +56,12 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.response, expected) def test_find_value_request(self): - serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode() + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21) + self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id)) + + serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode() decoded = decode_datagram(serialized) self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.rpc_id, b'1' * 20) @@ -58,3 +77,53 @@ class TestDatagram(unittest.TestCase): self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.node_id, b'1' * 48) self.assertDictEqual(decoded.response, found_value_response) + + def test_store_request(self): + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 47, 3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, -3333, b'1' * 20) + self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 21) + + serialized = RequestDatagram.make_store(b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 20).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, REQUEST_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.method, b'store') + + def test_error_datagram(self): + serialized = ErrorDatagram(ERROR_TYPE, b'1' * 20, b'1' * 48, b'FakeErrorType', b'more info').bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, ERROR_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.exception_type, 'FakeErrorType') + self.assertEqual(decoded.response, 'more info') + + def test_invalid_datagram_type(self): + serialized = b'di0ei5ei1e20:11111111111111111111i2e48:11111111111111111111' \ + b'1111111111111111111111111111i3e13:FakeErrorTypei4e9:more infoe' + self.assertRaises(ValueError, decode_datagram, serialized) + self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4])) + + +class TestCompactAddress(unittest.TestCase): + def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48): + decoded = decode_compact_address(make_compact_address(node_id, address, port)) + self.assertEqual((node_id, address, port), decoded) + + def test_errors(self): + self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 0) + self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 65536) + self.assertRaises( + ValueError, decode_compact_address, + b'\x01\x02\x03\x04\x00\x00111111111111111111111111111111111111111111111111' + ) + + self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4.5', 4444) + self.assertRaises(ValueError, make_compact_address, b'1' * 47, '1.2.3.4', 4444) + self.assertRaises( + ValueError, decode_compact_address, + b'\x01\x02\x03\x04\x11\\11111111111111111111111111111111111111111111111' + )