diff --git a/lbry/lbry/dht/serialization/datagram.py b/lbry/lbry/dht/serialization/datagram.py index 5550f86cc..d84401228 100644 --- a/lbry/lbry/dht/serialization/datagram.py +++ b/lbry/lbry/dht/serialization/datagram.py @@ -7,9 +7,13 @@ REQUEST_TYPE = 0 RESPONSE_TYPE = 1 ERROR_TYPE = 2 +OPTIONAL_ARG_OFFSET = 100 + # bencode representation of argument keys PAGE_KEY = b'p' +OPTIONAL_FIELDS = () + class KademliaDatagramBase: """ @@ -18,7 +22,7 @@ class KademliaDatagramBase: these correspond to the packet_type, rpc_id, and node_id args """ - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id' @@ -38,13 +42,18 @@ class KademliaDatagramBase: self.node_id = node_id def bencode(self) -> bytes: - return bencode({ - i: getattr(self, k) for i, k in enumerate(self.fields) - }) + datagram = { + i: getattr(self, k) for i, k in enumerate(self.required_fields) + } + for i, k in enumerate(OPTIONAL_FIELDS): + v = getattr(self, k, None) + if v is not None: + datagram[i + OPTIONAL_ARG_OFFSET] = v + return bencode(datagram) class RequestDatagram(KademliaDatagramBase): - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id', @@ -104,7 +113,7 @@ class RequestDatagram(KademliaDatagramBase): class ResponseDatagram(KademliaDatagramBase): - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id', @@ -119,7 +128,7 @@ class ResponseDatagram(KademliaDatagramBase): class ErrorDatagram(KademliaDatagramBase): - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id', @@ -148,12 +157,15 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa else: raise ValueError("invalid datagram type") datagram_class = msg_types[datagram_type] - return datagram_class(**{ - k: primitive[i] # pylint: disable=unsubscriptable-object - for i, k in enumerate(datagram_class.fields) - if i in primitive # pylint: disable=unsupported-membership-test - } - ) + decoded = { + k: primitive[i] # pylint: disable=unsubscriptable-object + for i, k in enumerate(datagram_class.required_fields) + if i in primitive # pylint: disable=unsupported-membership-test + } + for i, k in enumerate(OPTIONAL_FIELDS): + if i + OPTIONAL_ARG_OFFSET in primitive: + decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET] + return datagram_class(**decoded) def make_compact_ip(address: str) -> bytearray: diff --git a/lbry/tests/unit/dht/serialization/test_datagram.py b/lbry/tests/unit/dht/serialization/test_datagram.py index 5544f13cd..ff502cbf0 100644 --- a/lbry/tests/unit/dht/serialization/test_datagram.py +++ b/lbry/tests/unit/dht/serialization/test_datagram.py @@ -127,6 +127,18 @@ class TestDatagram(unittest.TestCase): self.assertRaises(ValueError, decode_datagram, serialized) self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4])) + def test_optional_field_backwards_compatible(self): + datagram = decode_datagram(_bencode({ + 0: 0, + 1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc', + 2: b'111111111111111111111111111111111111111111111111', + 3: b'ping', + 4: [{b'protocolVersion': 1}], + 5: b'should not error' + })) + self.assertEqual(datagram.packet_type, REQUEST_TYPE) + self.assertEqual(b'ping', datagram.method) + class TestCompactAddress(unittest.TestCase): def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):