mypy refactor, improve coverage #12
3 changed files with 125 additions and 118 deletions
|
@ -4,7 +4,7 @@ import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
import typing
|
import typing
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from aioupnp.upnp import UPnP
|
from aioupnp.upnp import run_cli, UPnP
|
||||||
from aioupnp.commands import SOAPCommands
|
from aioupnp.commands import SOAPCommands
|
||||||
|
|
||||||
log = logging.getLogger("aioupnp")
|
log = logging.getLogger("aioupnp")
|
||||||
|
@ -100,7 +100,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
|
||||||
interface: str = str(options.pop('interface'))
|
interface: str = str(options.pop('interface'))
|
||||||
unicast: bool = bool(options.pop('unicast'))
|
unicast: bool = bool(options.pop('unicast'))
|
||||||
|
|
||||||
UPnP.run_cli(
|
run_cli(
|
||||||
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
|
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
|
||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
|
237
aioupnp/upnp.py
237
aioupnp/upnp.py
|
@ -11,21 +11,17 @@ from aioupnp.gateway import Gateway
|
||||||
from aioupnp.interfaces import get_gateway_and_lan_addresses
|
from aioupnp.interfaces import get_gateway_and_lan_addresses
|
||||||
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
|
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
|
||||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||||
|
from aioupnp.commands import SOAPCommands
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def cli(fn):
|
# def _encode(x):
|
||||||
fn._cli = True
|
# if isinstance(x, bytes):
|
||||||
return fn
|
# return x.decode()
|
||||||
|
# elif isinstance(x, Exception):
|
||||||
|
# return str(x)
|
||||||
def _encode(x):
|
# return x
|
||||||
if isinstance(x, bytes):
|
|
||||||
return x.decode()
|
|
||||||
elif isinstance(x, Exception):
|
|
||||||
return str(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class UPnP:
|
class UPnP:
|
||||||
|
@ -36,7 +32,7 @@ class UPnP:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_annotations(cls, command: str) -> Dict[str, type]:
|
def get_annotations(cls, command: str) -> Dict[str, type]:
|
||||||
return getattr(Gateway.commands, command).__annotations__
|
return getattr(SOAPCommands, command).__annotations__
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
||||||
|
@ -61,7 +57,6 @@ class UPnP:
|
||||||
return cls(lan_address, gateway_address, gateway)
|
return cls(lan_address, gateway_address, gateway)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@cli
|
|
||||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||||
igd_args: Optional[Dict[str, Union[int, str]]] = None,
|
igd_args: Optional[Dict[str, Union[int, str]]] = None,
|
||||||
unicast: bool = True, interface_name: str = 'default',
|
unicast: bool = True, interface_name: str = 'default',
|
||||||
|
@ -86,92 +81,87 @@ class UPnP:
|
||||||
'discover_reply': datagram.as_dict()
|
'discover_reply': datagram.as_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
@cli
|
|
||||||
async def get_external_ip(self) -> str:
|
async def get_external_ip(self) -> str:
|
||||||
return await self.gateway.commands.GetExternalIPAddress()
|
return await self.gateway.commands.GetExternalIPAddress()
|
||||||
|
|
||||||
@cli
|
|
||||||
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
|
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
|
||||||
description: str) -> None:
|
description: str) -> None:
|
||||||
return await self.gateway.commands.AddPortMapping(
|
await self.gateway.commands.AddPortMapping(
|
||||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
|
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
|
||||||
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
||||||
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
|
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
@cli
|
async def get_port_mapping_by_index(self, index: int) -> Optional[Tuple[str, int, str, int, str, bool, str, int]]:
|
||||||
async def get_port_mapping_by_index(self, index: int):
|
|
||||||
return await self._get_port_mapping_by_index(index)
|
|
||||||
# if result:
|
|
||||||
# if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
|
|
||||||
# return {
|
|
||||||
# k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
|
|
||||||
# }
|
|
||||||
# return {}
|
|
||||||
|
|
||||||
async def _get_port_mapping_by_index(self, index: int) -> Optional[Tuple[Optional[str], int, str,
|
|
||||||
int, str, bool, str, int]]:
|
|
||||||
try:
|
try:
|
||||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||||
return redirect
|
return redirect
|
||||||
except UPnPError:
|
except UPnPError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@cli
|
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
|
||||||
async def get_redirects(self) -> List[Dict]:
|
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
|
||||||
redirects = []
|
|
||||||
cnt = 0
|
cnt = 0
|
||||||
redirect = await self.get_port_mapping_by_index(cnt)
|
redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
|
||||||
while redirect:
|
try:
|
||||||
|
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||||
|
except UPnPError:
|
||||||
|
pass
|
||||||
|
while redirect is not None:
|
||||||
redirects.append(redirect)
|
redirects.append(redirect)
|
||||||
cnt += 1
|
cnt += 1
|
||||||
redirect = await self.get_port_mapping_by_index(cnt)
|
try:
|
||||||
|
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||||
|
except UPnPError:
|
||||||
|
pass
|
||||||
return redirects
|
return redirects
|
||||||
|
|
||||||
@cli
|
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
|
||||||
async def get_specific_port_mapping(self, external_port: int, protocol: str):
|
|
||||||
"""
|
"""
|
||||||
:param external_port: (int) external port to listen on
|
:param external_port: (int) external port to listen on
|
||||||
:param protocol: (str) 'UDP' | 'TCP'
|
:param protocol: (str) 'UDP' | 'TCP'
|
||||||
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
|
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# try:
|
|
||||||
return await self.gateway.commands.GetSpecificPortMappingEntry(
|
return await self.gateway.commands.GetSpecificPortMappingEntry(
|
||||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
|
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
|
||||||
)
|
)
|
||||||
# except UPnPError:
|
|
||||||
# pass
|
|
||||||
# return {}
|
|
||||||
|
|
||||||
@cli
|
|
||||||
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
||||||
"""
|
"""
|
||||||
:param external_port: (int) external port to listen on
|
:param external_port: (int) external port to listen on
|
||||||
:param protocol: (str) 'UDP' | 'TCP'
|
:param protocol: (str) 'UDP' | 'TCP'
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
return await self.gateway.commands.DeletePortMapping(
|
await self.gateway.commands.DeletePortMapping(
|
||||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
@cli
|
|
||||||
async def get_next_mapping(self, port: int, protocol: str, description: str,
|
async def get_next_mapping(self, port: int, protocol: str, description: str,
|
||||||
internal_port: Optional[int] = None) -> int:
|
internal_port: Optional[int] = None) -> int:
|
||||||
if protocol not in ["UDP", "TCP"]:
|
if protocol not in ["UDP", "TCP"]:
|
||||||
raise UPnPError("unsupported protocol: {}".format(protocol))
|
raise UPnPError("unsupported protocol: {}".format(protocol))
|
||||||
internal_port = int(internal_port or port)
|
_internal_port = int(internal_port or port)
|
||||||
requested_port = int(internal_port)
|
requested_port = int(_internal_port)
|
||||||
redirect_tups = []
|
redirect_tups: List[Tuple[str, int, str, int, str, bool, str, int]] = []
|
||||||
cnt = 0
|
cnt = 0
|
||||||
port = int(port)
|
port = int(port)
|
||||||
redirect = await self._get_port_mapping_by_index(cnt)
|
redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
|
||||||
while redirect:
|
try:
|
||||||
|
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||||
|
except UPnPError:
|
||||||
|
pass
|
||||||
|
while redirect is not None:
|
||||||
redirect_tups.append(redirect)
|
redirect_tups.append(redirect)
|
||||||
cnt += 1
|
cnt += 1
|
||||||
redirect = await self._get_port_mapping_by_index(cnt)
|
try:
|
||||||
|
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||||
|
except UPnPError as err:
|
||||||
|
if "ArrayIndex" in str(err):
|
||||||
|
break
|
||||||
|
|
||||||
redirects = {
|
redirects: Dict[Tuple[int, str], Tuple[str, int, str]] = {
|
||||||
(ext_port, proto): (int_host, int_port, desc)
|
(ext_port, proto): (int_host, int_port, desc)
|
||||||
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups
|
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups
|
||||||
}
|
}
|
||||||
|
@ -181,18 +171,20 @@ class UPnP:
|
||||||
if int_host == self.lan_address and int_port == requested_port and desc == description:
|
if int_host == self.lan_address and int_port == requested_port and desc == description:
|
||||||
return port
|
return port
|
||||||
port += 1
|
port += 1
|
||||||
await self.add_port_mapping( # set one up
|
await self.gateway.commands.AddPortMapping(
|
||||||
port, protocol, internal_port, self.lan_address, description
|
NewRemoteHost='', NewExternalPort=port, NewProtocol=protocol,
|
||||||
|
NewInternalPort=_internal_port, NewInternalClient=self.lan_address,
|
||||||
|
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
|
||||||
)
|
)
|
||||||
return port
|
return port
|
||||||
|
|
||||||
@cli
|
# @cli
|
||||||
async def debug_gateway(self) -> str:
|
# async def debug_gateway(self) -> str:
|
||||||
return json.dumps({
|
# return json.dumps({
|
||||||
"gateway": self.gateway.debug_gateway(),
|
# "gateway": self.gateway.debug_gateway(),
|
||||||
"client_address": self.lan_address,
|
# "client_address": self.lan_address,
|
||||||
}, default=_encode, indent=2)
|
# }, default=_encode, indent=2)
|
||||||
|
#
|
||||||
# @property
|
# @property
|
||||||
# def zipped_debugging_info(self) -> str:
|
# def zipped_debugging_info(self) -> str:
|
||||||
# return base64.b64encode(zlib.compress(
|
# return base64.b64encode(zlib.compress(
|
||||||
|
@ -291,66 +283,81 @@ class UPnP:
|
||||||
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||||
# return await self.gateway.commands.GetActiveConnections()
|
# return await self.gateway.commands.GetActiveConnections()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
|
||||||
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
|
||||||
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
|
|
||||||
"""
|
|
||||||
:param method: the command name
|
|
||||||
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
|
|
||||||
:param lan_address: the ip address of the local interface
|
|
||||||
:param gateway_address: the ip address of the gateway
|
|
||||||
:param timeout: timeout, in seconds
|
|
||||||
:param interface_name: name of the network interface, the default is aliased to 'default'
|
|
||||||
:param kwargs: keyword arguments for the command
|
|
||||||
:param loop: EventLoop, used for testing
|
|
||||||
"""
|
|
||||||
kwargs = kwargs or {}
|
|
||||||
igd_args = igd_args
|
|
||||||
timeout = int(timeout)
|
|
||||||
loop = loop or asyncio.get_event_loop()
|
|
||||||
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
|
|
||||||
|
|
||||||
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||||
|
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
||||||
|
unicast: bool = True, kwargs: Optional[Dict] = None,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
|
"""
|
||||||
|
:param method: the command name
|
||||||
|
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
|
||||||
|
:param lan_address: the ip address of the local interface
|
||||||
|
:param gateway_address: the ip address of the gateway
|
||||||
|
:param timeout: timeout, in seconds
|
||||||
|
:param interface_name: name of the network interface, the default is aliased to 'default'
|
||||||
|
:param kwargs: keyword arguments for the command
|
||||||
|
:param loop: EventLoop, used for testing
|
||||||
|
"""
|
||||||
|
|
||||||
if method == 'm_search': # if we're only m_searching don't do any device discovery
|
|
||||||
fn = lambda *_a, **_kw: cls.m_search(
|
kwargs = kwargs or {}
|
||||||
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
|
igd_args = igd_args
|
||||||
|
timeout = int(timeout)
|
||||||
|
loop = loop or asyncio.get_event_loop()
|
||||||
|
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
|
||||||
|
|
||||||
|
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
||||||
|
cli_commands = [
|
||||||
|
'm_search',
|
||||||
|
'get_external_ip',
|
||||||
|
'add_port_mapping',
|
||||||
|
'get_port_mapping_by_index',
|
||||||
|
'get_redirects',
|
||||||
|
'get_specific_port_mapping',
|
||||||
|
'delete_port_mapping',
|
||||||
|
'get_next_mapping'
|
||||||
|
]
|
||||||
|
|
||||||
|
if method == 'm_search': # if we're only m_searching don't do any device discovery
|
||||||
|
fn = lambda *_a, **_kw: UPnP.m_search(
|
||||||
|
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
|
||||||
|
)
|
||||||
|
else: # automatically discover the gateway
|
||||||
|
try:
|
||||||
|
u = await UPnP.discover(
|
||||||
|
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
|
||||||
)
|
)
|
||||||
else: # automatically discover the gateway
|
|
||||||
try:
|
|
||||||
u = await cls.discover(
|
|
||||||
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
|
|
||||||
)
|
|
||||||
except UPnPError as err:
|
|
||||||
fut.set_exception(err)
|
|
||||||
return
|
|
||||||
if hasattr(u, method) and hasattr(getattr(u, method), "_cli"):
|
|
||||||
fn = getattr(u, method)
|
|
||||||
else:
|
|
||||||
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
|
|
||||||
return
|
|
||||||
try: # call the command
|
|
||||||
result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
|
|
||||||
fut.set_result(result)
|
|
||||||
except UPnPError as err:
|
except UPnPError as err:
|
||||||
fut.set_exception(err)
|
fut.set_exception(err)
|
||||||
|
return
|
||||||
|
if method not in cli_commands:
|
||||||
|
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
fn = getattr(u, method)
|
||||||
|
|
||||||
except Exception as err:
|
try: # call the command
|
||||||
log.exception("uncaught error")
|
result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
|
||||||
fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
|
fut.set_result(result)
|
||||||
|
|
||||||
if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"):
|
|
||||||
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
|
|
||||||
else:
|
|
||||||
loop.run_until_complete(wrapper())
|
|
||||||
try:
|
|
||||||
result = fut.result()
|
|
||||||
except UPnPError as err:
|
except UPnPError as err:
|
||||||
print("aioupnp encountered an error: %s" % str(err))
|
fut.set_exception(err)
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(result, (list, tuple, dict)):
|
except Exception as err:
|
||||||
print(json.dumps(result, indent=2))
|
log.exception("uncaught error")
|
||||||
else:
|
fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
|
||||||
print(result)
|
|
||||||
|
if not hasattr(UPnP, method):
|
||||||
|
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
|
||||||
|
else:
|
||||||
|
loop.run_until_complete(wrapper())
|
||||||
|
try:
|
||||||
|
result = fut.result()
|
||||||
|
except UPnPError as err:
|
||||||
|
print("aioupnp encountered an error: %s" % str(err))
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(result, (list, tuple, dict)):
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
else:
|
||||||
|
print(result)
|
||||||
|
return
|
||||||
|
|
|
@ -20,4 +20,4 @@ class ElementTree:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fromstring(cls, xml_str: str) -> 'ElementTree':
|
def fromstring(cls, xml_str: str) -> 'ElementTree':
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
Loading…
Reference in a new issue