mypy refactor
This commit is contained in:
parent
a404269d91
commit
4137f7cd8a
20 changed files with 1357 additions and 643 deletions
4
.coveragerc
Normal file
4
.coveragerc
Normal file
|
@ -0,0 +1,4 @@
|
|||
[run]
|
||||
omit =
|
||||
tests/*
|
||||
stubs/*
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -3,6 +3,9 @@
|
|||
_trial_temp/
|
||||
build/
|
||||
dist/
|
||||
html/
|
||||
index.html
|
||||
mypy-html.css
|
||||
.coverage
|
||||
.mypy_cache/
|
||||
aioupnp.spec
|
||||
|
|
440
.pylintrc
Normal file
440
.pylintrc
Normal file
|
@ -0,0 +1,440 @@
|
|||
[MASTER]
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS,schema
|
||||
|
||||
# Add files or directories matching the regex patterns to the
|
||||
# blacklist. The regex matches against base names, not paths.
|
||||
# `\.#.*` - add emacs tmp files to the blacklist
|
||||
ignore-patterns=\.#.*
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=1
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code
|
||||
# extension-pkg-whitelist=
|
||||
|
||||
# Allow optimization of some AST trees. This will activate a peephole AST
|
||||
# optimizer, which will apply various small optimizations. For instance, it can
|
||||
# be used to obtain the result of joining multiple strings with the addition
|
||||
# operator. Joining a lot of strings can lead to a maximum recursion error in
|
||||
# Pylint and this flag can prevent that. It has one side effect, the resulting
|
||||
# AST will be different than the one from reality.
|
||||
optimize-ast=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then re-enable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=
|
||||
anomalous-backslash-in-string,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
bad-continuation,
|
||||
bare-except,
|
||||
broad-except,
|
||||
cell-var-from-loop,
|
||||
consider-iterating-dictionary,
|
||||
dangerous-default-value,
|
||||
duplicate-code,
|
||||
fixme,
|
||||
global-statement,
|
||||
inherit-non-class,
|
||||
invalid-name,
|
||||
len-as-condition,
|
||||
locally-disabled,
|
||||
logging-not-lazy,
|
||||
missing-docstring,
|
||||
no-else-return,
|
||||
no-init,
|
||||
no-member,
|
||||
no-self-use,
|
||||
protected-access,
|
||||
redefined-builtin,
|
||||
redefined-outer-name,
|
||||
redefined-variable-type,
|
||||
relative-import,
|
||||
signature-differs,
|
||||
super-init-not-called,
|
||||
too-few-public-methods,
|
||||
too-many-arguments,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-lines,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
undefined-loop-variable,
|
||||
ungrouped-imports,
|
||||
unnecessary-lambda,
|
||||
unused-argument,
|
||||
unused-variable,
|
||||
wildcard-import,
|
||||
wrong-import-order,
|
||||
wrong-import-position,
|
||||
deprecated-lambda,
|
||||
simplifiable-if-statement,
|
||||
unidiomatic-typecheck,
|
||||
global-at-module-level,
|
||||
inconsistent-return-statements,
|
||||
keyword-arg-before-vararg,
|
||||
assignment-from-no-return,
|
||||
useless-return,
|
||||
assignment-from-none,
|
||||
stop-iteration-return
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Put messages in a separate file for each module / package specified on the
|
||||
# command line instead of printing them on stdout. Reports (if any) will be
|
||||
# written in a file name "pylint_global.[txt|html]".
|
||||
files-output=no
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=_$|dummy
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# List of builtins function names that should not be used, separated by a comma
|
||||
bad-functions=map,filter,input
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
# allow `d` as its used frequently for deferred callback chains
|
||||
good-names=i,j,k,ex,Run,_,d
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=foo,bar,baz,toto,tutu,tata
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for function names
|
||||
function-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for variable names
|
||||
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
|
||||
|
||||
# Naming hint for constant names
|
||||
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for attribute names
|
||||
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for argument names
|
||||
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
|
||||
|
||||
# Naming hint for class attribute names
|
||||
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
|
||||
|
||||
# Naming hint for inline iteration names
|
||||
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=[A-Z_][a-zA-Z0-9]+$
|
||||
|
||||
# Naming hint for class names
|
||||
class-name-hint=[A-Z_][a-zA-Z0-9]+$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
|
||||
|
||||
# Naming hint for module names
|
||||
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for method names
|
||||
method-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
|
||||
[ELIF]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=120
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=trailing-comma,dict-separator
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=1000
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,XXX,TODO
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=leveldb,distutils
|
||||
# Ignoring distutils because: https://github.com/PyCQA/pylint/issues/73
|
||||
|
||||
# List of classes names for which member attributes should not be checked
|
||||
# (useful for classes with attributes dynamically set). This supports can work
|
||||
# with qualified names.
|
||||
# ignored-classes=
|
||||
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,TERMIOS,Bastion,rexec,miniupnpc
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method
|
||||
max-args=10
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore
|
||||
ignored-argument-names=_.*
|
||||
|
||||
# Maximum number of locals for function / method body
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of return / yield for function / method body
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of branch for function / method body
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of statements in function / method body
|
||||
max-statements=50
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=8
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of boolean expressions in a if statement
|
||||
max-bool-expr=5
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,__new__,setUp
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,_fields,_replace,_source,_make
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=Exception
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
import textwrap
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from aioupnp.upnp import UPnP
|
||||
from aioupnp.commands import SOAPCommands
|
||||
|
||||
log = logging.getLogger("aioupnp")
|
||||
handler = logging.StreamHandler()
|
||||
|
@ -16,17 +19,18 @@ base_usage = "\n".join(textwrap.wrap(
|
|||
100, subsequent_indent=' ', break_long_words=False)) + "\n"
|
||||
|
||||
|
||||
def get_help(command):
|
||||
fn = getattr(UPnP, command)
|
||||
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return'])
|
||||
def get_help(command: str) -> str:
|
||||
annotations = UPnP.get_annotations(command)
|
||||
params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return'])
|
||||
return base_usage + "\n".join(
|
||||
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False)
|
||||
)
|
||||
|
||||
|
||||
def main(argv=None, loop=None):
|
||||
argv = argv or sys.argv
|
||||
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
|
||||
def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
|
||||
argv = argv or list(sys.argv)
|
||||
commands = list(SOAPCommands.SOAP_COMMANDS)
|
||||
help_str = "\n".join(textwrap.wrap(
|
||||
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
|
||||
))
|
||||
|
@ -41,14 +45,16 @@ def main(argv=None, loop=None):
|
|||
"For help with a specific command:" \
|
||||
" aioupnp help <command>\n" % (base_usage, help_str)
|
||||
|
||||
args = argv[1:]
|
||||
args: typing.List[str] = [str(arg) for arg in argv[1:]]
|
||||
if args[0] in ['help', '-h', '--help']:
|
||||
if len(args) > 1:
|
||||
if args[1] in commands:
|
||||
sys.exit(get_help(args[1]))
|
||||
sys.exit(print(usage))
|
||||
print(get_help(args[1]))
|
||||
return 0
|
||||
print(usage)
|
||||
return 0
|
||||
|
||||
defaults = {
|
||||
defaults: typing.Dict[str, typing.Union[bool, str, int]] = {
|
||||
'debug_logging': False,
|
||||
'interface': 'default',
|
||||
'gateway_address': '',
|
||||
|
@ -57,22 +63,22 @@ def main(argv=None, loop=None):
|
|||
'unicast': False
|
||||
}
|
||||
|
||||
options = OrderedDict()
|
||||
options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
|
||||
command = None
|
||||
for arg in args:
|
||||
if arg.startswith("--"):
|
||||
if "=" in arg:
|
||||
k, v = arg.split("=")
|
||||
options[k.lstrip('--')] = v
|
||||
else:
|
||||
k, v = arg, True
|
||||
k = k.lstrip('--')
|
||||
options[k] = v
|
||||
options[arg.lstrip('--')] = True
|
||||
else:
|
||||
command = arg
|
||||
break
|
||||
if not command:
|
||||
print("no command given")
|
||||
sys.exit(print(usage))
|
||||
print(usage)
|
||||
return 0
|
||||
kwargs = {}
|
||||
for arg in args[len(options)+1:]:
|
||||
if arg.startswith("--"):
|
||||
|
@ -81,18 +87,24 @@ def main(argv=None, loop=None):
|
|||
kwargs[k] = v
|
||||
else:
|
||||
break
|
||||
for k, v in defaults.items():
|
||||
for k in defaults:
|
||||
if k not in options:
|
||||
options[k] = v
|
||||
options[k] = defaults[k]
|
||||
|
||||
if options.pop('debug_logging'):
|
||||
log.setLevel(logging.DEBUG)
|
||||
|
||||
lan_address: str = str(options.pop('lan_address'))
|
||||
gateway_address: str = str(options.pop('gateway_address'))
|
||||
timeout: int = int(options.pop('timeout'))
|
||||
interface: str = str(options.pop('interface'))
|
||||
unicast: bool = bool(options.pop('unicast'))
|
||||
|
||||
UPnP.run_cli(
|
||||
command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'),
|
||||
options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop
|
||||
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
sys.exit(main())
|
||||
|
|
|
@ -1,64 +1,67 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
import typing
|
||||
from typing import Tuple, Union, List
|
||||
import functools
|
||||
import logging
|
||||
from typing import Tuple
|
||||
from aioupnp.protocols.scpd import scpd_post
|
||||
from aioupnp.device import Service
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
none_or_str = Union[None, str]
|
||||
return_type_lambas = {
|
||||
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
|
||||
}
|
||||
|
||||
|
||||
def safe_type(t):
|
||||
if t is typing.Tuple:
|
||||
return tuple
|
||||
if t is typing.List:
|
||||
return list
|
||||
if t is typing.Dict:
|
||||
return dict
|
||||
if t is typing.Set:
|
||||
return set
|
||||
return t
|
||||
def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]:
|
||||
return x if x is not None and str(x).lower() not in ['none', 'nil'] else None
|
||||
|
||||
|
||||
class SOAPCommand:
|
||||
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
|
||||
param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None:
|
||||
self.gateway_address = gateway_address
|
||||
self.service_port = service_port
|
||||
self.control_url = control_url
|
||||
self.service_id = service_id
|
||||
self.method = method
|
||||
self.param_types = param_types
|
||||
self.param_order = param_order
|
||||
self.return_types = return_types
|
||||
self.return_order = return_order
|
||||
self.loop = loop
|
||||
self._requests: typing.List = []
|
||||
def soap_bool(x: typing.Optional[str]) -> bool:
|
||||
return False if not x or str(x).lower() in ['false', 'False'] else True
|
||||
|
||||
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
|
||||
if set(kwargs.keys()) != set(self.param_types.keys()):
|
||||
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
|
||||
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
|
||||
|
||||
def recast_single_result(t, result):
|
||||
if t is bool:
|
||||
return soap_bool(result)
|
||||
if t is str:
|
||||
return soap_optional_str(result)
|
||||
return t(result)
|
||||
|
||||
|
||||
def recast_return(return_annotation, result, result_keys: typing.List[str]):
|
||||
if return_annotation is None:
|
||||
return None
|
||||
if len(result_keys) == 1:
|
||||
assert len(result_keys) == 1
|
||||
single_result = result[result_keys[0]]
|
||||
return recast_single_result(return_annotation, single_result)
|
||||
|
||||
annotated_args: typing.List[type] = list(return_annotation.__args__)
|
||||
assert len(annotated_args) == len(result_keys)
|
||||
recast_results: typing.List[typing.Optional[typing.Union[str, int, bool, bytes]]] = []
|
||||
for type_annotation, result_key in zip(annotated_args, result_keys):
|
||||
recast_results.append(recast_single_result(type_annotation, result[result_key]))
|
||||
return tuple(recast_results)
|
||||
|
||||
|
||||
def soap_command(fn):
|
||||
@functools.wraps(fn)
|
||||
async def wrapper(self: 'SOAPCommands', **kwargs):
|
||||
if not self.is_registered(fn.__name__):
|
||||
return fn(self, **kwargs)
|
||||
service = self.get_service(fn.__name__)
|
||||
assert service.controlURL is not None
|
||||
assert service.serviceType is not None
|
||||
response, xml_bytes, err = await scpd_post(
|
||||
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order,
|
||||
self.service_id, self.loop, **soap_kwargs
|
||||
service.controlURL, self._base_address.decode(), self._port, fn.__name__, self._registered[service][fn.__name__][0],
|
||||
service.serviceType.encode(), self._loop, **kwargs
|
||||
)
|
||||
if err is not None:
|
||||
self._requests.append((soap_kwargs, xml_bytes, None, err, time.time()))
|
||||
|
||||
self._requests.append((fn.__name__, kwargs, xml_bytes, None, err, time.time()))
|
||||
raise err
|
||||
if not response:
|
||||
result = None
|
||||
else:
|
||||
recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
|
||||
if len(recast_result) == 1:
|
||||
result = recast_result[0]
|
||||
else:
|
||||
result = recast_result
|
||||
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time()))
|
||||
result = recast_return(fn.__annotations__.get('return'), response, self._registered[service][fn.__name__][1])
|
||||
self._requests.append((fn.__name__, kwargs, xml_bytes, result, None, time.time()))
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class SOAPCommands:
|
||||
|
@ -72,7 +75,7 @@ class SOAPCommands:
|
|||
to their expected types.
|
||||
"""
|
||||
|
||||
SOAP_COMMANDS = [
|
||||
SOAP_COMMANDS: typing.List[str] = [
|
||||
'AddPortMapping',
|
||||
'GetNATRSIPStatus',
|
||||
'GetGenericPortMappingEntry',
|
||||
|
@ -91,59 +94,63 @@ class SOAPCommands:
|
|||
'GetTotalPacketsReceived',
|
||||
'X_GetICSStatistics',
|
||||
'GetDefaultConnectionService',
|
||||
'NewDefaultConnectionService',
|
||||
'NewEnabledForInternet',
|
||||
'SetDefaultConnectionService',
|
||||
'SetEnabledForInternet',
|
||||
'GetEnabledForInternet',
|
||||
'NewActiveConnectionIndex',
|
||||
'GetMaximumActiveConnections',
|
||||
'GetActiveConnections'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._registered = set()
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, base_address: bytes, port: int) -> None:
|
||||
self._loop = loop
|
||||
self._registered: typing.Dict[Service,
|
||||
typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]]] = {}
|
||||
self._base_address = base_address
|
||||
self._port = port
|
||||
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
typing.Optional[typing.Dict[str, typing.Any]],
|
||||
typing.Optional[Exception], float]] = []
|
||||
|
||||
def register(self, base_ip: bytes, port: int, name: str, control_url: str,
|
||||
service_type: bytes, inputs: List, outputs: List, loop=None) -> None:
|
||||
if name not in self.SOAP_COMMANDS or name in self._registered:
|
||||
def is_registered(self, name: str) -> bool:
|
||||
if name not in self.SOAP_COMMANDS:
|
||||
raise ValueError("unknown command")
|
||||
for service in self._registered.values():
|
||||
if name in service:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_service(self, name: str) -> Service:
|
||||
if name not in self.SOAP_COMMANDS:
|
||||
raise ValueError("unknown command")
|
||||
for service, commands in self._registered.items():
|
||||
if name in commands:
|
||||
return service
|
||||
raise ValueError(name)
|
||||
|
||||
def register(self, name: str, service: Service, inputs: typing.List[str], outputs: typing.List[str]) -> None:
|
||||
# control_url: str, service_type: bytes,
|
||||
if name not in self.SOAP_COMMANDS:
|
||||
raise AttributeError(name)
|
||||
current = getattr(self, name)
|
||||
annotations = current.__annotations__
|
||||
return_types = annotations.get('return', None)
|
||||
if return_types:
|
||||
if hasattr(return_types, '__args__'):
|
||||
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
|
||||
elif isinstance(return_types, type):
|
||||
return_types = (return_types,)
|
||||
return_types = {r: t for r, t in zip(outputs, return_types)}
|
||||
param_types = {}
|
||||
for param_name, param_type in annotations.items():
|
||||
if param_name == "return":
|
||||
continue
|
||||
param_types[param_name] = param_type
|
||||
command = SOAPCommand(
|
||||
base_ip.decode(), port, control_url, service_type,
|
||||
name, param_types, return_types, inputs, outputs, loop=loop
|
||||
)
|
||||
setattr(command, "__doc__", current.__doc__)
|
||||
setattr(self, command.method, command)
|
||||
self._registered.add(command.method)
|
||||
if self.is_registered(name):
|
||||
raise AttributeError(f"{name} is already a registered SOAP command")
|
||||
if service not in self._registered:
|
||||
self._registered[service] = {}
|
||||
self._registered[service][name] = inputs, outputs
|
||||
|
||||
@staticmethod
|
||||
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
|
||||
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
|
||||
NewLeaseDuration: str) -> None:
|
||||
@soap_command
|
||||
async def AddPortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
|
||||
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
|
||||
NewLeaseDuration: str) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetNATRSIPStatus() -> Tuple[bool, bool]:
|
||||
@soap_command
|
||||
async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
|
||||
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
|
||||
@soap_command
|
||||
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
|
||||
bool, str, int]:
|
||||
"""
|
||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||
|
@ -151,100 +158,100 @@ class SOAPCommands:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
|
||||
@soap_command
|
||||
async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
|
||||
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
|
||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def SetConnectionType(NewConnectionType: str) -> None:
|
||||
@soap_command
|
||||
async def SetConnectionType(self, NewConnectionType: str) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetExternalIPAddress() -> str:
|
||||
@soap_command
|
||||
async def GetExternalIPAddress(self) -> str:
|
||||
"""Returns (NewExternalIPAddress)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetConnectionTypeInfo() -> Tuple[str, str]:
|
||||
@soap_command
|
||||
async def GetConnectionTypeInfo(self) -> Tuple[str, str]:
|
||||
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetStatusInfo() -> Tuple[str, str, int]:
|
||||
@soap_command
|
||||
async def GetStatusInfo(self) -> Tuple[str, str, int]:
|
||||
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def ForceTermination() -> None:
|
||||
@soap_command
|
||||
async def ForceTermination(self) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
|
||||
@soap_command
|
||||
async def DeletePortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def RequestConnection() -> None:
|
||||
@soap_command
|
||||
async def RequestConnection(self) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetCommonLinkProperties():
|
||||
@soap_command
|
||||
async def GetCommonLinkProperties(self) -> Tuple[str, int, int, str]:
|
||||
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetTotalBytesSent():
|
||||
@soap_command
|
||||
async def GetTotalBytesSent(self) -> int:
|
||||
"""Returns (NewTotalBytesSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetTotalBytesReceived():
|
||||
@soap_command
|
||||
async def GetTotalBytesReceived(self) -> int:
|
||||
"""Returns (NewTotalBytesReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetTotalPacketsSent():
|
||||
@soap_command
|
||||
async def GetTotalPacketsSent(self) -> int:
|
||||
"""Returns (NewTotalPacketsSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalPacketsReceived():
|
||||
@soap_command
|
||||
async def GetTotalPacketsReceived(self) -> int:
|
||||
"""Returns (NewTotalPacketsReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]:
|
||||
@soap_command
|
||||
async def X_GetICSStatistics(self) -> Tuple[int, int, int, int, str, str]:
|
||||
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetDefaultConnectionService():
|
||||
@soap_command
|
||||
async def GetDefaultConnectionService(self) -> str:
|
||||
"""Returns (NewDefaultConnectionService)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None:
|
||||
@soap_command
|
||||
async def SetDefaultConnectionService(self, NewDefaultConnectionService: str) -> None:
|
||||
"""Returns (None)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None:
|
||||
@soap_command
|
||||
async def SetEnabledForInternet(self, NewEnabledForInternet: bool) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetEnabledForInternet() -> bool:
|
||||
@soap_command
|
||||
async def GetEnabledForInternet(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int):
|
||||
@soap_command
|
||||
async def GetMaximumActiveConnections(self, NewActiveConnectionIndex: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetActiveConnections() -> Tuple[str, str]:
|
||||
@soap_command
|
||||
async def GetActiveConnections(self) -> Tuple[str, str]:
|
||||
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -1,23 +1,33 @@
|
|||
from collections import OrderedDict
|
||||
import typing
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CaseInsensitive:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for k, v in kwargs.items():
|
||||
def __init__(self, **kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List[typing.Any]]]) -> None:
|
||||
keys: typing.List[str] = list(kwargs.keys())
|
||||
for k in keys:
|
||||
if not k.startswith("_"):
|
||||
setattr(self, k, v)
|
||||
assert k in kwargs
|
||||
setattr(self, k, kwargs[k])
|
||||
|
||||
def __getattr__(self, item):
|
||||
for k in self.__class__.__dict__.keys():
|
||||
def __getattr__(self, item: str) -> typing.Union[str, typing.Dict[str, typing.Any], typing.List]:
|
||||
keys: typing.List[str] = list(self.__class__.__dict__.keys())
|
||||
for k in keys:
|
||||
if k.lower() == item.lower():
|
||||
return self.__dict__.get(k)
|
||||
value: typing.Optional[typing.Union[str, typing.Dict[str, typing.Any],
|
||||
typing.List]] = self.__dict__.get(k)
|
||||
assert value is not None and isinstance(value, (str, dict, list))
|
||||
return value
|
||||
raise AttributeError(item)
|
||||
|
||||
def __setattr__(self, item, value):
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
def __setattr__(self, item: str,
|
||||
value: typing.Union[str, typing.Dict[str, typing.Any], typing.List]) -> None:
|
||||
assert isinstance(value, (str, dict)), ValueError(f"got type {str(type(value))}, expected str")
|
||||
keys: typing.List[str] = list(self.__class__.__dict__.keys())
|
||||
for k in keys:
|
||||
if k.lower() == item.lower():
|
||||
self.__dict__[k] = value
|
||||
return
|
||||
|
@ -26,52 +36,57 @@ class CaseInsensitive:
|
|||
return
|
||||
raise AttributeError(item)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {
|
||||
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
|
||||
}
|
||||
def as_dict(self) -> typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]:
|
||||
result: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]] = OrderedDict()
|
||||
keys: typing.List[str] = list(self.__dict__.keys())
|
||||
for k in keys:
|
||||
if not k.startswith("_"):
|
||||
result[k] = self.__getattr__(k)
|
||||
return result
|
||||
|
||||
|
||||
class Service(CaseInsensitive):
|
||||
serviceType = None
|
||||
serviceId = None
|
||||
controlURL = None
|
||||
eventSubURL = None
|
||||
SCPDURL = None
|
||||
serviceType: typing.Optional[str] = None
|
||||
serviceId: typing.Optional[str] = None
|
||||
controlURL: typing.Optional[str] = None
|
||||
eventSubURL: typing.Optional[str] = None
|
||||
SCPDURL: typing.Optional[str] = None
|
||||
|
||||
|
||||
class Device(CaseInsensitive):
|
||||
serviceList = None
|
||||
deviceList = None
|
||||
deviceType = None
|
||||
friendlyName = None
|
||||
manufacturer = None
|
||||
manufacturerURL = None
|
||||
modelDescription = None
|
||||
modelName = None
|
||||
modelNumber = None
|
||||
modelURL = None
|
||||
serialNumber = None
|
||||
udn = None
|
||||
upc = None
|
||||
presentationURL = None
|
||||
iconList = None
|
||||
serviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
|
||||
deviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
|
||||
deviceType: typing.Optional[str] = None
|
||||
friendlyName: typing.Optional[str] = None
|
||||
manufacturer: typing.Optional[str] = None
|
||||
manufacturerURL: typing.Optional[str] = None
|
||||
modelDescription: typing.Optional[str] = None
|
||||
modelName: typing.Optional[str] = None
|
||||
modelNumber: typing.Optional[str] = None
|
||||
modelURL: typing.Optional[str] = None
|
||||
serialNumber: typing.Optional[str] = None
|
||||
udn: typing.Optional[str] = None
|
||||
upc: typing.Optional[str] = None
|
||||
presentationURL: typing.Optional[str] = None
|
||||
iconList: typing.Optional[str] = None
|
||||
|
||||
def __init__(self, devices: List, services: List, **kwargs) -> None:
|
||||
def __init__(self, devices: typing.List['Device'], services: typing.List[Service],
|
||||
**kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]) -> None:
|
||||
super(Device, self).__init__(**kwargs)
|
||||
if self.serviceList and "service" in self.serviceList:
|
||||
new_services = self.serviceList["service"]
|
||||
if isinstance(new_services, dict):
|
||||
new_services = [new_services]
|
||||
services.extend([Service(**service) for service in new_services])
|
||||
if isinstance(self.serviceList['service'], dict):
|
||||
assert isinstance(self.serviceList['service'], dict)
|
||||
svc_list: typing.Dict[str, typing.Any] = self.serviceList['service']
|
||||
services.append(Service(**svc_list))
|
||||
elif isinstance(self.serviceList['service'], list):
|
||||
services.extend(Service(**svc) for svc in self.serviceList["service"])
|
||||
|
||||
if self.deviceList:
|
||||
for kw in self.deviceList.values():
|
||||
if isinstance(kw, dict):
|
||||
d = Device(devices, services, **kw)
|
||||
devices.append(d)
|
||||
devices.append(Device(devices, services, **kw))
|
||||
elif isinstance(kw, list):
|
||||
for _inner_kw in kw:
|
||||
d = Device(devices, services, **_inner_kw)
|
||||
devices.append(d)
|
||||
devices.append(Device(devices, services, **_inner_kw))
|
||||
else:
|
||||
log.warning("failed to parse device:\n%s", kw)
|
||||
|
|
|
@ -1,14 +1,2 @@
|
|||
from aioupnp.util import flatten_keys
|
||||
from aioupnp.constants import FAULT, CONTROL
|
||||
|
||||
|
||||
class UPnPError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def handle_fault(response: dict) -> dict:
|
||||
if FAULT in response:
|
||||
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
|
||||
error_description = fault['detail']['UPnPError']['errorDescription']
|
||||
raise UPnPError(error_description)
|
||||
return response
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import re
|
||||
import logging
|
||||
import socket
|
||||
import typing
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Union, Type
|
||||
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||
from typing import Dict, List, Union
|
||||
from aioupnp.util import get_dict_val_case_insensitive
|
||||
from aioupnp.constants import SPEC_VERSION, SERVICE
|
||||
from aioupnp.commands import SOAPCommands
|
||||
from aioupnp.device import Device, Service
|
||||
|
@ -15,77 +16,94 @@ from aioupnp.fault import UPnPError
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
return_type_lambas = {
|
||||
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
|
||||
}
|
||||
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
|
||||
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
|
||||
|
||||
|
||||
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
|
||||
def get_action_list(element_dict: typing.Dict[str, typing.Union[str, typing.Dict[str, str],
|
||||
typing.List[typing.Dict[str, typing.Dict[str, str]]]]]
|
||||
) -> typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]]:
|
||||
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
|
||||
result: typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]] = []
|
||||
if "actionList" in service_info:
|
||||
action_list = service_info["actionList"]
|
||||
else:
|
||||
return []
|
||||
return result
|
||||
if not len(action_list): # it could be an empty string
|
||||
return []
|
||||
return result
|
||||
|
||||
result: list = []
|
||||
if isinstance(action_list["action"], dict):
|
||||
arg_dicts = action_list["action"]['argumentList']['argument']
|
||||
if not isinstance(arg_dicts, list): # when there is one arg
|
||||
arg_dicts = [arg_dicts]
|
||||
return [[
|
||||
action_list["action"]['name'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'out']
|
||||
]]
|
||||
for action in action_list["action"]:
|
||||
if not action.get('argumentList'):
|
||||
result.append((action['name'], [], []))
|
||||
action = action_list["action"]
|
||||
if isinstance(action, dict):
|
||||
arg_dicts: typing.List[typing.Dict[str, str]] = []
|
||||
if not isinstance(action['argumentList']['argument'], list): # when there is one arg
|
||||
arg_dicts.extend([action['argumentList']['argument']])
|
||||
else:
|
||||
arg_dicts = action['argumentList']['argument']
|
||||
if not isinstance(arg_dicts, list): # when there is one arg
|
||||
arg_dicts = [arg_dicts]
|
||||
arg_dicts.extend(action['argumentList']['argument'])
|
||||
|
||||
result.append((action_list["action"]['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'out']))
|
||||
return result
|
||||
assert isinstance(action, list)
|
||||
for _action in action:
|
||||
if not _action.get('argumentList'):
|
||||
result.append((_action['name'], [], []))
|
||||
else:
|
||||
if not isinstance(_action['argumentList']['argument'], list): # when there is one arg
|
||||
arg_dicts = [_action['argumentList']['argument']]
|
||||
else:
|
||||
arg_dicts = _action['argumentList']['argument']
|
||||
result.append((
|
||||
action['name'],
|
||||
_action['name'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'out']
|
||||
))
|
||||
return result
|
||||
|
||||
|
||||
def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
|
||||
base_address_result: typing.List[bytes] = BASE_ADDRESS_REGEX.findall(location)
|
||||
base_address = base_address_result[0]
|
||||
port_result: typing.List[bytes] = BASE_PORT_REGEX.findall(location)
|
||||
port = int(port_result[0])
|
||||
return base_address, port
|
||||
|
||||
|
||||
class Gateway:
|
||||
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str,
|
||||
gateway_address: str) -> None:
|
||||
commands: SOAPCommands
|
||||
|
||||
def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]],
|
||||
lan_address: str, gateway_address: str,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._ok_packet = ok_packet
|
||||
self._m_search_args = m_search_args
|
||||
self._lan_address = lan_address
|
||||
self.usn = (ok_packet.usn or '').encode()
|
||||
self.ext = (ok_packet.ext or '').encode()
|
||||
self.server = (ok_packet.server or '').encode()
|
||||
self.location = (ok_packet.location or '').encode()
|
||||
self.cache_control = (ok_packet.cache_control or '').encode()
|
||||
self.date = (ok_packet.date or '').encode()
|
||||
self.urn = (ok_packet.st or '').encode()
|
||||
self.usn: bytes = (ok_packet.usn or '').encode()
|
||||
self.ext: bytes = (ok_packet.ext or '').encode()
|
||||
self.server: bytes = (ok_packet.server or '').encode()
|
||||
self.location: bytes = (ok_packet.location or '').encode()
|
||||
self.cache_control: bytes = (ok_packet.cache_control or '').encode()
|
||||
self.date: bytes = (ok_packet.date or '').encode()
|
||||
self.urn: bytes = (ok_packet.st or '').encode()
|
||||
|
||||
self._xml_response = b""
|
||||
self._service_descriptors: Dict = {}
|
||||
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
|
||||
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
|
||||
self._xml_response: bytes = b""
|
||||
self._service_descriptors: Dict[str, bytes] = {}
|
||||
|
||||
self.base_address, self.port = parse_location(self.location)
|
||||
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
|
||||
assert self.base_ip == gateway_address.encode()
|
||||
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
|
||||
|
||||
self.spec_version = None
|
||||
self.url_base = None
|
||||
self.spec_version: typing.Optional[str] = None
|
||||
self.url_base: typing.Optional[str] = None
|
||||
|
||||
self._device: Union[None, Device] = None
|
||||
self._devices: List = []
|
||||
self._services: List = []
|
||||
self._device: typing.Optional[Device] = None
|
||||
self._devices: List[Device] = []
|
||||
self._services: List[Service] = []
|
||||
|
||||
self._unsupported_actions: Dict = {}
|
||||
self._registered_commands: Dict = {}
|
||||
self.commands = SOAPCommands()
|
||||
self._unsupported_actions: Dict[str, typing.List[str]] = {}
|
||||
self._registered_commands: Dict[str, str] = {}
|
||||
self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
|
||||
|
||||
def gateway_descriptor(self) -> dict:
|
||||
r = {
|
||||
|
@ -102,14 +120,15 @@ class Gateway:
|
|||
def manufacturer_string(self) -> str:
|
||||
if not self.devices:
|
||||
return "UNKNOWN GATEWAY"
|
||||
device = list(self.devices.values())[0]
|
||||
return "%s %s" % (device.manufacturer, device.modelName)
|
||||
devices: typing.List[Device] = list(self.devices.values())
|
||||
device = devices[0]
|
||||
return f"{device.manufacturer} {device.modelName}"
|
||||
|
||||
@property
|
||||
def services(self) -> Dict:
|
||||
def services(self) -> Dict[str, Service]:
|
||||
if not self._device:
|
||||
return {}
|
||||
return {service.serviceType: service for service in self._services}
|
||||
return {str(service.serviceType): service for service in self._services}
|
||||
|
||||
@property
|
||||
def devices(self) -> Dict:
|
||||
|
@ -117,28 +136,29 @@ class Gateway:
|
|||
return {}
|
||||
return {device.udn: device for device in self._devices}
|
||||
|
||||
def get_service(self, service_type: str) -> Union[Type[Service], None]:
|
||||
def get_service(self, service_type: str) -> typing.Optional[Service]:
|
||||
for service in self._services:
|
||||
if service.serviceType.lower() == service_type.lower():
|
||||
if service.serviceType and service.serviceType.lower() == service_type.lower():
|
||||
return service
|
||||
return None
|
||||
|
||||
@property
|
||||
def soap_requests(self) -> List:
|
||||
soap_call_infos = []
|
||||
for name in self._registered_commands.keys():
|
||||
if not hasattr(getattr(self.commands, name), "_requests"):
|
||||
continue
|
||||
soap_call_infos.extend([
|
||||
(name, request_args, raw_response, decoded_response, soap_error, ts)
|
||||
for (
|
||||
request_args, raw_response, decoded_response, soap_error, ts
|
||||
) in getattr(self.commands, name)._requests
|
||||
])
|
||||
def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
typing.Optional[typing.Dict[str, typing.Any]],
|
||||
typing.Optional[Exception], float]]:
|
||||
soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
typing.Optional[typing.Dict[str, typing.Any]],
|
||||
typing.Optional[Exception], float]] = []
|
||||
soap_call_infos.extend([
|
||||
(name, request_args, raw_response, decoded_response, soap_error, ts)
|
||||
for (
|
||||
name, request_args, raw_response, decoded_response, soap_error, ts
|
||||
) in self.commands._requests
|
||||
])
|
||||
soap_call_infos.sort(key=lambda x: x[5])
|
||||
return soap_call_infos
|
||||
|
||||
def debug_gateway(self) -> Dict:
|
||||
def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
|
||||
return {
|
||||
'manufacturer_string': self.manufacturer_string,
|
||||
'gateway_address': self.base_ip,
|
||||
|
@ -156,9 +176,11 @@ class Gateway:
|
|||
|
||||
@classmethod
|
||||
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: OrderedDict = None, loop=None, unicast: bool = False):
|
||||
ignored: set = set()
|
||||
required_commands = [
|
||||
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
unicast: bool = False) -> 'Gateway':
|
||||
ignored: typing.Set[str] = set()
|
||||
required_commands: typing.List[str] = [
|
||||
'AddPortMapping',
|
||||
'DeletePortMapping',
|
||||
'GetExternalIPAddress'
|
||||
|
@ -166,20 +188,21 @@ class Gateway:
|
|||
while True:
|
||||
if not igd_args:
|
||||
m_search_args, datagram = await fuzzy_m_search(
|
||||
lan_address, gateway_address, timeout, loop, ignored, unicast
|
||||
lan_address, gateway_address, timeout, loop, ignored, unicast
|
||||
)
|
||||
else:
|
||||
m_search_args = OrderedDict(igd_args)
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
|
||||
try:
|
||||
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
|
||||
gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
|
||||
log.debug('get gateway descriptor %s', datagram.location)
|
||||
await gateway.discover_commands(loop)
|
||||
requirements_met = all([required in gateway._registered_commands for required in required_commands])
|
||||
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
|
||||
if not requirements_met:
|
||||
not_met = [
|
||||
required for required in required_commands if required not in gateway._registered_commands
|
||||
required for required in required_commands if not gateway.commands.is_registered(required)
|
||||
]
|
||||
assert datagram.location is not None
|
||||
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
|
||||
gateway.manufacturer_string, gateway.location, not_met)
|
||||
ignored.add(datagram.location)
|
||||
|
@ -188,13 +211,17 @@ class Gateway:
|
|||
log.debug('found gateway device %s', datagram.location)
|
||||
return gateway
|
||||
except (asyncio.TimeoutError, UPnPError) as err:
|
||||
assert datagram.location is not None
|
||||
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: OrderedDict = None, loop=None, unicast: bool = None):
|
||||
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
unicast: typing.Optional[bool] = None) -> 'Gateway':
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
if unicast is not None:
|
||||
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
|
||||
|
||||
|
@ -205,7 +232,7 @@ class Gateway:
|
|||
cls._discover_gateway(
|
||||
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
|
||||
)
|
||||
], return_when=asyncio.tasks.FIRST_COMPLETED)
|
||||
], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
@ -214,56 +241,78 @@ class Gateway:
|
|||
task.exception()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
results: typing.List[asyncio.Future['Gateway']] = list(done)
|
||||
return results[0].result()
|
||||
|
||||
return list(done)[0].result()
|
||||
|
||||
async def discover_commands(self, loop=None):
|
||||
async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
|
||||
self._xml_response = xml_bytes
|
||||
if get_err is not None:
|
||||
raise get_err
|
||||
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
|
||||
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
|
||||
spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
|
||||
if isinstance(spec_version, bytes):
|
||||
self.spec_version = spec_version.decode()
|
||||
else:
|
||||
self.spec_version = spec_version
|
||||
url_base = get_dict_val_case_insensitive(response, "urlbase")
|
||||
if isinstance(url_base, bytes):
|
||||
self.url_base = url_base.decode()
|
||||
else:
|
||||
self.url_base = url_base
|
||||
if not self.url_base:
|
||||
self.url_base = self.base_address.decode()
|
||||
if response:
|
||||
device_dict = get_dict_val_case_insensitive(response, "device")
|
||||
source_keys: typing.List[str] = list(response.keys())
|
||||
matches: typing.List[str] = list(filter(lambda x: x.lower() == "device", source_keys))
|
||||
match_key = matches[0]
|
||||
match: dict = response[match_key]
|
||||
# if not len(match):
|
||||
# return None
|
||||
# if len(match) > 1:
|
||||
# raise KeyError("overlapping keys")
|
||||
# if len(match) == 1:
|
||||
# matched_key: typing.AnyStr = match[0]
|
||||
# return source[matched_key]
|
||||
# raise KeyError("overlapping keys")
|
||||
|
||||
self._device = Device(
|
||||
self._devices, self._services, **device_dict
|
||||
self._devices, self._services, **match
|
||||
)
|
||||
else:
|
||||
self._device = Device(self._devices, self._services)
|
||||
for service_type in self.services.keys():
|
||||
await self.register_commands(self.services[service_type], loop)
|
||||
return None
|
||||
|
||||
async def register_commands(self, service: Service, loop=None):
|
||||
async def register_commands(self, service: Service,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
if not service.SCPDURL:
|
||||
raise UPnPError("no scpd url")
|
||||
if not service.serviceType:
|
||||
raise UPnPError("no service type")
|
||||
|
||||
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
|
||||
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
|
||||
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
|
||||
self._service_descriptors[service.SCPDURL] = xml_bytes
|
||||
|
||||
if get_err is not None:
|
||||
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
|
||||
if xml_bytes:
|
||||
log.debug("response: %s", xml_bytes.decode())
|
||||
return
|
||||
return None
|
||||
if not service_dict:
|
||||
return
|
||||
return None
|
||||
|
||||
action_list = get_action_list(service_dict)
|
||||
|
||||
for name, inputs, outputs in action_list:
|
||||
try:
|
||||
self.commands.register(self.base_ip, self.port, name, service.controlURL, service.serviceType.encode(),
|
||||
inputs, outputs, loop)
|
||||
self.commands.register(name, service, inputs, outputs)
|
||||
self._registered_commands[name] = service.serviceType
|
||||
log.debug("registered %s::%s", service.serviceType, name)
|
||||
except AttributeError:
|
||||
s = self._unsupported_actions.get(service.serviceType, [])
|
||||
s.append(name)
|
||||
self._unsupported_actions[service.serviceType] = s
|
||||
self._unsupported_actions.setdefault(service.serviceType, [])
|
||||
self._unsupported_actions[service.serviceType].append(name)
|
||||
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
|
||||
service.serviceType, name, inputs, outputs)
|
||||
log.debug("registered service %s", service.serviceType)
|
||||
return None
|
||||
|
|
51
aioupnp/interfaces.py
Normal file
51
aioupnp/interfaces.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import socket
|
||||
from collections import OrderedDict
|
||||
import typing
|
||||
import netifaces
|
||||
|
||||
|
||||
def get_netifaces():
|
||||
return netifaces
|
||||
|
||||
|
||||
def ifaddresses(iface: str):
|
||||
return get_netifaces().ifaddresses(iface)
|
||||
|
||||
|
||||
def _get_interfaces():
|
||||
return get_netifaces().interfaces()
|
||||
|
||||
|
||||
def _get_gateways():
|
||||
return get_netifaces().gateways()
|
||||
|
||||
|
||||
def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]:
|
||||
gateways = _get_gateways()
|
||||
infos = gateways[socket.AF_INET]
|
||||
assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict")
|
||||
interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos
|
||||
result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict(
|
||||
(interface_name, (router_address, ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']))
|
||||
for router_address, interface_name, _ in interface_infos
|
||||
)
|
||||
for interface_name in _get_interfaces():
|
||||
if interface_name in ['lo', 'localhost'] or interface_name in result:
|
||||
continue
|
||||
addresses = ifaddresses(interface_name)
|
||||
if netifaces.AF_INET in addresses:
|
||||
address = addresses[netifaces.AF_INET][0]['addr']
|
||||
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
|
||||
result[interface_name] = (gateway_guess, address)
|
||||
_default = gateways['default']
|
||||
assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list")
|
||||
default: typing.Dict[int, typing.Tuple[str, str]] = _default
|
||||
result['default'] = result[default[netifaces.AF_INET][1]]
|
||||
return result
|
||||
|
||||
|
||||
def get_gateway_and_lan_addresses(interface_name: str) -> typing.Tuple[str, str]:
|
||||
for iface_name, (gateway, lan) in get_interfaces().items():
|
||||
if interface_name == iface_name:
|
||||
return gateway, lan
|
||||
return '', ''
|
|
@ -44,10 +44,11 @@ ST
|
|||
characters in the domain name must be replaced with hyphens in accordance with RFC 2141.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST
|
||||
|
||||
SEARCH_TARGETS = [
|
||||
SEARCH_TARGETS: typing.List[str] = [
|
||||
'upnp:rootdevice',
|
||||
'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
|
||||
'urn:schemas-wifialliance-org:device:WFADevice:1',
|
||||
|
@ -58,7 +59,8 @@ SEARCH_TARGETS = [
|
|||
]
|
||||
|
||||
|
||||
def format_packet_args(order: list, **kwargs):
|
||||
def format_packet_args(order: typing.List[str],
|
||||
kwargs: typing.Dict[str, typing.Union[int, str]]) -> typing.Dict[str, typing.Union[int, str]]:
|
||||
args = []
|
||||
for o in order:
|
||||
for k, v in kwargs.items():
|
||||
|
@ -68,18 +70,18 @@ def format_packet_args(order: list, **kwargs):
|
|||
return OrderedDict(args)
|
||||
|
||||
|
||||
def packet_generator():
|
||||
def packet_generator() -> typing.Iterator[typing.Dict[str, typing.Union[int, str]]]:
|
||||
for st in SEARCH_TARGETS:
|
||||
order = ["HOST", "MAN", "MX", "ST"]
|
||||
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
|
||||
order = ["HOST", "MAN", "ST", "MX"]
|
||||
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
|
||||
order = ["HOST", "ST", "MAN", "MX"]
|
||||
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
|
||||
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
|
||||
|
|
|
@ -1,45 +1,70 @@
|
|||
import struct
|
||||
import socket
|
||||
import typing
|
||||
from asyncio.protocols import DatagramProtocol
|
||||
from asyncio.transports import DatagramTransport
|
||||
from asyncio.transports import BaseTransport
|
||||
from unittest import mock
|
||||
|
||||
|
||||
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]:
|
||||
if transport is None or not hasattr(transport, "_extra"):
|
||||
return None
|
||||
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
|
||||
assert sock is None or isinstance(sock, socket.SocketType) or isinstance(sock, mock.MagicMock)
|
||||
return sock
|
||||
|
||||
|
||||
class MulticastProtocol(DatagramProtocol):
|
||||
def __init__(self, multicast_address: str, bind_address: str) -> None:
|
||||
self.multicast_address = multicast_address
|
||||
self.bind_address = bind_address
|
||||
self.transport: DatagramTransport
|
||||
self.transport: typing.Optional[BaseTransport] = None
|
||||
|
||||
@property
|
||||
def sock(self) -> socket.socket:
|
||||
s: socket.socket = self.transport.get_extra_info(name='socket')
|
||||
return s
|
||||
def sock(self) -> typing.Optional[socket.socket]:
|
||||
return _get_sock(self.transport)
|
||||
|
||||
def get_ttl(self) -> int:
|
||||
return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
raise ValueError("not connected")
|
||||
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
|
||||
|
||||
def set_ttl(self, ttl: int = 1) -> None:
|
||||
self.sock.setsockopt(
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
return None
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
|
||||
)
|
||||
return None
|
||||
|
||||
def join_group(self, multicast_address: str, bind_address: str) -> None:
|
||||
self.sock.setsockopt(
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
return None
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
|
||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||
)
|
||||
return None
|
||||
|
||||
def leave_group(self, multicast_address: str, bind_address: str) -> None:
|
||||
self.sock.setsockopt(
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
raise ValueError("not connected")
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
|
||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||
)
|
||||
return None
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
self.transport = transport
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def create_multicast_socket(cls, bind_address: str):
|
||||
def create_multicast_socket(cls, bind_address: str) -> socket.socket:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind((bind_address, 0))
|
||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import typing
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from xml.etree import ElementTree
|
||||
import asyncio
|
||||
from asyncio.protocols import Protocol
|
||||
from aioupnp.fault import UPnPError
|
||||
|
@ -18,16 +17,22 @@ log = logging.getLogger(__name__)
|
|||
HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
|
||||
|
||||
|
||||
def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]:
|
||||
def parse_http_response_code(http_response: bytes) -> typing.Tuple[bytes, bytes]:
|
||||
parsed: typing.List[typing.Tuple[bytes, bytes]] = HTTP_CODE_REGEX.findall(http_response)
|
||||
return parsed[0]
|
||||
|
||||
|
||||
def parse_headers(response: bytes) -> typing.Tuple[typing.Dict[bytes, bytes], int, bytes]:
|
||||
lines = response.split(b'\r\n')
|
||||
headers = OrderedDict([
|
||||
headers: typing.Dict[bytes, bytes] = OrderedDict([
|
||||
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
|
||||
for l in response.split(b'\r\n')
|
||||
])
|
||||
if len(lines) != len(headers):
|
||||
raise ValueError("duplicate headers")
|
||||
http_response = tuple(headers.keys())[0]
|
||||
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0]
|
||||
header_keys: typing.List[bytes] = list(headers.keys())
|
||||
http_response = header_keys[0]
|
||||
response_code, message = parse_http_response_code(http_response)
|
||||
del headers[http_response]
|
||||
return headers, int(response_code), message
|
||||
|
||||
|
@ -40,37 +45,42 @@ class SCPDHTTPClientProtocol(Protocol):
|
|||
and devices respond with an invalid HTTP version line
|
||||
"""
|
||||
|
||||
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None,
|
||||
soap_service_id: str=None) -> None:
|
||||
def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]',
|
||||
soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None:
|
||||
self.message = message
|
||||
self.response_buff = b""
|
||||
self.finished = finished
|
||||
self.soap_method = soap_method
|
||||
self.soap_service_id = soap_service_id
|
||||
|
||||
self._response_code: int = 0
|
||||
self._response_msg: bytes = b""
|
||||
self._content_length: int = 0
|
||||
self._response_code = 0
|
||||
self._response_msg = b""
|
||||
self._content_length = 0
|
||||
self._got_headers = False
|
||||
self._headers: dict = {}
|
||||
self._headers: typing.Dict[bytes, bytes] = {}
|
||||
self._body = b""
|
||||
self.transport: typing.Optional[asyncio.WriteTransport] = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
transport.write(self.message)
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
assert isinstance(transport, asyncio.WriteTransport)
|
||||
self.transport = transport
|
||||
self.transport.write(self.message)
|
||||
return None
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.response_buff += data
|
||||
for i, line in enumerate(self.response_buff.split(b'\r\n')):
|
||||
if not line: # we hit the blank line between the headers and the body
|
||||
if i == (len(self.response_buff.split(b'\r\n')) - 1):
|
||||
return # the body is still yet to be written
|
||||
return None # the body is still yet to be written
|
||||
if not self._got_headers:
|
||||
self._headers, self._response_code, self._response_msg = parse_headers(
|
||||
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
|
||||
)
|
||||
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length')
|
||||
content_length = get_dict_val_case_insensitive(
|
||||
self._headers, b'Content-Length'
|
||||
)
|
||||
if content_length is None:
|
||||
return
|
||||
return None
|
||||
self._content_length = int(content_length or 0)
|
||||
self._got_headers = True
|
||||
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
|
||||
|
@ -86,21 +96,28 @@ class SCPDHTTPClientProtocol(Protocol):
|
|||
)
|
||||
)
|
||||
)
|
||||
return
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typing.Tuple[typing.Dict, bytes,
|
||||
typing.Optional[Exception]]:
|
||||
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
|
||||
finished: asyncio.Future = asyncio.Future()
|
||||
async def scpd_get(control_url: str, address: str, port: int,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> typing.Tuple[typing.Dict[str, typing.Any], bytes,
|
||||
typing.Optional[Exception]]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
packet = serialize_scpd_get(control_url, address)
|
||||
transport, protocol = await loop.create_connection(
|
||||
lambda : SCPDHTTPClientProtocol(packet, finished), address, port
|
||||
finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop)
|
||||
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
|
||||
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
||||
proto_factory, address, port
|
||||
)
|
||||
protocol = connect_tup[1]
|
||||
transport = connect_tup[0]
|
||||
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
||||
|
||||
error = None
|
||||
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
|
||||
try:
|
||||
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
|
||||
body, response_code, response_msg = await wait_task
|
||||
except asyncio.TimeoutError:
|
||||
error = UPnPError("get request timed out")
|
||||
body = b''
|
||||
|
@ -112,24 +129,31 @@ async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typi
|
|||
if not error:
|
||||
try:
|
||||
return deserialize_scpd_get_response(body), body, None
|
||||
except ElementTree.ParseError as err:
|
||||
except Exception as err:
|
||||
error = UPnPError(err)
|
||||
|
||||
return {}, body, error
|
||||
|
||||
|
||||
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
|
||||
loop=None, **kwargs) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
|
||||
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
|
||||
finished: asyncio.Future = asyncio.Future()
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
**kwargs: typing.Dict[str, typing.Any]
|
||||
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop)
|
||||
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
||||
transport, protocol = await loop.create_connection(
|
||||
lambda : SCPDHTTPClientProtocol(
|
||||
packet, finished, soap_method=method, soap_service_id=service_id.decode(),
|
||||
), address, port
|
||||
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
|
||||
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
|
||||
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
||||
proto_factory, address, port
|
||||
)
|
||||
protocol = connect_tup[1]
|
||||
transport = connect_tup[0]
|
||||
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
||||
|
||||
try:
|
||||
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
|
||||
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
|
||||
body, response_code, response_msg = await wait_task
|
||||
except asyncio.TimeoutError:
|
||||
return {}, b'', UPnPError("Timeout")
|
||||
except UPnPError as err:
|
||||
|
@ -140,5 +164,5 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
|
|||
return (
|
||||
deserialize_soap_post_response(body, method, service_id.decode()), body, None
|
||||
)
|
||||
except (ElementTree.ParseError, UPnPError) as err:
|
||||
except Exception as err:
|
||||
return {}, body, UPnPError(err)
|
||||
|
|
|
@ -3,8 +3,8 @@ import binascii
|
|||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
import socket
|
||||
from collections import OrderedDict
|
||||
from asyncio.futures import Future
|
||||
from asyncio.transports import DatagramTransport
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
|
@ -18,32 +18,48 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SSDPProtocol(MulticastProtocol):
|
||||
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None,
|
||||
unicast: bool = False) -> None:
|
||||
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None,
|
||||
unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(multicast_address, lan_address)
|
||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self.transport: typing.Optional[DatagramTransport] = None
|
||||
self._unicast = unicast
|
||||
self._ignored: typing.Set[str] = ignored or set() # ignored locations
|
||||
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
|
||||
self.notifications: typing.List = []
|
||||
self._pending_searches: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = []
|
||||
self.notifications: typing.List[SSDPDatagram] = []
|
||||
self.connected = asyncio.Event(loop=self.loop)
|
||||
|
||||
def disconnect(self):
|
||||
def connection_made(self, transport) -> None:
|
||||
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
|
||||
super().connection_made(transport)
|
||||
self.connected.set()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.transport:
|
||||
try:
|
||||
self.leave_group(self.multicast_address, self.bind_address)
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception("unexpected error leaving multicast group")
|
||||
self.transport.close()
|
||||
self.connected.clear()
|
||||
while self._pending_searches:
|
||||
pending = self._pending_searches.pop()[2]
|
||||
if not pending.cancelled() and not pending.done():
|
||||
pending.cancel()
|
||||
return None
|
||||
|
||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||
if packet.location in self._ignored:
|
||||
return
|
||||
tmp: typing.List = []
|
||||
set_futures: typing.List = []
|
||||
while self._pending_searches:
|
||||
t: tuple = self._pending_searches.pop()
|
||||
a, s = t[0], t[1]
|
||||
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
|
||||
f: Future = t[2]
|
||||
return None
|
||||
# TODO: fix this
|
||||
tmp: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = []
|
||||
set_futures: typing.List[asyncio.Future[SSDPDatagram]] = []
|
||||
while len(self._pending_searches):
|
||||
t: typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle] = self._pending_searches.pop()
|
||||
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
|
||||
f: asyncio.Future[SSDPDatagram] = t[2]
|
||||
if f not in set_futures:
|
||||
set_futures.append(f)
|
||||
if not f.done():
|
||||
|
@ -52,38 +68,41 @@ class SSDPProtocol(MulticastProtocol):
|
|||
tmp.append(t)
|
||||
while tmp:
|
||||
self._pending_searches.append(tmp.pop())
|
||||
return None
|
||||
|
||||
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
|
||||
def _send_m_search(self, address: str, packet: SSDPDatagram) -> None:
|
||||
dest = address if self._unicast else SSDP_IP_ADDRESS
|
||||
for packet in packets:
|
||||
log.debug("send m search to %s: %s", dest, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
|
||||
if not self.transport:
|
||||
raise UPnPError("SSDP transport not connected")
|
||||
log.debug("send m search to %s: %s", dest, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
|
||||
return None
|
||||
|
||||
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
|
||||
fut: Future = Future()
|
||||
packets: typing.List[SSDPDatagram] = []
|
||||
async def m_search(self, address: str, timeout: float,
|
||||
datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
|
||||
fut: asyncio.Future[SSDPDatagram] = asyncio.Future(loop=self.loop)
|
||||
for datagram in datagrams:
|
||||
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
|
||||
packet = SSDPDatagram("M-SEARCH", datagram)
|
||||
assert packet.st is not None
|
||||
self._pending_searches.append((address, packet.st, fut))
|
||||
packets.append(packet)
|
||||
self.send_many_m_searches(address, packets),
|
||||
return await fut
|
||||
self._pending_searches.append(
|
||||
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet))
|
||||
)
|
||||
return await asyncio.wait_for(fut, timeout)
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore
|
||||
if addr[0] == self.bind_address:
|
||||
return
|
||||
return None
|
||||
try:
|
||||
packet = SSDPDatagram.decode(data)
|
||||
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
|
||||
except UPnPError as err:
|
||||
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
|
||||
binascii.hexlify(data))
|
||||
return
|
||||
return None
|
||||
|
||||
if packet._packet_type == packet._OK:
|
||||
self._callback_m_search_ok(addr[0], packet)
|
||||
return
|
||||
return None
|
||||
# elif packet._packet_type == packet._NOTIFY:
|
||||
# log.debug("%s:%i sent us a notification: %s", packet)
|
||||
# if packet.nt == SSDP_ROOT_DEVICE:
|
||||
|
@ -104,17 +123,18 @@ class SSDPProtocol(MulticastProtocol):
|
|||
# return
|
||||
|
||||
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
|
||||
SSDPProtocol, str, str]:
|
||||
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: typing.Optional[typing.Set[str]] = None,
|
||||
unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
try:
|
||||
sock = SSDPProtocol.create_multicast_socket(lan_address)
|
||||
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
|
||||
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
|
||||
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
|
||||
)
|
||||
transport: DatagramTransport = listen_result[0]
|
||||
protocol: SSDPProtocol = listen_result[1]
|
||||
transport = listen_result[0]
|
||||
protocol = listen_result[1]
|
||||
assert isinstance(protocol, SSDPProtocol)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise UPnPError(err)
|
||||
|
@ -125,30 +145,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
|
|||
protocol.disconnect()
|
||||
raise UPnPError(err)
|
||||
|
||||
return transport, protocol, gateway_address, lan_address
|
||||
return protocol, gateway_address, lan_address
|
||||
|
||||
|
||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
|
||||
loop=None, ignored: typing.Set[str] = None,
|
||||
unicast: bool = False) -> SSDPDatagram:
|
||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]],
|
||||
timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram:
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored, unicast
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout
|
||||
)
|
||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
finally:
|
||||
protocol.disconnect()
|
||||
|
||||
|
||||
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
|
||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: typing.Set[str] = None,
|
||||
unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]:
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored, unicast
|
||||
)
|
||||
await protocol.connected.wait()
|
||||
packet_args = list(packet_generator())
|
||||
batch_size = 2
|
||||
batch_timeout = float(timeout) / float(len(packet_args))
|
||||
|
@ -157,7 +178,7 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
|
|||
packet_args = packet_args[batch_size:]
|
||||
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
|
||||
try:
|
||||
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), batch_timeout)
|
||||
await protocol.m_search(gateway_address, batch_timeout, args)
|
||||
protocol.disconnect()
|
||||
return args
|
||||
except asyncio.TimeoutError:
|
||||
|
@ -166,9 +187,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
|
|||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
|
||||
|
||||
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
|
||||
SSDPDatagram]:
|
||||
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: typing.Set[str] = None,
|
||||
unicast: bool = False) -> typing.Tuple[typing.Dict[str,
|
||||
typing.Union[int, str]], SSDPDatagram]:
|
||||
# we don't know which packet the gateway replies to, so send small batches at a time
|
||||
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
|
||||
# check the args in the batch that got a reply one at a time to see which one worked
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import re
|
||||
from typing import Dict
|
||||
from xml.etree import ElementTree
|
||||
from aioupnp.constants import XML_VERSION, DEVICE, ROOT
|
||||
from aioupnp.util import etree_to_dict, flatten_keys
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.constants import XML_VERSION
|
||||
from aioupnp.serialization.xml import xml_to_dict
|
||||
from aioupnp.util import flatten_keys
|
||||
|
||||
|
||||
CONTENT_PATTERN = re.compile(
|
||||
|
@ -28,34 +29,38 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
|
|||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
return (
|
||||
(
|
||||
'GET %s HTTP/1.1\r\n'
|
||||
'Accept-Encoding: gzip\r\n'
|
||||
'Host: %s\r\n'
|
||||
'Connection: Close\r\n'
|
||||
'\r\n'
|
||||
) % (path, host)
|
||||
f'GET {path} HTTP/1.1\r\n'
|
||||
f'Accept-Encoding: gzip\r\n'
|
||||
f'Host: {host}\r\n'
|
||||
f'Connection: Close\r\n'
|
||||
f'\r\n'
|
||||
).encode()
|
||||
|
||||
|
||||
def deserialize_scpd_get_response(content: bytes) -> Dict:
|
||||
def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
|
||||
if XML_VERSION.encode() in content:
|
||||
parsed = CONTENT_PATTERN.findall(content)
|
||||
content = b'' if not parsed else parsed[0][0]
|
||||
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
|
||||
parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content)
|
||||
xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode())
|
||||
return parse_device_dict(xml_dict)
|
||||
return {}
|
||||
|
||||
|
||||
def parse_device_dict(xml_dict: dict) -> Dict:
|
||||
def parse_device_dict(xml_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
keys = list(xml_dict.keys())
|
||||
found = False
|
||||
for k in keys:
|
||||
m = XML_ROOT_SANITY_PATTERN.findall(k)
|
||||
m: List[Tuple[str, str, str, str, str, str]] = XML_ROOT_SANITY_PATTERN.findall(k)
|
||||
if len(m) == 3 and m[1][0] and m[2][5]:
|
||||
schema_key = m[1][0]
|
||||
root = m[2][5]
|
||||
xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
||||
schema_key: str = m[1][0]
|
||||
root: str = m[2][5]
|
||||
flattened = flatten_keys(xml_dict, "{%s}" % schema_key)
|
||||
if root not in flattened:
|
||||
raise UPnPError("root device not found")
|
||||
xml_dict = flattened[root]
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
raise UPnPError("device not found")
|
||||
result = {}
|
||||
for k, v in xml_dict.items():
|
||||
if isinstance(xml_dict[k], dict):
|
||||
|
@ -65,10 +70,9 @@ def parse_device_dict(xml_dict: dict) -> Dict:
|
|||
if len(parsed_k) == 2:
|
||||
inner_d[parsed_k[0]] = inner_v
|
||||
else:
|
||||
assert len(parsed_k) == 3
|
||||
assert len(parsed_k) == 3, f"expected len=3, got {len(parsed_k)}"
|
||||
inner_d[parsed_k[1]] = inner_v
|
||||
result[k] = inner_d
|
||||
else:
|
||||
result[k] = v
|
||||
|
||||
return result
|
||||
|
|
|
@ -1,64 +1,65 @@
|
|||
import re
|
||||
from xml.etree import ElementTree
|
||||
from aioupnp.util import etree_to_dict, flatten_keys
|
||||
from aioupnp.fault import handle_fault, UPnPError
|
||||
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY
|
||||
import typing
|
||||
from aioupnp.util import flatten_keys
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL
|
||||
from aioupnp.serialization.xml import xml_to_dict
|
||||
|
||||
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
|
||||
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
|
||||
)
|
||||
|
||||
|
||||
def serialize_soap_post(method: str, param_names: list, service_id: bytes, gateway_address: bytes,
|
||||
control_url: bytes, **kwargs) -> bytes:
|
||||
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in param_names)
|
||||
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
|
||||
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
|
||||
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
|
||||
XML_VERSION, method, service_id.decode(),
|
||||
args, method))
|
||||
def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes,
|
||||
control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes:
|
||||
args = "".join(f"<{n}>{kwargs.get(n)}</{n}>" for n in param_names)
|
||||
soap_body = (f'\r\n{XML_VERSION}\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
|
||||
f's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
|
||||
f'<u:{method} xmlns:u="{service_id.decode()}">{args}</u:{method}></s:Body></s:Envelope>')
|
||||
if "http://" in gateway_address.decode():
|
||||
host = gateway_address.decode().split("http://")[1]
|
||||
else:
|
||||
host = gateway_address.decode()
|
||||
return (
|
||||
(
|
||||
'POST %s HTTP/1.1\r\n'
|
||||
'Host: %s\r\n'
|
||||
'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
|
||||
'Content-Length: %i\r\n'
|
||||
'Content-Type: text/xml\r\n'
|
||||
'SOAPAction: \"%s#%s\"\r\n'
|
||||
'Connection: Close\r\n'
|
||||
'Cache-Control: no-cache\r\n'
|
||||
'Pragma: no-cache\r\n'
|
||||
'%s'
|
||||
'\r\n'
|
||||
) % (
|
||||
control_url.decode(), # could be just / even if it shouldn't be
|
||||
host,
|
||||
len(soap_body),
|
||||
service_id.decode(), # maybe no quotes
|
||||
method,
|
||||
soap_body
|
||||
)
|
||||
f'POST {control_url.decode()} HTTP/1.1\r\n' # could be just / even if it shouldn't be
|
||||
f'Host: {host}\r\n'
|
||||
f'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
|
||||
f'Content-Length: {len(soap_body)}\r\n'
|
||||
f'Content-Type: text/xml\r\n'
|
||||
f'SOAPAction: \"{service_id.decode()}#{method}\"\r\n'
|
||||
f'Connection: Close\r\n'
|
||||
f'Cache-Control: no-cache\r\n'
|
||||
f'Pragma: no-cache\r\n'
|
||||
f'{soap_body}'
|
||||
f'\r\n'
|
||||
).encode()
|
||||
|
||||
|
||||
def deserialize_soap_post_response(response: bytes, method: str, service_id: str) -> dict:
|
||||
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
|
||||
def deserialize_soap_post_response(response: bytes, method: str,
|
||||
service_id: str) -> typing.Dict[str, typing.Dict[str, str]]:
|
||||
parsed: typing.List[typing.List[bytes]] = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
|
||||
content = b'' if not parsed else parsed[0][0]
|
||||
content_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
|
||||
content_dict = xml_to_dict(content.decode())
|
||||
envelope = content_dict[ENVELOPE]
|
||||
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id)
|
||||
body = handle_fault(response_body) # raises UPnPError if there is a fault
|
||||
if not isinstance(envelope[BODY], dict):
|
||||
# raise UPnPError('blank response')
|
||||
return {} # TODO: raise
|
||||
response_body: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
|
||||
envelope[BODY], f"{'{' + service_id + '}'}"
|
||||
)
|
||||
if not response_body:
|
||||
# raise UPnPError('blank response')
|
||||
return {} # TODO: raise
|
||||
if FAULT in response_body:
|
||||
fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
|
||||
response_body[FAULT], "{%s}" % CONTROL
|
||||
)
|
||||
raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
|
||||
response_key = None
|
||||
if not body:
|
||||
return {}
|
||||
for key in body:
|
||||
for key in response_body:
|
||||
if method in key:
|
||||
response_key = key
|
||||
break
|
||||
if not response_key:
|
||||
raise UPnPError("unknown response fields for %s: %s" % (method, body))
|
||||
return body[response_key]
|
||||
raise UPnPError(f"unknown response fields for {method}: {response_body}")
|
||||
return response_body[response_key]
|
||||
|
|
|
@ -3,43 +3,67 @@ import logging
|
|||
import binascii
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from typing import List, Optional, Dict, Union, Tuple, Callable
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.constants import line_separator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_template = "(?i)^(%s):[ ]*(.*)$"
|
||||
|
||||
|
||||
ssdp_datagram_patterns = {
|
||||
'host': (re.compile("(?i)^(host):(.*)$"), str),
|
||||
'st': (re.compile(_template % 'st'), str),
|
||||
'man': (re.compile(_template % 'man'), str),
|
||||
'mx': (re.compile(_template % 'mx'), int),
|
||||
'nt': (re.compile(_template % 'nt'), str),
|
||||
'nts': (re.compile(_template % 'nts'), str),
|
||||
'usn': (re.compile(_template % 'usn'), str),
|
||||
'location': (re.compile(_template % 'location'), str),
|
||||
'cache_control': (re.compile(_template % 'cache[-|_]control'), str),
|
||||
'server': (re.compile(_template % 'server'), str),
|
||||
}
|
||||
|
||||
vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
|
||||
|
||||
|
||||
class SSDPDatagram(object):
|
||||
def match_vendor(line: str) -> Optional[Tuple[str, str]]:
|
||||
match: List[Tuple[str, str, str]] = vendor_pattern.findall(line)
|
||||
if match:
|
||||
vendor_key: str = match[-1][0].lstrip(" ").rstrip(" ")
|
||||
vendor_value: str = match[-1][2].lstrip(" ").rstrip(" ")
|
||||
return vendor_key, vendor_value
|
||||
return None
|
||||
|
||||
|
||||
def compile_find(pattern: str) -> Callable[[str], Optional[str]]:
|
||||
p = re.compile(pattern)
|
||||
|
||||
def find(line: str) -> Optional[str]:
|
||||
result: List[List[str]] = []
|
||||
for outer in p.findall(line):
|
||||
result.append([])
|
||||
for inner in outer:
|
||||
result[-1].append(inner)
|
||||
if result:
|
||||
return result[-1][-1].lstrip(" ").rstrip(" ")
|
||||
return None
|
||||
|
||||
return find
|
||||
|
||||
|
||||
ssdp_datagram_patterns: Dict[str, Callable[[str], Optional[str]]] = {
|
||||
'host': compile_find("(?i)^(host):(.*)$"),
|
||||
'st': compile_find(_template % 'st'),
|
||||
'man': compile_find(_template % 'man'),
|
||||
'mx': compile_find(_template % 'mx'),
|
||||
'nt': compile_find(_template % 'nt'),
|
||||
'nts': compile_find(_template % 'nts'),
|
||||
'usn': compile_find(_template % 'usn'),
|
||||
'location': compile_find(_template % 'location'),
|
||||
'cache_control': compile_find(_template % 'cache[-|_]control'),
|
||||
'server': compile_find(_template % 'server'),
|
||||
}
|
||||
|
||||
|
||||
class SSDPDatagram:
|
||||
_M_SEARCH = "M-SEARCH"
|
||||
_NOTIFY = "NOTIFY"
|
||||
_OK = "OK"
|
||||
|
||||
_start_lines = {
|
||||
_start_lines: Dict[str, str] = {
|
||||
_M_SEARCH: "M-SEARCH * HTTP/1.1",
|
||||
_NOTIFY: "NOTIFY * HTTP/1.1",
|
||||
_OK: "HTTP/1.1 200 OK"
|
||||
}
|
||||
|
||||
_friendly_names = {
|
||||
_friendly_names: Dict[str, str] = {
|
||||
_M_SEARCH: "m-search",
|
||||
_NOTIFY: "notify",
|
||||
_OK: "m-search response"
|
||||
|
@ -47,9 +71,7 @@ class SSDPDatagram(object):
|
|||
|
||||
_vendor_field_pattern = vendor_pattern
|
||||
|
||||
_patterns = ssdp_datagram_patterns
|
||||
|
||||
_required_fields = {
|
||||
_required_fields: Dict[str, List[str]] = {
|
||||
_M_SEARCH: [
|
||||
'host',
|
||||
'man',
|
||||
|
@ -75,137 +97,137 @@ class SSDPDatagram(object):
|
|||
]
|
||||
}
|
||||
|
||||
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None:
|
||||
def __init__(self, packet_type: str, kwargs: Optional[Dict[str, Union[str, int]]] = None) -> None:
|
||||
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
||||
raise UPnPError("unknown packet type: {}".format(packet_type))
|
||||
self._packet_type = packet_type
|
||||
kwargs = kwargs or OrderedDict()
|
||||
self._field_order: list = [
|
||||
k.lower().replace("-", "_") for k in kwargs.keys()
|
||||
kw: Dict[str, Union[str, int]] = kwargs or OrderedDict()
|
||||
self._field_order: List[str] = [
|
||||
k.lower().replace("-", "_") for k in kw.keys()
|
||||
]
|
||||
self.host = None
|
||||
self.man = None
|
||||
self.mx = None
|
||||
self.st = None
|
||||
self.nt = None
|
||||
self.nts = None
|
||||
self.usn = None
|
||||
self.location = None
|
||||
self.cache_control = None
|
||||
self.server = None
|
||||
self.date = None
|
||||
self.ext = None
|
||||
for k, v in kwargs.items():
|
||||
self.host: Optional[str] = None
|
||||
self.man: Optional[str] = None
|
||||
self.mx: Optional[Union[str, int]] = None
|
||||
self.st: Optional[str] = None
|
||||
self.nt: Optional[str] = None
|
||||
self.nts: Optional[str] = None
|
||||
self.usn: Optional[str] = None
|
||||
self.location: Optional[str] = None
|
||||
self.cache_control: Optional[str] = None
|
||||
self.server: Optional[str] = None
|
||||
self.date: Optional[str] = None
|
||||
self.ext: Optional[str] = None
|
||||
for k, v in kw.items():
|
||||
normalized = k.lower().replace("-", "_")
|
||||
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self, normalized) is None:
|
||||
setattr(self, normalized, v)
|
||||
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()}
|
||||
if not normalized.startswith("_") and hasattr(self, normalized):
|
||||
if getattr(self, normalized, None) is None:
|
||||
setattr(self, normalized, v)
|
||||
self._case_mappings: Dict[str, str] = {k.lower(): k for k in kw.keys()}
|
||||
for k in self._required_fields[self._packet_type]:
|
||||
if getattr(self, k) is None:
|
||||
if getattr(self, k, None) is None:
|
||||
raise UPnPError("missing required field %s" % k)
|
||||
|
||||
def get_cli_igd_kwargs(self) -> str:
|
||||
fields = []
|
||||
for field in self._field_order:
|
||||
v = getattr(self, field)
|
||||
v = getattr(self, field, None)
|
||||
if v is None:
|
||||
raise UPnPError("missing required field %s" % field)
|
||||
fields.append("--%s=%s" % (self._case_mappings.get(field, field), v))
|
||||
return " ".join(fields)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.as_json()
|
||||
|
||||
def __getitem__(self, item):
|
||||
def __getitem__(self, item: str) -> Union[str, int]:
|
||||
for i in self._required_fields[self._packet_type]:
|
||||
if i.lower() == item.lower():
|
||||
return getattr(self, i)
|
||||
raise KeyError(item)
|
||||
|
||||
def get_friendly_name(self) -> str:
|
||||
return self._friendly_names[self._packet_type]
|
||||
|
||||
def encode(self, trailing_newlines: int = 2) -> str:
|
||||
lines = [self._start_lines[self._packet_type]]
|
||||
for attr_name in self._field_order:
|
||||
if attr_name not in self._required_fields[self._packet_type]:
|
||||
continue
|
||||
attr = getattr(self, attr_name)
|
||||
if attr is None:
|
||||
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
||||
if attr_name == 'mx':
|
||||
value = str(attr)
|
||||
else:
|
||||
value = attr
|
||||
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
|
||||
lines.extend(
|
||||
f"{self._case_mappings.get(attr_name.lower(), attr_name.upper())}: {str(getattr(self, attr_name))}"
|
||||
for attr_name in self._field_order if attr_name in self._required_fields[self._packet_type]
|
||||
)
|
||||
serialized = line_separator.join(lines)
|
||||
for _ in range(trailing_newlines):
|
||||
serialized += line_separator
|
||||
return serialized
|
||||
|
||||
def as_dict(self) -> OrderedDict:
|
||||
def as_dict(self) -> Dict[str, Union[str, int]]:
|
||||
return self._lines_to_content_dict(self.encode().split(line_separator))
|
||||
|
||||
def as_json(self) -> str:
|
||||
return json.dumps(self.as_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, datagram: bytes):
|
||||
def decode(cls, datagram: bytes) -> 'SSDPDatagram':
|
||||
packet = cls._from_string(datagram.decode())
|
||||
if packet is None:
|
||||
raise UPnPError(
|
||||
"failed to decode datagram: {}".format(binascii.hexlify(datagram))
|
||||
)
|
||||
for attr_name in packet._required_fields[packet._packet_type]:
|
||||
attr = getattr(packet, attr_name)
|
||||
if attr is None:
|
||||
if getattr(packet, attr_name, None) is None:
|
||||
raise UPnPError(
|
||||
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
|
||||
)
|
||||
return packet
|
||||
|
||||
@classmethod
|
||||
def _lines_to_content_dict(cls, lines: list) -> OrderedDict:
|
||||
result: OrderedDict = OrderedDict()
|
||||
def _lines_to_content_dict(cls, lines: List[str]) -> Dict[str, Union[str, int]]:
|
||||
result: Dict[str, Union[str, int]] = OrderedDict()
|
||||
matched_keys: List[str] = []
|
||||
for line in lines:
|
||||
if not line:
|
||||
continue
|
||||
matched = False
|
||||
for name, (pattern, field_type) in cls._patterns.items():
|
||||
if name not in result and pattern.findall(line):
|
||||
match = pattern.findall(line)[-1][-1]
|
||||
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" "))
|
||||
matched = True
|
||||
break
|
||||
|
||||
for name, pattern in ssdp_datagram_patterns.items():
|
||||
if name not in matched_keys:
|
||||
if name.lower() == 'mx':
|
||||
_matched_int = pattern(line)
|
||||
if _matched_int is not None:
|
||||
match_int = int(_matched_int)
|
||||
result[line[:len(name)]] = match_int
|
||||
matched = True
|
||||
matched_keys.append(name)
|
||||
break
|
||||
else:
|
||||
match = pattern(line)
|
||||
if match is not None:
|
||||
result[line[:len(name)]] = match
|
||||
matched = True
|
||||
matched_keys.append(name)
|
||||
break
|
||||
if not matched:
|
||||
if cls._vendor_field_pattern.findall(line):
|
||||
match = cls._vendor_field_pattern.findall(line)[-1]
|
||||
vendor_key = match[0].lstrip(" ").rstrip(" ")
|
||||
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
|
||||
value = match[2].lstrip(" ").rstrip(" ")
|
||||
if vendor_key not in result:
|
||||
result[vendor_key] = value
|
||||
matched_vendor = match_vendor(line)
|
||||
if matched_vendor and matched_vendor[0] not in result:
|
||||
result[matched_vendor[0]] = matched_vendor[1]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _from_string(cls, datagram: str):
|
||||
def _from_string(cls, datagram: str) -> Optional['SSDPDatagram']:
|
||||
lines = [l for l in datagram.split(line_separator) if l]
|
||||
if not lines:
|
||||
return
|
||||
return None
|
||||
if lines[0] == cls._start_lines[cls._M_SEARCH]:
|
||||
return cls._from_request(lines[1:])
|
||||
if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]:
|
||||
return cls._from_notify(lines[1:])
|
||||
if lines[0] == cls._start_lines[cls._OK]:
|
||||
return cls._from_response(lines[1:])
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _from_response(cls, lines: List):
|
||||
def _from_response(cls, lines: List) -> 'SSDPDatagram':
|
||||
return cls(cls._OK, cls._lines_to_content_dict(lines))
|
||||
|
||||
@classmethod
|
||||
def _from_notify(cls, lines: List):
|
||||
def _from_notify(cls, lines: List) -> 'SSDPDatagram':
|
||||
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
|
||||
|
||||
@classmethod
|
||||
def _from_request(cls, lines: List):
|
||||
def _from_request(cls, lines: List) -> 'SSDPDatagram':
|
||||
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))
|
||||
|
|
81
aioupnp/serialization/xml.py
Normal file
81
aioupnp/serialization/xml.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import typing
|
||||
from collections import OrderedDict
|
||||
from defusedxml import ElementTree
|
||||
|
||||
str_any_dict = typing.Dict[str, typing.Any]
|
||||
|
||||
|
||||
def parse_xml(xml_str: str) -> ElementTree:
|
||||
element: ElementTree = ElementTree.fromstring(xml_str)
|
||||
return element
|
||||
|
||||
|
||||
def _element_text(element_tree: ElementTree) -> typing.Optional[str]:
|
||||
# if element_tree.attrib:
|
||||
# element: typing.Dict[str, str] = OrderedDict()
|
||||
# for k, v in element_tree.attrib.items():
|
||||
# element['@' + k] = v
|
||||
# if not element_tree.text:
|
||||
# return element
|
||||
# element['#text'] = element_tree.text.strip()
|
||||
# return element
|
||||
if element_tree.text:
|
||||
return element_tree.text.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _get_element_children(element_tree: ElementTree) -> typing.Dict[str, typing.Union[
|
||||
str, typing.List[typing.Union[typing.Dict[str, str], str]]]]:
|
||||
element_children = _get_child_dicts(element_tree)
|
||||
element: typing.Dict[str, typing.Any] = OrderedDict()
|
||||
keys: typing.List[str] = list(element_children.keys())
|
||||
for k in keys:
|
||||
v: typing.Union[str, typing.List[typing.Any], typing.Dict[str, typing.Any]] = element_children[k]
|
||||
if len(v) == 1 and isinstance(v, list):
|
||||
l: typing.List[typing.Union[typing.Dict[str, str], str]] = v
|
||||
element[k] = l[0]
|
||||
else:
|
||||
element[k] = v
|
||||
return element
|
||||
|
||||
|
||||
def _get_child_dicts(element: ElementTree) -> typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]]:
|
||||
children_dicts: typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]] = OrderedDict()
|
||||
children: typing.List[ElementTree] = list(element)
|
||||
for child in children:
|
||||
child_dict = _recursive_element_to_dict(child)
|
||||
child_keys: typing.List[str] = list(child_dict.keys())
|
||||
for k in child_keys:
|
||||
assert k in child_dict
|
||||
v: typing.Union[typing.Dict[str, str], str] = child_dict[k]
|
||||
if k not in children_dicts.keys():
|
||||
new_item = [v]
|
||||
children_dicts[k] = new_item
|
||||
else:
|
||||
sublist = children_dicts[k]
|
||||
assert isinstance(sublist, list)
|
||||
sublist.append(v)
|
||||
return children_dicts
|
||||
|
||||
|
||||
def _recursive_element_to_dict(element_tree: ElementTree) -> typing.Dict[str, typing.Any]:
|
||||
if len(element_tree):
|
||||
element_result: typing.Dict[str, typing.Dict[str, typing.Union[
|
||||
str, typing.List[typing.Union[str, typing.Dict[str, typing.Any]]]]]] = OrderedDict()
|
||||
children_element = _get_element_children(element_tree)
|
||||
if element_tree.tag is not None:
|
||||
element_result[element_tree.tag] = children_element
|
||||
return element_result
|
||||
else:
|
||||
element_text = _element_text(element_tree)
|
||||
if element_text is not None:
|
||||
base_element_result: typing.Dict[str, typing.Any] = OrderedDict()
|
||||
if element_tree.tag is not None:
|
||||
base_element_result[element_tree.tag] = element_text
|
||||
return base_element_result
|
||||
null_result: typing.Dict[str, str] = OrderedDict()
|
||||
return null_result
|
||||
|
||||
|
||||
def xml_to_dict(xml_str: str) -> typing.Dict[str, typing.Any]:
|
||||
return _recursive_element_to_dict(parse_xml(xml_str))
|
|
@ -5,12 +5,11 @@ import asyncio
|
|||
import zlib
|
||||
import base64
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Dict, List, Union
|
||||
from typing import Tuple, Dict, List, Union, Optional, Callable
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.gateway import Gateway
|
||||
from aioupnp.util import get_gateway_and_lan_addresses
|
||||
from aioupnp.interfaces import get_gateway_and_lan_addresses
|
||||
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
|
||||
from aioupnp.commands import SOAPCommand
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -35,6 +34,10 @@ class UPnP:
|
|||
self.gateway_address = gateway_address
|
||||
self.gateway = gateway
|
||||
|
||||
@classmethod
|
||||
def get_annotations(cls, command: str) -> Dict[str, type]:
|
||||
return getattr(Gateway.commands, command).__annotations__
|
||||
|
||||
@classmethod
|
||||
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
||||
interface_name: str = 'default') -> Tuple[str, str]:
|
||||
|
@ -59,8 +62,9 @@ class UPnP:
|
|||
@classmethod
|
||||
@cli
|
||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||
igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default',
|
||||
loop=None) -> Dict:
|
||||
igd_args: Optional[Dict[str, Union[int, str]]] = None,
|
||||
unicast: bool = True, interface_name: str = 'default',
|
||||
loop=None) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]:
|
||||
if not lan_address or not gateway_address:
|
||||
try:
|
||||
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
||||
|
@ -97,13 +101,13 @@ class UPnP:
|
|||
async def get_port_mapping_by_index(self, index: int) -> Dict:
|
||||
result = await self._get_port_mapping_by_index(index)
|
||||
if result:
|
||||
if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand):
|
||||
if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
|
||||
return {
|
||||
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
|
||||
}
|
||||
return {}
|
||||
|
||||
async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Union[None, str], int, str,
|
||||
async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Optional[str], int, str,
|
||||
int, str, bool, str, int]]:
|
||||
try:
|
||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||
|
@ -134,7 +138,7 @@ class UPnP:
|
|||
result = await self.gateway.commands.GetSpecificPortMappingEntry(
|
||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
|
||||
if result and self.gateway.commands.is_registered('GetSpecificPortMappingEntry'):
|
||||
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
|
||||
except UPnPError:
|
||||
pass
|
||||
|
@ -152,7 +156,8 @@ class UPnP:
|
|||
)
|
||||
|
||||
@cli
|
||||
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
|
||||
async def get_next_mapping(self, port: int, protocol: str, description: str,
|
||||
internal_port: Optional[int] = None) -> int:
|
||||
if protocol not in ["UDP", "TCP"]:
|
||||
raise UPnPError("unsupported protocol: {}".format(protocol))
|
||||
internal_port = int(internal_port or port)
|
||||
|
@ -340,8 +345,9 @@ class UPnP:
|
|||
return await self.gateway.commands.GetActiveConnections()
|
||||
|
||||
@classmethod
|
||||
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
|
||||
interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None:
|
||||
def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
||||
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
|
||||
"""
|
||||
:param method: the command name
|
||||
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
|
||||
|
@ -356,7 +362,7 @@ class UPnP:
|
|||
igd_args = igd_args
|
||||
timeout = int(timeout)
|
||||
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
|
||||
fut: asyncio.Future = asyncio.Future()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
|
||||
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
||||
|
||||
|
|
119
aioupnp/util.py
119
aioupnp/util.py
|
@ -1,91 +1,48 @@
|
|||
import re
|
||||
import socket
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Dict
|
||||
from xml.etree import ElementTree
|
||||
import netifaces
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
|
||||
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
|
||||
str_any_dict = typing.Dict[str, typing.Any]
|
||||
|
||||
|
||||
def etree_to_dict(t: ElementTree.Element) -> Dict:
|
||||
d: dict = {}
|
||||
if t.attrib:
|
||||
d[t.tag] = {}
|
||||
children = list(t)
|
||||
if children:
|
||||
dd: dict = defaultdict(list)
|
||||
for dc in map(etree_to_dict, children):
|
||||
for k, v in dc.items():
|
||||
dd[k].append(v)
|
||||
d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()}
|
||||
if t.attrib:
|
||||
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
|
||||
if t.text:
|
||||
text = t.text.strip()
|
||||
if children or t.attrib:
|
||||
if text:
|
||||
d[t.tag]['#text'] = text
|
||||
else:
|
||||
d[t.tag] = text
|
||||
return d
|
||||
|
||||
|
||||
def flatten_keys(d, strip):
|
||||
if not isinstance(d, (list, dict)):
|
||||
return d
|
||||
if isinstance(d, list):
|
||||
return [flatten_keys(i, strip) for i in d]
|
||||
t = {}
|
||||
for k, v in d.items():
|
||||
def _recursive_flatten(to_flatten: typing.Any, strip: str) -> typing.Any:
|
||||
if not isinstance(to_flatten, (list, dict)):
|
||||
return to_flatten
|
||||
if isinstance(to_flatten, list):
|
||||
assert isinstance(to_flatten, list)
|
||||
return [_recursive_flatten(i, strip) for i in to_flatten]
|
||||
assert isinstance(to_flatten, dict)
|
||||
keys: typing.List[str] = list(to_flatten.keys())
|
||||
copy: str_any_dict = OrderedDict()
|
||||
for k in keys:
|
||||
item: typing.Any = to_flatten[k]
|
||||
if strip in k and strip != k:
|
||||
t[k.split(strip)[1]] = flatten_keys(v, strip)
|
||||
copy[k.split(strip)[1]] = _recursive_flatten(item, strip)
|
||||
else:
|
||||
t[k] = flatten_keys(v, strip)
|
||||
return t
|
||||
copy[k] = _recursive_flatten(item, strip)
|
||||
return copy
|
||||
|
||||
|
||||
def get_dict_val_case_insensitive(d, k):
|
||||
match = list(filter(lambda x: x.lower() == k.lower(), d.keys()))
|
||||
if not match:
|
||||
return
|
||||
def flatten_keys(to_flatten: str_any_dict, strip: str) -> str_any_dict:
|
||||
keys: typing.List[str] = list(to_flatten.keys())
|
||||
copy: str_any_dict = OrderedDict()
|
||||
for k in keys:
|
||||
item = to_flatten[k]
|
||||
if strip in k and strip != k:
|
||||
new_key: str = k.split(strip)[1]
|
||||
copy[new_key] = _recursive_flatten(item, strip)
|
||||
else:
|
||||
copy[k] = _recursive_flatten(item, strip)
|
||||
return copy
|
||||
|
||||
|
||||
def get_dict_val_case_insensitive(source: typing.Dict[typing.AnyStr, typing.AnyStr], key: typing.AnyStr) -> typing.Optional[typing.AnyStr]:
|
||||
match: typing.List[typing.AnyStr] = list(filter(lambda x: x.lower() == key.lower(), source.keys()))
|
||||
if not len(match):
|
||||
return None
|
||||
if len(match) > 1:
|
||||
raise KeyError("overlapping keys")
|
||||
return d[match[0]]
|
||||
|
||||
# import struct
|
||||
# import fcntl
|
||||
# def get_ip_address(ifname):
|
||||
# SIOCGIFADDR = 0x8915
|
||||
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# return socket.inet_ntoa(fcntl.ioctl(
|
||||
# s.fileno(),
|
||||
# SIOCGIFADDR,
|
||||
# struct.pack(b'256s', ifname[:15].encode())
|
||||
# )[20:24])
|
||||
|
||||
|
||||
def get_interfaces():
|
||||
r = {
|
||||
interface_name: (router_address, netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'])
|
||||
for router_address, interface_name, _ in netifaces.gateways()[socket.AF_INET]
|
||||
}
|
||||
for interface_name in netifaces.interfaces():
|
||||
if interface_name in ['lo', 'localhost'] or interface_name in r:
|
||||
continue
|
||||
addresses = netifaces.ifaddresses(interface_name)
|
||||
if netifaces.AF_INET in addresses:
|
||||
address = addresses[netifaces.AF_INET][0]['addr']
|
||||
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
|
||||
r[interface_name] = (gateway_guess, address)
|
||||
r['default'] = r[netifaces.gateways()['default'][netifaces.AF_INET][1]]
|
||||
return r
|
||||
|
||||
|
||||
def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
|
||||
for iface_name, (gateway, lan) in get_interfaces().items():
|
||||
if interface_name == iface_name:
|
||||
return gateway, lan
|
||||
return '', ''
|
||||
if len(match) == 1:
|
||||
matched_key: typing.AnyStr = match[0]
|
||||
return source[matched_key]
|
||||
raise KeyError("overlapping keys")
|
||||
|
|
2
setup.py
2
setup.py
|
@ -37,7 +37,7 @@ setup(
|
|||
packages=find_packages(exclude=('tests',)),
|
||||
entry_points={'console_scripts': console_scripts},
|
||||
install_requires=[
|
||||
'netifaces',
|
||||
'netifaces', 'defusedxml'
|
||||
],
|
||||
extras_require={
|
||||
'test': (
|
||||
|
|
Loading…
Reference in a new issue