diff --git a/aioupnp/constants.py b/aioupnp/constants.py index cdb0c52..63e40ed 100644 --- a/aioupnp/constants.py +++ b/aioupnp/constants.py @@ -22,8 +22,4 @@ SSDP_IP_ADDRESS = '239.255.255.250' SSDP_PORT = 1900 SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT) SSDP_DISCOVER = "ssdp:discover" -SSDP_ALL = "ssdp:all" -SSDP_BYEBYE = "ssdp:byebye" -SSDP_UPDATE = "ssdp:update" -SSDP_ROOT_DEVICE = "upnp:rootdevice" line_separator = "\r\n" diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index 03bc27c..87e5d78 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -1,13 +1,15 @@ import logging import socket +from collections import OrderedDict from typing import Dict, List, Union, Type from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX -from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE +from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.commands import SOAPCommands from aioupnp.device import Device, Service from aioupnp.protocols.ssdp import fuzzy_m_search from aioupnp.protocols.scpd import scpd_get -from aioupnp.protocols.soap import SCPDCommand +from aioupnp.protocols.soap import SOAPCommand +from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.util import flatten_keys from aioupnp.fault import UPnPError @@ -53,43 +55,36 @@ def get_action_list(element_dict: dict) -> List: # [(, [, ...], class Gateway: - def __init__(self, **kwargs): - flattened = { - k.lower(): v for k, v in kwargs.items() - } - usn = flattened["usn"] - server = flattened["server"] - location = flattened["location"] - st = flattened["st"] + def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str, + gateway_address: str) -> None: + self._ok_packet = ok_packet + self._m_search_args = m_search_args + self._lan_address = lan_address + self.usn = (ok_packet.usn or '').encode() + self.ext = (ok_packet.ext or '').encode() + self.server = (ok_packet.server or '').encode() + self.location = (ok_packet.location or '').encode() + self.cache_control = (ok_packet.cache_control or '').encode() + self.date = (ok_packet.date or '').encode() + self.urn = (ok_packet.st or '').encode() - cache_control = flattened.get("cache_control") or flattened.get("cache-control") or "" - date = flattened.get("date", "") - ext = flattened.get("ext", "") - - self.usn = usn.encode() - self.ext = ext.encode() - self.server = server.encode() - self.location = location.encode() - self.cache_control = cache_control.encode() - self.date = date.encode() - self.urn = st.encode() - - self._xml_response = "" - self._service_descriptors = {} + self._xml_response = b"" + self._service_descriptors: Dict = {} self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] + assert self.base_ip == gateway_address.encode() self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] self.spec_version = None self.url_base = None - self._device = None - self._devices = [] - self._services = [] + self._device: Union[None, Device] = None + self._devices: List = [] + self._services: List = [] - self._unsupported_actions = {} - self._registered_commands = {} + self._unsupported_actions: Dict = {} + self._registered_commands: Dict = {} self.commands = SOAPCommands() def gateway_descriptor(self) -> dict: @@ -103,6 +98,12 @@ class Gateway: } return r + @property + def manufacturer_string(self) -> str: + if not self._device: + raise NotImplementedError() + return "%s %s" % (self._device.manufacturer, self._device.modelName) + @property def services(self) -> Dict: if not self._device: @@ -121,23 +122,36 @@ class Gateway: return service return None - def debug_commands(self): + @property + def _soap_requests(self) -> Dict: return { - 'available': self._registered_commands, - 'failed': self._unsupported_actions + name: getattr(self.commands, name)._requests for name in self._registered_commands.keys() + } + + def debug_gateway(self) -> Dict: + return { + 'gateway_address': self.base_ip, + 'soap_port': self.port, + 'm_search_args': self._m_search_args, + 'reply': self._ok_packet.as_dict(), + 'registered_soap_commands': self._registered_commands, + 'unsupported_soap_commands': self._unsupported_actions, + 'gateway_xml': self._xml_response, + 'service_descriptors': self._service_descriptors, + 'soap_requests': self._soap_requests } @classmethod - async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1, + async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): - datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) - gateway = cls(**datagram.as_dict()) + m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) + gateway = cls(datagram, m_search_args, lan_address, gateway_address) await gateway.discover_commands(soap_socket) return gateway async def discover_commands(self, soap_socket: socket.socket = None): - response = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) - + response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) + self._xml_response = xml_bytes self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) self.url_base = get_dict_val_case_insensitive(response, "urlbase") if not self.url_base: @@ -154,7 +168,9 @@ class Gateway: async def register_commands(self, service: Service, soap_socket: socket.socket = None): if not service.SCPDURL: raise UPnPError("no scpd url") - service_dict = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) + service_dict, xml_bytes = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) + self._service_descriptors[service.SCPDURL] = xml_bytes + if not service_dict: return @@ -176,7 +192,7 @@ class Gateway: if param_name == "return": continue param_types[param_name] = param_type - command = SCPDCommand( + command = SOAPCommand( self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(), name, param_types, return_types, inputs, outputs, soap_socket) setattr(command, "__doc__", current.__doc__) diff --git a/aioupnp/protocols/m_search_patterns.py b/aioupnp/protocols/m_search_patterns.py index c3322fa..a657b81 100644 --- a/aioupnp/protocols/m_search_patterns.py +++ b/aioupnp/protocols/m_search_patterns.py @@ -1,81 +1,84 @@ -M_SEARCH_ARG_PATTERNS = [ - # - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('ST', lambda s: s), - ('MAN', lambda s: '"%s"' % s), - ('MX', lambda n: int(n)), - ], - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('ST', lambda s: s), - ('Man', lambda s: '"%s"' % s), - ('MX', lambda n: int(n)), - ], - [ - ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('ST', lambda s: s), - ('Man', lambda s: '"%s"' % s), - ('MX', lambda n: int(n)), - ], +""" +Alleged SSDP discovery documentation - # swap st and man - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('MAN', lambda s: '"%s"' % s), - ('ST', lambda s: s), - ('MX', lambda n: int(n)), - ], - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('Man', lambda s: '"%s"' % s), - ('ST', lambda s: s), - ('MX', lambda n: int(n)), - ], - [ - ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('Man', lambda s: '"%s"' % s), - ('ST', lambda s: s), - ('MX', lambda n: int(n)), - ], +M-SEARCH * HTTP/1.1 - # repeat above but with no quotes on man - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('ST', lambda s: s), - ('MAN', lambda s: s), - ('MX', lambda n: int(n)), - ], - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('ST', lambda s: s), - ('Man', lambda s: s), - ('MX', lambda n: int(n)), - ], - [ - ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('ST', lambda s: s), - ('Man', lambda s: s), - ('MX', lambda n: int(n)), - ], +Headers +HOST + Required. Multicast channel and port reserved for SSDP by Internet Assigned Numbers Authority (IANA). Must be + 239.255.255.250:1900. If the port number (“:1900”) is omitted, the receiver should assume the default SSDP port + number of 1900. +MAN + Required by HTTP Extension Framework. Unlike the NTS and ST headers, the value of the MAN header is enclosed in + double quotes; it defines the scope (namespace) of the extension. Must be "ssdp:discover". +MX + Required. Maximum wait time in seconds. Should be between 1 and 120 inclusive. Device responses should be delayed a + random duration between 0 and this many seconds to balance load for the control point when it processes responses. + This value may be increased if a large number of devices are expected to respond. The MX value should not be + increased to accommodate network characteristics such as latency or propagation delay (for more details, see the + explanation below). Specified by UPnP vendor. Integer. +ST + Required. Search Target. Must be one of the following. (cf. NT header in NOTIFY with ssdp:alive above.) Single URI. - # swap st and man - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('MAN', lambda s: s), - ('ST', lambda s: s), - ('MX', lambda n: int(n)), - ], - [ - ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('Man', lambda s: s), - ('ST', lambda s: s), - ('MX', lambda n: int(n)), - ], - [ - ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), - ('Man', lambda s: s), - ('ST', lambda s: str(s)), - ('MX', lambda n: int(n)), - ], -] \ No newline at end of file + ssdp:all + Search for all devices and services. + + upnp:rootdevice + Search for root devices only. + + uuid:device-UUID + Search for a particular device. Device UUID specified by UPnP vendor. + + urn:schemas-upnp-org:device:deviceType:v + Search for any device of this type. Device type and version defined by UPnP Forum working committee. + + urn:schemas-upnp-org:service:serviceType:v + Search for any service of this type. Service type and version defined by UPnP Forum working committee. + + urn:domain-name:device:deviceType:v + Search for any device of this type. Domain name, device type and version defined by UPnP vendor. Period + characters in the domain name must be replaced with hyphens in accordance with RFC 2141. + + urn:domain-name:service:serviceType:v + Search for any service of this type. Domain name, service type and version defined by UPnP vendor. Period + characters in the domain name must be replaced with hyphens in accordance with RFC 2141. +""" + +from collections import OrderedDict +from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST + +SEARCH_TARGETS = [ + 'ssdp:all' + 'urn:schemas-upnp-org:device:InternetGatewayDevice:1', + 'upnp:rootdevice', + 'urn:schemas-wifialliance-org:device:WFADevice:1', + 'urn:schemas-upnp-org:device:WANDevice:1', +] + + +def format_packet_args(order: list, **kwargs): + args = [] + for o in order: + for k, v in kwargs.items(): + if k.lower() == o.lower(): + args.append((k, v)) + break + return OrderedDict(args) + + +def packet_generator(): + for st in SEARCH_TARGETS: + order = ["HOST", "MAN", "MX", "ST"] + yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) + + yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st) + + order = ["HOST", "MAN", "ST", "MX"] + yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) + + order = ["HOST", "ST", "MAN", "MX"] + yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) diff --git a/aioupnp/protocols/scpd.py b/aioupnp/protocols/scpd.py index f3cad9a..b10b6bb 100644 --- a/aioupnp/protocols/scpd.py +++ b/aioupnp/protocols/scpd.py @@ -1,5 +1,6 @@ import logging import socket +import typing from xml.etree import ElementTree import asyncio from asyncio.protocols import Protocol @@ -38,43 +39,46 @@ class SCPDHTTPClientProtocol(Protocol): if self.method == self.GET: try: packet = deserialize_scpd_get_response(self.response_buff) - if not packet: - return - except ElementTree.ParseError: - pass - except UPnPError as err: - self.finished.set_exception(err) - else: - self.finished.set_result(packet) - elif self.method == self.POST: - try: - packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id) - if not packet: + if packet: + self.finished.set_result(packet) + return + except ElementTree.ParseError: + pass + except UPnPError as err: + self.finished.set_exception(err) + elif self.method == self.POST: + try: + packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id) + if packet: self.finished.set_result(packet) return except ElementTree.ParseError: pass except UPnPError as err: self.finished.set_exception(err) - else: - self.finished.set_result(packet) -async def scpd_get(control_url: str, address: str, port: int) -> dict: +async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes]: loop = asyncio.get_running_loop() finished: asyncio.Future = asyncio.Future() packet = serialize_scpd_get(control_url, address) transport, protocol = await loop.create_connection( lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port ) + assert isinstance(protocol, SCPDHTTPClientProtocol) + parsed: typing.Dict = {} try: - return await asyncio.wait_for(finished, 1.0) + parsed = await asyncio.wait_for(finished, 1.0) + except UPnPError: + return parsed, protocol.response_buff finally: transport.close() + return parsed, protocol.response_buff async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, - close_after_send: bool, soap_socket: socket.socket = None, **kwargs): + close_after_send: bool, soap_socket: socket.socket = None, + **kwargs) -> typing.Tuple[typing.Dict, bytes]: loop = asyncio.get_running_loop() finished: asyncio.Future = asyncio.Future() packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) @@ -84,7 +88,12 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para close_after_send=close_after_send ), address, port, sock=soap_socket ) + assert isinstance(protocol, SCPDHTTPClientProtocol) + parsed: typing.Dict = {} try: - return await asyncio.wait_for(finished, 1.0) + parsed = await asyncio.wait_for(finished, 1.0) + except UPnPError: + return parsed, protocol.response_buff finally: transport.close() + return parsed, protocol.response_buff diff --git a/aioupnp/protocols/soap.py b/aioupnp/protocols/soap.py index 4967769..812a21d 100644 --- a/aioupnp/protocols/soap.py +++ b/aioupnp/protocols/soap.py @@ -6,7 +6,7 @@ from aioupnp.protocols.scpd import scpd_post log = logging.getLogger(__name__) -class SCPDCommand: +class SOAPCommand: def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, param_types: dict, return_types: dict, param_order: list, return_order: list, soap_socket: socket.socket = None) -> None: @@ -20,18 +20,21 @@ class SCPDCommand: self.return_types = return_types self.return_order = return_order self.soap_socket = soap_socket + self._requests: typing.List = [] async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]: if set(kwargs.keys()) != set(self.param_types.keys()): raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) close_after_send = not self.return_types or self.return_types == [None] - response = await scpd_post( + soap_kwargs = {n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()} + response, xml_bytes = await scpd_post( self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id, - close_after_send, self.soap_socket, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()} + close_after_send, self.soap_socket, **soap_kwargs ) - result = tuple([self.return_types[n](response.get(n)) for n in self.return_order]) - if not result: + self._requests.append((soap_kwargs, xml_bytes)) + if not response: return None + result = tuple([self.return_types[n](response.get(n)) for n in self.return_order]) if len(result) == 1: return result[0] return result diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index fb69540..dde6684 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -9,10 +9,9 @@ from asyncio.futures import Future from asyncio.transports import DatagramTransport from aioupnp.fault import UPnPError from aioupnp.serialization.ssdp import SSDPDatagram -from aioupnp.constants import UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD -from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE, SSDP_ALL +from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT from aioupnp.protocols.multicast import MulticastProtocol -from aioupnp.protocols.m_search_patterns import M_SEARCH_ARG_PATTERNS +from aioupnp.protocols.m_search_patterns import packet_generator ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$") @@ -23,21 +22,45 @@ class SSDPProtocol(MulticastProtocol): def __init__(self, multicast_address: str, lan_address: str) -> None: super().__init__(multicast_address, lan_address) self.lan_address = lan_address - self.discover_callbacks: Dict = {} + self._pending_searches: List[Tuple[str, str, Future, asyncio.Handle]] = [] + self.notifications: List = [] - self.replies: List = [] - def m_search(self, address: str, timeout: int, datagram_args: OrderedDict) -> Future: - packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram_args) - f: Future = Future() - futs = self.discover_callbacks.get((address, packet.st), []) - futs.append(f) - self.discover_callbacks[(address, packet.st)] = futs - log.debug("send m search to %s: %s", address, packet) - self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) + def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: + tmp: list = [] + set_futures: list = [] + while self._pending_searches: + t: tuple = self._pending_searches.pop() + a, s = t[0], t[1] + if (address == a) and (s in [packet.st, "upnp:rootdevice"]): + f: Future = t[2] + h: asyncio.Handle = t[3] + h.cancel() + if f not in set_futures: + set_futures.append(f) + if not f.done(): + f.set_result(packet) + elif t[2] not in set_futures: + tmp.append(t) + while tmp: + self._pending_searches.append(tmp.pop()) - r: Future = asyncio.ensure_future(asyncio.wait_for(f, timeout)) - return r + def send_many_m_searches(self, address: str, packets: List[SSDPDatagram]): + for packet in packets: + log.debug("send m search to %s: %s", address, packet.st) + self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) + + async def m_search(self, address: str, timeout: float, datagrams: List[OrderedDict]) -> SSDPDatagram: + fut: Future = Future() + packets: List[SSDPDatagram] = [] + for datagram in datagrams: + packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) + assert packet.st is not None + h = asyncio.get_running_loop().call_later(timeout, fut.cancel) + self._pending_searches.append((address, packet.st, fut, h)) + packets.append(packet) + self.send_many_m_searches(address, packets), + return await fut def datagram_received(self, data, addr) -> None: if addr[0] == self.lan_address: @@ -51,15 +74,8 @@ class SSDPProtocol(MulticastProtocol): return if packet._packet_type == packet._OK: - if (addr[0], packet.st) in self.discover_callbacks: - log.debug("%s:%i replied to our m-search", addr[0], addr[1]) - if packet.st not in map(lambda p: p['st'], self.replies): - self.replies.append(packet) - for ok_fut in self.discover_callbacks[(addr[0], packet.st)]: - ok_fut.set_result(packet) - del self.discover_callbacks[(addr[0], packet.st)] - return - + self._callback_m_search_ok(addr[0], packet) + return # elif packet._packet_type == packet._NOTIFY: # log.debug("%s:%i sent us a notification: %s", packet) # if packet.nt == SSDP_ROOT_DEVICE: @@ -109,46 +125,42 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Ordere lan_address, gateway_address, ssdp_socket ) try: - return await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=datagram_args) - except asyncio.TimeoutError: + return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) + except (asyncio.TimeoutError, asyncio.CancelledError): raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) finally: transport.close() -async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 1, - ssdp_socket: socket.socket = None) -> SSDPDatagram: +async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, + ssdp_socket: socket.socket = None) -> List[OrderedDict]: transport, protocol, gateway_address, lan_address = await listen_ssdp( lan_address, gateway_address, ssdp_socket ) - datagram_kwargs: list = [] - services = [UPNP_ORG_IGD, SSDP_ALL, WIFI_ALLIANCE_ORG_IGD] - mans = [SSDP_DISCOVER, SSDP_ROOT_DEVICE] - mx = 1 - - for service in services: - for man in mans: - for arg_pattern in M_SEARCH_ARG_PATTERNS: - dgram_kwargs: OrderedDict = OrderedDict() - for k, l in arg_pattern: - if k.lower() == 'host': - dgram_kwargs[k] = l(SSDP_IP_ADDRESS) - elif k.lower() == 'st': - dgram_kwargs[k] = l(service) - elif k.lower() == 'man': - dgram_kwargs[k] = l(man) - elif k.lower() == 'mx': - dgram_kwargs[k] = l(mx) - datagram_kwargs.append(dgram_kwargs) - - for i, args in enumerate(datagram_kwargs): + packet_args = list(packet_generator()) + batch_size = 2 + b = 0 + batch_timeout = float(timeout) / float(len(packet_args)) + while packet_args: + args = packet_args[:batch_size] + packet_args = packet_args[batch_size:] + log.debug("sending batch of %i M-SEARCH attempts", batch_size) try: - result = await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=args) - transport.close() - return result - except asyncio.TimeoutError: - pass - except Exception as err: - log.error(err) - transport.close() + await protocol.m_search(gateway_address, batch_timeout, args) + return args + except (asyncio.TimeoutError, asyncio.CancelledError): + b += 1 + continue raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) + + +async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, + ssdp_socket: socket.socket = None) -> Tuple[OrderedDict, SSDPDatagram]: + args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) + for args in args_to_try: + try: + packet = await m_search(lan_address, gateway_address, args, 3) + return args, packet + except UPnPError: + continue + raise UPnPError("failed to discover gateway") diff --git a/aioupnp/serialization/ssdp.py b/aioupnp/serialization/ssdp.py index b9bfc25..103296e 100644 --- a/aioupnp/serialization/ssdp.py +++ b/aioupnp/serialization/ssdp.py @@ -9,7 +9,6 @@ from aioupnp.constants import line_separator log = logging.getLogger(__name__) - _template = "^(?i)(%s):[ ]*(.*)$" @@ -54,8 +53,8 @@ class SSDPDatagram(object): _M_SEARCH: [ 'host', 'man', - 'st', 'mx', + 'st', ], _NOTIFY: [ 'host', @@ -85,9 +84,9 @@ class SSDPDatagram(object): k.lower().replace("-", "_") for k in kwargs.keys() ] self.host = None - self.st = None self.man = None self.mx = None + self.st = None self.nt = None self.nts = None self.usn = None diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index ac12533..8bfa3b5 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -8,7 +8,8 @@ from typing import Tuple, Dict, List, Union from aioupnp.fault import UPnPError from aioupnp.gateway import Gateway from aioupnp.util import get_gateway_and_lan_addresses -from aioupnp.protocols.ssdp import m_search, fuzzy_m_search +from aioupnp.protocols.ssdp import m_search +from aioupnp.protocols.soap import SOAPCommand log = logging.getLogger(__name__) @@ -42,7 +43,7 @@ class UPnP: return lan_address, gateway_address @classmethod - async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, + async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): try: @@ -77,7 +78,7 @@ class UPnP: await self.gateway.commands.AddPortMapping( NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol, NewInternalPort=internal_port, NewInternalClient=lan_address, - NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration="" + NewEnabled=True, NewPortMappingDescription=description, NewLeaseDuration="" ) return @@ -85,13 +86,14 @@ class UPnP: async def get_port_mapping_by_index(self, index: int) -> Dict: result = await self._get_port_mapping_by_index(index) if result: - return { - k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result) - } + if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand): + return { + k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result) + } return {} - async def _get_port_mapping_by_index(self, index: int) -> Union[Tuple[str, int, str, int, str, bool, str, int], - None]: + async def _get_port_mapping_by_index(self, index: int) -> Union[None, + Tuple[Union[None, str], int, str, int, str, bool, str, int]]: try: redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) return redirect @@ -119,9 +121,11 @@ class UPnP: try: result = await self.gateway.commands.GetSpecificPortMappingEntry( - NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol + NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol ) - return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} + if isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand): + return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} + return {} except UPnPError: return {} @@ -175,49 +179,38 @@ class UPnP: @cli async def generate_test_data(self): - external_ip = await self.get_external_ip() - redirects = await self.get_redirects() - ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping") - delete = await self.delete_port_mapping(ext_port, "UDP") - after_delete = await self.get_specific_port_mapping(ext_port, "UDP") + print("found gateway via M-SEARCH") + try: + external_ip = await self.get_external_ip() + print("got external ip: %s" % external_ip) + except UPnPError: + print("failed to get the external ip") + try: + redirects = await self.get_redirects() + print("got redirects:\n%s" % redirects) + except UPnPError: + print("failed to get redirects") - commands_test_case = ( - ("get_external_ip", (), "1.2.3.4"), - ("get_redirects", (), redirects), - ("get_next_mapping", (4567, "UDP", "aioupnp test mapping"), ext_port), - ("delete_port_mapping", (ext_port, "UDP"), delete), - ("get_specific_port_mapping", (ext_port, "UDP"), after_delete), - ) + try: + ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping") + print("set up external mapping to port %i" % ext_port) + await self.delete_port_mapping(ext_port, "UDP") + print("deleted mapping") + except UPnPError: + print("failed to add and remove a mapping") - gateway = self.gateway - device = list(gateway.devices.values())[0] + device = list(self.gateway.devices.values())[0] assert device.manufacturer and device.modelName - device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName)) - commands = gateway.debug_commands() + device_path = os.path.join(os.getcwd(), self.gateway.manufacturer_string) with open(device_path, "w") as f: f.write(json.dumps({ - "router_address": self.gateway_address, + "gateway": self.gateway.debug_gateway(), "client_address": self.lan_address, - "port": gateway.port, - "gateway_dict": gateway.gateway_descriptor(), - 'expected_devices': [ - { - 'cache_control': 'max-age=1800', - 'location': gateway.location, - 'server': gateway.server, - 'st': gateway.urn, - 'usn': gateway.usn - } - ], - 'commands': commands, - # 'ssdp': u.sspd_factory.get_ssdp_packet_replay(), - # 'scpd': gateway.requester.dump_packets(), - 'soap': commands_test_case - }, default=_encode, indent=2).replace(external_ip, "1.2.3.4")) + }, default=_encode, indent=2)) return "Generated test data! -> %s" % device_path @classmethod - def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 60, + def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', kwargs: dict = None) -> None: kwargs = kwargs or {} igd_args = igd_args @@ -257,6 +250,9 @@ class UPnP: log.exception("uncaught error") fut.set_exception(UPnPError("uncaught error: %s" % str(err))) + if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"): + fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) + wrapper = lambda : None asyncio.run(wrapper()) try: result = fut.result()