From 0956d6a71ee2885223bbf329f90646579b5efabd Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Wed, 10 Oct 2018 19:39:45 -0400 Subject: [PATCH] fuzzy m search --- aioupnp/__main__.py | 18 ++-- aioupnp/gateway.py | 5 +- aioupnp/protocols/m_search_patterns.py | 81 ++++++++++++++++ aioupnp/protocols/ssdp.py | 129 ++++++++++++++----------- aioupnp/upnp.py | 22 +++-- 5 files changed, 177 insertions(+), 78 deletions(-) create mode 100644 aioupnp/protocols/m_search_patterns.py diff --git a/aioupnp/__main__.py b/aioupnp/__main__.py index 392747f..d219692 100644 --- a/aioupnp/__main__.py +++ b/aioupnp/__main__.py @@ -1,6 +1,8 @@ import logging import sys +from collections import OrderedDict from aioupnp.upnp import UPnP +from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER, SSDP_HOST log = logging.getLogger("aioupnp") handler = logging.StreamHandler() @@ -45,13 +47,14 @@ def main(): 'gateway_address': '', 'lan_address': '', 'timeout': 1, - 'service': '', # if not provided try all of them - 'man': '', - 'mx': 1, - 'return_as_json': True + + 'HOST': SSDP_HOST, + 'ST': UPNP_ORG_IGD, + 'MAN': SSDP_DISCOVER, + 'MX': 1, } - options = {} + options = OrderedDict() command = None for arg in args: if arg.startswith("--"): @@ -80,9 +83,8 @@ def main(): log.setLevel(logging.DEBUG) UPnP.run_cli( - command.replace('-', '_'), options.pop('lan_address'), options.pop('gateway_address'), - options.pop('timeout'), options.pop('service'), options.pop('man'), options.pop('mx'), - options.pop('interface'), kwargs + command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'), + options.pop('timeout'), options.pop('interface'), kwargs ) diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index 42ce641..03bc27c 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -5,7 +5,7 @@ from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_AD from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE from aioupnp.commands import SOAPCommands from aioupnp.device import Device, Service -from aioupnp.protocols.ssdp import m_search +from aioupnp.protocols.ssdp import fuzzy_m_search from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.soap import SCPDCommand from aioupnp.util import flatten_keys @@ -129,9 +129,8 @@ class Gateway: @classmethod async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1, - service: str = UPNP_ORG_IGD, man: str = '', mx: int = 1, ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): - datagram = await m_search(lan_address, gateway_address, timeout, service, man, mx, ssdp_socket) + datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) gateway = cls(**datagram.as_dict()) await gateway.discover_commands(soap_socket) return gateway diff --git a/aioupnp/protocols/m_search_patterns.py b/aioupnp/protocols/m_search_patterns.py new file mode 100644 index 0000000..c3322fa --- /dev/null +++ b/aioupnp/protocols/m_search_patterns.py @@ -0,0 +1,81 @@ +M_SEARCH_ARG_PATTERNS = [ + # + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('ST', lambda s: s), + ('MAN', lambda s: '"%s"' % s), + ('MX', lambda n: int(n)), + ], + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('ST', lambda s: s), + ('Man', lambda s: '"%s"' % s), + ('MX', lambda n: int(n)), + ], + [ + ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('ST', lambda s: s), + ('Man', lambda s: '"%s"' % s), + ('MX', lambda n: int(n)), + ], + + # swap st and man + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('MAN', lambda s: '"%s"' % s), + ('ST', lambda s: s), + ('MX', lambda n: int(n)), + ], + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('Man', lambda s: '"%s"' % s), + ('ST', lambda s: s), + ('MX', lambda n: int(n)), + ], + [ + ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('Man', lambda s: '"%s"' % s), + ('ST', lambda s: s), + ('MX', lambda n: int(n)), + ], + + # repeat above but with no quotes on man + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('ST', lambda s: s), + ('MAN', lambda s: s), + ('MX', lambda n: int(n)), + ], + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('ST', lambda s: s), + ('Man', lambda s: s), + ('MX', lambda n: int(n)), + ], + [ + ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('ST', lambda s: s), + ('Man', lambda s: s), + ('MX', lambda n: int(n)), + ], + + # swap st and man + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('MAN', lambda s: s), + ('ST', lambda s: s), + ('MX', lambda n: int(n)), + ], + [ + ('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('Man', lambda s: s), + ('ST', lambda s: s), + ('MX', lambda n: int(n)), + ], + [ + ('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), + ('Man', lambda s: s), + ('ST', lambda s: str(s)), + ('MX', lambda n: int(n)), + ], +] \ No newline at end of file diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 6f801d7..c7d0c82 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -3,14 +3,16 @@ import socket import binascii import asyncio import logging +from collections import OrderedDict from typing import Dict, List, Tuple from asyncio.futures import Future from asyncio.transports import DatagramTransport from aioupnp.fault import UPnPError from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.constants import UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD -from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE +from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE, SSDP_ALL from aioupnp.protocols.multicast import MulticastProtocol +from aioupnp.protocols.m_search_patterns import M_SEARCH_ARG_PATTERNS ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$") @@ -25,41 +27,17 @@ class SSDPProtocol(MulticastProtocol): self.notifications: List = [] self.replies: List = [] - def send_m_search_packet(self, service, address, man, mx): - packet = SSDPDatagram( - SSDPDatagram._M_SEARCH, host="{}:{}".format(SSDP_IP_ADDRESS, SSDP_PORT), st=service, - man=man, mx=mx - ) - log.debug("sending packet to %s:%i: %s", address, SSDP_PORT, packet) + def m_search(self, address: str, timeout: int, datagram_args: OrderedDict) -> Future: + packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram_args) + f: Future = Future() + futs = self.discover_callbacks.get((address, packet.st), []) + futs.append(f) + self.discover_callbacks[(address, packet.st)] = futs + log.debug("send m search to %s: %s", address, packet) self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) - async def m_search(self, address: str, timeout: int = 1, service: str = '', man: str = '', mx: int = 1) -> SSDPDatagram: - if (address, service) in self.discover_callbacks: - return self.discover_callbacks[(address, service)] - man = man or SSDP_DISCOVER - if not service: - services = [UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD] - else: - services = [service] - - search_futs: List[Future] = [] - outer_fut: Future = Future() - - for service in services: - # D-Link works with both - - # Cisco only works with quotes - self.send_m_search_packet(service, address, '\"%s\"' % man, mx) - - # DD-WRT only works without quotes - self.send_m_search_packet(service, address, man, mx) - - f: Future = Future() - f.add_done_callback(lambda _f: outer_fut.set_result(_f.result())) - self.discover_callbacks[(address, service)] = f - search_futs.append(f) - - return await asyncio.wait_for(outer_fut, timeout) + r: Future = asyncio.ensure_future(asyncio.wait_for(f, timeout)) + return r def datagram_received(self, data, addr) -> None: if addr[0] == self.lan_address: @@ -77,28 +55,29 @@ class SSDPProtocol(MulticastProtocol): log.debug("%s:%i replied to our m-search", addr[0], addr[1]) if packet.st not in map(lambda p: p['st'], self.replies): self.replies.append(packet) - ok_fut: Future = self.discover_callbacks.pop((addr[0], packet.st)) - ok_fut.set_result(packet) + for ok_fut in self.discover_callbacks[(addr[0], packet.st)]: + ok_fut.set_result(packet) + del self.discover_callbacks[(addr[0], packet.st)] return - elif packet._packet_type == packet._NOTIFY: - log.debug("%s:%i sent us a notification: %s", packet) - if packet.nt == SSDP_ROOT_DEVICE: - address, port, path = ADDRESS_REGEX.findall(packet.location)[0] - key = None - for (addr, service) in self.discover_callbacks: - if addr == address: - key = (addr, service) - break - if key: - log.debug("got a notification with the requested m-search info") - notify_fut: Future = self.discover_callbacks.pop(key) - notify_fut.set_result(SSDPDatagram( - SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server, - st=UPNP_ORG_IGD, usn=packet.usn - )) - self.notifications.append(packet.as_dict()) - return + # elif packet._packet_type == packet._NOTIFY: + # log.debug("%s:%i sent us a notification: %s", packet) + # if packet.nt == SSDP_ROOT_DEVICE: + # address, port, path = ADDRESS_REGEX.findall(packet.location)[0] + # key = None + # for (addr, service) in self.discover_callbacks: + # if addr == address: + # key = (addr, service) + # break + # if key: + # log.debug("got a notification with the requested m-search info") + # notify_fut: Future = self.discover_callbacks.pop(key) + # notify_fut.set_result(SSDPDatagram( + # SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server, + # st=UPNP_ORG_IGD, usn=packet.usn + # )) + # self.notifications.append(packet.as_dict()) + # return async def listen_ssdp(lan_address: str, gateway_address: str, @@ -124,14 +103,50 @@ async def listen_ssdp(lan_address: str, gateway_address: str, return transport, protocol, gateway_address, lan_address -async def m_search(lan_address: str, gateway_address: str, timeout: int = 1, - service: str = '', man: str = '', mx: int = 1, ssdp_socket: socket.socket = None) -> SSDPDatagram: +async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, + ssdp_socket: socket.socket = None) -> SSDPDatagram: transport, protocol, gateway_address, lan_address = await listen_ssdp( lan_address, gateway_address, ssdp_socket ) try: - return await protocol.m_search(address=gateway_address, timeout=timeout, service=service, man=man, mx=mx) + return await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=datagram_args) except asyncio.TimeoutError: raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) finally: transport.close() + + +async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 1, + ssdp_socket: socket.socket = None) -> SSDPDatagram: + transport, protocol, gateway_address, lan_address = await listen_ssdp( + lan_address, gateway_address, ssdp_socket + ) + datagram_kwargs: list = [] + services = [UPNP_ORG_IGD, SSDP_ALL, WIFI_ALLIANCE_ORG_IGD] + mans = [SSDP_DISCOVER, SSDP_ROOT_DEVICE] + mx = 1 + + for service in services: + for man in mans: + for arg_pattern in M_SEARCH_ARG_PATTERNS: + dgram_kwargs: OrderedDict = OrderedDict() + for k, l in arg_pattern: + if k.lower() == 'host': + dgram_kwargs[k] = l(SSDP_IP_ADDRESS) + elif k.lower() == 'st': + dgram_kwargs[k] = l(service) + elif k.lower() == 'man': + dgram_kwargs[k] = l(man) + elif k.lower() == 'mx': + dgram_kwargs[k] = l(mx) + datagram_kwargs.append(dgram_kwargs) + + for i, args in enumerate(datagram_kwargs): + try: + result = await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=args) + transport.close() + return result + except TimeoutError: + pass + transport.close() + raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index 8a3b49d..ac12533 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -3,11 +3,12 @@ import socket import logging import json import asyncio +from collections import OrderedDict from typing import Tuple, Dict, List, Union from aioupnp.fault import UPnPError from aioupnp.gateway import Gateway from aioupnp.util import get_gateway_and_lan_addresses -from aioupnp.protocols.ssdp import m_search +from aioupnp.protocols.ssdp import m_search, fuzzy_m_search log = logging.getLogger(__name__) @@ -42,23 +43,24 @@ class UPnP: @classmethod async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, - service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default', + interface_name: str = 'default', ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): try: lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) except Exception as err: raise UPnPError("failed to get lan and gateway addresses: %s" % str(err)) gateway = await Gateway.discover_gateway( - lan_address, gateway_address, timeout, service, man, mx, ssdp_socket, soap_socket + lan_address, gateway_address, timeout, ssdp_socket, soap_socket ) return cls(lan_address, gateway_address, gateway) @classmethod @cli async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, - service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default') -> Dict: + args: OrderedDict = None, interface_name: str = 'default') -> Dict: + args = args or OrderedDict() lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) - datagram = await m_search(lan_address, gateway_address, timeout, service, man, mx) + datagram = await m_search(lan_address, gateway_address, args, timeout) return { 'lan_address': lan_address, 'gateway_address': gateway_address, @@ -215,10 +217,10 @@ class UPnP: return "Generated test data! -> %s" % device_path @classmethod - def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60, - service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default', - kwargs: dict = None) -> None: + def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 60, + interface_name: str = 'default', kwargs: dict = None) -> None: kwargs = kwargs or {} + igd_args = igd_args timeout = int(timeout) try: asyncio.get_running_loop() @@ -231,12 +233,12 @@ class UPnP: async def wrapper(): if method == 'm_search': fn = lambda *_a, **_kw: cls.m_search( - lan_address, gateway_address, timeout, service, man, mx, interface_name + lan_address, gateway_address, timeout, igd_args, interface_name ) else: try: u = await cls.discover( - lan_address, gateway_address, timeout, service, man, mx, interface_name + lan_address, gateway_address, timeout, interface_name ) except UPnPError as err: fut.set_exception(err)