mypy refactor, improve coverage #12

Merged
jackrobison merged 9 commits from mypy-refactor into master 2019-05-22 09:05:10 +02:00
3 changed files with 125 additions and 118 deletions
Showing only changes of commit 8457ef54e2 - Show all commits

View file

@ -4,7 +4,7 @@ import logging
import textwrap import textwrap
import typing import typing
from collections import OrderedDict from collections import OrderedDict
from aioupnp.upnp import UPnP from aioupnp.upnp import run_cli, UPnP
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
log = logging.getLogger("aioupnp") log = logging.getLogger("aioupnp")
@ -100,7 +100,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
interface: str = str(options.pop('interface')) interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast')) unicast: bool = bool(options.pop('unicast'))
UPnP.run_cli( run_cli(
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
) )
return 0 return 0

View file

@ -11,21 +11,17 @@ from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import SOAPCommands
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def cli(fn): # def _encode(x):
fn._cli = True # if isinstance(x, bytes):
return fn # return x.decode()
# elif isinstance(x, Exception):
# return str(x)
def _encode(x): # return x
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
class UPnP: class UPnP:
@ -36,7 +32,7 @@ class UPnP:
@classmethod @classmethod
def get_annotations(cls, command: str) -> Dict[str, type]: def get_annotations(cls, command: str) -> Dict[str, type]:
return getattr(Gateway.commands, command).__annotations__ return getattr(SOAPCommands, command).__annotations__
@classmethod @classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
@ -61,7 +57,6 @@ class UPnP:
return cls(lan_address, gateway_address, gateway) return cls(lan_address, gateway_address, gateway)
@classmethod @classmethod
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: Optional[Dict[str, Union[int, str]]] = None, igd_args: Optional[Dict[str, Union[int, str]]] = None,
unicast: bool = True, interface_name: str = 'default', unicast: bool = True, interface_name: str = 'default',
@ -86,92 +81,87 @@ class UPnP:
'discover_reply': datagram.as_dict() 'discover_reply': datagram.as_dict()
} }
@cli
async def get_external_ip(self) -> str: async def get_external_ip(self) -> str:
return await self.gateway.commands.GetExternalIPAddress() return await self.gateway.commands.GetExternalIPAddress()
@cli
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str, async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
description: str) -> None: description: str) -> None:
return await self.gateway.commands.AddPortMapping( await self.gateway.commands.AddPortMapping(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol, NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address, NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0' NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
) )
return None
@cli async def get_port_mapping_by_index(self, index: int) -> Optional[Tuple[str, int, str, int, str, bool, str, int]]:
async def get_port_mapping_by_index(self, index: int):
return await self._get_port_mapping_by_index(index)
# if result:
# if self.gateway.commands.is_registered('GetGenericPortMappingEntry'):
# return {
# k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
# }
# return {}
async def _get_port_mapping_by_index(self, index: int) -> Optional[Tuple[Optional[str], int, str,
int, str, bool, str, int]]:
try: try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect return redirect
except UPnPError: except UPnPError:
return None return None
@cli async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
async def get_redirects(self) -> List[Dict]: redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
redirects = []
cnt = 0 cnt = 0
redirect = await self.get_port_mapping_by_index(cnt) redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
while redirect: try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError:
pass
while redirect is not None:
redirects.append(redirect) redirects.append(redirect)
cnt += 1 cnt += 1
redirect = await self.get_port_mapping_by_index(cnt) try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError:
pass
return redirects return redirects
@cli async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
async def get_specific_port_mapping(self, external_port: int, protocol: str):
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time> :return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
""" """
# try:
return await self.gateway.commands.GetSpecificPortMappingEntry( return await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
) )
# except UPnPError:
# pass
# return {}
@cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None: async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
:return: None :return: None
""" """
return await self.gateway.commands.DeletePortMapping( await self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
) )
return None
@cli
async def get_next_mapping(self, port: int, protocol: str, description: str, async def get_next_mapping(self, port: int, protocol: str, description: str,
internal_port: Optional[int] = None) -> int: internal_port: Optional[int] = None) -> int:
if protocol not in ["UDP", "TCP"]: if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol)) raise UPnPError("unsupported protocol: {}".format(protocol))
internal_port = int(internal_port or port) _internal_port = int(internal_port or port)
requested_port = int(internal_port) requested_port = int(_internal_port)
redirect_tups = [] redirect_tups: List[Tuple[str, int, str, int, str, bool, str, int]] = []
cnt = 0 cnt = 0
port = int(port) port = int(port)
redirect = await self._get_port_mapping_by_index(cnt) redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None
while redirect: try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError:
pass
while redirect is not None:
redirect_tups.append(redirect) redirect_tups.append(redirect)
cnt += 1 cnt += 1
redirect = await self._get_port_mapping_by_index(cnt) try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt)
except UPnPError as err:
if "ArrayIndex" in str(err):
break
redirects = { redirects: Dict[Tuple[int, str], Tuple[str, int, str]] = {
(ext_port, proto): (int_host, int_port, desc) (ext_port, proto): (int_host, int_port, desc)
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups
} }
@ -181,18 +171,20 @@ class UPnP:
if int_host == self.lan_address and int_port == requested_port and desc == description: if int_host == self.lan_address and int_port == requested_port and desc == description:
return port return port
port += 1 port += 1
await self.add_port_mapping( # set one up await self.gateway.commands.AddPortMapping(
port, protocol, internal_port, self.lan_address, description NewRemoteHost='', NewExternalPort=port, NewProtocol=protocol,
NewInternalPort=_internal_port, NewInternalClient=self.lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0'
) )
return port return port
@cli # @cli
async def debug_gateway(self) -> str: # async def debug_gateway(self) -> str:
return json.dumps({ # return json.dumps({
"gateway": self.gateway.debug_gateway(), # "gateway": self.gateway.debug_gateway(),
"client_address": self.lan_address, # "client_address": self.lan_address,
}, default=_encode, indent=2) # }, default=_encode, indent=2)
#
# @property # @property
# def zipped_debugging_info(self) -> str: # def zipped_debugging_info(self) -> str:
# return base64.b64encode(zlib.compress( # return base64.b64encode(zlib.compress(
@ -291,66 +283,81 @@ class UPnP:
# """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" # """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
# return await self.gateway.commands.GetActiveConnections() # return await self.gateway.commands.GetActiveConnections()
@classmethod
def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
:param lan_address: the ip address of the local interface
:param gateway_address: the ip address of the gateway
:param timeout: timeout, in seconds
:param interface_name: name of the network interface, the default is aliased to 'default'
:param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing
"""
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
:param lan_address: the ip address of the local interface
:param gateway_address: the ip address of the gateway
:param timeout: timeout, in seconds
:param interface_name: name of the network interface, the default is aliased to 'default'
:param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing
"""
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: cls.m_search( kwargs = kwargs or {}
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
cli_commands = [
'm_search',
'get_external_ip',
'add_port_mapping',
'get_port_mapping_by_index',
'get_redirects',
'get_specific_port_mapping',
'delete_port_mapping',
'get_next_mapping'
]
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
)
else: # automatically discover the gateway
try:
u = await UPnP.discover(
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
) )
else: # automatically discover the gateway
try:
u = await cls.discover(
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
)
except UPnPError as err:
fut.set_exception(err)
return
if hasattr(u, method) and hasattr(getattr(u, method), "_cli"):
fn = getattr(u, method)
else:
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
return
try: # call the command
result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
fut.set_result(result)
except UPnPError as err: except UPnPError as err:
fut.set_exception(err) fut.set_exception(err)
return
if method not in cli_commands:
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
return
else:
fn = getattr(u, method)
except Exception as err: try: # call the command
log.exception("uncaught error") result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()})
fut.set_exception(UPnPError("uncaught error: %s" % str(err))) fut.set_result(result)
if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"):
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
else:
loop.run_until_complete(wrapper())
try:
result = fut.result()
except UPnPError as err: except UPnPError as err:
print("aioupnp encountered an error: %s" % str(err)) fut.set_exception(err)
return
if isinstance(result, (list, tuple, dict)): except Exception as err:
print(json.dumps(result, indent=2)) log.exception("uncaught error")
else: fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
print(result)
if not hasattr(UPnP, method):
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
else:
loop.run_until_complete(wrapper())
try:
result = fut.result()
except UPnPError as err:
print("aioupnp encountered an error: %s" % str(err))
return
if isinstance(result, (list, tuple, dict)):
print(json.dumps(result, indent=2))
else:
print(result)
return

View file

@ -20,4 +20,4 @@ class ElementTree:
@classmethod @classmethod
def fromstring(cls, xml_str: str) -> 'ElementTree': def fromstring(cls, xml_str: str) -> 'ElementTree':
raise NotImplementedError() raise NotImplementedError()