Fix unhandled errors #15

Merged
jackrobison merged 4 commits from fix-unhandled-errors into master 2019-08-15 23:00:49 +02:00
8 changed files with 305 additions and 197 deletions

View file

@ -4,6 +4,7 @@ import typing
import logging import logging
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
from aioupnp.device import Service from aioupnp.device import Service
from aioupnp.fault import UPnPError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,17 +36,38 @@ class GetGenericPortMappingEntryResponse(typing.NamedTuple):
lease_time: int lease_time: int
def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]], class SCPDRequestDebuggingInfo(typing.NamedTuple):
method: str
kwargs: typing.Dict[str, typing.Union[str, int, bool]]
response_xml: bytes
result: typing.Optional[typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse,
GetGenericPortMappingEntryResponse]]
err: typing.Optional[Exception]
ts: float
def recast_return(return_annotation, result: typing.Union[str, int, bool, typing.Dict[str, typing.Union[int, str]]],
result_keys: typing.List[str]) -> typing.Optional[ result_keys: typing.List[str]) -> typing.Optional[
typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]: typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
if len(result_keys) == 1: if len(result_keys) == 1:
single_result = result[result_keys[0]] if isinstance(result, (str, int, bool)):
single_result = result
else:
if result_keys[0] in result:
single_result = result[result_keys[0]]
else: # check for the field having incorrect capitalization
flattened = {k.lower(): v for k, v in result.items()}
if result_keys[0].lower() in flattened:
single_result = flattened[result_keys[0].lower()]
else:
raise UPnPError(f"expected response key {result_keys[0]}, got {list(result.keys())}")
if return_annotation is bool: if return_annotation is bool:
return soap_bool(single_result) return soap_bool(single_result)
if return_annotation is str: if return_annotation is str:
return soap_optional_str(single_result) return soap_optional_str(single_result)
return int(result[result_keys[0]]) if result_keys[0] in result else None return None if single_result is None else int(single_result)
elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]: elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]:
assert isinstance(result, dict)
arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types
assert len(arg_types) == len(result_keys) assert len(arg_types) == len(result_keys)
recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {} recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {}
@ -108,11 +130,7 @@ class SOAPCommands:
self._base_address = base_address self._base_address = base_address
self._port = port self._port = port
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, self._request_debug_infos: typing.List[SCPDRequestDebuggingInfo] = []
typing.Optional[typing.Union[str, int, bool,
GetSpecificPortMappingEntryResponse,
GetGenericPortMappingEntryResponse]],
typing.Optional[Exception], float]] = []
def is_registered(self, name: str) -> bool: def is_registered(self, name: str) -> bool:
if name not in self.SOAP_COMMANDS: if name not in self.SOAP_COMMANDS:
@ -147,11 +165,17 @@ class SOAPCommands:
) )
if err is not None: if err is not None:
assert isinstance(xml_bytes, bytes) assert isinstance(xml_bytes, bytes)
self._requests.append((name, kwargs, xml_bytes, None, err, time.time())) self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, None, err, time.time()))
raise err raise err
assert 'return' in annotations assert 'return' in annotations
result = recast_return(annotations['return'], response, output_names) try:
self._requests.append((name, kwargs, xml_bytes, result, None, time.time())) result = recast_return(annotations['return'], response, output_names)
self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, result, None, time.time()))
except Exception as err:
if isinstance(err, asyncio.CancelledError):
raise
self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, None, err, time.time()))
raise UPnPError(f"Raised {str(type(err).__name__)}({str(err)}) parsing response for {name}")
return result return result
if not len(list(k for k in annotations if k != 'return')): if not len(list(k for k in annotations if k != 'return')):

View file

