forked from LBRYCommunity/lbry-sdk
add forward compatibility for byte datagram keys
This commit is contained in:
parent
d0f21c0095
commit
3a64ceb4d6
2 changed files with 51 additions and 7 deletions
|
@ -144,7 +144,7 @@ class ErrorDatagram(KademliaDatagramBase):
|
|||
self.response = response.decode()
|
||||
|
||||
|
||||
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
|
||||
def _decode_datagram(datagram: bytes):
|
||||
msg_types = {
|
||||
REQUEST_TYPE: RequestDatagram,
|
||||
RESPONSE_TYPE: ResponseDatagram,
|
||||
|
@ -152,19 +152,29 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
|||
}
|
||||
|
||||
primitive: typing.Dict = bdecode(datagram)
|
||||
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
||||
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
||||
|
||||
converted = {
|
||||
str(k).encode() if not isinstance(k, bytes) else k: v for k, v in primitive.items()
|
||||
}
|
||||
|
||||
if converted[b'0'] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
||||
datagram_type = converted[b'0'] # pylint: disable=unsubscriptable-object
|
||||
else:
|
||||
raise ValueError("invalid datagram type")
|
||||
datagram_class = msg_types[datagram_type]
|
||||
decoded = {
|
||||
k: primitive[i] # pylint: disable=unsubscriptable-object
|
||||
k: converted[str(i).encode()] # pylint: disable=unsubscriptable-object
|
||||
for i, k in enumerate(datagram_class.required_fields)
|
||||
if i in primitive # pylint: disable=unsupported-membership-test
|
||||
if str(i).encode() in converted # pylint: disable=unsupported-membership-test
|
||||
}
|
||||
for i, _ in enumerate(OPTIONAL_FIELDS):
|
||||
if i + OPTIONAL_ARG_OFFSET in primitive:
|
||||
decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
|
||||
if str(i + OPTIONAL_ARG_OFFSET).encode() in converted:
|
||||
decoded[i + OPTIONAL_ARG_OFFSET] = converted[str(i + OPTIONAL_ARG_OFFSET).encode()]
|
||||
return decoded, datagram_class
|
||||
|
||||
|
||||
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
|
||||
decoded, datagram_class = _decode_datagram(datagram)
|
||||
return datagram_class(**decoded)
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import binascii
|
||||
import unittest
|
||||
from lbry.dht.error import DecodeError
|
||||
from lbry.dht.serialization.bencoding import _bencode
|
||||
from lbry.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
|
||||
from lbry.dht.serialization.datagram import _decode_datagram
|
||||
from lbry.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
|
||||
from lbry.dht.serialization.datagram import make_compact_address, decode_compact_address
|
||||
|
||||
|
@ -139,6 +141,38 @@ class TestDatagram(unittest.TestCase):
|
|||
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(b'ping', datagram.method)
|
||||
|
||||
def test_str_or_int_keys(self):
|
||||
datagram = decode_datagram(_bencode({
|
||||
b'0': 0,
|
||||
b'1': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||
b'2': b'111111111111111111111111111111111111111111111111',
|
||||
b'3': b'ping',
|
||||
b'4': [{b'protocolVersion': 1}],
|
||||
b'5': b'should not error'
|
||||
}))
|
||||
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(b'ping', datagram.method)
|
||||
|
||||
def test_mixed_str_or_int_keys(self):
|
||||
# datagram, _ = _bencode({
|
||||
# b'0': 0,
|
||||
# 1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||
# b'2': b'111111111111111111111111111111111111111111111111',
|
||||
# 3: b'ping',
|
||||
# b'4': [{b'protocolVersion': 1}],
|
||||
# b'5': b'should not error'
|
||||
# }))
|
||||
encoded = binascii.unhexlify(b"64313a3069306569316532303a0abcb5269d6cfc1e87a08e920bf39fe9df8e92fc313a3234383a313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131693365343a70696e67313a346c6431353a70726f746f636f6c56657273696f6e6931656565313a3531363a73686f756c64206e6f74206572726f7265")
|
||||
self.assertDictEqual(
|
||||
{
|
||||
'packet_type': 0,
|
||||
'rpc_id': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||
'node_id': b'111111111111111111111111111111111111111111111111',
|
||||
'method': b'ping',
|
||||
'args': [{b'protocolVersion': 1}]
|
||||
}, _decode_datagram(encoded)[0]
|
||||
)
|
||||
|
||||
|
||||
class TestCompactAddress(unittest.TestCase):
|
||||
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
||||
|
|
Loading…
Reference in a new issue