mypy refactor, improve coverage #12

Merged
jackrobison merged 9 commits from mypy-refactor into master 2019-05-22 09:05:10 +02:00
36 changed files with 2525 additions and 1160 deletions

4
.coveragerc Normal file
View file

@ -0,0 +1,4 @@
[run]
omit =
tests/*
stubs/*

3
.gitignore vendored
View file

@ -3,6 +3,9 @@
_trial_temp/ _trial_temp/
build/ build/
dist/ dist/
html/
index.html
mypy-html.css
.coverage .coverage
.mypy_cache/ .mypy_cache/
aioupnp.spec aioupnp.spec

442
.pylintrc Normal file
View file

@ -0,0 +1,442 @@
[MASTER]
# Specify a configuration file.
#rcfile=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,schema
# Add files or directories matching the regex patterns to the
# blacklist. The regex matches against base names, not paths.
# `\.#.*` - add emacs tmp files to the blacklist
ignore-patterns=\.#.*
# Pickle collected data for later comparisons.
persistent=yes
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=1
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=netifaces,
# Allow optimization of some AST trees. This will activate a peephole AST
# optimizer, which will apply various small optimizations. For instance, it can
# be used to obtain the result of joining multiple strings with the addition
# operator. Joining a lot of strings can lead to a maximum recursion error in
# Pylint and this flag can prevent that. It has one side effect, the resulting
# AST will be different than the one from reality.
optimize-ast=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=
anomalous-backslash-in-string,
arguments-differ,
attribute-defined-outside-init,
bad-continuation,
bare-except,
broad-except,
cell-var-from-loop,
consider-iterating-dictionary,
dangerous-default-value,
duplicate-code,
fixme,
global-statement,
inherit-non-class,
invalid-name,
len-as-condition,
locally-disabled,
logging-not-lazy,
missing-docstring,
no-else-return,
no-init,
no-member,
no-self-use,
protected-access,
redefined-builtin,
redefined-outer-name,
redefined-variable-type,
relative-import,
signature-differs,
super-init-not-called,
too-few-public-methods,
too-many-arguments,
too-many-branches,
too-many-instance-attributes,
too-many-lines,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
undefined-loop-variable,
ungrouped-imports,
unnecessary-lambda,
unused-argument,
unused-variable,
wildcard-import,
wrong-import-order,
wrong-import-position,
deprecated-lambda,
simplifiable-if-statement,
unidiomatic-typecheck,
global-at-module-level,
inconsistent-return-statements,
keyword-arg-before-vararg,
assignment-from-no-return,
useless-return,
assignment-from-none,
stop-iteration-return,
unsubscriptable-object,
unsupported-membership-test
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]".
files-output=no
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=_$|dummy
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging
[BASIC]
# List of builtins function names that should not be used, separated by a comma
bad-functions=map,filter,input
# Good variable names which should always be accepted, separated by a comma
# allow `d` as its used frequently for deferred callback chains
good-names=i,j,k,ex,Run,_,d
# Bad variable names which should always be refused, separated by a comma
bad-names=foo,bar,baz,toto,tutu,tata
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct variable names
variable-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for variable names
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct constant names
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Naming hint for constant names
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Regular expression matching correct attribute names
attr-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for attribute names
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for argument names
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Naming hint for class attribute names
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
# Naming hint for inline iteration names
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
# Naming hint for class names
class-name-hint=[A-Z_][a-zA-Z0-9]+$
# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Naming hint for module names
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Regular expression matching correct method names
method-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for method names
method-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
[ELIF]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,dict-separator
# Maximum number of lines in a module
max-module-lines=1000
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,XXX,TODO
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[TYPECHECK]
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=leveldb,distutils
# Ignoring distutils because: https://github.com/PyCQA/pylint/issues/73
# List of classes names for which member attributes should not be checked
# (useful for classes with attributes dynamically set). This supports can work
# with qualified names.
# ignored-classes=
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,TERMIOS,Bastion,rexec,miniupnpc
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
[DESIGN]
# Maximum number of arguments for function / method
max-args=10
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of statements in function / method body
max-statements=50
# Maximum number of parents for a class (see R0901).
max-parents=8
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of boolean expressions in a if statement
max-bool-expr=5
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,__new__,setUp
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception

View file

@ -12,7 +12,7 @@ jobs:
- pip install -e .[test] - pip install -e .[test]
script: script:
- mypy . --txt-report . --scripts-are-modules; cat index.txt; rm index.txt - mypy aioupnp --txt-report . --scripts-are-modules; cat index.txt; rm index.txt
- &tests - &tests
stage: test stage: test

View file

@ -1,4 +1,4 @@
__version__ = "0.0.12" __version__ = "0.0.13a"
__name__ = "aioupnp" __name__ = "aioupnp"
__author__ = "Jack Robison" __author__ = "Jack Robison"
__maintainer__ = "Jack Robison" __maintainer__ = "Jack Robison"

View file

@ -1,8 +1,11 @@
import logging
import sys import sys
import asyncio
import logging
import textwrap import textwrap
import typing
from collections import OrderedDict from collections import OrderedDict
from aioupnp.upnp import UPnP from aioupnp.upnp import run_cli, UPnP
from aioupnp.commands import SOAPCommands
log = logging.getLogger("aioupnp") log = logging.getLogger("aioupnp")
handler = logging.StreamHandler() handler = logging.StreamHandler()
@ -16,17 +19,18 @@ base_usage = "\n".join(textwrap.wrap(
100, subsequent_indent=' ', break_long_words=False)) + "\n" 100, subsequent_indent=' ', break_long_words=False)) + "\n"
def get_help(command): def get_help(command: str) -> str:
fn = getattr(UPnP, command) annotations = UPnP.get_annotations(command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return']) params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return'])
return base_usage + "\n".join( return base_usage + "\n".join(
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False) textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False)
) )
def main(argv=None, loop=None): def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
argv = argv or sys.argv loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")] argv = argv or list(sys.argv)
commands = list(SOAPCommands.SOAP_COMMANDS)
help_str = "\n".join(textwrap.wrap( help_str = "\n".join(textwrap.wrap(
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False " | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
)) ))
@ -41,14 +45,16 @@ def main(argv=None, loop=None):
"For help with a specific command:" \ "For help with a specific command:" \
" aioupnp help <command>\n" % (base_usage, help_str) " aioupnp help <command>\n" % (base_usage, help_str)
args = argv[1:] args: typing.List[str] = [str(arg) for arg in argv[1:]]
if args[0] in ['help', '-h', '--help']: if args[0] in ['help', '-h', '--help']:
if len(args) > 1: if len(args) > 1:
if args[1] in commands: if args[1] in commands:
sys.exit(get_help(args[1])) print(get_help(args[1]))
sys.exit(print(usage)) return 0
print(usage)
return 0
defaults = { defaults: typing.Dict[str, typing.Union[bool, str, int]] = {
'debug_logging': False, 'debug_logging': False,
'interface': 'default', 'interface': 'default',
'gateway_address': '', 'gateway_address': '',
@ -57,22 +63,22 @@ def main(argv=None, loop=None):
'unicast': False 'unicast': False
} }
options = OrderedDict() options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
command = None command = None
for arg in args: for arg in args:
if arg.startswith("--"): if arg.startswith("--"):
if "=" in arg: if "=" in arg:
k, v = arg.split("=") k, v = arg.split("=")
options[k.lstrip('--')] = v
else: else:
k, v = arg, True options[arg.lstrip('--')] = True
k = k.lstrip('--')
options[k] = v
else: else:
command = arg command = arg
break break
if not command: if not command:
print("no command given") print("no command given")
sys.exit(print(usage)) print(usage)
return 0
kwargs = {} kwargs = {}
for arg in args[len(options)+1:]: for arg in args[len(options)+1:]:
if arg.startswith("--"): if arg.startswith("--"):
@ -81,18 +87,24 @@ def main(argv=None, loop=None):
kwargs[k] = v kwargs[k] = v
else: else:
break break
for k, v in defaults.items(): for k in defaults:
if k not in options: if k not in options:
options[k] = v options[k] = defaults[k]
if options.pop('debug_logging'): if options.pop('debug_logging'):
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
UPnP.run_cli( lan_address: str = str(options.pop('lan_address'))
command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'), gateway_address: str = str(options.pop('gateway_address'))
options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop timeout: int = int(options.pop('timeout'))
interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
run_cli(
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
) )
return 0
if __name__ == "__main__": if __name__ == "__main__":
main() sys.exit(main())

View file

@ -1,64 +1,44 @@
import logging import asyncio
import time import time
import typing import typing
from typing import Tuple, Union, List import logging
from typing import Tuple
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
from aioupnp.device import Service
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
none_or_str = Union[None, str]
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
def safe_type(t): def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]:
if t is typing.Tuple: return x if x is not None and str(x).lower() not in ['none', 'nil'] else None
return tuple
if t is typing.List:
return list
if t is typing.Dict:
return dict
if t is typing.Set:
return set
return t
class SOAPCommand: def soap_bool(x: typing.Optional[str]) -> bool:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, return False if not x or str(x).lower() in ['false', 'False'] else True
param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None:
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
self.service_id = service_id
self.method = method
self.param_types = param_types
self.param_order = param_order
self.return_types = return_types
self.return_order = return_order
self.loop = loop
self._requests: typing.List = []
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()): def recast_single_result(t: type, result: typing.Any) -> typing.Optional[typing.Union[str, int, float, bool]]:
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) if t is bool:
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()} return soap_bool(result)
response, xml_bytes, err = await scpd_post( if t is str:
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, return soap_optional_str(result)
self.service_id, self.loop, **soap_kwargs return t(result)
)
if err is not None:
self._requests.append((soap_kwargs, xml_bytes, None, err, time.time())) def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]],
raise err result_keys: typing.List[str]) -> typing.Tuple:
if not response: if return_annotation is None or len(result_keys) == 0:
result = None return ()
else: if len(result_keys) == 1:
recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order]) assert len(result_keys) == 1
if len(recast_result) == 1: single_result = result[result_keys[0]]
result = recast_result[0] return (recast_single_result(return_annotation, single_result), )
else: annotated_args: typing.List[type] = list(return_annotation.__args__)
result = recast_result assert len(annotated_args) == len(result_keys)
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time())) recast_results: typing.List[typing.Optional[typing.Union[str, int, float, bool]]] = []
return result for type_annotation, result_key in zip(annotated_args, result_keys):
recast_results.append(recast_single_result(type_annotation, result.get(result_key, None)))
return tuple(recast_results)
class SOAPCommands: class SOAPCommands:
@ -72,179 +52,319 @@ class SOAPCommands:
to their expected types. to their expected types.
""" """
SOAP_COMMANDS = [ SOAP_COMMANDS: typing.List[str] = [
'AddPortMapping', 'AddPortMapping',
'GetNATRSIPStatus',
'GetGenericPortMappingEntry', 'GetGenericPortMappingEntry',
'GetSpecificPortMappingEntry', 'GetSpecificPortMappingEntry',
'SetConnectionType',
'GetExternalIPAddress',
'GetConnectionTypeInfo',
'GetStatusInfo',
'ForceTermination',
'DeletePortMapping', 'DeletePortMapping',
'RequestConnection', 'GetExternalIPAddress',
'GetCommonLinkProperties', # 'SetConnectionType',
'GetTotalBytesSent', # 'GetNATRSIPStatus',
'GetTotalBytesReceived', # 'GetConnectionTypeInfo',
'GetTotalPacketsSent', # 'GetStatusInfo',
'GetTotalPacketsReceived', # 'ForceTermination',
'X_GetICSStatistics', # 'RequestConnection',
'GetDefaultConnectionService', # 'GetCommonLinkProperties',
'NewDefaultConnectionService', # 'GetTotalBytesSent',
'NewEnabledForInternet', # 'GetTotalBytesReceived',
'SetDefaultConnectionService', # 'GetTotalPacketsSent',
'SetEnabledForInternet', # 'GetTotalPacketsReceived',
'GetEnabledForInternet', # 'X_GetICSStatistics',
'NewActiveConnectionIndex', # 'GetDefaultConnectionService',
'GetMaximumActiveConnections', # 'SetDefaultConnectionService',
'GetActiveConnections' # 'SetEnabledForInternet',
# 'GetEnabledForInternet',
# 'GetMaximumActiveConnections',
# 'GetActiveConnections'
] ]
def __init__(self): def __init__(self, loop: asyncio.AbstractEventLoop, base_address: bytes, port: int) -> None:
self._registered = set() self._loop = loop
self._registered: typing.Dict[Service,
typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]]] = {}
self._wrappers_no_args: typing.Dict[str, typing.Callable[[], typing.Awaitable[typing.Any]]] = {}
self._wrappers_kwargs: typing.Dict[str, typing.Callable[..., typing.Awaitable[typing.Any]]] = {}
def register(self, base_ip: bytes, port: int, name: str, control_url: str, self._base_address = base_address
service_type: bytes, inputs: List, outputs: List, loop=None) -> None: self._port = port
if name not in self.SOAP_COMMANDS or name in self._registered: self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
raise AttributeError(name) typing.Tuple, typing.Optional[Exception], float]] = []
current = getattr(self, name)
annotations = current.__annotations__ def is_registered(self, name: str) -> bool:
return_types = annotations.get('return', None) if name not in self.SOAP_COMMANDS:
if return_types: raise ValueError("unknown command")
if hasattr(return_types, '__args__'): for service in self._registered.values():
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__]) if name in service:
elif isinstance(return_types, type): return True
return_types = (return_types,) return False
return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {} def get_service(self, name: str) -> Service:
for param_name, param_type in annotations.items(): if name not in self.SOAP_COMMANDS:
if param_name == "return": raise ValueError("unknown command")
continue for service, commands in self._registered.items():
param_types[param_name] = param_type if name in commands:
command = SOAPCommand( return service
base_ip.decode(), port, control_url, service_type, raise ValueError(name)
name, param_types, return_types, inputs, outputs, loop=loop
def _register_soap_wrapper(self, name: str) -> None:
annotations: typing.Dict[str, typing.Any] = typing.get_type_hints(getattr(self, name))
service = self.get_service(name)
input_names: typing.List[str] = self._registered[service][name][0]
output_names: typing.List[str] = self._registered[service][name][1]
async def wrapper(**kwargs: typing.Any) -> typing.Tuple:
assert service.controlURL is not None
assert service.serviceType is not None
response, xml_bytes, err = await scpd_post(
service.controlURL, self._base_address.decode(), self._port, name, input_names,
service.serviceType.encode(), self._loop, **kwargs
) )
setattr(command, "__doc__", current.__doc__) if err is not None:
setattr(self, command.method, command) assert isinstance(xml_bytes, bytes)
self._registered.add(command.method) self._requests.append((name, kwargs, xml_bytes, (), err, time.time()))
raise err
assert 'return' in annotations
result = recast_return(annotations['return'], response, output_names)
@staticmethod self._requests.append((name, kwargs, xml_bytes, result, None, time.time()))
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int, return result
if not len(list(k for k in annotations if k != 'return')):
self._wrappers_no_args[name] = wrapper
else:
self._wrappers_kwargs[name] = wrapper
return None
def register(self, name: str, service: Service, inputs: typing.List[str], outputs: typing.List[str]) -> None:
if name not in self.SOAP_COMMANDS:
raise AttributeError(name)
if self.is_registered(name):
raise AttributeError(f"{name} is already a registered SOAP command")
if service not in self._registered:
self._registered[service] = {}
self._registered[service][name] = inputs, outputs
self._register_soap_wrapper(name)
async def AddPortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str, NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
NewLeaseDuration: str) -> None: NewLeaseDuration: str) -> None:
"""Returns None""" """Returns None"""
name = "AddPortMapping"
if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_kwargs
await self._wrappers_kwargs[name](
NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol,
NewInternalPort=NewInternalPort, NewInternalClient=NewInternalClient, NewEnabled=NewEnabled,
NewPortMappingDescription=NewPortMappingDescription, NewLeaseDuration=NewLeaseDuration
)
return None
@staticmethod async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
async def GetNATRSIPStatus() -> Tuple[bool, bool]:
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
bool, str, int]: bool, str, int]:
""" """
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration) NewPortMappingDescription, NewLeaseDuration)
""" """
name = "GetGenericPortMappingEntry"
if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_kwargs
result: Tuple[str, int, str, int, str, bool, str, int] = await self._wrappers_kwargs[name](
NewPortMappingIndex=NewPortMappingIndex
)
return result
@staticmethod async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]: NewProtocol: str) -> Tuple[int, str, bool, str, int]:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
name = "GetSpecificPortMappingEntry"
if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_kwargs
result: Tuple[int, str, bool, str, int] = await self._wrappers_kwargs[name](
NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol
)
return result
@staticmethod async def DeletePortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
async def SetConnectionType(NewConnectionType: str) -> None:
"""Returns None""" """Returns None"""
name = "DeletePortMapping"
if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_kwargs
await self._wrappers_kwargs[name](
NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol
)
return None
@staticmethod async def GetExternalIPAddress(self) -> str:
async def GetExternalIPAddress() -> str:
"""Returns (NewExternalIPAddress)""" """Returns (NewExternalIPAddress)"""
name = "GetExternalIPAddress"
if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_no_args
result: Tuple[str] = await self._wrappers_no_args[name]()
return result[0]
@staticmethod # async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
async def GetConnectionTypeInfo() -> Tuple[str, str]: # """Returns (NewRSIPAvailable, NewNATEnabled)"""
"""Returns (NewConnectionType, NewPossibleConnectionTypes)""" # name = "GetNATRSIPStatus"
raise NotImplementedError() # if not self.is_registered(name):
# raise NotImplementedError()
@staticmethod # assert name in self._wrappers_no_args
async def GetStatusInfo() -> Tuple[str, str, int]: # result: Tuple[bool, bool] = await self._wrappers_no_args[name]()
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" # return result[0], result[1]
raise NotImplementedError() #
# async def SetConnectionType(self, NewConnectionType: str) -> None:
@staticmethod # """Returns None"""
async def ForceTermination() -> None: # name = "SetConnectionType"
"""Returns None""" # if not self.is_registered(name):
raise NotImplementedError() # raise NotImplementedError()
# assert name in self._wrappers_kwargs
@staticmethod # await self._wrappers_kwargs[name](NewConnectionType=NewConnectionType)
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None: # return None
"""Returns None""" #
raise NotImplementedError() # async def GetConnectionTypeInfo(self) -> Tuple[str, str]:
# """Returns (NewConnectionType, NewPossibleConnectionTypes)"""
@staticmethod # name = "GetConnectionTypeInfo"
async def RequestConnection() -> None: # if not self.is_registered(name):
"""Returns None""" # raise NotImplementedError()
raise NotImplementedError() # assert name in self._wrappers_no_args
# result: Tuple[str, str] = await self._wrappers_no_args[name]()
@staticmethod # return result
async def GetCommonLinkProperties(): #
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)""" # async def GetStatusInfo(self) -> Tuple[str, str, int]:
raise NotImplementedError() # """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
# name = "GetStatusInfo"
@staticmethod # if not self.is_registered(name):
async def GetTotalBytesSent(): # raise NotImplementedError()
"""Returns (NewTotalBytesSent)""" # assert name in self._wrappers_no_args
raise NotImplementedError() # result: Tuple[str, str, int] = await self._wrappers_no_args[name]()
# return result
@staticmethod #
async def GetTotalBytesReceived(): # async def ForceTermination(self) -> None:
"""Returns (NewTotalBytesReceived)""" # """Returns None"""
raise NotImplementedError() # name = "ForceTermination"
# if not self.is_registered(name):
@staticmethod # raise NotImplementedError()
async def GetTotalPacketsSent(): # assert name in self._wrappers_no_args
"""Returns (NewTotalPacketsSent)""" # await self._wrappers_no_args[name]()
raise NotImplementedError() # return None
#
@staticmethod # async def RequestConnection(self) -> None:
def GetTotalPacketsReceived(): # """Returns None"""
"""Returns (NewTotalPacketsReceived)""" # name = "RequestConnection"
raise NotImplementedError() # if not self.is_registered(name):
# raise NotImplementedError()
@staticmethod # assert name in self._wrappers_no_args
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]: # await self._wrappers_no_args[name]()
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)""" # return None
raise NotImplementedError() #
# async def GetCommonLinkProperties(self) -> Tuple[str, int, int, str]:
@staticmethod # """Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate,
async def GetDefaultConnectionService(): # NewPhysicalLinkStatus)"""
"""Returns (NewDefaultConnectionService)""" # name = "GetCommonLinkProperties"
raise NotImplementedError() # if not self.is_registered(name):
# raise NotImplementedError()
@staticmethod # assert name in self._wrappers_no_args
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None: # result: Tuple[str, int, int, str] = await self._wrappers_no_args[name]()
"""Returns (None)""" # return result
raise NotImplementedError() #
# async def GetTotalBytesSent(self) -> int:
@staticmethod # """Returns (NewTotalBytesSent)"""
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None: # name = "GetTotalBytesSent"
raise NotImplementedError() # if not self.is_registered(name):
# raise NotImplementedError()
@staticmethod # assert name in self._wrappers_no_args
async def GetEnabledForInternet() -> bool: # result: Tuple[int] = await self._wrappers_no_args[name]()
raise NotImplementedError() # return result[0]
#
@staticmethod # async def GetTotalBytesReceived(self) -> int:
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int): # """Returns (NewTotalBytesReceived)"""
raise NotImplementedError() # name = "GetTotalBytesReceived"
# if not self.is_registered(name):
@staticmethod # raise NotImplementedError()
async def GetActiveConnections() -> Tuple[str, str]: # assert name in self._wrappers_no_args
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" # result: Tuple[int] = await self._wrappers_no_args[name]()
raise NotImplementedError() # return result[0]
#
# async def GetTotalPacketsSent(self) -> int:
# """Returns (NewTotalPacketsSent)"""
# name = "GetTotalPacketsSent"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_no_args
# result: Tuple[int] = await self._wrappers_no_args[name]()
# return result[0]
#
# async def GetTotalPacketsReceived(self) -> int:
# """Returns (NewTotalPacketsReceived)"""
# name = "GetTotalPacketsReceived"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_no_args
# result: Tuple[int] = await self._wrappers_no_args[name]()
# return result[0]
#
# async def X_GetICSStatistics(self) -> Tuple[int, int, int, int, str, str]:
# """Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived,
# Layer1DownstreamMaxBitRate, Uptime)"""
# name = "X_GetICSStatistics"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_no_args
# result: Tuple[int, int, int, int, str, str] = await self._wrappers_no_args[name]()
# return result
#
# async def GetDefaultConnectionService(self) -> str:
# """Returns (NewDefaultConnectionService)"""
# name = "GetDefaultConnectionService"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_no_args
# result: Tuple[str] = await self._wrappers_no_args[name]()
# return result[0]
#
# async def SetDefaultConnectionService(self, NewDefaultConnectionService: str) -> None:
# """Returns (None)"""
# name = "SetDefaultConnectionService"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_kwargs
# await self._wrappers_kwargs[name](NewDefaultConnectionService=NewDefaultConnectionService)
# return None
#
# async def SetEnabledForInternet(self, NewEnabledForInternet: bool) -> None:
# name = "SetEnabledForInternet"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_kwargs
# await self._wrappers_kwargs[name](NewEnabledForInternet=NewEnabledForInternet)
# return None
#
# async def GetEnabledForInternet(self) -> bool:
# name = "GetEnabledForInternet"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_no_args
# result: Tuple[bool] = await self._wrappers_no_args[name]()
# return result[0]
#
# async def GetMaximumActiveConnections(self, NewActiveConnectionIndex: int) -> None:
# name = "GetMaximumActiveConnections"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_kwargs
# await self._wrappers_kwargs[name](NewActiveConnectionIndex=NewActiveConnectionIndex)
# return None
#
# async def GetActiveConnections(self) -> Tuple[str, str]:
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
# name = "GetActiveConnections"
# if not self.is_registered(name):
# raise NotImplementedError()
# assert name in self._wrappers_no_args
# result: Tuple[str, str] = await self._wrappers_no_args[name]()
# return result

View file

@ -1,23 +1,34 @@
from collections import OrderedDict
import typing
import logging import logging
from typing import List
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class CaseInsensitive: class CaseInsensitive:
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any],
for k, v in kwargs.items(): typing.List[typing.Any]]]) -> None:
keys: typing.List[str] = list(kwargs.keys())
for k in keys:
if not k.startswith("_"): if not k.startswith("_"):
setattr(self, k, v) assert k in kwargs
setattr(self, k, kwargs[k])
def __getattr__(self, item): def __getattr__(self, item: str) -> typing.Union[str, typing.Dict[str, typing.Any], typing.List]:
for k in self.__class__.__dict__.keys(): keys: typing.List[str] = list(self.__class__.__dict__.keys())
for k in keys:
if k.lower() == item.lower(): if k.lower() == item.lower():
return self.__dict__.get(k) value: typing.Optional[typing.Union[str, typing.Dict[str, typing.Any],
typing.List]] = self.__dict__.get(k)
assert value is not None and isinstance(value, (str, dict, list))
return value
raise AttributeError(item) raise AttributeError(item)
def __setattr__(self, item, value): def __setattr__(self, item: str,
for k, v in self.__class__.__dict__.items(): value: typing.Union[str, typing.Dict[str, typing.Any], typing.List]) -> None:
assert isinstance(value, (str, dict)), ValueError(f"got type {str(type(value))}, expected str")
keys: typing.List[str] = list(self.__class__.__dict__.keys())
for k in keys:
if k.lower() == item.lower(): if k.lower() == item.lower():
self.__dict__[k] = value self.__dict__[k] = value
return return
@ -26,52 +37,57 @@ class CaseInsensitive:
return return
raise AttributeError(item) raise AttributeError(item)
def as_dict(self) -> dict: def as_dict(self) -> typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]:
return { result: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]] = OrderedDict()
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) keys: typing.List[str] = list(self.__dict__.keys())
} for k in keys:
if not k.startswith("_"):
result[k] = self.__getattr__(k)
return result
class Service(CaseInsensitive): class Service(CaseInsensitive):
serviceType = None serviceType: typing.Optional[str] = None
serviceId = None serviceId: typing.Optional[str] = None
controlURL = None controlURL: typing.Optional[str] = None
eventSubURL = None eventSubURL: typing.Optional[str] = None
SCPDURL = None SCPDURL: typing.Optional[str] = None
class Device(CaseInsensitive): class Device(CaseInsensitive):
serviceList = None serviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
deviceList = None deviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
deviceType = None deviceType: typing.Optional[str] = None
friendlyName = None friendlyName: typing.Optional[str] = None
manufacturer = None manufacturer: typing.Optional[str] = None
manufacturerURL = None manufacturerURL: typing.Optional[str] = None
modelDescription = None modelDescription: typing.Optional[str] = None
modelName = None modelName: typing.Optional[str] = None
modelNumber = None modelNumber: typing.Optional[str] = None
modelURL = None modelURL: typing.Optional[str] = None
serialNumber = None serialNumber: typing.Optional[str] = None
udn = None udn: typing.Optional[str] = None
upc = None upc: typing.Optional[str] = None
presentationURL = None presentationURL: typing.Optional[str] = None
iconList = None iconList: typing.Optional[str] = None
def __init__(self, devices: List, services: List, **kwargs) -> None: def __init__(self, devices: typing.List['Device'], services: typing.List[Service],
**kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]) -> None:
super(Device, self).__init__(**kwargs) super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList: if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"] if isinstance(self.serviceList['service'], dict):
if isinstance(new_services, dict): assert isinstance(self.serviceList['service'], dict)
new_services = [new_services] svc_list: typing.Dict[str, typing.Any] = self.serviceList['service']
services.extend([Service(**service) for service in new_services]) services.append(Service(**svc_list))
elif isinstance(self.serviceList['service'], list):
services.extend(Service(**svc) for svc in self.serviceList["service"])
if self.deviceList: if self.deviceList:
for kw in self.deviceList.values(): for kw in self.deviceList.values():
if isinstance(kw, dict): if isinstance(kw, dict):
d = Device(devices, services, **kw) devices.append(Device(devices, services, **kw))
devices.append(d)
elif isinstance(kw, list): elif isinstance(kw, list):
for _inner_kw in kw: for _inner_kw in kw:
d = Device(devices, services, **_inner_kw) devices.append(Device(devices, services, **_inner_kw))
devices.append(d)
else: else:
log.warning("failed to parse device:\n%s", kw) log.warning("failed to parse device:\n%s", kw)

View file

@ -1,14 +1,2 @@
from aioupnp.util import flatten_keys
from aioupnp.constants import FAULT, CONTROL
class UPnPError(Exception): class UPnPError(Exception):
pass pass
def handle_fault(response: dict) -> dict:
if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
error_description = fault['detail']['UPnPError']['errorDescription']
raise UPnPError(error_description)
return response

View file

@ -1,9 +1,10 @@
import re
import logging import logging
import socket import typing
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Union, Type from typing import Dict, List, Union
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
@ -15,77 +16,92 @@ from aioupnp.fault import UPnPError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
return_type_lambas = { BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
}
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...] def get_action_list(element_dict: typing.Dict[str, typing.Union[str, typing.Dict[str, str],
typing.List[typing.Dict[str, typing.Dict[str, str]]]]]
) -> typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]]:
service_info = flatten_keys(element_dict, "{%s}" % SERVICE) service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
result: typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]] = []
if "actionList" in service_info: if "actionList" in service_info:
action_list = service_info["actionList"] action_list = service_info["actionList"]
else: else:
return [] return result
if not len(action_list): # it could be an empty string if not len(action_list): # it could be an empty string
return [] return result
result: list = [] action = action_list["action"]
if isinstance(action_list["action"], dict): if isinstance(action, dict):
arg_dicts = action_list["action"]['argumentList']['argument'] arg_dicts: typing.List[typing.Dict[str, str]] = []
if not isinstance(arg_dicts, list): # when there is one arg if not isinstance(action['argumentList']['argument'], list): # when there is one arg
arg_dicts = [arg_dicts] arg_dicts.extend([action['argumentList']['argument']])
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else: else:
arg_dicts = action['argumentList']['argument'] arg_dicts.extend(action['argumentList']['argument'])
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts] result.append((action_list["action"]['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']))
return result
assert isinstance(action, list)
for _action in action:
if not _action.get('argumentList'):
result.append((_action['name'], [], []))
else:
if not isinstance(_action['argumentList']['argument'], list): # when there is one arg
arg_dicts = [_action['argumentList']['argument']]
else:
arg_dicts = _action['argumentList']['argument']
result.append(( result.append((
action['name'], _action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out'] [i['name'] for i in arg_dicts if i['direction'] == 'out']
)) ))
return result return result
def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
base_address_result: typing.List[bytes] = BASE_ADDRESS_REGEX.findall(location)
base_address = base_address_result[0]
port_result: typing.List[bytes] = BASE_PORT_REGEX.findall(location)
port = int(port_result[0])
return base_address, port
class Gateway: class Gateway:
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str, def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]],
gateway_address: str) -> None: lan_address: str, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet self._ok_packet = ok_packet
self._m_search_args = m_search_args self._m_search_args = m_search_args
self._lan_address = lan_address self._lan_address = lan_address
self.usn = (ok_packet.usn or '').encode() self.usn: bytes = (ok_packet.usn or '').encode()
self.ext = (ok_packet.ext or '').encode() self.ext: bytes = (ok_packet.ext or '').encode()
self.server = (ok_packet.server or '').encode() self.server: bytes = (ok_packet.server or '').encode()
self.location = (ok_packet.location or '').encode() self.location: bytes = (ok_packet.location or '').encode()
self.cache_control = (ok_packet.cache_control or '').encode() self.cache_control: bytes = (ok_packet.cache_control or '').encode()
self.date = (ok_packet.date or '').encode() self.date: bytes = (ok_packet.date or '').encode()
self.urn = (ok_packet.st or '').encode() self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response = b"" self._xml_response: bytes = b""
self._service_descriptors: Dict = {} self._service_descriptors: Dict[str, bytes] = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.base_address, self.port = parse_location(self.location)
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
assert self.base_ip == gateway_address.encode() assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version = None self.spec_version: typing.Optional[str] = None
self.url_base = None self.url_base: typing.Optional[str] = None
self._device: Union[None, Device] = None self._device: typing.Optional[Device] = None
self._devices: List = [] self._devices: List[Device] = []
self._services: List = [] self._services: List[Service] = []
self._unsupported_actions: Dict = {} self._unsupported_actions: Dict[str, typing.List[str]] = {}
self._registered_commands: Dict = {} self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands() self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
def gateway_descriptor(self) -> dict: def gateway_descriptor(self) -> dict:
r = { r = {
@ -102,14 +118,15 @@ class Gateway:
def manufacturer_string(self) -> str: def manufacturer_string(self) -> str:
if not self.devices: if not self.devices:
return "UNKNOWN GATEWAY" return "UNKNOWN GATEWAY"
device = list(self.devices.values())[0] devices: typing.List[Device] = list(self.devices.values())
return "%s %s" % (device.manufacturer, device.modelName) device = devices[0]
return f"{device.manufacturer} {device.modelName}"
@property @property
def services(self) -> Dict: def services(self) -> Dict[str, Service]:
if not self._device: if not self._device:
return {} return {}
return {service.serviceType: service for service in self._services} return {str(service.serviceType): service for service in self._services}
@property @property
def devices(self) -> Dict: def devices(self) -> Dict:
@ -117,28 +134,29 @@ class Gateway:
return {} return {}
return {device.udn: device for device in self._devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Union[Type[Service], None]: def get_service(self, service_type: str) -> typing.Optional[Service]:
for service in self._services: for service in self._services:
if service.serviceType.lower() == service_type.lower(): if service.serviceType and service.serviceType.lower() == service_type.lower():
return service return service
return None return None
@property @property
def soap_requests(self) -> List: def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
soap_call_infos = [] typing.Optional[typing.Tuple],
for name in self._registered_commands.keys(): typing.Optional[Exception], float]]:
if not hasattr(getattr(self.commands, name), "_requests"): soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
continue typing.Optional[typing.Tuple],
typing.Optional[Exception], float]] = []
soap_call_infos.extend([ soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts) (name, request_args, raw_response, decoded_response, soap_error, ts)
for ( for (
request_args, raw_response, decoded_response, soap_error, ts name, request_args, raw_response, decoded_response, soap_error, ts
) in getattr(self.commands, name)._requests ) in self.commands._requests
]) ])
soap_call_infos.sort(key=lambda x: x[5]) soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos return soap_call_infos
def debug_gateway(self) -> Dict: def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
return { return {
'manufacturer_string': self.manufacturer_string, 'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip, 'gateway_address': self.base_ip,
@ -156,9 +174,11 @@ class Gateway:
@classmethod @classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = False): igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
ignored: set = set() loop: typing.Optional[asyncio.AbstractEventLoop] = None,
required_commands = [ unicast: bool = False) -> 'Gateway':
ignored: typing.Set[str] = set()
required_commands: typing.List[str] = [
'AddPortMapping', 'AddPortMapping',
'DeletePortMapping', 'DeletePortMapping',
'GetExternalIPAddress' 'GetExternalIPAddress'
@ -172,14 +192,15 @@ class Gateway:
m_search_args = OrderedDict(igd_args) m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast) datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
try: try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address) gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location) log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands(loop) await gateway.discover_commands(loop)
requirements_met = all([required in gateway._registered_commands for required in required_commands]) requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met: if not requirements_met:
not_met = [ not_met = [
required for required in required_commands if required not in gateway._registered_commands required for required in required_commands if not gateway.commands.is_registered(required)
] ]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s", log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met) gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location) ignored.add(datagram.location)
@ -188,13 +209,17 @@ class Gateway:
log.debug('found gateway device %s', datagram.location) log.debug('found gateway device %s', datagram.location)
return gateway return gateway
except (asyncio.TimeoutError, UPnPError) as err: except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
ignored.add(datagram.location) ignored.add(datagram.location)
continue continue
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = None): igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: typing.Optional[bool] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop()
if unicast is not None: if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
@ -205,7 +230,7 @@ class Gateway:
cls._discover_gateway( cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False lan_address, gateway_address, timeout, igd_args, loop, unicast=False
) )
], return_when=asyncio.tasks.FIRST_COMPLETED) ], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
for task in pending: for task in pending:
task.cancel() task.cancel()
@ -214,56 +239,78 @@ class Gateway:
task.exception() task.exception()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
results: typing.List['asyncio.Future[Gateway]'] = list(done)
return results[0].result()
return list(done)[0].result() async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
async def discover_commands(self, loop=None):
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop) response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
self._xml_response = xml_bytes self._xml_response = xml_bytes
if get_err is not None: if get_err is not None:
raise get_err raise get_err
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase") if isinstance(spec_version, bytes):
self.spec_version = spec_version.decode()
else:
self.spec_version = spec_version
url_base = get_dict_val_case_insensitive(response, "urlbase")
if isinstance(url_base, bytes):
self.url_base = url_base.decode()
else:
self.url_base = url_base
if not self.url_base: if not self.url_base:
self.url_base = self.base_address.decode() self.url_base = self.base_address.decode()
if response: if response:
device_dict = get_dict_val_case_insensitive(response, "device") source_keys: typing.List[str] = list(response.keys())
matches: typing.List[str] = list(filter(lambda x: x.lower() == "device", source_keys))
match_key = matches[0]
match: dict = response[match_key]
# if not len(match):
# return None
# if len(match) > 1:
# raise KeyError("overlapping keys")
# if len(match) == 1:
# matched_key: typing.AnyStr = match[0]
# return source[matched_key]
# raise KeyError("overlapping keys")
self._device = Device( self._device = Device(
self._devices, self._services, **device_dict self._devices, self._services, **match
) )
else: else:
self._device = Device(self._devices, self._services) self._device = Device(self._devices, self._services)
for service_type in self.services.keys(): for service_type in self.services.keys():
await self.register_commands(self.services[service_type], loop) await self.register_commands(self.services[service_type], loop)
return None
async def register_commands(self, service: Service, loop=None): async def register_commands(self, service: Service,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL: if not service.SCPDURL:
raise UPnPError("no scpd url") raise UPnPError("no scpd url")
if not service.serviceType:
raise UPnPError("no service type")
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
self._service_descriptors[service.SCPDURL] = xml_bytes self._service_descriptors[service.SCPDURL] = xml_bytes
if get_err is not None: if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
if xml_bytes: if xml_bytes:
log.debug("response: %s", xml_bytes.decode()) log.debug("response: %s", xml_bytes.decode())
return return None
if not service_dict: if not service_dict:
return return None
action_list = get_action_list(service_dict) action_list = get_action_list(service_dict)
for name, inputs, outputs in action_list: for name, inputs, outputs in action_list:
try: try:
self.commands.register(self.base_ip, self.port, name, service.controlURL, service.serviceType.encode(), self.commands.register(name, service, inputs, outputs)
inputs, outputs, loop)
self._registered_commands[name] = service.serviceType self._registered_commands[name] = service.serviceType
log.debug("registered %s::%s", service.serviceType, name) log.debug("registered %s::%s", service.serviceType, name)
except AttributeError: except AttributeError:
s = self._unsupported_actions.get(service.serviceType, []) self._unsupported_actions.setdefault(service.serviceType, [])
s.append(name) self._unsupported_actions[service.serviceType].append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s", log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs) service.serviceType, name, inputs, outputs)
log.debug("registered service %s", service.serviceType) log.debug("registered service %s", service.serviceType)
return None

54
aioupnp/interfaces.py Normal file
View file

@ -0,0 +1,54 @@
import socket
from collections import OrderedDict
import typing
import netifaces
from aioupnp.fault import UPnPError
def get_netifaces():
return netifaces
def ifaddresses(iface: str) -> typing.Dict[int, typing.List[typing.Dict[str, str]]]:
return get_netifaces().ifaddresses(iface)
def _get_interfaces() -> typing.List[str]:
return get_netifaces().interfaces()
def _get_gateways() -> typing.Dict[typing.Union[str, int],
typing.Union[typing.Dict[int, typing.Tuple[str, str]],
typing.List[typing.Tuple[str, str, bool]]]]:
return get_netifaces().gateways()
def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]:
gateways = _get_gateways()
infos = gateways[socket.AF_INET]
assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict")
interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos
result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict(
(interface_name, (router_address, ifaddresses(interface_name)[socket.AF_INET][0]['addr']))
for router_address, interface_name, _ in interface_infos
)
for interface_name in _get_interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in result:
continue
addresses = ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
result[interface_name] = (gateway_guess, address)
_default = gateways['default']
assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list")
default: typing.Dict[int, typing.Tuple[str, str]] = _default
result['default'] = result[default[socket.AF_INET][1]]
return result
def get_gateway_and_lan_addresses(interface_name: str) -> typing.Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
raise UPnPError(f'failed to get lan and gateway addresses for {interface_name}')

View file

@ -44,10 +44,11 @@ ST
characters in the domain name must be replaced with hyphens in accordance with RFC 2141. characters in the domain name must be replaced with hyphens in accordance with RFC 2141.
""" """
import typing
from collections import OrderedDict from collections import OrderedDict
from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST
SEARCH_TARGETS = [ SEARCH_TARGETS: typing.List[str] = [
'upnp:rootdevice', 'upnp:rootdevice',
'urn:schemas-upnp-org:device:InternetGatewayDevice:1', 'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
'urn:schemas-wifialliance-org:device:WFADevice:1', 'urn:schemas-wifialliance-org:device:WFADevice:1',
@ -58,7 +59,8 @@ SEARCH_TARGETS = [
] ]
def format_packet_args(order: list, **kwargs): def format_packet_args(order: typing.List[str],
kwargs: typing.Dict[str, typing.Union[int, str]]) -> typing.Dict[str, typing.Union[int, str]]:
args = [] args = []
for o in order: for o in order:
for k, v in kwargs.items(): for k, v in kwargs.items():
@ -68,18 +70,18 @@ def format_packet_args(order: list, **kwargs):
return OrderedDict(args) return OrderedDict(args)
def packet_generator(): def packet_generator() -> typing.Iterator[typing.Dict[str, typing.Union[int, str]]]:
for st in SEARCH_TARGETS: for st in SEARCH_TARGETS:
order = ["HOST", "MAN", "MX", "ST"] order = ["HOST", "MAN", "MX", "ST"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': SSDP_DISCOVER, 'MX': 1, 'ST': st})
order = ["HOST", "MAN", "ST", "MX"] order = ["HOST", "MAN", "ST", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
order = ["HOST", "ST", "MAN", "MX"] order = ["HOST", "ST", "MAN", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})

View file

@ -1,45 +1,70 @@
import struct import struct
import socket import socket
import typing
from asyncio.protocols import DatagramProtocol from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport from asyncio.transports import BaseTransport
from unittest import mock
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]:
if transport is None or not hasattr(transport, "_extra"):
return None
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
assert sock is None or isinstance(sock, (socket.SocketType, mock.MagicMock))
return sock
class MulticastProtocol(DatagramProtocol): class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None: def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address self.multicast_address = multicast_address
self.bind_address = bind_address self.bind_address = bind_address
self.transport: DatagramTransport self.transport: typing.Optional[BaseTransport] = None
@property @property
def sock(self) -> socket.socket: def sock(self) -> typing.Optional[socket.socket]:
s: socket.socket = self.transport.get_extra_info(name='socket') return _get_sock(self.transport)
return s
def get_ttl(self) -> int: def get_ttl(self) -> int:
return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) sock = self.sock
if not sock:
raise ValueError("not connected")
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
def set_ttl(self, ttl: int = 1) -> None: def set_ttl(self, ttl: int = 1) -> None:
self.sock.setsockopt( sock = self.sock
if not sock:
return None
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
) )
return None
def join_group(self, multicast_address: str, bind_address: str) -> None: def join_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt( sock = self.sock
if not sock:
return None
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
) )
return None
def leave_group(self, multicast_address: str, bind_address: str) -> None: def leave_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt( sock = self.sock
if not sock:
raise ValueError("not connected")
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
) )
return None
def connection_made(self, transport) -> None: def connection_made(self, transport: BaseTransport) -> None:
self.transport = transport self.transport = transport
return None
@classmethod @classmethod
def create_multicast_socket(cls, bind_address: str): def create_multicast_socket(cls, bind_address: str) -> socket.socket:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0)) sock.bind((bind_address, 0))

View file

@ -2,7 +2,6 @@ import logging
import typing import typing
import re import re
from collections import OrderedDict from collections import OrderedDict
from xml.etree import ElementTree
import asyncio import asyncio
from asyncio.protocols import Protocol from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
@ -18,16 +17,22 @@ log = logging.getLogger(__name__)
HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$") HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]: def parse_http_response_code(http_response: bytes) -> typing.Tuple[bytes, bytes]:
parsed: typing.List[typing.Tuple[bytes, bytes]] = HTTP_CODE_REGEX.findall(http_response)
return parsed[0]
def parse_headers(response: bytes) -> typing.Tuple[typing.Dict[bytes, bytes], int, bytes]:
lines = response.split(b'\r\n') lines = response.split(b'\r\n')
headers = OrderedDict([ headers: typing.Dict[bytes, bytes] = OrderedDict([
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' ')) (l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
for l in response.split(b'\r\n') for l in response.split(b'\r\n')
]) ])
if len(lines) != len(headers): if len(lines) != len(headers):
raise ValueError("duplicate headers") raise ValueError("duplicate headers")
http_response = tuple(headers.keys())[0] header_keys: typing.List[bytes] = list(headers.keys())
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0] http_response = header_keys[0]
response_code, message = parse_http_response_code(http_response)
del headers[http_response] del headers[http_response]
return headers, int(response_code), message return headers, int(response_code), message
@ -40,37 +45,42 @@ class SCPDHTTPClientProtocol(Protocol):
and devices respond with an invalid HTTP version line and devices respond with an invalid HTTP version line
""" """
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None, def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]',
soap_service_id: str=None) -> None: soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None:
self.message = message self.message = message
self.response_buff = b"" self.response_buff = b""
self.finished = finished self.finished = finished
self.soap_method = soap_method self.soap_method = soap_method
self.soap_service_id = soap_service_id self.soap_service_id = soap_service_id
self._response_code = 0
self._response_code: int = 0 self._response_msg = b""
self._response_msg: bytes = b"" self._content_length = 0
self._content_length: int = 0
self._got_headers = False self._got_headers = False
self._headers: dict = {} self._headers: typing.Dict[bytes, bytes] = {}
self._body = b"" self._body = b""
self.transport: typing.Optional[asyncio.WriteTransport] = None
def connection_made(self, transport): def connection_made(self, transport: asyncio.BaseTransport) -> None:
transport.write(self.message) assert isinstance(transport, asyncio.WriteTransport)
self.transport = transport
self.transport.write(self.message)
return None
def data_received(self, data): def data_received(self, data: bytes) -> None:
self.response_buff += data self.response_buff += data
for i, line in enumerate(self.response_buff.split(b'\r\n')): for i, line in enumerate(self.response_buff.split(b'\r\n')):
if not line: # we hit the blank line between the headers and the body if not line: # we hit the blank line between the headers and the body
if i == (len(self.response_buff.split(b'\r\n')) - 1): if i == (len(self.response_buff.split(b'\r\n')) - 1):
return # the body is still yet to be written return None # the body is still yet to be written
if not self._got_headers: if not self._got_headers:
self._headers, self._response_code, self._response_msg = parse_headers( self._headers, self._response_code, self._response_msg = parse_headers(
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i]) b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
) )
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length') content_length = get_dict_val_case_insensitive(
self._headers, b'Content-Length'
)
if content_length is None: if content_length is None:
return return None
self._content_length = int(content_length or 0) self._content_length = int(content_length or 0)
self._got_headers = True self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:]) body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
@ -86,21 +96,28 @@ class SCPDHTTPClientProtocol(Protocol):
) )
) )
) )
return return None
return None
async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typing.Tuple[typing.Dict, bytes, async def scpd_get(control_url: str, address: str, port: int,
typing.Optional[Exception]]: loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> typing.Tuple[
loop = loop or asyncio.get_event_loop_policy().get_event_loop() typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
finished: asyncio.Future = asyncio.Future() loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection( finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]' = asyncio.Future(loop=loop)
lambda : SCPDHTTPClientProtocol(packet, finished), address, port proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
) )
protocol = connect_tup[1]
transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
error = None error = None
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
try: try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) body, response_code, response_msg = await wait_task
except asyncio.TimeoutError: except asyncio.TimeoutError:
error = UPnPError("get request timed out") error = UPnPError("get request timed out")
body = b'' body = b''
@ -112,33 +129,41 @@ async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typi
if not error: if not error:
try: try:
return deserialize_scpd_get_response(body), body, None return deserialize_scpd_get_response(body), body, None
except ElementTree.ParseError as err: except Exception as err:
error = UPnPError(err) error = UPnPError(err)
return {}, body, error return {}, body, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
loop=None, **kwargs) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: loop: typing.Optional[asyncio.AbstractEventLoop] = None,
loop = loop or asyncio.get_event_loop_policy().get_event_loop() **kwargs: typing.Dict[str, typing.Any]
finished: asyncio.Future = asyncio.Future() ) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]' = asyncio.Future(loop=loop)
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection( proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
lambda : SCPDHTTPClientProtocol( SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
packet, finished, soap_method=method, soap_service_id=service_id.decode(), connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
), address, port proto_factory, address, port
) )
protocol = connect_tup[1]
transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
try: try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
body, response_code, response_msg = await wait_task
except asyncio.TimeoutError: except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout") return {}, b'', UPnPError("Timeout")
except UPnPError as err: except UPnPError as err:
return {}, protocol.response_buff, err return {}, protocol.response_buff, err
finally: finally:
# raw_response = protocol.response_buff
transport.close() transport.close()
try: try:
return ( return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None deserialize_soap_post_response(body, method, service_id.decode()), body, None
) )
except (ElementTree.ParseError, UPnPError) as err: except Exception as err:
return {}, body, UPnPError(err) return {}, body, UPnPError(err)

View file

@ -3,8 +3,7 @@ import binascii
import asyncio import asyncio
import logging import logging
import typing import typing
from collections import OrderedDict import socket
from asyncio.futures import Future
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
@ -18,32 +17,48 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None, def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None,
unicast: bool = False) -> None: unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: typing.Optional[DatagramTransport] = None
self._unicast = unicast self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations self._ignored: typing.Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = [] self._pending_searches: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
self.notifications: typing.List = [] self.notifications: typing.List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop)
def disconnect(self): def connection_made(self, transport) -> None:
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
super().connection_made(transport)
self.connected.set()
def disconnect(self) -> None:
if self.transport: if self.transport:
try:
self.leave_group(self.multicast_address, self.bind_address)
except ValueError:
pass
except Exception:
log.exception("unexpected error leaving multicast group")
self.transport.close() self.transport.close()
self.connected.clear()
while self._pending_searches: while self._pending_searches:
pending = self._pending_searches.pop()[2] pending = self._pending_searches.pop()[2]
if not pending.cancelled() and not pending.done(): if not pending.cancelled() and not pending.done():
pending.cancel() pending.cancel()
return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored: if packet.location in self._ignored:
return return None
tmp: typing.List = [] # TODO: fix this
set_futures: typing.List = [] tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
while self._pending_searches: set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = []
t: tuple = self._pending_searches.pop() while len(self._pending_searches):
a, s = t[0], t[1] t: typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle] = self._pending_searches.pop()
if (address == a) and (s in [packet.st, "upnp:rootdevice"]): if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f: Future = t[2] f = t[2]
if f not in set_futures: if f not in set_futures:
set_futures.append(f) set_futures.append(f)
if not f.done(): if not f.done():
@ -52,38 +67,41 @@ class SSDPProtocol(MulticastProtocol):
tmp.append(t) tmp.append(t)
while tmp: while tmp:
self._pending_searches.append(tmp.pop()) self._pending_searches.append(tmp.pop())
return None
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]): def _send_m_search(self, address: str, packet: SSDPDatagram) -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS dest = address if self._unicast else SSDP_IP_ADDRESS
for packet in packets: if not self.transport:
raise UPnPError("SSDP transport not connected")
log.debug("send m search to %s: %s", dest, packet.st) log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram: async def m_search(self, address: str, timeout: float,
fut: Future = Future() datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
packets: typing.List[SSDPDatagram] = [] fut: 'asyncio.Future[SSDPDatagram]' = asyncio.Future(loop=self.loop)
for datagram in datagrams: for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None assert packet.st is not None
self._pending_searches.append((address, packet.st, fut)) self._pending_searches.append(
packets.append(packet) (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet))
self.send_many_m_searches(address, packets), )
return await fut return await asyncio.wait_for(fut, timeout)
def datagram_received(self, data, addr) -> None: def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore
if addr[0] == self.bind_address: if addr[0] == self.bind_address:
return return None
try: try:
packet = SSDPDatagram.decode(data) packet = SSDPDatagram.decode(data)
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet) log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
except UPnPError as err: except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data)) binascii.hexlify(data))
return return None
if packet._packet_type == packet._OK: if packet._packet_type == packet._OK:
self._callback_m_search_ok(addr[0], packet) self._callback_m_search_ok(addr[0], packet)
return return None
# elif packet._packet_type == packet._NOTIFY: # elif packet._packet_type == packet._NOTIFY:
# log.debug("%s:%i sent us a notification: %s", packet) # log.debug("%s:%i sent us a notification: %s", packet)
# if packet.nt == SSDP_ROOT_DEVICE: # if packet.nt == SSDP_ROOT_DEVICE:
@ -104,17 +122,18 @@ class SSDPProtocol(MulticastProtocol):
# return # return
async def listen_ssdp(lan_address: str, gateway_address: str, loop=None, async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport, ignored: typing.Optional[typing.Set[str]] = None,
SSDPProtocol, str, str]: unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop_policy().get_event_loop() loop = loop or asyncio.get_event_loop()
try: try:
sock = SSDPProtocol.create_multicast_socket(lan_address) sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple = await loop.create_datagram_endpoint( listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
) )
transport: DatagramTransport = listen_result[0] transport = listen_result[0]
protocol: SSDPProtocol = listen_result[1] protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol)
except Exception as err: except Exception as err:
print(err) print(err)
raise UPnPError(err) raise UPnPError(err)
@ -125,30 +144,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
protocol.disconnect() protocol.disconnect()
raise UPnPError(err) raise UPnPError(err)
return transport, protocol, gateway_address, lan_address return protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]],
loop=None, ignored: typing.Set[str] = None, timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: bool = False) -> SSDPDatagram: ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored, unicast
) )
try: try:
return await asyncio.wait_for( return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout
)
except (asyncio.TimeoutError, asyncio.CancelledError): except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:
protocol.disconnect() protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]: loop: typing.Optional[asyncio.AbstractEventLoop] = None,
transport, protocol, gateway_address, lan_address = await listen_ssdp( ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored, unicast
) )
await protocol.connected.wait()
packet_args = list(packet_generator()) packet_args = list(packet_generator())
batch_size = 2 batch_size = 2
batch_timeout = float(timeout) / float(len(packet_args)) batch_timeout = float(timeout) / float(len(packet_args))
@ -157,7 +177,7 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
packet_args = packet_args[batch_size:] packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size) log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try: try:
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), batch_timeout) await protocol.m_search(gateway_address, batch_timeout, args)
protocol.disconnect() protocol.disconnect()
return args return args
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -166,9 +186,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
SSDPDatagram]: ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.Tuple[typing.Dict[str,
typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time # we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast) args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked # check the args in the batch that got a reply one at a time to see which one worked

View file

@ -1,8 +1,9 @@
import re import re
from typing import Dict from typing import Dict, Any, List, Tuple
from xml.etree import ElementTree from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION, DEVICE, ROOT from aioupnp.constants import XML_VERSION
from aioupnp.util import etree_to_dict, flatten_keys from aioupnp.serialization.xml import xml_to_dict
from aioupnp.util import flatten_keys
CONTENT_PATTERN = re.compile( CONTENT_PATTERN = re.compile(
@ -28,34 +29,38 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
if not path.startswith("/"): if not path.startswith("/"):
path = "/" + path path = "/" + path
return ( return (
( f'GET {path} HTTP/1.1\r\n'
'GET %s HTTP/1.1\r\n' f'Accept-Encoding: gzip\r\n'
'Accept-Encoding: gzip\r\n' f'Host: {host}\r\n'
'Host: %s\r\n' f'Connection: Close\r\n'
'Connection: Close\r\n' f'\r\n'
'\r\n'
) % (path, host)
).encode() ).encode()
def deserialize_scpd_get_response(content: bytes) -> Dict: def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
if XML_VERSION.encode() in content: if XML_VERSION.encode() in content:
parsed = CONTENT_PATTERN.findall(content) parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0] xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode())
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
return parse_device_dict(xml_dict) return parse_device_dict(xml_dict)
return {} return {}
def parse_device_dict(xml_dict: dict) -> Dict: def parse_device_dict(xml_dict: Dict[str, Any]) -> Dict[str, Any]:
keys = list(xml_dict.keys()) keys = list(xml_dict.keys())
found = False
for k in keys: for k in keys:
m = XML_ROOT_SANITY_PATTERN.findall(k) m: List[Tuple[str, str, str, str, str, str]] = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]: if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0] schema_key: str = m[1][0]
root = m[2][5] root: str = m[2][5]
xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root] flattened = flatten_keys(xml_dict, "{%s}" % schema_key)
if root not in flattened:
raise UPnPError("root device not found")
xml_dict = flattened[root]
found = True
break break
if not found:
raise UPnPError("device not found")
result = {} result = {}
for k, v in xml_dict.items(): for k, v in xml_dict.items():
if isinstance(xml_dict[k], dict): if isinstance(xml_dict[k], dict):
@ -65,10 +70,9 @@ def parse_device_dict(xml_dict: dict) -> Dict:
if len(parsed_k) == 2: if len(parsed_k) == 2:
inner_d[parsed_k[0]] = inner_v inner_d[parsed_k[0]] = inner_v
else: else:
assert len(parsed_k) == 3 assert len(parsed_k) == 3, f"expected len=3, got {len(parsed_k)}"
inner_d[parsed_k[1]] = inner_v inner_d[parsed_k[1]] = inner_v
result[k] = inner_d result[k] = inner_d
else: else:
result[k] = v result[k] = v
return result return result

View file

@ -1,64 +1,65 @@
import re import re
from xml.etree import ElementTree import typing
from aioupnp.util import etree_to_dict, flatten_keys from aioupnp.util import flatten_keys
from aioupnp.fault import handle_fault, UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL
from aioupnp.serialization.xml import xml_to_dict
CONTENT_NO_XML_VERSION_PATTERN = re.compile( CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode() "(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
) )
def serialize_soap_post(method: str, param_names: list, service_id: bytes, gateway_address: bytes, def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes,
control_url: bytes, **kwargs) -> bytes: control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes:
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in param_names) args = "".join(f"<{n}>{kwargs.get(n)}</{n}>" for n in param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" ' soap_body = (f'\r\n{XML_VERSION}\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>' f's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % ( f'<u:{method} xmlns:u="{service_id.decode()}">{args}</u:{method}></s:Body></s:Envelope>')
XML_VERSION, method, service_id.decode(),
args, method))
if "http://" in gateway_address.decode(): if "http://" in gateway_address.decode():
host = gateway_address.decode().split("http://")[1] host = gateway_address.decode().split("http://")[1]
else: else:
host = gateway_address.decode() host = gateway_address.decode()
return ( return (
( f'POST {control_url.decode()} HTTP/1.1\r\n' # could be just / even if it shouldn't be
'POST %s HTTP/1.1\r\n' f'Host: {host}\r\n'
'Host: %s\r\n' f'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' f'Content-Length: {len(soap_body)}\r\n'
'Content-Length: %i\r\n' f'Content-Type: text/xml\r\n'
'Content-Type: text/xml\r\n' f'SOAPAction: \"{service_id.decode()}#{method}\"\r\n'
'SOAPAction: \"%s#%s\"\r\n' f'Connection: Close\r\n'
'Connection: Close\r\n' f'Cache-Control: no-cache\r\n'
'Cache-Control: no-cache\r\n' f'Pragma: no-cache\r\n'
'Pragma: no-cache\r\n' f'{soap_body}'
'%s' f'\r\n'
'\r\n'
) % (
control_url.decode(), # could be just / even if it shouldn't be
host,
len(soap_body),
service_id.decode(), # maybe no quotes
method,
soap_body
)
).encode() ).encode()
def deserialize_soap_post_response(response: bytes, method: str, service_id: str) -> dict: def deserialize_soap_post_response(response: bytes, method: str,
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(response) service_id: str) -> typing.Dict[str, typing.Dict[str, str]]:
parsed: typing.List[typing.List[bytes]] = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
content = b'' if not parsed else parsed[0][0] content = b'' if not parsed else parsed[0][0]
content_dict = etree_to_dict(ElementTree.fromstring(content.decode())) content_dict = xml_to_dict(content.decode())
envelope = content_dict[ENVELOPE] envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id) if not isinstance(envelope[BODY], dict):
body = handle_fault(response_body) # raises UPnPError if there is a fault # raise UPnPError('blank response')
return {} # TODO: raise
response_body: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
envelope[BODY], f"{'{' + service_id + '}'}"
)
if not response_body:
# raise UPnPError('blank response')
return {} # TODO: raise
if FAULT in response_body:
fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
response_body[FAULT], "{%s}" % CONTROL
)
raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
response_key = None response_key = None
if not body: for key in response_body:
return {}
for key in body:
if method in key: if method in key:
response_key = key response_key = key
break break
if not response_key: if not response_key:
raise UPnPError("unknown response fields for %s: %s" % (method, body)) raise UPnPError(f"unknown response fields for {method}: {response_body}")
return body[response_key] return response_body[response_key]

View file

@ -3,43 +3,67 @@ import logging
import binascii import binascii
import json import json
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List, Optional, Dict, Union, Tuple, Callable
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator from aioupnp.constants import line_separator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_template = "(?i)^(%s):[ ]*(.*)$" _template = "(?i)^(%s):[ ]*(.*)$"
ssdp_datagram_patterns = {
'host': (re.compile("(?i)^(host):(.*)$"), str),
'st': (re.compile(_template % 'st'), str),
'man': (re.compile(_template % 'man'), str),
'mx': (re.compile(_template % 'mx'), int),
'nt': (re.compile(_template % 'nt'), str),
'nts': (re.compile(_template % 'nts'), str),
'usn': (re.compile(_template % 'usn'), str),
'location': (re.compile(_template % 'location'), str),
'cache_control': (re.compile(_template % 'cache[-|_]control'), str),
'server': (re.compile(_template % 'server'), str),
}
vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
class SSDPDatagram(object): def match_vendor(line: str) -> Optional[Tuple[str, str]]:
match: List[Tuple[str, str, str]] = vendor_pattern.findall(line)
if match:
vendor_key: str = match[-1][0].lstrip(" ").rstrip(" ")
vendor_value: str = match[-1][2].lstrip(" ").rstrip(" ")
return vendor_key, vendor_value
return None
def compile_find(pattern: str) -> Callable[[str], Optional[str]]:
p = re.compile(pattern)
def find(line: str) -> Optional[str]:
result: List[List[str]] = []
for outer in p.findall(line):
result.append([])
for inner in outer:
result[-1].append(inner)
if result:
return result[-1][-1].lstrip(" ").rstrip(" ")
return None
return find
ssdp_datagram_patterns: Dict[str, Callable[[str], Optional[str]]] = {
'host': compile_find("(?i)^(host):(.*)$"),
'st': compile_find(_template % 'st'),
'man': compile_find(_template % 'man'),
'mx': compile_find(_template % 'mx'),
'nt': compile_find(_template % 'nt'),
'nts': compile_find(_template % 'nts'),
'usn': compile_find(_template % 'usn'),
'location': compile_find(_template % 'location'),
'cache_control': compile_find(_template % 'cache[-|_]control'),
'server': compile_find(_template % 'server'),
}
class SSDPDatagram:
_M_SEARCH = "M-SEARCH" _M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY" _NOTIFY = "NOTIFY"
_OK = "OK" _OK = "OK"
_start_lines = { _start_lines: Dict[str, str] = {
_M_SEARCH: "M-SEARCH * HTTP/1.1", _M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1", _NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK" _OK: "HTTP/1.1 200 OK"
} }
_friendly_names = { _friendly_names: Dict[str, str] = {
_M_SEARCH: "m-search", _M_SEARCH: "m-search",
_NOTIFY: "notify", _NOTIFY: "notify",
_OK: "m-search response" _OK: "m-search response"
@ -47,9 +71,7 @@ class SSDPDatagram(object):
_vendor_field_pattern = vendor_pattern _vendor_field_pattern = vendor_pattern
_patterns = ssdp_datagram_patterns _required_fields: Dict[str, List[str]] = {
_required_fields = {
_M_SEARCH: [ _M_SEARCH: [
'host', 'host',
'man', 'man',
@ -75,137 +97,137 @@ class SSDPDatagram(object):
] ]
} }
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None: def __init__(self, packet_type: str, kwargs: Optional[Dict[str, Union[str, int]]] = None) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]: if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type)) raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type self._packet_type = packet_type
kwargs = kwargs or OrderedDict() kw: Dict[str, Union[str, int]] = kwargs or OrderedDict()
self._field_order: list = [ self._field_order: List[str] = [
k.lower().replace("-", "_") for k in kwargs.keys() k.lower().replace("-", "_") for k in kw.keys()
] ]
self.host = None self.host: Optional[str] = None
self.man = None self.man: Optional[str] = None
self.mx = None self.mx: Optional[Union[str, int]] = None
self.st = None self.st: Optional[str] = None
self.nt = None self.nt: Optional[str] = None
self.nts = None self.nts: Optional[str] = None
self.usn = None self.usn: Optional[str] = None
self.location = None self.location: Optional[str] = None
self.cache_control = None self.cache_control: Optional[str] = None
self.server = None self.server: Optional[str] = None
self.date = None self.date: Optional[str] = None
self.ext = None self.ext: Optional[str] = None
for k, v in kwargs.items(): for k, v in kw.items():
normalized = k.lower().replace("-", "_") normalized = k.lower().replace("-", "_")
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self, normalized) is None: if not normalized.startswith("_") and hasattr(self, normalized):
if getattr(self, normalized, None) is None:
setattr(self, normalized, v) setattr(self, normalized, v)
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()} self._case_mappings: Dict[str, str] = {k.lower(): k for k in kw.keys()}
for k in self._required_fields[self._packet_type]: for k in self._required_fields[self._packet_type]:
if getattr(self, k) is None: if getattr(self, k, None) is None:
raise UPnPError("missing required field %s" % k) raise UPnPError("missing required field %s" % k)
def get_cli_igd_kwargs(self) -> str: def get_cli_igd_kwargs(self) -> str:
fields = [] fields = []
for field in self._field_order: for field in self._field_order:
v = getattr(self, field) v = getattr(self, field, None)
if v is None:
raise UPnPError("missing required field %s" % field)
fields.append("--%s=%s" % (self._case_mappings.get(field, field), v)) fields.append("--%s=%s" % (self._case_mappings.get(field, field), v))
return " ".join(fields) return " ".join(fields)
def __repr__(self) -> str: def __repr__(self) -> str:
return self.as_json() return self.as_json()
def __getitem__(self, item): def __getitem__(self, item: str) -> Union[str, int]:
for i in self._required_fields[self._packet_type]: for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower(): if i.lower() == item.lower():
return getattr(self, i) return getattr(self, i)
raise KeyError(item) raise KeyError(item)
def get_friendly_name(self) -> str:
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines: int = 2) -> str: def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]] lines = [self._start_lines[self._packet_type]]
for attr_name in self._field_order: lines.extend(
if attr_name not in self._required_fields[self._packet_type]: f"{self._case_mappings.get(attr_name.lower(), attr_name.upper())}: {str(getattr(self, attr_name))}"
continue for attr_name in self._field_order if attr_name in self._required_fields[self._packet_type]
attr = getattr(self, attr_name) )
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name == 'mx':
value = str(attr)
else:
value = attr
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
serialized = line_separator.join(lines) serialized = line_separator.join(lines)
for _ in range(trailing_newlines): for _ in range(trailing_newlines):
serialized += line_separator serialized += line_separator
return serialized return serialized
def as_dict(self) -> OrderedDict: def as_dict(self) -> Dict[str, Union[str, int]]:
return self._lines_to_content_dict(self.encode().split(line_separator)) return self._lines_to_content_dict(self.encode().split(line_separator))
def as_json(self) -> str: def as_json(self) -> str:
return json.dumps(self.as_dict(), indent=2) return json.dumps(self.as_dict(), indent=2)
@classmethod @classmethod
def decode(cls, datagram: bytes): def decode(cls, datagram: bytes) -> 'SSDPDatagram':
packet = cls._from_string(datagram.decode()) packet = cls._from_string(datagram.decode())
if packet is None: if packet is None:
raise UPnPError( raise UPnPError(
"failed to decode datagram: {}".format(binascii.hexlify(datagram)) "failed to decode datagram: {}".format(binascii.hexlify(datagram))
) )
for attr_name in packet._required_fields[packet._packet_type]: for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name) if getattr(packet, attr_name, None) is None:
if attr is None:
raise UPnPError( raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name) "required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
) )
return packet return packet
@classmethod @classmethod
def _lines_to_content_dict(cls, lines: list) -> OrderedDict: def _lines_to_content_dict(cls, lines: List[str]) -> Dict[str, Union[str, int]]:
result: OrderedDict = OrderedDict() result: Dict[str, Union[str, int]] = OrderedDict()
matched_keys: List[str] = []
for line in lines: for line in lines:
if not line: if not line:
continue continue
matched = False matched = False
for name, (pattern, field_type) in cls._patterns.items(): for name, pattern in ssdp_datagram_patterns.items():
if name not in result and pattern.findall(line): if name not in matched_keys:
match = pattern.findall(line)[-1][-1] if name.lower() == 'mx':
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" ")) _matched_int = pattern(line)
if _matched_int is not None:
match_int = int(_matched_int)
result[line[:len(name)]] = match_int
matched = True matched = True
matched_keys.append(name)
break
else:
match = pattern(line)
if match is not None:
result[line[:len(name)]] = match
matched = True
matched_keys.append(name)
break break
if not matched: if not matched:
if cls._vendor_field_pattern.findall(line): matched_vendor = match_vendor(line)
match = cls._vendor_field_pattern.findall(line)[-1] if matched_vendor and matched_vendor[0] not in result:
vendor_key = match[0].lstrip(" ").rstrip(" ") result[matched_vendor[0]] = matched_vendor[1]
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result return result
@classmethod @classmethod
def _from_string(cls, datagram: str): def _from_string(cls, datagram: str) -> Optional['SSDPDatagram']:
lines = [l for l in datagram.split(line_separator) if l] lines = [l for l in datagram.split(line_separator) if l]
if not lines: if not lines:
return return None
if lines[0] == cls._start_lines[cls._M_SEARCH]: if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:]) return cls._from_request(lines[1:])
if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]: if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]:
return cls._from_notify(lines[1:]) return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]: if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:]) return cls._from_response(lines[1:])
return None
@classmethod @classmethod
def _from_response(cls, lines: List): def _from_response(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._OK, cls._lines_to_content_dict(lines)) return cls(cls._OK, cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_notify(cls, lines: List): def _from_notify(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines)) return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_request(cls, lines: List): def _from_request(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines)) return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))

View file

@ -0,0 +1,81 @@
import typing
from collections import OrderedDict
from defusedxml import ElementTree
str_any_dict = typing.Dict[str, typing.Any]
def parse_xml(xml_str: str) -> ElementTree:
element: ElementTree = ElementTree.fromstring(xml_str)
return element
def _element_text(element_tree: ElementTree) -> typing.Optional[str]:
# if element_tree.attrib:
# element: typing.Dict[str, str] = OrderedDict()
# for k, v in element_tree.attrib.items():
# element['@' + k] = v
# if not element_tree.text:
# return element
# element['#text'] = element_tree.text.strip()
# return element
if element_tree.text:
return element_tree.text.strip()
return None
def _get_element_children(element_tree: ElementTree) -> typing.Dict[str, typing.Union[
str, typing.List[typing.Union[typing.Dict[str, str], str]]]]:
element_children = _get_child_dicts(element_tree)
element: typing.Dict[str, typing.Any] = OrderedDict()
keys: typing.List[str] = list(element_children.keys())
for k in keys:
v: typing.Union[str, typing.List[typing.Any], typing.Dict[str, typing.Any]] = element_children[k]
if len(v) == 1 and isinstance(v, list):
l: typing.List[typing.Union[typing.Dict[str, str], str]] = v
element[k] = l[0]
else:
element[k] = v
return element
def _get_child_dicts(element: ElementTree) -> typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]]:
children_dicts: typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]] = OrderedDict()
children: typing.List[ElementTree] = list(element)
for child in children:
child_dict = _recursive_element_to_dict(child)
child_keys: typing.List[str] = list(child_dict.keys())
for k in child_keys:
assert k in child_dict
v: typing.Union[typing.Dict[str, str], str] = child_dict[k]
if k not in children_dicts.keys():
new_item = [v]
children_dicts[k] = new_item
else:
sublist = children_dicts[k]
assert isinstance(sublist, list)
sublist.append(v)
return children_dicts
def _recursive_element_to_dict(element_tree: ElementTree) -> typing.Dict[str, typing.Any]:
if len(element_tree):
element_result: typing.Dict[str, typing.Dict[str, typing.Union[
str, typing.List[typing.Union[str, typing.Dict[str, typing.Any]]]]]] = OrderedDict()
children_element = _get_element_children(element_tree)
if element_tree.tag is not None:
element_result[element_tree.tag] = children_element
return element_result
else:
element_text = _element_text(element_tree)
if element_text is not None:
base_element_result: typing.Dict[str, typing.Any] = OrderedDict()
if element_tree.tag is not None:
base_element_result[element_tree.tag] = element_text
return base_element_result
null_result: typing.Dict[str, str] = OrderedDict()
return null_result
def xml_to_dict(xml_str: str) -> typing.Dict[str, typing.Any]:
return _recursive_element_to_dict(parse_xml(xml_str))

View file

@ -1,32 +1,27 @@
import os # import os
# import zlib
# import base64
import logging import logging
import json import json
import asyncio import asyncio
import zlib
import base64
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, Dict, List, Union from typing import Tuple, Dict, List, Union, Optional
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.util import get_gateway_and_lan_addresses from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.commands import SOAPCommand
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import SOAPCommands
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def cli(fn): # def _encode(x):
fn._cli = True # if isinstance(x, bytes):
return fn # return x.decode()
# elif isinstance(x, Exception):
# return str(x)
def _encode(x): # return x
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
class UPnP: class UPnP:
@ -35,6 +30,10 @@ class UPnP:
self.gateway_address = gateway_address self.gateway_address = gateway_address
self.gateway = gateway self.gateway = gateway
@classmethod
def get_annotations(cls, command: str) -> Dict[str, type]:
return getattr(SOAPCommands, command).__annotations__
@classmethod @classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> Tuple[str, str]: interface_name: str = 'default') -> Tuple[str, str]:
@ -46,21 +45,20 @@ class UPnP:
@classmethod @classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30, async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
igd_args: OrderedDict = None, interface_name: str = 'default', loop=None): igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default',
try: loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP':
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
gateway = await Gateway.discover_gateway( gateway = await Gateway.discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop lan_address, gateway_address, timeout, igd_args, loop
) )
return cls(lan_address, gateway_address, gateway) return cls(lan_address, gateway_address, gateway)
@classmethod @classmethod
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default', igd_args: Optional[Dict[str, Union[int, str]]] = None,
loop=None) -> Dict: unicast: bool = True, interface_name: str = 'default',
loop: Optional[asyncio.AbstractEventLoop] = None
) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]:
if not lan_address or not gateway_address: if not lan_address or not gateway_address:
try: try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -80,93 +78,74 @@ class UPnP:
'discover_reply': datagram.as_dict() 'discover_reply': datagram.as_dict()
} }
@cli
async def get_external_ip(self) -> str: async def get_external_ip(self) -> str:
return await self.gateway.commands.GetExternalIPAddress() return await self.gateway.commands.GetExternalIPAddress()
@cli
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str, async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
description: str) -> None: description: str) -> None:
return await self.gateway.commands.AddPortMapping( await self.gateway.commands.AddPortMapping(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol, NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address, NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0' NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
) )
@cli
async def get_port_mapping_by_index(self, index: int) -> Dict:
result = await self._get_port_mapping_by_index(index)
if result:
if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand):
return {
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
}
return {}
async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Union[None, str], int, str,
int, str, bool, str, int]]:
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect
except UPnPError:
return None return None
@cli async def get_port_mapping_by_index(self, index: int) -> Tuple[str, int, str, int, str, bool, str, int]:
async def get_redirects(self) -> List[Dict]: return await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
redirects = []
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
cnt = 0 cnt = 0
try:
redirect = await self.get_port_mapping_by_index(cnt) redirect = await self.get_port_mapping_by_index(cnt)
while redirect: except UPnPError:
return redirects
while redirect is not None:
redirects.append(redirect) redirects.append(redirect)
cnt += 1 cnt += 1
try:
redirect = await self.get_port_mapping_by_index(cnt) redirect = await self.get_port_mapping_by_index(cnt)
except UPnPError:
break
return redirects return redirects
@cli async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Dict:
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time> :return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
""" """
return await self.gateway.commands.GetSpecificPortMappingEntry(
try:
result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
) )
if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
except UPnPError:
pass
return {}
@cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None: async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
:return: None :return: None
""" """
return await self.gateway.commands.DeletePortMapping( await self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
) )
return None
@cli async def get_next_mapping(self, port: int, protocol: str, description: str,
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int: internal_port: Optional[int] = None) -> int:
if protocol not in ["UDP", "TCP"]: """
raise UPnPError("unsupported protocol: {}".format(protocol)) :param port: (int) external port to redirect from
internal_port = int(internal_port or port) :param protocol: (str) 'UDP' | 'TCP'
requested_port = int(internal_port) :param description: (str) mapping description
redirect_tups = [] :param internal_port: (int) internal port to redirect to
cnt = 0
:return: (int) <mapped port>
"""
_internal_port = int(internal_port or port)
requested_port = int(_internal_port)
port = int(port) port = int(port)
redirect = await self._get_port_mapping_by_index(cnt) redirect_tups = await self.get_redirects()
while redirect:
redirect_tups.append(redirect)
cnt += 1
redirect = await self._get_port_mapping_by_index(cnt)
redirects = { redirects: Dict[Tuple[int, str], Tuple[str, int, str]] = {
(ext_port, proto): (int_host, int_port, desc) (ext_port, proto): (int_host, int_port, desc)
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups
} }
@ -176,172 +155,119 @@ class UPnP:
if int_host == self.lan_address and int_port == requested_port and desc == description: if int_host == self.lan_address and int_port == requested_port and desc == description:
return port return port
port += 1 port += 1
await self.add_port_mapping( # set one up await self.add_port_mapping(port, protocol, _internal_port, self.lan_address, description)
port, protocol, internal_port, self.lan_address, description
)
return port return port
@cli # @cli
async def debug_gateway(self) -> str: # async def debug_gateway(self) -> str:
return json.dumps({ # return json.dumps({
"gateway": self.gateway.debug_gateway(), # "gateway": self.gateway.debug_gateway(),
"client_address": self.lan_address, # "client_address": self.lan_address,
}, default=_encode, indent=2) # }, default=_encode, indent=2)
#
# @property
# def zipped_debugging_info(self) -> str:
# return base64.b64encode(zlib.compress(
# json.dumps({
# "gateway": self.gateway.debug_gateway(),
# "client_address": self.lan_address,
# }, default=_encode, indent=2).encode()
# )).decode()
#
# @cli
# async def get_natrsip_status(self) -> Tuple[bool, bool]:
# """Returns (NewRSIPAvailable, NewNATEnabled)"""
# return await self.gateway.commands.GetNATRSIPStatus()
#
# @cli
# async def set_connection_type(self, NewConnectionType: str) -> None:
# """Returns None"""
# return await self.gateway.commands.SetConnectionType(NewConnectionType)
#
# @cli
# async def get_connection_type_info(self) -> Tuple[str, str]:
# """Returns (NewConnectionType, NewPossibleConnectionTypes)"""
# return await self.gateway.commands.GetConnectionTypeInfo()
#
# @cli
# async def get_status_info(self) -> Tuple[str, str, int]:
# """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
# return await self.gateway.commands.GetStatusInfo()
#
# @cli
# async def force_termination(self) -> None:
# """Returns None"""
# return await self.gateway.commands.ForceTermination()
#
# @cli
# async def request_connection(self) -> None:
# """Returns None"""
# return await self.gateway.commands.RequestConnection()
#
# @cli
# async def get_common_link_properties(self):
# """Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate,
# NewPhysicalLinkStatus)"""
# return await self.gateway.commands.GetCommonLinkProperties()
#
# @cli
# async def get_total_bytes_sent(self) -> int:
# """Returns (NewTotalBytesSent)"""
# return await self.gateway.commands.GetTotalBytesSent()
#
# @cli
# async def get_total_bytes_received(self):
# """Returns (NewTotalBytesReceived)"""
# return await self.gateway.commands.GetTotalBytesReceived()
#
# @cli
# async def get_total_packets_sent(self):
# """Returns (NewTotalPacketsSent)"""
# return await self.gateway.commands.GetTotalPacketsSent()
#
# @cli
# async def get_total_packets_received(self):
# """Returns (NewTotalPacketsReceived)"""
# return await self.gateway.commands.GetTotalPacketsReceived()
#
# @cli
# async def x_get_ics_statistics(self) -> Tuple[int, int, int, int, str, str]:
# """Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived,
# Layer1DownstreamMaxBitRate, Uptime)"""
# return await self.gateway.commands.X_GetICSStatistics()
#
# @cli
# async def get_default_connection_service(self):
# """Returns (NewDefaultConnectionService)"""
# return await self.gateway.commands.GetDefaultConnectionService()
#
# @cli
# async def set_default_connection_service(self, NewDefaultConnectionService: str) -> None:
# """Returns (None)"""
# return await self.gateway.commands.SetDefaultConnectionService(NewDefaultConnectionService)
#
# @cli
# async def set_enabled_for_internet(self, NewEnabledForInternet: bool) -> None:
# return await self.gateway.commands.SetEnabledForInternet(NewEnabledForInternet)
#
# @cli
# async def get_enabled_for_internet(self) -> bool:
# return await self.gateway.commands.GetEnabledForInternet()
#
# @cli
# async def get_maximum_active_connections(self, NewActiveConnectionIndex: int):
# return await self.gateway.commands.GetMaximumActiveConnections(NewActiveConnectionIndex)
#
# @cli
# async def get_active_connections(self) -> Tuple[str, str]:
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
# return await self.gateway.commands.GetActiveConnections()
@property
def zipped_debugging_info(self) -> str:
return base64.b64encode(zlib.compress(
json.dumps({
"gateway": self.gateway.debug_gateway(),
"client_address": self.lan_address,
}, default=_encode, indent=2).encode()
)).decode()
@cli def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
async def generate_test_data(self): gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
print("found gateway via M-SEARCH") unicast: bool = True, kwargs: Optional[Dict] = None,
try: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
external_ip = await self.get_external_ip()
print("got external ip: %s" % external_ip)
except (UPnPError, NotImplementedError):
print("failed to get the external ip")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
try:
ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping")
print("set up external mapping to port %i" % ext_port)
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
await self.delete_port_mapping(ext_port, "UDP")
print("deleted mapping")
except (UPnPError, NotImplementedError):
print("failed to add and remove a mapping")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
if self.gateway.devices:
device = list(self.gateway.devices.values())[0]
assert device.manufacturer and device.modelName
device_path = os.path.join(os.getcwd(), self.gateway.manufacturer_string)
else:
device_path = os.path.join(os.getcwd(), "UNKNOWN GATEWAY")
with open(device_path, "w") as f:
f.write(await self.debug_gateway())
return "Generated test data! -> %s" % device_path
@cli
async def get_natrsip_status(self) -> Tuple[bool, bool]:
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
return await self.gateway.commands.GetNATRSIPStatus()
@cli
async def set_connection_type(self, NewConnectionType: str) -> None:
"""Returns None"""
return await self.gateway.commands.SetConnectionType(NewConnectionType)
@cli
async def get_connection_type_info(self) -> Tuple[str, str]:
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
return await self.gateway.commands.GetConnectionTypeInfo()
@cli
async def get_status_info(self) -> Tuple[str, str, int]:
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
return await self.gateway.commands.GetStatusInfo()
@cli
async def force_termination(self) -> None:
"""Returns None"""
return await self.gateway.commands.ForceTermination()
@cli
async def request_connection(self) -> None:
"""Returns None"""
return await self.gateway.commands.RequestConnection()
@cli
async def get_common_link_properties(self):
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
return await self.gateway.commands.GetCommonLinkProperties()
@cli
async def get_total_bytes_sent(self):
"""Returns (NewTotalBytesSent)"""
return await self.gateway.commands.GetTotalBytesSent()
@cli
async def get_total_bytes_received(self):
"""Returns (NewTotalBytesReceived)"""
return await self.gateway.commands.GetTotalBytesReceived()
@cli
async def get_total_packets_sent(self):
"""Returns (NewTotalPacketsSent)"""
return await self.gateway.commands.GetTotalPacketsSent()
@cli
async def get_total_packets_received(self):
"""Returns (NewTotalPacketsReceived)"""
return await self.gateway.commands.GetTotalPacketsReceived()
@cli
async def x_get_ics_statistics(self) -> Tuple[int, int, int, int, str, str]:
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
return await self.gateway.commands.X_GetICSStatistics()
@cli
async def get_default_connection_service(self):
"""Returns (NewDefaultConnectionService)"""
return await self.gateway.commands.GetDefaultConnectionService()
@cli
async def set_default_connection_service(self, NewDefaultConnectionService: str) -> None:
"""Returns (None)"""
return await self.gateway.commands.SetDefaultConnectionService(NewDefaultConnectionService)
@cli
async def set_enabled_for_internet(self, NewEnabledForInternet: bool) -> None:
return await self.gateway.commands.SetEnabledForInternet(NewEnabledForInternet)
@cli
async def get_enabled_for_internet(self) -> bool:
return await self.gateway.commands.GetEnabledForInternet()
@cli
async def get_maximum_active_connections(self, NewActiveConnectionIndex: int):
return await self.gateway.commands.GetMaximumActiveConnections(NewActiveConnectionIndex)
@cli
async def get_active_connections(self) -> Tuple[str, str]:
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
return await self.gateway.commands.GetActiveConnections()
@classmethod
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None:
""" """
:param method: the command name :param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided :param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
@ -352,31 +278,44 @@ class UPnP:
:param kwargs: keyword arguments for the command :param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing :param loop: EventLoop, used for testing
""" """
kwargs = kwargs or {} kwargs = kwargs or {}
igd_args = igd_args igd_args = igd_args
timeout = int(timeout) timeout = int(timeout)
loop = loop or asyncio.get_event_loop_policy().get_event_loop() loop = loop or asyncio.get_event_loop()
fut: asyncio.Future = asyncio.Future() fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
cli_commands = [
'm_search',
'get_external_ip',
'add_port_mapping',
'get_port_mapping_by_index',
'get_redirects',
'get_specific_port_mapping',
'delete_port_mapping',
'get_next_mapping'
]
if method == 'm_search': # if we're only m_searching don't do any device discovery if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: cls.m_search( fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
) )
else: # automatically discover the gateway else: # automatically discover the gateway
try: try:
u = await cls.discover( u = await UPnP.discover(
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
) )
except UPnPError as err: except UPnPError as err:
fut.set_exception(err) fut.set_exception(err)
return return
if hasattr(u, method) and hasattr(getattr(u, method), "_cli"): if method not in cli_commands:
fn = getattr(u, method)
else:
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
return return
else:
fn = getattr(u, method)
try: # call the command try: # call the command
result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()}) result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
fut.set_result(result) fut.set_result(result)
@ -387,7 +326,7 @@ class UPnP:
log.exception("uncaught error") log.exception("uncaught error")
fut.set_exception(UPnPError("uncaught error: %s" % str(err))) fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"): if not hasattr(UPnP, method):
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
else: else:
loop.run_until_complete(wrapper()) loop.run_until_complete(wrapper())
@ -398,6 +337,7 @@ class UPnP:
return return
if isinstance(result, (list, tuple, dict)): if isinstance(result, (list, tuple, dict)):
print(json.dumps(result, indent=2, default=_encode)) print(json.dumps(result, indent=2))
else: else:
print(result) print(result)
return

View file

@ -1,91 +1,49 @@
import re import typing
import socket from collections import OrderedDict
from collections import defaultdict
from typing import Tuple, Dict
from xml.etree import ElementTree
import netifaces
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) str_any_dict = typing.Dict[str, typing.Any]
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t: ElementTree.Element) -> Dict: def _recursive_flatten(to_flatten: typing.Any, strip: str) -> typing.Any:
d: dict = {} if not isinstance(to_flatten, (list, dict)):
if t.attrib: return to_flatten
d[t.tag] = {} if isinstance(to_flatten, list):
children = list(t) assert isinstance(to_flatten, list)
if children: return [_recursive_flatten(i, strip) for i in to_flatten]
dd: dict = defaultdict(list) assert isinstance(to_flatten, dict)
for dc in map(etree_to_dict, children): keys: typing.List[str] = list(to_flatten.keys())
for k, v in dc.items(): copy: str_any_dict = OrderedDict()
dd[k].append(v) for k in keys:
d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()} item: typing.Any = to_flatten[k]
if t.attrib:
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
if t.text:
text = t.text.strip()
if children or t.attrib:
if text:
d[t.tag]['#text'] = text
else:
d[t.tag] = text
return d
def flatten_keys(d, strip):
if not isinstance(d, (list, dict)):
return d
if isinstance(d, list):
return [flatten_keys(i, strip) for i in d]
t = {}
for k, v in d.items():
if strip in k and strip != k: if strip in k and strip != k:
t[k.split(strip)[1]] = flatten_keys(v, strip) copy[k.split(strip)[1]] = _recursive_flatten(item, strip)
else: else:
t[k] = flatten_keys(v, strip) copy[k] = _recursive_flatten(item, strip)
return t return copy
def get_dict_val_case_insensitive(d, k): def flatten_keys(to_flatten: str_any_dict, strip: str) -> str_any_dict:
match = list(filter(lambda x: x.lower() == k.lower(), d.keys())) keys: typing.List[str] = list(to_flatten.keys())
if not match: copy: str_any_dict = OrderedDict()
return for k in keys:
item = to_flatten[k]
if strip in k and strip != k:
new_key: str = k.split(strip)[1]
copy[new_key] = _recursive_flatten(item, strip)
else:
copy[k] = _recursive_flatten(item, strip)
return copy
def get_dict_val_case_insensitive(source: typing.Dict[typing.AnyStr, typing.AnyStr],
key: typing.AnyStr) -> typing.Optional[typing.AnyStr]:
match: typing.List[typing.AnyStr] = list(filter(lambda x: x.lower() == key.lower(), source.keys()))
if not len(match):
return None
if len(match) > 1: if len(match) > 1:
raise KeyError("overlapping keys") raise KeyError("overlapping keys")
return d[match[0]] if len(match) == 1:
matched_key: typing.AnyStr = match[0]
# import struct return source[matched_key]
# import fcntl raise KeyError("overlapping keys")
# def get_ip_address(ifname):
# SIOCGIFADDR = 0x8915
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# return socket.inet_ntoa(fcntl.ioctl(
# s.fileno(),
# SIOCGIFADDR,
# struct.pack(b'256s', ifname[:15].encode())
# )[20:24])
def get_interfaces():
r = {
interface_name: (router_address, netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'])
for router_address, interface_name, _ in netifaces.gateways()[socket.AF_INET]
}
for interface_name in netifaces.interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in r:
continue
addresses = netifaces.ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
r[interface_name] = (gateway_guess, address)
r['default'] = r[netifaces.gateways()['default'][netifaces.AF_INET][1]]
return r
def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return '', ''

View file

@ -37,7 +37,7 @@ setup(
packages=find_packages(exclude=('tests',)), packages=find_packages(exclude=('tests',)),
entry_points={'console_scripts': console_scripts}, entry_points={'console_scripts': console_scripts},
install_requires=[ install_requires=[
'netifaces', 'netifaces', 'defusedxml'
], ],
extras_require={ extras_require={
'test': ( 'test': (

23
stubs/defusedxml.py Normal file
View file

@ -0,0 +1,23 @@
import typing
class ElementTree:
tag: typing.Optional[str] = None
"""The element's name."""
attrib: typing.Optional[typing.Dict[str, str]] = None
"""Dictionary of the element's attributes."""
text: typing.Optional[str] = None
tail: typing.Optional[str] = None
def __len__(self) -> int:
raise NotImplementedError()
def __iter__(self) -> typing.Iterator['ElementTree']:
raise NotImplementedError()
@classmethod
def fromstring(cls, xml_str: str) -> 'ElementTree':
raise NotImplementedError()

View file

@ -36,8 +36,9 @@ version = '0.10.7'
# functions # functions
def gateways(*args, **kwargs) -> typing.Dict[typing.Union[str, int],
def gateways(*args, **kwargs) -> typing.List: # real signature unknown typing.Union[typing.Dict[int, typing.Tuple[str, str]],
typing.List[typing.Tuple[str, str, bool]]]]:
""" """
Obtain a list of the gateways on this machine. Obtain a list of the gateways on this machine.
@ -56,7 +57,7 @@ def gateways(*args, **kwargs) -> typing.List: # real signature unknown
pass pass
def ifaddresses(*args, **kwargs) -> typing.Dict: # real signature unknown def ifaddresses(*args, **kwargs) -> typing.Dict[int, typing.List[typing.Dict[str, str]]]:
""" """
Obtain information about the specified network interface. Obtain information about the specified network interface.
@ -67,7 +68,7 @@ def ifaddresses(*args, **kwargs) -> typing.Dict: # real signature unknown
pass pass
def interfaces(*args, **kwargs) -> typing.List: # real signature unknown def interfaces(*args, **kwargs) -> typing.List[str]:
""" Obtain a list of the interfaces available on this machine. """ """ Obtain a list of the interfaces available on this machine. """
pass pass

View file

@ -1,7 +1,11 @@
import asyncio import asyncio
import unittest import unittest
import contextlib
import socket
from unittest import mock
from unittest.case import _Outcome from unittest.case import _Outcome
try: try:
from asyncio.runners import _cancel_all_tasks from asyncio.runners import _cancel_all_tasks
except ImportError: except ImportError:
@ -10,30 +14,92 @@ except ImportError:
pass pass
class TestBase(unittest.TestCase): @contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {}
sent_tcp_packets = sent_tcp_packets if sent_tcp_packets is not None else []
tcp_replies = tcp_replies or {}
async def create_connection(protocol_factory, host=None, port=None):
def write(p: asyncio.Protocol):
def _write(data):
sent_tcp_packets.append(data)
if data in tcp_replies:
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
return
else:
pass
return _write
protocol = protocol_factory()
transport = asyncio.Transport(extra={'socket': mock.Mock(spec=socket.socket)})
transport.close = lambda: None
transport.write = write(protocol)
protocol.connection_made(transport)
return transport, protocol
async def create_datagram_endpoint(proto_lam, sock=None):
def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr):
sent_udp_packets.append(data)
if (data, addr) in udp_replies:
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
(udp_expected_addr, 1900))
return _sendto
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto(protocol)
transport.sendto = mock_sock.sendto
protocol.connection_made(transport)
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_DGRAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_datagram_endpoint = create_datagram_endpoint
loop.create_connection = create_connection
yield
class AsyncioTestCase(unittest.TestCase):
# Implementation inspired by discussion: # Implementation inspired by discussion:
# https://bugs.python.org/issue32972 # https://bugs.python.org/issue32972
async def asyncSetUp(self): maxDiff = None
async def asyncSetUp(self): # pylint: disable=C0103
pass pass
async def asyncTearDown(self): async def asyncTearDown(self): # pylint: disable=C0103
pass pass
async def doAsyncCleanups(self): def run(self, result=None): # pylint: disable=R0915
pass
def run(self, result=None):
orig_result = result orig_result = result
if result is None: if result is None:
result = self.defaultTestResult() result = self.defaultTestResult()
startTestRun = getattr(result, 'startTestRun', None) startTestRun = getattr(result, 'startTestRun', None) # pylint: disable=C0103
if startTestRun is not None: if startTestRun is not None:
startTestRun() startTestRun()
result.startTest(self) result.startTest(self)
testMethod = getattr(self, self._testMethodName) testMethod = getattr(self, self._testMethodName) # pylint: disable=C0103
if (getattr(self.__class__, "__unittest_skip__", False) or if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)): getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped. # If the class or method was skipped.
@ -50,36 +116,36 @@ class TestBase(unittest.TestCase):
"__unittest_expecting_failure__", False) "__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result) outcome = _Outcome(result)
self.loop = asyncio.new_event_loop() # pylint: disable=W0201
asyncio.set_event_loop(self.loop)
self.loop.set_debug(True)
try: try:
self._outcome = outcome self._outcome = outcome
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.set_debug(True)
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
self.setUp() self.setUp()
loop.run_until_complete(self.asyncSetUp()) self.loop.run_until_complete(self.asyncSetUp())
if outcome.success: if outcome.success:
outcome.expecting_failure = expecting_failure outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True): with outcome.testPartExecutor(self, isTest=True):
possible_coroutine = testMethod() maybe_coroutine = testMethod()
if asyncio.iscoroutine(possible_coroutine): if asyncio.iscoroutine(maybe_coroutine):
loop.run_until_complete(possible_coroutine) self.loop.run_until_complete(maybe_coroutine)
outcome.expecting_failure = False outcome.expecting_failure = False
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown()) self.loop.run_until_complete(self.asyncTearDown())
self.tearDown() self.tearDown()
finally:
self.doAsyncCleanups()
try: try:
_cancel_all_tasks(loop) _cancel_all_tasks(self.loop)
loop.run_until_complete(loop.shutdown_asyncgens()) self.loop.run_until_complete(self.loop.shutdown_asyncgens())
finally: finally:
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
loop.close() self.loop.close()
self.doCleanups()
for test, reason in outcome.skipped: for test, reason in outcome.skipped:
self._addSkip(result, test, reason) self._addSkip(result, test, reason)
@ -96,9 +162,9 @@ class TestBase(unittest.TestCase):
finally: finally:
result.stopTest(self) result.stopTest(self)
if orig_result is None: if orig_result is None:
stopTestRun = getattr(result, 'stopTestRun', None) stopTestRun = getattr(result, 'stopTestRun', None) # pylint: disable=C0103
if stopTestRun is not None: if stopTestRun is not None:
stopTestRun() stopTestRun() # pylint: disable=E1102
# explicitly break reference cycles: # explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors # outcome.errors -> frame -> outcome -> outcome.errors
@ -109,5 +175,11 @@ class TestBase(unittest.TestCase):
# clear the outcome, no more needed # clear the outcome, no more needed
self._outcome = None self._outcome = None
def setUp(self): def doAsyncCleanups(self): # pylint: disable=C0103
self.loop = asyncio.get_event_loop_policy().get_event_loop() outcome = self._outcome or _Outcome()
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
maybe_coroutine = function(*args, **kwargs)
if asyncio.iscoroutine(maybe_coroutine):
self.loop.run_until_complete(maybe_coroutine)

0
tests/generate_test.py Normal file
View file

View file

@ -1,64 +0,0 @@
import asyncio
import contextlib
import socket
import mock
@contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {}
sent_tcp_packets = sent_tcp_packets if sent_tcp_packets is not None else []
tcp_replies = tcp_replies or {}
async def create_connection(protocol_factory, host=None, port=None):
def write(p: asyncio.Protocol):
def _write(data):
sent_tcp_packets.append(data)
if data in tcp_replies:
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
return _write
protocol = protocol_factory()
transport = asyncio.Transport(extra={'socket': mock.Mock(spec=socket.socket)})
transport.close = lambda: None
transport.write = write(protocol)
protocol.connection_made(transport)
return transport, protocol
async def create_datagram_endpoint(proto_lam, sock=None):
def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr):
sent_udp_packets.append(data)
if (data, addr) in udp_replies:
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
(udp_expected_addr, 1900))
return _sendto
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto(protocol)
transport.sendto = mock_sock.sendto
protocol.connection_made(transport)
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_DGRAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_datagram_endpoint = create_datagram_endpoint
loop.create_connection = create_connection
yield

View file

@ -0,0 +1,23 @@
import unittest
from asyncio import DatagramTransport
from aioupnp.protocols.multicast import MulticastProtocol
class TestMulticast(unittest.TestCase):
def test_it(self):
class none_socket:
sock = None
def get(self, name, default=None):
return default
protocol = MulticastProtocol('1.2.3.4', '1.2.3.4')
transport = DatagramTransport(none_socket())
protocol.set_ttl(1)
with self.assertRaises(ValueError):
_ = protocol.get_ttl()
protocol.connection_made(transport)
protocol.set_ttl(1)
with self.assertRaises(ValueError):
_ = protocol.get_ttl()

View file

@ -1,10 +1,9 @@
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.protocols.scpd import scpd_post, scpd_get from aioupnp.protocols.scpd import scpd_post, scpd_get
from tests import TestBase from tests import AsyncioTestCase, mock_tcp_and_udp
from tests.mocks import mock_tcp_and_udp
class TestSCPDGet(TestBase): class TestSCPDGet(AsyncioTestCase):
path, lan_address, port = '/IGDdevicedesc_brlan0.xml', '10.1.10.1', 49152 path, lan_address, port = '/IGDdevicedesc_brlan0.xml', '10.1.10.1', 49152
get_request = b'GET /IGDdevicedesc_brlan0.xml HTTP/1.1\r\n' \ get_request = b'GET /IGDdevicedesc_brlan0.xml HTTP/1.1\r\n' \
b'Accept-Encoding: gzip\r\nHost: 10.1.10.1\r\nConnection: Close\r\n\r\n' b'Accept-Encoding: gzip\r\nHost: 10.1.10.1\r\nConnection: Close\r\n\r\n'
@ -142,7 +141,7 @@ class TestSCPDGet(TestBase):
self.assertTrue(str(err).startswith('too many bytes written')) self.assertTrue(str(err).startswith('too many bytes written'))
class TestSCPDPost(TestBase): class TestSCPDPost(AsyncioTestCase):
param_names: list = [] param_names: list = []
kwargs: dict = {} kwargs: dict = {}
method, gateway_address, port = "GetExternalIPAddress", '10.0.0.1', 49152 method, gateway_address, port = "GetExternalIPAddress", '10.0.0.1', 49152

View file

@ -4,11 +4,10 @@ from aioupnp.protocols.m_search_patterns import packet_generator
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from tests import TestBase from tests import AsyncioTestCase, mock_tcp_and_udp
from tests.mocks import mock_tcp_and_udp
class TestSSDP(TestBase): class TestSSDP(AsyncioTestCase):
packet_args = list(packet_generator()) packet_args = list(packet_generator())
byte_packets = [SSDPDatagram("M-SEARCH", p).encode().encode() for p in packet_args] byte_packets = [SSDPDatagram("M-SEARCH", p).encode().encode() for p in packet_args]

View file

@ -1,5 +1,7 @@
import unittest import unittest
from aioupnp.fault import UPnPError
from aioupnp.serialization.scpd import serialize_scpd_get, deserialize_scpd_get_response from aioupnp.serialization.scpd import serialize_scpd_get, deserialize_scpd_get_response
from aioupnp.serialization.xml import xml_to_dict
from aioupnp.device import Device from aioupnp.device import Device
from aioupnp.util import get_dict_val_case_insensitive from aioupnp.util import get_dict_val_case_insensitive
@ -20,6 +22,28 @@ class TestSCPDSerialization(unittest.TestCase):
b"\r\n" \ b"\r\n" \
b"<?xml version=\"1.0\"?>\n<root xmlns=\"urn:schemas-upnp-org:device-1-0\">\n<specVersion>\n<major>1</major>\n<minor>0</minor>\n</specVersion>\n<device>\n<deviceType>urn:schemas-upnp-org:device:InternetGatewayDevice:1</deviceType>\n<friendlyName>CGA4131COM</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:Layer3Forwarding:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:L3Forwarding1</serviceId>\n<SCPDURL>/Layer3ForwardingSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/Layer3Forwarding</controlURL>\n<eventSubURL>/upnp/event/Layer3Forwarding</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n<device>\n<deviceType>urn:schemas-upnp-org:device:WANDevice:1</deviceType>\n<friendlyName>WANDevice:1</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:WANCommonIFC1</serviceId>\n<SCPDURL>/WANCommonInterfaceConfigSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/WANCommonInterfaceConfig0</controlURL>\n<eventSubURL>/upnp/event/WANCommonInterfaceConfig0</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n <device>\n <deviceType>urn:schemas-upnp-org:device:WANConnectionDevice:1</deviceType>\n <friendlyName>WANConnectionDevice:1</friendlyName>\n <manufacturer>Cisco</manufacturer>\n <manufacturerURL>http://www.cisco.com/</manufacturerURL>\n <modelDescription>CGA4131COM</modelDescription>\n <modelName>CGA4131COM</modelName>\n <modelNumber>CGA4131COM</modelNumber>\n <modelURL>http://www.cisco.com</modelURL>\n <serialNumber></serialNumber>\n <UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n <UPC>CGA4131COM</UPC>\n <serviceList>\n <service>\n <serviceType>urn:schemas-upnp-org:service:WANIPConnection:1</serviceType>\n <serviceId>urn:upnp-org:serviceId:WANIPConn1</serviceId>\n <SCPDURL>/WANIPConnectionServiceSCPD.xml</SCPDURL>\n <controlURL>/upnp/control/WANIPConnection0</controlURL>\n <eventSubURL>/upnp/event/WANIPConnection0</eventSubURL>\n </service>\n </serviceList>\n </device>\n</deviceList>\n</device>\n</deviceList>\n<presentationURL>http://10.1.10.1/</presentationURL></device>\n</root>\n" b"<?xml version=\"1.0\"?>\n<root xmlns=\"urn:schemas-upnp-org:device-1-0\">\n<specVersion>\n<major>1</major>\n<minor>0</minor>\n</specVersion>\n<device>\n<deviceType>urn:schemas-upnp-org:device:InternetGatewayDevice:1</deviceType>\n<friendlyName>CGA4131COM</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:Layer3Forwarding:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:L3Forwarding1</serviceId>\n<SCPDURL>/Layer3ForwardingSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/Layer3Forwarding</controlURL>\n<eventSubURL>/upnp/event/Layer3Forwarding</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n<device>\n<deviceType>urn:schemas-upnp-org:device:WANDevice:1</deviceType>\n<friendlyName>WANDevice:1</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:WANCommonIFC1</serviceId>\n<SCPDURL>/WANCommonInterfaceConfigSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/WANCommonInterfaceConfig0</controlURL>\n<eventSubURL>/upnp/event/WANCommonInterfaceConfig0</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n <device>\n <deviceType>urn:schemas-upnp-org:device:WANConnectionDevice:1</deviceType>\n <friendlyName>WANConnectionDevice:1</friendlyName>\n <manufacturer>Cisco</manufacturer>\n <manufacturerURL>http://www.cisco.com/</manufacturerURL>\n <modelDescription>CGA4131COM</modelDescription>\n <modelName>CGA4131COM</modelName>\n <modelNumber>CGA4131COM</modelNumber>\n <modelURL>http://www.cisco.com</modelURL>\n <serialNumber></serialNumber>\n <UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n <UPC>CGA4131COM</UPC>\n <serviceList>\n <service>\n <serviceType>urn:schemas-upnp-org:service:WANIPConnection:1</serviceType>\n <serviceId>urn:upnp-org:serviceId:WANIPConn1</serviceId>\n <SCPDURL>/WANIPConnectionServiceSCPD.xml</SCPDURL>\n <controlURL>/upnp/control/WANIPConnection0</controlURL>\n <eventSubURL>/upnp/event/WANIPConnection0</eventSubURL>\n </service>\n </serviceList>\n </device>\n</deviceList>\n</device>\n</deviceList>\n<presentationURL>http://10.1.10.1/</presentationURL></device>\n</root>\n"
response_bad_root_device_name = b"HTTP/1.1 200 OK\r\n" \
b"CONTENT-LENGTH: 2972\r\n" \
b"CONTENT-TYPE: text/xml\r\n" \
b"DATE: Thu, 18 Oct 2018 01:20:23 GMT\r\n" \
b"LAST-MODIFIED: Fri, 28 Sep 2018 18:35:48 GMT\r\n" \
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
b"X-User-Agent: redsonic\r\n" \
b"CONNECTION: close\r\n" \
b"\r\n" \
b"<?xml version=\"1.0\"?>\n<root xmlns=\"urn:schemas-upnp-org:device-1-?\">\n<specVersion>\n<major>1</major>\n<minor>0</minor>\n</specVersion>\n<device>\n<deviceType>urn:schemas-upnp-org:device:InternetGatewayDevic3:1</deviceType>\n<friendlyName>CGA4131COM</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:Layer3Forwarding:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:L3Forwarding1</serviceId>\n<SCPDURL>/Layer3ForwardingSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/Layer3Forwarding</controlURL>\n<eventSubURL>/upnp/event/Layer3Forwarding</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n<device>\n<deviceType>urn:schemas-upnp-org:device:WANDevice:1</deviceType>\n<friendlyName>WANDevice:1</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:WANCommonIFC1</serviceId>\n<SCPDURL>/WANCommonInterfaceConfigSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/WANCommonInterfaceConfig0</controlURL>\n<eventSubURL>/upnp/event/WANCommonInterfaceConfig0</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n <device>\n <deviceType>urn:schemas-upnp-org:device:WANConnectionDevice:1</deviceType>\n <friendlyName>WANConnectionDevice:1</friendlyName>\n <manufacturer>Cisco</manufacturer>\n <manufacturerURL>http://www.cisco.com/</manufacturerURL>\n <modelDescription>CGA4131COM</modelDescription>\n <modelName>CGA4131COM</modelName>\n <modelNumber>CGA4131COM</modelNumber>\n <modelURL>http://www.cisco.com</modelURL>\n <serialNumber></serialNumber>\n <UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n <UPC>CGA4131COM</UPC>\n <serviceList>\n <service>\n <serviceType>urn:schemas-upnp-org:service:WANIPConnection:1</serviceType>\n <serviceId>urn:upnp-org:serviceId:WANIPConn1</serviceId>\n <SCPDURL>/WANIPConnectionServiceSCPD.xml</SCPDURL>\n <controlURL>/upnp/control/WANIPConnection0</controlURL>\n <eventSubURL>/upnp/event/WANIPConnection0</eventSubURL>\n </service>\n </serviceList>\n </device>\n</deviceList>\n</device>\n</deviceList>\n<presentationURL>http://10.1.10.1/</presentationURL></device>\n</root>\n"
response_bad_root_xmls = b"HTTP/1.1 200 OK\r\n" \
b"CONTENT-LENGTH: 2972\r\n" \
b"CONTENT-TYPE: text/xml\r\n" \
b"DATE: Thu, 18 Oct 2018 01:20:23 GMT\r\n" \
b"LAST-MODIFIED: Fri, 28 Sep 2018 18:35:48 GMT\r\n" \
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
b"X-User-Agent: redsonic\r\n" \
b"CONNECTION: close\r\n" \
b"\r\n" \
b"<?xml version=\"1.0\"?>\n<root xmlns=\"urn:schemas-upnp--org:device-1-0\">\n<specVersion>\n<major>1</major>\n<minor>0</minor>\n</specVersion>\n<device>\n<deviceType>urn:schemas-upnp-org:device:InternetGatewayDevic3:1</deviceType>\n<friendlyName>CGA4131COM</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:Layer3Forwarding:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:L3Forwarding1</serviceId>\n<SCPDURL>/Layer3ForwardingSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/Layer3Forwarding</controlURL>\n<eventSubURL>/upnp/event/Layer3Forwarding</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n<device>\n<deviceType>urn:schemas-upnp-org:device:WANDevice:1</deviceType>\n<friendlyName>WANDevice:1</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:WANCommonIFC1</serviceId>\n<SCPDURL>/WANCommonInterfaceConfigSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/WANCommonInterfaceConfig0</controlURL>\n<eventSubURL>/upnp/event/WANCommonInterfaceConfig0</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n <device>\n <deviceType>urn:schemas-upnp-org:device:WANConnectionDevice:1</deviceType>\n <friendlyName>WANConnectionDevice:1</friendlyName>\n <manufacturer>Cisco</manufacturer>\n <manufacturerURL>http://www.cisco.com/</manufacturerURL>\n <modelDescription>CGA4131COM</modelDescription>\n <modelName>CGA4131COM</modelName>\n <modelNumber>CGA4131COM</modelNumber>\n <modelURL>http://www.cisco.com</modelURL>\n <serialNumber></serialNumber>\n <UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n <UPC>CGA4131COM</UPC>\n <serviceList>\n <service>\n <serviceType>urn:schemas-upnp-org:service:WANIPConnection:1</serviceType>\n <serviceId>urn:upnp-org:serviceId:WANIPConn1</serviceId>\n <SCPDURL>/WANIPConnectionServiceSCPD.xml</SCPDURL>\n <controlURL>/upnp/control/WANIPConnection0</controlURL>\n <eventSubURL>/upnp/event/WANIPConnection0</eventSubURL>\n </service>\n </serviceList>\n </device>\n</deviceList>\n</device>\n</deviceList>\n<presentationURL>http://10.1.10.1/</presentationURL></device>\n</root>\n"
expected_parsed = { expected_parsed = {
'specVersion': {'major': '1', 'minor': '0'}, 'specVersion': {'major': '1', 'minor': '0'},
'device': { 'device': {
@ -94,6 +118,87 @@ class TestSCPDSerialization(unittest.TestCase):
def test_serialize_get(self): def test_serialize_get(self):
self.assertEqual(serialize_scpd_get(self.path, self.lan_address), self.get_request) self.assertEqual(serialize_scpd_get(self.path, self.lan_address), self.get_request)
self.assertEqual(serialize_scpd_get(self.path, 'http://' + self.lan_address), self.get_request)
self.assertEqual(serialize_scpd_get(self.path, 'http://' + self.lan_address + ':1337'), self.get_request)
self.assertEqual(serialize_scpd_get(self.path, self.lan_address + ':1337'), self.get_request)
def test_parse_device_response_xml(self):
self.assertDictEqual(
xml_to_dict('<?xml version="1.0"?>\n<root xmlns="urn:schemas-upnp-org:device-1-0">\n\t<specVersion>\n\t\t<major>1</major>\n\t\t<minor>0</minor>\n\t</specVersion>\n\t<URLBase>http://10.0.0.1:49152</URLBase>\n\t<device>\n\t\t<deviceType>urn:schemas-upnp-org:device:InternetGatewayDevice:1</deviceType>\n\t\t<friendlyName>Wireless Broadband Router</friendlyName>\n\t\t<manufacturer>D-Link Corporation</manufacturer>\n\t\t<manufacturerURL>http://www.dlink.com</manufacturerURL>\n\t\t<modelDescription>D-Link Router</modelDescription>\n\t\t<modelName>D-Link Router</modelName>\n\t\t<modelNumber>DIR-890L</modelNumber>\n\t\t<modelURL>http://www.dlink.com</modelURL>\n\t\t<serialNumber>120</serialNumber>\n\t\t<UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n\t\t<iconList>\n\t\t\t<icon>\n\t\t\t\t<mimetype>image/gif</mimetype>\n\t\t\t\t<width>118</width>\n\t\t\t\t<height>119</height>\n\t\t\t\t<depth>8</depth>\n\t\t\t\t<url>/ligd.gif</url>\n\t\t\t</icon>\n\t\t</iconList>\n\t\t<serviceList>\n\t\t\t<service>\n\t\t\t\t<serviceType>urn:schemas-microsoft-com:service:OSInfo:1</serviceType>\n\t\t\t\t<serviceId>urn:microsoft-com:serviceId:OSInfo1</serviceId>\n\t\t\t\t<controlURL>/soap.cgi?service=OSInfo1</controlURL>\n\t\t\t\t<eventSubURL>/gena.cgi?service=OSInfo1</eventSubURL>\n\t\t\t\t<SCPDURL>/OSInfo.xml</SCPDURL>\n\t\t\t</service>\n\t\t\t<service>\n\t\t\t\t<serviceType>urn:schemas-upnp-org:service:Layer3Forwarding:1</serviceType>\n\t\t\t\t<serviceId>urn:upnp-org:serviceId:L3Forwarding1</serviceId>\n\t\t\t\t<controlURL>/soap.cgi?service=L3Forwarding1</controlURL>\n\t\t\t\t<eventSubURL>/gena.cgi?service=L3Forwarding1</eventSubURL>\n\t\t\t\t<SCPDURL>/Layer3Forwarding.xml</SCPDURL>\n\t\t\t</service>\n\t\t</serviceList>\n\t\t<deviceList>\n\t\t\t<device>\n\t\t\t\t<deviceType>urn:schemas-upnp-org:device:WANDevice:1</deviceType>\n\t\t\t\t<friendlyName>WANDevice</friendlyName>\n\t\t\t\t<manufacturer>D-Link</manufacturer>\n\t\t\t\t<manufacturerURL>http://www.dlink.com</manufacturerURL>\n\t\t\t\t<modelDescription>WANDevice</modelDescription>\n\t\t\t\t<modelName>DIR-890L</modelName>\n\t\t\t\t<modelNumber>1</modelNumber>\n\t\t\t\t<modelURL>http://www.dlink.com</modelURL>\n\t\t\t\t<serialNumber>120</serialNumber>\n\t\t\t\t<UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n\t\t\t\t<serviceList>\n\t\t\t\t\t<service>\n\t\t\t\t\t\t<serviceType>urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1</serviceType>\n\t\t\t\t\t\t<serviceId>urn:upnp-org:serviceId:WANCommonIFC1</serviceId>\n\t\t\t\t\t\t<controlURL>/soap.cgi?service=WANCommonIFC1</controlURL>\n\t\t\t\t\t\t<eventSubURL>/gena.cgi?service=WANCommonIFC1</eventSubURL>\n\t\t\t\t\t\t<SCPDURL>/WANCommonInterfaceConfig.xml</SCPDURL>\n\t\t\t\t\t</service>\n\t\t\t\t</serviceList>\n\t\t\t\t<deviceList>\n\t\t\t\t\t<device>\n\t\t\t\t\t\t<deviceType>urn:schemas-upnp-org:device:WANConnectionDevice:1</deviceType>\n\t\t\t\t\t\t<friendlyName>WANConnectionDevice</friendlyName>\n\t\t\t\t\t\t<manufacturer>D-Link</manufacturer>\n\t\t\t\t\t\t<manufacturerURL>http://www.dlink.com</manufacturerURL>\n\t\t\t\t\t\t<modelDescription>WanConnectionDevice</modelDescription>\n\t\t\t\t\t\t<modelName>DIR-890L</modelName>\n\t\t\t\t\t\t<modelNumber>1</modelNumber>\n\t\t\t\t\t\t<modelURL>http://www.dlink.com</modelURL>\n\t\t\t\t\t\t<serialNumber>120</serialNumber>\n\t\t\t\t\t\t<UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n\t\t\t\t\t\t<serviceList>\n\t\t\t\t\t\t\t<service>\n\t\t\t\t\t\t\t\t<serviceType>urn:schemas-upnp-org:service:WANEthernetLinkConfig:1</serviceType>\n\t\t\t\t\t\t\t\t<serviceId>urn:upnp-org:serviceId:WANEthLinkC1</serviceId>\n\t\t\t\t\t\t\t\t<controlURL>/soap.cgi?service=WANEthLinkC1</controlURL>\n\t\t\t\t\t\t\t\t<eventSubURL>/gena.cgi?service=WANEthLinkC1</eventSubURL>\n\t\t\t\t\t\t\t\t<SCPDURL>/WANEthernetLinkConfig.xml</SCPDURL>\n\t\t\t\t\t\t\t</service>\n\t\t\t\t\t\t\t<service>\n\t\t\t\t\t\t\t\t<serviceType>urn:schemas-upnp-org:service:WANIPConnection:1</serviceType>\n\t\t\t\t\t\t\t\t<serviceId>urn:upnp-org:serviceId:WANIPConn1</serviceId>\n\t\t\t\t\t\t\t\t<controlURL>/soap.cgi?service=WANIPConn1</controlURL>\n\t\t\t\t\t\t\t\t<eventSubURL>/gena.cgi?service=WANIPConn1</eventSubURL>\n\t\t\t\t\t\t\t\t<SCPDURL>/WANIPConnection.xml</SCPDURL>\n\t\t\t\t\t\t\t</service>\n\t\t\t\t\t\t</serviceList>\n\t\t\t\t\t</device>\n\t\t\t\t</deviceList>\n\t\t\t</device>\n\t\t</deviceList>\n\t\t<presentationURL>http://10.0.0.1</presentationURL>\n\t</device>\n</root>\n'),
{'{urn:schemas-upnp-org:device-1-0}root': {
'{urn:schemas-upnp-org:device-1-0}specVersion': {'{urn:schemas-upnp-org:device-1-0}major': '1',
'{urn:schemas-upnp-org:device-1-0}minor': '0'},
'{urn:schemas-upnp-org:device-1-0}URLBase': 'http://10.0.0.1:49152',
'{urn:schemas-upnp-org:device-1-0}device': {
'{urn:schemas-upnp-org:device-1-0}deviceType': 'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
'{urn:schemas-upnp-org:device-1-0}friendlyName': 'Wireless Broadband Router',
'{urn:schemas-upnp-org:device-1-0}manufacturer': 'D-Link Corporation',
'{urn:schemas-upnp-org:device-1-0}manufacturerURL': 'http://www.dlink.com',
'{urn:schemas-upnp-org:device-1-0}modelDescription': 'D-Link Router',
'{urn:schemas-upnp-org:device-1-0}modelName': 'D-Link Router',
'{urn:schemas-upnp-org:device-1-0}modelNumber': 'DIR-890L',
'{urn:schemas-upnp-org:device-1-0}modelURL': 'http://www.dlink.com',
'{urn:schemas-upnp-org:device-1-0}serialNumber': '120',
'{urn:schemas-upnp-org:device-1-0}UDN': 'uuid:11111111-2222-3333-4444-555555555555',
'{urn:schemas-upnp-org:device-1-0}iconList': {'{urn:schemas-upnp-org:device-1-0}icon': {
'{urn:schemas-upnp-org:device-1-0}mimetype': 'image/gif',
'{urn:schemas-upnp-org:device-1-0}width': '118',
'{urn:schemas-upnp-org:device-1-0}height': '119', '{urn:schemas-upnp-org:device-1-0}depth': '8',
'{urn:schemas-upnp-org:device-1-0}url': '/ligd.gif'}},
'{urn:schemas-upnp-org:device-1-0}serviceList': {'{urn:schemas-upnp-org:device-1-0}service': [
{'{urn:schemas-upnp-org:device-1-0}serviceType': 'urn:schemas-microsoft-com:service:OSInfo:1',
'{urn:schemas-upnp-org:device-1-0}serviceId': 'urn:microsoft-com:serviceId:OSInfo1',
'{urn:schemas-upnp-org:device-1-0}controlURL': '/soap.cgi?service=OSInfo1',
'{urn:schemas-upnp-org:device-1-0}eventSubURL': '/gena.cgi?service=OSInfo1',
'{urn:schemas-upnp-org:device-1-0}SCPDURL': '/OSInfo.xml'}, {
'{urn:schemas-upnp-org:device-1-0}serviceType': 'urn:schemas-upnp-org:service:Layer3Forwarding:1',
'{urn:schemas-upnp-org:device-1-0}serviceId': 'urn:upnp-org:serviceId:L3Forwarding1',
'{urn:schemas-upnp-org:device-1-0}controlURL': '/soap.cgi?service=L3Forwarding1',
'{urn:schemas-upnp-org:device-1-0}eventSubURL': '/gena.cgi?service=L3Forwarding1',
'{urn:schemas-upnp-org:device-1-0}SCPDURL': '/Layer3Forwarding.xml'}]},
'{urn:schemas-upnp-org:device-1-0}deviceList': {'{urn:schemas-upnp-org:device-1-0}device': {
'{urn:schemas-upnp-org:device-1-0}deviceType': 'urn:schemas-upnp-org:device:WANDevice:1',
'{urn:schemas-upnp-org:device-1-0}friendlyName': 'WANDevice',
'{urn:schemas-upnp-org:device-1-0}manufacturer': 'D-Link',
'{urn:schemas-upnp-org:device-1-0}manufacturerURL': 'http://www.dlink.com',
'{urn:schemas-upnp-org:device-1-0}modelDescription': 'WANDevice',
'{urn:schemas-upnp-org:device-1-0}modelName': 'DIR-890L',
'{urn:schemas-upnp-org:device-1-0}modelNumber': '1',
'{urn:schemas-upnp-org:device-1-0}modelURL': 'http://www.dlink.com',
'{urn:schemas-upnp-org:device-1-0}serialNumber': '120',
'{urn:schemas-upnp-org:device-1-0}UDN': 'uuid:11111111-2222-3333-4444-555555555555',
'{urn:schemas-upnp-org:device-1-0}serviceList': {'{urn:schemas-upnp-org:device-1-0}service': {
'{urn:schemas-upnp-org:device-1-0}serviceType': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'{urn:schemas-upnp-org:device-1-0}serviceId': 'urn:upnp-org:serviceId:WANCommonIFC1',
'{urn:schemas-upnp-org:device-1-0}controlURL': '/soap.cgi?service=WANCommonIFC1',
'{urn:schemas-upnp-org:device-1-0}eventSubURL': '/gena.cgi?service=WANCommonIFC1',
'{urn:schemas-upnp-org:device-1-0}SCPDURL': '/WANCommonInterfaceConfig.xml'}},
'{urn:schemas-upnp-org:device-1-0}deviceList': {'{urn:schemas-upnp-org:device-1-0}device': {
'{urn:schemas-upnp-org:device-1-0}deviceType': 'urn:schemas-upnp-org:device:WANConnectionDevice:1',
'{urn:schemas-upnp-org:device-1-0}friendlyName': 'WANConnectionDevice',
'{urn:schemas-upnp-org:device-1-0}manufacturer': 'D-Link',
'{urn:schemas-upnp-org:device-1-0}manufacturerURL': 'http://www.dlink.com',
'{urn:schemas-upnp-org:device-1-0}modelDescription': 'WanConnectionDevice',
'{urn:schemas-upnp-org:device-1-0}modelName': 'DIR-890L',
'{urn:schemas-upnp-org:device-1-0}modelNumber': '1',
'{urn:schemas-upnp-org:device-1-0}modelURL': 'http://www.dlink.com',
'{urn:schemas-upnp-org:device-1-0}serialNumber': '120',
'{urn:schemas-upnp-org:device-1-0}UDN': 'uuid:11111111-2222-3333-4444-555555555555',
'{urn:schemas-upnp-org:device-1-0}serviceList': {
'{urn:schemas-upnp-org:device-1-0}service': [{
'{urn:schemas-upnp-org:device-1-0}serviceType': 'urn:schemas-upnp-org:service:WANEthernetLinkConfig:1',
'{urn:schemas-upnp-org:device-1-0}serviceId': 'urn:upnp-org:serviceId:WANEthLinkC1',
'{urn:schemas-upnp-org:device-1-0}controlURL': '/soap.cgi?service=WANEthLinkC1',
'{urn:schemas-upnp-org:device-1-0}eventSubURL': '/gena.cgi?service=WANEthLinkC1',
'{urn:schemas-upnp-org:device-1-0}SCPDURL': '/WANEthernetLinkConfig.xml'},
{
'{urn:schemas-upnp-org:device-1-0}serviceType': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'{urn:schemas-upnp-org:device-1-0}serviceId': 'urn:upnp-org:serviceId:WANIPConn1',
'{urn:schemas-upnp-org:device-1-0}controlURL': '/soap.cgi?service=WANIPConn1',
'{urn:schemas-upnp-org:device-1-0}eventSubURL': '/gena.cgi?service=WANIPConn1',
'{urn:schemas-upnp-org:device-1-0}SCPDURL': '/WANIPConnection.xml'}]}}}}},
'{urn:schemas-upnp-org:device-1-0}presentationURL': 'http://10.0.0.1'}}}
)
def test_deserialize_get_response(self): def test_deserialize_get_response(self):
self.assertDictEqual(deserialize_scpd_get_response(self.response), self.expected_parsed) self.assertDictEqual(deserialize_scpd_get_response(self.response), self.expected_parsed)
@ -101,6 +206,14 @@ class TestSCPDSerialization(unittest.TestCase):
def test_deserialize_blank(self): def test_deserialize_blank(self):
self.assertDictEqual(deserialize_scpd_get_response(b''), {}) self.assertDictEqual(deserialize_scpd_get_response(b''), {})
def test_fail_to_deserialize_invalid_root_device(self):
with self.assertRaises(UPnPError):
deserialize_scpd_get_response(self.response_bad_root_device_name)
def test_fail_to_deserialize_invalid_root_xmls(self):
with self.assertRaises(UPnPError):
deserialize_scpd_get_response(self.response_bad_root_xmls)
def test_deserialize_to_device_object(self): def test_deserialize_to_device_object(self):
devices = [] devices = []
services = [] services = []
@ -173,3 +286,78 @@ class TestSCPDSerialization(unittest.TestCase):
}, 'presentationURL': 'http://10.1.10.1/' }, 'presentationURL': 'http://10.1.10.1/'
} }
self.assertDictEqual(expected_result, device.as_dict()) self.assertDictEqual(expected_result, device.as_dict())
def test_deserialize_another_device(self):
xml_bytes = b"<?xml version=\"1.0\"?>\n<root xmlns=\"urn:schemas-upnp-org:device-1-0\">\n<specVersion>\n<major>1</major>\n<minor>0</minor>\n</specVersion>\n<device>\n<deviceType>urn:schemas-upnp-org:device:InternetGatewayDevice:1</deviceType>\n<friendlyName>CGA4131COM</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:11111111-2222-3333-4444-555555555556</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:Layer3Forwarding:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:L3Forwarding1</serviceId>\n<SCPDURL>/Layer3ForwardingSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/Layer3Forwarding</controlURL>\n<eventSubURL>/upnp/event/Layer3Forwarding</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n<device>\n<deviceType>urn:schemas-upnp-org:device:WANDevice:1</deviceType>\n<friendlyName>WANDevice:1</friendlyName>\n<manufacturer>Cisco</manufacturer>\n<manufacturerURL>http://www.cisco.com/</manufacturerURL>\n<modelDescription>CGA4131COM</modelDescription>\n<modelName>CGA4131COM</modelName>\n<modelNumber>CGA4131COM</modelNumber>\n<modelURL>http://www.cisco.com</modelURL>\n<serialNumber></serialNumber>\n<UDN>uuid:ebf5a0a0-1dd1-11b2-a92f-603d266f9915</UDN>\n<UPC>CGA4131COM</UPC>\n<serviceList>\n<service>\n<serviceType>urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1</serviceType>\n<serviceId>urn:upnp-org:serviceId:WANCommonIFC1</serviceId>\n<SCPDURL>/WANCommonInterfaceConfigSCPD.xml</SCPDURL>\n<controlURL>/upnp/control/WANCommonInterfaceConfig0</controlURL>\n<eventSubURL>/upnp/event/WANCommonInterfaceConfig0</eventSubURL>\n</service>\n</serviceList>\n<deviceList>\n <device>\n <deviceType>urn:schemas-upnp-org:device:WANConnectionDevice:1</deviceType>\n <friendlyName>WANConnectionDevice:1</friendlyName>\n <manufacturer>Cisco</manufacturer>\n <manufacturerURL>http://www.cisco.com/</manufacturerURL>\n <modelDescription>CGA4131COM</modelDescription>\n <modelName>CGA4131COM</modelName>\n <modelNumber>CGA4131COM</modelNumber>\n <modelURL>http://www.cisco.com</modelURL>\n <serialNumber></serialNumber>\n <UDN>uuid:11111111-2222-3333-4444-555555555555</UDN>\n <UPC>CGA4131COM</UPC>\n <serviceList>\n <service>\n <serviceType>urn:schemas-upnp-org:service:WANIPConnection:1</serviceType>\n <serviceId>urn:upnp-org:serviceId:WANIPConn1</serviceId>\n <SCPDURL>/WANIPConnectionServiceSCPD.xml</SCPDURL>\n <controlURL>/upnp/control/WANIPConnection0</controlURL>\n <eventSubURL>/upnp/event/WANIPConnection0</eventSubURL>\n </service>\n </serviceList>\n </device>\n</deviceList>\n</device>\n</deviceList>\n<presentationURL>http://10.1.10.1/</presentationURL></device>\n</root>\n"
expected_parsed = {
'specVersion': {'major': '1', 'minor': '0'},
'device': {
'deviceType': 'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
'friendlyName': 'CGA4131COM',
'manufacturer': 'Cisco',
'manufacturerURL': 'http://www.cisco.com/',
'modelDescription': 'CGA4131COM',
'modelName': 'CGA4131COM',
'modelNumber': 'CGA4131COM',
'modelURL': 'http://www.cisco.com',
'UDN': 'uuid:11111111-2222-3333-4444-555555555556',
'UPC': 'CGA4131COM',
'serviceList': {
'service': {
'serviceType': 'urn:schemas-upnp-org:service:Layer3Forwarding:1',
'serviceId': 'urn:upnp-org:serviceId:L3Forwarding1',
'SCPDURL': '/Layer3ForwardingSCPD.xml',
'controlURL': '/upnp/control/Layer3Forwarding',
'eventSubURL': '/upnp/event/Layer3Forwarding'
}
},
'deviceList': {
'device': {
'deviceType': 'urn:schemas-upnp-org:device:WANDevice:1',
'friendlyName': 'WANDevice:1',
'manufacturer': 'Cisco',
'manufacturerURL': 'http://www.cisco.com/',
'modelDescription': 'CGA4131COM',
'modelName': 'CGA4131COM',
'modelNumber': 'CGA4131COM',
'modelURL': 'http://www.cisco.com',
'UDN': 'uuid:ebf5a0a0-1dd1-11b2-a92f-603d266f9915',
'UPC': 'CGA4131COM',
'serviceList': {
'service': {
'serviceType': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'serviceId': 'urn:upnp-org:serviceId:WANCommonIFC1',
'SCPDURL': '/WANCommonInterfaceConfigSCPD.xml',
'controlURL': '/upnp/control/WANCommonInterfaceConfig0',
'eventSubURL': '/upnp/event/WANCommonInterfaceConfig0'
}
},
'deviceList': {
'device': {
'deviceType': 'urn:schemas-upnp-org:device:WANConnectionDevice:1',
'friendlyName': 'WANConnectionDevice:1',
'manufacturer': 'Cisco',
'manufacturerURL': 'http://www.cisco.com/',
'modelDescription': 'CGA4131COM',
'modelName': 'CGA4131COM',
'modelNumber': 'CGA4131COM',
'modelURL': 'http://www.cisco.com',
'UDN': 'uuid:11111111-2222-3333-4444-555555555555',
'UPC': 'CGA4131COM',
'serviceList': {
'service': {
'serviceType': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'serviceId': 'urn:upnp-org:serviceId:WANIPConn1',
'SCPDURL': '/WANIPConnectionServiceSCPD.xml',
'controlURL': '/upnp/control/WANIPConnection0',
'eventSubURL': '/upnp/event/WANIPConnection0'
}
}
}
}
}
},
'presentationURL': 'http://10.1.10.1/'
}
}
self.assertDictEqual(expected_parsed, deserialize_scpd_get_response(xml_bytes))

View file

@ -28,6 +28,26 @@ class TestSOAPSerialization(unittest.TestCase):
b"\r\n" \ b"\r\n" \
b"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\"><s:Body>\n<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n<NewExternalIPAddress>11.22.33.44</NewExternalIPAddress>\r\n</u:GetExternalIPAddressResponse>\r\n</s:Body> </s:Envelope>" b"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\"><s:Body>\n<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n<NewExternalIPAddress>11.22.33.44</NewExternalIPAddress>\r\n</u:GetExternalIPAddressResponse>\r\n</s:Body> </s:Envelope>"
blank_response = b"HTTP/1.1 200 OK\r\n" \
b"CONTENT-LENGTH: 148\r\n" \
b"CONTENT-TYPE: text/xml; charset=\"utf-8\"\r\n" \
b"DATE: Thu, 18 Oct 2018 01:20:23 GMT\r\n" \
b"EXT:\r\n" \
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
b"X-User-Agent: redsonic\r\n" \
b"\r\n" \
b"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\"><s:Body>\n</s:Body> </s:Envelope>"
blank_response_body = b"HTTP/1.1 200 OK\r\n" \
b"CONTENT-LENGTH: 280\r\n" \
b"CONTENT-TYPE: text/xml; charset=\"utf-8\"\r\n" \
b"DATE: Thu, 18 Oct 2018 01:20:23 GMT\r\n" \
b"EXT:\r\n" \
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
b"X-User-Agent: redsonic\r\n" \
b"\r\n" \
b"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\"><s:Body>\n<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\"></u:GetExternalIPAddressResponse>\r\n</s:Body> </s:Envelope>"
error_response = b"HTTP/1.1 500 Internal Server Error\r\n" \ error_response = b"HTTP/1.1 500 Internal Server Error\r\n" \
b"Server: WebServer\r\n" \ b"Server: WebServer\r\n" \
b"Date: Thu, 11 Oct 2018 22:16:17 GMT\r\n" \ b"Date: Thu, 11 Oct 2018 22:16:17 GMT\r\n" \
@ -43,12 +63,29 @@ class TestSOAPSerialization(unittest.TestCase):
self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs
), self.post_bytes) ), self.post_bytes)
def test_serialize_post_http_host(self):
self.assertEqual(serialize_soap_post(
self.method, self.param_names, self.st, b'http://' + self.gateway_address, self.path, **self.kwargs
), self.post_bytes)
def test_deserialize_post_response(self): def test_deserialize_post_response(self):
self.assertDictEqual( self.assertDictEqual(
deserialize_soap_post_response(self.post_response, self.method, service_id=self.st.decode()), deserialize_soap_post_response(self.post_response, self.method, service_id=self.st.decode()),
{'NewExternalIPAddress': '11.22.33.44'} {'NewExternalIPAddress': '11.22.33.44'}
) )
def test_deserialize_error_response_field_not_found(self):
with self.assertRaises(UPnPError) as e:
deserialize_soap_post_response(self.post_response, self.method + 'derp', service_id=self.st.decode())
self.assertTrue(str(e.exception).startswith('unknown response fields for GetExternalIPAddressderp'))
def test_deserialize_blank_response(self):
# TODO: these seem like they should error... this test will break and have to be updated
self.assertDictEqual({}, deserialize_soap_post_response(self.blank_response, self.method,
service_id=self.st.decode()))
self.assertDictEqual({}, deserialize_soap_post_response(self.blank_response_body, self.method,
service_id=self.st.decode()))
def test_raise_from_error_response(self): def test_raise_from_error_response(self):
raised = False raised = False
try: try:

View file

@ -1,7 +1,6 @@
import contextlib import contextlib
from io import StringIO from io import StringIO
from tests import TestBase from tests import AsyncioTestCase, mock_tcp_and_udp
from tests.mocks import mock_tcp_and_udp
from collections import OrderedDict from collections import OrderedDict
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.protocols.m_search_patterns import packet_generator from aioupnp.protocols.m_search_patterns import packet_generator
@ -22,7 +21,7 @@ m_search_cli_result = """{
}\n""" }\n"""
class TestCLI(TestBase): class TestCLI(AsyncioTestCase):
gateway_address = "10.0.0.1" gateway_address = "10.0.0.1"
soap_port = 49152 soap_port = 49152
m_search_args = OrderedDict([ m_search_args = OrderedDict([
@ -78,7 +77,7 @@ class TestCLI(TestBase):
with contextlib.redirect_stdout(actual_output): with contextlib.redirect_stdout(actual_output):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies): with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies):
main( main(
(None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'get-external-ip'), [None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'get-external-ip'],
self.loop self.loop
) )
self.assertEqual("11.22.33.44\n", actual_output.getvalue()) self.assertEqual("11.22.33.44\n", actual_output.getvalue())
@ -89,7 +88,7 @@ class TestCLI(TestBase):
with contextlib.redirect_stdout(actual_output): with contextlib.redirect_stdout(actual_output):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies): with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies):
main( main(
(None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'm-search'), [None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'm-search'],
self.loop self.loop
) )
self.assertEqual(timeout_msg, actual_output.getvalue()) self.assertEqual(timeout_msg, actual_output.getvalue())
@ -98,7 +97,7 @@ class TestCLI(TestBase):
with contextlib.redirect_stdout(actual_output): with contextlib.redirect_stdout(actual_output):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies): with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies):
main( main(
(None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', '--unicast', 'm-search'), [None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', '--unicast', 'm-search'],
self.loop self.loop
) )
self.assertEqual(m_search_cli_result, actual_output.getvalue()) self.assertEqual(m_search_cli_result, actual_output.getvalue())

View file

@ -1,10 +1,7 @@
import asyncio
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from tests import TestBase from tests import AsyncioTestCase, mock_tcp_and_udp
from tests.mocks import mock_tcp_and_udp
from collections import OrderedDict from collections import OrderedDict
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway, get_action_list
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
@ -14,7 +11,137 @@ def gen_get_bytes(location: str, host: str) -> bytes:
).encode() ).encode()
class TestDiscoverDLinkDIR890L(TestBase): class TestParseActionList(AsyncioTestCase):
test_action_list = {'actionList': {
'action': [OrderedDict([('name', 'SetConnectionType'), ('argumentList', OrderedDict([('argument', OrderedDict(
[('name', 'NewConnectionType'), ('direction', 'in'), ('relatedStateVariable', 'ConnectionType')]))]))]),
OrderedDict([('name', 'GetConnectionTypeInfo'), ('argumentList', OrderedDict([('argument', [
OrderedDict([('name', 'NewConnectionType'), ('direction', 'out'),
('relatedStateVariable', 'ConnectionType')]), OrderedDict(
[('name', 'NewPossibleConnectionTypes'), ('direction', 'out'),
('relatedStateVariable', 'PossibleConnectionTypes')])])]))]),
OrderedDict([('name', 'RequestConnection')]), OrderedDict([('name', 'ForceTermination')]),
OrderedDict([('name', 'GetStatusInfo'), ('argumentList', OrderedDict([('argument', [OrderedDict(
[('name', 'NewConnectionStatus'), ('direction', 'out'),
('relatedStateVariable', 'ConnectionStatus')]), OrderedDict(
[('name', 'NewLastConnectionError'), ('direction', 'out'),
('relatedStateVariable', 'LastConnectionError')]), OrderedDict(
[('name', 'NewUptime'), ('direction', 'out'), ('relatedStateVariable', 'Uptime')])])]))]),
OrderedDict([('name', 'GetNATRSIPStatus'), ('argumentList', OrderedDict([('argument', [OrderedDict(
[('name', 'NewRSIPAvailable'), ('direction', 'out'),
('relatedStateVariable', 'RSIPAvailable')]), OrderedDict(
[('name', 'NewNATEnabled'), ('direction', 'out'),
('relatedStateVariable', 'NATEnabled')])])]))]), OrderedDict(
[('name', 'GetGenericPortMappingEntry'), ('argumentList', OrderedDict([('argument', [OrderedDict(
[('name', 'NewPortMappingIndex'), ('direction', 'in'),
('relatedStateVariable', 'PortMappingNumberOfEntries')]), OrderedDict(
[('name', 'NewRemoteHost'), ('direction', 'out'), ('relatedStateVariable', 'RemoteHost')]),
OrderedDict(
[('name', 'NewExternalPort'), ('direction', 'out'), ('relatedStateVariable', 'ExternalPort')]),
OrderedDict(
[('name', 'NewProtocol'), ('direction', 'out'),
('relatedStateVariable', 'PortMappingProtocol')]),
OrderedDict([('name',
'NewInternalPort'),
('direction',
'out'), (
'relatedStateVariable',
'InternalPort')]),
OrderedDict([('name',
'NewInternalClient'),
('direction',
'out'), (
'relatedStateVariable',
'InternalClient')]),
OrderedDict([('name',
'NewEnabled'),
('direction',
'out'), (
'relatedStateVariable',
'PortMappingEnabled')]),
OrderedDict([('name',
'NewPortMappingDescription'),
('direction',
'out'), (
'relatedStateVariable',
'PortMappingDescription')]),
OrderedDict([('name',
'NewLeaseDuration'),
('direction',
'out'), (
'relatedStateVariable',
'PortMappingLeaseDuration')])])]))]),
OrderedDict([('name', 'GetSpecificPortMappingEntry'), ('argumentList', OrderedDict([('argument', [
OrderedDict(
[('name', 'NewRemoteHost'), ('direction', 'in'), ('relatedStateVariable', 'RemoteHost')]),
OrderedDict([('name', 'NewExternalPort'), ('direction', 'in'),
('relatedStateVariable', 'ExternalPort')]), OrderedDict(
[('name', 'NewProtocol'), ('direction', 'in'),
('relatedStateVariable', 'PortMappingProtocol')]), OrderedDict(
[('name', 'NewInternalPort'), ('direction', 'out'),
('relatedStateVariable', 'InternalPort')]), OrderedDict(
[('name', 'NewInternalClient'), ('direction', 'out'),
('relatedStateVariable', 'InternalClient')]), OrderedDict(
[('name', 'NewEnabled'), ('direction', 'out'),
('relatedStateVariable', 'PortMappingEnabled')]), OrderedDict(
[('name', 'NewPortMappingDescription'), ('direction', 'out'),
('relatedStateVariable', 'PortMappingDescription')]), OrderedDict(
[('name', 'NewLeaseDuration'), ('direction', 'out'),
('relatedStateVariable', 'PortMappingLeaseDuration')])])]))]), OrderedDict(
[('name', 'AddPortMapping'), ('argumentList', OrderedDict([('argument', [
OrderedDict(
[('name', 'NewRemoteHost'), ('direction', 'in'), ('relatedStateVariable', 'RemoteHost')]),
OrderedDict(
[('name', 'NewExternalPort'), ('direction', 'in'), ('relatedStateVariable', 'ExternalPort')]),
OrderedDict(
[('name', 'NewProtocol'), ('direction', 'in'),
('relatedStateVariable', 'PortMappingProtocol')]),
OrderedDict(
[('name', 'NewInternalPort'), ('direction', 'in'), ('relatedStateVariable', 'InternalPort')]),
OrderedDict(
[('name', 'NewInternalClient'), ('direction', 'in'),
('relatedStateVariable', 'InternalClient')]),
OrderedDict(
[('name', 'NewEnabled'), ('direction', 'in'), ('relatedStateVariable', 'PortMappingEnabled')]),
OrderedDict([('name', 'NewPortMappingDescription'), ('direction', 'in'),
('relatedStateVariable', 'PortMappingDescription')]), OrderedDict(
[('name', 'NewLeaseDuration'), ('direction', 'in'),
('relatedStateVariable', 'PortMappingLeaseDuration')])])]))]), OrderedDict(
[('name', 'DeletePortMapping'), ('argumentList', OrderedDict([('argument', [
OrderedDict(
[('name', 'NewRemoteHost'), ('direction', 'in'), ('relatedStateVariable', 'RemoteHost')]),
OrderedDict(
[('name', 'NewExternalPort'), ('direction', 'in'), ('relatedStateVariable', 'ExternalPort')]),
OrderedDict(
[('name', 'NewProtocol'), ('direction', 'in'),
('relatedStateVariable', 'PortMappingProtocol')])])]))]),
OrderedDict([('name', 'GetExternalIPAddress'),
('argumentList', OrderedDict(
[('argument', OrderedDict([('name', 'NewExternalIPAddress'),
('direction', 'out'),
('relatedStateVariable', 'ExternalIPAddress')]))]))])]}}
def test_parse_expected_action_list(self):
expected = [('SetConnectionType', ['NewConnectionType'], []),
('GetConnectionTypeInfo', [], ['NewConnectionType', 'NewPossibleConnectionTypes']),
('RequestConnection', [], []), ('ForceTermination', [], []),
('GetStatusInfo', [], ['NewConnectionStatus', 'NewLastConnectionError', 'NewUptime']),
('GetNATRSIPStatus', [], ['NewRSIPAvailable', 'NewNATEnabled']), (
'GetGenericPortMappingEntry', ['NewPortMappingIndex'],
['NewRemoteHost', 'NewExternalPort', 'NewProtocol', 'NewInternalPort', 'NewInternalClient',
'NewEnabled', 'NewPortMappingDescription', 'NewLeaseDuration']), (
'GetSpecificPortMappingEntry', ['NewRemoteHost', 'NewExternalPort', 'NewProtocol'],
['NewInternalPort', 'NewInternalClient', 'NewEnabled', 'NewPortMappingDescription',
'NewLeaseDuration']), ('AddPortMapping',
['NewRemoteHost', 'NewExternalPort', 'NewProtocol', 'NewInternalPort',
'NewInternalClient', 'NewEnabled', 'NewPortMappingDescription',
'NewLeaseDuration'], []),
('DeletePortMapping', ['NewRemoteHost', 'NewExternalPort', 'NewProtocol'], []),
('GetExternalIPAddress', [], ['NewExternalIPAddress'])]
self.assertEqual(expected, get_action_list(self.test_action_list))
class TestDiscoverDLinkDIR890L(AsyncioTestCase):
gateway_address = "10.0.0.1" gateway_address = "10.0.0.1"
client_address = "10.0.0.2" client_address = "10.0.0.2"
soap_port = 49152 soap_port = 49152
@ -44,20 +171,20 @@ class TestDiscoverDLinkDIR890L(TestBase):
} }
expected_commands = { expected_commands = {
'GetDefaultConnectionService': 'urn:schemas-upnp-org:service:Layer3Forwarding:1', # 'GetDefaultConnectionService': 'urn:schemas-upnp-org:service:Layer3Forwarding:1',
'SetDefaultConnectionService': 'urn:schemas-upnp-org:service:Layer3Forwarding:1', # 'SetDefaultConnectionService': 'urn:schemas-upnp-org:service:Layer3Forwarding:1',
'GetCommonLinkProperties': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1', # 'GetCommonLinkProperties': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'GetTotalBytesSent': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1', # 'GetTotalBytesSent': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'GetTotalBytesReceived': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1', # 'GetTotalBytesReceived': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'GetTotalPacketsSent': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1', # 'GetTotalPacketsSent': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'GetTotalPacketsReceived': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1', # 'GetTotalPacketsReceived': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'X_GetICSStatistics': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1', # 'X_GetICSStatistics': 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'SetConnectionType': 'urn:schemas-upnp-org:service:WANIPConnection:1', # 'SetConnectionType': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'GetConnectionTypeInfo': 'urn:schemas-upnp-org:service:WANIPConnection:1', # 'GetConnectionTypeInfo': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'RequestConnection': 'urn:schemas-upnp-org:service:WANIPConnection:1', # 'RequestConnection': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'ForceTermination': 'urn:schemas-upnp-org:service:WANIPConnection:1', # 'ForceTermination': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'GetStatusInfo': 'urn:schemas-upnp-org:service:WANIPConnection:1', # 'GetStatusInfo': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'GetNATRSIPStatus': 'urn:schemas-upnp-org:service:WANIPConnection:1', # 'GetNATRSIPStatus': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'GetGenericPortMappingEntry': 'urn:schemas-upnp-org:service:WANIPConnection:1', 'GetGenericPortMappingEntry': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'GetSpecificPortMappingEntry': 'urn:schemas-upnp-org:service:WANIPConnection:1', 'GetSpecificPortMappingEntry': 'urn:schemas-upnp-org:service:WANIPConnection:1',
'AddPortMapping': 'urn:schemas-upnp-org:service:WANIPConnection:1', 'AddPortMapping': 'urn:schemas-upnp-org:service:WANIPConnection:1',
@ -111,22 +238,22 @@ class TestDiscoverNetgearNighthawkAC2350(TestDiscoverDLinkDIR890L):
} }
expected_commands = { expected_commands = {
"SetDefaultConnectionService": "urn:schemas-upnp-org:service:Layer3Forwarding:1", # "SetDefaultConnectionService": "urn:schemas-upnp-org:service:Layer3Forwarding:1",
"GetDefaultConnectionService": "urn:schemas-upnp-org:service:Layer3Forwarding:1", # "GetDefaultConnectionService": "urn:schemas-upnp-org:service:Layer3Forwarding:1",
"GetCommonLinkProperties": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1", # "GetCommonLinkProperties": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1",
"GetTotalBytesSent": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1", # "GetTotalBytesSent": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1",
"GetTotalBytesReceived": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1", # "GetTotalBytesReceived": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1",
"GetTotalPacketsSent": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1", # "GetTotalPacketsSent": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1",
"GetTotalPacketsReceived": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1", # "GetTotalPacketsReceived": "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1",
"AddPortMapping": "urn:schemas-upnp-org:service:WANIPConnection:1", "AddPortMapping": "urn:schemas-upnp-org:service:WANIPConnection:1",
"GetExternalIPAddress": "urn:schemas-upnp-org:service:WANIPConnection:1", "GetExternalIPAddress": "urn:schemas-upnp-org:service:WANIPConnection:1",
"DeletePortMapping": "urn:schemas-upnp-org:service:WANIPConnection:1", "DeletePortMapping": "urn:schemas-upnp-org:service:WANIPConnection:1",
"SetConnectionType": "urn:schemas-upnp-org:service:WANIPConnection:1", # "SetConnectionType": "urn:schemas-upnp-org:service:WANIPConnection:1",
"GetConnectionTypeInfo": "urn:schemas-upnp-org:service:WANIPConnection:1", # "GetConnectionTypeInfo": "urn:schemas-upnp-org:service:WANIPConnection:1",
"RequestConnection": "urn:schemas-upnp-org:service:WANIPConnection:1", # "RequestConnection": "urn:schemas-upnp-org:service:WANIPConnection:1",
"ForceTermination": "urn:schemas-upnp-org:service:WANIPConnection:1", # "ForceTermination": "urn:schemas-upnp-org:service:WANIPConnection:1",
"GetStatusInfo": "urn:schemas-upnp-org:service:WANIPConnection:1", # "GetStatusInfo": "urn:schemas-upnp-org:service:WANIPConnection:1",
"GetNATRSIPStatus": "urn:schemas-upnp-org:service:WANIPConnection:1", # "GetNATRSIPStatus": "urn:schemas-upnp-org:service:WANIPConnection:1",
"GetGenericPortMappingEntry": "urn:schemas-upnp-org:service:WANIPConnection:1", "GetGenericPortMappingEntry": "urn:schemas-upnp-org:service:WANIPConnection:1",
"GetSpecificPortMappingEntry": "urn:schemas-upnp-org:service:WANIPConnection:1" "GetSpecificPortMappingEntry": "urn:schemas-upnp-org:service:WANIPConnection:1"
} }

70
tests/test_interfaces.py Normal file
View file

@ -0,0 +1,70 @@
from unittest import mock
from aioupnp.fault import UPnPError
from aioupnp.upnp import UPnP
from tests import AsyncioTestCase
class mock_netifaces:
@staticmethod
def gateways():
return {
"default": {
2: [
"192.168.1.1",
"test0"
]
},
2: [
[
"192.168.1.1",
"test0",
True
]
]
}
@staticmethod
def interfaces():
return ['test0']
@staticmethod
def ifaddresses(interface):
return {
"test0": {
17: [
{
"addr": "01:02:03:04:05:06",
"broadcast": "ff:ff:ff:ff:ff:ff"
}
],
2: [
{
"addr": "192.168.1.2",
"netmask": "255.255.255.0",
"broadcast": "192.168.1.255"
}
],
},
}[interface]
class TestParseInterfaces(AsyncioTestCase):
def test_parse_interfaces(self):
with mock.patch('aioupnp.interfaces.get_netifaces') as patch:
patch.return_value = mock_netifaces
lan, gateway = UPnP.get_lan_and_gateway(interface_name='test0')
self.assertEqual(gateway, '192.168.1.1')
self.assertEqual(lan, '192.168.1.2')
async def test_netifaces_fail(self):
checked = []
with mock.patch('aioupnp.interfaces.get_netifaces') as patch:
patch.return_value = mock_netifaces
try:
await UPnP.discover(interface_name='test1')
except UPnPError as err:
self.assertEqual(str(err), 'failed to get lan and gateway addresses for test1')
checked.append(True)
else:
self.assertTrue(False)
self.assertTrue(len(checked) == 1)

125
tests/test_upnp.py Normal file

File diff suppressed because one or more lines are too long