cleanup, support mocks
This commit is contained in:
parent
f8b1a1326a
commit
f5deb00a50
8 changed files with 41 additions and 37 deletions
|
@ -6,7 +6,7 @@ class UPnPError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def handle_fault(response):
|
def handle_fault(response: dict) -> dict:
|
||||||
if FAULT in response:
|
if FAULT in response:
|
||||||
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
|
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
|
||||||
error_description = fault['detail']['UPnPError']['errorDescription']
|
error_description = fault['detail']['UPnPError']['errorDescription']
|
||||||
|
|
|
@ -19,7 +19,7 @@ xml_root_sanity_pattern = re.compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CaseInsensitive(object):
|
class CaseInsensitive:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
not_evaluated = {}
|
not_evaluated = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
|
@ -34,7 +34,7 @@ class CaseInsensitive(object):
|
||||||
if not_evaluated:
|
if not_evaluated:
|
||||||
log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
|
log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
|
||||||
|
|
||||||
def _get_attr_name(self, case_insensitive):
|
def _get_attr_name(self, case_insensitive: str) -> str:
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k.lower() == case_insensitive.lower():
|
if k.lower() == case_insensitive.lower():
|
||||||
return k
|
return k
|
||||||
|
@ -60,7 +60,7 @@ class CaseInsensitive(object):
|
||||||
break
|
break
|
||||||
self.__dict__[to_update or item] = value
|
self.__dict__[to_update or item] = value
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
|
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ class Device(CaseInsensitive):
|
||||||
log.warning("failed to parse device:\n%s", kw)
|
log.warning("failed to parse device:\n%s", kw)
|
||||||
|
|
||||||
|
|
||||||
class Gateway(object):
|
class Gateway:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
flattened = {
|
flattened = {
|
||||||
k.lower(): v for k, v in kwargs.items()
|
k.lower(): v for k, v in kwargs.items()
|
||||||
|
@ -143,7 +143,7 @@ class Gateway(object):
|
||||||
self._devices = []
|
self._devices = []
|
||||||
self._services = []
|
self._services = []
|
||||||
|
|
||||||
def debug_device(self, include_xml=False, include_services=True):
|
def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict:
|
||||||
r = {
|
r = {
|
||||||
'server': self.server,
|
'server': self.server,
|
||||||
'urlBase': self.url_base,
|
'urlBase': self.url_base,
|
||||||
|
@ -194,18 +194,18 @@ class Gateway(object):
|
||||||
log.debug("finished setting up gateway:\n%s", self.debug_device())
|
log.debug("finished setting up gateway:\n%s", self.debug_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def services(self):
|
def services(self) -> dict:
|
||||||
if not self._device:
|
if not self._device:
|
||||||
return {}
|
return {}
|
||||||
return {service.serviceType: service for service in self._services}
|
return {service.serviceType: service for service in self._services}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def devices(self):
|
def devices(self) -> dict:
|
||||||
if not self._device:
|
if not self._device:
|
||||||
return {}
|
return {}
|
||||||
return {device.udn: device for device in self._devices}
|
return {device.udn: device for device in self._devices}
|
||||||
|
|
||||||
def get_service(self, service_type):
|
def get_service(self, service_type) -> Service:
|
||||||
for service in self._services:
|
for service in self._services:
|
||||||
if service.serviceType.lower() == service_type.lower():
|
if service.serviceType.lower() == service_type.lower():
|
||||||
return service
|
return service
|
||||||
|
|
|
@ -5,7 +5,7 @@ from twisted.web.client import Agent
|
||||||
import treq
|
import treq
|
||||||
from treq.client import HTTPClient
|
from treq.client import HTTPClient
|
||||||
from xml.etree import ElementTree
|
from xml.etree import ElementTree
|
||||||
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none
|
from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none
|
||||||
from txupnp.fault import handle_fault, UPnPError
|
from txupnp.fault import handle_fault, UPnPError
|
||||||
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
|
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
|
||||||
from txupnp.constants import BODY, POST
|
from txupnp.constants import BODY, POST
|
||||||
|
@ -14,7 +14,7 @@ from txupnp.dirty_pool import DirtyPool
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StringProducer(object):
|
class StringProducer:
|
||||||
def __init__(self, body):
|
def __init__(self, body):
|
||||||
self.body = body
|
self.body = body
|
||||||
self.length = len(body)
|
self.length = len(body)
|
||||||
|
@ -64,7 +64,7 @@ class _SCPDCommand(object):
|
||||||
response_key = key
|
response_key = key
|
||||||
break
|
break
|
||||||
if not response_key:
|
if not response_key:
|
||||||
raise UPnPError("unknown response fields")
|
raise UPnPError("unknown response fields for %s")
|
||||||
response = body[response_key]
|
response = body[response_key]
|
||||||
extracted_response = tuple([response[n] for n in self.returns])
|
extracted_response = tuple([response[n] for n in self.returns])
|
||||||
if len(extracted_response) == 1:
|
if len(extracted_response) == 1:
|
||||||
|
@ -143,7 +143,7 @@ class SCPDResponse(object):
|
||||||
|
|
||||||
|
|
||||||
class SCPDCommandRunner(object):
|
class SCPDCommandRunner(object):
|
||||||
def __init__(self, gateway, reactor):
|
def __init__(self, gateway, reactor, treq_get=None):
|
||||||
self._gateway = gateway
|
self._gateway = gateway
|
||||||
self._unsupported_actions = {}
|
self._unsupported_actions = {}
|
||||||
self._registered_commands = {}
|
self._registered_commands = {}
|
||||||
|
@ -151,15 +151,15 @@ class SCPDCommandRunner(object):
|
||||||
self._connection_pool = DirtyPool(reactor)
|
self._connection_pool = DirtyPool(reactor)
|
||||||
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
|
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
|
||||||
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
|
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
|
||||||
|
self._treq_get = treq_get or treq.get
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _discover_commands(self, service):
|
def _discover_commands(self, service):
|
||||||
scpd_url = self._gateway.base_address + service.SCPDURL.encode()
|
scpd_url = self._gateway.base_address + service.SCPDURL.encode()
|
||||||
response = yield treq.get(scpd_url)
|
response = yield self._treq_get(scpd_url)
|
||||||
content = yield response.content()
|
content = yield response.content()
|
||||||
try:
|
try:
|
||||||
scpd_response = SCPDResponse(scpd_url,
|
scpd_response = SCPDResponse(scpd_url, response.headers, content)
|
||||||
response.headers, content)
|
|
||||||
for action_dict in scpd_response.get_action_list():
|
for action_dict in scpd_response.get_action_list():
|
||||||
self._register_command(action_dict, service.serviceType)
|
self._register_command(action_dict, service.serviceType)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
@ -200,7 +200,7 @@ class SCPDCommandRunner(object):
|
||||||
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
|
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
|
||||||
current = getattr(self, command.method)
|
current = getattr(self, command.method)
|
||||||
if hasattr(current, "_return_types"):
|
if hasattr(current, "_return_types"):
|
||||||
command._process_result = _return_types(*current._return_types)(command._process_result)
|
command._process_result = verify_return_types(*current._return_types)(command._process_result)
|
||||||
setattr(command, "__doc__", current.__doc__)
|
setattr(command, "__doc__", current.__doc__)
|
||||||
setattr(self, command.method, command)
|
setattr(self, command.method, command)
|
||||||
self._registered_commands[command.method] = service_type
|
self._registered_commands[command.method] = service_type
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
import netifaces
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from txupnp.util import get_lan_info
|
|
||||||
from txupnp.ssdp import SSDPFactory
|
from txupnp.ssdp import SSDPFactory
|
||||||
from txupnp.scpd import SCPDCommandRunner
|
from txupnp.scpd import SCPDCommandRunner
|
||||||
from txupnp.gateway import Gateway
|
from txupnp.gateway import Gateway
|
||||||
|
@ -11,12 +11,14 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SOAPServiceManager(object):
|
class SOAPServiceManager(object):
|
||||||
def __init__(self, reactor):
|
def __init__(self, reactor, treq_get=None):
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
self.iface_name, self.router_ip, self.lan_address = get_lan_info()
|
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
|
||||||
|
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
|
||||||
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
|
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
|
||||||
self._command_runners = {}
|
self._command_runners = {}
|
||||||
self._selected_runner = UPNP_ORG_IGD
|
self._selected_runner = UPNP_ORG_IGD
|
||||||
|
self._treq_get = treq_get
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def discover_services(self, address=None, timeout=30, max_devices=1):
|
def discover_services(self, address=None, timeout=30, max_devices=1):
|
||||||
|
@ -29,7 +31,7 @@ class SOAPServiceManager(object):
|
||||||
locations.append(server_info['location'])
|
locations.append(server_info['location'])
|
||||||
gateway = Gateway(**server_info)
|
gateway = Gateway(**server_info)
|
||||||
yield gateway.discover_services()
|
yield gateway.discover_services()
|
||||||
command_runner = SCPDCommandRunner(gateway, self._reactor)
|
command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get)
|
||||||
yield command_runner.discover_commands()
|
yield command_runner.discover_commands()
|
||||||
self._command_runners[gateway.urn.decode()] = command_runner
|
self._command_runners[gateway.urn.decode()] = command_runner
|
||||||
elif 'st' not in server_info:
|
elif 'st' not in server_info:
|
||||||
|
|
|
@ -31,7 +31,7 @@ class SSDPProtocol(DatagramProtocol):
|
||||||
try:
|
try:
|
||||||
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
|
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port)
|
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
|
import binascii
|
||||||
from txupnp.fault import UPnPError
|
from txupnp.fault import UPnPError
|
||||||
from txupnp.constants import line_separator
|
from txupnp.constants import line_separator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_ssdp_datagram_patterns = {
|
_ssdp_datagram_patterns = {
|
||||||
'host': (re.compile("^(?i)(host):(.*)$"), str),
|
'host': (re.compile("^(?i)(host):(.*)$"), str),
|
||||||
'st': (re.compile("^(?i)(st):(.*)$"), str),
|
'st': (re.compile("^(?i)(st):(.*)$"), str),
|
||||||
|
@ -135,7 +135,7 @@ class SSDPDatagram(object):
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _lines_to_content_dict(cls, lines):
|
def _lines_to_content_dict(cls, lines: list) -> dict:
|
||||||
result = {}
|
result = {}
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if not line:
|
if not line:
|
||||||
|
@ -158,7 +158,7 @@ class SSDPDatagram(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_string(cls, datagram):
|
def _from_string(cls, datagram: str):
|
||||||
lines = [l for l in datagram.split(line_separator) if l]
|
lines = [l for l in datagram.split(line_separator) if l]
|
||||||
if lines[0] == cls._start_lines[cls._M_SEARCH]:
|
if lines[0] == cls._start_lines[cls._M_SEARCH]:
|
||||||
return cls._from_request(lines[1:])
|
return cls._from_request(lines[1:])
|
||||||
|
|
|
@ -9,10 +9,10 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UPnP(object):
|
class UPnP(object):
|
||||||
def __init__(self, reactor, try_miniupnpc_fallback=True):
|
def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None):
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
self.try_miniupnpc_fallback = try_miniupnpc_fallback
|
self.try_miniupnpc_fallback = try_miniupnpc_fallback
|
||||||
self.soap_manager = SOAPServiceManager(reactor)
|
self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get)
|
||||||
self.miniupnpc_runner = None
|
self.miniupnpc_runner = None
|
||||||
self.miniupnpc_igd_url = None
|
self.miniupnpc_igd_url = None
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import re
|
import re
|
||||||
import functools
|
import functools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import netifaces
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from xml.etree import ElementTree
|
||||||
|
|
||||||
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
|
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
|
||||||
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
|
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
|
||||||
|
|
||||||
|
|
||||||
def etree_to_dict(t):
|
def etree_to_dict(t: ElementTree) -> dict:
|
||||||
d = {t.tag: {} if t.attrib else None}
|
d = {t.tag: {} if t.attrib else None}
|
||||||
children = list(t)
|
children = list(t)
|
||||||
if children:
|
if children:
|
||||||
|
@ -52,14 +52,12 @@ def get_dict_val_case_insensitive(d, k):
|
||||||
return d[match[0]]
|
return d[match[0]]
|
||||||
|
|
||||||
|
|
||||||
def get_lan_info():
|
def verify_return_types(*types):
|
||||||
gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
|
"""
|
||||||
lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr']
|
Attempt to recast results to expected result types
|
||||||
return iface_name, gateway_address, lan_addr
|
"""
|
||||||
|
|
||||||
|
def _verify_return_types(fn):
|
||||||
def _return_types(*types):
|
|
||||||
def _return_types_wrapper(fn):
|
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def _inner(response):
|
def _inner(response):
|
||||||
if isinstance(response, (list, tuple)):
|
if isinstance(response, (list, tuple)):
|
||||||
|
@ -69,10 +67,14 @@ def _return_types(*types):
|
||||||
return fn(r)
|
return fn(r)
|
||||||
return fn(types[0](response))
|
return fn(types[0](response))
|
||||||
return _inner
|
return _inner
|
||||||
return _return_types_wrapper
|
return _verify_return_types
|
||||||
|
|
||||||
|
|
||||||
def return_types(*types):
|
def return_types(*types):
|
||||||
|
"""
|
||||||
|
Decorator to set the expected return types of a SOAP function call
|
||||||
|
"""
|
||||||
|
|
||||||
def return_types_wrapper(fn):
|
def return_types_wrapper(fn):
|
||||||
fn._return_types = types
|
fn._return_types = types
|
||||||
return fn
|
return fn
|
||||||
|
|
Loading…
Reference in a new issue