cleanup, support mocks

This commit is contained in:
Jack Robison 2018-09-25 14:52:29 -04:00
parent f8b1a1326a
commit f5deb00a50
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
8 changed files with 41 additions and 37 deletions

View file

@ -6,7 +6,7 @@ class UPnPError(Exception):
pass pass
def handle_fault(response): def handle_fault(response: dict) -> dict:
if FAULT in response: if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL) fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
error_description = fault['detail']['UPnPError']['errorDescription'] error_description = fault['detail']['UPnPError']['errorDescription']

View file

@ -19,7 +19,7 @@ xml_root_sanity_pattern = re.compile(
) )
class CaseInsensitive(object): class CaseInsensitive:
def __init__(self, **kwargs): def __init__(self, **kwargs):
not_evaluated = {} not_evaluated = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
@ -34,7 +34,7 @@ class CaseInsensitive(object):
if not_evaluated: if not_evaluated:
log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated) log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
def _get_attr_name(self, case_insensitive): def _get_attr_name(self, case_insensitive: str) -> str:
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k.lower() == case_insensitive.lower(): if k.lower() == case_insensitive.lower():
return k return k
@ -60,7 +60,7 @@ class CaseInsensitive(object):
break break
self.__dict__[to_update or item] = value self.__dict__[to_update or item] = value
def as_dict(self): def as_dict(self) -> dict:
return { return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
} }
@ -111,7 +111,7 @@ class Device(CaseInsensitive):
log.warning("failed to parse device:\n%s", kw) log.warning("failed to parse device:\n%s", kw)
class Gateway(object): class Gateway:
def __init__(self, **kwargs): def __init__(self, **kwargs):
flattened = { flattened = {
k.lower(): v for k, v in kwargs.items() k.lower(): v for k, v in kwargs.items()
@ -143,7 +143,7 @@ class Gateway(object):
self._devices = [] self._devices = []
self._services = [] self._services = []
def debug_device(self, include_xml=False, include_services=True): def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict:
r = { r = {
'server': self.server, 'server': self.server,
'urlBase': self.url_base, 'urlBase': self.url_base,
@ -194,18 +194,18 @@ class Gateway(object):
log.debug("finished setting up gateway:\n%s", self.debug_device()) log.debug("finished setting up gateway:\n%s", self.debug_device())
@property @property
def services(self): def services(self) -> dict:
if not self._device: if not self._device:
return {} return {}
return {service.serviceType: service for service in self._services} return {service.serviceType: service for service in self._services}
@property @property
def devices(self): def devices(self) -> dict:
if not self._device: if not self._device:
return {} return {}
return {device.udn: device for device in self._devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type): def get_service(self, service_type) -> Service:
for service in self._services: for service in self._services:
if service.serviceType.lower() == service_type.lower(): if service.serviceType.lower() == service_type.lower():
return service return service

View file

@ -5,7 +5,7 @@ from twisted.web.client import Agent
import treq import treq
from treq.client import HTTPClient from treq.client import HTTPClient
from xml.etree import ElementTree from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none
from txupnp.fault import handle_fault, UPnPError from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST from txupnp.constants import BODY, POST
@ -14,7 +14,7 @@ from txupnp.dirty_pool import DirtyPool
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class StringProducer(object): class StringProducer:
def __init__(self, body): def __init__(self, body):
self.body = body self.body = body
self.length = len(body) self.length = len(body)
@ -64,7 +64,7 @@ class _SCPDCommand(object):
response_key = key response_key = key
break break
if not response_key: if not response_key:
raise UPnPError("unknown response fields") raise UPnPError("unknown response fields for %s")
response = body[response_key] response = body[response_key]
extracted_response = tuple([response[n] for n in self.returns]) extracted_response = tuple([response[n] for n in self.returns])
if len(extracted_response) == 1: if len(extracted_response) == 1:
@ -143,7 +143,7 @@ class SCPDResponse(object):
class SCPDCommandRunner(object): class SCPDCommandRunner(object):
def __init__(self, gateway, reactor): def __init__(self, gateway, reactor, treq_get=None):
self._gateway = gateway self._gateway = gateway
self._unsupported_actions = {} self._unsupported_actions = {}
self._registered_commands = {} self._registered_commands = {}
@ -151,15 +151,15 @@ class SCPDCommandRunner(object):
self._connection_pool = DirtyPool(reactor) self._connection_pool = DirtyPool(reactor)
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool) self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer) self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._treq_get = treq_get or treq.get
@defer.inlineCallbacks @defer.inlineCallbacks
def _discover_commands(self, service): def _discover_commands(self, service):
scpd_url = self._gateway.base_address + service.SCPDURL.encode() scpd_url = self._gateway.base_address + service.SCPDURL.encode()
response = yield treq.get(scpd_url) response = yield self._treq_get(scpd_url)
content = yield response.content() content = yield response.content()
try: try:
scpd_response = SCPDResponse(scpd_url, scpd_response = SCPDResponse(scpd_url, response.headers, content)
response.headers, content)
for action_dict in scpd_response.get_action_list(): for action_dict in scpd_response.get_action_list():
self._register_command(action_dict, service.serviceType) self._register_command(action_dict, service.serviceType)
except Exception as err: except Exception as err:
@ -200,7 +200,7 @@ class SCPDCommandRunner(object):
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs) self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
current = getattr(self, command.method) current = getattr(self, command.method)
if hasattr(current, "_return_types"): if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result) command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__) setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command) setattr(self, command.method, command)
self._registered_commands[command.method] = service_type self._registered_commands[command.method] = service_type

