mypy refactor, improve coverage #12

Merged
jackrobison merged 9 commits from mypy-refactor into master 2019-05-22 09:05:10 +02:00
3 changed files with 125 additions and 118 deletions
Showing only changes of commit 8457ef54e2 - Show all commits

View file

@ -4,7 +4,7 @@ import logging
import textwrap
import typing
from collections import OrderedDict
from aioupnp.upnp import UPnP
from aioupnp.upnp import run_cli, UPnP
from aioupnp.commands import SOAPCommands
log = logging.getLogger("aioupnp")
@ -100,7 +100,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
UPnP.run_cli(
run_cli(
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
)
return 0

View file

@ -11,21 +11,17 @@ from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import SOAPCommands
log = logging.getLogger(__name__)
def cli(fn):
fn._cli = True
return fn
def _encode(x):
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
# def _encode(x):
# if isinstance(x, bytes):
# return x.decode()
# elif isinstance(x, Exception):
# return str(x)
# return x
class UPnP:
@ -36,7 +32,7 @@ class UPnP:
@classmethod
def get_annotations(cls, command: str) -> Dict[str, type]:
return getattr(Gateway.commands, command).__annotations__
return getattr(SOAPCommands, command).__annotations__
@classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
@ -61,7 +57,6 @@ class UPnP:
return cls(lan_address, gateway_address, gateway)
@classmethod
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: Optional[Dict[str, Union[int, str]]] = None,
unicast: bool = True, interface_name: str = 'default',
@ -86,92 +81,87 @@ class UPnP:
'discover_reply': datagram.as_dict()
}
@cli
async def get_external_ip(self) -> str:
return await self.gateway.commands.GetExternalIPAddress()
@cli
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
description: str) -> None:
return await self.gateway.commands.AddPortMapping(
await self.gateway.commands.AddPortMapping(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
)
return None
@cli
async def get_port_mapping_by_index(self, index: int):
return await self._get_port_mapping_by_index(index)
# if result:
# if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
# return {
# k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
# }
# return {}
async def _get_port_mapping_by_index(self, index: int) -> Optional[Tuple[Optional[str], int, str,
int, str, bool, str, int]]:
async def get_port_mapping_by_index(self, index: int) -> Optional[Tuple[str, int, str, int, str, bool, str, int]]:
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect
except UPnPError:
return None
@cli
async def get_redirects(self) -> List[Dict]:
redirects = []
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
cnt = 0
redirect = await self.get_port_mapping_by_index(cnt)
while redirect:
redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError:
pass
while redirect is not None:
redirects.append(redirect)
cnt += 1
redirect = await self.get_port_mapping_by_index(cnt)
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError:
pass
return redirects
@cli
async def get_specific_port_mapping(self, external_port: int, protocol: str):
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
"""
# try:
return await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
)
# except UPnPError:
# pass
# return {}
@cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: None
"""
return await self.gateway.commands.DeletePortMapping(
await self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
)
return None
@cli
async def get_next_mapping(self, port: int, protocol: str, description: str,
internal_port: Optional[int] = None) -> int:
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
internal_port = int(internal_port or port)
requested_port = int(internal_port)
redirect_tups = []
_internal_port = int(internal_port or port)
requested_port = int(_internal_port)
redirect_tups: List[Tuple[str, int, str, int, str, bool, str, int]] = []
cnt = 0
port = int(port)
redirect = await self._get_port_mapping_by_index(cnt)
while redirect:
redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError:
pass
while redirect is not None:
redirect_tups.append(redirect)
cnt += 1
redirect = await self._get_port_mapping_by_index(cnt)
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError as err:
if "ArrayIndex" in str(err):
break
redirects = {
redirects: Dict[Tuple[int, str], Tuple[str, int, str]] = {
(ext_port, proto): (int_host, int_port, desc)
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups
}
@ -181,18 +171,20 @@ class UPnP:
if int_host == self.lan_address and int_port == requested_port and desc == description:
return port
port += 1
await self.add_port_mapping( # set one up
port, protocol, internal_port, self.lan_address, description
await self.gateway.commands.AddPortMapping(
NewRemoteHost='', NewExternalPort=port, NewProtocol=protocol,
NewInternalPort=_internal_port, NewInternalClient=self.lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
)
return port
@cli
async def debug_gateway(self) -> str:
return json.dumps({
"gateway": self.gateway.debug_gateway(),
"client_address": self.lan_address,
}, default=_encode, indent=2)
# @cli
# async def debug_gateway(self) -> str:
# return json.dumps({
# "gateway": self.gateway.debug_gateway(),
# "client_address": self.lan_address,
# }, default=_encode, indent=2)
#
# @property
# def zipped_debugging_info(self) -> str:
# return base64.b64encode(zlib.compress(
@ -291,10 +283,11 @@ class UPnP:
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
# return await self.gateway.commands.GetActiveConnections()
@classmethod
def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
unicast: bool = True, kwargs: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
@ -305,6 +298,8 @@ class UPnP:
:param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing
"""
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
@ -312,24 +307,35 @@ class UPnP:
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
cli_commands = [
'm_search',
'get_external_ip',
'add_port_mapping',
'get_port_mapping_by_index',
'get_redirects',
'get_specific_port_mapping',
'delete_port_mapping',
'get_next_mapping'
]
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: cls.m_search(
fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
)
else: # automatically discover the gateway
try:
u = await cls.discover(
u = await UPnP.discover(
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
)
except UPnPError as err:
fut.set_exception(err)
return
if hasattr(u, method) and hasattr(getattr(u, method), "_cli"):
fn = getattr(u, method)
else:
if method not in cli_commands:
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
return
else:
fn = getattr(u, method)
try: # call the command
result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
fut.set_result(result)
@ -340,7 +346,7 @@ class UPnP:
log.exception("uncaught error")
fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"):
if not hasattr(UPnP, method):
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
else:
loop.run_until_complete(wrapper())
@ -354,3 +360,4 @@ class UPnP:
print(json.dumps(result, indent=2))
else:
print(result)
return