diff --git a/txupnp/cli.py b/txupnp/cli.py index 38746f3..acfc5cf 100644 --- a/txupnp/cli.py +++ b/txupnp/cli.py @@ -38,9 +38,13 @@ def main(): if command not in ['debug_device', 'list_mappings']: return sys.exit(0) + def show(err): + print("error: {}".format(err)) + u = UPnP(reactor) d = u.discover() d.addCallback(run_command, u, command) + d.addErrback(show) d.addBoth(lambda _: reactor.callLater(0, reactor.stop)) reactor.run() diff --git a/txupnp/constants.py b/txupnp/constants.py index 0c47477..8aad336 100644 --- a/txupnp/constants.py +++ b/txupnp/constants.py @@ -10,16 +10,22 @@ BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body" CONTROL = 'urn:schemas-upnp-org:control-1-0' SERVICE = 'urn:schemas-upnp-org:service-1-0' DEVICE = 'urn:schemas-upnp-org:device-1-0' -GATEWAY_SCHEMA = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1' + +WIFI_ALLIANCE_ORG_IGD = "urn:schemas-wifialliance-org:device:WFADevice:1" +UPNP_ORG_IGD = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1' + WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1' LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1' IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1' service_types = [ - GATEWAY_SCHEMA, + UPNP_ORG_IGD, + WIFI_ALLIANCE_ORG_IGD, WAN_SCHEMA, LAYER_SCHEMA, IP_SCHEMA, + + CONTROL, SERVICE, DEVICE, diff --git a/txupnp/gateway.py b/txupnp/gateway.py index df757a3..44a0ac9 100644 --- a/txupnp/gateway.py +++ b/txupnp/gateway.py @@ -1,90 +1,120 @@ import logging from twisted.internet import defer import treq +import re from xml.etree import ElementTree -from txupnp.util import etree_to_dict, flatten_keys +from txupnp.util import etree_to_dict, flatten_keys, get_dict_val_case_insensitive from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX from txupnp.constants import DEVICE, ROOT from txupnp.constants import SPEC_VERSION log = logging.getLogger(__name__) +service_type_pattern = re.compile( + "(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\})" +) -class Service(object): - def __init__(self, serviceType, serviceId, SCPDURL, eventSubURL, controlURL): - self.service_type = serviceType - self.service_id = serviceId - self.control_path = controlURL - self.subscribe_path = eventSubURL - self.scpd_path = SCPDURL +xml_root_sanity_pattern = re.compile( + "(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))" +) - def get_info(self): + +class CaseInsensitive(object): + def __init__(self, **kwargs): + not_evaluated = {} + for k, v in kwargs.items(): + if k.startswith("_"): + not_evaluated[k] = v + continue + try: + getattr(self, k) + setattr(self, k, v) + except AttributeError as err: + not_evaluated[k] = v + if not_evaluated: + log.error("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated) + + def _get_attr_name(self, case_insensitive): + for k, v in self.__dict__.items(): + if k.lower() == case_insensitive.lower(): + return k + + def __getattr__(self, item): + if item in self.__dict__: + return self.__dict__[item] + for k, v in self.__class__.__dict__.items(): + if k.lower() == item.lower(): + if k not in self.__dict__: + self.__dict__[k] = v + return v + raise AttributeError(item) + + def __setattr__(self, item, value): + if item in self.__dict__: + self.__dict__[item] = value + return + to_update = None + for k, v in self.__dict__.items(): + if k.lower() == item.lower(): + to_update = k + break + self.__dict__[to_update or item] = value + + def as_dict(self): return { - "service_type": self.service_type, - "service_id": self.service_id, - "control_path": self.control_path, - "subscribe_path": self.subscribe_path, - "scpd_path": self.scpd_path + k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) } -class Device(object): - def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None, - modelDescription=None, modelName=None, modelNumber=None, modelURL=None, serialNumber=None, - UDN=None, serviceList=None, deviceList=None, **kwargs): - serviceList = serviceList or {} - deviceList = deviceList or {} - self._root_device = _root_device - self.device_type = deviceType - self.friendly_name = friendlyName - self.manufacturer = manufacturer - self.manufacturer_url = manufacturerURL - self.model_description = modelDescription - self.model_name = modelName - self.model_number = modelNumber - self.model_url = modelURL - self.serial_number = serialNumber - self.udn = UDN - services = serviceList["service"] - if isinstance(services, dict): - services = [services] - services = [Service(**service) for service in services] - self._root_device.services.extend(services) - devices = [Device(self._root_device, **deviceList[k]) for k in deviceList] - self._root_device.devices.extend(devices) - - def get_info(self): - return { - 'device_type': self.device_type, - 'friendly_name': self.friendly_name, - 'manufacturers': self.manufacturer, - 'model_name': self.model_name, - 'model_number': self.model_number, - 'serial_number': self.serial_number, - 'udn': self.udn - } +class Service(CaseInsensitive): + serviceType = None + serviceId = None + controlURL = None + eventSubURL = None + SCPDURL = None -class RootDevice(object): - def __init__(self, xml_string): - try: - root = flatten_keys(etree_to_dict(ElementTree.fromstring(xml_string)), "{%s}" % DEVICE)[ROOT] - except Exception as err: - if xml_string: - log.exception("failed to decode xml: %s\n%s", err, xml_string) - root = {} - self.spec_version = root.get(SPEC_VERSION) - self.url_base = root.get("URLBase") - self.devices = [] - self.services = [] - if root: - root_device = Device(self, **(root["device"])) - self.devices.append(root_device) - log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services)) +class Device(CaseInsensitive): + serviceList = None + deviceList = None + deviceType = None + friendlyName = None + manufacturer = None + manufacturerURL = None + modelDescription = None + modelName = None + modelNumber = None + modelURL = None + serialNumber = None + udn = None + presentationURL = None + iconList = None + + def __init__(self, devices, services, **kwargs): + super(Device, self).__init__(**kwargs) + if self.serviceList and "service" in self.serviceList: + new_services = self.serviceList["service"] + if isinstance(new_services, dict): + new_services = [new_services] + services.extend([Service(**service) for service in new_services]) + if self.deviceList: + devices.extend([Device(devices, services, **kw) for kw in self.deviceList.values()]) class Gateway(object): - def __init__(self, usn, server, location, st, cache_control="", date="", ext=""): + def __init__(self, **kwargs): + flattened = { + k.lower(): v for k, v in kwargs.items() + } + usn = flattened["usn"] + server = flattened["server"] + location = flattened["location"] + st = flattened["st"] + + cache_control = flattened.get("cache_control") or flattened.get("cache-control") or "" + date = flattened.get("date", "") + ext = flattened.get("ext", "") + self.usn = usn.encode() self.ext = ext.encode() self.server = server.encode() @@ -92,54 +122,79 @@ class Gateway(object): self.cache_control = cache_control.encode() self.date = date.encode() self.urn = st.encode() + self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) - self._device = None + self.xml_response = None + self.spec_version = None + self.url_base = None - def debug_device(self): - devices = [] - for device in self._device.devices: - info = device.get_info() - devices.append(info) - services = [] - for service in self._device.services: - info = service.get_info() - services.append(info) - return { - 'root_url': self.base_address, - 'gateway_xml_url': self.location, + self._device = None + self._devices = [] + self._services = [] + + def debug_device(self, include_xml=False, include_services=True): + r = { + 'server': self.server, + 'urlBase': self.url_base, + 'location': self.location, + "specVersion": self.spec_version, 'usn': self.usn, - 'devices': devices, - 'services': services + 'urn': self.urn, } + if include_xml: + r['xml_response'] = self.xml_response + if include_services: + r['services'] = [service.as_dict() for service in self._services] + + return r @defer.inlineCallbacks def discover_services(self): log.debug("querying %s", self.location) response = yield treq.get(self.location) - response_xml = yield response.content() - if not response_xml: + self.xml_response = yield response.content() + if not self.xml_response: log.error("service sent an empty reply\n%s", self.debug_device()) - try: - self._device = RootDevice(response_xml) - except Exception as err: - log.error("error parsing gateway: %s\n%s\n\n%s", err, self.debug_device(), response_xml) - self._device = RootDevice("") + xml_dict = etree_to_dict(ElementTree.fromstring(self.xml_response)) + schema_key = DEVICE + root = ROOT + if len(xml_dict) > 1: + log.warning(xml_dict.keys()) + for k in xml_dict.keys(): + m = xml_root_sanity_pattern.findall(k) + if len(m) == 3 and m[1][0] and m[2][5]: + schema_key = m[1][0] + root = m[2][5] + break + + flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root] + self.spec_version = get_dict_val_case_insensitive(flattened_xml, SPEC_VERSION) + self.url_base = get_dict_val_case_insensitive(flattened_xml, "urlbase") + + if flattened_xml: + self._device = Device( + self._devices, self._services, **get_dict_val_case_insensitive(flattened_xml, "device") + ) + log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices), + len(self.services)) + else: + self._device = Device(self._devices, self._services) log.debug("finished setting up gateway:\n%s", self.debug_device()) @property def services(self): if not self._device: return {} - return {service.service_type: service for service in self._device.services} + return {service.serviceType: service for service in self._services} @property def devices(self): if not self._device: return {} - return {device.udn: device for device in self._device.devices} + return {device.udn: device for device in self._devices} def get_service(self, service_type): - for service in self._device.services: - if service.service_type.lower() == service_type.lower(): + for service in self._services: + if service.serviceType.lower() == service_type.lower(): return service diff --git a/txupnp/scpd.py b/txupnp/scpd.py index 4f80c05..c9e239d 100644 --- a/txupnp/scpd.py +++ b/txupnp/scpd.py @@ -139,14 +139,14 @@ class SCPDCommandRunner(object): @defer.inlineCallbacks def _discover_commands(self, service): - scpd_url = self._gateway.base_address + service.scpd_path.encode() + scpd_url = self._gateway.base_address + service.SCPDURL.encode() response = yield treq.get(scpd_url) content = yield response.content() try: scpd_response = SCPDResponse(scpd_url, response.headers, content) for action_dict in scpd_response.get_action_list(): - self._register_command(action_dict, service.service_type) + self._register_command(action_dict, service.serviceType) except Exception as err: log.exception("failed to parse scpd response (%s) from %s\nheaders:\n%s\ncontent\n%s", err, scpd_url, response.headers, content) @@ -182,8 +182,8 @@ class SCPDCommandRunner(object): def _patch_command(self, action_info, service_type): name, inputs, outputs = self._soap_function_info(action_info) command = _SCPDCommand(self._gateway.base_address, self._gateway.port, - self._gateway.base_address + self._gateway.get_service(service_type).control_path.encode(), - self._gateway.get_service(service_type).service_id.encode(), name, inputs, outputs, + self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(), + self._gateway.get_service(service_type).serviceId.encode(), name, inputs, outputs, self._reactor, self._connection_pool, self._agent, self._http_client) current = getattr(self, command.method) if hasattr(current, "_return_types"): @@ -336,9 +336,9 @@ class UPnPFallback(object): raise NotImplementedError() devices = yield threads.deferToThread(self._upnp.discover) if devices: - device_url = yield threads.deferToThread(self._upnp.selectigd) + self.device_url = yield threads.deferToThread(self._upnp.selectigd) else: - device_url = None + self.device_url = None defer.returnValue(devices > 0) diff --git a/txupnp/soap.py b/txupnp/soap.py index d1bb0ce..2c4b1ea 100644 --- a/txupnp/soap.py +++ b/txupnp/soap.py @@ -5,7 +5,7 @@ from txupnp.ssdp import SSDPFactory from txupnp.scpd import SCPDCommandRunner from txupnp.gateway import Gateway from txupnp.fault import UPnPError -from txupnp.constants import GATEWAY_SCHEMA +from txupnp.constants import UPNP_ORG_IGD log = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class SOAPServiceManager(object): self.iface_name, self.router_ip, self.lan_address = get_lan_info() self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip) self._command_runners = {} - self._selected_runner = GATEWAY_SCHEMA + self._selected_runner = UPNP_ORG_IGD @defer.inlineCallbacks def discover_services(self, address=None, timeout=30, max_devices=1): @@ -57,6 +57,21 @@ class SOAPServiceManager(object): for runner in self._command_runners.values(): gateway = runner._gateway info = gateway.debug_device() - info.update(runner.debug_commands()) + commands = runner.debug_commands() + service_result = [] + for service in info['services']: + service_commands = [] + unavailable = [] + for command, service_type in commands['available'].items(): + if service['serviceType'] == service_type: + service_commands.append(command) + for command, service_type in commands['failed'].items(): + if service['serviceType'] == service_type: + unavailable.append(command) + services_with_commands = dict(service) + services_with_commands['available_commands'] = service_commands + services_with_commands['unavailable_commands'] = unavailable + service_result.append(services_with_commands) + info['services'] = service_result results.append(info) return results diff --git a/txupnp/ssdp.py b/txupnp/ssdp.py index 61f7a31..48f513d 100644 --- a/txupnp/ssdp.py +++ b/txupnp/ssdp.py @@ -2,7 +2,7 @@ import logging import binascii from twisted.internet import defer from twisted.internet.protocol import DatagramProtocol -from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types +from txupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types from txupnp.constants import SSDP_HOST from txupnp.fault import UPnPError from txupnp.ssdp_datagram import SSDPDatagram @@ -25,7 +25,7 @@ class SSDPProtocol(DatagramProtocol): self.max_devices = max_devices self.devices = [] - def _send_m_search(self, service=GATEWAY_SCHEMA): + def _send_m_search(self, service=UPNP_ORG_IGD): packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1) log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode()) try: diff --git a/txupnp/ssdp_datagram.py b/txupnp/ssdp_datagram.py index 1dadf94..ad440ba 100644 --- a/txupnp/ssdp_datagram.py +++ b/txupnp/ssdp_datagram.py @@ -6,6 +6,22 @@ from txupnp.constants import line_separator log = logging.getLogger(__name__) +_ssdp_datagram_patterns = { + 'host': (re.compile("^(?i)(host):(.*)$"), str), + 'st': (re.compile("^(?i)(st):(.*)$"), str), + 'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str), + 'mx': (re.compile("^(?i)(mx):(.*)$"), int), + 'nt': (re.compile("^(?i)(nt):(.*)$"), str), + 'nts': (re.compile("^(?i)(nts):(.*)$"), str), + 'usn': (re.compile("^(?i)(usn):(.*)$"), str), + 'location': (re.compile("^(?i)(location):(.*)$"), str), + 'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str), + 'server': (re.compile("^(?i)(server):(.*)$"), str), +} + +_vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") + + class SSDPDatagram(object): _M_SEARCH = "M-SEARCH" _NOTIFY = "NOTIFY" @@ -23,20 +39,9 @@ class SSDPDatagram(object): _OK: "m-search response" } - _vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") + _vendor_field_pattern = _vendor_pattern - _patterns = { - 'host': (re.compile("^(?i)(host):(.*)$"), str), - 'st': (re.compile("^(?i)(st):(.*)$"), str), - 'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str), - 'mx': (re.compile("^(?i)(mx):(.*)$"), int), - 'nt': (re.compile("^(?i)(nt):(.*)$"), str), - 'nts': (re.compile("^(?i)(nts):(.*)$"), str), - 'usn': (re.compile("^(?i)(usn):(.*)$"), str), - 'location': (re.compile("^(?i)(location):(.*)$"), str), - 'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str), - 'server': (re.compile("^(?i)(server):(.*)$"), str), - } + _patterns = _ssdp_datagram_patterns _required_fields = { _M_SEARCH: [ diff --git a/txupnp/upnp.py b/txupnp/upnp.py index 6bff78c..d14a44c 100644 --- a/txupnp/upnp.py +++ b/txupnp/upnp.py @@ -1,5 +1,6 @@ import logging import json +import treq from twisted.internet import defer from txupnp.fault import UPnPError from txupnp.soap import SOAPServiceManager @@ -15,6 +16,7 @@ class UPnP(object): self._miniupnpc_fallback = miniupnpc_fallback self.soap_manager = SOAPServiceManager(reactor) self.miniupnpc_runner = None + self._miniupnpc_igd_url = None @property def lan_address(self): @@ -57,6 +59,7 @@ class UPnP(object): log.debug("trying miniupnpc fallback") fallback = UPnPFallback() success = yield fallback.discover() + self._miniupnpc_igd_url = fallback.device_url if success: log.info("successfully started miniupnpc fallback") self.miniupnpc_runner = fallback @@ -164,4 +167,9 @@ class UPnP(object): if isinstance(x, bytes): return x.decode() return x - return json.dumps(self.soap_manager.debug(), indent=2, default=default_byte) + return json.dumps({ + 'txupnp': self.soap_manager.debug(), + 'miniupnpc_igd_url': self._miniupnpc_igd_url + }, + indent=2, default=default_byte + ) diff --git a/txupnp/util.py b/txupnp/util.py index 030eb10..80eb0d1 100644 --- a/txupnp/util.py +++ b/txupnp/util.py @@ -4,7 +4,6 @@ from collections import defaultdict import netifaces from twisted.internet import defer -DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$") BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) @@ -44,6 +43,15 @@ def flatten_keys(d, strip): return t +def get_dict_val_case_insensitive(d, k): + match = list(filter(lambda x: x.lower() == k.lower(), d.keys())) + if not match: + return + if len(match) > 1: + raise KeyError("overlapping keys") + return d[match[0]] + + def get_lan_info(): gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET] lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr']