forked from LBRYCommunity/lbry-sdk
Merge pull request #2552 from lbryio/dht-extensions
Add optional extensions to DHT datagram format
This commit is contained in:
commit
72e486791b
2 changed files with 37 additions and 13 deletions
|
@ -7,9 +7,13 @@ REQUEST_TYPE = 0
|
||||||
RESPONSE_TYPE = 1
|
RESPONSE_TYPE = 1
|
||||||
ERROR_TYPE = 2
|
ERROR_TYPE = 2
|
||||||
|
|
||||||
|
OPTIONAL_ARG_OFFSET = 100
|
||||||
|
|
||||||
# bencode representation of argument keys
|
# bencode representation of argument keys
|
||||||
PAGE_KEY = b'p'
|
PAGE_KEY = b'p'
|
||||||
|
|
||||||
|
OPTIONAL_FIELDS = ()
|
||||||
|
|
||||||
|
|
||||||
class KademliaDatagramBase:
|
class KademliaDatagramBase:
|
||||||
"""
|
"""
|
||||||
|
@ -18,7 +22,7 @@ class KademliaDatagramBase:
|
||||||
these correspond to the packet_type, rpc_id, and node_id args
|
these correspond to the packet_type, rpc_id, and node_id args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fields = [
|
required_fields = [
|
||||||
'packet_type',
|
'packet_type',
|
||||||
'rpc_id',
|
'rpc_id',
|
||||||
'node_id'
|
'node_id'
|
||||||
|
@ -38,13 +42,18 @@ class KademliaDatagramBase:
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
|
|
||||||
def bencode(self) -> bytes:
|
def bencode(self) -> bytes:
|
||||||
return bencode({
|
datagram = {
|
||||||
i: getattr(self, k) for i, k in enumerate(self.fields)
|
i: getattr(self, k) for i, k in enumerate(self.required_fields)
|
||||||
})
|
}
|
||||||
|
for i, k in enumerate(OPTIONAL_FIELDS):
|
||||||
|
v = getattr(self, k, None)
|
||||||
|
if v is not None:
|
||||||
|
datagram[i + OPTIONAL_ARG_OFFSET] = v
|
||||||
|
return bencode(datagram)
|
||||||
|
|
||||||
|
|
||||||
class RequestDatagram(KademliaDatagramBase):
|
class RequestDatagram(KademliaDatagramBase):
|
||||||
fields = [
|
required_fields = [
|
||||||
'packet_type',
|
'packet_type',
|
||||||
'rpc_id',
|
'rpc_id',
|
||||||
'node_id',
|
'node_id',
|
||||||
|
@ -104,7 +113,7 @@ class RequestDatagram(KademliaDatagramBase):
|
||||||
|
|
||||||
|
|
||||||
class ResponseDatagram(KademliaDatagramBase):
|
class ResponseDatagram(KademliaDatagramBase):
|
||||||
fields = [
|
required_fields = [
|
||||||
'packet_type',
|
'packet_type',
|
||||||
'rpc_id',
|
'rpc_id',
|
||||||
'node_id',
|
'node_id',
|
||||||
|
@ -119,7 +128,7 @@ class ResponseDatagram(KademliaDatagramBase):
|
||||||
|
|
||||||
|
|
||||||
class ErrorDatagram(KademliaDatagramBase):
|
class ErrorDatagram(KademliaDatagramBase):
|
||||||
fields = [
|
required_fields = [
|
||||||
'packet_type',
|
'packet_type',
|
||||||
'rpc_id',
|
'rpc_id',
|
||||||
'node_id',
|
'node_id',
|
||||||
|
@ -148,12 +157,15 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid datagram type")
|
raise ValueError("invalid datagram type")
|
||||||
datagram_class = msg_types[datagram_type]
|
datagram_class = msg_types[datagram_type]
|
||||||
return datagram_class(**{
|
decoded = {
|
||||||
k: primitive[i] # pylint: disable=unsubscriptable-object
|
k: primitive[i] # pylint: disable=unsubscriptable-object
|
||||||
for i, k in enumerate(datagram_class.fields)
|
for i, k in enumerate(datagram_class.required_fields)
|
||||||
if i in primitive # pylint: disable=unsupported-membership-test
|
if i in primitive # pylint: disable=unsupported-membership-test
|
||||||
}
|
}
|
||||||
)
|
for i, k in enumerate(OPTIONAL_FIELDS):
|
||||||
|
if i + OPTIONAL_ARG_OFFSET in primitive:
|
||||||
|
decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
|
||||||
|
return datagram_class(**decoded)
|
||||||
|
|
||||||
|
|
||||||
def make_compact_ip(address: str) -> bytearray:
|
def make_compact_ip(address: str) -> bytearray:
|
||||||
|
|
|
@ -127,6 +127,18 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertRaises(ValueError, decode_datagram, serialized)
|
self.assertRaises(ValueError, decode_datagram, serialized)
|
||||||
self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4]))
|
self.assertRaises(DecodeError, decode_datagram, _bencode([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
def test_optional_field_backwards_compatible(self):
|
||||||
|
datagram = decode_datagram(_bencode({
|
||||||
|
0: 0,
|
||||||
|
1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||||
|
2: b'111111111111111111111111111111111111111111111111',
|
||||||
|
3: b'ping',
|
||||||
|
4: [{b'protocolVersion': 1}],
|
||||||
|
5: b'should not error'
|
||||||
|
}))
|
||||||
|
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
|
||||||
|
self.assertEqual(b'ping', datagram.method)
|
||||||
|
|
||||||
|
|
||||||
class TestCompactAddress(unittest.TestCase):
|
class TestCompactAddress(unittest.TestCase):
|
||||||
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
||||||
|
|
Loading…
Reference in a new issue