From 3a64ceb4d66b27f1d3303cd6c4fc10ab182cdd83 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Mon, 14 Sep 2020 15:43:02 -0400 Subject: [PATCH] add forward compatibility for byte datagram keys --- lbry/dht/serialization/datagram.py | 24 +++++++++---- tests/unit/dht/serialization/test_datagram.py | 34 +++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/lbry/dht/serialization/datagram.py b/lbry/dht/serialization/datagram.py index 5ac6d489f..5008907f0 100644 --- a/lbry/dht/serialization/datagram.py +++ b/lbry/dht/serialization/datagram.py @@ -144,7 +144,7 @@ class ErrorDatagram(KademliaDatagramBase): self.response = response.decode() -def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]: +def _decode_datagram(datagram: bytes): msg_types = { REQUEST_TYPE: RequestDatagram, RESPONSE_TYPE: ResponseDatagram, @@ -152,19 +152,29 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa } primitive: typing.Dict = bdecode(datagram) - if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object - datagram_type = primitive[0] # pylint: disable=unsubscriptable-object + + converted = { + str(k).encode() if not isinstance(k, bytes) else k: v for k, v in primitive.items() + } + + if converted[b'0'] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object + datagram_type = converted[b'0'] # pylint: disable=unsubscriptable-object else: raise ValueError("invalid datagram type") datagram_class = msg_types[datagram_type] decoded = { - k: primitive[i] # pylint: disable=unsubscriptable-object + k: converted[str(i).encode()] # pylint: disable=unsubscriptable-object for i, k in enumerate(datagram_class.required_fields) - if i in primitive # pylint: disable=unsupported-membership-test + if str(i).encode() in converted # pylint: disable=unsupported-membership-test } for i, _ in enumerate(OPTIONAL_FIELDS): - if i + OPTIONAL_ARG_OFFSET in primitive: - decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET] + if str(i + OPTIONAL_ARG_OFFSET).encode() in converted: + decoded[i + OPTIONAL_ARG_OFFSET] = converted[str(i + OPTIONAL_ARG_OFFSET).encode()] + return decoded, datagram_class + + +def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]: + decoded, datagram_class = _decode_datagram(datagram) return datagram_class(**decoded) diff --git a/tests/unit/dht/serialization/test_datagram.py b/tests/unit/dht/serialization/test_datagram.py index ff502cbf0..1bda095d6 100644 --- a/tests/unit/dht/serialization/test_datagram.py +++ b/tests/unit/dht/serialization/test_datagram.py @@ -1,7 +1,9 @@ +import binascii import unittest from lbry.dht.error import DecodeError from lbry.dht.serialization.bencoding import _bencode from lbry.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram +from lbry.dht.serialization.datagram import _decode_datagram from lbry.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE from lbry.dht.serialization.datagram import make_compact_address, decode_compact_address @@ -139,6 +141,38 @@ class TestDatagram(unittest.TestCase): self.assertEqual(datagram.packet_type, REQUEST_TYPE) self.assertEqual(b'ping', datagram.method) + def test_str_or_int_keys(self): + datagram = decode_datagram(_bencode({ + b'0': 0, + b'1': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc', + b'2': b'111111111111111111111111111111111111111111111111', + b'3': b'ping', + b'4': [{b'protocolVersion': 1}], + b'5': b'should not error' + })) + self.assertEqual(datagram.packet_type, REQUEST_TYPE) + self.assertEqual(b'ping', datagram.method) + + def test_mixed_str_or_int_keys(self): + # datagram, _ = _bencode({ + # b'0': 0, + # 1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc', + # b'2': b'111111111111111111111111111111111111111111111111', + # 3: b'ping', + # b'4': [{b'protocolVersion': 1}], + # b'5': b'should not error' + # })) + encoded = binascii.unhexlify(b"64313a3069306569316532303a0abcb5269d6cfc1e87a08e920bf39fe9df8e92fc313a3234383a313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131693365343a70696e67313a346c6431353a70726f746f636f6c56657273696f6e6931656565313a3531363a73686f756c64206e6f74206572726f7265") + self.assertDictEqual( + { + 'packet_type': 0, + 'rpc_id': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc', + 'node_id': b'111111111111111111111111111111111111111111111111', + 'method': b'ping', + 'args': [{b'protocolVersion': 1}] + }, _decode_datagram(encoded)[0] + ) + class TestCompactAddress(unittest.TestCase): def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):