fix parsing Netgear Nighthawk AC2350 xml, update tests
This commit is contained in:
parent
cfce30fa3a
commit
475a0c7738
7 changed files with 119 additions and 108 deletions
|
@ -8,7 +8,6 @@ class CaseInsensitive:
|
|||
def __init__(self, **kwargs) -> None:
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith("_"):
|
||||
getattr(self, k)
|
||||
setattr(self, k, v)
|
||||
|
||||
def __getattr__(self, item):
|
||||
|
@ -22,6 +21,9 @@ class CaseInsensitive:
|
|||
if k.lower() == item.lower():
|
||||
self.__dict__[k] = value
|
||||
return
|
||||
if not item.startswith("_"):
|
||||
self.__dict__[item] = value
|
||||
return
|
||||
raise AttributeError(item)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
|
|
|
@ -221,8 +221,9 @@ class Gateway:
|
|||
if not self.url_base:
|
||||
self.url_base = self.base_address.decode()
|
||||
if response:
|
||||
device_dict = get_dict_val_case_insensitive(response, "device")
|
||||
self._device = Device(
|
||||
self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
|
||||
self._devices, self._services, **device_dict
|
||||
)
|
||||
else:
|
||||
self._device = Device(self._devices, self._services)
|
||||
|
|
|
@ -13,6 +13,10 @@ XML_ROOT_SANITY_PATTERN = re.compile(
|
|||
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
|
||||
)
|
||||
|
||||
XML_OTHER_KEYS = re.compile(
|
||||
"{[\w|\:\/\.]*}|(\w*)"
|
||||
)
|
||||
|
||||
|
||||
def serialize_scpd_get(path: str, address: str) -> bytes:
|
||||
if "http://" in address:
|
||||
|
@ -39,13 +43,31 @@ def deserialize_scpd_get_response(content: bytes) -> Dict:
|
|||
parsed = CONTENT_PATTERN.findall(content)
|
||||
content = b'' if not parsed else parsed[0][0]
|
||||
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
|
||||
schema_key = DEVICE
|
||||
root = ROOT
|
||||
for k in xml_dict.keys():
|
||||
m = XML_ROOT_SANITY_PATTERN.findall(k)
|
||||
if len(m) == 3 and m[1][0] and m[2][5]:
|
||||
schema_key = m[1][0]
|
||||
root = m[2][5]
|
||||
break
|
||||
return flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
||||
return parse_device_dict(xml_dict)
|
||||
return {}
|
||||
|
||||
|
||||
def parse_device_dict(xml_dict: dict) -> Dict:
|
||||
keys = list(xml_dict.keys())
|
||||
for k in keys:
|
||||
m = XML_ROOT_SANITY_PATTERN.findall(k)
|
||||
if len(m) == 3 and m[1][0] and m[2][5]:
|
||||
schema_key = m[1][0]
|
||||
root = m[2][5]
|
||||
xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
||||
result = {}
|
||||
for k, v in xml_dict.items():
|
||||
if isinstance(xml_dict[k], dict):
|
||||
inner_d = {}
|
||||
for inner_k, inner_v in xml_dict[k].items():
|
||||
parsed_k = XML_OTHER_KEYS.findall(inner_k)
|
||||
if len(parsed_k) == 2:
|
||||
inner_d[parsed_k[0]] = inner_v
|
||||
else:
|
||||
assert len(parsed_k) == 3
|
||||
inner_d[parsed_k[1]] = inner_v
|
||||
result[k] = inner_d
|
||||
else:
|
||||
result[k] = v
|
||||
|
||||
return result
|
||||
|
|
|
@ -5,91 +5,18 @@ import mock
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_datagram_endpoint_factory(loop, expected_addr, replies=None, delay_reply=0.0, sent_packets=None):
|
||||
sent_packets = sent_packets if sent_packets is not None else []
|
||||
replies = replies or {}
|
||||
|
||||
def sendto(p: asyncio.DatagramProtocol):
|
||||
def _sendto(data, addr):
|
||||
sent_packets.append(data)
|
||||
if (data, addr) in replies:
|
||||
loop.call_later(delay_reply, p.datagram_received, replies[(data, addr)], (expected_addr, 1900))
|
||||
return _sendto
|
||||
|
||||
async def create_datagram_endpoint(proto_lam, sock=None):
|
||||
protocol = proto_lam()
|
||||
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
|
||||
transport.close = lambda: mock_sock.close()
|
||||
mock_sock.sendto = sendto(protocol)
|
||||
transport.sendto = mock_sock.sendto
|
||||
protocol.connection_made(transport)
|
||||
return transport, protocol
|
||||
|
||||
with mock.patch('socket.socket') as mock_socket:
|
||||
mock_sock = mock.Mock(spec=socket.socket)
|
||||
mock_sock.setsockopt = lambda *_: None
|
||||
mock_sock.bind = lambda *_: None
|
||||
mock_sock.setblocking = lambda *_: None
|
||||
mock_sock.getsockname = lambda: "0.0.0.0"
|
||||
mock_sock.getpeername = lambda: ""
|
||||
mock_sock.close = lambda: None
|
||||
mock_sock.type = socket.SOCK_DGRAM
|
||||
mock_sock.fileno = lambda: 7
|
||||
|
||||
mock_socket.return_value = mock_sock
|
||||
loop.create_datagram_endpoint = create_datagram_endpoint
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_tcp_endpoint_factory(loop, replies=None, delay_reply=0.0, sent_packets=None):
|
||||
sent_packets = sent_packets if sent_packets is not None else []
|
||||
replies = replies or {}
|
||||
|
||||
def write(p: asyncio.Protocol):
|
||||
def _write(data):
|
||||
sent_packets.append(data)
|
||||
if data in replies:
|
||||
loop.call_later(delay_reply, p.data_received, replies[data])
|
||||
return _write
|
||||
|
||||
async def create_connection(protocol_factory, host=None, port=None):
|
||||
protocol = protocol_factory()
|
||||
transport = asyncio.Transport(extra={'socket': mock_sock})
|
||||
transport.close = lambda: mock_sock.close()
|
||||
mock_sock.write = write(protocol)
|
||||
transport.write = mock_sock.write
|
||||
protocol.connection_made(transport)
|
||||
return transport, protocol
|
||||
|
||||
with mock.patch('socket.socket') as mock_socket:
|
||||
mock_sock = mock.Mock(spec=socket.socket)
|
||||
mock_sock.setsockopt = lambda *_: None
|
||||
mock_sock.bind = lambda *_: None
|
||||
mock_sock.setblocking = lambda *_: None
|
||||
mock_sock.getsockname = lambda: "0.0.0.0"
|
||||
mock_sock.getpeername = lambda: ""
|
||||
mock_sock.close = lambda: None
|
||||
mock_sock.type = socket.SOCK_STREAM
|
||||
mock_sock.fileno = lambda: 7
|
||||
|
||||
mock_socket.return_value = mock_sock
|
||||
loop.create_connection = create_connection
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_tcp_and_udp(loop, udp_expected_addr, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||
tcp_replies=None, tcp_delay_reply=0.0, tcp_sent_packets=None):
|
||||
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None):
|
||||
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
||||
udp_replies = udp_replies or {}
|
||||
|
||||
tcp_sent_packets = tcp_sent_packets if tcp_sent_packets is not None else []
|
||||
sent_tcp_packets = sent_tcp_packets if sent_tcp_packets is not None else []
|
||||
tcp_replies = tcp_replies or {}
|
||||
|
||||
async def create_connection(protocol_factory, host=None, port=None):
|
||||
def write(p: asyncio.Protocol):
|
||||
def _write(data):
|
||||
tcp_sent_packets.append(data)
|
||||
sent_tcp_packets.append(data)
|
||||
if data in tcp_replies:
|
||||
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.protocols.scpd import scpd_post, scpd_get
|
||||
from tests import TestBase
|
||||
from tests.mocks import mock_tcp_endpoint_factory
|
||||
from tests.mocks import mock_tcp_and_udp
|
||||
|
||||
|
||||
class TestSCPDGet(TestBase):
|
||||
|
@ -107,7 +107,7 @@ class TestSCPDGet(TestBase):
|
|||
async def test_scpd_get(self):
|
||||
sent = []
|
||||
replies = {self.get_request: self.response}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
|
||||
self.assertEqual(None, err)
|
||||
self.assertDictEqual(self.expected_parsed, result)
|
||||
|
@ -115,7 +115,7 @@ class TestSCPDGet(TestBase):
|
|||
async def test_scpd_get_timeout(self):
|
||||
sent = []
|
||||
replies = {}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
|
||||
self.assertTrue(isinstance(err, UPnPError))
|
||||
self.assertDictEqual({}, result)
|
||||
|
@ -124,7 +124,7 @@ class TestSCPDGet(TestBase):
|
|||
async def test_scpd_get_bad_xml(self):
|
||||
sent = []
|
||||
replies = {self.get_request: self.bad_response}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
|
||||
self.assertDictEqual({}, result)
|
||||
self.assertEqual(self.bad_xml, raw)
|
||||
|
@ -134,7 +134,7 @@ class TestSCPDGet(TestBase):
|
|||
async def test_scpd_get_overrun_content_length(self):
|
||||
sent = []
|
||||
replies = {self.get_request: self.bad_response + b'\r\n'}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
|
||||
self.assertDictEqual({}, result)
|
||||
self.assertEqual(self.bad_response + b'\r\n', raw)
|
||||
|
@ -183,7 +183,7 @@ class TestSCPDPost(TestBase):
|
|||
async def test_scpd_post(self):
|
||||
sent = []
|
||||
replies = {self.post_bytes: self.post_response}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_post(
|
||||
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
|
||||
)
|
||||
|
@ -194,7 +194,7 @@ class TestSCPDPost(TestBase):
|
|||
async def test_scpd_post_timeout(self):
|
||||
sent = []
|
||||
replies = {}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_post(
|
||||
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
|
||||
)
|
||||
|
@ -206,7 +206,7 @@ class TestSCPDPost(TestBase):
|
|||
async def test_scpd_post_bad_xml_response(self):
|
||||
sent = []
|
||||
replies = {self.post_bytes: self.bad_envelope_response}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_post(
|
||||
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
|
||||
)
|
||||
|
@ -218,7 +218,7 @@ class TestSCPDPost(TestBase):
|
|||
async def test_scpd_post_overrun_response(self):
|
||||
sent = []
|
||||
replies = {self.post_bytes: self.post_response + b'\r\n'}
|
||||
with mock_tcp_endpoint_factory(self.loop, replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
|
||||
result, raw, err = await scpd_post(
|
||||
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
|
||||
)
|
||||
|
|
|
@ -5,7 +5,7 @@ from aioupnp.serialization.ssdp import SSDPDatagram
|
|||
from aioupnp.constants import SSDP_IP_ADDRESS
|
||||
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
|
||||
from tests import TestBase
|
||||
from tests.mocks import mock_datagram_endpoint_factory
|
||||
from tests.mocks import mock_tcp_and_udp
|
||||
|
||||
|
||||
class TestSSDP(TestBase):
|
||||
|
@ -35,14 +35,14 @@ class TestSSDP(TestBase):
|
|||
}
|
||||
sent = []
|
||||
|
||||
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
|
||||
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
|
||||
|
||||
self.assertEqual(reply.encode(), self.reply_packet.encode())
|
||||
self.assertListEqual(sent, [self.query_packet.encode().encode()])
|
||||
|
||||
with self.assertRaises(UPnPError):
|
||||
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies):
|
||||
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies):
|
||||
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False)
|
||||
|
||||
async def test_m_search_reply_multicast(self):
|
||||
|
@ -51,21 +51,21 @@ class TestSSDP(TestBase):
|
|||
}
|
||||
sent = []
|
||||
|
||||
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
|
||||
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
|
||||
|
||||
self.assertEqual(reply.encode(), self.reply_packet.encode())
|
||||
self.assertListEqual(sent, [self.query_packet.encode().encode()])
|
||||
|
||||
with self.assertRaises(UPnPError):
|
||||
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies):
|
||||
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1"):
|
||||
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
|
||||
|
||||
async def test_packets_sent_fuzzy_m_search(self):
|
||||
sent = []
|
||||
|
||||
with self.assertRaises(UPnPError):
|
||||
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
|
||||
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
|
||||
|
||||
self.assertListEqual(sent, self.byte_packets)
|
||||
|
@ -76,7 +76,7 @@ class TestSSDP(TestBase):
|
|||
}
|
||||
sent = []
|
||||
|
||||
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent):
|
||||
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies, sent_udp_packets=sent):
|
||||
args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
|
||||
|
||||
self.assertEqual(reply.encode(), self.reply_packet.encode())
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in a new issue