@ -6,7 +6,7 @@ from collections import OrderedDict
from typing import Dict, List from typing import Dict, List
from aioupnp.util import get_dict_val_case_insensitive from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands, SCPDRequestDebuggingInfo
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.scpd import scpd_get
@ -85,7 +85,7 @@ class Gateway:
self.urn: bytes = (ok_packet.st or '').encode() self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response: bytes = b"" self._xml_response: bytes = b""
self._service_descriptors: Dict[str, bytes] = {} self._service_descriptors: Dict[str, str] = {}
self.base_address, self.port = parse_location(self.location) self.base_address, self.port = parse_location(self.location)
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
@ -103,17 +103,6 @@ class Gateway:
self._registered_commands: Dict[str, str] = {} self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands(self._loop, self.base_ip, self.port) self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
# def gateway_descriptor(self) -> dict:
# r = {
# 'server': self.server.decode(),
# 'urlBase': self.url_base,
# 'location': self.location.decode(),
# "specVersion": self.spec_version,
# 'usn': self.usn.decode(),
# 'urn': self.urn.decode(),
# }
# return r
@property @property
def manufacturer_string(self) -> str: def manufacturer_string(self) -> str:
manufacturer_string = "UNKNOWN GATEWAY" manufacturer_string = "UNKNOWN GATEWAY"
@ -147,37 +136,26 @@ class Gateway:
# return service # return service
# return None # return None
# @property def debug_gateway(self) -> Dict[str, typing.Union[str, bytes, int, Dict, List]]:
# def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, return {
# typing.Optional[typing.Tuple], 'manufacturer_string': self.manufacturer_string,
# typing.Optional[Exception], float]]: 'gateway_address': self.base_ip.decode(),
# soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, 'server': self.server.decode(),
# typing.Optional[typing.Tuple], 'urlBase': self.url_base or '',
# typing.Optional[Exception], float]] = [] 'location': self.location.decode(),
# soap_call_infos.extend([ "specVersion": self.spec_version or '',
# (name, request_args, raw_response, decoded_response, soap_error, ts) 'usn': self.usn.decode(),
# for ( 'urn': self.urn.decode(),
# name, request_args, raw_response, decoded_response, soap_error, ts 'gateway_xml': self._xml_response.decode(),
# ) in self.commands._requests 'services_xml': self._service_descriptors,
# ]) 'services': {service.SCPDURL: service.as_dict() for service in self._services},
# soap_call_infos.sort(key=lambda x: x[5]) 'm_search_args': OrderedDict(self._m_search_args),
# return soap_call_infos 'reply': self._ok_packet.as_dict(),
'soap_port': self.port,
# def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]: 'registered_soap_commands': self._registered_commands,
# return { 'unsupported_soap_commands': self._unsupported_actions,
# 'manufacturer_string': self.manufacturer_string, 'soap_requests': list(self.commands._request_debug_infos)
# 'gateway_address': self.base_ip, }
# 'gateway_descriptor': self.gateway_descriptor(),
# 'gateway_xml': self._xml_response,
# 'services_xml': self._service_descriptors,
# 'services': {service.SCPDURL: service.as_dict() for service in self._services},
# 'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
# 'reply': self._ok_packet.as_dict(),
# 'soap_port': self.port,
# 'registered_soap_commands': self._registered_commands,
# 'unsupported_soap_commands': self._unsupported_actions,
# 'soap_requests': self.soap_requests
# }
@classmethod @classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
@ -201,7 +179,7 @@ class Gateway:
try: try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop) gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location) log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands(loop) await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands]) requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met: if not requirements_met:
not_met = [ not_met = [
@ -249,8 +227,10 @@ class Gateway:
results: typing.List['asyncio.Future[Gateway]'] = list(done) results: typing.List['asyncio.Future[Gateway]'] = list(done)
return results[0].result() return results[0].result()
async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: async def discover_commands(self) -> None:
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop) response, xml_bytes, get_err = await scpd_get(
self.path.decode(), self.base_ip.decode(), self.port, loop=self._loop
)
self._xml_response = xml_bytes self._xml_response = xml_bytes
if get_err is not None: if get_err is not None:
raise get_err raise get_err
@ -286,7 +266,7 @@ class Gateway:
else: else:
self._device = Device(self._devices, self._services) self._device = Device(self._devices, self._services)
for service_type in self.services.keys(): for service_type in self.services.keys():
await self.register_commands(self.services[service_type], loop) await self.register_commands(self.services[service_type], self._loop)
return None return None
async def register_commands(self, service: Service, async def register_commands(self, service: Service,
@ -298,7 +278,7 @@ class Gateway:
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop) service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
self._service_descriptors[service.SCPDURL] = xml_bytes self._service_descriptors[service.SCPDURL] = xml_bytes.decode()
if get_err is not None: if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)

View file

@ -45,7 +45,7 @@ class SCPDHTTPClientProtocol(Protocol):
and devices respond with an invalid HTTP version line and devices respond with an invalid HTTP version line
""" """
def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]', def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]',
soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None: soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None:
self.message = message self.message = message
self.response_buff = b"" self.response_buff = b""
@ -85,7 +85,7 @@ class SCPDHTTPClientProtocol(Protocol):
self._got_headers = True self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:]) body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
if self._content_length == len(body): if self._content_length == len(body):
self.finished.set_result((body, self._response_code, self._response_msg)) self.finished.set_result((self.response_buff, body, self._response_code, self._response_msg))
elif self._content_length > len(body): elif self._content_length > len(body):
pass pass
else: else:
@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int,
typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]: typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]' = asyncio.Future(loop=loop) finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished) proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port proto_factory, address, port
@ -115,24 +115,25 @@ async def scpd_get(control_url: str, address: str, port: int,
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
error = None error = None
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop) wait_task: typing.Awaitable[typing.Tuple[bytes, bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
body = b''
raw_response = b''
try: try:
body, response_code, response_msg = await wait_task raw_response, body, response_code, response_msg = await wait_task
except asyncio.TimeoutError: except asyncio.TimeoutError:
error = UPnPError("get request timed out") error = UPnPError("get request timed out")
body = b''
except UPnPError as err: except UPnPError as err:
error = err error = err
body = protocol.response_buff raw_response = protocol.response_buff
finally: finally:
transport.close() transport.close()
if not error: if not error:
try: try:
return deserialize_scpd_get_response(body), body, None return deserialize_scpd_get_response(body), raw_response, None
except Exception as err: except Exception as err:
error = UPnPError(err) error = UPnPError(err)
return {}, body, error return {}, raw_response, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
@ -140,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
**kwargs: typing.Dict[str, typing.Any] **kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: ) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]' = asyncio.Future(loop=loop) finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\ proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode()) SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
@ -152,18 +153,17 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
try: try:
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop) wait_task: typing.Awaitable[typing.Tuple[bytes, bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
body, response_code, response_msg = await wait_task raw_response, body, response_code, response_msg = await wait_task
except asyncio.TimeoutError: except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout") return {}, b'', UPnPError("Timeout")
except UPnPError as err: except UPnPError as err:
return {}, protocol.response_buff, err return {}, protocol.response_buff, err
finally: finally:
# raw_response = protocol.response_buff
transport.close() transport.close()
try: try:
return ( return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None deserialize_soap_post_response(body, method, service_id.decode()), raw_response, None
) )
except Exception as err: except Exception as err:
return {}, body, UPnPError(err) return {}, raw_response, UPnPError(err)

View file

@ -1,5 +1,6 @@
import re import re
import typing import typing
import json
from aioupnp.util import flatten_keys from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL
@ -54,7 +55,10 @@ def deserialize_soap_post_response(response: bytes, method: str,
fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys( fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
response_body[FAULT], "{%s}" % CONTROL response_body[FAULT], "{%s}" % CONTROL
) )
raise UPnPError(fault['detail']['UPnPError']['errorDescription']) try:
raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
except (KeyError, TypeError, ValueError):
raise UPnPError(f"Failed to decode error response: {json.dumps(fault)}")
response_key = None response_key = None
for key in response_body: for key in response_body:
if method in key: if method in key:

View file

@ -126,7 +126,7 @@ class TestSCPDGet(AsyncioTestCase):
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent): with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop) result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
self.assertDictEqual({}, result) self.assertDictEqual({}, result)
self.assertEqual(self.bad_xml, raw) self.assertEqual(self.bad_response, raw)
self.assertTrue(isinstance(err, UPnPError)) self.assertTrue(isinstance(err, UPnPError))
self.assertTrue(str(err).startswith('no element found')) self.assertTrue(str(err).startswith('no element found'))
@ -187,7 +187,7 @@ class TestSCPDPost(AsyncioTestCase):
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
) )
self.assertEqual(None, err) self.assertEqual(None, err)
self.assertEqual(self.envelope, raw) self.assertEqual(self.post_response, raw)
self.assertDictEqual({'NewExternalIPAddress': '11.22.33.44'}, result) self.assertDictEqual({'NewExternalIPAddress': '11.22.33.44'}, result)
async def test_scpd_post_timeout(self): async def test_scpd_post_timeout(self):
@ -211,7 +211,7 @@ class TestSCPDPost(AsyncioTestCase):
) )
self.assertTrue(isinstance(err, UPnPError)) self.assertTrue(isinstance(err, UPnPError))
self.assertTrue(str(err).startswith('no element found')) self.assertTrue(str(err).startswith('no element found'))
self.assertEqual(self.bad_envelope, raw) self.assertEqual(self.bad_envelope_response, raw)
self.assertDictEqual({}, result) self.assertDictEqual({}, result)
async def test_scpd_post_overrun_response(self): async def test_scpd_post_overrun_response(self):

