Merge pull request #3045 from lbryio/bencode-byte-keys

add forward compatibility for byte datagram keys
This commit is contained in:
Lex Berezhny 2020-09-29 07:23:49 -04:00 committed by GitHub
commit 318cc15323
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 7 deletions

View file

@ -144,7 +144,7 @@ class ErrorDatagram(KademliaDatagramBase):
self.response = response.decode()
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
def _decode_datagram(datagram: bytes):
msg_types = {
REQUEST_TYPE: RequestDatagram,
RESPONSE_TYPE: ResponseDatagram,
@ -152,19 +152,29 @@ def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDa
}
primitive: typing.Dict = bdecode(datagram)
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
converted = {
str(k).encode() if not isinstance(k, bytes) else k: v for k, v in primitive.items()
}
if converted[b'0'] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = converted[b'0'] # pylint: disable=unsubscriptable-object
else:
raise ValueError("invalid datagram type")
datagram_class = msg_types[datagram_type]
decoded = {
k: primitive[i] # pylint: disable=unsubscriptable-object
k: converted[str(i).encode()] # pylint: disable=unsubscriptable-object
for i, k in enumerate(datagram_class.required_fields)
if i in primitive # pylint: disable=unsupported-membership-test
if str(i).encode() in converted # pylint: disable=unsupported-membership-test
}
for i, _ in enumerate(OPTIONAL_FIELDS):
if i + OPTIONAL_ARG_OFFSET in primitive:
decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
if str(i + OPTIONAL_ARG_OFFSET).encode() in converted:
decoded[i + OPTIONAL_ARG_OFFSET] = converted[str(i + OPTIONAL_ARG_OFFSET).encode()]
return decoded, datagram_class
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
decoded, datagram_class = _decode_datagram(datagram)
return datagram_class(**decoded)

View file

@ -1,7 +1,9 @@
import binascii
import unittest
from lbry.dht.error import DecodeError
from lbry.dht.serialization.bencoding import _bencode
from lbry.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
from lbry.dht.serialization.datagram import _decode_datagram
from lbry.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
from lbry.dht.serialization.datagram import make_compact_address, decode_compact_address
@ -139,6 +141,38 @@ class TestDatagram(unittest.TestCase):
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
self.assertEqual(b'ping', datagram.method)
def test_str_or_int_keys(self):
datagram = decode_datagram(_bencode({
b'0': 0,
b'1': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
b'2': b'111111111111111111111111111111111111111111111111',
b'3': b'ping',
b'4': [{b'protocolVersion': 1}],
b'5': b'should not error'
}))
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
self.assertEqual(b'ping', datagram.method)
def test_mixed_str_or_int_keys(self):
# datagram, _ = _bencode({
# b'0': 0,
# 1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
# b'2': b'111111111111111111111111111111111111111111111111',
# 3: b'ping',
# b'4': [{b'protocolVersion': 1}],
# b'5': b'should not error'
# }))
encoded = binascii.unhexlify(b"64313a3069306569316532303a0abcb5269d6cfc1e87a08e920bf39fe9df8e92fc313a3234383a313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131693365343a70696e67313a346c6431353a70726f746f636f6c56657273696f6e6931656565313a3531363a73686f756c64206e6f74206572726f7265")
self.assertDictEqual(
{
'packet_type': 0,
'rpc_id': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
'node_id': b'111111111111111111111111111111111111111111111111',
'method': b'ping',
'args': [{b'protocolVersion': 1}]
}, _decode_datagram(encoded)[0]
)
class TestCompactAddress(unittest.TestCase):
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):