mypy refactor, improve coverage #12
3 changed files with 125 additions and 118 deletions
|
@ -4,7 +4,7 @@ import logging
|
|||
import textwrap
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from aioupnp.upnp import UPnP
|
||||
from aioupnp.upnp import run_cli, UPnP
|
||||
from aioupnp.commands import SOAPCommands
|
||||
|
||||
log = logging.getLogger("aioupnp")
|
||||
|
@ -100,7 +100,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
|
|||
interface: str = str(options.pop('interface'))
|
||||
unicast: bool = bool(options.pop('unicast'))
|
||||
|
||||
UPnP.run_cli(
|
||||
run_cli(
|
||||
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
|
||||
)
|
||||
return 0
|
||||
|
|
145
aioupnp/upnp.py
145
aioupnp/upnp.py
|
@ -11,21 +11,17 @@ from aioupnp.gateway import Gateway
|
|||
from aioupnp.interfaces import get_gateway_and_lan_addresses
|
||||
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.commands import SOAPCommands
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cli(fn):
|
||||
fn._cli = True
|
||||
return fn
|
||||
|
||||
|
||||
def _encode(x):
|
||||
if isinstance(x, bytes):
|
||||
return x.decode()
|
||||
elif isinstance(x, Exception):
|
||||
return str(x)
|
||||
return x
|
||||
# def _encode(x):
|
||||
# if isinstance(x, bytes):
|
||||
# return x.decode()
|
||||
# elif isinstance(x, Exception):
|
||||
# return str(x)
|
||||
# return x
|
||||
|
||||
|
||||
class UPnP:
|
||||
|
@ -36,7 +32,7 @@ class UPnP:
|
|||
|
||||
@classmethod
|
||||
def get_annotations(cls, command: str) -> Dict[str, type]:
|
||||
return getattr(Gateway.commands, command).__annotations__
|
||||
return getattr(SOAPCommands, command).__annotations__
|
||||
|
||||
@classmethod
|
||||
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
||||
|
@ -61,7 +57,6 @@ class UPnP:
|
|||
return cls(lan_address, gateway_address, gateway)
|
||||
|
||||
@classmethod
|
||||
@cli
|
||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||
igd_args: Optional[Dict[str, Union[int, str]]] = None,
|
||||
unicast: bool = True, interface_name: str = 'default',
|
||||
|
@ -86,92 +81,87 @@ class UPnP:
|
|||
'discover_reply': datagram.as_dict()
|
||||
}
|
||||
|
||||
@cli
|
||||
async def get_external_ip(self) -> str:
|
||||
return await self.gateway.commands.GetExternalIPAddress()
|
||||
|
||||
@cli
|
||||
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
|
||||
description: str) -> None:
|
||||
return await self.gateway.commands.AddPortMapping(
|
||||
await self.gateway.commands.AddPortMapping(
|
||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
|
||||
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
||||
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
|
||||
)
|
||||
return None
|
||||
|
||||
@cli
|
||||
async def get_port_mapping_by_index(self, index: int):
|
||||
return await self._get_port_mapping_by_index(index)
|
||||
# if result:
|
||||
# if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
|
||||
# return {
|
||||
# k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
|
||||
# }
|
||||
# return {}
|
||||
|
||||
async def _get_port_mapping_by_index(self, index: int) -> Optional[Tuple[Optional[str], int, str,
|
||||
int, str, bool, str, int]]:
|
||||
async def get_port_mapping_by_index(self, index: int) -> Optional[Tuple[str, int, str, int, str, bool, str, int]]:
|
||||
try:
|
||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||
return redirect
|
||||
except UPnPError:
|
||||
return None
|
||||
|
||||
@cli
|
||||
async def get_redirects(self) -> List[Dict]:
|
||||
redirects = []
|
||||
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
|
||||
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
|
||||
cnt = 0
|
||||
redirect = await self.get_port_mapping_by_index(cnt)
|
||||
while redirect:
|
||||
redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
|
||||
try:
|
||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||
except UPnPError:
|
||||
pass
|
||||
while redirect is not None:
|
||||
redirects.append(redirect)
|
||||
cnt += 1
|
||||
redirect = await self.get_port_mapping_by_index(cnt)
|
||||
try:
|
||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||
except UPnPError:
|
||||
pass
|
||||
return redirects
|
||||
|
||||
@cli
|
||||
async def get_specific_port_mapping(self, external_port: int, protocol: str):
|
||||
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
|
||||
"""
|
||||
:param external_port: (int) external port to listen on
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
|
||||
"""
|
||||
|
||||
# try:
|
||||
return await self.gateway.commands.GetSpecificPortMappingEntry(
|
||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
# except UPnPError:
|
||||
# pass
|
||||
# return {}
|
||||
|
||||
@cli
|
||||
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
||||
"""
|
||||
:param external_port: (int) external port to listen on
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
:return: None
|
||||
"""
|
||||
return await self.gateway.commands.DeletePortMapping(
|
||||
await self.gateway.commands.DeletePortMapping(
|
||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
return None
|
||||
|
||||
@cli
|
||||
async def get_next_mapping(self, port: int, protocol: str, description: str,
|
||||
internal_port: Optional[int] = None) -> int:
|
||||
if protocol not in ["UDP", "TCP"]:
|
||||
raise UPnPError("unsupported protocol: {}".format(protocol))
|
||||
internal_port = int(internal_port or port)
|
||||
requested_port = int(internal_port)
|
||||
redirect_tups = []
|
||||
_internal_port = int(internal_port or port)
|
||||
requested_port = int(_internal_port)
|
||||
redirect_tups: List[Tuple[str, int, str, int, str, bool, str, int]] = []
|
||||
cnt = 0
|
||||
port = int(port)
|
||||
redirect = await self._get_port_mapping_by_index(cnt)
|
||||
while redirect:
|
||||
redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
|
||||
try:
|
||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||
except UPnPError:
|
||||
pass
|
||||
while redirect is not None:
|
||||
redirect_tups.append(redirect)
|
||||
cnt += 1
|
||||
redirect = await self._get_port_mapping_by_index(cnt)
|
||||
try:
|
||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
|
||||
except UPnPError as err:
|
||||
if "ArrayIndex" in str(err):
|
||||
break
|
||||
|
||||
redirects = {
|
||||
redirects: Dict[Tuple[int, str], Tuple[str, int, str]] = {
|
||||
(ext_port, proto): (int_host, int_port, desc)
|
||||
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups
|
||||
}
|
||||
|
@ -181,18 +171,20 @@ class UPnP:
|
|||
if int_host == self.lan_address and int_port == requested_port and desc == description:
|
||||
return port
|
||||
port += 1
|
||||
await self.add_port_mapping( # set one up
|
||||
port, protocol, internal_port, self.lan_address, description
|
||||
await self.gateway.commands.AddPortMapping(
|
||||
NewRemoteHost='', NewExternalPort=port, NewProtocol=protocol,
|
||||
NewInternalPort=_internal_port, NewInternalClient=self.lan_address,
|
||||
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
|
||||
)
|
||||
return port
|
||||
|
||||
@cli
|
||||
async def debug_gateway(self) -> str:
|
||||
return json.dumps({
|
||||
"gateway": self.gateway.debug_gateway(),
|
||||
"client_address": self.lan_address,
|
||||
}, default=_encode, indent=2)
|
||||
|
||||
# @cli
|
||||
# async def debug_gateway(self) -> str:
|
||||
# return json.dumps({
|
||||
# "gateway": self.gateway.debug_gateway(),
|
||||
# "client_address": self.lan_address,
|
||||
# }, default=_encode, indent=2)
|
||||
#
|
||||
# @property
|
||||
# def zipped_debugging_info(self) -> str:
|
||||
# return base64.b64encode(zlib.compress(
|
||||
|
@ -291,10 +283,11 @@ class UPnP:
|
|||
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||
# return await self.gateway.commands.GetActiveConnections()
|
||||
|
||||
@classmethod
|
||||
def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||
|
||||
def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
||||
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
|
||||
unicast: bool = True, kwargs: Optional[Dict] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
"""
|
||||
:param method: the command name
|
||||
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
|
||||
|
@ -305,6 +298,8 @@ class UPnP:
|
|||
:param kwargs: keyword arguments for the command
|
||||
:param loop: EventLoop, used for testing
|
||||
"""
|
||||
|
||||
|
||||
kwargs = kwargs or {}
|
||||
igd_args = igd_args
|
||||
timeout = int(timeout)
|
||||
|
@ -312,24 +307,35 @@ class UPnP:
|
|||
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
|
||||
|
||||
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
||||
cli_commands = [
|
||||
'm_search',
|
||||
'get_external_ip',
|
||||
'add_port_mapping',
|
||||
'get_port_mapping_by_index',
|
||||
'get_redirects',
|
||||
'get_specific_port_mapping',
|
||||
'delete_port_mapping',
|
||||
'get_next_mapping'
|
||||
]
|
||||
|
||||
if method == 'm_search': # if we're only m_searching don't do any device discovery
|
||||
fn = lambda *_a, **_kw: cls.m_search(
|
||||
fn = lambda *_a, **_kw: UPnP.m_search(
|
||||
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
|
||||
)
|
||||
else: # automatically discover the gateway
|
||||
try:
|
||||
u = await cls.discover(
|
||||
u = await UPnP.discover(
|
||||
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
|
||||
)
|
||||
except UPnPError as err:
|
||||
fut.set_exception(err)
|
||||
return
|
||||
if hasattr(u, method) and hasattr(getattr(u, method), "_cli"):
|
||||
fn = getattr(u, method)
|
||||
else:
|
||||
if method not in cli_commands:
|
||||
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
|
||||
return
|
||||
else:
|
||||
fn = getattr(u, method)
|
||||
|
||||
try: # call the command
|
||||
result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
|
||||
fut.set_result(result)
|
||||
|
@ -340,7 +346,7 @@ class UPnP:
|
|||
log.exception("uncaught error")
|
||||
fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
|
||||
|
||||
if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"):
|
||||
if not hasattr(UPnP, method):
|
||||
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
|
||||
else:
|
||||
loop.run_until_complete(wrapper())
|
||||
|
@ -354,3 +360,4 @@ class UPnP:
|
|||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print(result)
|
||||
return
|
||||
|
|
Loading…
Reference in a new issue