View file

@ -1,6 +1,6 @@
import logging import logging
import netifaces
from twisted.internet import defer from twisted.internet import defer
from txupnp.util import get_lan_info
from txupnp.ssdp import SSDPFactory from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandRunner from txupnp.scpd import SCPDCommandRunner
from txupnp.gateway import Gateway from txupnp.gateway import Gateway
@ -11,12 +11,14 @@ log = logging.getLogger(__name__)
class SOAPServiceManager(object): class SOAPServiceManager(object):
def __init__(self, reactor): def __init__(self, reactor, treq_get=None):
self._reactor = reactor self._reactor = reactor
self.iface_name, self.router_ip, self.lan_address = get_lan_info() self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip) self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
self._command_runners = {} self._command_runners = {}
self._selected_runner = UPNP_ORG_IGD self._selected_runner = UPNP_ORG_IGD
self._treq_get = treq_get
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self, address=None, timeout=30, max_devices=1): def discover_services(self, address=None, timeout=30, max_devices=1):
@ -29,7 +31,7 @@ class SOAPServiceManager(object):
locations.append(server_info['location']) locations.append(server_info['location'])
gateway = Gateway(**server_info) gateway = Gateway(**server_info)
yield gateway.discover_services() yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway, self._reactor) command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get)
yield command_runner.discover_commands() yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner self._command_runners[gateway.urn.decode()] = command_runner
elif 'st' not in server_info: elif 'st' not in server_info:

View file

@ -31,7 +31,7 @@ class SSDPProtocol(DatagramProtocol):
try: try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port)) self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err: except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port) log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
raise err raise err
@staticmethod @staticmethod

View file

@ -1,11 +1,11 @@
import re import re
import logging import logging
import binascii
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.constants import line_separator from txupnp.constants import line_separator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_ssdp_datagram_patterns = { _ssdp_datagram_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str), 'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str), 'st': (re.compile("^(?i)(st):(.*)$"), str),
@ -135,7 +135,7 @@ class SSDPDatagram(object):
return packet return packet
@classmethod @classmethod
def _lines_to_content_dict(cls, lines): def _lines_to_content_dict(cls, lines: list) -> dict:
result = {} result = {}
for line in lines: for line in lines:
if not line: if not line:
@ -158,7 +158,7 @@ class SSDPDatagram(object):
return result return result
@classmethod @classmethod
def _from_string(cls, datagram): def _from_string(cls, datagram: str):
lines = [l for l in datagram.split(line_separator) if l] lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]: if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:]) return cls._from_request(lines[1:])

View file

@ -9,10 +9,10 @@ log = logging.getLogger(__name__)
class UPnP(object): class UPnP(object):
def __init__(self, reactor, try_miniupnpc_fallback=True): def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None):
self._reactor = reactor self._reactor = reactor
self.try_miniupnpc_fallback = try_miniupnpc_fallback self.try_miniupnpc_fallback = try_miniupnpc_fallback
self.soap_manager = SOAPServiceManager(reactor) self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get)
self.miniupnpc_runner = None self.miniupnpc_runner = None
self.miniupnpc_igd_url = None self.miniupnpc_igd_url = None

View file

@ -1,14 +1,14 @@
import re import re
import functools import functools
from collections import defaultdict from collections import defaultdict
import netifaces
from twisted.internet import defer from twisted.internet import defer
from xml.etree import ElementTree
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t): def etree_to_dict(t: ElementTree) -> dict:
d = {t.tag: {} if t.attrib else None} d = {t.tag: {} if t.attrib else None}
children = list(t) children = list(t)
if children: if children:
@ -52,14 +52,12 @@ def get_dict_val_case_insensitive(d, k):
return d[match[0]] return d[match[0]]
def get_lan_info(): def verify_return_types(*types):
gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET] """
lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr'] Attempt to recast results to expected result types
return iface_name, gateway_address, lan_addr """
def _verify_return_types(fn):
def _return_types(*types):
def _return_types_wrapper(fn):
@functools.wraps(fn) @functools.wraps(fn)
def _inner(response): def _inner(response):
if isinstance(response, (list, tuple)): if isinstance(response, (list, tuple)):
@ -69,10 +67,14 @@ def _return_types(*types):
return fn(r) return fn(r)
return fn(types[0](response)) return fn(types[0](response))
return _inner return _inner
return _return_types_wrapper return _verify_return_types
def return_types(*types): def return_types(*types):
"""
Decorator to set the expected return types of a SOAP function call
"""
def return_types_wrapper(fn): def return_types_wrapper(fn):
fn._return_types = types fn._return_types = types
return fn return fn