forked from LBRYCommunity/lbry-sdk
add forward compatibility for byte datagram keys
This commit is contained in:
parent
d0f21c0095
commit
3a64ceb4d6
2 changed files with 51 additions and 7 deletions
|
@ -144,7 +144,7 @@ class ErrorDatagram(KademliaDatagramBase):
|
||||||
self.response = response.decode()
|
self.response = response.decode()
|
||||||
|
|
||||||
|
|
||||||
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
|
def _decode_datagram(datagram: bytes):
|
||||||
msg_types = {
|
msg_types = {
|
||||||
REQUEST_TYPE: RequestDatagram,
|
REQUEST_TYPE: RequestDatagram,
|
||||||
RESPONSE_TYPE: ResponseDatagram,
|
RESPONSE_TYPE: ResponseDatagram,
|
||||||
|
@ -152,19 +152,29 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
|
||||||
}
|
}
|
||||||
|
|
||||||
primitive: typing.Dict = bdecode(datagram)
|
primitive: typing.Dict = bdecode(datagram)
|
||||||
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
|
||||||
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
converted = {
|
||||||
|
str(k).encode() if not isinstance(k, bytes) else k: v for k, v in primitive.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if converted[b'0'] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
||||||
|
datagram_type = converted[b'0'] # pylint: disable=unsubscriptable-object
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid datagram type")
|
raise ValueError("invalid datagram type")
|
||||||
datagram_class = msg_types[datagram_type]
|
datagram_class = msg_types[datagram_type]
|
||||||
decoded = {
|
decoded = {
|
||||||
k: primitive[i] # pylint: disable=unsubscriptable-object
|
k: converted[str(i).encode()] # pylint: disable=unsubscriptable-object
|
||||||
for i, k in enumerate(datagram_class.required_fields)
|
for i, k in enumerate(datagram_class.required_fields)
|
||||||
if i in primitive # pylint: disable=unsupported-membership-test
|
if str(i).encode() in converted # pylint: disable=unsupported-membership-test
|
||||||
}
|
}
|
||||||
for i, _ in enumerate(OPTIONAL_FIELDS):
|
for i, _ in enumerate(OPTIONAL_FIELDS):
|
||||||
if i + OPTIONAL_ARG_OFFSET in primitive:
|
if str(i + OPTIONAL_ARG_OFFSET).encode() in converted:
|
||||||
decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
|
decoded[i + OPTIONAL_ARG_OFFSET] = converted[str(i + OPTIONAL_ARG_OFFSET).encode()]
|
||||||
|
return decoded, datagram_class
|
||||||
|
|
||||||
|
|
||||||
|
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
|
||||||
|
decoded, datagram_class = _decode_datagram(datagram)
|
||||||
return datagram_class(**decoded)
|
return datagram_class(**decoded)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import binascii
|
||||||
import unittest
|
import unittest
|
||||||
from lbry.dht.error import DecodeError
|
from lbry.dht.error import DecodeError
|
||||||
from lbry.dht.serialization.bencoding import _bencode
|
from lbry.dht.serialization.bencoding import _bencode
|
||||||
from lbry.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
|
from lbry.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
|
||||||
|
from lbry.dht.serialization.datagram import _decode_datagram
|
||||||
from lbry.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
|
from lbry.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
|
||||||
from lbry.dht.serialization.datagram import make_compact_address, decode_compact_address
|
from lbry.dht.serialization.datagram import make_compact_address, decode_compact_address
|
||||||
|
|
||||||
|
@ -139,6 +141,38 @@ class TestDatagram(unittest.TestCase):
|
||||||
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
|
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
|
||||||
self.assertEqual(b'ping', datagram.method)
|
self.assertEqual(b'ping', datagram.method)
|
||||||
|
|
||||||
|
def test_str_or_int_keys(self):
|
||||||
|
datagram = decode_datagram(_bencode({
|
||||||
|
b'0': 0,
|
||||||
|
b'1': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||||
|
b'2': b'111111111111111111111111111111111111111111111111',
|
||||||
|
b'3': b'ping',
|
||||||
|
b'4': [{b'protocolVersion': 1}],
|
||||||
|
b'5': b'should not error'
|
||||||
|
}))
|
||||||
|
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
|
||||||
|
self.assertEqual(b'ping', datagram.method)
|
||||||
|
|
||||||
|
def test_mixed_str_or_int_keys(self):
|
||||||
|
# datagram, _ = _bencode({
|
||||||
|
# b'0': 0,
|
||||||
|
# 1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||||
|
# b'2': b'111111111111111111111111111111111111111111111111',
|
||||||
|
# 3: b'ping',
|
||||||
|
# b'4': [{b'protocolVersion': 1}],
|
||||||
|
# b'5': b'should not error'
|
||||||
|
# }))
|
||||||
|
encoded = binascii.unhexlify(b"64313a3069306569316532303a0abcb5269d6cfc1e87a08e920bf39fe9df8e92fc313a3234383a313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131693365343a70696e67313a346c6431353a70726f746f636f6c56657273696f6e6931656565313a3531363a73686f756c64206e6f74206572726f7265")
|
||||||
|
self.assertDictEqual(
|
||||||
|
{
|
||||||
|
'packet_type': 0,
|
||||||
|
'rpc_id': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
|
||||||
|
'node_id': b'111111111111111111111111111111111111111111111111',
|
||||||
|
'method': b'ping',
|
||||||
|
'args': [{b'protocolVersion': 1}]
|
||||||
|
}, _decode_datagram(encoded)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCompactAddress(unittest.TestCase):
|
class TestCompactAddress(unittest.TestCase):
|
||||||
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
|
||||||
|
|
Loading…
Reference in a new issue