fix parsing Netgear Nighthawk AC2350 xml, update tests

This commit is contained in:
Jack Robison 2018-10-29 13:02:19 -04:00
parent cfce30fa3a
commit 475a0c7738
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 119 additions and 108 deletions

View file

@ -8,7 +8,6 @@ class CaseInsensitive:
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
for k, v in kwargs.items(): for k, v in kwargs.items():
if not k.startswith("_"): if not k.startswith("_"):
getattr(self, k)
setattr(self, k, v) setattr(self, k, v)
def __getattr__(self, item): def __getattr__(self, item):
@ -22,6 +21,9 @@ class CaseInsensitive:
if k.lower() == item.lower(): if k.lower() == item.lower():
self.__dict__[k] = value self.__dict__[k] = value
return return
if not item.startswith("_"):
self.__dict__[item] = value
return
raise AttributeError(item) raise AttributeError(item)
def as_dict(self) -> dict: def as_dict(self) -> dict:

View file

@ -221,8 +221,9 @@ class Gateway:
if not self.url_base: if not self.url_base:
self.url_base = self.base_address.decode() self.url_base = self.base_address.decode()
if response: if response:
device_dict = get_dict_val_case_insensitive(response, "device")
self._device = Device( self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(response, "device") self._devices, self._services, **device_dict
) )
else: else:
self._device = Device(self._devices, self._services) self._device = Device(self._devices, self._services)

View file

@ -13,6 +13,10 @@ XML_ROOT_SANITY_PATTERN = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))" "(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
) )
XML_OTHER_KEYS = re.compile(
"{[\w|\:\/\.]*}|(\w*)"
)
def serialize_scpd_get(path: str, address: str) -> bytes: def serialize_scpd_get(path: str, address: str) -> bytes:
if "http://" in address: if "http://" in address:
@ -39,13 +43,31 @@ def deserialize_scpd_get_response(content: bytes) -> Dict:
parsed = CONTENT_PATTERN.findall(content) parsed = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0] content = b'' if not parsed else parsed[0][0]
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode())) xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
schema_key = DEVICE return parse_device_dict(xml_dict)
root = ROOT return {}
for k in xml_dict.keys():
def parse_device_dict(xml_dict: dict) -> Dict:
keys = list(xml_dict.keys())
for k in keys:
m = XML_ROOT_SANITY_PATTERN.findall(k) m = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]: if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0] schema_key = m[1][0]
root = m[2][5] root = m[2][5]
break xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
return flatten_keys(xml_dict, "{%s}" % schema_key)[root] result = {}
return {} for k, v in xml_dict.items():
if isinstance(xml_dict[k], dict):
inner_d = {}
for inner_k, inner_v in xml_dict[k].items():
parsed_k = XML_OTHER_KEYS.findall(inner_k)
if len(parsed_k) == 2:
inner_d[parsed_k[0]] = inner_v
else:
assert len(parsed_k) == 3
inner_d[parsed_k[1]] = inner_v
result[k] = inner_d
else:
result[k] = v
return result

View file

@ -5,91 +5,18 @@ import mock
@contextlib.contextmanager @contextlib.contextmanager
def mock_datagram_endpoint_factory(loop, expected_addr, replies=None, delay_reply=0.0, sent_packets=None): def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
sent_packets = sent_packets if sent_packets is not None else [] tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None):
replies = replies or {}
def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr):
sent_packets.append(data)
if (data, addr) in replies:
loop.call_later(delay_reply, p.datagram_received, replies[(data, addr)], (expected_addr, 1900))
return _sendto
async def create_datagram_endpoint(proto_lam, sock=None):
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto(protocol)
transport.sendto = mock_sock.sendto
protocol.connection_made(transport)
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_DGRAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_datagram_endpoint = create_datagram_endpoint
yield
@contextlib.contextmanager
def mock_tcp_endpoint_factory(loop, replies=None, delay_reply=0.0, sent_packets=None):
sent_packets = sent_packets if sent_packets is not None else []
replies = replies or {}
def write(p: asyncio.Protocol):
def _write(data):
sent_packets.append(data)
if data in replies:
loop.call_later(delay_reply, p.data_received, replies[data])
return _write
async def create_connection(protocol_factory, host=None, port=None):
protocol = protocol_factory()
transport = asyncio.Transport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.write = write(protocol)
transport.write = mock_sock.write
protocol.connection_made(transport)
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_STREAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_connection = create_connection
yield
@contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, tcp_sent_packets=None):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {} udp_replies = udp_replies or {}
tcp_sent_packets = tcp_sent_packets if tcp_sent_packets is not None else [] sent_tcp_packets = sent_tcp_packets if sent_tcp_packets is not None else []
tcp_replies = tcp_replies or {} tcp_replies = tcp_replies or {}
async def create_connection(protocol_factory, host=None, port=None): async def create_connection(protocol_factory, host=None, port=None):
def write(p: asyncio.Protocol): def write(p: asyncio.Protocol):
def _write(data): def _write(data):
tcp_sent_packets.append(data) sent_tcp_packets.append(data)
if data in tcp_replies: if data in tcp_replies:
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data]) loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])

View file

@ -1,7 +1,7 @@
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.protocols.scpd import scpd_post, scpd_get from aioupnp.protocols.scpd import scpd_post, scpd_get
from tests import TestBase from tests import TestBase
from tests.mocks import mock_tcp_endpoint_factory from tests.mocks import mock_tcp_and_udp
class TestSCPDGet(TestBase): class TestSCPDGet(TestBase):
@ -107,7 +107,7 @@ class TestSCPDGet(TestBase):
async def test_scpd_get(self): async def test_scpd_get(self):
sent = [] sent = []
replies = {self.get_request: self.response} replies = {self.get_request: self.response}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop) result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
self.assertEqual(None, err) self.assertEqual(None, err)
self.assertDictEqual(self.expected_parsed, result) self.assertDictEqual(self.expected_parsed, result)
@ -115,7 +115,7 @@ class TestSCPDGet(TestBase):
async def test_scpd_get_timeout(self): async def test_scpd_get_timeout(self):
sent = [] sent = []
replies = {} replies = {}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop) result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
self.assertTrue(isinstance(err, UPnPError)) self.assertTrue(isinstance(err, UPnPError))
self.assertDictEqual({}, result) self.assertDictEqual({}, result)
@ -124,7 +124,7 @@ class TestSCPDGet(TestBase):
async def test_scpd_get_bad_xml(self): async def test_scpd_get_bad_xml(self):
sent = [] sent = []
replies = {self.get_request: self.bad_response} replies = {self.get_request: self.bad_response}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop) result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
self.assertDictEqual({}, result) self.assertDictEqual({}, result)
self.assertEqual(self.bad_xml, raw) self.assertEqual(self.bad_xml, raw)
@ -134,7 +134,7 @@ class TestSCPDGet(TestBase):
async def test_scpd_get_overrun_content_length(self): async def test_scpd_get_overrun_content_length(self):
sent = [] sent = []
replies = {self.get_request: self.bad_response + b'\r\n'} replies = {self.get_request: self.bad_response + b'\r\n'}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop) result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
self.assertDictEqual({}, result) self.assertDictEqual({}, result)
self.assertEqual(self.bad_response + b'\r\n', raw) self.assertEqual(self.bad_response + b'\r\n', raw)
@ -183,7 +183,7 @@ class TestSCPDPost(TestBase):
async def test_scpd_post(self): async def test_scpd_post(self):
sent = [] sent = []
replies = {self.post_bytes: self.post_response} replies = {self.post_bytes: self.post_response}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_post( result, raw, err = await scpd_post(
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
) )
@ -194,7 +194,7 @@ class TestSCPDPost(TestBase):
async def test_scpd_post_timeout(self): async def test_scpd_post_timeout(self):
sent = [] sent = []
replies = {} replies = {}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_post( result, raw, err = await scpd_post(
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
) )
@ -206,7 +206,7 @@ class TestSCPDPost(TestBase):
async def test_scpd_post_bad_xml_response(self): async def test_scpd_post_bad_xml_response(self):
sent = [] sent = []
replies = {self.post_bytes: self.bad_envelope_response} replies = {self.post_bytes: self.bad_envelope_response}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_post( result, raw, err = await scpd_post(
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
) )
@ -218,7 +218,7 @@ class TestSCPDPost(TestBase):
async def test_scpd_post_overrun_response(self): async def test_scpd_post_overrun_response(self):
sent = [] sent = []
replies = {self.post_bytes: self.post_response + b'\r\n'} replies = {self.post_bytes: self.post_response + b'\r\n'}
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_post( result, raw, err = await scpd_post(
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
) )

View file

@ -5,7 +5,7 @@ from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from tests import TestBase from tests import TestBase
from tests.mocks import mock_datagram_endpoint_factory from tests.mocks import mock_tcp_and_udp
class TestSSDP(TestBase): class TestSSDP(TestBase):
@ -35,14 +35,14 @@ class TestSSDP(TestBase):
} }
sent = [] sent = []
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True) reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
self.assertEqual(reply.encode(), self.reply_packet.encode()) self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertListEqual(sent, [self.query_packet.encode().encode()]) self.assertListEqual(sent, [self.query_packet.encode().encode()])
with self.assertRaises(UPnPError): with self.assertRaises(UPnPError):
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies): with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False) await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False)
async def test_m_search_reply_multicast(self): async def test_m_search_reply_multicast(self):
@ -51,21 +51,21 @@ class TestSSDP(TestBase):
} }
sent = [] sent = []
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode()) self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertListEqual(sent, [self.query_packet.encode().encode()]) self.assertListEqual(sent, [self.query_packet.encode().encode()])
with self.assertRaises(UPnPError): with self.assertRaises(UPnPError):
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies): with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1"):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True) await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
async def test_packets_sent_fuzzy_m_search(self): async def test_packets_sent_fuzzy_m_search(self):
sent = [] sent = []
with self.assertRaises(UPnPError): with self.assertRaises(UPnPError):
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", sent_packets=sent): with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertListEqual(sent, self.byte_packets) self.assertListEqual(sent, self.byte_packets)
@ -76,7 +76,7 @@ class TestSSDP(TestBase):
} }
sent = [] sent = []
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent): with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies, sent_udp_packets=sent):
args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode()) self.assertEqual(reply.encode(), self.reply_packet.encode())

File diff suppressed because one or more lines are too long