case insensitivity

This commit is contained in:
Jack Robison 2018-07-31 16:53:08 -04:00
parent bb8a1e33d6
commit b011360814
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
9 changed files with 221 additions and 120 deletions

View file

@ -38,9 +38,13 @@ def main():
if command not in ['debug_device', 'list_mappings']: if command not in ['debug_device', 'list_mappings']:
return sys.exit(0) return sys.exit(0)
def show(err):
print("error: {}".format(err))
u = UPnP(reactor) u = UPnP(reactor)
d = u.discover() d = u.discover()
d.addCallback(run_command, u, command) d.addCallback(run_command, u, command)
d.addErrback(show)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop)) d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run() reactor.run()

View file

@ -10,16 +10,22 @@ BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
CONTROL = 'urn:schemas-upnp-org:control-1-0' CONTROL = 'urn:schemas-upnp-org:control-1-0'
SERVICE = 'urn:schemas-upnp-org:service-1-0' SERVICE = 'urn:schemas-upnp-org:service-1-0'
DEVICE = 'urn:schemas-upnp-org:device-1-0' DEVICE = 'urn:schemas-upnp-org:device-1-0'
GATEWAY_SCHEMA = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
WIFI_ALLIANCE_ORG_IGD = "urn:schemas-wifialliance-org:device:WFADevice:1"
UPNP_ORG_IGD = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1' WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1' LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1' IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
service_types = [ service_types = [
GATEWAY_SCHEMA, UPNP_ORG_IGD,
WIFI_ALLIANCE_ORG_IGD,
WAN_SCHEMA, WAN_SCHEMA,
LAYER_SCHEMA, LAYER_SCHEMA,
IP_SCHEMA, IP_SCHEMA,
CONTROL, CONTROL,
SERVICE, SERVICE,
DEVICE, DEVICE,

View file

@ -1,90 +1,120 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
import treq import treq
import re
from xml.etree import ElementTree from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys from txupnp.util import etree_to_dict, flatten_keys, get_dict_val_case_insensitive
from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from txupnp.constants import DEVICE, ROOT from txupnp.constants import DEVICE, ROOT
from txupnp.constants import SPEC_VERSION from txupnp.constants import SPEC_VERSION
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
service_type_pattern = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\})"
)
class Service(object): xml_root_sanity_pattern = re.compile(
def __init__(self, serviceType, serviceId, SCPDURL, eventSubURL, controlURL): "(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
self.service_type = serviceType )
self.service_id = serviceId
self.control_path = controlURL
self.subscribe_path = eventSubURL
self.scpd_path = SCPDURL
def get_info(self):
return {
"service_type": self.service_type,
"service_id": self.service_id,
"control_path": self.control_path,
"subscribe_path": self.subscribe_path,
"scpd_path": self.scpd_path
}
class Device(object): class CaseInsensitive(object):
def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None, def __init__(self, **kwargs):
modelDescription=None, modelName=None, modelNumber=None, modelURL=None, serialNumber=None, not_evaluated = {}
UDN=None, serviceList=None, deviceList=None, **kwargs): for k, v in kwargs.items():
serviceList = serviceList or {} if k.startswith("_"):
deviceList = deviceList or {} not_evaluated[k] = v
self._root_device = _root_device continue
self.device_type = deviceType
self.friendly_name = friendlyName
self.manufacturer = manufacturer
self.manufacturer_url = manufacturerURL
self.model_description = modelDescription
self.model_name = modelName
self.model_number = modelNumber
self.model_url = modelURL
self.serial_number = serialNumber
self.udn = UDN
services = serviceList["service"]
if isinstance(services, dict):
services = [services]
services = [Service(**service) for service in services]
self._root_device.services.extend(services)
devices = [Device(self._root_device, **deviceList[k]) for k in deviceList]
self._root_device.devices.extend(devices)
def get_info(self):
return {
'device_type': self.device_type,
'friendly_name': self.friendly_name,
'manufacturers': self.manufacturer,
'model_name': self.model_name,
'model_number': self.model_number,
'serial_number': self.serial_number,
'udn': self.udn
}
class RootDevice(object):
def __init__(self, xml_string):
try: try:
root = flatten_keys(etree_to_dict(ElementTree.fromstring(xml_string)), "{%s}" % DEVICE)[ROOT] getattr(self, k)
except Exception as err: setattr(self, k, v)
if xml_string: except AttributeError as err:
log.exception("failed to decode xml: %s\n%s", err, xml_string) not_evaluated[k] = v
root = {} if not_evaluated:
self.spec_version = root.get(SPEC_VERSION) log.error("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
self.url_base = root.get("URLBase")
self.devices = [] def _get_attr_name(self, case_insensitive):
self.services = [] for k, v in self.__dict__.items():
if root: if k.lower() == case_insensitive.lower():
root_device = Device(self, **(root["device"])) return k
self.devices.append(root_device)
log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services)) def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
for k, v in self.__class__.__dict__.items():
if k.lower() == item.lower():
if k not in self.__dict__:
self.__dict__[k] = v
return v
raise AttributeError(item)
def __setattr__(self, item, value):
if item in self.__dict__:
self.__dict__[item] = value
return
to_update = None
for k, v in self.__dict__.items():
if k.lower() == item.lower():
to_update = k
break
self.__dict__[to_update or item] = value
def as_dict(self):
return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
}
class Service(CaseInsensitive):
serviceType = None
serviceId = None
controlURL = None
eventSubURL = None
SCPDURL = None
class Device(CaseInsensitive):
serviceList = None
deviceList = None
deviceType = None
friendlyName = None
manufacturer = None
manufacturerURL = None
modelDescription = None
modelName = None
modelNumber = None
modelURL = None
serialNumber = None
udn = None
presentationURL = None
iconList = None
def __init__(self, devices, services, **kwargs):
super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"]
if isinstance(new_services, dict):
new_services = [new_services]
services.extend([Service(**service) for service in new_services])
if self.deviceList:
devices.extend([Device(devices, services, **kw) for kw in self.deviceList.values()])
class Gateway(object): class Gateway(object):
def __init__(self, usn, server, location, st, cache_control="", date="", ext=""): def __init__(self, **kwargs):
flattened = {
k.lower(): v for k, v in kwargs.items()
}
usn = flattened["usn"]
server = flattened["server"]
location = flattened["location"]
st = flattened["st"]
cache_control = flattened.get("cache_control") or flattened.get("cache-control") or ""
date = flattened.get("date", "")
ext = flattened.get("ext", "")
self.usn = usn.encode() self.usn = usn.encode()
self.ext = ext.encode() self.ext = ext.encode()
self.server = server.encode() self.server = server.encode()
@ -92,54 +122,79 @@ class Gateway(object):
self.cache_control = cache_control.encode() self.cache_control = cache_control.encode()
self.date = date.encode() self.date = date.encode()
self.urn = st.encode() self.urn = st.encode()
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self._device = None self.xml_response = None
self.spec_version = None
self.url_base = None
def debug_device(self): self._device = None
devices = [] self._devices = []
for device in self._device.devices: self._services = []
info = device.get_info()
devices.append(info) def debug_device(self, include_xml=False, include_services=True):
services = [] r = {
for service in self._device.services: 'server': self.server,
info = service.get_info() 'urlBase': self.url_base,
services.append(info) 'location': self.location,
return { "specVersion": self.spec_version,
'root_url': self.base_address,
'gateway_xml_url': self.location,
'usn': self.usn, 'usn': self.usn,
'devices': devices, 'urn': self.urn,
'services': services
} }
if include_xml:
r['xml_response'] = self.xml_response
if include_services:
r['services'] = [service.as_dict() for service in self._services]
return r
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self): def discover_services(self):
log.debug("querying %s", self.location) log.debug("querying %s", self.location)
response = yield treq.get(self.location) response = yield treq.get(self.location)
response_xml = yield response.content() self.xml_response = yield response.content()
if not response_xml: if not self.xml_response:
log.error("service sent an empty reply\n%s", self.debug_device()) log.error("service sent an empty reply\n%s", self.debug_device())
try: xml_dict = etree_to_dict(ElementTree.fromstring(self.xml_response))
self._device = RootDevice(response_xml) schema_key = DEVICE
except Exception as err: root = ROOT
log.error("error parsing gateway: %s\n%s\n\n%s", err, self.debug_device(), response_xml) if len(xml_dict) > 1:
self._device = RootDevice("") log.warning(xml_dict.keys())
for k in xml_dict.keys():
m = xml_root_sanity_pattern.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
self.spec_version = get_dict_val_case_insensitive(flattened_xml, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(flattened_xml, "urlbase")
if flattened_xml:
self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(flattened_xml, "device")
)
log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices),
len(self.services))
else:
self._device = Device(self._devices, self._services)
log.debug("finished setting up gateway:\n%s", self.debug_device()) log.debug("finished setting up gateway:\n%s", self.debug_device())
@property @property
def services(self): def services(self):
if not self._device: if not self._device:
return {} return {}
return {service.service_type: service for service in self._device.services} return {service.serviceType: service for service in self._services}
@property @property
def devices(self): def devices(self):
if not self._device: if not self._device:
return {} return {}
return {device.udn: device for device in self._device.devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type): def get_service(self, service_type):
for service in self._device.services: for service in self._services:
if service.service_type.lower() == service_type.lower(): if service.serviceType.lower() == service_type.lower():
return service return service

