cleanup, support mocks

This commit is contained in:
Jack Robison 2018-09-25 14:52:29 -04:00
parent f8b1a1326a
commit f5deb00a50
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
8 changed files with 41 additions and 37 deletions

View file

@ -6,7 +6,7 @@ class UPnPError(Exception):
pass
def handle_fault(response):
def handle_fault(response: dict) -> dict:
if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
error_description = fault['detail']['UPnPError']['errorDescription']

View file

@ -19,7 +19,7 @@ xml_root_sanity_pattern = re.compile(
)
class CaseInsensitive(object):
class CaseInsensitive:
def __init__(self, **kwargs):
not_evaluated = {}
for k, v in kwargs.items():
@ -34,7 +34,7 @@ class CaseInsensitive(object):
if not_evaluated:
log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
def _get_attr_name(self, case_insensitive):
def _get_attr_name(self, case_insensitive: str) -> str:
for k, v in self.__dict__.items():
if k.lower() == case_insensitive.lower():
return k
@ -60,7 +60,7 @@ class CaseInsensitive(object):
break
self.__dict__[to_update or item] = value
def as_dict(self):
def as_dict(self) -> dict:
return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
}
@ -111,7 +111,7 @@ class Device(CaseInsensitive):
log.warning("failed to parse device:\n%s", kw)
class Gateway(object):
class Gateway:
def __init__(self, **kwargs):
flattened = {
k.lower(): v for k, v in kwargs.items()
@ -143,7 +143,7 @@ class Gateway(object):
self._devices = []
self._services = []
def debug_device(self, include_xml=False, include_services=True):
def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict:
r = {
'server': self.server,
'urlBase': self.url_base,
@ -194,18 +194,18 @@ class Gateway(object):
log.debug("finished setting up gateway:\n%s", self.debug_device())
@property
def services(self):
def services(self) -> dict:
if not self._device:
return {}
return {service.serviceType: service for service in self._services}
@property
def devices(self):
def devices(self) -> dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type):
def get_service(self, service_type) -> Service:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
return service

View file

@ -5,7 +5,7 @@ from twisted.web.client import Agent
import treq
from treq.client import HTTPClient
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none
from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none
from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST
@ -14,7 +14,7 @@ from txupnp.dirty_pool import DirtyPool
log = logging.getLogger(__name__)
class StringProducer(object):
class StringProducer:
def __init__(self, body):
self.body = body
self.length = len(body)
@ -64,7 +64,7 @@ class _SCPDCommand(object):
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields")
raise UPnPError("unknown response fields for %s")
response = body[response_key]
extracted_response = tuple([response[n] for n in self.returns])
if len(extracted_response) == 1:
@ -143,7 +143,7 @@ class SCPDResponse(object):
class SCPDCommandRunner(object):
def __init__(self, gateway, reactor):
def __init__(self, gateway, reactor, treq_get=None):
self._gateway = gateway
self._unsupported_actions = {}
self._registered_commands = {}
@ -151,15 +151,15 @@ class SCPDCommandRunner(object):
self._connection_pool = DirtyPool(reactor)
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._treq_get = treq_get or treq.get
@defer.inlineCallbacks
def _discover_commands(self, service):
scpd_url = self._gateway.base_address + service.SCPDURL.encode()
response = yield treq.get(scpd_url)
response = yield self._treq_get(scpd_url)
content = yield response.content()
try:
scpd_response = SCPDResponse(scpd_url,
response.headers, content)
scpd_response = SCPDResponse(scpd_url, response.headers, content)
for action_dict in scpd_response.get_action_list():
self._register_command(action_dict, service.serviceType)
except Exception as err:
@ -200,7 +200,7 @@ class SCPDCommandRunner(object):
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
current = getattr(self, command.method)
if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result)
command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command)
self._registered_commands[command.method] = service_type

View file

@ -1,6 +1,6 @@
import logging
import netifaces
from twisted.internet import defer
from txupnp.util import get_lan_info
from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandRunner
from txupnp.gateway import Gateway
@ -11,12 +11,14 @@ log = logging.getLogger(__name__)
class SOAPServiceManager(object):
def __init__(self, reactor):
def __init__(self, reactor, treq_get=None):
self._reactor = reactor
self.iface_name, self.router_ip, self.lan_address = get_lan_info()
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
self._command_runners = {}
self._selected_runner = UPNP_ORG_IGD
self._treq_get = treq_get
@defer.inlineCallbacks
def discover_services(self, address=None, timeout=30, max_devices=1):
@ -29,7 +31,7 @@ class SOAPServiceManager(object):
locations.append(server_info['location'])
gateway = Gateway(**server_info)
yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway, self._reactor)
command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get)
yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner
elif 'st' not in server_info:

View file

@ -31,7 +31,7 @@ class SSDPProtocol(DatagramProtocol):
try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port)
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
raise err
@staticmethod

View file

@ -1,11 +1,11 @@
import re
import logging
import binascii
from txupnp.fault import UPnPError
from txupnp.constants import line_separator
log = logging.getLogger(__name__)
_ssdp_datagram_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
@ -135,7 +135,7 @@ class SSDPDatagram(object):
return packet
@classmethod
def _lines_to_content_dict(cls, lines):
def _lines_to_content_dict(cls, lines: list) -> dict:
result = {}
for line in lines:
if not line:
@ -158,7 +158,7 @@ class SSDPDatagram(object):
return result
@classmethod
def _from_string(cls, datagram):
def _from_string(cls, datagram: str):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])

View file

@ -9,10 +9,10 @@ log = logging.getLogger(__name__)
class UPnP(object):
def __init__(self, reactor, try_miniupnpc_fallback=True):
def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None):
self._reactor = reactor
self.try_miniupnpc_fallback = try_miniupnpc_fallback
self.soap_manager = SOAPServiceManager(reactor)
self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get)
self.miniupnpc_runner = None
self.miniupnpc_igd_url = None

View file

@ -1,14 +1,14 @@
import re
import functools
from collections import defaultdict
import netifaces
from twisted.internet import defer
from xml.etree import ElementTree
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t):
def etree_to_dict(t: ElementTree) -> dict:
d = {t.tag: {} if t.attrib else None}
children = list(t)
if children:
@ -52,14 +52,12 @@ def get_dict_val_case_insensitive(d, k):
return d[match[0]]
def get_lan_info():
gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr']
return iface_name, gateway_address, lan_addr
def verify_return_types(*types):
"""
Attempt to recast results to expected result types
"""
def _return_types(*types):
def _return_types_wrapper(fn):
def _verify_return_types(fn):
@functools.wraps(fn)
def _inner(response):
if isinstance(response, (list, tuple)):
@ -69,10 +67,14 @@ def _return_types(*types):
return fn(r)
return fn(types[0](response))
return _inner
return _return_types_wrapper
return _verify_return_types
def return_types(*types):
"""
Decorator to set the expected return types of a SOAP function call
"""
def return_types_wrapper(fn):
fn._return_types = types
return fn