preserve capitalization of kwargs
This commit is contained in:
parent
b0d1f7a193
commit
ba8be4746a
2 changed files with 116 additions and 39 deletions
|
@ -1,6 +1,8 @@
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import binascii
|
import binascii
|
||||||
|
import json
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from aioupnp.fault import UPnPError
|
from aioupnp.fault import UPnPError
|
||||||
from aioupnp.constants import line_separator
|
from aioupnp.constants import line_separator
|
||||||
|
@ -74,30 +76,36 @@ class SSDPDatagram(object):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
|
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None:
|
||||||
cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
|
|
||||||
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
||||||
raise UPnPError("unknown packet type: {}".format(packet_type))
|
raise UPnPError("unknown packet type: {}".format(packet_type))
|
||||||
self._packet_type = packet_type
|
self._packet_type = packet_type
|
||||||
self.host = host
|
kwargs = kwargs or OrderedDict()
|
||||||
self.st = st
|
self._field_order: list = [
|
||||||
self.man = man
|
k.lower().replace("-", "_") for k in kwargs.keys()
|
||||||
self.mx = mx
|
]
|
||||||
self.nt = nt
|
self.host = None
|
||||||
self.nts = nts
|
self.st = None
|
||||||
self.usn = usn
|
self.man = None
|
||||||
self.location = location
|
self.mx = None
|
||||||
self.cache_control = cache_control
|
self.nt = None
|
||||||
self.server = server
|
self.nts = None
|
||||||
self.date = date
|
self.usn = None
|
||||||
self.ext = ext
|
self.location = None
|
||||||
|
self.cache_control = None
|
||||||
|
self.server = None
|
||||||
|
self.date = None
|
||||||
|
self.ext = None
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
|
normalized = k.lower().replace("-", "_")
|
||||||
setattr(self, k.lower(), v)
|
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self,normalized) is None:
|
||||||
|
setattr(self, normalized, v)
|
||||||
|
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
|
return self.as_json()
|
||||||
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
|
# return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
|
||||||
|
# ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
for i in self._required_fields[self._packet_type]:
|
for i in self._required_fields[self._packet_type]:
|
||||||
|
@ -110,7 +118,9 @@ class SSDPDatagram(object):
|
||||||
|
|
||||||
def encode(self, trailing_newlines: int = 2) -> str:
|
def encode(self, trailing_newlines: int = 2) -> str:
|
||||||
lines = [self._start_lines[self._packet_type]]
|
lines = [self._start_lines[self._packet_type]]
|
||||||
for attr_name in self._required_fields[self._packet_type]:
|
for attr_name in self._field_order:
|
||||||
|
if attr_name not in self._required_fields[self._packet_type]:
|
||||||
|
continue
|
||||||
attr = getattr(self, attr_name)
|
attr = getattr(self, attr_name)
|
||||||
if attr is None:
|
if attr is None:
|
||||||
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
||||||
|
@ -118,15 +128,18 @@ class SSDPDatagram(object):
|
||||||
value = str(attr)
|
value = str(attr)
|
||||||
else:
|
else:
|
||||||
value = attr
|
value = attr
|
||||||
lines.append("{}: {}".format(attr_name.upper(), value))
|
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
|
||||||
serialized = line_separator.join(lines)
|
serialized = line_separator.join(lines)
|
||||||
for _ in range(trailing_newlines):
|
for _ in range(trailing_newlines):
|
||||||
serialized += line_separator
|
serialized += line_separator
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> OrderedDict:
|
||||||
return self._lines_to_content_dict(self.encode().split(line_separator))
|
return self._lines_to_content_dict(self.encode().split(line_separator))
|
||||||
|
|
||||||
|
def as_json(self) -> str:
|
||||||
|
return json.dumps(self.as_dict(), indent=2)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, datagram: bytes):
|
def decode(cls, datagram: bytes):
|
||||||
packet = cls._from_string(datagram.decode())
|
packet = cls._from_string(datagram.decode())
|
||||||
|
@ -143,8 +156,8 @@ class SSDPDatagram(object):
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _lines_to_content_dict(cls, lines: list) -> Dict:
|
def _lines_to_content_dict(cls, lines: list) -> OrderedDict:
|
||||||
result: dict = {}
|
result: OrderedDict = OrderedDict()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
@ -152,9 +165,10 @@ class SSDPDatagram(object):
|
||||||
for name, (pattern, field_type) in cls._patterns.items():
|
for name, (pattern, field_type) in cls._patterns.items():
|
||||||
if name not in result and pattern.findall(line):
|
if name not in result and pattern.findall(line):
|
||||||
match = pattern.findall(line)[-1][-1]
|
match = pattern.findall(line)[-1][-1]
|
||||||
result[name] = field_type(match.lstrip(" ").rstrip(" "))
|
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" "))
|
||||||
matched = True
|
matched = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if not matched:
|
if not matched:
|
||||||
if cls._vendor_field_pattern.findall(line):
|
if cls._vendor_field_pattern.findall(line):
|
||||||
match = cls._vendor_field_pattern.findall(line)[-1]
|
match = cls._vendor_field_pattern.findall(line)[-1]
|
||||||
|
@ -177,12 +191,12 @@ class SSDPDatagram(object):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_response(cls, lines: List):
|
def _from_response(cls, lines: List):
|
||||||
return cls(cls._OK, **cls._lines_to_content_dict(lines))
|
return cls(cls._OK, cls._lines_to_content_dict(lines))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_notify(cls, lines: List):
|
def _from_notify(cls, lines: List):
|
||||||
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
|
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_request(cls, lines: List):
|
def _from_request(cls, lines: List):
|
||||||
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
|
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))
|
||||||
|
|
|
@ -1,9 +1,83 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import OrderedDict
|
||||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||||
from aioupnp.fault import UPnPError
|
from aioupnp.fault import UPnPError
|
||||||
from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER
|
from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSearchDatagramSerialization(unittest.TestCase):
|
||||||
|
packet = \
|
||||||
|
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||||
|
b'Host: 239.255.255.250:1900\r\n' \
|
||||||
|
b'Man: "ssdp:discover"\r\n' \
|
||||||
|
b'ST: ssdp:all\r\n' \
|
||||||
|
b'MX: 5\r\n' \
|
||||||
|
b'\r\n'
|
||||||
|
|
||||||
|
datagram_args = OrderedDict([
|
||||||
|
('Host', "{}:{}".format('239.255.255.250', 1900)),
|
||||||
|
('Man', '"ssdp:discover"'),
|
||||||
|
('ST', 'ssdp:all'),
|
||||||
|
('MX', 5),
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_deserialize_and_reserialize(self):
|
||||||
|
packet1 = SSDPDatagram.decode(self.packet)
|
||||||
|
packet2 = SSDPDatagram("M-SEARCH", self.datagram_args)
|
||||||
|
self.assertEqual(packet2.encode(), packet1.encode())
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializationOrder(TestMSearchDatagramSerialization):
|
||||||
|
packet = \
|
||||||
|
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||||
|
b'Host: 239.255.255.250:1900\r\n' \
|
||||||
|
b'ST: ssdp:all\r\n' \
|
||||||
|
b'Man: "ssdp:discover"\r\n' \
|
||||||
|
b'MX: 5\r\n' \
|
||||||
|
b'\r\n'
|
||||||
|
|
||||||
|
datagram_args = OrderedDict([
|
||||||
|
('Host', "{}:{}".format('239.255.255.250', 1900)),
|
||||||
|
('ST', 'ssdp:all'),
|
||||||
|
('Man', '"ssdp:discover"'),
|
||||||
|
('MX', 5),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializationPreserveCase(TestMSearchDatagramSerialization):
|
||||||
|
packet = \
|
||||||
|
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||||
|
b'HOST: 239.255.255.250:1900\r\n' \
|
||||||
|
b'ST: ssdp:all\r\n' \
|
||||||
|
b'Man: "ssdp:discover"\r\n' \
|
||||||
|
b'mx: 5\r\n' \
|
||||||
|
b'\r\n'
|
||||||
|
|
||||||
|
datagram_args = OrderedDict([
|
||||||
|
('HOST', "{}:{}".format('239.255.255.250', 1900)),
|
||||||
|
('ST', 'ssdp:all'),
|
||||||
|
('Man', '"ssdp:discover"'),
|
||||||
|
('mx', 5),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializationPreserveAllLowerCase(TestMSearchDatagramSerialization):
|
||||||
|
packet = \
|
||||||
|
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||||
|
b'host: 239.255.255.250:1900\r\n' \
|
||||||
|
b'st: ssdp:all\r\n' \
|
||||||
|
b'man: "ssdp:discover"\r\n' \
|
||||||
|
b'mx: 5\r\n' \
|
||||||
|
b'\r\n'
|
||||||
|
|
||||||
|
datagram_args = OrderedDict([
|
||||||
|
('host', "{}:{}".format('239.255.255.250', 1900)),
|
||||||
|
('st', 'ssdp:all'),
|
||||||
|
('man', '"ssdp:discover"'),
|
||||||
|
('mx', 5),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class TestParseMSearchRequestWithQuotes(unittest.TestCase):
|
class TestParseMSearchRequestWithQuotes(unittest.TestCase):
|
||||||
datagram = b"M-SEARCH * HTTP/1.1\r\n" \
|
datagram = b"M-SEARCH * HTTP/1.1\r\n" \
|
||||||
b"HOST: 239.255.255.250:1900\r\n" \
|
b"HOST: 239.255.255.250:1900\r\n" \
|
||||||
|
@ -20,17 +94,6 @@ class TestParseMSearchRequestWithQuotes(unittest.TestCase):
|
||||||
self.assertEqual(packet.man, '"ssdp:discover"')
|
self.assertEqual(packet.man, '"ssdp:discover"')
|
||||||
self.assertEqual(packet.mx, 1)
|
self.assertEqual(packet.mx, 1)
|
||||||
|
|
||||||
def test_serialize_m_search(self):
|
|
||||||
packet = SSDPDatagram.decode(self.datagram)
|
|
||||||
self.assertEqual(packet.encode().encode(), self.datagram)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
self.datagram, SSDPDatagram(
|
|
||||||
SSDPDatagram._M_SEARCH, host="{}:{}".format('239.255.255.250', 1900), st=UPNP_ORG_IGD,
|
|
||||||
man='\"%s\"' % SSDP_DISCOVER, mx=1
|
|
||||||
).encode().encode()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestParseMSearchRequestWithoutQuotes(unittest.TestCase):
|
class TestParseMSearchRequestWithoutQuotes(unittest.TestCase):
|
||||||
datagram = b'M-SEARCH * HTTP/1.1\r\n' \
|
datagram = b'M-SEARCH * HTTP/1.1\r\n' \
|
||||||
|
@ -178,7 +241,7 @@ class TestParseNotify(unittest.TestCase):
|
||||||
packet = SSDPDatagram.decode(self.datagram)
|
packet = SSDPDatagram.decode(self.datagram)
|
||||||
self.assertTrue(packet._packet_type, packet._NOTIFY)
|
self.assertTrue(packet._packet_type, packet._NOTIFY)
|
||||||
self.assertEqual(packet.host, '239.255.255.250:1900')
|
self.assertEqual(packet.host, '239.255.255.250:1900')
|
||||||
self.assertEqual(packet.cache_control, 'max-age=180')
|
self.assertEqual(packet.cache_control, 'max-age=180') # this is an optional field
|
||||||
self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068')
|
self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068')
|
||||||
self.assertEqual(packet.nt, 'upnp:rootdevice')
|
self.assertEqual(packet.nt, 'upnp:rootdevice')
|
||||||
self.assertEqual(packet.nts, 'ssdp:alive')
|
self.assertEqual(packet.nts, 'ssdp:alive')
|
||||||
|
|
Loading…
Reference in a new issue