mypy refactor

This commit is contained in:
Jack Robison 2019-05-21 18:16:30 -04:00
parent a404269d91
commit 4137f7cd8a
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
20 changed files with 1357 additions and 643 deletions

4
.coveragerc Normal file
View file

@ -0,0 +1,4 @@
[run]
omit =
tests/*
stubs/*

3
.gitignore vendored
View file

@ -3,6 +3,9 @@
_trial_temp/ _trial_temp/
build/ build/
dist/ dist/
html/
index.html
mypy-html.css
.coverage .coverage
.mypy_cache/ .mypy_cache/
aioupnp.spec aioupnp.spec

440
.pylintrc Normal file
View file

@ -0,0 +1,440 @@
[MASTER]
# Specify a configuration file.
#rcfile=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,schema
# Add files or directories matching the regex patterns to the
# blacklist. The regex matches against base names, not paths.
# `\.#.*` - add emacs tmp files to the blacklist
ignore-patterns=\.#.*
# Pickle collected data for later comparisons.
persistent=yes
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=1
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
# extension-pkg-whitelist=
# Allow optimization of some AST trees. This will activate a peephole AST
# optimizer, which will apply various small optimizations. For instance, it can
# be used to obtain the result of joining multiple strings with the addition
# operator. Joining a lot of strings can lead to a maximum recursion error in
# Pylint and this flag can prevent that. It has one side effect, the resulting
# AST will be different than the one from reality.
optimize-ast=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=
anomalous-backslash-in-string,
arguments-differ,
attribute-defined-outside-init,
bad-continuation,
bare-except,
broad-except,
cell-var-from-loop,
consider-iterating-dictionary,
dangerous-default-value,
duplicate-code,
fixme,
global-statement,
inherit-non-class,
invalid-name,
len-as-condition,
locally-disabled,
logging-not-lazy,
missing-docstring,
no-else-return,
no-init,
no-member,
no-self-use,
protected-access,
redefined-builtin,
redefined-outer-name,
redefined-variable-type,
relative-import,
signature-differs,
super-init-not-called,
too-few-public-methods,
too-many-arguments,
too-many-branches,
too-many-instance-attributes,
too-many-lines,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
undefined-loop-variable,
ungrouped-imports,
unnecessary-lambda,
unused-argument,
unused-variable,
wildcard-import,
wrong-import-order,
wrong-import-position,
deprecated-lambda,
simplifiable-if-statement,
unidiomatic-typecheck,
global-at-module-level,
inconsistent-return-statements,
keyword-arg-before-vararg,
assignment-from-no-return,
useless-return,
assignment-from-none,
stop-iteration-return
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]".
files-output=no
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=_$|dummy
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging
[BASIC]
# List of builtins function names that should not be used, separated by a comma
bad-functions=map,filter,input
# Good variable names which should always be accepted, separated by a comma
# allow `d` as its used frequently for deferred callback chains
good-names=i,j,k,ex,Run,_,d
# Bad variable names which should always be refused, separated by a comma
bad-names=foo,bar,baz,toto,tutu,tata
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct variable names
variable-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for variable names
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct constant names
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Naming hint for constant names
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Regular expression matching correct attribute names
attr-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for attribute names
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for argument names
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Naming hint for class attribute names
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
# Naming hint for inline iteration names
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
# Naming hint for class names
class-name-hint=[A-Z_][a-zA-Z0-9]+$
# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Naming hint for module names
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Regular expression matching correct method names
method-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for method names
method-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
[ELIF]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,dict-separator
# Maximum number of lines in a module
max-module-lines=1000
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,XXX,TODO
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[TYPECHECK]
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=leveldb,distutils
# Ignoring distutils because: https://github.com/PyCQA/pylint/issues/73
# List of classes names for which member attributes should not be checked
# (useful for classes with attributes dynamically set). This supports can work
# with qualified names.
# ignored-classes=
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,TERMIOS,Bastion,rexec,miniupnpc
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
[DESIGN]
# Maximum number of arguments for function / method
max-args=10
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of statements in function / method body
max-statements=50
# Maximum number of parents for a class (see R0901).
max-parents=8
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of boolean expressions in a if statement
max-bool-expr=5
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,__new__,setUp
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception

View file

@ -1,8 +1,11 @@
import logging
import sys import sys
import asyncio
import logging
import textwrap import textwrap
import typing
from collections import OrderedDict from collections import OrderedDict
from aioupnp.upnp import UPnP from aioupnp.upnp import UPnP
from aioupnp.commands import SOAPCommands
log = logging.getLogger("aioupnp") log = logging.getLogger("aioupnp")
handler = logging.StreamHandler() handler = logging.StreamHandler()
@ -16,17 +19,18 @@ base_usage = "\n".join(textwrap.wrap(
100, subsequent_indent=' ', break_long_words=False)) + "\n" 100, subsequent_indent=' ', break_long_words=False)) + "\n"
def get_help(command): def get_help(command: str) -> str:
fn = getattr(UPnP, command) annotations = UPnP.get_annotations(command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return']) params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return'])
return base_usage + "\n".join( return base_usage + "\n".join(
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False) textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False)
) )
def main(argv=None, loop=None): def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
argv = argv or sys.argv loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")] argv = argv or list(sys.argv)
commands = list(SOAPCommands.SOAP_COMMANDS)
help_str = "\n".join(textwrap.wrap( help_str = "\n".join(textwrap.wrap(
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False " | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
)) ))
@ -41,14 +45,16 @@ def main(argv=None, loop=None):
"For help with a specific command:" \ "For help with a specific command:" \
" aioupnp help <command>\n" % (base_usage, help_str) " aioupnp help <command>\n" % (base_usage, help_str)
args = argv[1:] args: typing.List[str] = [str(arg) for arg in argv[1:]]
if args[0] in ['help', '-h', '--help']: if args[0] in ['help', '-h', '--help']:
if len(args) > 1: if len(args) > 1:
if args[1] in commands: if args[1] in commands:
sys.exit(get_help(args[1])) print(get_help(args[1]))
sys.exit(print(usage)) return 0
print(usage)
return 0
defaults = { defaults: typing.Dict[str, typing.Union[bool, str, int]] = {
'debug_logging': False, 'debug_logging': False,
'interface': 'default', 'interface': 'default',
'gateway_address': '', 'gateway_address': '',
@ -57,22 +63,22 @@ def main(argv=None, loop=None):
'unicast': False 'unicast': False
} }
options = OrderedDict() options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
command = None command = None
for arg in args: for arg in args:
if arg.startswith("--"): if arg.startswith("--"):
if "=" in arg: if "=" in arg:
k, v = arg.split("=") k, v = arg.split("=")
options[k.lstrip('--')] = v
else: else:
k, v = arg, True options[arg.lstrip('--')] = True
k = k.lstrip('--')
options[k] = v
else: else:
command = arg command = arg
break break
if not command: if not command:
print("no command given") print("no command given")
sys.exit(print(usage)) print(usage)
return 0
kwargs = {} kwargs = {}
for arg in args[len(options)+1:]: for arg in args[len(options)+1:]:
if arg.startswith("--"): if arg.startswith("--"):
@ -81,18 +87,24 @@ def main(argv=None, loop=None):
kwargs[k] = v kwargs[k] = v
else: else:
break break
for k, v in defaults.items(): for k in defaults:
if k not in options: if k not in options:
options[k] = v options[k] = defaults[k]
if options.pop('debug_logging'): if options.pop('debug_logging'):
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
lan_address: str = str(options.pop('lan_address'))
gateway_address: str = str(options.pop('gateway_address'))
timeout: int = int(options.pop('timeout'))
interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
UPnP.run_cli( UPnP.run_cli(
command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'), command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop
) )
return 0
if __name__ == "__main__": if __name__ == "__main__":
main() sys.exit(main())

View file

@ -1,64 +1,67 @@
import logging import asyncio
import time import time
import typing import typing
from typing import Tuple, Union, List import functools
import logging
from typing import Tuple
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
from aioupnp.device import Service
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
none_or_str = Union[None, str]
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
def safe_type(t): def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]:
if t is typing.Tuple: return x if x is not None and str(x).lower() not in ['none', 'nil'] else None
return tuple
if t is typing.List:
return list
if t is typing.Dict:
return dict
if t is typing.Set:
return set
return t
class SOAPCommand: def soap_bool(x: typing.Optional[str]) -> bool:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, return False if not x or str(x).lower() in ['false', 'False'] else True
param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None:
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
self.service_id = service_id
self.method = method
self.param_types = param_types
self.param_order = param_order
self.return_types = return_types
self.return_order = return_order
self.loop = loop
self._requests: typing.List = []
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()): def recast_single_result(t, result):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) if t is bool:
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()} return soap_bool(result)
if t is str:
return soap_optional_str(result)
return t(result)
def recast_return(return_annotation, result, result_keys: typing.List[str]):
if return_annotation is None:
return None
if len(result_keys) == 1:
assert len(result_keys) == 1
single_result = result[result_keys[0]]
return recast_single_result(return_annotation, single_result)
annotated_args: typing.List[type] = list(return_annotation.__args__)
assert len(annotated_args) == len(result_keys)
recast_results: typing.List[typing.Optional[typing.Union[str, int, bool, bytes]]] = []
for type_annotation, result_key in zip(annotated_args, result_keys):
recast_results.append(recast_single_result(type_annotation, result[result_key]))
return tuple(recast_results)
def soap_command(fn):
@functools.wraps(fn)
async def wrapper(self: 'SOAPCommands', **kwargs):
if not self.is_registered(fn.__name__):
return fn(self, **kwargs)
service = self.get_service(fn.__name__)
assert service.controlURL is not None
assert service.serviceType is not None
response, xml_bytes, err = await scpd_post( response, xml_bytes, err = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, service.controlURL, self._base_address.decode(), self._port, fn.__name__, self._registered[service][fn.__name__][0],
self.service_id, self.loop, **soap_kwargs service.serviceType.encode(), self._loop, **kwargs
) )
if err is not None: if err is not None:
self._requests.append((soap_kwargs, xml_bytes, None, err, time.time()))
self._requests.append((fn.__name__, kwargs, xml_bytes, None, err, time.time()))
raise err raise err
if not response: result = recast_return(fn.__annotations__.get('return'), response, self._registered[service][fn.__name__][1])
result = None self._requests.append((fn.__name__, kwargs, xml_bytes, result, None, time.time()))
else:
recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
if len(recast_result) == 1:
result = recast_result[0]
else:
result = recast_result
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time()))
return result return result
return wrapper
class SOAPCommands: class SOAPCommands:
@ -72,7 +75,7 @@ class SOAPCommands:
to their expected types. to their expected types.
""" """
SOAP_COMMANDS = [ SOAP_COMMANDS: typing.List[str] = [
'AddPortMapping', 'AddPortMapping',
'GetNATRSIPStatus', 'GetNATRSIPStatus',
'GetGenericPortMappingEntry', 'GetGenericPortMappingEntry',
@ -91,59 +94,63 @@ class SOAPCommands:
'GetTotalPacketsReceived', 'GetTotalPacketsReceived',
'X_GetICSStatistics', 'X_GetICSStatistics',
'GetDefaultConnectionService', 'GetDefaultConnectionService',
'NewDefaultConnectionService',
'NewEnabledForInternet',
'SetDefaultConnectionService', 'SetDefaultConnectionService',
'SetEnabledForInternet', 'SetEnabledForInternet',
'GetEnabledForInternet', 'GetEnabledForInternet',
'NewActiveConnectionIndex',
'GetMaximumActiveConnections', 'GetMaximumActiveConnections',
'GetActiveConnections' 'GetActiveConnections'
] ]
def __init__(self): def __init__(self, loop: asyncio.AbstractEventLoop, base_address: bytes, port: int) -> None:
self._registered = set() self._loop = loop
self._registered: typing.Dict[Service,
typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]]] = {}
self._base_address = base_address
self._port = port
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Optional[typing.Dict[str, typing.Any]],
typing.Optional[Exception], float]] = []
def register(self, base_ip: bytes, port: int, name: str, control_url: str, def is_registered(self, name: str) -> bool:
service_type: bytes, inputs: List, outputs: List, loop=None) -> None: if name not in self.SOAP_COMMANDS:
if name not in self.SOAP_COMMANDS or name in self._registered: raise ValueError("unknown command")
for service in self._registered.values():
if name in service:
return True
return False
def get_service(self, name: str) -> Service:
if name not in self.SOAP_COMMANDS:
raise ValueError("unknown command")
for service, commands in self._registered.items():
if name in commands:
return service
raise ValueError(name)
def register(self, name: str, service: Service, inputs: typing.List[str], outputs: typing.List[str]) -> None:
# control_url: str, service_type: bytes,
if name not in self.SOAP_COMMANDS:
raise AttributeError(name) raise AttributeError(name)
current = getattr(self, name) if self.is_registered(name):
annotations = current.__annotations__ raise AttributeError(f"{name} is already a registered SOAP command")
return_types = annotations.get('return', None) if service not in self._registered:
if return_types: self._registered[service] = {}
if hasattr(return_types, '__args__'): self._registered[service][name] = inputs, outputs
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
elif isinstance(return_types, type):
return_types = (return_types,)
return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {}
for param_name, param_type in annotations.items():
if param_name == "return":
continue
param_types[param_name] = param_type
command = SOAPCommand(
base_ip.decode(), port, control_url, service_type,
name, param_types, return_types, inputs, outputs, loop=loop
)
setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command)
self._registered.add(command.method)
@staticmethod @soap_command
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int, async def AddPortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str, NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
NewLeaseDuration: str) -> None: NewLeaseDuration: str) -> None:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetNATRSIPStatus() -> Tuple[bool, bool]: async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
"""Returns (NewRSIPAvailable, NewNATEnabled)""" """Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str, async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
bool, str, int]: bool, str, int]:
""" """
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
@ -151,100 +158,100 @@ class SOAPCommands:
""" """
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]: NewProtocol: str) -> Tuple[int, str, bool, str, int]:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def SetConnectionType(NewConnectionType: str) -> None: async def SetConnectionType(self, NewConnectionType: str) -> None:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetExternalIPAddress() -> str: async def GetExternalIPAddress(self) -> str:
"""Returns (NewExternalIPAddress)""" """Returns (NewExternalIPAddress)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetConnectionTypeInfo() -> Tuple[str, str]: async def GetConnectionTypeInfo(self) -> Tuple[str, str]:
"""Returns (NewConnectionType, NewPossibleConnectionTypes)""" """Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetStatusInfo() -> Tuple[str, str, int]: async def GetStatusInfo(self) -> Tuple[str, str, int]:
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def ForceTermination() -> None: async def ForceTermination(self) -> None:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None: async def DeletePortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def RequestConnection() -> None: async def RequestConnection(self) -> None:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetCommonLinkProperties(): async def GetCommonLinkProperties(self) -> Tuple[str, int, int, str]:
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)""" """Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetTotalBytesSent(): async def GetTotalBytesSent(self) -> int:
"""Returns (NewTotalBytesSent)""" """Returns (NewTotalBytesSent)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetTotalBytesReceived(): async def GetTotalBytesReceived(self) -> int:
"""Returns (NewTotalBytesReceived)""" """Returns (NewTotalBytesReceived)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetTotalPacketsSent(): async def GetTotalPacketsSent(self) -> int:
"""Returns (NewTotalPacketsSent)""" """Returns (NewTotalPacketsSent)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
def GetTotalPacketsReceived(): async def GetTotalPacketsReceived(self) -> int:
"""Returns (NewTotalPacketsReceived)""" """Returns (NewTotalPacketsReceived)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]: async def X_GetICSStatistics(self) -> Tuple[int, int, int, int, str, str]:
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)""" """Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetDefaultConnectionService(): async def GetDefaultConnectionService(self) -> str:
"""Returns (NewDefaultConnectionService)""" """Returns (NewDefaultConnectionService)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None: async def SetDefaultConnectionService(self, NewDefaultConnectionService: str) -> None:
"""Returns (None)""" """Returns (None)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None: async def SetEnabledForInternet(self, NewEnabledForInternet: bool) -> None:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetEnabledForInternet() -> bool: async def GetEnabledForInternet(self) -> bool:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int): async def GetMaximumActiveConnections(self, NewActiveConnectionIndex: int):
raise NotImplementedError() raise NotImplementedError()
@staticmethod @soap_command
async def GetActiveConnections() -> Tuple[str, str]: async def GetActiveConnections(self) -> Tuple[str, str]:
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError() raise NotImplementedError()

View file

@ -1,23 +1,33 @@
from collections import OrderedDict
import typing
import logging import logging
from typing import List
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class CaseInsensitive: class CaseInsensitive:
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List[typing.Any]]]) -> None:
for k, v in kwargs.items(): keys: typing.List[str] = list(kwargs.keys())
for k in keys:
if not k.startswith("_"): if not k.startswith("_"):
setattr(self, k, v) assert k in kwargs
setattr(self, k, kwargs[k])
def __getattr__(self, item): def __getattr__(self, item: str) -> typing.Union[str, typing.Dict[str, typing.Any], typing.List]:
for k in self.__class__.__dict__.keys(): keys: typing.List[str] = list(self.__class__.__dict__.keys())
for k in keys:
if k.lower() == item.lower(): if k.lower() == item.lower():
return self.__dict__.get(k) value: typing.Optional[typing.Union[str, typing.Dict[str, typing.Any],
typing.List]] = self.__dict__.get(k)
assert value is not None and isinstance(value, (str, dict, list))
return value
raise AttributeError(item) raise AttributeError(item)
def __setattr__(self, item, value): def __setattr__(self, item: str,
for k, v in self.__class__.__dict__.items(): value: typing.Union[str, typing.Dict[str, typing.Any], typing.List]) -> None:
assert isinstance(value, (str, dict)), ValueError(f"got type {str(type(value))}, expected str")
keys: typing.List[str] = list(self.__class__.__dict__.keys())
for k in keys:
if k.lower() == item.lower(): if k.lower() == item.lower():
self.__dict__[k] = value self.__dict__[k] = value
return return
@ -26,52 +36,57 @@ class CaseInsensitive:
return return
raise AttributeError(item) raise AttributeError(item)
def as_dict(self) -> dict: def as_dict(self) -> typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]:
return { result: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]] = OrderedDict()
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) keys: typing.List[str] = list(self.__dict__.keys())
} for k in keys:
if not k.startswith("_"):
result[k] = self.__getattr__(k)
return result
class Service(CaseInsensitive): class Service(CaseInsensitive):
serviceType = None serviceType: typing.Optional[str] = None
serviceId = None serviceId: typing.Optional[str] = None
controlURL = None controlURL: typing.Optional[str] = None
eventSubURL = None eventSubURL: typing.Optional[str] = None
SCPDURL = None SCPDURL: typing.Optional[str] = None
class Device(CaseInsensitive): class Device(CaseInsensitive):
serviceList = None serviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
deviceList = None deviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
deviceType = None deviceType: typing.Optional[str] = None
friendlyName = None friendlyName: typing.Optional[str] = None
manufacturer = None manufacturer: typing.Optional[str] = None
manufacturerURL = None manufacturerURL: typing.Optional[str] = None
modelDescription = None modelDescription: typing.Optional[str] = None
modelName = None modelName: typing.Optional[str] = None
modelNumber = None modelNumber: typing.Optional[str] = None
modelURL = None modelURL: typing.Optional[str] = None
serialNumber = None serialNumber: typing.Optional[str] = None
udn = None udn: typing.Optional[str] = None
upc = None upc: typing.Optional[str] = None
presentationURL = None presentationURL: typing.Optional[str] = None
iconList = None iconList: typing.Optional[str] = None
def __init__(self, devices: List, services: List, **kwargs) -> None: def __init__(self, devices: typing.List['Device'], services: typing.List[Service],
**kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]) -> None:
super(Device, self).__init__(**kwargs) super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList: if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"] if isinstance(self.serviceList['service'], dict):
if isinstance(new_services, dict): assert isinstance(self.serviceList['service'], dict)
new_services = [new_services] svc_list: typing.Dict[str, typing.Any] = self.serviceList['service']
services.extend([Service(**service) for service in new_services]) services.append(Service(**svc_list))
elif isinstance(self.serviceList['service'], list):
services.extend(Service(**svc) for svc in self.serviceList["service"])
if self.deviceList: if self.deviceList:
for kw in self.deviceList.values(): for kw in self.deviceList.values():
if isinstance(kw, dict): if isinstance(kw, dict):
d = Device(devices, services, **kw) devices.append(Device(devices, services, **kw))
devices.append(d)
elif isinstance(kw, list): elif isinstance(kw, list):
for _inner_kw in kw: for _inner_kw in kw:
d = Device(devices, services, **_inner_kw) devices.append(Device(devices, services, **_inner_kw))
devices.append(d)
else: else:
log.warning("failed to parse device:\n%s", kw) log.warning("failed to parse device:\n%s", kw)

View file

@ -1,14 +1,2 @@
from aioupnp.util import flatten_keys
from aioupnp.constants import FAULT, CONTROL
class UPnPError(Exception): class UPnPError(Exception):
pass pass
def handle_fault(response: dict) -> dict:
if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
error_description = fault['detail']['UPnPError']['errorDescription']
raise UPnPError(error_description)
return response

View file

@ -1,9 +1,10 @@
import re
import logging import logging
import socket import typing
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Union, Type from typing import Dict, List, Union
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
@ -15,77 +16,94 @@ from aioupnp.fault import UPnPError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
return_type_lambas = { BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
}
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...] def get_action_list(element_dict: typing.Dict[str, typing.Union[str, typing.Dict[str, str],
typing.List[typing.Dict[str, typing.Dict[str, str]]]]]
) -> typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]]:
service_info = flatten_keys(element_dict, "{%s}" % SERVICE) service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
result: typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]] = []
if "actionList" in service_info: if "actionList" in service_info:
action_list = service_info["actionList"] action_list = service_info["actionList"]
else: else:
return [] return result
if not len(action_list): # it could be an empty string if not len(action_list): # it could be an empty string
return [] return result
result: list = [] action = action_list["action"]
if isinstance(action_list["action"], dict): if isinstance(action, dict):
arg_dicts = action_list["action"]['argumentList']['argument'] arg_dicts: typing.List[typing.Dict[str, str]] = []
if not isinstance(arg_dicts, list): # when there is one arg if not isinstance(action['argumentList']['argument'], list): # when there is one arg
arg_dicts = [arg_dicts] arg_dicts.extend([action['argumentList']['argument']])
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else: else:
arg_dicts = action['argumentList']['argument'] arg_dicts.extend(action['argumentList']['argument'])
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts] result.append((action_list["action"]['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']))
return result
assert isinstance(action, list)
for _action in action:
if not _action.get('argumentList'):
result.append((_action['name'], [], []))
else:
if not isinstance(_action['argumentList']['argument'], list): # when there is one arg
arg_dicts = [_action['argumentList']['argument']]
else:
arg_dicts = _action['argumentList']['argument']
result.append(( result.append((
action['name'], _action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out'] [i['name'] for i in arg_dicts if i['direction'] == 'out']
)) ))
return result return result
def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
base_address_result: typing.List[bytes] = BASE_ADDRESS_REGEX.findall(location)
base_address = base_address_result[0]
port_result: typing.List[bytes] = BASE_PORT_REGEX.findall(location)
port = int(port_result[0])
return base_address, port
class Gateway: class Gateway:
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str, commands: SOAPCommands
gateway_address: str) -> None:
def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]],
lan_address: str, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet self._ok_packet = ok_packet
self._m_search_args = m_search_args self._m_search_args = m_search_args
self._lan_address = lan_address self._lan_address = lan_address
self.usn = (ok_packet.usn or '').encode() self.usn: bytes = (ok_packet.usn or '').encode()
self.ext = (ok_packet.ext or '').encode() self.ext: bytes = (ok_packet.ext or '').encode()
self.server = (ok_packet.server or '').encode() self.server: bytes = (ok_packet.server or '').encode()
self.location = (ok_packet.location or '').encode() self.location: bytes = (ok_packet.location or '').encode()
self.cache_control = (ok_packet.cache_control or '').encode() self.cache_control: bytes = (ok_packet.cache_control or '').encode()
self.date = (ok_packet.date or '').encode() self.date: bytes = (ok_packet.date or '').encode()
self.urn = (ok_packet.st or '').encode() self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response = b"" self._xml_response: bytes = b""
self._service_descriptors: Dict = {} self._service_descriptors: Dict[str, bytes] = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.base_address, self.port = parse_location(self.location)
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
assert self.base_ip == gateway_address.encode() assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version = None self.spec_version: typing.Optional[str] = None
self.url_base = None self.url_base: typing.Optional[str] = None
self._device: Union[None, Device] = None self._device: typing.Optional[Device] = None
self._devices: List = [] self._devices: List[Device] = []
self._services: List = [] self._services: List[Service] = []
self._unsupported_actions: Dict = {} self._unsupported_actions: Dict[str, typing.List[str]] = {}
self._registered_commands: Dict = {} self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands() self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
def gateway_descriptor(self) -> dict: def gateway_descriptor(self) -> dict:
r = { r = {
@ -102,14 +120,15 @@ class Gateway:
def manufacturer_string(self) -> str: def manufacturer_string(self) -> str:
if not self.devices: if not self.devices:
return "UNKNOWN GATEWAY" return "UNKNOWN GATEWAY"
device = list(self.devices.values())[0] devices: typing.List[Device] = list(self.devices.values())
return "%s %s" % (device.manufacturer, device.modelName) device = devices[0]
return f"{device.manufacturer} {device.modelName}"
@property @property
def services(self) -> Dict: def services(self) -> Dict[str, Service]:
if not self._device: if not self._device:
return {} return {}
return {service.serviceType: service for service in self._services} return {str(service.serviceType): service for service in self._services}
@property @property
def devices(self) -> Dict: def devices(self) -> Dict:
@ -117,28 +136,29 @@ class Gateway:
return {} return {}
return {device.udn: device for device in self._devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Union[Type[Service], None]: def get_service(self, service_type: str) -> typing.Optional[Service]:
for service in self._services: for service in self._services:
if service.serviceType.lower() == service_type.lower(): if service.serviceType and service.serviceType.lower() == service_type.lower():
return service return service
return None return None
@property @property
def soap_requests(self) -> List: def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
soap_call_infos = [] typing.Optional[typing.Dict[str, typing.Any]],
for name in self._registered_commands.keys(): typing.Optional[Exception], float]]:
if not hasattr(getattr(self.commands, name), "_requests"): soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
continue typing.Optional[typing.Dict[str, typing.Any]],
soap_call_infos.extend([ typing.Optional[Exception], float]] = []
(name, request_args, raw_response, decoded_response, soap_error, ts) soap_call_infos.extend([
for ( (name, request_args, raw_response, decoded_response, soap_error, ts)
request_args, raw_response, decoded_response, soap_error, ts for (
) in getattr(self.commands, name)._requests name, request_args, raw_response, decoded_response, soap_error, ts
]) ) in self.commands._requests
])
soap_call_infos.sort(key=lambda x: x[5]) soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos return soap_call_infos
def debug_gateway(self) -> Dict: def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
return { return {
'manufacturer_string': self.manufacturer_string, 'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip, 'gateway_address': self.base_ip,
@ -156,9 +176,11 @@ class Gateway:
@classmethod @classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = False): igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
ignored: set = set() loop: typing.Optional[asyncio.AbstractEventLoop] = None,
required_commands = [ unicast: bool = False) -> 'Gateway':
ignored: typing.Set[str] = set()
required_commands: typing.List[str] = [
'AddPortMapping', 'AddPortMapping',
'DeletePortMapping', 'DeletePortMapping',
'GetExternalIPAddress' 'GetExternalIPAddress'
@ -166,20 +188,21 @@ class Gateway:
while True: while True:
if not igd_args: if not igd_args:
m_search_args, datagram = await fuzzy_m_search( m_search_args, datagram = await fuzzy_m_search(
lan_address, gateway_address, timeout, loop, ignored, unicast lan_address, gateway_address, timeout, loop, ignored, unicast
) )
else: else:
m_search_args = OrderedDict(igd_args) m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast) datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
try: try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address) gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location) log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands(loop) await gateway.discover_commands(loop)
requirements_met = all([required in gateway._registered_commands for required in required_commands]) requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met: if not requirements_met:
not_met = [ not_met = [
required for required in required_commands if required not in gateway._registered_commands required for required in required_commands if not gateway.commands.is_registered(required)
] ]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s", log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met) gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location) ignored.add(datagram.location)
@ -188,13 +211,17 @@ class Gateway:
log.debug('found gateway device %s', datagram.location) log.debug('found gateway device %s', datagram.location)
return gateway return gateway
except (asyncio.TimeoutError, UPnPError) as err: except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
ignored.add(datagram.location) ignored.add(datagram.location)
continue continue
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = None): igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: typing.Optional[bool] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop()
if unicast is not None: if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
@ -205,7 +232,7 @@ class Gateway:
cls._discover_gateway( cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False lan_address, gateway_address, timeout, igd_args, loop, unicast=False
) )
], return_when=asyncio.tasks.FIRST_COMPLETED) ], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
for task in pending: for task in pending:
task.cancel() task.cancel()
@ -214,56 +241,78 @@ class Gateway:
task.exception() task.exception()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
results: typing.List[asyncio.Future['Gateway']] = list(done)
return results[0].result()
return list(done)[0].result() async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
async def discover_commands(self, loop=None):
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop) response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
self._xml_response = xml_bytes self._xml_response = xml_bytes
if get_err is not None: if get_err is not None:
raise get_err raise get_err
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase") if isinstance(spec_version, bytes):
self.spec_version = spec_version.decode()
else:
self.spec_version = spec_version
url_base = get_dict_val_case_insensitive(response, "urlbase")
if isinstance(url_base, bytes):
self.url_base = url_base.decode()
else:
self.url_base = url_base
if not self.url_base: if not self.url_base:
self.url_base = self.base_address.decode() self.url_base = self.base_address.decode()
if response: if response:
device_dict = get_dict_val_case_insensitive(response, "device") source_keys: typing.List[str] = list(response.keys())
matches: typing.List[str] = list(filter(lambda x: x.lower() == "device", source_keys))
match_key = matches[0]
match: dict = response[match_key]
# if not len(match):
# return None
# if len(match) > 1:
# raise KeyError("overlapping keys")
# if len(match) == 1:
# matched_key: typing.AnyStr = match[0]
# return source[matched_key]
# raise KeyError("overlapping keys")
self._device = Device( self._device = Device(
self._devices, self._services, **device_dict self._devices, self._services, **match
) )
else: else:
self._device = Device(self._devices, self._services) self._device = Device(self._devices, self._services)
for service_type in self.services.keys(): for service_type in self.services.keys():
await self.register_commands(self.services[service_type], loop) await self.register_commands(self.services[service_type], loop)
return None
async def register_commands(self, service: Service, loop=None): async def register_commands(self, service: Service,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL: if not service.SCPDURL:
raise UPnPError("no scpd url") raise UPnPError("no scpd url")
if not service.serviceType:
raise UPnPError("no service type")
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
self._service_descriptors[service.SCPDURL] = xml_bytes self._service_descriptors[service.SCPDURL] = xml_bytes
if get_err is not None: if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
if xml_bytes: if xml_bytes:
log.debug("response: %s", xml_bytes.decode()) log.debug("response: %s", xml_bytes.decode())
return return None
if not service_dict: if not service_dict:
return return None
action_list = get_action_list(service_dict) action_list = get_action_list(service_dict)
for name, inputs, outputs in action_list: for name, inputs, outputs in action_list:
try: try:
self.commands.register(self.base_ip, self.port, name, service.controlURL, service.serviceType.encode(), self.commands.register(name, service, inputs, outputs)
inputs, outputs, loop)
self._registered_commands[name] = service.serviceType self._registered_commands[name] = service.serviceType
log.debug("registered %s::%s", service.serviceType, name) log.debug("registered %s::%s", service.serviceType, name)
except AttributeError: except AttributeError:
s = self._unsupported_actions.get(service.serviceType, []) self._unsupported_actions.setdefault(service.serviceType, [])
s.append(name) self._unsupported_actions[service.serviceType].append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s", log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs) service.serviceType, name, inputs, outputs)
log.debug("registered service %s", service.serviceType) log.debug("registered service %s", service.serviceType)
return None

51
aioupnp/interfaces.py Normal file
View file

@ -0,0 +1,51 @@
import socket
from collections import OrderedDict
import typing
import netifaces
def get_netifaces():
return netifaces
def ifaddresses(iface: str):
return get_netifaces().ifaddresses(iface)
def _get_interfaces():
return get_netifaces().interfaces()
def _get_gateways():
return get_netifaces().gateways()
def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]:
gateways = _get_gateways()
infos = gateways[socket.AF_INET]
assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict")
interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos
result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict(
(interface_name, (router_address, ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']))
for router_address, interface_name, _ in interface_infos
)
for interface_name in _get_interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in result:
continue
addresses = ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
result[interface_name] = (gateway_guess, address)
_default = gateways['default']
assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list")
default: typing.Dict[int, typing.Tuple[str, str]] = _default
result['default'] = result[default[netifaces.AF_INET][1]]
return result
def get_gateway_and_lan_addresses(interface_name: str) -> typing.Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return '', ''

View file

@ -44,10 +44,11 @@ ST
characters in the domain name must be replaced with hyphens in accordance with RFC 2141. characters in the domain name must be replaced with hyphens in accordance with RFC 2141.
""" """
import typing
from collections import OrderedDict from collections import OrderedDict
from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST
SEARCH_TARGETS = [ SEARCH_TARGETS: typing.List[str] = [
'upnp:rootdevice', 'upnp:rootdevice',
'urn:schemas-upnp-org:device:InternetGatewayDevice:1', 'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
'urn:schemas-wifialliance-org:device:WFADevice:1', 'urn:schemas-wifialliance-org:device:WFADevice:1',
@ -58,7 +59,8 @@ SEARCH_TARGETS = [
] ]
def format_packet_args(order: list, **kwargs): def format_packet_args(order: typing.List[str],
kwargs: typing.Dict[str, typing.Union[int, str]]) -> typing.Dict[str, typing.Union[int, str]]:
args = [] args = []
for o in order: for o in order:
for k, v in kwargs.items(): for k, v in kwargs.items():
@ -68,18 +70,18 @@ def format_packet_args(order: list, **kwargs):
return OrderedDict(args) return OrderedDict(args)
def packet_generator(): def packet_generator() -> typing.Iterator[typing.Dict[str, typing.Union[int, str]]]:
for st in SEARCH_TARGETS: for st in SEARCH_TARGETS:
order = ["HOST", "MAN", "MX", "ST"] order = ["HOST", "MAN", "MX", "ST"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': SSDP_DISCOVER, 'MX': 1, 'ST': st})
order = ["HOST", "MAN", "ST", "MX"] order = ["HOST", "MAN", "ST", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
order = ["HOST", "ST", "MAN", "MX"] order = ["HOST", "ST", "MAN", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st) yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})

View file

@ -1,45 +1,70 @@
import struct import struct
import socket import socket
import typing
from asyncio.protocols import DatagramProtocol from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport from asyncio.transports import BaseTransport
from unittest import mock
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]:
if transport is None or not hasattr(transport, "_extra"):
return None
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
assert sock is None or isinstance(sock, socket.SocketType) or isinstance(sock, mock.MagicMock)
return sock
class MulticastProtocol(DatagramProtocol): class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None: def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address self.multicast_address = multicast_address
self.bind_address = bind_address self.bind_address = bind_address
self.transport: DatagramTransport self.transport: typing.Optional[BaseTransport] = None
@property @property
def sock(self) -> socket.socket: def sock(self) -> typing.Optional[socket.socket]:
s: socket.socket = self.transport.get_extra_info(name='socket') return _get_sock(self.transport)
return s
def get_ttl(self) -> int: def get_ttl(self) -> int:
return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) sock = self.sock
if not sock:
raise ValueError("not connected")
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
def set_ttl(self, ttl: int = 1) -> None: def set_ttl(self, ttl: int = 1) -> None:
self.sock.setsockopt( sock = self.sock
if not sock:
return None
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
) )
return None
def join_group(self, multicast_address: str, bind_address: str) -> None: def join_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt( sock = self.sock
if not sock:
return None
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
) )
return None
def leave_group(self, multicast_address: str, bind_address: str) -> None: def leave_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt( sock = self.sock
if not sock:
raise ValueError("not connected")
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
) )
return None
def connection_made(self, transport) -> None: def connection_made(self, transport: BaseTransport) -> None:
self.transport = transport self.transport = transport
return None
@classmethod @classmethod
def create_multicast_socket(cls, bind_address: str): def create_multicast_socket(cls, bind_address: str) -> socket.socket:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0)) sock.bind((bind_address, 0))

View file

@ -2,7 +2,6 @@ import logging
import typing import typing
import re import re
from collections import OrderedDict from collections import OrderedDict
from xml.etree import ElementTree
import asyncio import asyncio
from asyncio.protocols import Protocol from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
@ -18,16 +17,22 @@ log = logging.getLogger(__name__)
HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$") HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]: def parse_http_response_code(http_response: bytes) -> typing.Tuple[bytes, bytes]:
parsed: typing.List[typing.Tuple[bytes, bytes]] = HTTP_CODE_REGEX.findall(http_response)
return parsed[0]
def parse_headers(response: bytes) -> typing.Tuple[typing.Dict[bytes, bytes], int, bytes]:
lines = response.split(b'\r\n') lines = response.split(b'\r\n')
headers = OrderedDict([ headers: typing.Dict[bytes, bytes] = OrderedDict([
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' ')) (l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
for l in response.split(b'\r\n') for l in response.split(b'\r\n')
]) ])
if len(lines) != len(headers): if len(lines) != len(headers):
raise ValueError("duplicate headers") raise ValueError("duplicate headers")
http_response = tuple(headers.keys())[0] header_keys: typing.List[bytes] = list(headers.keys())
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0] http_response = header_keys[0]
response_code, message = parse_http_response_code(http_response)
del headers[http_response] del headers[http_response]
return headers, int(response_code), message return headers, int(response_code), message
@ -40,37 +45,42 @@ class SCPDHTTPClientProtocol(Protocol):
and devices respond with an invalid HTTP version line and devices respond with an invalid HTTP version line
""" """
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None, def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]',
soap_service_id: str=None) -> None: soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None:
self.message = message self.message = message
self.response_buff = b"" self.response_buff = b""
self.finished = finished self.finished = finished
self.soap_method = soap_method self.soap_method = soap_method
self.soap_service_id = soap_service_id self.soap_service_id = soap_service_id
self._response_code = 0
self._response_code: int = 0 self._response_msg = b""
self._response_msg: bytes = b"" self._content_length = 0
self._content_length: int = 0
self._got_headers = False self._got_headers = False
self._headers: dict = {} self._headers: typing.Dict[bytes, bytes] = {}
self._body = b"" self._body = b""
self.transport: typing.Optional[asyncio.WriteTransport] = None
def connection_made(self, transport): def connection_made(self, transport: asyncio.BaseTransport) -> None:
transport.write(self.message) assert isinstance(transport, asyncio.WriteTransport)
self.transport = transport
self.transport.write(self.message)
return None
def data_received(self, data): def data_received(self, data: bytes) -> None:
self.response_buff += data self.response_buff += data
for i, line in enumerate(self.response_buff.split(b'\r\n')): for i, line in enumerate(self.response_buff.split(b'\r\n')):
if not line: # we hit the blank line between the headers and the body if not line: # we hit the blank line between the headers and the body
if i == (len(self.response_buff.split(b'\r\n')) - 1): if i == (len(self.response_buff.split(b'\r\n')) - 1):
return # the body is still yet to be written return None # the body is still yet to be written
if not self._got_headers: if not self._got_headers:
self._headers, self._response_code, self._response_msg = parse_headers( self._headers, self._response_code, self._response_msg = parse_headers(
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i]) b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
) )
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length') content_length = get_dict_val_case_insensitive(
self._headers, b'Content-Length'
)
if content_length is None: if content_length is None:
return return None
self._content_length = int(content_length or 0) self._content_length = int(content_length or 0)
self._got_headers = True self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:]) body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
@ -86,21 +96,28 @@ class SCPDHTTPClientProtocol(Protocol):
) )
) )
) )
return return None
return None
async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typing.Tuple[typing.Dict, bytes, async def scpd_get(control_url: str, address: str, port: int,
typing.Optional[Exception]]: loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> typing.Tuple[typing.Dict[str, typing.Any], bytes,
loop = loop or asyncio.get_event_loop_policy().get_event_loop() typing.Optional[Exception]]:
finished: asyncio.Future = asyncio.Future() loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection( finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop)
lambda : SCPDHTTPClientProtocol(packet, finished), address, port proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
) )
protocol = connect_tup[1]
transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
error = None error = None
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
try: try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) body, response_code, response_msg = await wait_task
except asyncio.TimeoutError: except asyncio.TimeoutError:
error = UPnPError("get request timed out") error = UPnPError("get request timed out")
body = b'' body = b''
@ -112,24 +129,31 @@ async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typi
if not error: if not error:
try: try:
return deserialize_scpd_get_response(body), body, None return deserialize_scpd_get_response(body), body, None
except ElementTree.ParseError as err: except Exception as err:
error = UPnPError(err) error = UPnPError(err)
return {}, body, error return {}, body, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
loop=None, **kwargs) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: loop: typing.Optional[asyncio.AbstractEventLoop] = None,
loop = loop or asyncio.get_event_loop_policy().get_event_loop() **kwargs: typing.Dict[str, typing.Any]
finished: asyncio.Future = asyncio.Future() ) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop)
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection( proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
lambda : SCPDHTTPClientProtocol( SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
packet, finished, soap_method=method, soap_service_id=service_id.decode(), connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
), address, port proto_factory, address, port
) )
protocol = connect_tup[1]
transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
try: try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0) wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
body, response_code, response_msg = await wait_task
except asyncio.TimeoutError: except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout") return {}, b'', UPnPError("Timeout")
except UPnPError as err: except UPnPError as err:
@ -140,5 +164,5 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
return ( return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None deserialize_soap_post_response(body, method, service_id.decode()), body, None
) )
except (ElementTree.ParseError, UPnPError) as err: except Exception as err:
return {}, body, UPnPError(err) return {}, body, UPnPError(err)

View file

@ -3,8 +3,8 @@ import binascii
import asyncio import asyncio
import logging import logging
import typing import typing
import socket
from collections import OrderedDict from collections import OrderedDict
from asyncio.futures import Future
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
@ -18,32 +18,48 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None, def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None,
unicast: bool = False) -> None: unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: typing.Optional[DatagramTransport] = None
self._unicast = unicast self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations self._ignored: typing.Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = [] self._pending_searches: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = []
self.notifications: typing.List = [] self.notifications: typing.List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop)
def disconnect(self): def connection_made(self, transport) -> None:
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
super().connection_made(transport)
self.connected.set()
def disconnect(self) -> None:
if self.transport: if self.transport:
try:
self.leave_group(self.multicast_address, self.bind_address)
except ValueError:
pass
except Exception:
log.exception("unexpected error leaving multicast group")
self.transport.close() self.transport.close()
self.connected.clear()
while self._pending_searches: while self._pending_searches:
pending = self._pending_searches.pop()[2] pending = self._pending_searches.pop()[2]
if not pending.cancelled() and not pending.done(): if not pending.cancelled() and not pending.done():
pending.cancel() pending.cancel()
return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored: if packet.location in self._ignored:
return return None
tmp: typing.List = [] # TODO: fix this
set_futures: typing.List = [] tmp: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = []
while self._pending_searches: set_futures: typing.List[asyncio.Future[SSDPDatagram]] = []
t: tuple = self._pending_searches.pop() while len(self._pending_searches):
a, s = t[0], t[1] t: typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle] = self._pending_searches.pop()
if (address == a) and (s in [packet.st, "upnp:rootdevice"]): if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f: Future = t[2] f: asyncio.Future[SSDPDatagram] = t[2]
if f not in set_futures: if f not in set_futures:
set_futures.append(f) set_futures.append(f)
if not f.done(): if not f.done():
@ -52,38 +68,41 @@ class SSDPProtocol(MulticastProtocol):
tmp.append(t) tmp.append(t)
while tmp: while tmp:
self._pending_searches.append(tmp.pop()) self._pending_searches.append(tmp.pop())
return None
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]): def _send_m_search(self, address: str, packet: SSDPDatagram) -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS dest = address if self._unicast else SSDP_IP_ADDRESS
for packet in packets: if not self.transport:
log.debug("send m search to %s: %s", dest, packet.st) raise UPnPError("SSDP transport not connected")
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram: async def m_search(self, address: str, timeout: float,
fut: Future = Future() datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
packets: typing.List[SSDPDatagram] = [] fut: asyncio.Future[SSDPDatagram] = asyncio.Future(loop=self.loop)
for datagram in datagrams: for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None assert packet.st is not None
self._pending_searches.append((address, packet.st, fut)) self._pending_searches.append(
packets.append(packet) (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet))
self.send_many_m_searches(address, packets), )
return await fut return await asyncio.wait_for(fut, timeout)
def datagram_received(self, data, addr) -> None: def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore
if addr[0] == self.bind_address: if addr[0] == self.bind_address:
return return None
try: try:
packet = SSDPDatagram.decode(data) packet = SSDPDatagram.decode(data)
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet) log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
except UPnPError as err: except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data)) binascii.hexlify(data))
return return None
if packet._packet_type == packet._OK: if packet._packet_type == packet._OK:
self._callback_m_search_ok(addr[0], packet) self._callback_m_search_ok(addr[0], packet)
return return None
# elif packet._packet_type == packet._NOTIFY: # elif packet._packet_type == packet._NOTIFY:
# log.debug("%s:%i sent us a notification: %s", packet) # log.debug("%s:%i sent us a notification: %s", packet)
# if packet.nt == SSDP_ROOT_DEVICE: # if packet.nt == SSDP_ROOT_DEVICE:
@ -104,17 +123,18 @@ class SSDPProtocol(MulticastProtocol):
# return # return
async def listen_ssdp(lan_address: str, gateway_address: str, loop=None, async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport, ignored: typing.Optional[typing.Set[str]] = None,
SSDPProtocol, str, str]: unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop_policy().get_event_loop() loop = loop or asyncio.get_event_loop()
try: try:
sock = SSDPProtocol.create_multicast_socket(lan_address) sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple = await loop.create_datagram_endpoint( listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
) )
transport: DatagramTransport = listen_result[0] transport = listen_result[0]
protocol: SSDPProtocol = listen_result[1] protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol)
except Exception as err: except Exception as err:
print(err) print(err)
raise UPnPError(err) raise UPnPError(err)
@ -125,30 +145,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
protocol.disconnect() protocol.disconnect()
raise UPnPError(err) raise UPnPError(err)
return transport, protocol, gateway_address, lan_address return protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]],
loop=None, ignored: typing.Set[str] = None, timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: bool = False) -> SSDPDatagram: ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored, unicast
) )
try: try:
return await asyncio.wait_for( return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout
)
except (asyncio.TimeoutError, asyncio.CancelledError): except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:
protocol.disconnect() protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]: loop: typing.Optional[asyncio.AbstractEventLoop] = None,
transport, protocol, gateway_address, lan_address = await listen_ssdp( ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored, unicast
) )
await protocol.connected.wait()
packet_args = list(packet_generator()) packet_args = list(packet_generator())
batch_size = 2 batch_size = 2
batch_timeout = float(timeout) / float(len(packet_args)) batch_timeout = float(timeout) / float(len(packet_args))
@ -157,7 +178,7 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
packet_args = packet_args[batch_size:] packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size) log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try: try:
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), batch_timeout) await protocol.m_search(gateway_address, batch_timeout, args)
protocol.disconnect() protocol.disconnect()
return args return args
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -166,9 +187,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
SSDPDatagram]: ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.Tuple[typing.Dict[str,
typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time # we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast) args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked # check the args in the batch that got a reply one at a time to see which one worked

View file

@ -1,8 +1,9 @@
import re import re
from typing import Dict from typing import Dict, Any, List, Tuple
from xml.etree import ElementTree from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION, DEVICE, ROOT from aioupnp.constants import XML_VERSION
from aioupnp.util import etree_to_dict, flatten_keys from aioupnp.serialization.xml import xml_to_dict
from aioupnp.util import flatten_keys
CONTENT_PATTERN = re.compile( CONTENT_PATTERN = re.compile(
@ -28,34 +29,38 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
if not path.startswith("/"): if not path.startswith("/"):
path = "/" + path path = "/" + path
return ( return (
( f'GET {path} HTTP/1.1\r\n'
'GET %s HTTP/1.1\r\n' f'Accept-Encoding: gzip\r\n'
'Accept-Encoding: gzip\r\n' f'Host: {host}\r\n'
'Host: %s\r\n' f'Connection: Close\r\n'
'Connection: Close\r\n' f'\r\n'
'\r\n'
) % (path, host)
).encode() ).encode()
def deserialize_scpd_get_response(content: bytes) -> Dict: def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
if XML_VERSION.encode() in content: if XML_VERSION.encode() in content:
parsed = CONTENT_PATTERN.findall(content) parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0] xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode())
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
return parse_device_dict(xml_dict) return parse_device_dict(xml_dict)
return {} return {}
def parse_device_dict(xml_dict: dict) -> Dict: def parse_device_dict(xml_dict: Dict[str, Any]) -> Dict[str, Any]:
keys = list(xml_dict.keys()) keys = list(xml_dict.keys())
found = False
for k in keys: for k in keys:
m = XML_ROOT_SANITY_PATTERN.findall(k) m: List[Tuple[str, str, str, str, str, str]] = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]: if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0] schema_key: str = m[1][0]
root = m[2][5] root: str = m[2][5]
xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root] flattened = flatten_keys(xml_dict, "{%s}" % schema_key)
if root not in flattened:
raise UPnPError("root device not found")
xml_dict = flattened[root]
found = True
break break
if not found:
raise UPnPError("device not found")
result = {} result = {}
for k, v in xml_dict.items(): for k, v in xml_dict.items():
if isinstance(xml_dict[k], dict): if isinstance(xml_dict[k], dict):
@ -65,10 +70,9 @@ def parse_device_dict(xml_dict: dict) -> Dict:
if len(parsed_k) == 2: if len(parsed_k) == 2:
inner_d[parsed_k[0]] = inner_v inner_d[parsed_k[0]] = inner_v
else: else:
assert len(parsed_k) == 3 assert len(parsed_k) == 3, f"expected len=3, got {len(parsed_k)}"
inner_d[parsed_k[1]] = inner_v inner_d[parsed_k[1]] = inner_v
result[k] = inner_d result[k] = inner_d
else: else:
result[k] = v result[k] = v
return result return result

View file

@ -1,64 +1,65 @@
import re import re
from xml.etree import ElementTree import typing
from aioupnp.util import etree_to_dict, flatten_keys from aioupnp.util import flatten_keys
from aioupnp.fault import handle_fault, UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL
from aioupnp.serialization.xml import xml_to_dict
CONTENT_NO_XML_VERSION_PATTERN = re.compile( CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode() "(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
) )
def serialize_soap_post(method: str, param_names: list, service_id: bytes, gateway_address: bytes, def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes,
control_url: bytes, **kwargs) -> bytes: control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes:
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in param_names) args = "".join(f"<{n}>{kwargs.get(n)}</{n}>" for n in param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" ' soap_body = (f'\r\n{XML_VERSION}\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>' f's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % ( f'<u:{method} xmlns:u="{service_id.decode()}">{args}</u:{method}></s:Body></s:Envelope>')
XML_VERSION, method, service_id.decode(),
args, method))
if "http://" in gateway_address.decode(): if "http://" in gateway_address.decode():
host = gateway_address.decode().split("http://")[1] host = gateway_address.decode().split("http://")[1]
else: else:
host = gateway_address.decode() host = gateway_address.decode()
return ( return (
( f'POST {control_url.decode()} HTTP/1.1\r\n' # could be just / even if it shouldn't be
'POST %s HTTP/1.1\r\n' f'Host: {host}\r\n'
'Host: %s\r\n' f'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' f'Content-Length: {len(soap_body)}\r\n'
'Content-Length: %i\r\n' f'Content-Type: text/xml\r\n'
'Content-Type: text/xml\r\n' f'SOAPAction: \"{service_id.decode()}#{method}\"\r\n'
'SOAPAction: \"%s#%s\"\r\n' f'Connection: Close\r\n'
'Connection: Close\r\n' f'Cache-Control: no-cache\r\n'
'Cache-Control: no-cache\r\n' f'Pragma: no-cache\r\n'
'Pragma: no-cache\r\n' f'{soap_body}'
'%s' f'\r\n'
'\r\n'
) % (
control_url.decode(), # could be just / even if it shouldn't be
host,
len(soap_body),
service_id.decode(), # maybe no quotes
method,
soap_body
)
).encode() ).encode()
def deserialize_soap_post_response(response: bytes, method: str, service_id: str) -> dict: def deserialize_soap_post_response(response: bytes, method: str,
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(response) service_id: str) -> typing.Dict[str, typing.Dict[str, str]]:
parsed: typing.List[typing.List[bytes]] = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
content = b'' if not parsed else parsed[0][0] content = b'' if not parsed else parsed[0][0]
content_dict = etree_to_dict(ElementTree.fromstring(content.decode())) content_dict = xml_to_dict(content.decode())
envelope = content_dict[ENVELOPE] envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id) if not isinstance(envelope[BODY], dict):
body = handle_fault(response_body) # raises UPnPError if there is a fault # raise UPnPError('blank response')
return {} # TODO: raise
response_body: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
envelope[BODY], f"{'{' + service_id + '}'}"
)
if not response_body:
# raise UPnPError('blank response')
return {} # TODO: raise
if FAULT in response_body:
fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
response_body[FAULT], "{%s}" % CONTROL
)
raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
response_key = None response_key = None
if not body: for key in response_body:
return {}
for key in body:
if method in key: if method in key:
response_key = key response_key = key
break break
if not response_key: if not response_key:
raise UPnPError("unknown response fields for %s: %s" % (method, body)) raise UPnPError(f"unknown response fields for {method}: {response_body}")
return body[response_key] return response_body[response_key]

View file

@ -3,43 +3,67 @@ import logging
import binascii import binascii
import json import json
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List, Optional, Dict, Union, Tuple, Callable
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator from aioupnp.constants import line_separator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_template = "(?i)^(%s):[ ]*(.*)$" _template = "(?i)^(%s):[ ]*(.*)$"
ssdp_datagram_patterns = {
'host': (re.compile("(?i)^(host):(.*)$"), str),
'st': (re.compile(_template % 'st'), str),
'man': (re.compile(_template % 'man'), str),
'mx': (re.compile(_template % 'mx'), int),
'nt': (re.compile(_template % 'nt'), str),
'nts': (re.compile(_template % 'nts'), str),
'usn': (re.compile(_template % 'usn'), str),
'location': (re.compile(_template % 'location'), str),
'cache_control': (re.compile(_template % 'cache[-|_]control'), str),
'server': (re.compile(_template % 'server'), str),
}
vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
class SSDPDatagram(object): def match_vendor(line: str) -> Optional[Tuple[str, str]]:
match: List[Tuple[str, str, str]] = vendor_pattern.findall(line)
if match:
vendor_key: str = match[-1][0].lstrip(" ").rstrip(" ")
vendor_value: str = match[-1][2].lstrip(" ").rstrip(" ")
return vendor_key, vendor_value
return None
def compile_find(pattern: str) -> Callable[[str], Optional[str]]:
p = re.compile(pattern)
def find(line: str) -> Optional[str]:
result: List[List[str]] = []
for outer in p.findall(line):
result.append([])
for inner in outer:
result[-1].append(inner)
if result:
return result[-1][-1].lstrip(" ").rstrip(" ")
return None
return find
ssdp_datagram_patterns: Dict[str, Callable[[str], Optional[str]]] = {
'host': compile_find("(?i)^(host):(.*)$"),
'st': compile_find(_template % 'st'),
'man': compile_find(_template % 'man'),
'mx': compile_find(_template % 'mx'),
'nt': compile_find(_template % 'nt'),
'nts': compile_find(_template % 'nts'),
'usn': compile_find(_template % 'usn'),
'location': compile_find(_template % 'location'),
'cache_control': compile_find(_template % 'cache[-|_]control'),
'server': compile_find(_template % 'server'),
}
class SSDPDatagram:
_M_SEARCH = "M-SEARCH" _M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY" _NOTIFY = "NOTIFY"
_OK = "OK" _OK = "OK"
_start_lines = { _start_lines: Dict[str, str] = {
_M_SEARCH: "M-SEARCH * HTTP/1.1", _M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1", _NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK" _OK: "HTTP/1.1 200 OK"
} }
_friendly_names = { _friendly_names: Dict[str, str] = {
_M_SEARCH: "m-search", _M_SEARCH: "m-search",
_NOTIFY: "notify", _NOTIFY: "notify",
_OK: "m-search response" _OK: "m-search response"
@ -47,9 +71,7 @@ class SSDPDatagram(object):
_vendor_field_pattern = vendor_pattern _vendor_field_pattern = vendor_pattern
_patterns = ssdp_datagram_patterns _required_fields: Dict[str, List[str]] = {
_required_fields = {
_M_SEARCH: [ _M_SEARCH: [
'host', 'host',
'man', 'man',
@ -75,137 +97,137 @@ class SSDPDatagram(object):
] ]
} }
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None: def __init__(self, packet_type: str, kwargs: Optional[Dict[str, Union[str, int]]] = None) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]: if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type)) raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type self._packet_type = packet_type
kwargs = kwargs or OrderedDict() kw: Dict[str, Union[str, int]] = kwargs or OrderedDict()
self._field_order: list = [ self._field_order: List[str] = [
k.lower().replace("-", "_") for k in kwargs.keys() k.lower().replace("-", "_") for k in kw.keys()
] ]
self.host = None self.host: Optional[str] = None
self.man = None self.man: Optional[str] = None
self.mx = None self.mx: Optional[Union[str, int]] = None
self.st = None self.st: Optional[str] = None
self.nt = None self.nt: Optional[str] = None
self.nts = None self.nts: Optional[str] = None
self.usn = None self.usn: Optional[str] = None
self.location = None self.location: Optional[str] = None
self.cache_control = None self.cache_control: Optional[str] = None
self.server = None self.server: Optional[str] = None
self.date = None self.date: Optional[str] = None
self.ext = None self.ext: Optional[str] = None
for k, v in kwargs.items(): for k, v in kw.items():
normalized = k.lower().replace("-", "_") normalized = k.lower().replace("-", "_")
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self, normalized) is None: if not normalized.startswith("_") and hasattr(self, normalized):
setattr(self, normalized, v) if getattr(self, normalized, None) is None:
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()} setattr(self, normalized, v)
self._case_mappings: Dict[str, str] = {k.lower(): k for k in kw.keys()}
for k in self._required_fields[self._packet_type]: for k in self._required_fields[self._packet_type]:
if getattr(self, k) is None: if getattr(self, k, None) is None:
raise UPnPError("missing required field %s" % k) raise UPnPError("missing required field %s" % k)
def get_cli_igd_kwargs(self) -> str: def get_cli_igd_kwargs(self) -> str:
fields = [] fields = []
for field in self._field_order: for field in self._field_order:
v = getattr(self, field) v = getattr(self, field, None)
if v is None:
raise UPnPError("missing required field %s" % field)
fields.append("--%s=%s" % (self._case_mappings.get(field, field), v)) fields.append("--%s=%s" % (self._case_mappings.get(field, field), v))
return " ".join(fields) return " ".join(fields)
def __repr__(self) -> str: def __repr__(self) -> str:
return self.as_json() return self.as_json()
def __getitem__(self, item): def __getitem__(self, item: str) -> Union[str, int]:
for i in self._required_fields[self._packet_type]: for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower(): if i.lower() == item.lower():
return getattr(self, i) return getattr(self, i)
raise KeyError(item) raise KeyError(item)
def get_friendly_name(self) -> str:
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines: int = 2) -> str: def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]] lines = [self._start_lines[self._packet_type]]
for attr_name in self._field_order: lines.extend(
if attr_name not in self._required_fields[self._packet_type]: f"{self._case_mappings.get(attr_name.lower(), attr_name.upper())}: {str(getattr(self, attr_name))}"
continue for attr_name in self._field_order if attr_name in self._required_fields[self._packet_type]
attr = getattr(self, attr_name) )
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name == 'mx':
value = str(attr)
else:
value = attr
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
serialized = line_separator.join(lines) serialized = line_separator.join(lines)
for _ in range(trailing_newlines): for _ in range(trailing_newlines):
serialized += line_separator serialized += line_separator
return serialized return serialized
def as_dict(self) -> OrderedDict: def as_dict(self) -> Dict[str, Union[str, int]]:
return self._lines_to_content_dict(self.encode().split(line_separator)) return self._lines_to_content_dict(self.encode().split(line_separator))
def as_json(self) -> str: def as_json(self) -> str:
return json.dumps(self.as_dict(), indent=2) return json.dumps(self.as_dict(), indent=2)
@classmethod @classmethod
def decode(cls, datagram: bytes): def decode(cls, datagram: bytes) -> 'SSDPDatagram':
packet = cls._from_string(datagram.decode()) packet = cls._from_string(datagram.decode())
if packet is None: if packet is None:
raise UPnPError( raise UPnPError(
"failed to decode datagram: {}".format(binascii.hexlify(datagram)) "failed to decode datagram: {}".format(binascii.hexlify(datagram))
) )
for attr_name in packet._required_fields[packet._packet_type]: for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name) if getattr(packet, attr_name, None) is None:
if attr is None:
raise UPnPError( raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name) "required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
) )
return packet return packet
@classmethod @classmethod
def _lines_to_content_dict(cls, lines: list) -> OrderedDict: def _lines_to_content_dict(cls, lines: List[str]) -> Dict[str, Union[str, int]]:
result: OrderedDict = OrderedDict() result: Dict[str, Union[str, int]] = OrderedDict()
matched_keys: List[str] = []
for line in lines: for line in lines:
if not line: if not line:
continue continue
matched = False matched = False
for name, (pattern, field_type) in cls._patterns.items(): for name, pattern in ssdp_datagram_patterns.items():
if name not in result and pattern.findall(line): if name not in matched_keys:
match = pattern.findall(line)[-1][-1] if name.lower() == 'mx':
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" ")) _matched_int = pattern(line)
matched = True if _matched_int is not None:
break match_int = int(_matched_int)
result[line[:len(name)]] = match_int
matched = True
matched_keys.append(name)
break
else:
match = pattern(line)
if match is not None:
result[line[:len(name)]] = match
matched = True
matched_keys.append(name)
break
if not matched: if not matched:
if cls._vendor_field_pattern.findall(line): matched_vendor = match_vendor(line)
match = cls._vendor_field_pattern.findall(line)[-1] if matched_vendor and matched_vendor[0] not in result:
vendor_key = match[0].lstrip(" ").rstrip(" ") result[matched_vendor[0]] = matched_vendor[1]
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result return result
@classmethod @classmethod
def _from_string(cls, datagram: str): def _from_string(cls, datagram: str) -> Optional['SSDPDatagram']:
lines = [l for l in datagram.split(line_separator) if l] lines = [l for l in datagram.split(line_separator) if l]
if not lines: if not lines:
return return None
if lines[0] == cls._start_lines[cls._M_SEARCH]: if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:]) return cls._from_request(lines[1:])
if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]: if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]:
return cls._from_notify(lines[1:]) return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]: if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:]) return cls._from_response(lines[1:])
return None
@classmethod @classmethod
def _from_response(cls, lines: List): def _from_response(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._OK, cls._lines_to_content_dict(lines)) return cls(cls._OK, cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_notify(cls, lines: List): def _from_notify(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines)) return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_request(cls, lines: List): def _from_request(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines)) return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))

View file

@ -0,0 +1,81 @@
import typing
from collections import OrderedDict
from defusedxml import ElementTree
str_any_dict = typing.Dict[str, typing.Any]
def parse_xml(xml_str: str) -> ElementTree:
element: ElementTree = ElementTree.fromstring(xml_str)
return element
def _element_text(element_tree: ElementTree) -> typing.Optional[str]:
# if element_tree.attrib:
# element: typing.Dict[str, str] = OrderedDict()
# for k, v in element_tree.attrib.items():
# element['@' + k] = v
# if not element_tree.text:
# return element
# element['#text'] = element_tree.text.strip()
# return element
if element_tree.text:
return element_tree.text.strip()
return None
def _get_element_children(element_tree: ElementTree) -> typing.Dict[str, typing.Union[
str, typing.List[typing.Union[typing.Dict[str, str], str]]]]:
element_children = _get_child_dicts(element_tree)
element: typing.Dict[str, typing.Any] = OrderedDict()
keys: typing.List[str] = list(element_children.keys())
for k in keys:
v: typing.Union[str, typing.List[typing.Any], typing.Dict[str, typing.Any]] = element_children[k]
if len(v) == 1 and isinstance(v, list):
l: typing.List[typing.Union[typing.Dict[str, str], str]] = v
element[k] = l[0]
else:
element[k] = v
return element
def _get_child_dicts(element: ElementTree) -> typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]]:
children_dicts: typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]] = OrderedDict()
children: typing.List[ElementTree] = list(element)
for child in children:
child_dict = _recursive_element_to_dict(child)
child_keys: typing.List[str] = list(child_dict.keys())
for k in child_keys:
assert k in child_dict
v: typing.Union[typing.Dict[str, str], str] = child_dict[k]
if k not in children_dicts.keys():
new_item = [v]
children_dicts[k] = new_item
else:
sublist = children_dicts[k]
assert isinstance(sublist, list)
sublist.append(v)
return children_dicts
def _recursive_element_to_dict(element_tree: ElementTree) -> typing.Dict[str, typing.Any]:
if len(element_tree):
element_result: typing.Dict[str, typing.Dict[str, typing.Union[
str, typing.List[typing.Union[str, typing.Dict[str, typing.Any]]]]]] = OrderedDict()
children_element = _get_element_children(element_tree)
if element_tree.tag is not None:
element_result[element_tree.tag] = children_element
return element_result
else:
element_text = _element_text(element_tree)
if element_text is not None:
base_element_result: typing.Dict[str, typing.Any] = OrderedDict()
if element_tree.tag is not None:
base_element_result[element_tree.tag] = element_text
return base_element_result
null_result: typing.Dict[str, str] = OrderedDict()
return null_result
def xml_to_dict(xml_str: str) -> typing.Dict[str, typing.Any]:
return _recursive_element_to_dict(parse_xml(xml_str))

View file

@ -5,12 +5,11 @@ import asyncio
import zlib import zlib
import base64 import base64
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, Dict, List, Union from typing import Tuple, Dict, List, Union, Optional, Callable
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.util import get_gateway_and_lan_addresses from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.commands import SOAPCommand
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,6 +34,10 @@ class UPnP:
self.gateway_address = gateway_address self.gateway_address = gateway_address
self.gateway = gateway self.gateway = gateway
@classmethod
def get_annotations(cls, command: str) -> Dict[str, type]:
return getattr(Gateway.commands, command).__annotations__
@classmethod @classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> Tuple[str, str]: interface_name: str = 'default') -> Tuple[str, str]:
@ -59,8 +62,9 @@ class UPnP:
@classmethod @classmethod
@cli @cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default', igd_args: Optional[Dict[str, Union[int, str]]] = None,
loop=None) -> Dict: unicast: bool = True, interface_name: str = 'default',
loop=None) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]:
if not lan_address or not gateway_address: if not lan_address or not gateway_address:
try: try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -97,13 +101,13 @@ class UPnP:
async def get_port_mapping_by_index(self, index: int) -> Dict: async def get_port_mapping_by_index(self, index: int) -> Dict:
result = await self._get_port_mapping_by_index(index) result = await self._get_port_mapping_by_index(index)
if result: if result:
if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand): if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
return { return {
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result) k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
} }
return {} return {}
async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Union[None, str], int, str, async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Optional[str], int, str,
int, str, bool, str, int]]: int, str, bool, str, int]]:
try: try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
@ -134,7 +138,7 @@ class UPnP:
result = await self.gateway.commands.GetSpecificPortMappingEntry( result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
) )
if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand): if result and self.gateway.commands.is_registered('GetSpecificPortMappingEntry'):
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
except UPnPError: except UPnPError:
pass pass
@ -152,7 +156,8 @@ class UPnP:
) )
@cli @cli
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int: async def get_next_mapping(self, port: int, protocol: str, description: str,
internal_port: Optional[int] = None) -> int:
if protocol not in ["UDP", "TCP"]: if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol)) raise UPnPError("unsupported protocol: {}".format(protocol))
internal_port = int(internal_port or port) internal_port = int(internal_port or port)
@ -340,8 +345,9 @@ class UPnP:
return await self.gateway.commands.GetActiveConnections() return await self.gateway.commands.GetActiveConnections()
@classmethod @classmethod
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30, def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None: gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
""" """
:param method: the command name :param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided :param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
@ -356,7 +362,7 @@ class UPnP:
igd_args = igd_args igd_args = igd_args
timeout = int(timeout) timeout = int(timeout)
loop = loop or asyncio.get_event_loop_policy().get_event_loop() loop = loop or asyncio.get_event_loop_policy().get_event_loop()
fut: asyncio.Future = asyncio.Future() fut: asyncio.Future = loop.create_future()
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine async def wrapper(): # wrap the upnp setup and call of the command in a coroutine

View file

@ -1,91 +1,48 @@
import re import typing
import socket from collections import OrderedDict
from collections import defaultdict
from typing import Tuple, Dict
from xml.etree import ElementTree
import netifaces
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) str_any_dict = typing.Dict[str, typing.Any]
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t: ElementTree.Element) -> Dict: def _recursive_flatten(to_flatten: typing.Any, strip: str) -> typing.Any:
d: dict = {} if not isinstance(to_flatten, (list, dict)):
if t.attrib: return to_flatten
d[t.tag] = {} if isinstance(to_flatten, list):
children = list(t) assert isinstance(to_flatten, list)
if children: return [_recursive_flatten(i, strip) for i in to_flatten]
dd: dict = defaultdict(list) assert isinstance(to_flatten, dict)
for dc in map(etree_to_dict, children): keys: typing.List[str] = list(to_flatten.keys())
for k, v in dc.items(): copy: str_any_dict = OrderedDict()
dd[k].append(v) for k in keys:
d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()} item: typing.Any = to_flatten[k]
if t.attrib:
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
if t.text:
text = t.text.strip()
if children or t.attrib:
if text:
d[t.tag]['#text'] = text
else:
d[t.tag] = text
return d
def flatten_keys(d, strip):
if not isinstance(d, (list, dict)):
return d
if isinstance(d, list):
return [flatten_keys(i, strip) for i in d]
t = {}
for k, v in d.items():
if strip in k and strip != k: if strip in k and strip != k:
t[k.split(strip)[1]] = flatten_keys(v, strip) copy[k.split(strip)[1]] = _recursive_flatten(item, strip)
else: else:
t[k] = flatten_keys(v, strip) copy[k] = _recursive_flatten(item, strip)
return t return copy
def get_dict_val_case_insensitive(d, k): def flatten_keys(to_flatten: str_any_dict, strip: str) -> str_any_dict:
match = list(filter(lambda x: x.lower() == k.lower(), d.keys())) keys: typing.List[str] = list(to_flatten.keys())
if not match: copy: str_any_dict = OrderedDict()
return for k in keys:
item = to_flatten[k]
if strip in k and strip != k:
new_key: str = k.split(strip)[1]
copy[new_key] = _recursive_flatten(item, strip)
else:
copy[k] = _recursive_flatten(item, strip)
return copy
def get_dict_val_case_insensitive(source: typing.Dict[typing.AnyStr, typing.AnyStr], key: typing.AnyStr) -> typing.Optional[typing.AnyStr]:
match: typing.List[typing.AnyStr] = list(filter(lambda x: x.lower() == key.lower(), source.keys()))
if not len(match):
return None
if len(match) > 1: if len(match) > 1:
raise KeyError("overlapping keys") raise KeyError("overlapping keys")
return d[match[0]] if len(match) == 1:
matched_key: typing.AnyStr = match[0]
# import struct return source[matched_key]
# import fcntl raise KeyError("overlapping keys")
# def get_ip_address(ifname):
# SIOCGIFADDR = 0x8915
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# return socket.inet_ntoa(fcntl.ioctl(
# s.fileno(),
# SIOCGIFADDR,
# struct.pack(b'256s', ifname[:15].encode())
# )[20:24])
def get_interfaces():
r = {
interface_name: (router_address, netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'])
for router_address, interface_name, _ in netifaces.gateways()[socket.AF_INET]
}
for interface_name in netifaces.interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in r:
continue
addresses = netifaces.ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
r[interface_name] = (gateway_guess, address)
r['default'] = r[netifaces.gateways()['default'][netifaces.AF_INET][1]]
return r
def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return '', ''

View file

@ -37,7 +37,7 @@ setup(
packages=find_packages(exclude=('tests',)), packages=find_packages(exclude=('tests',)),
entry_points={'console_scripts': console_scripts}, entry_points={'console_scripts': console_scripts},
install_requires=[ install_requires=[
'netifaces', 'netifaces', 'defusedxml'
], ],
extras_require={ extras_require={
'test': ( 'test': (