From 4137f7cd8a784f5b0705b35cb6984ec1ce06924c Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 21 May 2019 18:16:30 -0400 Subject: [PATCH] mypy refactor --- .coveragerc | 4 + .gitignore | 3 + .pylintrc | 440 +++++++++++++++++++++++++ aioupnp/__main__.py | 54 +-- aioupnp/commands.py | 255 +++++++------- aioupnp/device.py | 99 +++--- aioupnp/fault.py | 12 - aioupnp/gateway.py | 231 ++++++++----- aioupnp/interfaces.py | 51 +++ aioupnp/protocols/m_search_patterns.py | 24 +- aioupnp/protocols/multicast.py | 47 ++- aioupnp/protocols/scpd.py | 96 ++++-- aioupnp/protocols/ssdp.py | 129 +++++--- aioupnp/serialization/scpd.py | 48 +-- aioupnp/serialization/soap.py | 85 ++--- aioupnp/serialization/ssdp.py | 190 ++++++----- aioupnp/serialization/xml.py | 81 +++++ aioupnp/upnp.py | 30 +- aioupnp/util.py | 119 +++---- setup.py | 2 +- 20 files changed, 1357 insertions(+), 643 deletions(-) create mode 100644 .coveragerc create mode 100644 .pylintrc create mode 100644 aioupnp/interfaces.py create mode 100644 aioupnp/serialization/xml.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..d3ae8b7 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[run] +omit = + tests/* + stubs/* diff --git a/.gitignore b/.gitignore index 28fe439..42d95e7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ _trial_temp/ build/ dist/ +html/ +index.html +mypy-html.css .coverage .mypy_cache/ aioupnp.spec diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..c97b3d2 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,440 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS,schema + +# Add files or directories matching the regex patterns to the +# blacklist. The regex matches against base names, not paths. +# `\.#.*` - add emacs tmp files to the blacklist +ignore-patterns=\.#.* + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +# extension-pkg-whitelist= + +# Allow optimization of some AST trees. This will activate a peephole AST +# optimizer, which will apply various small optimizations. For instance, it can +# be used to obtain the result of joining multiple strings with the addition +# operator. Joining a lot of strings can lead to a maximum recursion error in +# Pylint and this flag can prevent that. It has one side effect, the resulting +# AST will be different than the one from reality. +optimize-ast=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable= + anomalous-backslash-in-string, + arguments-differ, + attribute-defined-outside-init, + bad-continuation, + bare-except, + broad-except, + cell-var-from-loop, + consider-iterating-dictionary, + dangerous-default-value, + duplicate-code, + fixme, + global-statement, + inherit-non-class, + invalid-name, + len-as-condition, + locally-disabled, + logging-not-lazy, + missing-docstring, + no-else-return, + no-init, + no-member, + no-self-use, + protected-access, + redefined-builtin, + redefined-outer-name, + redefined-variable-type, + relative-import, + signature-differs, + super-init-not-called, + too-few-public-methods, + too-many-arguments, + too-many-branches, + too-many-instance-attributes, + too-many-lines, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + undefined-loop-variable, + ungrouped-imports, + unnecessary-lambda, + unused-argument, + unused-variable, + wildcard-import, + wrong-import-order, + wrong-import-position, + deprecated-lambda, + simplifiable-if-statement, + unidiomatic-typecheck, + global-at-module-level, + inconsistent-return-statements, + keyword-arg-before-vararg, + assignment-from-no-return, + useless-return, + assignment-from-none, + stop-iteration-return + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_$|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[BASIC] + +# List of builtins function names that should not be used, separated by a comma +bad-functions=map,filter,input + +# Good variable names which should always be accepted, separated by a comma +# allow `d` as its used frequently for deferred callback chains +good-names=i,j,k,ex,Run,_,d + +# Bad variable names which should always be refused, separated by a comma +bad-names=foo,bar,baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + + +[ELIF] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=leveldb,distutils +# Ignoring distutils because: https://github.com/PyCQA/pylint/issues/73 + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). This supports can work +# with qualified names. +# ignored-classes= + + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,miniupnpc + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=10 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=8 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of boolean expressions in a if statement +max-bool-expr=5 + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/aioupnp/__main__.py b/aioupnp/__main__.py index cd8fae8..abef621 100644 --- a/aioupnp/__main__.py +++ b/aioupnp/__main__.py @@ -1,8 +1,11 @@ -import logging import sys +import asyncio +import logging import textwrap +import typing from collections import OrderedDict from aioupnp.upnp import UPnP +from aioupnp.commands import SOAPCommands log = logging.getLogger("aioupnp") handler = logging.StreamHandler() @@ -16,17 +19,18 @@ base_usage = "\n".join(textwrap.wrap( 100, subsequent_indent=' ', break_long_words=False)) + "\n" -def get_help(command): - fn = getattr(UPnP, command) - params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return']) +def get_help(command: str) -> str: + annotations = UPnP.get_annotations(command) + params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return']) return base_usage + "\n".join( textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False) ) -def main(argv=None, loop=None): - argv = argv or sys.argv - commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")] +def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None, + loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int: + argv = argv or list(sys.argv) + commands = list(SOAPCommands.SOAP_COMMANDS) help_str = "\n".join(textwrap.wrap( " | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False )) @@ -41,14 +45,16 @@ def main(argv=None, loop=None): "For help with a specific command:" \ " aioupnp help \n" % (base_usage, help_str) - args = argv[1:] + args: typing.List[str] = [str(arg) for arg in argv[1:]] if args[0] in ['help', '-h', '--help']: if len(args) > 1: if args[1] in commands: - sys.exit(get_help(args[1])) - sys.exit(print(usage)) + print(get_help(args[1])) + return 0 + print(usage) + return 0 - defaults = { + defaults: typing.Dict[str, typing.Union[bool, str, int]] = { 'debug_logging': False, 'interface': 'default', 'gateway_address': '', @@ -57,22 +63,22 @@ def main(argv=None, loop=None): 'unicast': False } - options = OrderedDict() + options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict() command = None for arg in args: if arg.startswith("--"): if "=" in arg: k, v = arg.split("=") + options[k.lstrip('--')] = v else: - k, v = arg, True - k = k.lstrip('--') - options[k] = v + options[arg.lstrip('--')] = True else: command = arg break if not command: print("no command given") - sys.exit(print(usage)) + print(usage) + return 0 kwargs = {} for arg in args[len(options)+1:]: if arg.startswith("--"): @@ -81,18 +87,24 @@ def main(argv=None, loop=None): kwargs[k] = v else: break - for k, v in defaults.items(): + for k in defaults: if k not in options: - options[k] = v + options[k] = defaults[k] if options.pop('debug_logging'): log.setLevel(logging.DEBUG) + lan_address: str = str(options.pop('lan_address')) + gateway_address: str = str(options.pop('gateway_address')) + timeout: int = int(options.pop('timeout')) + interface: str = str(options.pop('interface')) + unicast: bool = bool(options.pop('unicast')) + UPnP.run_cli( - command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'), - options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop + command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop ) + return 0 if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/aioupnp/commands.py b/aioupnp/commands.py index d202d22..4e514bf 100644 --- a/aioupnp/commands.py +++ b/aioupnp/commands.py @@ -1,64 +1,67 @@ -import logging +import asyncio import time import typing -from typing import Tuple, Union, List +import functools +import logging +from typing import Tuple from aioupnp.protocols.scpd import scpd_post +from aioupnp.device import Service log = logging.getLogger(__name__) -none_or_str = Union[None, str] -return_type_lambas = { - Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None -} -def safe_type(t): - if t is typing.Tuple: - return tuple - if t is typing.List: - return list - if t is typing.Dict: - return dict - if t is typing.Set: - return set - return t +def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]: + return x if x is not None and str(x).lower() not in ['none', 'nil'] else None -class SOAPCommand: - def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, - param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None: - self.gateway_address = gateway_address - self.service_port = service_port - self.control_url = control_url - self.service_id = service_id - self.method = method - self.param_types = param_types - self.param_order = param_order - self.return_types = return_types - self.return_order = return_order - self.loop = loop - self._requests: typing.List = [] +def soap_bool(x: typing.Optional[str]) -> bool: + return False if not x or str(x).lower() in ['false', 'False'] else True - async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]: - if set(kwargs.keys()) != set(self.param_types.keys()): - raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) - soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()} + +def recast_single_result(t, result): + if t is bool: + return soap_bool(result) + if t is str: + return soap_optional_str(result) + return t(result) + + +def recast_return(return_annotation, result, result_keys: typing.List[str]): + if return_annotation is None: + return None + if len(result_keys) == 1: + assert len(result_keys) == 1 + single_result = result[result_keys[0]] + return recast_single_result(return_annotation, single_result) + + annotated_args: typing.List[type] = list(return_annotation.__args__) + assert len(annotated_args) == len(result_keys) + recast_results: typing.List[typing.Optional[typing.Union[str, int, bool, bytes]]] = [] + for type_annotation, result_key in zip(annotated_args, result_keys): + recast_results.append(recast_single_result(type_annotation, result[result_key])) + return tuple(recast_results) + + +def soap_command(fn): + @functools.wraps(fn) + async def wrapper(self: 'SOAPCommands', **kwargs): + if not self.is_registered(fn.__name__): + return fn(self, **kwargs) + service = self.get_service(fn.__name__) + assert service.controlURL is not None + assert service.serviceType is not None response, xml_bytes, err = await scpd_post( - self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, - self.service_id, self.loop, **soap_kwargs + service.controlURL, self._base_address.decode(), self._port, fn.__name__, self._registered[service][fn.__name__][0], + service.serviceType.encode(), self._loop, **kwargs ) if err is not None: - self._requests.append((soap_kwargs, xml_bytes, None, err, time.time())) + + self._requests.append((fn.__name__, kwargs, xml_bytes, None, err, time.time())) raise err - if not response: - result = None - else: - recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order]) - if len(recast_result) == 1: - result = recast_result[0] - else: - result = recast_result - self._requests.append((soap_kwargs, xml_bytes, result, None, time.time())) + result = recast_return(fn.__annotations__.get('return'), response, self._registered[service][fn.__name__][1]) + self._requests.append((fn.__name__, kwargs, xml_bytes, result, None, time.time())) return result + return wrapper class SOAPCommands: @@ -72,7 +75,7 @@ class SOAPCommands: to their expected types. """ - SOAP_COMMANDS = [ + SOAP_COMMANDS: typing.List[str] = [ 'AddPortMapping', 'GetNATRSIPStatus', 'GetGenericPortMappingEntry', @@ -91,59 +94,63 @@ class SOAPCommands: 'GetTotalPacketsReceived', 'X_GetICSStatistics', 'GetDefaultConnectionService', - 'NewDefaultConnectionService', - 'NewEnabledForInternet', 'SetDefaultConnectionService', 'SetEnabledForInternet', 'GetEnabledForInternet', - 'NewActiveConnectionIndex', 'GetMaximumActiveConnections', 'GetActiveConnections' ] - def __init__(self): - self._registered = set() + def __init__(self, loop: asyncio.AbstractEventLoop, base_address: bytes, port: int) -> None: + self._loop = loop + self._registered: typing.Dict[Service, + typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]]] = {} + self._base_address = base_address + self._port = port + self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, + typing.Optional[typing.Dict[str, typing.Any]], + typing.Optional[Exception], float]] = [] - def register(self, base_ip: bytes, port: int, name: str, control_url: str, - service_type: bytes, inputs: List, outputs: List, loop=None) -> None: - if name not in self.SOAP_COMMANDS or name in self._registered: + def is_registered(self, name: str) -> bool: + if name not in self.SOAP_COMMANDS: + raise ValueError("unknown command") + for service in self._registered.values(): + if name in service: + return True + return False + + def get_service(self, name: str) -> Service: + if name not in self.SOAP_COMMANDS: + raise ValueError("unknown command") + for service, commands in self._registered.items(): + if name in commands: + return service + raise ValueError(name) + + def register(self, name: str, service: Service, inputs: typing.List[str], outputs: typing.List[str]) -> None: + # control_url: str, service_type: bytes, + if name not in self.SOAP_COMMANDS: raise AttributeError(name) - current = getattr(self, name) - annotations = current.__annotations__ - return_types = annotations.get('return', None) - if return_types: - if hasattr(return_types, '__args__'): - return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__]) - elif isinstance(return_types, type): - return_types = (return_types,) - return_types = {r: t for r, t in zip(outputs, return_types)} - param_types = {} - for param_name, param_type in annotations.items(): - if param_name == "return": - continue - param_types[param_name] = param_type - command = SOAPCommand( - base_ip.decode(), port, control_url, service_type, - name, param_types, return_types, inputs, outputs, loop=loop - ) - setattr(command, "__doc__", current.__doc__) - setattr(self, command.method, command) - self._registered.add(command.method) + if self.is_registered(name): + raise AttributeError(f"{name} is already a registered SOAP command") + if service not in self._registered: + self._registered[service] = {} + self._registered[service][name] = inputs, outputs - @staticmethod - async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int, - NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str, - NewLeaseDuration: str) -> None: + @soap_command + async def AddPortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int, + NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str, + NewLeaseDuration: str) -> None: """Returns None""" raise NotImplementedError() - @staticmethod - async def GetNATRSIPStatus() -> Tuple[bool, bool]: + @soap_command + async def GetNATRSIPStatus(self) -> Tuple[bool, bool]: """Returns (NewRSIPAvailable, NewNATEnabled)""" raise NotImplementedError() - @staticmethod - async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str, + @soap_command + async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str, bool, str, int]: """ Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, @@ -151,100 +158,100 @@ class SOAPCommands: """ raise NotImplementedError() - @staticmethod - async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, + @soap_command + async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> Tuple[int, str, bool, str, int]: """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" raise NotImplementedError() - @staticmethod - async def SetConnectionType(NewConnectionType: str) -> None: + @soap_command + async def SetConnectionType(self, NewConnectionType: str) -> None: """Returns None""" raise NotImplementedError() - @staticmethod - async def GetExternalIPAddress() -> str: + @soap_command + async def GetExternalIPAddress(self) -> str: """Returns (NewExternalIPAddress)""" raise NotImplementedError() - @staticmethod - async def GetConnectionTypeInfo() -> Tuple[str, str]: + @soap_command + async def GetConnectionTypeInfo(self) -> Tuple[str, str]: """Returns (NewConnectionType, NewPossibleConnectionTypes)""" raise NotImplementedError() - @staticmethod - async def GetStatusInfo() -> Tuple[str, str, int]: + @soap_command + async def GetStatusInfo(self) -> Tuple[str, str, int]: """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" raise NotImplementedError() - @staticmethod - async def ForceTermination() -> None: + @soap_command + async def ForceTermination(self) -> None: """Returns None""" raise NotImplementedError() - @staticmethod - async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None: + @soap_command + async def DeletePortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None: """Returns None""" raise NotImplementedError() - @staticmethod - async def RequestConnection() -> None: + @soap_command + async def RequestConnection(self) -> None: """Returns None""" raise NotImplementedError() - @staticmethod - async def GetCommonLinkProperties(): + @soap_command + async def GetCommonLinkProperties(self) -> Tuple[str, int, int, str]: """Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)""" raise NotImplementedError() - @staticmethod - async def GetTotalBytesSent(): + @soap_command + async def GetTotalBytesSent(self) -> int: """Returns (NewTotalBytesSent)""" raise NotImplementedError() - @staticmethod - async def GetTotalBytesReceived(): + @soap_command + async def GetTotalBytesReceived(self) -> int: """Returns (NewTotalBytesReceived)""" raise NotImplementedError() - @staticmethod - async def GetTotalPacketsSent(): + @soap_command + async def GetTotalPacketsSent(self) -> int: """Returns (NewTotalPacketsSent)""" raise NotImplementedError() - @staticmethod - def GetTotalPacketsReceived(): + @soap_command + async def GetTotalPacketsReceived(self) -> int: """Returns (NewTotalPacketsReceived)""" raise NotImplementedError() - @staticmethod - async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]: + @soap_command + async def X_GetICSStatistics(self) -> Tuple[int, int, int, int, str, str]: """Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)""" raise NotImplementedError() - @staticmethod - async def GetDefaultConnectionService(): + @soap_command + async def GetDefaultConnectionService(self) -> str: """Returns (NewDefaultConnectionService)""" raise NotImplementedError() - @staticmethod - async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None: + @soap_command + async def SetDefaultConnectionService(self, NewDefaultConnectionService: str) -> None: """Returns (None)""" raise NotImplementedError() - @staticmethod - async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None: + @soap_command + async def SetEnabledForInternet(self, NewEnabledForInternet: bool) -> None: raise NotImplementedError() - @staticmethod - async def GetEnabledForInternet() -> bool: + @soap_command + async def GetEnabledForInternet(self) -> bool: raise NotImplementedError() - @staticmethod - async def GetMaximumActiveConnections(NewActiveConnectionIndex: int): + @soap_command + async def GetMaximumActiveConnections(self, NewActiveConnectionIndex: int): raise NotImplementedError() - @staticmethod - async def GetActiveConnections() -> Tuple[str, str]: + @soap_command + async def GetActiveConnections(self) -> Tuple[str, str]: """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" raise NotImplementedError() diff --git a/aioupnp/device.py b/aioupnp/device.py index 8502840..d6d84d4 100644 --- a/aioupnp/device.py +++ b/aioupnp/device.py @@ -1,23 +1,33 @@ +from collections import OrderedDict +import typing import logging -from typing import List log = logging.getLogger(__name__) class CaseInsensitive: - def __init__(self, **kwargs) -> None: - for k, v in kwargs.items(): + def __init__(self, **kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List[typing.Any]]]) -> None: + keys: typing.List[str] = list(kwargs.keys()) + for k in keys: if not k.startswith("_"): - setattr(self, k, v) + assert k in kwargs + setattr(self, k, kwargs[k]) - def __getattr__(self, item): - for k in self.__class__.__dict__.keys(): + def __getattr__(self, item: str) -> typing.Union[str, typing.Dict[str, typing.Any], typing.List]: + keys: typing.List[str] = list(self.__class__.__dict__.keys()) + for k in keys: if k.lower() == item.lower(): - return self.__dict__.get(k) + value: typing.Optional[typing.Union[str, typing.Dict[str, typing.Any], + typing.List]] = self.__dict__.get(k) + assert value is not None and isinstance(value, (str, dict, list)) + return value raise AttributeError(item) - def __setattr__(self, item, value): - for k, v in self.__class__.__dict__.items(): + def __setattr__(self, item: str, + value: typing.Union[str, typing.Dict[str, typing.Any], typing.List]) -> None: + assert isinstance(value, (str, dict)), ValueError(f"got type {str(type(value))}, expected str") + keys: typing.List[str] = list(self.__class__.__dict__.keys()) + for k in keys: if k.lower() == item.lower(): self.__dict__[k] = value return @@ -26,52 +36,57 @@ class CaseInsensitive: return raise AttributeError(item) - def as_dict(self) -> dict: - return { - k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) - } + def as_dict(self) -> typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]: + result: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]] = OrderedDict() + keys: typing.List[str] = list(self.__dict__.keys()) + for k in keys: + if not k.startswith("_"): + result[k] = self.__getattr__(k) + return result class Service(CaseInsensitive): - serviceType = None - serviceId = None - controlURL = None - eventSubURL = None - SCPDURL = None + serviceType: typing.Optional[str] = None + serviceId: typing.Optional[str] = None + controlURL: typing.Optional[str] = None + eventSubURL: typing.Optional[str] = None + SCPDURL: typing.Optional[str] = None class Device(CaseInsensitive): - serviceList = None - deviceList = None - deviceType = None - friendlyName = None - manufacturer = None - manufacturerURL = None - modelDescription = None - modelName = None - modelNumber = None - modelURL = None - serialNumber = None - udn = None - upc = None - presentationURL = None - iconList = None + serviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None + deviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None + deviceType: typing.Optional[str] = None + friendlyName: typing.Optional[str] = None + manufacturer: typing.Optional[str] = None + manufacturerURL: typing.Optional[str] = None + modelDescription: typing.Optional[str] = None + modelName: typing.Optional[str] = None + modelNumber: typing.Optional[str] = None + modelURL: typing.Optional[str] = None + serialNumber: typing.Optional[str] = None + udn: typing.Optional[str] = None + upc: typing.Optional[str] = None + presentationURL: typing.Optional[str] = None + iconList: typing.Optional[str] = None - def __init__(self, devices: List, services: List, **kwargs) -> None: + def __init__(self, devices: typing.List['Device'], services: typing.List[Service], + **kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]) -> None: super(Device, self).__init__(**kwargs) if self.serviceList and "service" in self.serviceList: - new_services = self.serviceList["service"] - if isinstance(new_services, dict): - new_services = [new_services] - services.extend([Service(**service) for service in new_services]) + if isinstance(self.serviceList['service'], dict): + assert isinstance(self.serviceList['service'], dict) + svc_list: typing.Dict[str, typing.Any] = self.serviceList['service'] + services.append(Service(**svc_list)) + elif isinstance(self.serviceList['service'], list): + services.extend(Service(**svc) for svc in self.serviceList["service"]) + if self.deviceList: for kw in self.deviceList.values(): if isinstance(kw, dict): - d = Device(devices, services, **kw) - devices.append(d) + devices.append(Device(devices, services, **kw)) elif isinstance(kw, list): for _inner_kw in kw: - d = Device(devices, services, **_inner_kw) - devices.append(d) + devices.append(Device(devices, services, **_inner_kw)) else: log.warning("failed to parse device:\n%s", kw) diff --git a/aioupnp/fault.py b/aioupnp/fault.py index 6acb914..5aa2630 100644 --- a/aioupnp/fault.py +++ b/aioupnp/fault.py @@ -1,14 +1,2 @@ -from aioupnp.util import flatten_keys -from aioupnp.constants import FAULT, CONTROL - - class UPnPError(Exception): pass - - -def handle_fault(response: dict) -> dict: - if FAULT in response: - fault = flatten_keys(response[FAULT], "{%s}" % CONTROL) - error_description = fault['detail']['UPnPError']['errorDescription'] - raise UPnPError(error_description) - return response diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index ad04143..1ed4b1e 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -1,9 +1,10 @@ +import re import logging -import socket +import typing import asyncio from collections import OrderedDict -from typing import Dict, List, Union, Type -from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX +from typing import Dict, List, Union +from aioupnp.util import get_dict_val_case_insensitive from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.commands import SOAPCommands from aioupnp.device import Device, Service @@ -15,77 +16,94 @@ from aioupnp.fault import UPnPError log = logging.getLogger(__name__) -return_type_lambas = { - Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None -} +BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) +BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) -def get_action_list(element_dict: dict) -> List: # [(, [, ...], [ typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]]: service_info = flatten_keys(element_dict, "{%s}" % SERVICE) + result: typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]] = [] if "actionList" in service_info: action_list = service_info["actionList"] else: - return [] + return result if not len(action_list): # it could be an empty string - return [] + return result - result: list = [] - if isinstance(action_list["action"], dict): - arg_dicts = action_list["action"]['argumentList']['argument'] - if not isinstance(arg_dicts, list): # when there is one arg - arg_dicts = [arg_dicts] - return [[ - action_list["action"]['name'], - [i['name'] for i in arg_dicts if i['direction'] == 'in'], - [i['name'] for i in arg_dicts if i['direction'] == 'out'] - ]] - for action in action_list["action"]: - if not action.get('argumentList'): - result.append((action['name'], [], [])) + action = action_list["action"] + if isinstance(action, dict): + arg_dicts: typing.List[typing.Dict[str, str]] = [] + if not isinstance(action['argumentList']['argument'], list): # when there is one arg + arg_dicts.extend([action['argumentList']['argument']]) else: - arg_dicts = action['argumentList']['argument'] - if not isinstance(arg_dicts, list): # when there is one arg - arg_dicts = [arg_dicts] + arg_dicts.extend(action['argumentList']['argument']) + + result.append((action_list["action"]['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'], + [i['name'] for i in arg_dicts if i['direction'] == 'out'])) + return result + assert isinstance(action, list) + for _action in action: + if not _action.get('argumentList'): + result.append((_action['name'], [], [])) + else: + if not isinstance(_action['argumentList']['argument'], list): # when there is one arg + arg_dicts = [_action['argumentList']['argument']] + else: + arg_dicts = _action['argumentList']['argument'] result.append(( - action['name'], + _action['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'], [i['name'] for i in arg_dicts if i['direction'] == 'out'] )) return result +def parse_location(location: bytes) -> typing.Tuple[bytes, int]: + base_address_result: typing.List[bytes] = BASE_ADDRESS_REGEX.findall(location) + base_address = base_address_result[0] + port_result: typing.List[bytes] = BASE_PORT_REGEX.findall(location) + port = int(port_result[0]) + return base_address, port + + class Gateway: - def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str, - gateway_address: str) -> None: + commands: SOAPCommands + + def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]], + lan_address: str, gateway_address: str, + loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: + self._loop = loop or asyncio.get_event_loop() self._ok_packet = ok_packet self._m_search_args = m_search_args self._lan_address = lan_address - self.usn = (ok_packet.usn or '').encode() - self.ext = (ok_packet.ext or '').encode() - self.server = (ok_packet.server or '').encode() - self.location = (ok_packet.location or '').encode() - self.cache_control = (ok_packet.cache_control or '').encode() - self.date = (ok_packet.date or '').encode() - self.urn = (ok_packet.st or '').encode() + self.usn: bytes = (ok_packet.usn or '').encode() + self.ext: bytes = (ok_packet.ext or '').encode() + self.server: bytes = (ok_packet.server or '').encode() + self.location: bytes = (ok_packet.location or '').encode() + self.cache_control: bytes = (ok_packet.cache_control or '').encode() + self.date: bytes = (ok_packet.date or '').encode() + self.urn: bytes = (ok_packet.st or '').encode() - self._xml_response = b"" - self._service_descriptors: Dict = {} - self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] - self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) + self._xml_response: bytes = b"" + self._service_descriptors: Dict[str, bytes] = {} + + self.base_address, self.port = parse_location(self.location) self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] assert self.base_ip == gateway_address.encode() self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] - self.spec_version = None - self.url_base = None + self.spec_version: typing.Optional[str] = None + self.url_base: typing.Optional[str] = None - self._device: Union[None, Device] = None - self._devices: List = [] - self._services: List = [] + self._device: typing.Optional[Device] = None + self._devices: List[Device] = [] + self._services: List[Service] = [] - self._unsupported_actions: Dict = {} - self._registered_commands: Dict = {} - self.commands = SOAPCommands() + self._unsupported_actions: Dict[str, typing.List[str]] = {} + self._registered_commands: Dict[str, str] = {} + self.commands = SOAPCommands(self._loop, self.base_ip, self.port) def gateway_descriptor(self) -> dict: r = { @@ -102,14 +120,15 @@ class Gateway: def manufacturer_string(self) -> str: if not self.devices: return "UNKNOWN GATEWAY" - device = list(self.devices.values())[0] - return "%s %s" % (device.manufacturer, device.modelName) + devices: typing.List[Device] = list(self.devices.values()) + device = devices[0] + return f"{device.manufacturer} {device.modelName}" @property - def services(self) -> Dict: + def services(self) -> Dict[str, Service]: if not self._device: return {} - return {service.serviceType: service for service in self._services} + return {str(service.serviceType): service for service in self._services} @property def devices(self) -> Dict: @@ -117,28 +136,29 @@ class Gateway: return {} return {device.udn: device for device in self._devices} - def get_service(self, service_type: str) -> Union[Type[Service], None]: + def get_service(self, service_type: str) -> typing.Optional[Service]: for service in self._services: - if service.serviceType.lower() == service_type.lower(): + if service.serviceType and service.serviceType.lower() == service_type.lower(): return service return None @property - def soap_requests(self) -> List: - soap_call_infos = [] - for name in self._registered_commands.keys(): - if not hasattr(getattr(self.commands, name), "_requests"): - continue - soap_call_infos.extend([ - (name, request_args, raw_response, decoded_response, soap_error, ts) - for ( - request_args, raw_response, decoded_response, soap_error, ts - ) in getattr(self.commands, name)._requests - ]) + def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, + typing.Optional[typing.Dict[str, typing.Any]], + typing.Optional[Exception], float]]: + soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, + typing.Optional[typing.Dict[str, typing.Any]], + typing.Optional[Exception], float]] = [] + soap_call_infos.extend([ + (name, request_args, raw_response, decoded_response, soap_error, ts) + for ( + name, request_args, raw_response, decoded_response, soap_error, ts + ) in self.commands._requests + ]) soap_call_infos.sort(key=lambda x: x[5]) return soap_call_infos - def debug_gateway(self) -> Dict: + def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]: return { 'manufacturer_string': self.manufacturer_string, 'gateway_address': self.base_ip, @@ -156,9 +176,11 @@ class Gateway: @classmethod async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, - igd_args: OrderedDict = None, loop=None, unicast: bool = False): - ignored: set = set() - required_commands = [ + igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + unicast: bool = False) -> 'Gateway': + ignored: typing.Set[str] = set() + required_commands: typing.List[str] = [ 'AddPortMapping', 'DeletePortMapping', 'GetExternalIPAddress' @@ -166,20 +188,21 @@ class Gateway: while True: if not igd_args: m_search_args, datagram = await fuzzy_m_search( - lan_address, gateway_address, timeout, loop, ignored, unicast + lan_address, gateway_address, timeout, loop, ignored, unicast ) else: m_search_args = OrderedDict(igd_args) datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast) try: - gateway = cls(datagram, m_search_args, lan_address, gateway_address) + gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop) log.debug('get gateway descriptor %s', datagram.location) await gateway.discover_commands(loop) - requirements_met = all([required in gateway._registered_commands for required in required_commands]) + requirements_met = all([gateway.commands.is_registered(required) for required in required_commands]) if not requirements_met: not_met = [ - required for required in required_commands if required not in gateway._registered_commands + required for required in required_commands if not gateway.commands.is_registered(required) ] + assert datagram.location is not None log.debug("found gateway %s at %s, but it does not implement required soap commands: %s", gateway.manufacturer_string, gateway.location, not_met) ignored.add(datagram.location) @@ -188,13 +211,17 @@ class Gateway: log.debug('found gateway device %s', datagram.location) return gateway except (asyncio.TimeoutError, UPnPError) as err: + assert datagram.location is not None log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) ignored.add(datagram.location) continue @classmethod async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, - igd_args: OrderedDict = None, loop=None, unicast: bool = None): + igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + unicast: typing.Optional[bool] = None) -> 'Gateway': + loop = loop or asyncio.get_event_loop() if unicast is not None: return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) @@ -205,7 +232,7 @@ class Gateway: cls._discover_gateway( lan_address, gateway_address, timeout, igd_args, loop, unicast=False ) - ], return_when=asyncio.tasks.FIRST_COMPLETED) + ], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop) for task in pending: task.cancel() @@ -214,56 +241,78 @@ class Gateway: task.exception() except asyncio.CancelledError: pass + results: typing.List[asyncio.Future['Gateway']] = list(done) + return results[0].result() - return list(done)[0].result() - - async def discover_commands(self, loop=None): + async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop) self._xml_response = xml_bytes if get_err is not None: raise get_err - self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) - self.url_base = get_dict_val_case_insensitive(response, "urlbase") + spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) + if isinstance(spec_version, bytes): + self.spec_version = spec_version.decode() + else: + self.spec_version = spec_version + url_base = get_dict_val_case_insensitive(response, "urlbase") + if isinstance(url_base, bytes): + self.url_base = url_base.decode() + else: + self.url_base = url_base if not self.url_base: self.url_base = self.base_address.decode() if response: - device_dict = get_dict_val_case_insensitive(response, "device") + source_keys: typing.List[str] = list(response.keys()) + matches: typing.List[str] = list(filter(lambda x: x.lower() == "device", source_keys)) + match_key = matches[0] + match: dict = response[match_key] + # if not len(match): + # return None + # if len(match) > 1: + # raise KeyError("overlapping keys") + # if len(match) == 1: + # matched_key: typing.AnyStr = match[0] + # return source[matched_key] + # raise KeyError("overlapping keys") + self._device = Device( - self._devices, self._services, **device_dict + self._devices, self._services, **match ) else: self._device = Device(self._devices, self._services) for service_type in self.services.keys(): await self.register_commands(self.services[service_type], loop) + return None - async def register_commands(self, service: Service, loop=None): + async def register_commands(self, service: Service, + loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: if not service.SCPDURL: raise UPnPError("no scpd url") + if not service.serviceType: + raise UPnPError("no service type") log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL) - service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) + service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop) self._service_descriptors[service.SCPDURL] = xml_bytes if get_err is not None: log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL) if xml_bytes: log.debug("response: %s", xml_bytes.decode()) - return + return None if not service_dict: - return + return None action_list = get_action_list(service_dict) - for name, inputs, outputs in action_list: try: - self.commands.register(self.base_ip, self.port, name, service.controlURL, service.serviceType.encode(), - inputs, outputs, loop) + self.commands.register(name, service, inputs, outputs) self._registered_commands[name] = service.serviceType log.debug("registered %s::%s", service.serviceType, name) except AttributeError: - s = self._unsupported_actions.get(service.serviceType, []) - s.append(name) - self._unsupported_actions[service.serviceType] = s + self._unsupported_actions.setdefault(service.serviceType, []) + self._unsupported_actions[service.serviceType].append(name) log.debug("available command for %s does not have a wrapper implemented: %s %s %s", service.serviceType, name, inputs, outputs) log.debug("registered service %s", service.serviceType) + return None diff --git a/aioupnp/interfaces.py b/aioupnp/interfaces.py new file mode 100644 index 0000000..78f11ad --- /dev/null +++ b/aioupnp/interfaces.py @@ -0,0 +1,51 @@ +import socket +from collections import OrderedDict +import typing +import netifaces + + +def get_netifaces(): + return netifaces + + +def ifaddresses(iface: str): + return get_netifaces().ifaddresses(iface) + + +def _get_interfaces(): + return get_netifaces().interfaces() + + +def _get_gateways(): + return get_netifaces().gateways() + + +def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]: + gateways = _get_gateways() + infos = gateways[socket.AF_INET] + assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict") + interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos + result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict( + (interface_name, (router_address, ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'])) + for router_address, interface_name, _ in interface_infos + ) + for interface_name in _get_interfaces(): + if interface_name in ['lo', 'localhost'] or interface_name in result: + continue + addresses = ifaddresses(interface_name) + if netifaces.AF_INET in addresses: + address = addresses[netifaces.AF_INET][0]['addr'] + gateway_guess = ".".join(address.split(".")[:-1] + ["1"]) + result[interface_name] = (gateway_guess, address) + _default = gateways['default'] + assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list") + default: typing.Dict[int, typing.Tuple[str, str]] = _default + result['default'] = result[default[netifaces.AF_INET][1]] + return result + + +def get_gateway_and_lan_addresses(interface_name: str) -> typing.Tuple[str, str]: + for iface_name, (gateway, lan) in get_interfaces().items(): + if interface_name == iface_name: + return gateway, lan + return '', '' diff --git a/aioupnp/protocols/m_search_patterns.py b/aioupnp/protocols/m_search_patterns.py index 722e5e3..5aed709 100644 --- a/aioupnp/protocols/m_search_patterns.py +++ b/aioupnp/protocols/m_search_patterns.py @@ -44,10 +44,11 @@ ST characters in the domain name must be replaced with hyphens in accordance with RFC 2141. """ +import typing from collections import OrderedDict from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST -SEARCH_TARGETS = [ +SEARCH_TARGETS: typing.List[str] = [ 'upnp:rootdevice', 'urn:schemas-upnp-org:device:InternetGatewayDevice:1', 'urn:schemas-wifialliance-org:device:WFADevice:1', @@ -58,7 +59,8 @@ SEARCH_TARGETS = [ ] -def format_packet_args(order: list, **kwargs): +def format_packet_args(order: typing.List[str], + kwargs: typing.Dict[str, typing.Union[int, str]]) -> typing.Dict[str, typing.Union[int, str]]: args = [] for o in order: for k, v in kwargs.items(): @@ -68,18 +70,18 @@ def format_packet_args(order: list, **kwargs): return OrderedDict(args) -def packet_generator(): +def packet_generator() -> typing.Iterator[typing.Dict[str, typing.Union[int, str]]]: for st in SEARCH_TARGETS: order = ["HOST", "MAN", "MX", "ST"] - yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) - yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st) - yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) - yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st}) + yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st}) + yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st}) + yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': SSDP_DISCOVER, 'MX': 1, 'ST': st}) order = ["HOST", "MAN", "ST", "MX"] - yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) - yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st}) + yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st}) order = ["HOST", "ST", "MAN", "MX"] - yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) - yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) + yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st}) + yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st}) diff --git a/aioupnp/protocols/multicast.py b/aioupnp/protocols/multicast.py index 2068dc6..452b304 100644 --- a/aioupnp/protocols/multicast.py +++ b/aioupnp/protocols/multicast.py @@ -1,45 +1,70 @@ import struct import socket +import typing from asyncio.protocols import DatagramProtocol -from asyncio.transports import DatagramTransport +from asyncio.transports import BaseTransport +from unittest import mock + + +def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]: + if transport is None or not hasattr(transport, "_extra"): + return None + sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None) + assert sock is None or isinstance(sock, socket.SocketType) or isinstance(sock, mock.MagicMock) + return sock class MulticastProtocol(DatagramProtocol): def __init__(self, multicast_address: str, bind_address: str) -> None: self.multicast_address = multicast_address self.bind_address = bind_address - self.transport: DatagramTransport + self.transport: typing.Optional[BaseTransport] = None @property - def sock(self) -> socket.socket: - s: socket.socket = self.transport.get_extra_info(name='socket') - return s + def sock(self) -> typing.Optional[socket.socket]: + return _get_sock(self.transport) def get_ttl(self) -> int: - return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) + sock = self.sock + if not sock: + raise ValueError("not connected") + return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) def set_ttl(self, ttl: int = 1) -> None: - self.sock.setsockopt( + sock = self.sock + if not sock: + return None + sock.setsockopt( socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) ) + return None def join_group(self, multicast_address: str, bind_address: str) -> None: - self.sock.setsockopt( + sock = self.sock + if not sock: + return None + sock.setsockopt( socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) ) + return None def leave_group(self, multicast_address: str, bind_address: str) -> None: - self.sock.setsockopt( + sock = self.sock + if not sock: + raise ValueError("not connected") + sock.setsockopt( socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) ) + return None - def connection_made(self, transport) -> None: + def connection_made(self, transport: BaseTransport) -> None: self.transport = transport + return None @classmethod - def create_multicast_socket(cls, bind_address: str): + def create_multicast_socket(cls, bind_address: str) -> socket.socket: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((bind_address, 0)) diff --git a/aioupnp/protocols/scpd.py b/aioupnp/protocols/scpd.py index 5b61212..3543ba8 100644 --- a/aioupnp/protocols/scpd.py +++ b/aioupnp/protocols/scpd.py @@ -2,7 +2,6 @@ import logging import typing import re from collections import OrderedDict -from xml.etree import ElementTree import asyncio from asyncio.protocols import Protocol from aioupnp.fault import UPnPError @@ -18,16 +17,22 @@ log = logging.getLogger(__name__) HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$") -def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]: +def parse_http_response_code(http_response: bytes) -> typing.Tuple[bytes, bytes]: + parsed: typing.List[typing.Tuple[bytes, bytes]] = HTTP_CODE_REGEX.findall(http_response) + return parsed[0] + + +def parse_headers(response: bytes) -> typing.Tuple[typing.Dict[bytes, bytes], int, bytes]: lines = response.split(b'\r\n') - headers = OrderedDict([ + headers: typing.Dict[bytes, bytes] = OrderedDict([ (l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' ')) for l in response.split(b'\r\n') ]) if len(lines) != len(headers): raise ValueError("duplicate headers") - http_response = tuple(headers.keys())[0] - response_code, message = HTTP_CODE_REGEX.findall(http_response)[0] + header_keys: typing.List[bytes] = list(headers.keys()) + http_response = header_keys[0] + response_code, message = parse_http_response_code(http_response) del headers[http_response] return headers, int(response_code), message @@ -40,37 +45,42 @@ class SCPDHTTPClientProtocol(Protocol): and devices respond with an invalid HTTP version line """ - def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None, - soap_service_id: str=None) -> None: + def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]', + soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None: self.message = message self.response_buff = b"" self.finished = finished self.soap_method = soap_method self.soap_service_id = soap_service_id - - self._response_code: int = 0 - self._response_msg: bytes = b"" - self._content_length: int = 0 + self._response_code = 0 + self._response_msg = b"" + self._content_length = 0 self._got_headers = False - self._headers: dict = {} + self._headers: typing.Dict[bytes, bytes] = {} self._body = b"" + self.transport: typing.Optional[asyncio.WriteTransport] = None - def connection_made(self, transport): - transport.write(self.message) + def connection_made(self, transport: asyncio.BaseTransport) -> None: + assert isinstance(transport, asyncio.WriteTransport) + self.transport = transport + self.transport.write(self.message) + return None - def data_received(self, data): + def data_received(self, data: bytes) -> None: self.response_buff += data for i, line in enumerate(self.response_buff.split(b'\r\n')): if not line: # we hit the blank line between the headers and the body if i == (len(self.response_buff.split(b'\r\n')) - 1): - return # the body is still yet to be written + return None # the body is still yet to be written if not self._got_headers: self._headers, self._response_code, self._response_msg = parse_headers( b'\r\n'.join(self.response_buff.split(b'\r\n')[:i]) ) - content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length') + content_length = get_dict_val_case_insensitive( + self._headers, b'Content-Length' + ) if content_length is None: - return + return None self._content_length = int(content_length or 0) self._got_headers = True body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:]) @@ -86,21 +96,28 @@ class SCPDHTTPClientProtocol(Protocol): ) ) ) - return + return None + return None -async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typing.Tuple[typing.Dict, bytes, - typing.Optional[Exception]]: - loop = loop or asyncio.get_event_loop_policy().get_event_loop() - finished: asyncio.Future = asyncio.Future() +async def scpd_get(control_url: str, address: str, port: int, + loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> typing.Tuple[typing.Dict[str, typing.Any], bytes, + typing.Optional[Exception]]: + loop = loop or asyncio.get_event_loop() packet = serialize_scpd_get(control_url, address) - transport, protocol = await loop.create_connection( - lambda : SCPDHTTPClientProtocol(packet, finished), address, port + finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop) + proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished) + connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( + proto_factory, address, port ) + protocol = connect_tup[1] + transport = connect_tup[0] assert isinstance(protocol, SCPDHTTPClientProtocol) + error = None + wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop) try: - body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) + body, response_code, response_msg = await wait_task except asyncio.TimeoutError: error = UPnPError("get request timed out") body = b'' @@ -112,24 +129,31 @@ async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typi if not error: try: return deserialize_scpd_get_response(body), body, None - except ElementTree.ParseError as err: + except Exception as err: error = UPnPError(err) + return {}, body, error async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, - loop=None, **kwargs) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: - loop = loop or asyncio.get_event_loop_policy().get_event_loop() - finished: asyncio.Future = asyncio.Future() + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + **kwargs: typing.Dict[str, typing.Any] + ) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: + loop = loop or asyncio.get_event_loop() + finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) - transport, protocol = await loop.create_connection( - lambda : SCPDHTTPClientProtocol( - packet, finished, soap_method=method, soap_service_id=service_id.decode(), - ), address, port + proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\ + SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode()) + connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( + proto_factory, address, port ) + protocol = connect_tup[1] + transport = connect_tup[0] assert isinstance(protocol, SCPDHTTPClientProtocol) + try: - body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) + wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop) + body, response_code, response_msg = await wait_task except asyncio.TimeoutError: return {}, b'', UPnPError("Timeout") except UPnPError as err: @@ -140,5 +164,5 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para return ( deserialize_soap_post_response(body, method, service_id.decode()), body, None ) - except (ElementTree.ParseError, UPnPError) as err: + except Exception as err: return {}, body, UPnPError(err) diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 7a4e694..a43e117 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -3,8 +3,8 @@ import binascii import asyncio import logging import typing +import socket from collections import OrderedDict -from asyncio.futures import Future from asyncio.transports import DatagramTransport from aioupnp.fault import UPnPError from aioupnp.serialization.ssdp import SSDPDatagram @@ -18,32 +18,48 @@ log = logging.getLogger(__name__) class SSDPProtocol(MulticastProtocol): - def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None, - unicast: bool = False) -> None: + def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None, + unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: super().__init__(multicast_address, lan_address) + self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + self.transport: typing.Optional[DatagramTransport] = None self._unicast = unicast self._ignored: typing.Set[str] = ignored or set() # ignored locations - self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = [] - self.notifications: typing.List = [] + self._pending_searches: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = [] + self.notifications: typing.List[SSDPDatagram] = [] + self.connected = asyncio.Event(loop=self.loop) - def disconnect(self): + def connection_made(self, transport) -> None: + # assert isinstance(transport, asyncio.DatagramTransport), str(type(transport)) + super().connection_made(transport) + self.connected.set() + + def disconnect(self) -> None: if self.transport: + try: + self.leave_group(self.multicast_address, self.bind_address) + except ValueError: + pass + except Exception: + log.exception("unexpected error leaving multicast group") self.transport.close() + self.connected.clear() while self._pending_searches: pending = self._pending_searches.pop()[2] if not pending.cancelled() and not pending.done(): pending.cancel() + return None def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: if packet.location in self._ignored: - return - tmp: typing.List = [] - set_futures: typing.List = [] - while self._pending_searches: - t: tuple = self._pending_searches.pop() - a, s = t[0], t[1] - if (address == a) and (s in [packet.st, "upnp:rootdevice"]): - f: Future = t[2] + return None + # TODO: fix this + tmp: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = [] + set_futures: typing.List[asyncio.Future[SSDPDatagram]] = [] + while len(self._pending_searches): + t: typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle] = self._pending_searches.pop() + if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): + f: asyncio.Future[SSDPDatagram] = t[2] if f not in set_futures: set_futures.append(f) if not f.done(): @@ -52,38 +68,41 @@ class SSDPProtocol(MulticastProtocol): tmp.append(t) while tmp: self._pending_searches.append(tmp.pop()) + return None - def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]): + def _send_m_search(self, address: str, packet: SSDPDatagram) -> None: dest = address if self._unicast else SSDP_IP_ADDRESS - for packet in packets: - log.debug("send m search to %s: %s", dest, packet.st) - self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) + if not self.transport: + raise UPnPError("SSDP transport not connected") + log.debug("send m search to %s: %s", dest, packet.st) + self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) + return None - async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram: - fut: Future = Future() - packets: typing.List[SSDPDatagram] = [] + async def m_search(self, address: str, timeout: float, + datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram: + fut: asyncio.Future[SSDPDatagram] = asyncio.Future(loop=self.loop) for datagram in datagrams: - packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) + packet = SSDPDatagram("M-SEARCH", datagram) assert packet.st is not None - self._pending_searches.append((address, packet.st, fut)) - packets.append(packet) - self.send_many_m_searches(address, packets), - return await fut + self._pending_searches.append( + (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet)) + ) + return await asyncio.wait_for(fut, timeout) - def datagram_received(self, data, addr) -> None: + def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore if addr[0] == self.bind_address: - return + return None try: packet = SSDPDatagram.decode(data) log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet) except UPnPError as err: log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, binascii.hexlify(data)) - return + return None if packet._packet_type == packet._OK: self._callback_m_search_ok(addr[0], packet) - return + return None # elif packet._packet_type == packet._NOTIFY: # log.debug("%s:%i sent us a notification: %s", packet) # if packet.nt == SSDP_ROOT_DEVICE: @@ -104,17 +123,18 @@ class SSDPProtocol(MulticastProtocol): # return -async def listen_ssdp(lan_address: str, gateway_address: str, loop=None, - ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport, - SSDPProtocol, str, str]: - loop = loop or asyncio.get_event_loop_policy().get_event_loop() +async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ignored: typing.Optional[typing.Set[str]] = None, + unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]: + loop = loop or asyncio.get_event_loop() try: - sock = SSDPProtocol.create_multicast_socket(lan_address) - listen_result: typing.Tuple = await loop.create_datagram_endpoint( + sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address) + listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock ) - transport: DatagramTransport = listen_result[0] - protocol: SSDPProtocol = listen_result[1] + transport = listen_result[0] + protocol = listen_result[1] + assert isinstance(protocol, SSDPProtocol) except Exception as err: print(err) raise UPnPError(err) @@ -125,30 +145,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop=None, protocol.disconnect() raise UPnPError(err) - return transport, protocol, gateway_address, lan_address + return protocol, gateway_address, lan_address -async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, - loop=None, ignored: typing.Set[str] = None, - unicast: bool = False) -> SSDPDatagram: - transport, protocol, gateway_address, lan_address = await listen_ssdp( +async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]], + timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram: + protocol, gateway_address, lan_address = await listen_ssdp( lan_address, gateway_address, loop, ignored, unicast ) try: - return await asyncio.wait_for( - protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout - ) + return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) except (asyncio.TimeoutError, asyncio.CancelledError): raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) finally: protocol.disconnect() -async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, - ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]: - transport, protocol, gateway_address, lan_address = await listen_ssdp( +async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ignored: typing.Set[str] = None, + unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]: + protocol, gateway_address, lan_address = await listen_ssdp( lan_address, gateway_address, loop, ignored, unicast ) + await protocol.connected.wait() packet_args = list(packet_generator()) batch_size = 2 batch_timeout = float(timeout) / float(len(packet_args)) @@ -157,7 +178,7 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = packet_args = packet_args[batch_size:] log.debug("sending batch of %i M-SEARCH attempts", batch_size) try: - await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), batch_timeout) + await protocol.m_search(gateway_address, batch_timeout, args) protocol.disconnect() return args except asyncio.TimeoutError: @@ -166,9 +187,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) -async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, - ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict, - SSDPDatagram]: +async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ignored: typing.Set[str] = None, + unicast: bool = False) -> typing.Tuple[typing.Dict[str, + typing.Union[int, str]], SSDPDatagram]: # we don't know which packet the gateway replies to, so send small batches at a time args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast) # check the args in the batch that got a reply one at a time to see which one worked diff --git a/aioupnp/serialization/scpd.py b/aioupnp/serialization/scpd.py index 75377cd..7694455 100644 --- a/aioupnp/serialization/scpd.py +++ b/aioupnp/serialization/scpd.py @@ -1,8 +1,9 @@ import re -from typing import Dict -from xml.etree import ElementTree -from aioupnp.constants import XML_VERSION, DEVICE, ROOT -from aioupnp.util import etree_to_dict, flatten_keys +from typing import Dict, Any, List, Tuple +from aioupnp.fault import UPnPError +from aioupnp.constants import XML_VERSION +from aioupnp.serialization.xml import xml_to_dict +from aioupnp.util import flatten_keys CONTENT_PATTERN = re.compile( @@ -28,34 +29,38 @@ def serialize_scpd_get(path: str, address: str) -> bytes: if not path.startswith("/"): path = "/" + path return ( - ( - 'GET %s HTTP/1.1\r\n' - 'Accept-Encoding: gzip\r\n' - 'Host: %s\r\n' - 'Connection: Close\r\n' - '\r\n' - ) % (path, host) + f'GET {path} HTTP/1.1\r\n' + f'Accept-Encoding: gzip\r\n' + f'Host: {host}\r\n' + f'Connection: Close\r\n' + f'\r\n' ).encode() -def deserialize_scpd_get_response(content: bytes) -> Dict: +def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]: if XML_VERSION.encode() in content: - parsed = CONTENT_PATTERN.findall(content) - content = b'' if not parsed else parsed[0][0] - xml_dict = etree_to_dict(ElementTree.fromstring(content.decode())) + parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content) + xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode()) return parse_device_dict(xml_dict) return {} -def parse_device_dict(xml_dict: dict) -> Dict: +def parse_device_dict(xml_dict: Dict[str, Any]) -> Dict[str, Any]: keys = list(xml_dict.keys()) + found = False for k in keys: - m = XML_ROOT_SANITY_PATTERN.findall(k) + m: List[Tuple[str, str, str, str, str, str]] = XML_ROOT_SANITY_PATTERN.findall(k) if len(m) == 3 and m[1][0] and m[2][5]: - schema_key = m[1][0] - root = m[2][5] - xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root] + schema_key: str = m[1][0] + root: str = m[2][5] + flattened = flatten_keys(xml_dict, "{%s}" % schema_key) + if root not in flattened: + raise UPnPError("root device not found") + xml_dict = flattened[root] + found = True break + if not found: + raise UPnPError("device not found") result = {} for k, v in xml_dict.items(): if isinstance(xml_dict[k], dict): @@ -65,10 +70,9 @@ def parse_device_dict(xml_dict: dict) -> Dict: if len(parsed_k) == 2: inner_d[parsed_k[0]] = inner_v else: - assert len(parsed_k) == 3 + assert len(parsed_k) == 3, f"expected len=3, got {len(parsed_k)}" inner_d[parsed_k[1]] = inner_v result[k] = inner_d else: result[k] = v - return result diff --git a/aioupnp/serialization/soap.py b/aioupnp/serialization/soap.py index fd5b906..9ebc07a 100644 --- a/aioupnp/serialization/soap.py +++ b/aioupnp/serialization/soap.py @@ -1,64 +1,65 @@ import re -from xml.etree import ElementTree -from aioupnp.util import etree_to_dict, flatten_keys -from aioupnp.fault import handle_fault, UPnPError -from aioupnp.constants import XML_VERSION, ENVELOPE, BODY +import typing +from aioupnp.util import flatten_keys +from aioupnp.fault import UPnPError +from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL +from aioupnp.serialization.xml import xml_to_dict CONTENT_NO_XML_VERSION_PATTERN = re.compile( "(\)".encode() ) -def serialize_soap_post(method: str, param_names: list, service_id: bytes, gateway_address: bytes, - control_url: bytes, **kwargs) -> bytes: - args = "".join("<%s>%s" % (n, kwargs.get(n), n) for n in param_names) - soap_body = ('\r\n%s\r\n' - '%s' % ( - XML_VERSION, method, service_id.decode(), - args, method)) +def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes, + control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes: + args = "".join(f"<{n}>{kwargs.get(n)}" for n in param_names) + soap_body = (f'\r\n{XML_VERSION}\r\n' + f'{args}') if "http://" in gateway_address.decode(): host = gateway_address.decode().split("http://")[1] else: host = gateway_address.decode() return ( - ( - 'POST %s HTTP/1.1\r\n' - 'Host: %s\r\n' - 'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' - 'Content-Length: %i\r\n' - 'Content-Type: text/xml\r\n' - 'SOAPAction: \"%s#%s\"\r\n' - 'Connection: Close\r\n' - 'Cache-Control: no-cache\r\n' - 'Pragma: no-cache\r\n' - '%s' - '\r\n' - ) % ( - control_url.decode(), # could be just / even if it shouldn't be - host, - len(soap_body), - service_id.decode(), # maybe no quotes - method, - soap_body - ) + f'POST {control_url.decode()} HTTP/1.1\r\n' # could be just / even if it shouldn't be + f'Host: {host}\r\n' + f'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' + f'Content-Length: {len(soap_body)}\r\n' + f'Content-Type: text/xml\r\n' + f'SOAPAction: \"{service_id.decode()}#{method}\"\r\n' + f'Connection: Close\r\n' + f'Cache-Control: no-cache\r\n' + f'Pragma: no-cache\r\n' + f'{soap_body}' + f'\r\n' ).encode() -def deserialize_soap_post_response(response: bytes, method: str, service_id: str) -> dict: - parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(response) +def deserialize_soap_post_response(response: bytes, method: str, + service_id: str) -> typing.Dict[str, typing.Dict[str, str]]: + parsed: typing.List[typing.List[bytes]] = CONTENT_NO_XML_VERSION_PATTERN.findall(response) content = b'' if not parsed else parsed[0][0] - content_dict = etree_to_dict(ElementTree.fromstring(content.decode())) + content_dict = xml_to_dict(content.decode()) envelope = content_dict[ENVELOPE] - response_body = flatten_keys(envelope[BODY], "{%s}" % service_id) - body = handle_fault(response_body) # raises UPnPError if there is a fault + if not isinstance(envelope[BODY], dict): + # raise UPnPError('blank response') + return {} # TODO: raise + response_body: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys( + envelope[BODY], f"{'{' + service_id + '}'}" + ) + if not response_body: + # raise UPnPError('blank response') + return {} # TODO: raise + if FAULT in response_body: + fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys( + response_body[FAULT], "{%s}" % CONTROL + ) + raise UPnPError(fault['detail']['UPnPError']['errorDescription']) response_key = None - if not body: - return {} - for key in body: + for key in response_body: if method in key: response_key = key break if not response_key: - raise UPnPError("unknown response fields for %s: %s" % (method, body)) - return body[response_key] + raise UPnPError(f"unknown response fields for {method}: {response_body}") + return response_body[response_key] diff --git a/aioupnp/serialization/ssdp.py b/aioupnp/serialization/ssdp.py index 824bbfe..99a3808 100644 --- a/aioupnp/serialization/ssdp.py +++ b/aioupnp/serialization/ssdp.py @@ -3,43 +3,67 @@ import logging import binascii import json from collections import OrderedDict -from typing import List +from typing import List, Optional, Dict, Union, Tuple, Callable from aioupnp.fault import UPnPError from aioupnp.constants import line_separator log = logging.getLogger(__name__) _template = "(?i)^(%s):[ ]*(.*)$" - - -ssdp_datagram_patterns = { - 'host': (re.compile("(?i)^(host):(.*)$"), str), - 'st': (re.compile(_template % 'st'), str), - 'man': (re.compile(_template % 'man'), str), - 'mx': (re.compile(_template % 'mx'), int), - 'nt': (re.compile(_template % 'nt'), str), - 'nts': (re.compile(_template % 'nts'), str), - 'usn': (re.compile(_template % 'usn'), str), - 'location': (re.compile(_template % 'location'), str), - 'cache_control': (re.compile(_template % 'cache[-|_]control'), str), - 'server': (re.compile(_template % 'server'), str), -} - vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") -class SSDPDatagram(object): +def match_vendor(line: str) -> Optional[Tuple[str, str]]: + match: List[Tuple[str, str, str]] = vendor_pattern.findall(line) + if match: + vendor_key: str = match[-1][0].lstrip(" ").rstrip(" ") + vendor_value: str = match[-1][2].lstrip(" ").rstrip(" ") + return vendor_key, vendor_value + return None + + +def compile_find(pattern: str) -> Callable[[str], Optional[str]]: + p = re.compile(pattern) + + def find(line: str) -> Optional[str]: + result: List[List[str]] = [] + for outer in p.findall(line): + result.append([]) + for inner in outer: + result[-1].append(inner) + if result: + return result[-1][-1].lstrip(" ").rstrip(" ") + return None + + return find + + +ssdp_datagram_patterns: Dict[str, Callable[[str], Optional[str]]] = { + 'host': compile_find("(?i)^(host):(.*)$"), + 'st': compile_find(_template % 'st'), + 'man': compile_find(_template % 'man'), + 'mx': compile_find(_template % 'mx'), + 'nt': compile_find(_template % 'nt'), + 'nts': compile_find(_template % 'nts'), + 'usn': compile_find(_template % 'usn'), + 'location': compile_find(_template % 'location'), + 'cache_control': compile_find(_template % 'cache[-|_]control'), + 'server': compile_find(_template % 'server'), +} + + +class SSDPDatagram: _M_SEARCH = "M-SEARCH" _NOTIFY = "NOTIFY" _OK = "OK" - _start_lines = { + _start_lines: Dict[str, str] = { _M_SEARCH: "M-SEARCH * HTTP/1.1", _NOTIFY: "NOTIFY * HTTP/1.1", _OK: "HTTP/1.1 200 OK" } - _friendly_names = { + _friendly_names: Dict[str, str] = { _M_SEARCH: "m-search", _NOTIFY: "notify", _OK: "m-search response" @@ -47,9 +71,7 @@ class SSDPDatagram(object): _vendor_field_pattern = vendor_pattern - _patterns = ssdp_datagram_patterns - - _required_fields = { + _required_fields: Dict[str, List[str]] = { _M_SEARCH: [ 'host', 'man', @@ -75,137 +97,137 @@ class SSDPDatagram(object): ] } - def __init__(self, packet_type, kwargs: OrderedDict = None) -> None: + def __init__(self, packet_type: str, kwargs: Optional[Dict[str, Union[str, int]]] = None) -> None: if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]: raise UPnPError("unknown packet type: {}".format(packet_type)) self._packet_type = packet_type - kwargs = kwargs or OrderedDict() - self._field_order: list = [ - k.lower().replace("-", "_") for k in kwargs.keys() + kw: Dict[str, Union[str, int]] = kwargs or OrderedDict() + self._field_order: List[str] = [ + k.lower().replace("-", "_") for k in kw.keys() ] - self.host = None - self.man = None - self.mx = None - self.st = None - self.nt = None - self.nts = None - self.usn = None - self.location = None - self.cache_control = None - self.server = None - self.date = None - self.ext = None - for k, v in kwargs.items(): + self.host: Optional[str] = None + self.man: Optional[str] = None + self.mx: Optional[Union[str, int]] = None + self.st: Optional[str] = None + self.nt: Optional[str] = None + self.nts: Optional[str] = None + self.usn: Optional[str] = None + self.location: Optional[str] = None + self.cache_control: Optional[str] = None + self.server: Optional[str] = None + self.date: Optional[str] = None + self.ext: Optional[str] = None + for k, v in kw.items(): normalized = k.lower().replace("-", "_") - if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self, normalized) is None: - setattr(self, normalized, v) - self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()} + if not normalized.startswith("_") and hasattr(self, normalized): + if getattr(self, normalized, None) is None: + setattr(self, normalized, v) + self._case_mappings: Dict[str, str] = {k.lower(): k for k in kw.keys()} for k in self._required_fields[self._packet_type]: - if getattr(self, k) is None: + if getattr(self, k, None) is None: raise UPnPError("missing required field %s" % k) def get_cli_igd_kwargs(self) -> str: fields = [] for field in self._field_order: - v = getattr(self, field) + v = getattr(self, field, None) + if v is None: + raise UPnPError("missing required field %s" % field) fields.append("--%s=%s" % (self._case_mappings.get(field, field), v)) return " ".join(fields) def __repr__(self) -> str: return self.as_json() - def __getitem__(self, item): + def __getitem__(self, item: str) -> Union[str, int]: for i in self._required_fields[self._packet_type]: if i.lower() == item.lower(): return getattr(self, i) raise KeyError(item) - def get_friendly_name(self) -> str: - return self._friendly_names[self._packet_type] - def encode(self, trailing_newlines: int = 2) -> str: lines = [self._start_lines[self._packet_type]] - for attr_name in self._field_order: - if attr_name not in self._required_fields[self._packet_type]: - continue - attr = getattr(self, attr_name) - if attr is None: - raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name)) - if attr_name == 'mx': - value = str(attr) - else: - value = attr - lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value)) + lines.extend( + f"{self._case_mappings.get(attr_name.lower(), attr_name.upper())}: {str(getattr(self, attr_name))}" + for attr_name in self._field_order if attr_name in self._required_fields[self._packet_type] + ) serialized = line_separator.join(lines) for _ in range(trailing_newlines): serialized += line_separator return serialized - def as_dict(self) -> OrderedDict: + def as_dict(self) -> Dict[str, Union[str, int]]: return self._lines_to_content_dict(self.encode().split(line_separator)) def as_json(self) -> str: return json.dumps(self.as_dict(), indent=2) @classmethod - def decode(cls, datagram: bytes): + def decode(cls, datagram: bytes) -> 'SSDPDatagram': packet = cls._from_string(datagram.decode()) if packet is None: raise UPnPError( "failed to decode datagram: {}".format(binascii.hexlify(datagram)) ) for attr_name in packet._required_fields[packet._packet_type]: - attr = getattr(packet, attr_name) - if attr is None: + if getattr(packet, attr_name, None) is None: raise UPnPError( "required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name) ) return packet @classmethod - def _lines_to_content_dict(cls, lines: list) -> OrderedDict: - result: OrderedDict = OrderedDict() + def _lines_to_content_dict(cls, lines: List[str]) -> Dict[str, Union[str, int]]: + result: Dict[str, Union[str, int]] = OrderedDict() + matched_keys: List[str] = [] for line in lines: if not line: continue matched = False - for name, (pattern, field_type) in cls._patterns.items(): - if name not in result and pattern.findall(line): - match = pattern.findall(line)[-1][-1] - result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" ")) - matched = True - break - + for name, pattern in ssdp_datagram_patterns.items(): + if name not in matched_keys: + if name.lower() == 'mx': + _matched_int = pattern(line) + if _matched_int is not None: + match_int = int(_matched_int) + result[line[:len(name)]] = match_int + matched = True + matched_keys.append(name) + break + else: + match = pattern(line) + if match is not None: + result[line[:len(name)]] = match + matched = True + matched_keys.append(name) + break if not matched: - if cls._vendor_field_pattern.findall(line): - match = cls._vendor_field_pattern.findall(line)[-1] - vendor_key = match[0].lstrip(" ").rstrip(" ") - # vendor_domain = match[1].lstrip(" ").rstrip(" ") - value = match[2].lstrip(" ").rstrip(" ") - if vendor_key not in result: - result[vendor_key] = value + matched_vendor = match_vendor(line) + if matched_vendor and matched_vendor[0] not in result: + result[matched_vendor[0]] = matched_vendor[1] return result @classmethod - def _from_string(cls, datagram: str): + def _from_string(cls, datagram: str) -> Optional['SSDPDatagram']: lines = [l for l in datagram.split(line_separator) if l] if not lines: - return + return None if lines[0] == cls._start_lines[cls._M_SEARCH]: return cls._from_request(lines[1:]) if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]: return cls._from_notify(lines[1:]) if lines[0] == cls._start_lines[cls._OK]: return cls._from_response(lines[1:]) + return None @classmethod - def _from_response(cls, lines: List): + def _from_response(cls, lines: List) -> 'SSDPDatagram': return cls(cls._OK, cls._lines_to_content_dict(lines)) @classmethod - def _from_notify(cls, lines: List): + def _from_notify(cls, lines: List) -> 'SSDPDatagram': return cls(cls._NOTIFY, cls._lines_to_content_dict(lines)) @classmethod - def _from_request(cls, lines: List): + def _from_request(cls, lines: List) -> 'SSDPDatagram': return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines)) diff --git a/aioupnp/serialization/xml.py b/aioupnp/serialization/xml.py new file mode 100644 index 0000000..98e2d01 --- /dev/null +++ b/aioupnp/serialization/xml.py @@ -0,0 +1,81 @@ +import typing +from collections import OrderedDict +from defusedxml import ElementTree + +str_any_dict = typing.Dict[str, typing.Any] + + +def parse_xml(xml_str: str) -> ElementTree: + element: ElementTree = ElementTree.fromstring(xml_str) + return element + + +def _element_text(element_tree: ElementTree) -> typing.Optional[str]: + # if element_tree.attrib: + # element: typing.Dict[str, str] = OrderedDict() + # for k, v in element_tree.attrib.items(): + # element['@' + k] = v + # if not element_tree.text: + # return element + # element['#text'] = element_tree.text.strip() + # return element + if element_tree.text: + return element_tree.text.strip() + return None + + +def _get_element_children(element_tree: ElementTree) -> typing.Dict[str, typing.Union[ + str, typing.List[typing.Union[typing.Dict[str, str], str]]]]: + element_children = _get_child_dicts(element_tree) + element: typing.Dict[str, typing.Any] = OrderedDict() + keys: typing.List[str] = list(element_children.keys()) + for k in keys: + v: typing.Union[str, typing.List[typing.Any], typing.Dict[str, typing.Any]] = element_children[k] + if len(v) == 1 and isinstance(v, list): + l: typing.List[typing.Union[typing.Dict[str, str], str]] = v + element[k] = l[0] + else: + element[k] = v + return element + + +def _get_child_dicts(element: ElementTree) -> typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]]: + children_dicts: typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]] = OrderedDict() + children: typing.List[ElementTree] = list(element) + for child in children: + child_dict = _recursive_element_to_dict(child) + child_keys: typing.List[str] = list(child_dict.keys()) + for k in child_keys: + assert k in child_dict + v: typing.Union[typing.Dict[str, str], str] = child_dict[k] + if k not in children_dicts.keys(): + new_item = [v] + children_dicts[k] = new_item + else: + sublist = children_dicts[k] + assert isinstance(sublist, list) + sublist.append(v) + return children_dicts + + +def _recursive_element_to_dict(element_tree: ElementTree) -> typing.Dict[str, typing.Any]: + if len(element_tree): + element_result: typing.Dict[str, typing.Dict[str, typing.Union[ + str, typing.List[typing.Union[str, typing.Dict[str, typing.Any]]]]]] = OrderedDict() + children_element = _get_element_children(element_tree) + if element_tree.tag is not None: + element_result[element_tree.tag] = children_element + return element_result + else: + element_text = _element_text(element_tree) + if element_text is not None: + base_element_result: typing.Dict[str, typing.Any] = OrderedDict() + if element_tree.tag is not None: + base_element_result[element_tree.tag] = element_text + return base_element_result + null_result: typing.Dict[str, str] = OrderedDict() + return null_result + + +def xml_to_dict(xml_str: str) -> typing.Dict[str, typing.Any]: + return _recursive_element_to_dict(parse_xml(xml_str)) diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index 9eab4e9..e16bb0f 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -5,12 +5,11 @@ import asyncio import zlib import base64 from collections import OrderedDict -from typing import Tuple, Dict, List, Union +from typing import Tuple, Dict, List, Union, Optional, Callable from aioupnp.fault import UPnPError from aioupnp.gateway import Gateway -from aioupnp.util import get_gateway_and_lan_addresses +from aioupnp.interfaces import get_gateway_and_lan_addresses from aioupnp.protocols.ssdp import m_search, fuzzy_m_search -from aioupnp.commands import SOAPCommand from aioupnp.serialization.ssdp import SSDPDatagram log = logging.getLogger(__name__) @@ -35,6 +34,10 @@ class UPnP: self.gateway_address = gateway_address self.gateway = gateway + @classmethod + def get_annotations(cls, command: str) -> Dict[str, type]: + return getattr(Gateway.commands, command).__annotations__ + @classmethod def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', interface_name: str = 'default') -> Tuple[str, str]: @@ -59,8 +62,9 @@ class UPnP: @classmethod @cli async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, - igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default', - loop=None) -> Dict: + igd_args: Optional[Dict[str, Union[int, str]]] = None, + unicast: bool = True, interface_name: str = 'default', + loop=None) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]: if not lan_address or not gateway_address: try: lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) @@ -97,13 +101,13 @@ class UPnP: async def get_port_mapping_by_index(self, index: int) -> Dict: result = await self._get_port_mapping_by_index(index) if result: - if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand): + if self.gateway.commands.is_registered('GetGenericPortMappingEntry'): return { k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result) } return {} - async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Union[None, str], int, str, + async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Optional[str], int, str, int, str, bool, str, int]]: try: redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) @@ -134,7 +138,7 @@ class UPnP: result = await self.gateway.commands.GetSpecificPortMappingEntry( NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol ) - if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand): + if result and self.gateway.commands.is_registered('GetSpecificPortMappingEntry'): return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} except UPnPError: pass @@ -152,7 +156,8 @@ class UPnP: ) @cli - async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int: + async def get_next_mapping(self, port: int, protocol: str, description: str, + internal_port: Optional[int] = None) -> int: if protocol not in ["UDP", "TCP"]: raise UPnPError("unsupported protocol: {}".format(protocol)) internal_port = int(internal_port or port) @@ -340,8 +345,9 @@ class UPnP: return await self.gateway.commands.GetActiveConnections() @classmethod - def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30, - interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None: + def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', + gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', + unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None: """ :param method: the command name :param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided @@ -356,7 +362,7 @@ class UPnP: igd_args = igd_args timeout = int(timeout) loop = loop or asyncio.get_event_loop_policy().get_event_loop() - fut: asyncio.Future = asyncio.Future() + fut: asyncio.Future = loop.create_future() async def wrapper(): # wrap the upnp setup and call of the command in a coroutine diff --git a/aioupnp/util.py b/aioupnp/util.py index 71be110..3b19504 100644 --- a/aioupnp/util.py +++ b/aioupnp/util.py @@ -1,91 +1,48 @@ -import re -import socket -from collections import defaultdict -from typing import Tuple, Dict -from xml.etree import ElementTree -import netifaces +import typing +from collections import OrderedDict -BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) -BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) +str_any_dict = typing.Dict[str, typing.Any] -def etree_to_dict(t: ElementTree.Element) -> Dict: - d: dict = {} - if t.attrib: - d[t.tag] = {} - children = list(t) - if children: - dd: dict = defaultdict(list) - for dc in map(etree_to_dict, children): - for k, v in dc.items(): - dd[k].append(v) - d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()} - if t.attrib: - d[t.tag].update(('@' + k, v) for k, v in t.attrib.items()) - if t.text: - text = t.text.strip() - if children or t.attrib: - if text: - d[t.tag]['#text'] = text - else: - d[t.tag] = text - return d - - -def flatten_keys(d, strip): - if not isinstance(d, (list, dict)): - return d - if isinstance(d, list): - return [flatten_keys(i, strip) for i in d] - t = {} - for k, v in d.items(): +def _recursive_flatten(to_flatten: typing.Any, strip: str) -> typing.Any: + if not isinstance(to_flatten, (list, dict)): + return to_flatten + if isinstance(to_flatten, list): + assert isinstance(to_flatten, list) + return [_recursive_flatten(i, strip) for i in to_flatten] + assert isinstance(to_flatten, dict) + keys: typing.List[str] = list(to_flatten.keys()) + copy: str_any_dict = OrderedDict() + for k in keys: + item: typing.Any = to_flatten[k] if strip in k and strip != k: - t[k.split(strip)[1]] = flatten_keys(v, strip) + copy[k.split(strip)[1]] = _recursive_flatten(item, strip) else: - t[k] = flatten_keys(v, strip) - return t + copy[k] = _recursive_flatten(item, strip) + return copy -def get_dict_val_case_insensitive(d, k): - match = list(filter(lambda x: x.lower() == k.lower(), d.keys())) - if not match: - return +def flatten_keys(to_flatten: str_any_dict, strip: str) -> str_any_dict: + keys: typing.List[str] = list(to_flatten.keys()) + copy: str_any_dict = OrderedDict() + for k in keys: + item = to_flatten[k] + if strip in k and strip != k: + new_key: str = k.split(strip)[1] + copy[new_key] = _recursive_flatten(item, strip) + else: + copy[k] = _recursive_flatten(item, strip) + return copy + + +def get_dict_val_case_insensitive(source: typing.Dict[typing.AnyStr, typing.AnyStr], key: typing.AnyStr) -> typing.Optional[typing.AnyStr]: + match: typing.List[typing.AnyStr] = list(filter(lambda x: x.lower() == key.lower(), source.keys())) + if not len(match): + return None if len(match) > 1: raise KeyError("overlapping keys") - return d[match[0]] - -# import struct -# import fcntl -# def get_ip_address(ifname): -# SIOCGIFADDR = 0x8915 -# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) -# return socket.inet_ntoa(fcntl.ioctl( -# s.fileno(), -# SIOCGIFADDR, -# struct.pack(b'256s', ifname[:15].encode()) -# )[20:24]) - - -def get_interfaces(): - r = { - interface_name: (router_address, netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']) - for router_address, interface_name, _ in netifaces.gateways()[socket.AF_INET] - } - for interface_name in netifaces.interfaces(): - if interface_name in ['lo', 'localhost'] or interface_name in r: - continue - addresses = netifaces.ifaddresses(interface_name) - if netifaces.AF_INET in addresses: - address = addresses[netifaces.AF_INET][0]['addr'] - gateway_guess = ".".join(address.split(".")[:-1] + ["1"]) - r[interface_name] = (gateway_guess, address) - r['default'] = r[netifaces.gateways()['default'][netifaces.AF_INET][1]] - return r - - -def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]: - for iface_name, (gateway, lan) in get_interfaces().items(): - if interface_name == iface_name: - return gateway, lan - return '', '' + if len(match) == 1: + matched_key: typing.AnyStr = match[0] + return source[matched_key] + raise KeyError("overlapping keys") diff --git a/setup.py b/setup.py index caa9ce6..39bfa9b 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup( packages=find_packages(exclude=('tests',)), entry_points={'console_scripts': console_scripts}, install_requires=[ - 'netifaces', + 'netifaces', 'defusedxml' ], extras_require={ 'test': (