preserve capitalization of kwargs

This commit is contained in:
Jack Robison 2018-10-10 19:37:18 -04:00
parent b0d1f7a193
commit ba8be4746a
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 116 additions and 39 deletions

View file

@ -1,6 +1,8 @@
import re import re
import logging import logging
import binascii import binascii
import json
from collections import OrderedDict
from typing import Dict, List from typing import Dict, List
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator from aioupnp.constants import line_separator
@ -74,30 +76,36 @@ class SSDPDatagram(object):
] ]
} }
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None, def __init__(self, packet_type, kwargs: OrderedDict = None) -> None:
cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]: if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type)) raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type self._packet_type = packet_type
self.host = host kwargs = kwargs or OrderedDict()
self.st = st self._field_order: list = [
self.man = man k.lower().replace("-", "_") for k in kwargs.keys()
self.mx = mx ]
self.nt = nt self.host = None
self.nts = nts self.st = None
self.usn = usn self.man = None
self.location = location self.mx = None
self.cache_control = cache_control self.nt = None
self.server = server self.nts = None
self.date = date self.usn = None
self.ext = ext self.location = None
self.cache_control = None
self.server = None
self.date = None
self.ext = None
for k, v in kwargs.items(): for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None: normalized = k.lower().replace("-", "_")
setattr(self, k.lower(), v) if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self,normalized) is None:
setattr(self, normalized, v)
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()}
def __repr__(self) -> str: def __repr__(self) -> str:
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \ return self.as_json()
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")" # return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
# ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
def __getitem__(self, item): def __getitem__(self, item):
for i in self._required_fields[self._packet_type]: for i in self._required_fields[self._packet_type]:
@ -110,7 +118,9 @@ class SSDPDatagram(object):
def encode(self, trailing_newlines: int = 2) -> str: def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]] lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]: for attr_name in self._field_order:
if attr_name not in self._required_fields[self._packet_type]:
continue
attr = getattr(self, attr_name) attr = getattr(self, attr_name)
if attr is None: if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name)) raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
@ -118,15 +128,18 @@ class SSDPDatagram(object):
value = str(attr) value = str(attr)
else: else:
value = attr value = attr
lines.append("{}: {}".format(attr_name.upper(), value)) lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
serialized = line_separator.join(lines) serialized = line_separator.join(lines)
for _ in range(trailing_newlines): for _ in range(trailing_newlines):
serialized += line_separator serialized += line_separator
return serialized return serialized
def as_dict(self) -> Dict: def as_dict(self) -> OrderedDict:
return self._lines_to_content_dict(self.encode().split(line_separator)) return self._lines_to_content_dict(self.encode().split(line_separator))
def as_json(self) -> str:
return json.dumps(self.as_dict(), indent=2)
@classmethod @classmethod
def decode(cls, datagram: bytes): def decode(cls, datagram: bytes):
packet = cls._from_string(datagram.decode()) packet = cls._from_string(datagram.decode())
@ -143,8 +156,8 @@ class SSDPDatagram(object):
return packet return packet
@classmethod @classmethod
def _lines_to_content_dict(cls, lines: list) -> Dict: def _lines_to_content_dict(cls, lines: list) -> OrderedDict:
result: dict = {} result: OrderedDict = OrderedDict()
for line in lines: for line in lines:
if not line: if not line:
continue continue
@ -152,9 +165,10 @@ class SSDPDatagram(object):
for name, (pattern, field_type) in cls._patterns.items(): for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line): if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1] match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" ")) result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" "))
matched = True matched = True
break break
if not matched: if not matched:
if cls._vendor_field_pattern.findall(line): if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1] match = cls._vendor_field_pattern.findall(line)[-1]
@ -177,12 +191,12 @@ class SSDPDatagram(object):
@classmethod @classmethod
def _from_response(cls, lines: List): def _from_response(cls, lines: List):
return cls(cls._OK, **cls._lines_to_content_dict(lines)) return cls(cls._OK, cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_notify(cls, lines: List): def _from_notify(cls, lines: List):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines)) return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_request(cls, lines: List): def _from_request(cls, lines: List):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines)) return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))

View file

