diff --git a/lbry/lbry/dht/serialization/datagram.py b/lbry/lbry/dht/serialization/datagram.py index 5550f86cc..d84401228 100644 --- a/lbry/lbry/dht/serialization/datagram.py +++ b/lbry/lbry/dht/serialization/datagram.py @@ -7,9 +7,13 @@ REQUEST_TYPE = 0 RESPONSE_TYPE = 1 ERROR_TYPE = 2 +OPTIONAL_ARG_OFFSET = 100 + # bencode representation of argument keys PAGE_KEY = b'p' +OPTIONAL_FIELDS = () + class KademliaDatagramBase: """ @@ -18,7 +22,7 @@ class KademliaDatagramBase: these correspond to the packet_type, rpc_id, and node_id args """ - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id' @@ -38,13 +42,18 @@ class KademliaDatagramBase: self.node_id = node_id def bencode(self) -> bytes: - return bencode({ - i: getattr(self, k) for i, k in enumerate(self.fields) - }) + datagram = { + i: getattr(self, k) for i, k in enumerate(self.required_fields) + } + for i, k in enumerate(OPTIONAL_FIELDS): + v = getattr(self, k, None) + if v is not None: + datagram[i + OPTIONAL_ARG_OFFSET] = v + return bencode(datagram) class RequestDatagram(KademliaDatagramBase): - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id', @@ -104,7 +113,7 @@ class RequestDatagram(KademliaDatagramBase): class ResponseDatagram(KademliaDatagramBase): - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id', @@ -119,7 +128,7 @@ class ResponseDatagram(KademliaDatagramBase): class ErrorDatagram(KademliaDatagramBase): - fields = [ + required_fields = [ 'packet_type', 'rpc_id', 'node_id', @@ -148,12 +157,15 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa else: raise ValueError("invalid datagram type") datagram_class = msg_types[datagram_type] - return datagram_class(**{ - k: primitive[i] # pylint: disable=unsubscriptable-object - for i, k in enumerate(datagram_class.fields) - if i in primitive # pylint: disable=unsupported-membership-test - } - ) + decoded = { + k: primitive[i] # pylint: disable=unsubscriptable-object + for i, k in enumerate(datagram_class.required_fields) + if i in primitive # pylint: disable=unsupported-membership-test + } + for i, k in enumerate(OPTIONAL_FIELDS): + if i + OPTIONAL_ARG_OFFSET in primitive: + decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET] + return datagram_class(**decoded) def make_compact_ip(address: str) -> bytearray: