preserve capitalization of kwargs
This commit is contained in:
parent
b0d1f7a193
commit
ba8be4746a
2 changed files with 116 additions and 39 deletions
|
@ -1,6 +1,8 @@
|
|||
import re
|
||||
import logging
|
||||
import binascii
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.constants import line_separator
|
||||
|
@ -74,30 +76,36 @@ class SSDPDatagram(object):
|
|||
]
|
||||
}
|
||||
|
||||
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
|
||||
cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
|
||||
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None:
|
||||
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
||||
raise UPnPError("unknown packet type: {}".format(packet_type))
|
||||
self._packet_type = packet_type
|
||||
self.host = host
|
||||
self.st = st
|
||||
self.man = man
|
||||
self.mx = mx
|
||||
self.nt = nt
|
||||
self.nts = nts
|
||||
self.usn = usn
|
||||
self.location = location
|
||||
self.cache_control = cache_control
|
||||
self.server = server
|
||||
self.date = date
|
||||
self.ext = ext
|
||||
kwargs = kwargs or OrderedDict()
|
||||
self._field_order: list = [
|
||||
k.lower().replace("-", "_") for k in kwargs.keys()
|
||||
]
|
||||
self.host = None
|
||||
self.st = None
|
||||
self.man = None
|
||||
self.mx = None
|
||||
self.nt = None
|
||||
self.nts = None
|
||||
self.usn = None
|
||||
self.location = None
|
||||
self.cache_control = None
|
||||
self.server = None
|
||||
self.date = None
|
||||
self.ext = None
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
|
||||
setattr(self, k.lower(), v)
|
||||
normalized = k.lower().replace("-", "_")
|
||||
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self,normalized) is None:
|
||||
setattr(self, normalized, v)
|
||||
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
|
||||
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
|
||||
return self.as_json()
|
||||
# return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
|
||||
# ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
|
||||
|
||||
def __getitem__(self, item):
|
||||
for i in self._required_fields[self._packet_type]:
|
||||
|
@ -110,7 +118,9 @@ class SSDPDatagram(object):
|
|||
|
||||
def encode(self, trailing_newlines: int = 2) -> str:
|
||||
lines = [self._start_lines[self._packet_type]]
|
||||
for attr_name in self._required_fields[self._packet_type]:
|
||||
for attr_name in self._field_order:
|
||||
if attr_name not in self._required_fields[self._packet_type]:
|
||||
continue
|
||||
attr = getattr(self, attr_name)
|
||||
if attr is None:
|
||||
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
||||
|
@ -118,15 +128,18 @@ class SSDPDatagram(object):
|
|||
value = str(attr)
|
||||
else:
|
||||
value = attr
|
||||
lines.append("{}: {}".format(attr_name.upper(), value))
|
||||
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
|
||||
serialized = line_separator.join(lines)
|
||||
for _ in range(trailing_newlines):
|
||||
serialized += line_separator
|
||||
return serialized
|
||||
|
||||
def as_dict(self) -> Dict:
|
||||
def as_dict(self) -> OrderedDict:
|
||||
return self._lines_to_content_dict(self.encode().split(line_separator))
|
||||
|
||||
def as_json(self) -> str:
|
||||
return json.dumps(self.as_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, datagram: bytes):
|
||||
packet = cls._from_string(datagram.decode())
|
||||
|
@ -143,8 +156,8 @@ class SSDPDatagram(object):
|
|||
return packet
|
||||
|
||||
@classmethod
|
||||
def _lines_to_content_dict(cls, lines: list) -> Dict:
|
||||
result: dict = {}
|
||||
def _lines_to_content_dict(cls, lines: list) -> OrderedDict:
|
||||
result: OrderedDict = OrderedDict()
|
||||
for line in lines:
|
||||
if not line:
|
||||
continue
|
||||
|
@ -152,9 +165,10 @@ class SSDPDatagram(object):
|
|||
for name, (pattern, field_type) in cls._patterns.items():
|
||||
if name not in result and pattern.findall(line):
|
||||
match = pattern.findall(line)[-1][-1]
|
||||
result[name] = field_type(match.lstrip(" ").rstrip(" "))
|
||||
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" "))
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
if cls._vendor_field_pattern.findall(line):
|
||||
match = cls._vendor_field_pattern.findall(line)[-1]
|
||||
|
@ -177,12 +191,12 @@ class SSDPDatagram(object):
|
|||
|
||||
@classmethod
|
||||
def _from_response(cls, lines: List):
|
||||
return cls(cls._OK, **cls._lines_to_content_dict(lines))
|
||||
return cls(cls._OK, cls._lines_to_content_dict(lines))
|
||||
|
||||
@classmethod
|
||||
def _from_notify(cls, lines: List):
|
||||
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
|
||||
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
|
||||
|
||||
@classmethod
|
||||
def _from_request(cls, lines: List):
|
||||
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
|
||||
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))
|
||||
|
|
|
@ -1,9 +1,83 @@
|
|||
import unittest
|
||||
from collections import OrderedDict
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER
|
||||
|
||||
|
||||
class TestMSearchDatagramSerialization(unittest.TestCase):
|
||||
packet = \
|
||||
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||
b'Host: 239.255.255.250:1900\r\n' \
|
||||
b'Man: "ssdp:discover"\r\n' \
|
||||
b'ST: ssdp:all\r\n' \
|
||||
b'MX: 5\r\n' \
|
||||
b'\r\n'
|
||||
|
||||
datagram_args = OrderedDict([
|
||||
('Host', "{}:{}".format('239.255.255.250', 1900)),
|
||||
('Man', '"ssdp:discover"'),
|
||||
('ST', 'ssdp:all'),
|
||||
('MX', 5),
|
||||
])
|
||||
|
||||
def test_deserialize_and_reserialize(self):
|
||||
packet1 = SSDPDatagram.decode(self.packet)
|
||||
packet2 = SSDPDatagram("M-SEARCH", self.datagram_args)
|
||||
self.assertEqual(packet2.encode(), packet1.encode())
|
||||
|
||||
|
||||
class TestSerializationOrder(TestMSearchDatagramSerialization):
|
||||
packet = \
|
||||
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||
b'Host: 239.255.255.250:1900\r\n' \
|
||||
b'ST: ssdp:all\r\n' \
|
||||
b'Man: "ssdp:discover"\r\n' \
|
||||
b'MX: 5\r\n' \
|
||||
b'\r\n'
|
||||
|
||||
datagram_args = OrderedDict([
|
||||
('Host', "{}:{}".format('239.255.255.250', 1900)),
|
||||
('ST', 'ssdp:all'),
|
||||
('Man', '"ssdp:discover"'),
|
||||
('MX', 5),
|
||||
])
|
||||
|
||||
|
||||
class TestSerializationPreserveCase(TestMSearchDatagramSerialization):
|
||||
packet = \
|
||||
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||
b'HOST: 239.255.255.250:1900\r\n' \
|
||||
b'ST: ssdp:all\r\n' \
|
||||
b'Man: "ssdp:discover"\r\n' \
|
||||
b'mx: 5\r\n' \
|
||||
b'\r\n'
|
||||
|
||||
datagram_args = OrderedDict([
|
||||
('HOST', "{}:{}".format('239.255.255.250', 1900)),
|
||||
('ST', 'ssdp:all'),
|
||||
('Man', '"ssdp:discover"'),
|
||||
('mx', 5),
|
||||
])
|
||||
|
||||
|
||||
class TestSerializationPreserveAllLowerCase(TestMSearchDatagramSerialization):
|
||||
packet = \
|
||||
b'M-SEARCH * HTTP/1.1\r\n' \
|
||||
b'host: 239.255.255.250:1900\r\n' \
|
||||
b'st: ssdp:all\r\n' \
|
||||
b'man: "ssdp:discover"\r\n' \
|
||||
b'mx: 5\r\n' \
|
||||
b'\r\n'
|
||||
|
||||
datagram_args = OrderedDict([
|
||||
('host', "{}:{}".format('239.255.255.250', 1900)),
|
||||
('st', 'ssdp:all'),
|
||||
('man', '"ssdp:discover"'),
|
||||
('mx', 5),
|
||||
])
|
||||
|
||||
|
||||
class TestParseMSearchRequestWithQuotes(unittest.TestCase):
|
||||
datagram = b"M-SEARCH * HTTP/1.1\r\n" \
|
||||
b"HOST: 239.255.255.250:1900\r\n" \
|
||||
|
@ -20,17 +94,6 @@ class TestParseMSearchRequestWithQuotes(unittest.TestCase):
|
|||
self.assertEqual(packet.man, '"ssdp:discover"')
|
||||
self.assertEqual(packet.mx, 1)
|
||||
|
||||
def test_serialize_m_search(self):
|
||||
packet = SSDPDatagram.decode(self.datagram)
|
||||
self.assertEqual(packet.encode().encode(), self.datagram)
|
||||
|
||||
self.assertEqual(
|
||||
self.datagram, SSDPDatagram(
|
||||
SSDPDatagram._M_SEARCH, host="{}:{}".format('239.255.255.250', 1900), st=UPNP_ORG_IGD,
|
||||
man='\"%s\"' % SSDP_DISCOVER, mx=1
|
||||
).encode().encode()
|
||||
)
|
||||
|
||||
|
||||
class TestParseMSearchRequestWithoutQuotes(unittest.TestCase):
|
||||
datagram = b'M-SEARCH * HTTP/1.1\r\n' \
|
||||
|
@ -178,7 +241,7 @@ class TestParseNotify(unittest.TestCase):
|
|||
packet = SSDPDatagram.decode(self.datagram)
|
||||
self.assertTrue(packet._packet_type, packet._NOTIFY)
|
||||
self.assertEqual(packet.host, '239.255.255.250:1900')
|
||||
self.assertEqual(packet.cache_control, 'max-age=180')
|
||||
self.assertEqual(packet.cache_control, 'max-age=180') # this is an optional field
|
||||
self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068')
|
||||
self.assertEqual(packet.nt, 'upnp:rootdevice')
|
||||
self.assertEqual(packet.nts, 'ssdp:alive')
|
||||
|
|
Loading…
Reference in a new issue