@ -1,9 +1,83 @@
import unittest import unittest
from collections import OrderedDict
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER
class TestMSearchDatagramSerialization(unittest.TestCase):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'Host: 239.255.255.250:1900\r\n' \
b'Man: "ssdp:discover"\r\n' \
b'ST: ssdp:all\r\n' \
b'MX: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('Host', "{}:{}".format('239.255.255.250', 1900)),
('Man', '"ssdp:discover"'),
('ST', 'ssdp:all'),
('MX', 5),
])
def test_deserialize_and_reserialize(self):
packet1 = SSDPDatagram.decode(self.packet)
packet2 = SSDPDatagram("M-SEARCH", self.datagram_args)
self.assertEqual(packet2.encode(), packet1.encode())
class TestSerializationOrder(TestMSearchDatagramSerialization):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'Host: 239.255.255.250:1900\r\n' \
b'ST: ssdp:all\r\n' \
b'Man: "ssdp:discover"\r\n' \
b'MX: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('Host', "{}:{}".format('239.255.255.250', 1900)),
('ST', 'ssdp:all'),
('Man', '"ssdp:discover"'),
('MX', 5),
])
class TestSerializationPreserveCase(TestMSearchDatagramSerialization):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'HOST: 239.255.255.250:1900\r\n' \
b'ST: ssdp:all\r\n' \
b'Man: "ssdp:discover"\r\n' \
b'mx: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('HOST', "{}:{}".format('239.255.255.250', 1900)),
('ST', 'ssdp:all'),
('Man', '"ssdp:discover"'),
('mx', 5),
])
class TestSerializationPreserveAllLowerCase(TestMSearchDatagramSerialization):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'host: 239.255.255.250:1900\r\n' \
b'st: ssdp:all\r\n' \
b'man: "ssdp:discover"\r\n' \
b'mx: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('host', "{}:{}".format('239.255.255.250', 1900)),
('st', 'ssdp:all'),
('man', '"ssdp:discover"'),
('mx', 5),
])
class TestParseMSearchRequestWithQuotes(unittest.TestCase): class TestParseMSearchRequestWithQuotes(unittest.TestCase):
datagram = b"M-SEARCH * HTTP/1.1\r\n" \ datagram = b"M-SEARCH * HTTP/1.1\r\n" \
b"HOST: 239.255.255.250:1900\r\n" \ b"HOST: 239.255.255.250:1900\r\n" \
@ -20,17 +94,6 @@ class TestParseMSearchRequestWithQuotes(unittest.TestCase):
self.assertEqual(packet.man, '"ssdp:discover"') self.assertEqual(packet.man, '"ssdp:discover"')
self.assertEqual(packet.mx, 1) self.assertEqual(packet.mx, 1)
def test_serialize_m_search(self):
packet = SSDPDatagram.decode(self.datagram)
self.assertEqual(packet.encode().encode(), self.datagram)
self.assertEqual(
self.datagram, SSDPDatagram(
SSDPDatagram._M_SEARCH, host="{}:{}".format('239.255.255.250', 1900), st=UPNP_ORG_IGD,
man='\"%s\"' % SSDP_DISCOVER, mx=1
).encode().encode()
)
class TestParseMSearchRequestWithoutQuotes(unittest.TestCase): class TestParseMSearchRequestWithoutQuotes(unittest.TestCase):
datagram = b'M-SEARCH * HTTP/1.1\r\n' \ datagram = b'M-SEARCH * HTTP/1.1\r\n' \
@ -178,7 +241,7 @@ class TestParseNotify(unittest.TestCase):
packet = SSDPDatagram.decode(self.datagram) packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._NOTIFY) self.assertTrue(packet._packet_type, packet._NOTIFY)
self.assertEqual(packet.host, '239.255.255.250:1900') self.assertEqual(packet.host, '239.255.255.250:1900')
self.assertEqual(packet.cache_control, 'max-age=180') self.assertEqual(packet.cache_control, 'max-age=180') # this is an optional field
self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068') self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068')
self.assertEqual(packet.nt, 'upnp:rootdevice') self.assertEqual(packet.nt, 'upnp:rootdevice')
self.assertEqual(packet.nts, 'ssdp:alive') self.assertEqual(packet.nts, 'ssdp:alive')