preserve capitalization of kwargs

This commit is contained in:
Jack Robison 2018-10-10 19:37:18 -04:00
parent b0d1f7a193
commit ba8be4746a
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 116 additions and 39 deletions

View file

@ -1,6 +1,8 @@
import re
import logging
import binascii
import json
from collections import OrderedDict
from typing import Dict, List
from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator
@ -74,30 +76,36 @@ class SSDPDatagram(object):
]
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
self.host = host
self.st = st
self.man = man
self.mx = mx
self.nt = nt
self.nts = nts
self.usn = usn
self.location = location
self.cache_control = cache_control
self.server = server
self.date = date
self.ext = ext
kwargs = kwargs or OrderedDict()
self._field_order: list = [
k.lower().replace("-", "_") for k in kwargs.keys()
]
self.host = None
self.st = None
self.man = None
self.mx = None
self.nt = None
self.nts = None
self.usn = None
self.location = None
self.cache_control = None
self.server = None
self.date = None
self.ext = None
for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
normalized = k.lower().replace("-", "_")
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self,normalized) is None:
setattr(self, normalized, v)
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()}
def __repr__(self) -> str:
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
return self.as_json()
# return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
# ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
@ -110,7 +118,9 @@ class SSDPDatagram(object):
def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
for attr_name in self._field_order:
if attr_name not in self._required_fields[self._packet_type]:
continue
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
@ -118,15 +128,18 @@ class SSDPDatagram(object):
value = str(attr)
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self) -> Dict:
def as_dict(self) -> OrderedDict:
return self._lines_to_content_dict(self.encode().split(line_separator))
def as_json(self) -> str:
return json.dumps(self.as_dict(), indent=2)
@classmethod
def decode(cls, datagram: bytes):
packet = cls._from_string(datagram.decode())
@ -143,8 +156,8 @@ class SSDPDatagram(object):
return packet
@classmethod
def _lines_to_content_dict(cls, lines: list) -> Dict:
result: dict = {}
def _lines_to_content_dict(cls, lines: list) -> OrderedDict:
result: OrderedDict = OrderedDict()
for line in lines:
if not line:
continue
@ -152,9 +165,10 @@ class SSDPDatagram(object):
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" "))
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
@ -177,12 +191,12 @@ class SSDPDatagram(object):
@classmethod
def _from_response(cls, lines: List):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
return cls(cls._OK, cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines: List):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines: List):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))

View file

@ -1,9 +1,83 @@
import unittest
from collections import OrderedDict
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.fault import UPnPError
from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER
class TestMSearchDatagramSerialization(unittest.TestCase):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'Host: 239.255.255.250:1900\r\n' \
b'Man: "ssdp:discover"\r\n' \
b'ST: ssdp:all\r\n' \
b'MX: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('Host', "{}:{}".format('239.255.255.250', 1900)),
('Man', '"ssdp:discover"'),
('ST', 'ssdp:all'),
('MX', 5),
])
def test_deserialize_and_reserialize(self):
packet1 = SSDPDatagram.decode(self.packet)
packet2 = SSDPDatagram("M-SEARCH", self.datagram_args)
self.assertEqual(packet2.encode(), packet1.encode())
class TestSerializationOrder(TestMSearchDatagramSerialization):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'Host: 239.255.255.250:1900\r\n' \
b'ST: ssdp:all\r\n' \
b'Man: "ssdp:discover"\r\n' \
b'MX: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('Host', "{}:{}".format('239.255.255.250', 1900)),
('ST', 'ssdp:all'),
('Man', '"ssdp:discover"'),
('MX', 5),
])
class TestSerializationPreserveCase(TestMSearchDatagramSerialization):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'HOST: 239.255.255.250:1900\r\n' \
b'ST: ssdp:all\r\n' \
b'Man: "ssdp:discover"\r\n' \
b'mx: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('HOST', "{}:{}".format('239.255.255.250', 1900)),
('ST', 'ssdp:all'),
('Man', '"ssdp:discover"'),
('mx', 5),
])
class TestSerializationPreserveAllLowerCase(TestMSearchDatagramSerialization):
packet = \
b'M-SEARCH * HTTP/1.1\r\n' \
b'host: 239.255.255.250:1900\r\n' \
b'st: ssdp:all\r\n' \
b'man: "ssdp:discover"\r\n' \
b'mx: 5\r\n' \
b'\r\n'
datagram_args = OrderedDict([
('host', "{}:{}".format('239.255.255.250', 1900)),
('st', 'ssdp:all'),
('man', '"ssdp:discover"'),
('mx', 5),
])
class TestParseMSearchRequestWithQuotes(unittest.TestCase):
datagram = b"M-SEARCH * HTTP/1.1\r\n" \
b"HOST: 239.255.255.250:1900\r\n" \
@ -20,17 +94,6 @@ class TestParseMSearchRequestWithQuotes(unittest.TestCase):
self.assertEqual(packet.man, '"ssdp:discover"')
self.assertEqual(packet.mx, 1)
def test_serialize_m_search(self):
packet = SSDPDatagram.decode(self.datagram)
self.assertEqual(packet.encode().encode(), self.datagram)
self.assertEqual(
self.datagram, SSDPDatagram(
SSDPDatagram._M_SEARCH, host="{}:{}".format('239.255.255.250', 1900), st=UPNP_ORG_IGD,
man='\"%s\"' % SSDP_DISCOVER, mx=1
).encode().encode()
)
class TestParseMSearchRequestWithoutQuotes(unittest.TestCase):
datagram = b'M-SEARCH * HTTP/1.1\r\n' \
@ -178,7 +241,7 @@ class TestParseNotify(unittest.TestCase):
packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._NOTIFY)
self.assertEqual(packet.host, '239.255.255.250:1900')
self.assertEqual(packet.cache_control, 'max-age=180')
self.assertEqual(packet.cache_control, 'max-age=180') # this is an optional field
self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068')
self.assertEqual(packet.nt, 'upnp:rootdevice')
self.assertEqual(packet.nts, 'ssdp:alive')