From 55e1621637cec950f2866409b21c9bce34f70f8c Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Thu, 18 Oct 2018 10:46:12 -0400 Subject: [PATCH] fix scpd post dropping some responses -improve debug_gateway and generate_test_data --- aioupnp/gateway.py | 30 ++++++-- aioupnp/protocols/scpd.py | 136 ++++++++++++++++++++++------------ aioupnp/protocols/soap.py | 30 ++++---- aioupnp/serialization/soap.py | 7 +- aioupnp/upnp.py | 36 +++++++-- 5 files changed, 161 insertions(+), 78 deletions(-) diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index b27baf2..6bd38aa 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -125,13 +125,23 @@ class Gateway: return None @property - def _soap_requests(self) -> Dict: - return { - name: getattr(self.commands, name)._requests for name in self._registered_commands.keys() - } + def soap_requests(self) -> List: + soap_call_infos = [] + for name in self._registered_commands.keys(): + if not hasattr(getattr(self.commands, name), "_requests"): + continue + soap_call_infos.extend([ + (name, request_args, raw_response, decoded_response, soap_error, ts) + for ( + request_args, raw_response, decoded_response, soap_error, ts + ) in getattr(self.commands, name)._requests + ]) + soap_call_infos.sort(key=lambda x: x[5]) + return soap_call_infos def debug_gateway(self) -> Dict: return { + 'manufacturer_string': self.manufacturer_string, 'gateway_address': self.base_ip, 'gateway_descriptor': self.gateway_descriptor(), 'gateway_xml': self._xml_response, @@ -142,7 +152,7 @@ class Gateway: 'soap_port': self.port, 'registered_soap_commands': self._registered_commands, 'unsupported_soap_commands': self._unsupported_actions, - 'soap_requests': self._soap_requests + 'soap_requests': self.soap_requests } @classmethod @@ -163,7 +173,7 @@ class Gateway: await gateway.discover_commands(soap_socket) log.debug('found gateway device %s', datagram.location) return gateway - except asyncio.TimeoutError: + except (asyncio.TimeoutError, UPnPError): log.debug("get %s timed out, looking for other devices", datagram.location) ignored.add(datagram.location) continue @@ -189,8 +199,10 @@ class Gateway: return result async def discover_commands(self, soap_socket: socket.socket = None): - response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) + response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) self._xml_response = xml_bytes + if get_err is not None: + raise get_err self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) self.url_base = get_dict_val_case_insensitive(response, "urlbase") if not self.url_base: @@ -207,9 +219,11 @@ class Gateway: async def register_commands(self, service: Service, soap_socket: socket.socket = None): if not service.SCPDURL: raise UPnPError("no scpd url") - service_dict, xml_bytes = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) + service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) self._service_descriptors[service.SCPDURL] = xml_bytes + if get_err is not None: + raise get_err if not service_dict: return diff --git a/aioupnp/protocols/scpd.py b/aioupnp/protocols/scpd.py index b10b6bb..25ef079 100644 --- a/aioupnp/protocols/scpd.py +++ b/aioupnp/protocols/scpd.py @@ -1,10 +1,13 @@ import logging import socket import typing +import re +from collections import OrderedDict from xml.etree import ElementTree import asyncio from asyncio.protocols import Protocol from aioupnp.fault import UPnPError +from aioupnp.util import get_dict_val_case_insensitive from aioupnp.serialization.scpd import deserialize_scpd_get_response from aioupnp.serialization.scpd import serialize_scpd_get from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response @@ -13,87 +16,126 @@ from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_pos log = logging.getLogger(__name__) -class SCPDHTTPClientProtocol(Protocol): - POST = 'POST' - GET = 'GET' +HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$") - def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None, - soap_service_id: str=None, close_after_send: bool = False) -> None: - self.method = method - assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \ - 'soap args not provided' + +def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]: + lines = response.split(b'\r\n') + headers = OrderedDict([ + (l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' ')) + for l in response.split(b'\r\n') + ]) + if len(lines) != len(headers): + raise ValueError("duplicate headers") + http_response = tuple(headers.keys())[0] + response_code, message = HTTP_CODE_REGEX.findall(http_response)[0] + del headers[http_response] + return headers, int(response_code), message + + +class SCPDHTTPClientProtocol(Protocol): + """ + This class will make HTTP GET and POST requests + + It differs from spec HTTP in that the version string can be invalid, all we care about is the xml body + and devices respond with an invalid HTTP version line + """ + + def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None, + soap_service_id: str=None) -> None: self.message = message self.response_buff = b"" self.finished = finished self.soap_method = soap_method self.soap_service_id = soap_service_id - self.close_after_send = close_after_send + + self._response_code: int = 0 + self._response_msg: bytes = b"" + self._content_length: int = 0 + self._got_headers = False + self._headers: dict = {} + self._body = b"" def connection_made(self, transport): transport.write(self.message) - if self.close_after_send: - self.finished.set_result(None) def data_received(self, data): self.response_buff += data - if self.method == self.GET: - try: - packet = deserialize_scpd_get_response(self.response_buff) - if packet: - self.finished.set_result(packet) - return - except ElementTree.ParseError: - pass - except UPnPError as err: - self.finished.set_exception(err) - elif self.method == self.POST: - try: - packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id) - if packet: - self.finished.set_result(packet) - return - except ElementTree.ParseError: - pass - except UPnPError as err: - self.finished.set_exception(err) + for i, line in enumerate(self.response_buff.split(b'\r\n')): + if not line: # we hit the blank line between the headers and the body + if i == (len(self.response_buff.split(b'\r\n')) - 1): + continue # the body is still yet to be written + if not self._got_headers: + self._headers, self._response_code, self._response_msg = parse_headers( + b'\r\n'.join(self.response_buff.split(b'\r\n')[:i]) + ) + content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length') + if content_length is None: + return + self._content_length = int(content_length or 0) + self._got_headers = True + body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:]) + if self._content_length == len(body): + self.finished.set_result((body, self._response_code, self._response_msg)) + elif self._content_length > len(body): + pass + else: + self.finished.set_exception( + UPnPError( + "too many bytes written to response (%i vs %i expected)" % ( + len(body), self._content_length + ) + ) + ) + return -async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes]: +async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes, + typing.Optional[Exception]]: loop = asyncio.get_running_loop() finished: asyncio.Future = asyncio.Future() packet = serialize_scpd_get(control_url, address) transport, protocol = await loop.create_connection( - lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port + lambda : SCPDHTTPClientProtocol(packet, finished), address, port ) assert isinstance(protocol, SCPDHTTPClientProtocol) - parsed: typing.Dict = {} + error = None try: - parsed = await asyncio.wait_for(finished, 1.0) - except UPnPError: - return parsed, protocol.response_buff + body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) + except asyncio.TimeoutError: + error = UPnPError("get request timed out") + body = b'' finally: transport.close() - return parsed, protocol.response_buff + if not error: + try: + return deserialize_scpd_get_response(body), body, None + except ElementTree.ParseError as err: + error = UPnPError(err) + return {}, body, error async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, - close_after_send: bool, soap_socket: socket.socket = None, - **kwargs) -> typing.Tuple[typing.Dict, bytes]: + soap_socket: socket.socket = None, **kwargs) -> typing.Tuple[typing.Dict, bytes, + typing.Optional[Exception]]: loop = asyncio.get_running_loop() finished: asyncio.Future = asyncio.Future() packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) transport, protocol = await loop.create_connection( lambda : SCPDHTTPClientProtocol( - 'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(), - close_after_send=close_after_send + packet, finished, soap_method=method, soap_service_id=service_id.decode(), ), address, port, sock=soap_socket ) assert isinstance(protocol, SCPDHTTPClientProtocol) - parsed: typing.Dict = {} try: - parsed = await asyncio.wait_for(finished, 1.0) - except UPnPError: - return parsed, protocol.response_buff + body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) + except asyncio.TimeoutError: + return {}, b'', UPnPError("Timeout") finally: transport.close() - return parsed, protocol.response_buff + try: + return ( + deserialize_soap_post_response(body, method, service_id.decode()), body, None + ) + except (ElementTree.ParseError, UPnPError) as err: + return {}, body, UPnPError(err) diff --git a/aioupnp/protocols/soap.py b/aioupnp/protocols/soap.py index 89d3cc5..f230590 100644 --- a/aioupnp/protocols/soap.py +++ b/aioupnp/protocols/soap.py @@ -2,6 +2,7 @@ import logging import socket import asyncio import typing +import time from aioupnp.protocols.scpd import scpd_post from aioupnp.fault import UPnPError @@ -39,20 +40,21 @@ class SOAPCommand: async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]: if set(kwargs.keys()) != set(self.param_types.keys()): raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) - close_after_send = not self.return_types or self.return_types == [None] soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()} - try: - response, xml_bytes = await scpd_post( - self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id, - close_after_send, self.soap_socket, **soap_kwargs - ) - except asyncio.TimeoutError as err: - raise UPnPError(err) - - self._requests.append((soap_kwargs, xml_bytes)) + response, xml_bytes, err = await scpd_post( + self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, + self.service_id, self.soap_socket, **soap_kwargs + ) + if err is not None: + self._requests.append((soap_kwargs, xml_bytes, None, err, time.time())) + raise err if not response: - return None - result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order]) - if len(result) == 1: - return result[0] + result = None + else: + recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order]) + if len(recast_result) == 1: + result = recast_result[0] + else: + result = recast_result + self._requests.append((soap_kwargs, xml_bytes, result, None, time.time())) return result diff --git a/aioupnp/serialization/soap.py b/aioupnp/serialization/soap.py index 55ecfed..fd5b906 100644 --- a/aioupnp/serialization/soap.py +++ b/aioupnp/serialization/soap.py @@ -53,11 +53,12 @@ def deserialize_soap_post_response(response: bytes, method: str, service_id: str response_body = flatten_keys(envelope[BODY], "{%s}" % service_id) body = handle_fault(response_body) # raises UPnPError if there is a fault response_key = None - + if not body: + return {} for key in body: if method in key: response_key = key break if not response_key: - raise UPnPError("unknown response fields for %s") - return body[response_key] \ No newline at end of file + raise UPnPError("unknown response fields for %s: %s" % (method, body)) + return body[response_key] diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index cb26f0e..6d37c09 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -133,11 +133,11 @@ class UPnP: result = await self.gateway.commands.GetSpecificPortMappingEntry( NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol ) - if isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand): + if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand): return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} - return {} except UPnPError: - return {} + pass + return {} @cli async def delete_port_mapping(self, external_port: int, protocol: str) -> None: @@ -196,18 +196,42 @@ class UPnP: except (UPnPError, NotImplementedError): print("failed to get the external ip") try: - redirects = await self.get_redirects() - print("got redirects:\n%s" % redirects) + await self.get_redirects() + print("got redirects") except (UPnPError, NotImplementedError): print("failed to get redirects") + try: + await self.get_specific_port_mapping(4567, "UDP") + print("got specific mapping") + except (UPnPError, NotImplementedError): + print("failed to get specific mapping") try: ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping") print("set up external mapping to port %i" % ext_port) + try: + await self.get_specific_port_mapping(4567, "UDP") + print("got specific mapping") + except (UPnPError, NotImplementedError): + print("failed to get specific mapping") + try: + await self.get_redirects() + print("got redirects") + except (UPnPError, NotImplementedError): + print("failed to get redirects") await self.delete_port_mapping(ext_port, "UDP") print("deleted mapping") except (UPnPError, NotImplementedError): print("failed to add and remove a mapping") - + try: + await self.get_redirects() + print("got redirects") + except (UPnPError, NotImplementedError): + print("failed to get redirects") + try: + await self.get_specific_port_mapping(4567, "UDP") + print("got specific mapping") + except (UPnPError, NotImplementedError): + print("failed to get specific mapping") if self.gateway.devices: device = list(self.gateway.devices.values())[0] assert device.manufacturer and device.modelName