improve lbrynet.dht.serialization unit tests
This commit is contained in:
parent
3b0ba3e534
commit
6565ca8558
4 changed files with 100 additions and 37 deletions
|
@ -13,11 +13,6 @@ def cancel_task(task: typing.Optional[asyncio.Task]):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
|
||||||
for task in tasks:
|
|
||||||
cancel_task(task)
|
|
||||||
|
|
||||||
|
|
||||||
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
||||||
while tasks:
|
while tasks:
|
||||||
cancel_task(tasks.pop())
|
cancel_task(tasks.pop())
|
||||||
|
|
|
@ -65,45 +65,36 @@ class RequestDatagram(KademliaDatagramBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
|
||||||
raise ValueError("invalid rpc id length")
|
|
||||||
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
|
||||||
raise ValueError("invalid node id")
|
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
|
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
raise ValueError("invalid rpc id length")
|
if len(blob_hash) != constants.hash_bits // 8:
|
||||||
if not rpc_id:
|
raise ValueError(f"invalid blob hash length: {len(blob_hash)}")
|
||||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
if not 0 < port < 65536:
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
raise ValueError(f"invalid port: {port}")
|
||||||
raise ValueError("invalid node id")
|
if len(token) != constants.hash_bits // 8:
|
||||||
|
raise ValueError(f"invalid token length: {len(token)}")
|
||||||
store_args = [blob_hash, token, port, from_node_id, 0]
|
store_args = [blob_hash, token, port, from_node_id, 0]
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_find_node(cls, from_node_id: bytes, key: bytes,
|
def make_find_node(cls, from_node_id: bytes, key: bytes,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
raise ValueError("invalid rpc id length")
|
if len(key) != constants.hash_bits // 8:
|
||||||
if not rpc_id:
|
raise ValueError(f"invalid key length: {len(key)}")
|
||||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
|
||||||
raise ValueError("invalid node id")
|
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
||||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
|
||||||
raise ValueError("invalid rpc id length")
|
if len(key) != constants.hash_bits // 8:
|
||||||
if not rpc_id:
|
raise ValueError(f"invalid key length: {len(key)}")
|
||||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
|
||||||
if len(from_node_id) != constants.hash_bits // 8:
|
|
||||||
raise ValueError("invalid node id")
|
|
||||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,8 +138,6 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
||||||
}
|
}
|
||||||
|
|
||||||
primitive: typing.Dict = bdecode(datagram)
|
primitive: typing.Dict = bdecode(datagram)
|
||||||
if not isinstance(primitive, dict):
|
|
||||||
raise ValueError("invalid datagram type")
|
|
||||||
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
||||||
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
||||||
else:
|
else:
|
||||||
|
@ -162,14 +151,19 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_compact_ip(address: str):
|
def make_compact_ip(address: str) -> bytearray:
|
||||||
return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
|
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
|
||||||
|
if len(compact_ip) != 4:
|
||||||
|
raise ValueError(f"invalid IPv4 length")
|
||||||
|
return compact_ip
|
||||||
|
|
||||||
|
|
||||||
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
|
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
|
||||||
compact_ip = make_compact_ip(address)
|
compact_ip = make_compact_ip(address)
|
||||||
if not 0 <= port <= 65536:
|
if not 0 < port < 65536:
|
||||||
raise ValueError(f'Invalid port: {port}')
|
raise ValueError(f'Invalid port: {port}')
|
||||||
|
if len(node_id) != constants.hash_bits // 8:
|
||||||
|
raise ValueError(f"invalid node node_id length")
|
||||||
return compact_ip + port.to_bytes(2, 'big') + node_id
|
return compact_ip + port.to_bytes(2, 'big') + node_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,4 +171,8 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i
|
||||||
address = "{}.{}.{}.{}".format(*compact_address[:4])
|
address = "{}.{}.{}.{}".format(*compact_address[:4])
|
||||||
port = int.from_bytes(compact_address[4:6], 'big')
|
port = int.from_bytes(compact_address[4:6], 'big')
|
||||||
node_id = compact_address[6:]
|
node_id = compact_address[6:]
|
||||||
|
if not 0 < port < 65536:
|
||||||
|
raise ValueError(f'Invalid port: {port}')
|
||||||
|
if len(node_id) != constants.hash_bits // 8:
|
||||||
|
raise ValueError(f"invalid node node_id length")
|
||||||
return node_id, address, port
|
return node_id, address, port
|
||||||
|
|
|
@ -62,3 +62,4 @@ class EncodeDecodeTest(unittest.TestCase):
|
||||||
def test_decode_error(self):
|
def test_decode_error(self):
|
||||||
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
|
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
|
||||||
self.assertRaises(DecodeError, bdecode, b'', True)
|
self.assertRaises(DecodeError, bdecode, b'', True)
|
||||||
|
self.assertRaises(DecodeError, bdecode, b'l4:spami42ee')
|
||||||
|
|
|
@ -1,11 +1,17 @@
|
||||||
import unittest
|
import unittest
|
||||||
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram
|
from lbrynet.dht.error import DecodeError
|
||||||
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE
|
from lbrynet.dht.serialization.bencoding import _bencode
|
||||||
|
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
|
||||||
|
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
|
||||||
|
from lbrynet.dht.serialization.datagram import make_compact_address, decode_compact_address
|
||||||
|
|
||||||
|
|
||||||
class TestDatagram(unittest.TestCase):
|
class TestDatagram(unittest.TestCase):
|
||||||
def test_ping_request_datagram(self):
|
def test_ping_request_datagram(self):
|
||||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode()
|
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 48, b'1' * 21)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 47, b'1' * 20)
|
||||||
|
self.assertEqual(20, len(RequestDatagram.make_ping(b'1' * 48).rpc_id))
|
||||||
|
serialized = RequestDatagram.make_ping(b'1' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
@ -14,6 +20,9 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
|
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
|
||||||
|
|
||||||
def test_ping_response(self):
|
def test_ping_response(self):
|
||||||
|
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 21, b'1' * 48, b'pong')
|
||||||
|
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 20, b'1' * 49, b'pong')
|
||||||
|
self.assertRaises(ValueError, ResponseDatagram, 5, b'1' * 20, b'1' * 48, b'pong')
|
||||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
|
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||||
|
@ -22,7 +31,12 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.response, b'pong')
|
self.assertEqual(decoded.response, b'pong')
|
||||||
|
|
||||||
def test_find_node_request_datagram(self):
|
def test_find_node_request_datagram(self):
|
||||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode()
|
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 49, b'2' * 48, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 49, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 48, b'1' * 21)
|
||||||
|
self.assertEqual(20, len(RequestDatagram.make_find_node(b'1' * 48, b'2' * 48).rpc_id))
|
||||||
|
|
||||||
|
serialized = RequestDatagram.make_find_node(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
@ -42,7 +56,12 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.response, expected)
|
self.assertEqual(decoded.response, expected)
|
||||||
|
|
||||||
def test_find_value_request(self):
|
def test_find_value_request(self):
|
||||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode()
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
|
||||||
|
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
|
||||||
|
|
||||||
|
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
|
||||||
decoded = decode_datagram(serialized)
|
decoded = decode_datagram(serialized)
|
||||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
@ -58,3 +77,53 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
self.assertDictEqual(decoded.response, found_value_response)
|
self.assertDictEqual(decoded.response, found_value_response)
|
||||||
|
|
||||||
|
def test_store_request(self):
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 47, 3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, -3333, b'1' * 20)
|
||||||
|
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 21)
|
||||||
|
|
||||||
|
serialized = RequestDatagram.make_store(b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 20).bencode()
|
||||||
|
decoded = decode_datagram(serialized)
|
||||||
|
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||||
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
|
self.assertEqual(decoded.method, b'store')
|
||||||
|
|
||||||
|
def test_error_datagram(self):
|
||||||
|
serialized = ErrorDatagram(ERROR_TYPE, b'1' * 20, b'1' * 48, b'FakeErrorType', b'more info').bencode()
|
||||||
|
decoded = decode_datagram(serialized)
|
||||||
|
self.assertEqual(decoded.packet_type, ERROR_TYPE)
|
||||||
|
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||||
|
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||||
|
self.assertEqual(decoded.exception_type, 'FakeErrorType')
|
||||||
|
self.assertEqual(decoded.response, 'more info')
|
||||||
|
|
||||||
|
def test_invalid_datagram_type(self):
|
||||||
|
serialized = b'di0ei5ei1e20:11111111111111111111i2e48:11111111111111111111' \
|
||||||
|
b'1111111111111111111111111111i3e13:FakeErrorTypei4e9:more infoe'
|
||||||
|
self.assertRaises(ValueError, decode_datagram, serialized)
|
||||||
|
self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactAddress(unittest.TestCase):
|
||||||
|
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
||||||
|
decoded = decode_compact_address(make_compact_address(node_id, address, port))
|
||||||
|
self.assertEqual((node_id, address, port), decoded)
|
||||||
|
|
||||||
|
def test_errors(self):
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 0)
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 65536)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError, decode_compact_address,
|
||||||
|
b'\x01\x02\x03\x04\x00\x00111111111111111111111111111111111111111111111111'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4.5', 4444)
|
||||||
|
self.assertRaises(ValueError, make_compact_address, b'1' * 47, '1.2.3.4', 4444)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError, decode_compact_address,
|
||||||
|
b'\x01\x02\x03\x04\x11\\11111111111111111111111111111111111111111111111'
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue