diff --git a/txupnp/fault.py b/txupnp/fault.py index 94b1084..70da4dd 100644 --- a/txupnp/fault.py +++ b/txupnp/fault.py @@ -6,7 +6,7 @@ class UPnPError(Exception): pass -def handle_fault(response): +def handle_fault(response: dict) -> dict: if FAULT in response: fault = flatten_keys(response[FAULT], "{%s}" % CONTROL) error_description = fault['detail']['UPnPError']['errorDescription'] diff --git a/txupnp/gateway.py b/txupnp/gateway.py index 726ad5e..77b36ca 100644 --- a/txupnp/gateway.py +++ b/txupnp/gateway.py @@ -19,7 +19,7 @@ xml_root_sanity_pattern = re.compile( ) -class CaseInsensitive(object): +class CaseInsensitive: def __init__(self, **kwargs): not_evaluated = {} for k, v in kwargs.items(): @@ -34,7 +34,7 @@ class CaseInsensitive(object): if not_evaluated: log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated) - def _get_attr_name(self, case_insensitive): + def _get_attr_name(self, case_insensitive: str) -> str: for k, v in self.__dict__.items(): if k.lower() == case_insensitive.lower(): return k @@ -60,7 +60,7 @@ class CaseInsensitive(object): break self.__dict__[to_update or item] = value - def as_dict(self): + def as_dict(self) -> dict: return { k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) } @@ -111,7 +111,7 @@ class Device(CaseInsensitive): log.warning("failed to parse device:\n%s", kw) -class Gateway(object): +class Gateway: def __init__(self, **kwargs): flattened = { k.lower(): v for k, v in kwargs.items() @@ -143,7 +143,7 @@ class Gateway(object): self._devices = [] self._services = [] - def debug_device(self, include_xml=False, include_services=True): + def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict: r = { 'server': self.server, 'urlBase': self.url_base, @@ -194,18 +194,18 @@ class Gateway(object): log.debug("finished setting up gateway:\n%s", self.debug_device()) @property - def services(self): + def services(self) -> dict: if not self._device: return {} return {service.serviceType: service for service in self._services} @property - def devices(self): + def devices(self) -> dict: if not self._device: return {} return {device.udn: device for device in self._devices} - def get_service(self, service_type): + def get_service(self, service_type) -> Service: for service in self._services: if service.serviceType.lower() == service_type.lower(): return service diff --git a/txupnp/scpd.py b/txupnp/scpd.py index 9c0248a..961d406 100644 --- a/txupnp/scpd.py +++ b/txupnp/scpd.py @@ -5,7 +5,7 @@ from twisted.web.client import Agent import treq from treq.client import HTTPClient from xml.etree import ElementTree -from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none +from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none from txupnp.fault import handle_fault, UPnPError from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION from txupnp.constants import BODY, POST @@ -14,7 +14,7 @@ from txupnp.dirty_pool import DirtyPool log = logging.getLogger(__name__) -class StringProducer(object): +class StringProducer: def __init__(self, body): self.body = body self.length = len(body) @@ -64,7 +64,7 @@ class _SCPDCommand(object): response_key = key break if not response_key: - raise UPnPError("unknown response fields") + raise UPnPError("unknown response fields for %s") response = body[response_key] extracted_response = tuple([response[n] for n in self.returns]) if len(extracted_response) == 1: @@ -143,7 +143,7 @@ class SCPDResponse(object): class SCPDCommandRunner(object): - def __init__(self, gateway, reactor): + def __init__(self, gateway, reactor, treq_get=None): self._gateway = gateway self._unsupported_actions = {} self._registered_commands = {} @@ -151,15 +151,15 @@ class SCPDCommandRunner(object): self._connection_pool = DirtyPool(reactor) self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool) self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer) + self._treq_get = treq_get or treq.get @defer.inlineCallbacks def _discover_commands(self, service): scpd_url = self._gateway.base_address + service.SCPDURL.encode() - response = yield treq.get(scpd_url) + response = yield self._treq_get(scpd_url) content = yield response.content() try: - scpd_response = SCPDResponse(scpd_url, - response.headers, content) + scpd_response = SCPDResponse(scpd_url, response.headers, content) for action_dict in scpd_response.get_action_list(): self._register_command(action_dict, service.serviceType) except Exception as err: @@ -200,7 +200,7 @@ class SCPDCommandRunner(object): self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs) current = getattr(self, command.method) if hasattr(current, "_return_types"): - command._process_result = _return_types(*current._return_types)(command._process_result) + command._process_result = verify_return_types(*current._return_types)(command._process_result) setattr(command, "__doc__", current.__doc__) setattr(self, command.method, command) self._registered_commands[command.method] = service_type diff --git a/txupnp/soap.py b/txupnp/soap.py index 1e53b7b..34b200c 100644 --- a/txupnp/soap.py +++ b/txupnp/soap.py @@ -1,6 +1,6 @@ import logging +import netifaces from twisted.internet import defer -from txupnp.util import get_lan_info from txupnp.ssdp import SSDPFactory from txupnp.scpd import SCPDCommandRunner from txupnp.gateway import Gateway @@ -11,12 +11,14 @@ log = logging.getLogger(__name__) class SOAPServiceManager(object): - def __init__(self, reactor): + def __init__(self, reactor, treq_get=None): self._reactor = reactor - self.iface_name, self.router_ip, self.lan_address = get_lan_info() + self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET] + self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr'] self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip) self._command_runners = {} self._selected_runner = UPNP_ORG_IGD + self._treq_get = treq_get @defer.inlineCallbacks def discover_services(self, address=None, timeout=30, max_devices=1): @@ -29,7 +31,7 @@ class SOAPServiceManager(object): locations.append(server_info['location']) gateway = Gateway(**server_info) yield gateway.discover_services() - command_runner = SCPDCommandRunner(gateway, self._reactor) + command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get) yield command_runner.discover_commands() self._command_runners[gateway.urn.decode()] = command_runner elif 'st' not in server_info: diff --git a/txupnp/ssdp.py b/txupnp/ssdp.py index e7f4397..a402710 100644 --- a/txupnp/ssdp.py +++ b/txupnp/ssdp.py @@ -31,7 +31,7 @@ class SSDPProtocol(DatagramProtocol): try: self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port)) except Exception as err: - log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port) + log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port) raise err @staticmethod diff --git a/txupnp/ssdp_datagram.py b/txupnp/ssdp_datagram.py index dbc8e8f..0eed52d 100644 --- a/txupnp/ssdp_datagram.py +++ b/txupnp/ssdp_datagram.py @@ -1,11 +1,11 @@ import re import logging +import binascii from txupnp.fault import UPnPError from txupnp.constants import line_separator log = logging.getLogger(__name__) - _ssdp_datagram_patterns = { 'host': (re.compile("^(?i)(host):(.*)$"), str), 'st': (re.compile("^(?i)(st):(.*)$"), str), @@ -135,7 +135,7 @@ class SSDPDatagram(object): return packet @classmethod - def _lines_to_content_dict(cls, lines): + def _lines_to_content_dict(cls, lines: list) -> dict: result = {} for line in lines: if not line: @@ -158,7 +158,7 @@ class SSDPDatagram(object): return result @classmethod - def _from_string(cls, datagram): + def _from_string(cls, datagram: str): lines = [l for l in datagram.split(line_separator) if l] if lines[0] == cls._start_lines[cls._M_SEARCH]: return cls._from_request(lines[1:]) diff --git a/txupnp/upnp.py b/txupnp/upnp.py index 09ec414..ba5d670 100644 --- a/txupnp/upnp.py +++ b/txupnp/upnp.py @@ -9,10 +9,10 @@ log = logging.getLogger(__name__) class UPnP(object): - def __init__(self, reactor, try_miniupnpc_fallback=True): + def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None): self._reactor = reactor self.try_miniupnpc_fallback = try_miniupnpc_fallback - self.soap_manager = SOAPServiceManager(reactor) + self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get) self.miniupnpc_runner = None self.miniupnpc_igd_url = None diff --git a/txupnp/util.py b/txupnp/util.py index 80eb0d1..f53aca7 100644 --- a/txupnp/util.py +++ b/txupnp/util.py @@ -1,14 +1,14 @@ import re import functools from collections import defaultdict -import netifaces from twisted.internet import defer +from xml.etree import ElementTree BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) -def etree_to_dict(t): +def etree_to_dict(t: ElementTree) -> dict: d = {t.tag: {} if t.attrib else None} children = list(t) if children: @@ -52,14 +52,12 @@ def get_dict_val_case_insensitive(d, k): return d[match[0]] -def get_lan_info(): - gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET] - lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr'] - return iface_name, gateway_address, lan_addr +def verify_return_types(*types): + """ + Attempt to recast results to expected result types + """ - -def _return_types(*types): - def _return_types_wrapper(fn): + def _verify_return_types(fn): @functools.wraps(fn) def _inner(response): if isinstance(response, (list, tuple)): @@ -69,10 +67,14 @@ def _return_types(*types): return fn(r) return fn(types[0](response)) return _inner - return _return_types_wrapper + return _verify_return_types def return_types(*types): + """ + Decorator to set the expected return types of a SOAP function call + """ + def return_types_wrapper(fn): fn._return_types = types return fn