diff --git a/aioupnp/commands.py b/aioupnp/commands.py index 34e4dc2..a254216 100644 --- a/aioupnp/commands.py +++ b/aioupnp/commands.py @@ -4,6 +4,7 @@ import typing import logging from aioupnp.protocols.scpd import scpd_post from aioupnp.device import Service +from aioupnp.fault import UPnPError log = logging.getLogger(__name__) @@ -45,17 +46,28 @@ class SCPDRequestDebuggingInfo(typing.NamedTuple): ts: float -def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]], +def recast_return(return_annotation, result: typing.Union[str, int, bool, typing.Dict[str, typing.Union[int, str]]], result_keys: typing.List[str]) -> typing.Optional[ typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]: if len(result_keys) == 1: - single_result = result[result_keys[0]] + if isinstance(result, (str, int, bool)): + single_result = result + else: + if result_keys[0] in result: + single_result = result[result_keys[0]] + else: # check for the field having incorrect capitalization + flattened = {k.lower(): v for k, v in result.items()} + if result_keys[0].lower() in flattened: + single_result = flattened[result_keys[0].lower()] + else: + raise UPnPError(f"expected response key {result_keys[0]}, got {list(result.keys())}") if return_annotation is bool: return soap_bool(single_result) if return_annotation is str: return soap_optional_str(single_result) - return int(result[result_keys[0]]) if result_keys[0] in result else None + return None if single_result is None else int(single_result) elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]: + assert isinstance(result, dict) arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types assert len(arg_types) == len(result_keys) recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {} diff --git a/tests/test_upnp.py b/tests/test_upnp.py index bd93570..96527ba 100644 --- a/tests/test_upnp.py +++ b/tests/test_upnp.py @@ -51,6 +51,43 @@ class TestGetExternalIPAddress(UPnPCommandTestCase): self.assertEqual("11.222.3.44", external_ip) +class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase): + client_address = '11.2.3.222' + get_ip_request = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\nHost: 11.2.3.4\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\nContent-Length: 285\r\nContent-Type: text/xml\r\nSOAPAction: "urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress"\r\nConnection: Close\r\nCache-Control: no-cache\r\nPragma: no-cache\r\n\r\n\r\n\r\n' + + async def test_response_key_mismatch(self): + self.replies.update({self.get_ip_request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 333 \r\nEXT:\r\n\r\n\n\n\t\n\t\t\n" + b"11.222.3.44\n\n\t\n\n"}) + self.addCleanup(self.replies.pop, self.get_ip_request) + with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): + gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) + await gateway.discover_commands(self.loop) + upnp = UPnP(self.client_address, self.gateway_address, gateway) + with self.assertRaises(UPnPError): + await upnp.get_external_ip() + + async def test_response_key_case_sensitivity(self): + self.replies.update({self.get_ip_request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 365 \r\nEXT:\r\n\r\n\n\n\t\n\t\t\n" + b"11.222.3.44\n\n\t\n\n"}) + self.addCleanup(self.replies.pop, self.get_ip_request) + with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): + gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) + await gateway.discover_commands(self.loop) + upnp = UPnP(self.client_address, self.gateway_address, gateway) + external_ip = await upnp.get_external_ip() + self.assertEqual("11.222.3.44", external_ip) + + async def test_non_encapsulated_single_field_response(self): + self.replies.update({self.get_ip_request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 320 \r\nEXT:\r\n\r\n\n\n\t\n\t\t\n" + b"11.222.3.44\n\n\t\n\n"}) + self.addCleanup(self.replies.pop, self.get_ip_request) + with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): + gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) + await gateway.discover_commands(self.loop) + upnp = UPnP(self.client_address, self.gateway_address, gateway) + external_ip = await upnp.get_external_ip() + self.assertEqual("11.222.3.44", external_ip) + class TestGetGenericPortMappingEntry(UPnPCommandTestCase): def setUp(self) -> None: query = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\nHost: 11.2.3.4\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\nContent-Length: 341\r\nContent-Type: text/xml\r\nSOAPAction: "urn:schemas-upnp-org:service:WANIPConnection:1#GetGenericPortMappingEntry"\r\nConnection: Close\r\nCache-Control: no-cache\r\nPragma: no-cache\r\n\r\n\r\n0\r\n'