View file

@ -139,14 +139,14 @@ class SCPDCommandRunner(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _discover_commands(self, service): def _discover_commands(self, service):
scpd_url = self._gateway.base_address + service.scpd_path.encode() scpd_url = self._gateway.base_address + service.SCPDURL.encode()
response = yield treq.get(scpd_url) response = yield treq.get(scpd_url)
content = yield response.content() content = yield response.content()
try: try:
scpd_response = SCPDResponse(scpd_url, scpd_response = SCPDResponse(scpd_url,
response.headers, content) response.headers, content)
for action_dict in scpd_response.get_action_list(): for action_dict in scpd_response.get_action_list():
self._register_command(action_dict, service.service_type) self._register_command(action_dict, service.serviceType)
except Exception as err: except Exception as err:
log.exception("failed to parse scpd response (%s) from %s\nheaders:\n%s\ncontent\n%s", log.exception("failed to parse scpd response (%s) from %s\nheaders:\n%s\ncontent\n%s",
err, scpd_url, response.headers, content) err, scpd_url, response.headers, content)
@ -182,8 +182,8 @@ class SCPDCommandRunner(object):
def _patch_command(self, action_info, service_type): def _patch_command(self, action_info, service_type):
name, inputs, outputs = self._soap_function_info(action_info) name, inputs, outputs = self._soap_function_info(action_info)
command = _SCPDCommand(self._gateway.base_address, self._gateway.port, command = _SCPDCommand(self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).control_path.encode(), self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(),
self._gateway.get_service(service_type).service_id.encode(), name, inputs, outputs, self._gateway.get_service(service_type).serviceId.encode(), name, inputs, outputs,
self._reactor, self._connection_pool, self._agent, self._http_client) self._reactor, self._connection_pool, self._agent, self._http_client)
current = getattr(self, command.method) current = getattr(self, command.method)
if hasattr(current, "_return_types"): if hasattr(current, "_return_types"):
@ -336,9 +336,9 @@ class UPnPFallback(object):
raise NotImplementedError() raise NotImplementedError()
devices = yield threads.deferToThread(self._upnp.discover) devices = yield threads.deferToThread(self._upnp.discover)
if devices: if devices:
device_url = yield threads.deferToThread(self._upnp.selectigd) self.device_url = yield threads.deferToThread(self._upnp.selectigd)
else: else:
device_url = None self.device_url = None
defer.returnValue(devices > 0) defer.returnValue(devices > 0)

View file

@ -5,7 +5,7 @@ from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandRunner from txupnp.scpd import SCPDCommandRunner
from txupnp.gateway import Gateway from txupnp.gateway import Gateway
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.constants import GATEWAY_SCHEMA from txupnp.constants import UPNP_ORG_IGD
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -16,7 +16,7 @@ class SOAPServiceManager(object):
self.iface_name, self.router_ip, self.lan_address = get_lan_info() self.iface_name, self.router_ip, self.lan_address = get_lan_info()
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip) self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
self._command_runners = {} self._command_runners = {}
self._selected_runner = GATEWAY_SCHEMA self._selected_runner = UPNP_ORG_IGD
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self, address=None, timeout=30, max_devices=1): def discover_services(self, address=None, timeout=30, max_devices=1):
@ -57,6 +57,21 @@ class SOAPServiceManager(object):
for runner in self._command_runners.values(): for runner in self._command_runners.values():
gateway = runner._gateway gateway = runner._gateway
info = gateway.debug_device() info = gateway.debug_device()
info.update(runner.debug_commands()) commands = runner.debug_commands()
service_result = []
for service in info['services']:
service_commands = []
unavailable = []
for command, service_type in commands['available'].items():
if service['serviceType'] == service_type:
service_commands.append(command)
for command, service_type in commands['failed'].items():
if service['serviceType'] == service_type:
unavailable.append(command)
services_with_commands = dict(service)
services_with_commands['available_commands'] = service_commands
services_with_commands['unavailable_commands'] = unavailable
service_result.append(services_with_commands)
info['services'] = service_result
results.append(info) results.append(info)
return results return results

View file

@ -2,7 +2,7 @@ import logging
import binascii import binascii
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol from twisted.internet.protocol import DatagramProtocol
from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types from txupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types
from txupnp.constants import SSDP_HOST from txupnp.constants import SSDP_HOST
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.ssdp_datagram import SSDPDatagram from txupnp.ssdp_datagram import SSDPDatagram
@ -25,7 +25,7 @@ class SSDPProtocol(DatagramProtocol):
self.max_devices = max_devices self.max_devices = max_devices
self.devices = [] self.devices = []
def _send_m_search(self, service=GATEWAY_SCHEMA): def _send_m_search(self, service=UPNP_ORG_IGD):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1) packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode()) log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
try: try:

View file

@ -6,6 +6,22 @@ from txupnp.constants import line_separator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_ssdp_datagram_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
class SSDPDatagram(object): class SSDPDatagram(object):
_M_SEARCH = "M-SEARCH" _M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY" _NOTIFY = "NOTIFY"
@ -23,20 +39,9 @@ class SSDPDatagram(object):
_OK: "m-search response" _OK: "m-search response"
} }
_vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") _vendor_field_pattern = _vendor_pattern
_patterns = { _patterns = _ssdp_datagram_patterns
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_required_fields = { _required_fields = {
_M_SEARCH: [ _M_SEARCH: [

View file

@ -1,5 +1,6 @@
import logging import logging
import json import json
import treq
from twisted.internet import defer from twisted.internet import defer
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.soap import SOAPServiceManager from txupnp.soap import SOAPServiceManager
@ -15,6 +16,7 @@ class UPnP(object):
self._miniupnpc_fallback = miniupnpc_fallback self._miniupnpc_fallback = miniupnpc_fallback
self.soap_manager = SOAPServiceManager(reactor) self.soap_manager = SOAPServiceManager(reactor)
self.miniupnpc_runner = None self.miniupnpc_runner = None
self._miniupnpc_igd_url = None
@property @property
def lan_address(self): def lan_address(self):
@ -57,6 +59,7 @@ class UPnP(object):
log.debug("trying miniupnpc fallback") log.debug("trying miniupnpc fallback")
fallback = UPnPFallback() fallback = UPnPFallback()
success = yield fallback.discover() success = yield fallback.discover()
self._miniupnpc_igd_url = fallback.device_url
if success: if success:
log.info("successfully started miniupnpc fallback") log.info("successfully started miniupnpc fallback")
self.miniupnpc_runner = fallback self.miniupnpc_runner = fallback
@ -164,4 +167,9 @@ class UPnP(object):
if isinstance(x, bytes): if isinstance(x, bytes):
return x.decode() return x.decode()
return x return x
return json.dumps(self.soap_manager.debug(), indent=2, default=default_byte) return json.dumps({
'txupnp': self.soap_manager.debug(),
'miniupnpc_igd_url': self._miniupnpc_igd_url
},
indent=2, default=default_byte
)

View file

@ -4,7 +4,6 @@ from collections import defaultdict
import netifaces import netifaces
from twisted.internet import defer from twisted.internet import defer
DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$")
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
@ -44,6 +43,15 @@ def flatten_keys(d, strip):
return t return t
def get_dict_val_case_insensitive(d, k):
match = list(filter(lambda x: x.lower() == k.lower(), d.keys()))
if not match:
return
if len(match) > 1:
raise KeyError("overlapping keys")
return d[match[0]]
def get_lan_info(): def get_lan_info():
gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET] gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr'] lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr']