improve lbrynet.dht.serialization unit tests

This commit is contained in:
Jack Robison 2019-03-14 18:41:24 -04:00
parent 3b0ba3e534
commit 6565ca8558
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 100 additions and 37 deletions

View file

@ -13,11 +13,6 @@ def cancel_task(task: typing.Optional[asyncio.Task]):
task.cancel() task.cancel()
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
for task in tasks:
cancel_task(task)
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks: while tasks:
cancel_task(tasks.pop()) cancel_task(tasks.pop())

View file

@ -65,45 +65,36 @@ class RequestDatagram(KademliaDatagramBase):
@classmethod @classmethod
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length] rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping') return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
@classmethod @classmethod
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int, def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length: rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
raise ValueError("invalid rpc id length") if len(blob_hash) != constants.hash_bits // 8:
if not rpc_id: raise ValueError(f"invalid blob hash length: {len(blob_hash)}")
rpc_id = constants.generate_id()[:constants.rpc_id_length] if not 0 < port < 65536:
if len(from_node_id) != constants.hash_bits // 8: raise ValueError(f"invalid port: {port}")
raise ValueError("invalid node id") if len(token) != constants.hash_bits // 8:
raise ValueError(f"invalid token length: {len(token)}")
store_args = [blob_hash, token, port, from_node_id, 0] store_args = [blob_hash, token, port, from_node_id, 0]
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args) return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
@classmethod @classmethod
def make_find_node(cls, from_node_id: bytes, key: bytes, def make_find_node(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length: rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
raise ValueError("invalid rpc id length") if len(key) != constants.hash_bits // 8:
if not rpc_id: raise ValueError(f"invalid key length: {len(key)}")
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key]) return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
@classmethod @classmethod
def make_find_value(cls, from_node_id: bytes, key: bytes, def make_find_value(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length: rpc_id = rpc_id or constants.generate_id()[:constants.rpc_id_length]
raise ValueError("invalid rpc id length") if len(key) != constants.hash_bits // 8:
if not rpc_id: raise ValueError(f"invalid key length: {len(key)}")
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key]) return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
@ -147,8 +138,6 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
} }
primitive: typing.Dict = bdecode(datagram) primitive: typing.Dict = bdecode(datagram)
if not isinstance(primitive, dict):
raise ValueError("invalid datagram type")
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
else: else:
@ -162,14 +151,19 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
) )
def make_compact_ip(address: str): def make_compact_ip(address: str) -> bytearray:
return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
if len(compact_ip) != 4:
raise ValueError(f"invalid IPv4 length")
return compact_ip
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray: def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
compact_ip = make_compact_ip(address) compact_ip = make_compact_ip(address)
if not 0 <= port <= 65536: if not 0 < port < 65536:
raise ValueError(f'Invalid port: {port}') raise ValueError(f'Invalid port: {port}')
if len(node_id) != constants.hash_bits // 8:
raise ValueError(f"invalid node node_id length")
return compact_ip + port.to_bytes(2, 'big') + node_id return compact_ip + port.to_bytes(2, 'big') + node_id
@ -177,4 +171,8 @@ def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, i
address = "{}.{}.{}.{}".format(*compact_address[:4]) address = "{}.{}.{}.{}".format(*compact_address[:4])
port = int.from_bytes(compact_address[4:6], 'big') port = int.from_bytes(compact_address[4:6], 'big')
node_id = compact_address[6:] node_id = compact_address[6:]
if not 0 < port < 65536:
raise ValueError(f'Invalid port: {port}')
if len(node_id) != constants.hash_bits // 8:
raise ValueError(f"invalid node node_id length")
return node_id, address, port return node_id, address, port

View file

@ -62,3 +62,4 @@ class EncodeDecodeTest(unittest.TestCase):
def test_decode_error(self): def test_decode_error(self):
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True) self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
self.assertRaises(DecodeError, bdecode, b'', True) self.assertRaises(DecodeError, bdecode, b'', True)
self.assertRaises(DecodeError, bdecode, b'l4:spami42ee')

View file

@ -1,11 +1,17 @@
import unittest import unittest
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram from lbrynet.dht.error import DecodeError
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE from lbrynet.dht.serialization.bencoding import _bencode
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
from lbrynet.dht.serialization.datagram import make_compact_address, decode_compact_address
class TestDatagram(unittest.TestCase): class TestDatagram(unittest.TestCase):
def test_ping_request_datagram(self): def test_ping_request_datagram(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode() self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 48, b'1' * 21)
self.assertRaises(ValueError, RequestDatagram.make_ping, b'1' * 47, b'1' * 20)
self.assertEqual(20, len(RequestDatagram.make_ping(b'1' * 48).rpc_id))
serialized = RequestDatagram.make_ping(b'1' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized) decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.rpc_id, b'1' * 20)
@ -14,6 +20,9 @@ class TestDatagram(unittest.TestCase):
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}]) self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
def test_ping_response(self): def test_ping_response(self):
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 21, b'1' * 48, b'pong')
self.assertRaises(ValueError, ResponseDatagram, RESPONSE_TYPE, b'1' * 20, b'1' * 49, b'pong')
self.assertRaises(ValueError, ResponseDatagram, 5, b'1' * 20, b'1' * 48, b'pong')
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode() serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
decoded = decode_datagram(serialized) decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE) self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
@ -22,7 +31,12 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.response, b'pong') self.assertEqual(decoded.response, b'pong')
def test_find_node_request_datagram(self): def test_find_node_request_datagram(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode() self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 49, b'2' * 48, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 49, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_node, b'1' * 48, b'2' * 48, b'1' * 21)
self.assertEqual(20, len(RequestDatagram.make_find_node(b'1' * 48, b'2' * 48).rpc_id))
serialized = RequestDatagram.make_find_node(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized) decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.rpc_id, b'1' * 20)
@ -42,7 +56,12 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.response, expected) self.assertEqual(decoded.response, expected)
def test_find_value_request(self): def test_find_value_request(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode() self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 49, b'2' * 48, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 49, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_find_value, b'1' * 48, b'2' * 48, b'1' * 21)
self.assertEqual(20, len(RequestDatagram.make_find_value(b'1' * 48, b'2' * 48).rpc_id))
serialized = RequestDatagram.make_find_value(b'1' * 48, b'2' * 48, b'1' * 20).bencode()
decoded = decode_datagram(serialized) decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE) self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.rpc_id, b'1' * 20)
@ -58,3 +77,53 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(decoded.rpc_id, b'1' * 20) self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48) self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response) self.assertDictEqual(decoded.response, found_value_response)
def test_store_request(self):
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 47, b'2' * 48, b'3' * 48, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 49, b'3' * 48, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 47, 3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, -3333, b'1' * 20)
self.assertRaises(ValueError, RequestDatagram.make_store, b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 21)
serialized = RequestDatagram.make_store(b'1' * 48, b'2' * 48, b'3' * 48, 3333, b'1' * 20).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'store')
def test_error_datagram(self):
serialized = ErrorDatagram(ERROR_TYPE, b'1' * 20, b'1' * 48, b'FakeErrorType', b'more info').bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, ERROR_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.exception_type, 'FakeErrorType')
self.assertEqual(decoded.response, 'more info')
def test_invalid_datagram_type(self):
serialized = b'di0ei5ei1e20:11111111111111111111i2e48:11111111111111111111' \
b'1111111111111111111111111111i3e13:FakeErrorTypei4e9:more infoe'
self.assertRaises(ValueError, decode_datagram, serialized)
self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4]))
class TestCompactAddress(unittest.TestCase):
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
decoded = decode_compact_address(make_compact_address(node_id, address, port))
self.assertEqual((node_id, address, port), decoded)
def test_errors(self):
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 0)
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4', 65536)
self.assertRaises(
ValueError, decode_compact_address,
b'\x01\x02\x03\x04\x00\x00111111111111111111111111111111111111111111111111'
)
self.assertRaises(ValueError, make_compact_address, b'1' * 48, '1.2.3.4.5', 4444)
self.assertRaises(ValueError, make_compact_address, b'1' * 47, '1.2.3.4', 4444)
self.assertRaises(
ValueError, decode_compact_address,
b'\x01\x02\x03\x04\x11\\11111111111111111111111111111111111111111111111'
)