View file

@ -58,6 +58,16 @@ class TestSOAPSerialization(unittest.TestCase):
b"\r\n" \ b"\r\n" \
b"<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<s:Fault>\n\t\t\t<faultcode>s:Client</faultcode>\n\t\t\t<faultstring>UPnPError</faultstring>\n\t\t\t<detail>\n\t\t\t\t<UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\">\n\t\t\t\t\t<errorCode>713</errorCode>\n\t\t\t\t\t<errorDescription>SpecifiedArrayIndexInvalid</errorDescription>\n\t\t\t\t</UPnPError>\n\t\t\t</detail>\n\t\t</s:Fault>\n\t</s:Body>\n</s:Envelope>\n" b"<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<s:Fault>\n\t\t\t<faultcode>s:Client</faultcode>\n\t\t\t<faultstring>UPnPError</faultstring>\n\t\t\t<detail>\n\t\t\t\t<UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\">\n\t\t\t\t\t<errorCode>713</errorCode>\n\t\t\t\t\t<errorDescription>SpecifiedArrayIndexInvalid</errorDescription>\n\t\t\t\t</UPnPError>\n\t\t\t</detail>\n\t\t</s:Fault>\n\t</s:Body>\n</s:Envelope>\n"
error_response_no_description = b"HTTP/1.1 500 Internal Server Error\r\n" \
b"Server: WebServer\r\n" \
b"Date: Thu, 11 Oct 2018 22:16:17 GMT\r\n" \
b"Connection: close\r\n" \
b"CONTENT-TYPE: text/xml; charset=\"utf-8\"\r\n" \
b"CONTENT-LENGTH: 429 \r\n" \
b"EXT:\r\n" \
b"\r\n" \
b"<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<s:Fault>\n\t\t\t<faultcode>s:Client</faultcode>\n\t\t\t<faultstring>UPnPError</faultstring>\n\t\t\t<detail>\n\t\t\t\t<UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\">\n\t\t\t\t\t<errorCode>713</errorCode>\n\t\t\t\t\t\n\t\t\t\t</UPnPError>\n\t\t\t</detail>\n\t\t</s:Fault>\n\t</s:Body>\n</s:Envelope>\n"
def test_serialize_post(self): def test_serialize_post(self):
self.assertEqual(serialize_soap_post( self.assertEqual(serialize_soap_post(
self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs
@ -94,3 +104,13 @@ class TestSOAPSerialization(unittest.TestCase):
raised = True raised = True
self.assertTrue(str(err) == 'SpecifiedArrayIndexInvalid') self.assertTrue(str(err) == 'SpecifiedArrayIndexInvalid')
self.assertTrue(raised) self.assertTrue(raised)
def test_raise_from_error_response_without_error_description(self):
raised = False
expected = 'Failed to decode error response: {"faultcode": "s:Client", "faultstring": "UPnPError", "detail": {"UPnPError": {"errorCode": "713"}}}'
try:
deserialize_soap_post_response(self.error_response_no_description, self.method, service_id=self.st.decode())
except UPnPError as err:
raised = True
self.assertTrue(str(err) == expected)
self.assertTrue(raised)

File diff suppressed because one or more lines are too long

View file

@ -44,8 +44,46 @@ class TestGetExternalIPAddress(UPnPCommandTestCase):
async def test_get_external_ip(self): async def test_get_external_ip(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands(self.loop) await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
self.assertEqual("11.222.3.44", external_ip)
class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
client_address = '11.2.3.222'
get_ip_request = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\nHost: 11.2.3.4\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\nContent-Length: 285\r\nContent-Type: text/xml\r\nSOAPAction: "urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress"\r\nConnection: Close\r\nCache-Control: no-cache\r\nPragma: no-cache\r\n\r\n<?xml version="1.0"?>\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body><u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1"></u:GetExternalIPAddress></s:Body></s:Envelope>\r\n'
async def test_response_key_mismatch(self):
self.replies.update({self.get_ip_request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 333 \r\nEXT:\r\n\r\n<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\n"
b"<derp>11.222.3.44</derp>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
with self.assertRaises(UPnPError):
await upnp.get_external_ip()
async def test_response_key_case_sensitivity(self):
self.replies.update({self.get_ip_request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 365 \r\nEXT:\r\n\r\n<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\n"
b"<newexternalipaddress>11.222.3.44</newexternalipaddress>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
self.assertEqual("11.222.3.44", external_ip)
async def test_non_encapsulated_single_field_response(self):
self.replies.update({self.get_ip_request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 320 \r\nEXT:\r\n\r\n<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\n"
b"11.222.3.44\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway) upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip() external_ip = await upnp.get_external_ip()
self.assertEqual("11.222.3.44", external_ip) self.assertEqual("11.222.3.44", external_ip)
@ -62,8 +100,8 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
async def test_get_port_mapping_by_index(self): async def test_get_port_mapping_by_index(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands(self.loop) await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway) upnp = UPnP(self.client_address, self.gateway_address, gateway)
result = await upnp.get_port_mapping_by_index(0) result = await upnp.get_port_mapping_by_index(0)
self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True, self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True,
@ -84,8 +122,8 @@ class TestGetNextPortMapping(UPnPCommandTestCase):
async def test_get_next_mapping(self): async def test_get_next_mapping(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands(self.loop) await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway) upnp = UPnP(self.client_address, self.gateway_address, gateway)
ext_port = await upnp.get_next_mapping(4567, "UDP", "aioupnp test mapping") ext_port = await upnp.get_next_mapping(4567, "UDP", "aioupnp test mapping")
self.assertEqual(4567, ext_port) self.assertEqual(4567, ext_port)
@ -104,8 +142,8 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase):
async def test_get_specific_port_mapping(self): async def test_get_specific_port_mapping(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address) gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands(self.loop) await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway) upnp = UPnP(self.client_address, self.gateway_address, gateway)
try: try:
await upnp.get_specific_port_mapping(1000, 'UDP') await upnp.get_specific_port_mapping(1000, 'UDP')