mypy refactor

This commit is contained in:
Jack Robison 2019-05-21 18:16:30 -04:00
parent a404269d91
commit 4137f7cd8a
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
20 changed files with 1357 additions and 643 deletions

4
.coveragerc Normal file
View file

@ -0,0 +1,4 @@
[run]
omit =
tests/*
stubs/*

3
.gitignore vendored
View file

@ -3,6 +3,9 @@
_trial_temp/
build/
dist/
html/
index.html
mypy-html.css
.coverage
.mypy_cache/
aioupnp.spec

440
.pylintrc Normal file
View file

@ -0,0 +1,440 @@
[MASTER]
# Specify a configuration file.
#rcfile=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,schema
# Add files or directories matching the regex patterns to the
# blacklist. The regex matches against base names, not paths.
# `\.#.*` - add emacs tmp files to the blacklist
ignore-patterns=\.#.*
# Pickle collected data for later comparisons.
persistent=yes
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=1
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
# extension-pkg-whitelist=
# Allow optimization of some AST trees. This will activate a peephole AST
# optimizer, which will apply various small optimizations. For instance, it can
# be used to obtain the result of joining multiple strings with the addition
# operator. Joining a lot of strings can lead to a maximum recursion error in
# Pylint and this flag can prevent that. It has one side effect, the resulting
# AST will be different than the one from reality.
optimize-ast=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=
anomalous-backslash-in-string,
arguments-differ,
attribute-defined-outside-init,
bad-continuation,
bare-except,
broad-except,
cell-var-from-loop,
consider-iterating-dictionary,
dangerous-default-value,
duplicate-code,
fixme,
global-statement,
inherit-non-class,
invalid-name,
len-as-condition,
locally-disabled,
logging-not-lazy,
missing-docstring,
no-else-return,
no-init,
no-member,
no-self-use,
protected-access,
redefined-builtin,
redefined-outer-name,
redefined-variable-type,
relative-import,
signature-differs,
super-init-not-called,
too-few-public-methods,
too-many-arguments,
too-many-branches,
too-many-instance-attributes,
too-many-lines,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
undefined-loop-variable,
ungrouped-imports,
unnecessary-lambda,
unused-argument,
unused-variable,
wildcard-import,
wrong-import-order,
wrong-import-position,
deprecated-lambda,
simplifiable-if-statement,
unidiomatic-typecheck,
global-at-module-level,
inconsistent-return-statements,
keyword-arg-before-vararg,
assignment-from-no-return,
useless-return,
assignment-from-none,
stop-iteration-return
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]".
files-output=no
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=_$|dummy
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging
[BASIC]
# List of builtins function names that should not be used, separated by a comma
bad-functions=map,filter,input
# Good variable names which should always be accepted, separated by a comma
# allow `d` as its used frequently for deferred callback chains
good-names=i,j,k,ex,Run,_,d
# Bad variable names which should always be refused, separated by a comma
bad-names=foo,bar,baz,toto,tutu,tata
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct variable names
variable-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for variable names
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct constant names
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Naming hint for constant names
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
# Regular expression matching correct attribute names
attr-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for attribute names
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for argument names
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Naming hint for class attribute names
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
# Naming hint for inline iteration names
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
# Naming hint for class names
class-name-hint=[A-Z_][a-zA-Z0-9]+$
# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Naming hint for module names
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
# Regular expression matching correct method names
method-rgx=[a-z_][a-z0-9_]{2,30}$
# Naming hint for method names
method-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
[ELIF]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,dict-separator
# Maximum number of lines in a module
max-module-lines=1000
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,XXX,TODO
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[TYPECHECK]
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=leveldb,distutils
# Ignoring distutils because: https://github.com/PyCQA/pylint/issues/73
# List of classes names for which member attributes should not be checked
# (useful for classes with attributes dynamically set). This supports can work
# with qualified names.
# ignored-classes=
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,TERMIOS,Bastion,rexec,miniupnpc
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
[DESIGN]
# Maximum number of arguments for function / method
max-args=10
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of statements in function / method body
max-statements=50
# Maximum number of parents for a class (see R0901).
max-parents=8
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of boolean expressions in a if statement
max-bool-expr=5
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,__new__,setUp
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception

View file

@ -1,8 +1,11 @@
import logging
import sys
import asyncio
import logging
import textwrap
import typing
from collections import OrderedDict
from aioupnp.upnp import UPnP
from aioupnp.commands import SOAPCommands
log = logging.getLogger("aioupnp")
handler = logging.StreamHandler()
@ -16,17 +19,18 @@ base_usage = "\n".join(textwrap.wrap(
100, subsequent_indent=' ', break_long_words=False)) + "\n"
def get_help(command):
fn = getattr(UPnP, command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return'])
def get_help(command: str) -> str:
annotations = UPnP.get_annotations(command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return'])
return base_usage + "\n".join(
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False)
)
def main(argv=None, loop=None):
argv = argv or sys.argv
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
argv = argv or list(sys.argv)
commands = list(SOAPCommands.SOAP_COMMANDS)
help_str = "\n".join(textwrap.wrap(
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
))
@ -41,14 +45,16 @@ def main(argv=None, loop=None):
"For help with a specific command:" \
" aioupnp help <command>\n" % (base_usage, help_str)
args = argv[1:]
args: typing.List[str] = [str(arg) for arg in argv[1:]]
if args[0] in ['help', '-h', '--help']:
if len(args) > 1:
if args[1] in commands:
sys.exit(get_help(args[1]))
sys.exit(print(usage))
print(get_help(args[1]))
return 0
print(usage)
return 0
defaults = {
defaults: typing.Dict[str, typing.Union[bool, str, int]] = {
'debug_logging': False,
'interface': 'default',
'gateway_address': '',
@ -57,22 +63,22 @@ def main(argv=None, loop=None):
'unicast': False
}
options = OrderedDict()
options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
command = None
for arg in args:
if arg.startswith("--"):
if "=" in arg:
k, v = arg.split("=")
options[k.lstrip('--')] = v
else:
k, v = arg, True
k = k.lstrip('--')
options[k] = v
options[arg.lstrip('--')] = True
else:
command = arg
break
if not command:
print("no command given")
sys.exit(print(usage))
print(usage)
return 0
kwargs = {}
for arg in args[len(options)+1:]:
if arg.startswith("--"):
@ -81,18 +87,24 @@ def main(argv=None, loop=None):
kwargs[k] = v
else:
break
for k, v in defaults.items():
for k in defaults:
if k not in options:
options[k] = v
options[k] = defaults[k]
if options.pop('debug_logging'):
log.setLevel(logging.DEBUG)
lan_address: str = str(options.pop('lan_address'))
gateway_address: str = str(options.pop('gateway_address'))
timeout: int = int(options.pop('timeout'))
interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
UPnP.run_cli(
command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'),
options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
)
return 0
if __name__ == "__main__":
main()
sys.exit(main())

View file

@ -1,64 +1,67 @@
import logging
import asyncio
import time
import typing
from typing import Tuple, Union, List
import functools
import logging
from typing import Tuple
from aioupnp.protocols.scpd import scpd_post
from aioupnp.device import Service
log = logging.getLogger(__name__)
none_or_str = Union[None, str]
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
def safe_type(t):
if t is typing.Tuple:
return tuple
if t is typing.List:
return list
if t is typing.Dict:
return dict
if t is typing.Set:
return set
return t
def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]:
return x if x is not None and str(x).lower() not in ['none', 'nil'] else None
class SOAPCommand:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None:
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
self.service_id = service_id
self.method = method
self.param_types = param_types
self.param_order = param_order
self.return_types = return_types
self.return_order = return_order
self.loop = loop
self._requests: typing.List = []
def soap_bool(x: typing.Optional[str]) -> bool:
return False if not x or str(x).lower() in ['false', 'False'] else True
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
def recast_single_result(t, result):
if t is bool:
return soap_bool(result)
if t is str:
return soap_optional_str(result)
return t(result)
def recast_return(return_annotation, result, result_keys: typing.List[str]):
if return_annotation is None:
return None
if len(result_keys) == 1:
assert len(result_keys) == 1
single_result = result[result_keys[0]]
return recast_single_result(return_annotation, single_result)
annotated_args: typing.List[type] = list(return_annotation.__args__)
assert len(annotated_args) == len(result_keys)
recast_results: typing.List[typing.Optional[typing.Union[str, int, bool, bytes]]] = []
for type_annotation, result_key in zip(annotated_args, result_keys):
recast_results.append(recast_single_result(type_annotation, result[result_key]))
return tuple(recast_results)
def soap_command(fn):
@functools.wraps(fn)
async def wrapper(self: 'SOAPCommands', **kwargs):
if not self.is_registered(fn.__name__):
return fn(self, **kwargs)
service = self.get_service(fn.__name__)
assert service.controlURL is not None
assert service.serviceType is not None
response, xml_bytes, err = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order,
self.service_id, self.loop, **soap_kwargs
service.controlURL, self._base_address.decode(), self._port, fn.__name__, self._registered[service][fn.__name__][0],
service.serviceType.encode(), self._loop, **kwargs
)
if err is not None:
self._requests.append((soap_kwargs, xml_bytes, None, err, time.time()))
self._requests.append((fn.__name__, kwargs, xml_bytes, None, err, time.time()))
raise err
if not response:
result = None
else:
recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
if len(recast_result) == 1:
result = recast_result[0]
else:
result = recast_result
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time()))
result = recast_return(fn.__annotations__.get('return'), response, self._registered[service][fn.__name__][1])
self._requests.append((fn.__name__, kwargs, xml_bytes, result, None, time.time()))
return result
return wrapper
class SOAPCommands:
@ -72,7 +75,7 @@ class SOAPCommands:
to their expected types.
"""
SOAP_COMMANDS = [
SOAP_COMMANDS: typing.List[str] = [
'AddPortMapping',
'GetNATRSIPStatus',
'GetGenericPortMappingEntry',
@ -91,59 +94,63 @@ class SOAPCommands:
'GetTotalPacketsReceived',
'X_GetICSStatistics',
'GetDefaultConnectionService',
'NewDefaultConnectionService',
'NewEnabledForInternet',
'SetDefaultConnectionService',
'SetEnabledForInternet',
'GetEnabledForInternet',
'NewActiveConnectionIndex',
'GetMaximumActiveConnections',
'GetActiveConnections'
]
def __init__(self):
self._registered = set()
def __init__(self, loop: asyncio.AbstractEventLoop, base_address: bytes, port: int) -> None:
self._loop = loop
self._registered: typing.Dict[Service,
typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]]] = {}
self._base_address = base_address
self._port = port
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Optional[typing.Dict[str, typing.Any]],
typing.Optional[Exception], float]] = []
def register(self, base_ip: bytes, port: int, name: str, control_url: str,
service_type: bytes, inputs: List, outputs: List, loop=None) -> None:
if name not in self.SOAP_COMMANDS or name in self._registered:
def is_registered(self, name: str) -> bool:
if name not in self.SOAP_COMMANDS:
raise ValueError("unknown command")
for service in self._registered.values():
if name in service:
return True
return False
def get_service(self, name: str) -> Service:
if name not in self.SOAP_COMMANDS:
raise ValueError("unknown command")
for service, commands in self._registered.items():
if name in commands:
return service
raise ValueError(name)
def register(self, name: str, service: Service, inputs: typing.List[str], outputs: typing.List[str]) -> None:
# control_url: str, service_type: bytes,
if name not in self.SOAP_COMMANDS:
raise AttributeError(name)
current = getattr(self, name)
annotations = current.__annotations__
return_types = annotations.get('return', None)
if return_types:
if hasattr(return_types, '__args__'):
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
elif isinstance(return_types, type):
return_types = (return_types,)
return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {}
for param_name, param_type in annotations.items():
if param_name == "return":
continue
param_types[param_name] = param_type
command = SOAPCommand(
base_ip.decode(), port, control_url, service_type,
name, param_types, return_types, inputs, outputs, loop=loop
)
setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command)
self._registered.add(command.method)
if self.is_registered(name):
raise AttributeError(f"{name} is already a registered SOAP command")
if service not in self._registered:
self._registered[service] = {}
self._registered[service][name] = inputs, outputs
@staticmethod
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
NewLeaseDuration: str) -> None:
@soap_command
async def AddPortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
NewLeaseDuration: str) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
async def GetNATRSIPStatus() -> Tuple[bool, bool]:
@soap_command
async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
@soap_command
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
bool, str, int]:
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
@ -151,100 +158,100 @@ class SOAPCommands:
"""
raise NotImplementedError()
@staticmethod
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
@soap_command
async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
async def SetConnectionType(NewConnectionType: str) -> None:
@soap_command
async def SetConnectionType(self, NewConnectionType: str) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
async def GetExternalIPAddress() -> str:
@soap_command
async def GetExternalIPAddress(self) -> str:
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
async def GetConnectionTypeInfo() -> Tuple[str, str]:
@soap_command
async def GetConnectionTypeInfo(self) -> Tuple[str, str]:
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
async def GetStatusInfo() -> Tuple[str, str, int]:
@soap_command
async def GetStatusInfo(self) -> Tuple[str, str, int]:
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
async def ForceTermination() -> None:
@soap_command
async def ForceTermination(self) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
@soap_command
async def DeletePortMapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
async def RequestConnection() -> None:
@soap_command
async def RequestConnection(self) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
async def GetCommonLinkProperties():
@soap_command
async def GetCommonLinkProperties(self) -> Tuple[str, int, int, str]:
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
async def GetTotalBytesSent():
@soap_command
async def GetTotalBytesSent(self) -> int:
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
async def GetTotalBytesReceived():
@soap_command
async def GetTotalBytesReceived(self) -> int:
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
async def GetTotalPacketsSent():
@soap_command
async def GetTotalPacketsSent(self) -> int:
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsReceived():
@soap_command
async def GetTotalPacketsReceived(self) -> int:
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
@staticmethod
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]:
@soap_command
async def X_GetICSStatistics(self) -> Tuple[int, int, int, int, str, str]:
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
async def GetDefaultConnectionService():
@soap_command
async def GetDefaultConnectionService(self) -> str:
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None:
@soap_command
async def SetDefaultConnectionService(self, NewDefaultConnectionService: str) -> None:
"""Returns (None)"""
raise NotImplementedError()
@staticmethod
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None:
@soap_command
async def SetEnabledForInternet(self, NewEnabledForInternet: bool) -> None:
raise NotImplementedError()
@staticmethod
async def GetEnabledForInternet() -> bool:
@soap_command
async def GetEnabledForInternet(self) -> bool:
raise NotImplementedError()
@staticmethod
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int):
@soap_command
async def GetMaximumActiveConnections(self, NewActiveConnectionIndex: int):
raise NotImplementedError()
@staticmethod
async def GetActiveConnections() -> Tuple[str, str]:
@soap_command
async def GetActiveConnections(self) -> Tuple[str, str]:
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()

View file

@ -1,23 +1,33 @@
from collections import OrderedDict
import typing
import logging
from typing import List
log = logging.getLogger(__name__)
class CaseInsensitive:
def __init__(self, **kwargs) -> None:
for k, v in kwargs.items():
def __init__(self, **kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List[typing.Any]]]) -> None:
keys: typing.List[str] = list(kwargs.keys())
for k in keys:
if not k.startswith("_"):
setattr(self, k, v)
assert k in kwargs
setattr(self, k, kwargs[k])
def __getattr__(self, item):
for k in self.__class__.__dict__.keys():
def __getattr__(self, item: str) -> typing.Union[str, typing.Dict[str, typing.Any], typing.List]:
keys: typing.List[str] = list(self.__class__.__dict__.keys())
for k in keys:
if k.lower() == item.lower():
return self.__dict__.get(k)
value: typing.Optional[typing.Union[str, typing.Dict[str, typing.Any],
typing.List]] = self.__dict__.get(k)
assert value is not None and isinstance(value, (str, dict, list))
return value
raise AttributeError(item)
def __setattr__(self, item, value):
for k, v in self.__class__.__dict__.items():
def __setattr__(self, item: str,
value: typing.Union[str, typing.Dict[str, typing.Any], typing.List]) -> None:
assert isinstance(value, (str, dict)), ValueError(f"got type {str(type(value))}, expected str")
keys: typing.List[str] = list(self.__class__.__dict__.keys())
for k in keys:
if k.lower() == item.lower():
self.__dict__[k] = value
return
@ -26,52 +36,57 @@ class CaseInsensitive:
return
raise AttributeError(item)
def as_dict(self) -> dict:
return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
}
def as_dict(self) -> typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]:
result: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]] = OrderedDict()
keys: typing.List[str] = list(self.__dict__.keys())
for k in keys:
if not k.startswith("_"):
result[k] = self.__getattr__(k)
return result
class Service(CaseInsensitive):
serviceType = None
serviceId = None
controlURL = None
eventSubURL = None
SCPDURL = None
serviceType: typing.Optional[str] = None
serviceId: typing.Optional[str] = None
controlURL: typing.Optional[str] = None
eventSubURL: typing.Optional[str] = None
SCPDURL: typing.Optional[str] = None
class Device(CaseInsensitive):
serviceList = None
deviceList = None
deviceType = None
friendlyName = None
manufacturer = None
manufacturerURL = None
modelDescription = None
modelName = None
modelNumber = None
modelURL = None
serialNumber = None
udn = None
upc = None
presentationURL = None
iconList = None
serviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
deviceList: typing.Optional[typing.Dict[str, typing.Union[typing.Dict[str, typing.Any], typing.List]]] = None
deviceType: typing.Optional[str] = None
friendlyName: typing.Optional[str] = None
manufacturer: typing.Optional[str] = None
manufacturerURL: typing.Optional[str] = None
modelDescription: typing.Optional[str] = None
modelName: typing.Optional[str] = None
modelNumber: typing.Optional[str] = None
modelURL: typing.Optional[str] = None
serialNumber: typing.Optional[str] = None
udn: typing.Optional[str] = None
upc: typing.Optional[str] = None
presentationURL: typing.Optional[str] = None
iconList: typing.Optional[str] = None
def __init__(self, devices: List, services: List, **kwargs) -> None:
def __init__(self, devices: typing.List['Device'], services: typing.List[Service],
**kwargs: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any], typing.List]]) -> None:
super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"]
if isinstance(new_services, dict):
new_services = [new_services]
services.extend([Service(**service) for service in new_services])
if isinstance(self.serviceList['service'], dict):
assert isinstance(self.serviceList['service'], dict)
svc_list: typing.Dict[str, typing.Any] = self.serviceList['service']
services.append(Service(**svc_list))
elif isinstance(self.serviceList['service'], list):
services.extend(Service(**svc) for svc in self.serviceList["service"])
if self.deviceList:
for kw in self.deviceList.values():
if isinstance(kw, dict):
d = Device(devices, services, **kw)
devices.append(d)
devices.append(Device(devices, services, **kw))
elif isinstance(kw, list):
for _inner_kw in kw:
d = Device(devices, services, **_inner_kw)
devices.append(d)
devices.append(Device(devices, services, **_inner_kw))
else:
log.warning("failed to parse device:\n%s", kw)

View file

@ -1,14 +1,2 @@
from aioupnp.util import flatten_keys
from aioupnp.constants import FAULT, CONTROL
class UPnPError(Exception):
pass
def handle_fault(response: dict) -> dict:
if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
error_description = fault['detail']['UPnPError']['errorDescription']
raise UPnPError(error_description)
return response

View file

@ -1,9 +1,10 @@
import re
import logging
import socket
import typing
import asyncio
from collections import OrderedDict
from typing import Dict, List, Union, Type
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from typing import Dict, List, Union
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service
@ -15,77 +16,94 @@ from aioupnp.fault import UPnPError
log = logging.getLogger(__name__)
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
def get_action_list(element_dict: typing.Dict[str, typing.Union[str, typing.Dict[str, str],
typing.List[typing.Dict[str, typing.Dict[str, str]]]]]
) -> typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]]:
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
result: typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]] = []
if "actionList" in service_info:
action_list = service_info["actionList"]
else:
return []
return result
if not len(action_list): # it could be an empty string
return []
return result
result: list = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts]
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
action = action_list["action"]
if isinstance(action, dict):
arg_dicts: typing.List[typing.Dict[str, str]] = []
if not isinstance(action['argumentList']['argument'], list): # when there is one arg
arg_dicts.extend([action['argumentList']['argument']])
else:
arg_dicts = action['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts]
arg_dicts.extend(action['argumentList']['argument'])
result.append((action_list["action"]['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']))
return result
assert isinstance(action, list)
for _action in action:
if not _action.get('argumentList'):
result.append((_action['name'], [], []))
else:
if not isinstance(_action['argumentList']['argument'], list): # when there is one arg
arg_dicts = [_action['argumentList']['argument']]
else:
arg_dicts = _action['argumentList']['argument']
result.append((
action['name'],
_action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
base_address_result: typing.List[bytes] = BASE_ADDRESS_REGEX.findall(location)
base_address = base_address_result[0]
port_result: typing.List[bytes] = BASE_PORT_REGEX.findall(location)
port = int(port_result[0])
return base_address, port
class Gateway:
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str,
gateway_address: str) -> None:
commands: SOAPCommands
def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]],
lan_address: str, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet
self._m_search_args = m_search_args
self._lan_address = lan_address
self.usn = (ok_packet.usn or '').encode()
self.ext = (ok_packet.ext or '').encode()
self.server = (ok_packet.server or '').encode()
self.location = (ok_packet.location or '').encode()
self.cache_control = (ok_packet.cache_control or '').encode()
self.date = (ok_packet.date or '').encode()
self.urn = (ok_packet.st or '').encode()
self.usn: bytes = (ok_packet.usn or '').encode()
self.ext: bytes = (ok_packet.ext or '').encode()
self.server: bytes = (ok_packet.server or '').encode()
self.location: bytes = (ok_packet.location or '').encode()
self.cache_control: bytes = (ok_packet.cache_control or '').encode()
self.date: bytes = (ok_packet.date or '').encode()
self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response = b""
self._service_descriptors: Dict = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self._xml_response: bytes = b""
self._service_descriptors: Dict[str, bytes] = {}
self.base_address, self.port = parse_location(self.location)
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version = None
self.url_base = None
self.spec_version: typing.Optional[str] = None
self.url_base: typing.Optional[str] = None
self._device: Union[None, Device] = None
self._devices: List = []
self._services: List = []
self._device: typing.Optional[Device] = None
self._devices: List[Device] = []
self._services: List[Service] = []
self._unsupported_actions: Dict = {}
self._registered_commands: Dict = {}
self.commands = SOAPCommands()
self._unsupported_actions: Dict[str, typing.List[str]] = {}
self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
def gateway_descriptor(self) -> dict:
r = {
@ -102,14 +120,15 @@ class Gateway:
def manufacturer_string(self) -> str:
if not self.devices:
return "UNKNOWN GATEWAY"
device = list(self.devices.values())[0]
return "%s %s" % (device.manufacturer, device.modelName)
devices: typing.List[Device] = list(self.devices.values())
device = devices[0]
return f"{device.manufacturer} {device.modelName}"
@property
def services(self) -> Dict:
def services(self) -> Dict[str, Service]:
if not self._device:
return {}
return {service.serviceType: service for service in self._services}
return {str(service.serviceType): service for service in self._services}
@property
def devices(self) -> Dict:
@ -117,28 +136,29 @@ class Gateway:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Union[Type[Service], None]:
def get_service(self, service_type: str) -> typing.Optional[Service]:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
if service.serviceType and service.serviceType.lower() == service_type.lower():
return service
return None
@property
def soap_requests(self) -> List:
soap_call_infos = []
for name in self._registered_commands.keys():
if not hasattr(getattr(self.commands, name), "_requests"):
continue
soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
request_args, raw_response, decoded_response, soap_error, ts
) in getattr(self.commands, name)._requests
])
def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Optional[typing.Dict[str, typing.Any]],
typing.Optional[Exception], float]]:
soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Optional[typing.Dict[str, typing.Any]],
typing.Optional[Exception], float]] = []
soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
name, request_args, raw_response, decoded_response, soap_error, ts
) in self.commands._requests
])
soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos
def debug_gateway(self) -> Dict:
def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip,
@ -156,9 +176,11 @@ class Gateway:
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = False):
ignored: set = set()
required_commands = [
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: bool = False) -> 'Gateway':
ignored: typing.Set[str] = set()
required_commands: typing.List[str] = [
'AddPortMapping',
'DeletePortMapping',
'GetExternalIPAddress'
@ -166,20 +188,21 @@ class Gateway:
while True:
if not igd_args:
m_search_args, datagram = await fuzzy_m_search(
lan_address, gateway_address, timeout, loop, ignored, unicast
lan_address, gateway_address, timeout, loop, ignored, unicast
)
else:
m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands(loop)
requirements_met = all([required in gateway._registered_commands for required in required_commands])
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if required not in gateway._registered_commands
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
@ -188,13 +211,17 @@ class Gateway:
log.debug('found gateway device %s', datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
ignored.add(datagram.location)
continue
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = None):
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: typing.Optional[bool] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop()
if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
@ -205,7 +232,7 @@ class Gateway:
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
)
], return_when=asyncio.tasks.FIRST_COMPLETED)
], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
for task in pending:
task.cancel()
@ -214,56 +241,78 @@ class Gateway:
task.exception()
except asyncio.CancelledError:
pass
results: typing.List[asyncio.Future['Gateway']] = list(done)
return results[0].result()
return list(done)[0].result()
async def discover_commands(self, loop=None):
async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
self._xml_response = xml_bytes
if get_err is not None:
raise get_err
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
if isinstance(spec_version, bytes):
self.spec_version = spec_version.decode()
else:
self.spec_version = spec_version
url_base = get_dict_val_case_insensitive(response, "urlbase")
if isinstance(url_base, bytes):
self.url_base = url_base.decode()
else:
self.url_base = url_base
if not self.url_base:
self.url_base = self.base_address.decode()
if response:
device_dict = get_dict_val_case_insensitive(response, "device")
source_keys: typing.List[str] = list(response.keys())
matches: typing.List[str] = list(filter(lambda x: x.lower() == "device", source_keys))
match_key = matches[0]
match: dict = response[match_key]
# if not len(match):
# return None
# if len(match) > 1:
# raise KeyError("overlapping keys")
# if len(match) == 1:
# matched_key: typing.AnyStr = match[0]
# return source[matched_key]
# raise KeyError("overlapping keys")
self._device = Device(
self._devices, self._services, **device_dict
self._devices, self._services, **match
)
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
await self.register_commands(self.services[service_type], loop)
return None
async def register_commands(self, service: Service, loop=None):
async def register_commands(self, service: Service,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL:
raise UPnPError("no scpd url")
if not service.serviceType:
raise UPnPError("no service type")
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
self._service_descriptors[service.SCPDURL] = xml_bytes
if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
if xml_bytes:
log.debug("response: %s", xml_bytes.decode())
return
return None
if not service_dict:
return
return None
action_list = get_action_list(service_dict)
for name, inputs, outputs in action_list:
try:
self.commands.register(self.base_ip, self.port, name, service.controlURL, service.serviceType.encode(),
inputs, outputs, loop)
self.commands.register(name, service, inputs, outputs)
self._registered_commands[name] = service.serviceType
log.debug("registered %s::%s", service.serviceType, name)
except AttributeError:
s = self._unsupported_actions.get(service.serviceType, [])
s.append(name)
self._unsupported_actions[service.serviceType] = s
self._unsupported_actions.setdefault(service.serviceType, [])
self._unsupported_actions[service.serviceType].append(name)
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)
log.debug("registered service %s", service.serviceType)
return None

51
aioupnp/interfaces.py Normal file
View file

@ -0,0 +1,51 @@
import socket
from collections import OrderedDict
import typing
import netifaces
def get_netifaces():
return netifaces
def ifaddresses(iface: str):
return get_netifaces().ifaddresses(iface)
def _get_interfaces():
return get_netifaces().interfaces()
def _get_gateways():
return get_netifaces().gateways()
def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]:
gateways = _get_gateways()
infos = gateways[socket.AF_INET]
assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict")
interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos
result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict(
(interface_name, (router_address, ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']))
for router_address, interface_name, _ in interface_infos
)
for interface_name in _get_interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in result:
continue
addresses = ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
result[interface_name] = (gateway_guess, address)
_default = gateways['default']
assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list")
default: typing.Dict[int, typing.Tuple[str, str]] = _default
result['default'] = result[default[netifaces.AF_INET][1]]
return result
def get_gateway_and_lan_addresses(interface_name: str) -> typing.Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return '', ''

View file

@ -44,10 +44,11 @@ ST
characters in the domain name must be replaced with hyphens in accordance with RFC 2141.
"""
import typing
from collections import OrderedDict
from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST
SEARCH_TARGETS = [
SEARCH_TARGETS: typing.List[str] = [
'upnp:rootdevice',
'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
'urn:schemas-wifialliance-org:device:WFADevice:1',
@ -58,7 +59,8 @@ SEARCH_TARGETS = [
]
def format_packet_args(order: list, **kwargs):
def format_packet_args(order: typing.List[str],
kwargs: typing.Dict[str, typing.Union[int, str]]) -> typing.Dict[str, typing.Union[int, str]]:
args = []
for o in order:
for k, v in kwargs.items():
@ -68,18 +70,18 @@ def format_packet_args(order: list, **kwargs):
return OrderedDict(args)
def packet_generator():
def packet_generator() -> typing.Iterator[typing.Dict[str, typing.Union[int, str]]]:
for st in SEARCH_TARGETS:
order = ["HOST", "MAN", "MX", "ST"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, {'Host': SSDP_HOST, 'Man': SSDP_DISCOVER, 'MX': 1, 'ST': st})
order = ["HOST", "MAN", "ST", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})
order = ["HOST", "ST", "MAN", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': '"%s"' % SSDP_DISCOVER, 'MX': 1, 'ST': st})
yield format_packet_args(order, {'HOST': SSDP_HOST, 'MAN': SSDP_DISCOVER, 'MX': 1, 'ST': st})

View file

@ -1,45 +1,70 @@
import struct
import socket
import typing
from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport
from asyncio.transports import BaseTransport
from unittest import mock
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]:
if transport is None or not hasattr(transport, "_extra"):
return None
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
assert sock is None or isinstance(sock, socket.SocketType) or isinstance(sock, mock.MagicMock)
return sock
class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address
self.bind_address = bind_address
self.transport: DatagramTransport
self.transport: typing.Optional[BaseTransport] = None
@property
def sock(self) -> socket.socket:
s: socket.socket = self.transport.get_extra_info(name='socket')
return s
def sock(self) -> typing.Optional[socket.socket]:
return _get_sock(self.transport)
def get_ttl(self) -> int:
return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
sock = self.sock
if not sock:
raise ValueError("not connected")
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
def set_ttl(self, ttl: int = 1) -> None:
self.sock.setsockopt(
sock = self.sock
if not sock:
return None
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
)
return None
def join_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt(
sock = self.sock
if not sock:
return None
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
return None
def leave_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt(
sock = self.sock
if not sock:
raise ValueError("not connected")
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
return None
def connection_made(self, transport) -> None:
def connection_made(self, transport: BaseTransport) -> None:
self.transport = transport
return None
@classmethod
def create_multicast_socket(cls, bind_address: str):
def create_multicast_socket(cls, bind_address: str) -> socket.socket:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0))

View file

@ -2,7 +2,6 @@ import logging
import typing
import re
from collections import OrderedDict
from xml.etree import ElementTree
import asyncio
from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError
@ -18,16 +17,22 @@ log = logging.getLogger(__name__)
HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]:
def parse_http_response_code(http_response: bytes) -> typing.Tuple[bytes, bytes]:
parsed: typing.List[typing.Tuple[bytes, bytes]] = HTTP_CODE_REGEX.findall(http_response)
return parsed[0]
def parse_headers(response: bytes) -> typing.Tuple[typing.Dict[bytes, bytes], int, bytes]:
lines = response.split(b'\r\n')
headers = OrderedDict([
headers: typing.Dict[bytes, bytes] = OrderedDict([
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
for l in response.split(b'\r\n')
])
if len(lines) != len(headers):
raise ValueError("duplicate headers")
http_response = tuple(headers.keys())[0]
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0]
header_keys: typing.List[bytes] = list(headers.keys())
http_response = header_keys[0]
response_code, message = parse_http_response_code(http_response)
del headers[http_response]
return headers, int(response_code), message
@ -40,37 +45,42 @@ class SCPDHTTPClientProtocol(Protocol):
and devices respond with an invalid HTTP version line
"""
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None) -> None:
def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]',
soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None:
self.message = message
self.response_buff = b""
self.finished = finished
self.soap_method = soap_method
self.soap_service_id = soap_service_id
self._response_code: int = 0
self._response_msg: bytes = b""
self._content_length: int = 0
self._response_code = 0
self._response_msg = b""
self._content_length = 0
self._got_headers = False
self._headers: dict = {}
self._headers: typing.Dict[bytes, bytes] = {}
self._body = b""
self.transport: typing.Optional[asyncio.WriteTransport] = None
def connection_made(self, transport):
transport.write(self.message)
def connection_made(self, transport: asyncio.BaseTransport) -> None:
assert isinstance(transport, asyncio.WriteTransport)
self.transport = transport
self.transport.write(self.message)
return None
def data_received(self, data):
def data_received(self, data: bytes) -> None:
self.response_buff += data
for i, line in enumerate(self.response_buff.split(b'\r\n')):
if not line: # we hit the blank line between the headers and the body
if i == (len(self.response_buff.split(b'\r\n')) - 1):
return # the body is still yet to be written
return None # the body is still yet to be written
if not self._got_headers:
self._headers, self._response_code, self._response_msg = parse_headers(
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
)
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length')
content_length = get_dict_val_case_insensitive(
self._headers, b'Content-Length'
)
if content_length is None:
return
return None
self._content_length = int(content_length or 0)
self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
@ -86,21 +96,28 @@ class SCPDHTTPClientProtocol(Protocol):
)
)
)
return
return None
return None
async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typing.Tuple[typing.Dict, bytes,
typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
finished: asyncio.Future = asyncio.Future()
async def scpd_get(control_url: str, address: str, port: int,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> typing.Tuple[typing.Dict[str, typing.Any], bytes,
typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol(packet, finished), address, port
finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
)
protocol = connect_tup[1]
transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol)
error = None
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
body, response_code, response_msg = await wait_task
except asyncio.TimeoutError:
error = UPnPError("get request timed out")
body = b''
@ -112,24 +129,31 @@ async def scpd_get(control_url: str, address: str, port: int, loop=None) -> typi
if not error:
try:
return deserialize_scpd_get_response(body), body, None
except ElementTree.ParseError as err:
except Exception as err:
error = UPnPError(err)
return {}, body, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
loop=None, **kwargs) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
finished: asyncio.Future = asyncio.Future()
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
**kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
finished: asyncio.Future[typing.Tuple[bytes, int, bytes]] = asyncio.Future(loop=loop)
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol(
packet, finished, soap_method=method, soap_service_id=service_id.decode(),
), address, port
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
)
protocol = connect_tup[1]
transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol)
try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
body, response_code, response_msg = await wait_task
except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout")
except UPnPError as err:
@ -140,5 +164,5 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None
)
except (ElementTree.ParseError, UPnPError) as err:
except Exception as err:
return {}, body, UPnPError(err)

View file

@ -3,8 +3,8 @@ import binascii
import asyncio
import logging
import typing
import socket
from collections import OrderedDict
from asyncio.futures import Future
from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram
@ -18,32 +18,48 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None,
unicast: bool = False) -> None:
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None,
unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: typing.Optional[DatagramTransport] = None
self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
self.notifications: typing.List = []
self._pending_searches: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = []
self.notifications: typing.List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop)
def disconnect(self):
def connection_made(self, transport) -> None:
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
super().connection_made(transport)
self.connected.set()
def disconnect(self) -> None:
if self.transport:
try:
self.leave_group(self.multicast_address, self.bind_address)
except ValueError:
pass
except Exception:
log.exception("unexpected error leaving multicast group")
self.transport.close()
self.connected.clear()
while self._pending_searches:
pending = self._pending_searches.pop()[2]
if not pending.cancelled() and not pending.done():
pending.cancel()
return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored:
return
tmp: typing.List = []
set_futures: typing.List = []
while self._pending_searches:
t: tuple = self._pending_searches.pop()
a, s = t[0], t[1]
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
f: Future = t[2]
return None
# TODO: fix this
tmp: typing.List[typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle]] = []
set_futures: typing.List[asyncio.Future[SSDPDatagram]] = []
while len(self._pending_searches):
t: typing.Tuple[str, str, asyncio.Future[SSDPDatagram], asyncio.Handle] = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f: asyncio.Future[SSDPDatagram] = t[2]
if f not in set_futures:
set_futures.append(f)
if not f.done():
@ -52,38 +68,41 @@ class SSDPProtocol(MulticastProtocol):
tmp.append(t)
while tmp:
self._pending_searches.append(tmp.pop())
return None
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
def _send_m_search(self, address: str, packet: SSDPDatagram) -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS
for packet in packets:
log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
if not self.transport:
raise UPnPError("SSDP transport not connected")
log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
fut: Future = Future()
packets: typing.List[SSDPDatagram] = []
async def m_search(self, address: str, timeout: float,
datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
fut: asyncio.Future[SSDPDatagram] = asyncio.Future(loop=self.loop)
for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None
self._pending_searches.append((address, packet.st, fut))
packets.append(packet)
self.send_many_m_searches(address, packets),
return await fut
self._pending_searches.append(
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet))
)
return await asyncio.wait_for(fut, timeout)
def datagram_received(self, data, addr) -> None:
def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore
if addr[0] == self.bind_address:
return
return None
try:
packet = SSDPDatagram.decode(data)
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data))
return
return None
if packet._packet_type == packet._OK:
self._callback_m_search_ok(addr[0], packet)
return
return None
# elif packet._packet_type == packet._NOTIFY:
# log.debug("%s:%i sent us a notification: %s", packet)
# if packet.nt == SSDP_ROOT_DEVICE:
@ -104,17 +123,18 @@ class SSDPProtocol(MulticastProtocol):
# return
async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Optional[typing.Set[str]] = None,
unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop()
try:
sock = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
)
transport: DatagramTransport = listen_result[0]
protocol: SSDPProtocol = listen_result[1]
transport = listen_result[0]
protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol)
except Exception as err:
print(err)
raise UPnPError(err)
@ -125,30 +145,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
protocol.disconnect()
raise UPnPError(err)
return transport, protocol, gateway_address, lan_address
return protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
loop=None, ignored: typing.Set[str] = None,
unicast: bool = False) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]],
timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
)
try:
return await asyncio.wait_for(
protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout
)
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:
protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
)
await protocol.connected.wait()
packet_args = list(packet_generator())
batch_size = 2
batch_timeout = float(timeout) / float(len(packet_args))
@ -157,7 +178,7 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try:
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), batch_timeout)
await protocol.m_search(gateway_address, batch_timeout, args)
protocol.disconnect()
return args
except asyncio.TimeoutError:
@ -166,9 +187,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
SSDPDatagram]:
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.Tuple[typing.Dict[str,
typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked

View file

@ -1,8 +1,9 @@
import re
from typing import Dict
from xml.etree import ElementTree
from aioupnp.constants import XML_VERSION, DEVICE, ROOT
from aioupnp.util import etree_to_dict, flatten_keys
from typing import Dict, Any, List, Tuple
from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION
from aioupnp.serialization.xml import xml_to_dict
from aioupnp.util import flatten_keys
CONTENT_PATTERN = re.compile(
@ -28,34 +29,38 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
if not path.startswith("/"):
path = "/" + path
return (
(
'GET %s HTTP/1.1\r\n'
'Accept-Encoding: gzip\r\n'
'Host: %s\r\n'
'Connection: Close\r\n'
'\r\n'
) % (path, host)
f'GET {path} HTTP/1.1\r\n'
f'Accept-Encoding: gzip\r\n'
f'Host: {host}\r\n'
f'Connection: Close\r\n'
f'\r\n'
).encode()
def deserialize_scpd_get_response(content: bytes) -> Dict:
def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
if XML_VERSION.encode() in content:
parsed = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0]
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content)
xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode())
return parse_device_dict(xml_dict)
return {}
def parse_device_dict(xml_dict: dict) -> Dict:
def parse_device_dict(xml_dict: Dict[str, Any]) -> Dict[str, Any]:
keys = list(xml_dict.keys())
found = False
for k in keys:
m = XML_ROOT_SANITY_PATTERN.findall(k)
m: List[Tuple[str, str, str, str, str, str]] = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
xml_dict = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
schema_key: str = m[1][0]
root: str = m[2][5]
flattened = flatten_keys(xml_dict, "{%s}" % schema_key)
if root not in flattened:
raise UPnPError("root device not found")
xml_dict = flattened[root]
found = True
break
if not found:
raise UPnPError("device not found")
result = {}
for k, v in xml_dict.items():
if isinstance(xml_dict[k], dict):
@ -65,10 +70,9 @@ def parse_device_dict(xml_dict: dict) -> Dict:
if len(parsed_k) == 2:
inner_d[parsed_k[0]] = inner_v
else:
assert len(parsed_k) == 3
assert len(parsed_k) == 3, f"expected len=3, got {len(parsed_k)}"
inner_d[parsed_k[1]] = inner_v
result[k] = inner_d
else:
result[k] = v
return result

View file

@ -1,64 +1,65 @@
import re
from xml.etree import ElementTree
from aioupnp.util import etree_to_dict, flatten_keys
from aioupnp.fault import handle_fault, UPnPError
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY
import typing
from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY, FAULT, CONTROL
from aioupnp.serialization.xml import xml_to_dict
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
)
def serialize_soap_post(method: str, param_names: list, service_id: bytes, gateway_address: bytes,
control_url: bytes, **kwargs) -> bytes:
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
XML_VERSION, method, service_id.decode(),
args, method))
def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes,
control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes:
args = "".join(f"<{n}>{kwargs.get(n)}</{n}>" for n in param_names)
soap_body = (f'\r\n{XML_VERSION}\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
f's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
f'<u:{method} xmlns:u="{service_id.decode()}">{args}</u:{method}></s:Body></s:Envelope>')
if "http://" in gateway_address.decode():
host = gateway_address.decode().split("http://")[1]
else:
host = gateway_address.decode()
return (
(
'POST %s HTTP/1.1\r\n'
'Host: %s\r\n'
'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
'Content-Length: %i\r\n'
'Content-Type: text/xml\r\n'
'SOAPAction: \"%s#%s\"\r\n'
'Connection: Close\r\n'
'Cache-Control: no-cache\r\n'
'Pragma: no-cache\r\n'
'%s'
'\r\n'
) % (
control_url.decode(), # could be just / even if it shouldn't be
host,
len(soap_body),
service_id.decode(), # maybe no quotes
method,
soap_body
)
f'POST {control_url.decode()} HTTP/1.1\r\n' # could be just / even if it shouldn't be
f'Host: {host}\r\n'
f'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
f'Content-Length: {len(soap_body)}\r\n'
f'Content-Type: text/xml\r\n'
f'SOAPAction: \"{service_id.decode()}#{method}\"\r\n'
f'Connection: Close\r\n'
f'Cache-Control: no-cache\r\n'
f'Pragma: no-cache\r\n'
f'{soap_body}'
f'\r\n'
).encode()
def deserialize_soap_post_response(response: bytes, method: str, service_id: str) -> dict:
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
def deserialize_soap_post_response(response: bytes, method: str,
service_id: str) -> typing.Dict[str, typing.Dict[str, str]]:
parsed: typing.List[typing.List[bytes]] = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
content = b'' if not parsed else parsed[0][0]
content_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
content_dict = xml_to_dict(content.decode())
envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault
if not isinstance(envelope[BODY], dict):
# raise UPnPError('blank response')
return {} # TODO: raise
response_body: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
envelope[BODY], f"{'{' + service_id + '}'}"
)
if not response_body:
# raise UPnPError('blank response')
return {} # TODO: raise
if FAULT in response_body:
fault: typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]] = flatten_keys(
response_body[FAULT], "{%s}" % CONTROL
)
raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
response_key = None
if not body:
return {}
for key in body:
for key in response_body:
if method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s: %s" % (method, body))
return body[response_key]
raise UPnPError(f"unknown response fields for {method}: {response_body}")
return response_body[response_key]

View file

@ -3,43 +3,67 @@ import logging
import binascii
import json
from collections import OrderedDict
from typing import List
from typing import List, Optional, Dict, Union, Tuple, Callable
from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator
log = logging.getLogger(__name__)
_template = "(?i)^(%s):[ ]*(.*)$"
ssdp_datagram_patterns = {
'host': (re.compile("(?i)^(host):(.*)$"), str),
'st': (re.compile(_template % 'st'), str),
'man': (re.compile(_template % 'man'), str),
'mx': (re.compile(_template % 'mx'), int),
'nt': (re.compile(_template % 'nt'), str),
'nts': (re.compile(_template % 'nts'), str),
'usn': (re.compile(_template % 'usn'), str),
'location': (re.compile(_template % 'location'), str),
'cache_control': (re.compile(_template % 'cache[-|_]control'), str),
'server': (re.compile(_template % 'server'), str),
}
vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
class SSDPDatagram(object):
def match_vendor(line: str) -> Optional[Tuple[str, str]]:
match: List[Tuple[str, str, str]] = vendor_pattern.findall(line)
if match:
vendor_key: str = match[-1][0].lstrip(" ").rstrip(" ")
vendor_value: str = match[-1][2].lstrip(" ").rstrip(" ")
return vendor_key, vendor_value
return None
def compile_find(pattern: str) -> Callable[[str], Optional[str]]:
p = re.compile(pattern)
def find(line: str) -> Optional[str]:
result: List[List[str]] = []
for outer in p.findall(line):
result.append([])
for inner in outer:
result[-1].append(inner)
if result:
return result[-1][-1].lstrip(" ").rstrip(" ")
return None
return find
ssdp_datagram_patterns: Dict[str, Callable[[str], Optional[str]]] = {
'host': compile_find("(?i)^(host):(.*)$"),
'st': compile_find(_template % 'st'),
'man': compile_find(_template % 'man'),
'mx': compile_find(_template % 'mx'),
'nt': compile_find(_template % 'nt'),
'nts': compile_find(_template % 'nts'),
'usn': compile_find(_template % 'usn'),
'location': compile_find(_template % 'location'),
'cache_control': compile_find(_template % 'cache[-|_]control'),
'server': compile_find(_template % 'server'),
}
class SSDPDatagram:
_M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY"
_OK = "OK"
_start_lines = {
_start_lines: Dict[str, str] = {
_M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK"
}
_friendly_names = {
_friendly_names: Dict[str, str] = {
_M_SEARCH: "m-search",
_NOTIFY: "notify",
_OK: "m-search response"
@ -47,9 +71,7 @@ class SSDPDatagram(object):
_vendor_field_pattern = vendor_pattern
_patterns = ssdp_datagram_patterns
_required_fields = {
_required_fields: Dict[str, List[str]] = {
_M_SEARCH: [
'host',
'man',
@ -75,137 +97,137 @@ class SSDPDatagram(object):
]
}
def __init__(self, packet_type, kwargs: OrderedDict = None) -> None:
def __init__(self, packet_type: str, kwargs: Optional[Dict[str, Union[str, int]]] = None) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
kwargs = kwargs or OrderedDict()
self._field_order: list = [
k.lower().replace("-", "_") for k in kwargs.keys()
kw: Dict[str, Union[str, int]] = kwargs or OrderedDict()
self._field_order: List[str] = [
k.lower().replace("-", "_") for k in kw.keys()
]
self.host = None
self.man = None
self.mx = None
self.st = None
self.nt = None
self.nts = None
self.usn = None
self.location = None
self.cache_control = None
self.server = None
self.date = None
self.ext = None
for k, v in kwargs.items():
self.host: Optional[str] = None
self.man: Optional[str] = None
self.mx: Optional[Union[str, int]] = None
self.st: Optional[str] = None
self.nt: Optional[str] = None
self.nts: Optional[str] = None
self.usn: Optional[str] = None
self.location: Optional[str] = None
self.cache_control: Optional[str] = None
self.server: Optional[str] = None
self.date: Optional[str] = None
self.ext: Optional[str] = None
for k, v in kw.items():
normalized = k.lower().replace("-", "_")
if not normalized.startswith("_") and hasattr(self, normalized) and getattr(self, normalized) is None:
setattr(self, normalized, v)
self._case_mappings: dict = {k.lower(): k for k in kwargs.keys()}
if not normalized.startswith("_") and hasattr(self, normalized):
if getattr(self, normalized, None) is None:
setattr(self, normalized, v)
self._case_mappings: Dict[str, str] = {k.lower(): k for k in kw.keys()}
for k in self._required_fields[self._packet_type]:
if getattr(self, k) is None:
if getattr(self, k, None) is None:
raise UPnPError("missing required field %s" % k)
def get_cli_igd_kwargs(self) -> str:
fields = []
for field in self._field_order:
v = getattr(self, field)
v = getattr(self, field, None)
if v is None:
raise UPnPError("missing required field %s" % field)
fields.append("--%s=%s" % (self._case_mappings.get(field, field), v))
return " ".join(fields)
def __repr__(self) -> str:
return self.as_json()
def __getitem__(self, item):
def __getitem__(self, item: str) -> Union[str, int]:
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self) -> str:
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]]
for attr_name in self._field_order:
if attr_name not in self._required_fields[self._packet_type]:
continue
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name == 'mx':
value = str(attr)
else:
value = attr
lines.append("{}: {}".format(self._case_mappings.get(attr_name.lower(), attr_name.upper()), value))
lines.extend(
f"{self._case_mappings.get(attr_name.lower(), attr_name.upper())}: {str(getattr(self, attr_name))}"
for attr_name in self._field_order if attr_name in self._required_fields[self._packet_type]
)
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self) -> OrderedDict:
def as_dict(self) -> Dict[str, Union[str, int]]:
return self._lines_to_content_dict(self.encode().split(line_separator))
def as_json(self) -> str:
return json.dumps(self.as_dict(), indent=2)
@classmethod
def decode(cls, datagram: bytes):
def decode(cls, datagram: bytes) -> 'SSDPDatagram':
packet = cls._from_string(datagram.decode())
if packet is None:
raise UPnPError(
"failed to decode datagram: {}".format(binascii.hexlify(datagram))
)
for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name)
if attr is None:
if getattr(packet, attr_name, None) is None:
raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
)
return packet
@classmethod
def _lines_to_content_dict(cls, lines: list) -> OrderedDict:
result: OrderedDict = OrderedDict()
def _lines_to_content_dict(cls, lines: List[str]) -> Dict[str, Union[str, int]]:
result: Dict[str, Union[str, int]] = OrderedDict()
matched_keys: List[str] = []
for line in lines:
if not line:
continue
matched = False
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[line[:len(name)]] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
for name, pattern in ssdp_datagram_patterns.items():
if name not in matched_keys:
if name.lower() == 'mx':
_matched_int = pattern(line)
if _matched_int is not None:
match_int = int(_matched_int)
result[line[:len(name)]] = match_int
matched = True
matched_keys.append(name)
break
else:
match = pattern(line)
if match is not None:
result[line[:len(name)]] = match
matched = True
matched_keys.append(name)
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
vendor_key = match[0].lstrip(" ").rstrip(" ")
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
matched_vendor = match_vendor(line)
if matched_vendor and matched_vendor[0] not in result:
result[matched_vendor[0]] = matched_vendor[1]
return result
@classmethod
def _from_string(cls, datagram: str):
def _from_string(cls, datagram: str) -> Optional['SSDPDatagram']:
lines = [l for l in datagram.split(line_separator) if l]
if not lines:
return
return None
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])
return None
@classmethod
def _from_response(cls, lines: List):
def _from_response(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._OK, cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines: List):
def _from_notify(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._NOTIFY, cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines: List):
def _from_request(cls, lines: List) -> 'SSDPDatagram':
return cls(cls._M_SEARCH, cls._lines_to_content_dict(lines))

View file

@ -0,0 +1,81 @@
import typing
from collections import OrderedDict
from defusedxml import ElementTree
str_any_dict = typing.Dict[str, typing.Any]
def parse_xml(xml_str: str) -> ElementTree:
element: ElementTree = ElementTree.fromstring(xml_str)
return element
def _element_text(element_tree: ElementTree) -> typing.Optional[str]:
# if element_tree.attrib:
# element: typing.Dict[str, str] = OrderedDict()
# for k, v in element_tree.attrib.items():
# element['@' + k] = v
# if not element_tree.text:
# return element
# element['#text'] = element_tree.text.strip()
# return element
if element_tree.text:
return element_tree.text.strip()
return None
def _get_element_children(element_tree: ElementTree) -> typing.Dict[str, typing.Union[
str, typing.List[typing.Union[typing.Dict[str, str], str]]]]:
element_children = _get_child_dicts(element_tree)
element: typing.Dict[str, typing.Any] = OrderedDict()
keys: typing.List[str] = list(element_children.keys())
for k in keys:
v: typing.Union[str, typing.List[typing.Any], typing.Dict[str, typing.Any]] = element_children[k]
if len(v) == 1 and isinstance(v, list):
l: typing.List[typing.Union[typing.Dict[str, str], str]] = v
element[k] = l[0]
else:
element[k] = v
return element
def _get_child_dicts(element: ElementTree) -> typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]]:
children_dicts: typing.Dict[str, typing.List[typing.Union[typing.Dict[str, str], str]]] = OrderedDict()
children: typing.List[ElementTree] = list(element)
for child in children:
child_dict = _recursive_element_to_dict(child)
child_keys: typing.List[str] = list(child_dict.keys())
for k in child_keys:
assert k in child_dict
v: typing.Union[typing.Dict[str, str], str] = child_dict[k]
if k not in children_dicts.keys():
new_item = [v]
children_dicts[k] = new_item
else:
sublist = children_dicts[k]
assert isinstance(sublist, list)
sublist.append(v)
return children_dicts
def _recursive_element_to_dict(element_tree: ElementTree) -> typing.Dict[str, typing.Any]:
if len(element_tree):
element_result: typing.Dict[str, typing.Dict[str, typing.Union[
str, typing.List[typing.Union[str, typing.Dict[str, typing.Any]]]]]] = OrderedDict()
children_element = _get_element_children(element_tree)
if element_tree.tag is not None:
element_result[element_tree.tag] = children_element
return element_result
else:
element_text = _element_text(element_tree)
if element_text is not None:
base_element_result: typing.Dict[str, typing.Any] = OrderedDict()
if element_tree.tag is not None:
base_element_result[element_tree.tag] = element_text
return base_element_result
null_result: typing.Dict[str, str] = OrderedDict()
return null_result
def xml_to_dict(xml_str: str) -> typing.Dict[str, typing.Any]:
return _recursive_element_to_dict(parse_xml(xml_str))

View file

@ -5,12 +5,11 @@ import asyncio
import zlib
import base64
from collections import OrderedDict
from typing import Tuple, Dict, List, Union
from typing import Tuple, Dict, List, Union, Optional, Callable
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.util import get_gateway_and_lan_addresses
from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.commands import SOAPCommand
from aioupnp.serialization.ssdp import SSDPDatagram
log = logging.getLogger(__name__)
@ -35,6 +34,10 @@ class UPnP:
self.gateway_address = gateway_address
self.gateway = gateway
@classmethod
def get_annotations(cls, command: str) -> Dict[str, type]:
return getattr(Gateway.commands, command).__annotations__
@classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> Tuple[str, str]:
@ -59,8 +62,9 @@ class UPnP:
@classmethod
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default',
loop=None) -> Dict:
igd_args: Optional[Dict[str, Union[int, str]]] = None,
unicast: bool = True, interface_name: str = 'default',
loop=None) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]:
if not lan_address or not gateway_address:
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -97,13 +101,13 @@ class UPnP:
async def get_port_mapping_by_index(self, index: int) -> Dict:
result = await self._get_port_mapping_by_index(index)
if result:
if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand):
if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
return {
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
}
return {}
async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Union[None, str], int, str,
async def _get_port_mapping_by_index(self, index: int) -> Union[None, Tuple[Optional[str], int, str,
int, str, bool, str, int]]:
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
@ -134,7 +138,7 @@ class UPnP:
result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
)
if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
if result and self.gateway.commands.is_registered('GetSpecificPortMappingEntry'):
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
except UPnPError:
pass
@ -152,7 +156,8 @@ class UPnP:
)
@cli
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
async def get_next_mapping(self, port: int, protocol: str, description: str,
internal_port: Optional[int] = None) -> int:
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
internal_port = int(internal_port or port)
@ -340,8 +345,9 @@ class UPnP:
return await self.gateway.commands.GetActiveConnections()
@classmethod
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None:
def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
@ -356,7 +362,7 @@ class UPnP:
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
fut: asyncio.Future = asyncio.Future()
fut: asyncio.Future = loop.create_future()
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine

View file

@ -1,91 +1,48 @@
import re
import socket
from collections import defaultdict
from typing import Tuple, Dict
from xml.etree import ElementTree
import netifaces
import typing
from collections import OrderedDict
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
str_any_dict = typing.Dict[str, typing.Any]
def etree_to_dict(t: ElementTree.Element) -> Dict:
d: dict = {}
if t.attrib:
d[t.tag] = {}
children = list(t)
if children:
dd: dict = defaultdict(list)
for dc in map(etree_to_dict, children):
for k, v in dc.items():
dd[k].append(v)
d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()}
if t.attrib:
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
if t.text:
text = t.text.strip()
if children or t.attrib:
if text:
d[t.tag]['#text'] = text
else:
d[t.tag] = text
return d
def flatten_keys(d, strip):
if not isinstance(d, (list, dict)):
return d
if isinstance(d, list):
return [flatten_keys(i, strip) for i in d]
t = {}
for k, v in d.items():
def _recursive_flatten(to_flatten: typing.Any, strip: str) -> typing.Any:
if not isinstance(to_flatten, (list, dict)):
return to_flatten
if isinstance(to_flatten, list):
assert isinstance(to_flatten, list)
return [_recursive_flatten(i, strip) for i in to_flatten]
assert isinstance(to_flatten, dict)
keys: typing.List[str] = list(to_flatten.keys())
copy: str_any_dict = OrderedDict()
for k in keys:
item: typing.Any = to_flatten[k]
if strip in k and strip != k:
t[k.split(strip)[1]] = flatten_keys(v, strip)
copy[k.split(strip)[1]] = _recursive_flatten(item, strip)
else:
t[k] = flatten_keys(v, strip)
return t
copy[k] = _recursive_flatten(item, strip)
return copy
def get_dict_val_case_insensitive(d, k):
match = list(filter(lambda x: x.lower() == k.lower(), d.keys()))
if not match:
return
def flatten_keys(to_flatten: str_any_dict, strip: str) -> str_any_dict:
keys: typing.List[str] = list(to_flatten.keys())
copy: str_any_dict = OrderedDict()
for k in keys:
item = to_flatten[k]
if strip in k and strip != k:
new_key: str = k.split(strip)[1]
copy[new_key] = _recursive_flatten(item, strip)
else:
copy[k] = _recursive_flatten(item, strip)
return copy
def get_dict_val_case_insensitive(source: typing.Dict[typing.AnyStr, typing.AnyStr], key: typing.AnyStr) -> typing.Optional[typing.AnyStr]:
match: typing.List[typing.AnyStr] = list(filter(lambda x: x.lower() == key.lower(), source.keys()))
if not len(match):
return None
if len(match) > 1:
raise KeyError("overlapping keys")
return d[match[0]]
# import struct
# import fcntl
# def get_ip_address(ifname):
# SIOCGIFADDR = 0x8915
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# return socket.inet_ntoa(fcntl.ioctl(
# s.fileno(),
# SIOCGIFADDR,
# struct.pack(b'256s', ifname[:15].encode())
# )[20:24])
def get_interfaces():
r = {
interface_name: (router_address, netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'])
for router_address, interface_name, _ in netifaces.gateways()[socket.AF_INET]
}
for interface_name in netifaces.interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in r:
continue
addresses = netifaces.ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
r[interface_name] = (gateway_guess, address)
r['default'] = r[netifaces.gateways()['default'][netifaces.AF_INET][1]]
return r
def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return '', ''
if len(match) == 1:
matched_key: typing.AnyStr = match[0]
return source[matched_key]
raise KeyError("overlapping keys")

View file

@ -37,7 +37,7 @@ setup(
packages=find_packages(exclude=('tests',)),
entry_points={'console_scripts': console_scripts},
install_requires=[
'netifaces',
'netifaces', 'defusedxml'
],
extras_require={
'test': (