add optional_fields to KademliaDatagramBase

-update KademliaDatagramBase.bencode and decode_datagram
This commit is contained in:
Jack Robison 2019-10-16 10:19:59 -04:00
parent aa7c0a3544
commit 874c28bd88
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2

View file

@ -7,9 +7,13 @@ REQUEST_TYPE = 0
RESPONSE_TYPE = 1 RESPONSE_TYPE = 1
ERROR_TYPE = 2 ERROR_TYPE = 2
OPTIONAL_ARG_OFFSET = 100
# bencode representation of argument keys # bencode representation of argument keys
PAGE_KEY = b'p' PAGE_KEY = b'p'
OPTIONAL_FIELDS = ()
class KademliaDatagramBase: class KademliaDatagramBase:
""" """
@ -18,7 +22,7 @@ class KademliaDatagramBase:
these correspond to the packet_type, rpc_id, and node_id args these correspond to the packet_type, rpc_id, and node_id args
""" """
fields = [ required_fields = [
'packet_type', 'packet_type',
'rpc_id', 'rpc_id',
'node_id' 'node_id'
@ -38,13 +42,18 @@ class KademliaDatagramBase:
self.node_id = node_id self.node_id = node_id
def bencode(self) -> bytes: def bencode(self) -> bytes:
return bencode({ datagram = {
i: getattr(self, k) for i, k in enumerate(self.fields) i: getattr(self, k) for i, k in enumerate(self.required_fields)
}) }
for i, k in enumerate(OPTIONAL_FIELDS):
v = getattr(self, k, None)
if v is not None:
datagram[i + OPTIONAL_ARG_OFFSET] = v
return bencode(datagram)
class RequestDatagram(KademliaDatagramBase): class RequestDatagram(KademliaDatagramBase):
fields = [ required_fields = [
'packet_type', 'packet_type',
'rpc_id', 'rpc_id',
'node_id', 'node_id',
@ -104,7 +113,7 @@ class RequestDatagram(KademliaDatagramBase):
class ResponseDatagram(KademliaDatagramBase): class ResponseDatagram(KademliaDatagramBase):
fields = [ required_fields = [
'packet_type', 'packet_type',
'rpc_id', 'rpc_id',
'node_id', 'node_id',
@ -119,7 +128,7 @@ class ResponseDatagram(KademliaDatagramBase):
class ErrorDatagram(KademliaDatagramBase): class ErrorDatagram(KademliaDatagramBase):
fields = [ required_fields = [
'packet_type', 'packet_type',
'rpc_id', 'rpc_id',
'node_id', 'node_id',
@ -148,12 +157,15 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
else: else:
raise ValueError("invalid datagram type") raise ValueError("invalid datagram type")
datagram_class = msg_types[datagram_type] datagram_class = msg_types[datagram_type]
return datagram_class(**{ decoded = {
k: primitive[i] # pylint: disable=unsubscriptable-object k: primitive[i] # pylint: disable=unsubscriptable-object
for i, k in enumerate(datagram_class.fields) for i, k in enumerate(datagram_class.required_fields)
if i in primitive # pylint: disable=unsupported-membership-test if i in primitive # pylint: disable=unsupported-membership-test
} }
) for i, k in enumerate(OPTIONAL_FIELDS):
if i + OPTIONAL_ARG_OFFSET in primitive:
decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
return datagram_class(**decoded)
def make_compact_ip(address: str) -> bytearray: def make_compact_ip(address: str) -> bytearray: