diff --git a/aioupnp/__main__.py b/aioupnp/__main__.py index abef621..7ee15ea 100644 --- a/aioupnp/__main__.py +++ b/aioupnp/__main__.py @@ -4,7 +4,7 @@ import logging import textwrap import typing from collections import OrderedDict -from aioupnp.upnp import UPnP +from aioupnp.upnp import run_cli, UPnP from aioupnp.commands import SOAPCommands log = logging.getLogger("aioupnp") @@ -100,7 +100,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None, interface: str = str(options.pop('interface')) unicast: bool = bool(options.pop('unicast')) - UPnP.run_cli( + run_cli( command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop ) return 0 diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index 7e8a27f..98dcb40 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -11,21 +11,17 @@ from aioupnp.gateway import Gateway from aioupnp.interfaces import get_gateway_and_lan_addresses from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.serialization.ssdp import SSDPDatagram +from aioupnp.commands import SOAPCommands log = logging.getLogger(__name__) -def cli(fn): - fn._cli = True - return fn - - -def _encode(x): - if isinstance(x, bytes): - return x.decode() - elif isinstance(x, Exception): - return str(x) - return x +# def _encode(x): +# if isinstance(x, bytes): +# return x.decode() +# elif isinstance(x, Exception): +# return str(x) +# return x class UPnP: @@ -36,7 +32,7 @@ class UPnP: @classmethod def get_annotations(cls, command: str) -> Dict[str, type]: - return getattr(Gateway.commands, command).__annotations__ + return getattr(SOAPCommands, command).__annotations__ @classmethod def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', @@ -61,7 +57,6 @@ class UPnP: return cls(lan_address, gateway_address, gateway) @classmethod - @cli async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, igd_args: Optional[Dict[str, Union[int, str]]] = None, unicast: bool = True, interface_name: str = 'default', @@ -86,92 +81,87 @@ class UPnP: 'discover_reply': datagram.as_dict() } - @cli async def get_external_ip(self) -> str: return await self.gateway.commands.GetExternalIPAddress() - @cli async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str, description: str) -> None: - return await self.gateway.commands.AddPortMapping( + await self.gateway.commands.AddPortMapping( NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol, NewInternalPort=internal_port, NewInternalClient=lan_address, NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0' ) + return None - @cli - async def get_port_mapping_by_index(self, index: int): - return await self._get_port_mapping_by_index(index) - # if result: - # if self.gateway.commands.is_registered('GetGenericPortMappingEntry'): - # return { - # k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result) - # } - # return {} - - async def _get_port_mapping_by_index(self, index: int) -> Optional[Tuple[Optional[str], int, str, - int, str, bool, str, int]]: + async def get_port_mapping_by_index(self, index: int) -> Optional[Tuple[str, int, str, int, str, bool, str, int]]: try: redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) return redirect except UPnPError: return None - @cli - async def get_redirects(self) -> List[Dict]: - redirects = [] + async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]: + redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = [] cnt = 0 - redirect = await self.get_port_mapping_by_index(cnt) - while redirect: + redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None + try: + redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt) + except UPnPError: + pass + while redirect is not None: redirects.append(redirect) cnt += 1 - redirect = await self.get_port_mapping_by_index(cnt) + try: + redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt) + except UPnPError: + pass return redirects - @cli - async def get_specific_port_mapping(self, external_port: int, protocol: str): + async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]: """ :param external_port: (int) external port to listen on :param protocol: (str) 'UDP' | 'TCP' :return: (int) , (str) , (bool) , (str) , (int) """ - - # try: return await self.gateway.commands.GetSpecificPortMappingEntry( NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol ) - # except UPnPError: - # pass - # return {} - @cli async def delete_port_mapping(self, external_port: int, protocol: str) -> None: """ :param external_port: (int) external port to listen on :param protocol: (str) 'UDP' | 'TCP' :return: None """ - return await self.gateway.commands.DeletePortMapping( + await self.gateway.commands.DeletePortMapping( NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol ) + return None - @cli async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: Optional[int] = None) -> int: if protocol not in ["UDP", "TCP"]: raise UPnPError("unsupported protocol: {}".format(protocol)) - internal_port = int(internal_port or port) - requested_port = int(internal_port) - redirect_tups = [] + _internal_port = int(internal_port or port) + requested_port = int(_internal_port) + redirect_tups: List[Tuple[str, int, str, int, str, bool, str, int]] = [] cnt = 0 port = int(port) - redirect = await self._get_port_mapping_by_index(cnt) - while redirect: + redirect: Optional[Tuple[str, int, str, int, str, bool, str, int]] = None + try: + redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt) + except UPnPError: + pass + while redirect is not None: redirect_tups.append(redirect) cnt += 1 - redirect = await self._get_port_mapping_by_index(cnt) + try: + redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=cnt) + except UPnPError as err: + if "ArrayIndex" in str(err): + break - redirects = { + redirects: Dict[Tuple[int, str], Tuple[str, int, str]] = { (ext_port, proto): (int_host, int_port, desc) for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, _) in redirect_tups } @@ -181,18 +171,20 @@ class UPnP: if int_host == self.lan_address and int_port == requested_port and desc == description: return port port += 1 - await self.add_port_mapping( # set one up - port, protocol, internal_port, self.lan_address, description + await self.gateway.commands.AddPortMapping( + NewRemoteHost='', NewExternalPort=port, NewProtocol=protocol, + NewInternalPort=_internal_port, NewInternalClient=self.lan_address, + NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration='0' ) return port - @cli - async def debug_gateway(self) -> str: - return json.dumps({ - "gateway": self.gateway.debug_gateway(), - "client_address": self.lan_address, - }, default=_encode, indent=2) - + # @cli + # async def debug_gateway(self) -> str: + # return json.dumps({ + # "gateway": self.gateway.debug_gateway(), + # "client_address": self.lan_address, + # }, default=_encode, indent=2) + # # @property # def zipped_debugging_info(self) -> str: # return base64.b64encode(zlib.compress( @@ -291,66 +283,81 @@ class UPnP: # """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" # return await self.gateway.commands.GetActiveConnections() - @classmethod - def run_cli(cls, method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', - gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', - unicast: bool = True, kwargs: Optional[Dict] = None, loop=None) -> None: - """ - :param method: the command name - :param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided - :param lan_address: the ip address of the local interface - :param gateway_address: the ip address of the gateway - :param timeout: timeout, in seconds - :param interface_name: name of the network interface, the default is aliased to 'default' - :param kwargs: keyword arguments for the command - :param loop: EventLoop, used for testing - """ - kwargs = kwargs or {} - igd_args = igd_args - timeout = int(timeout) - loop = loop or asyncio.get_event_loop() - fut: 'asyncio.Future' = asyncio.Future(loop=loop) - async def wrapper(): # wrap the upnp setup and call of the command in a coroutine +def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', + gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', + unicast: bool = True, kwargs: Optional[Dict] = None, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + """ + :param method: the command name + :param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided + :param lan_address: the ip address of the local interface + :param gateway_address: the ip address of the gateway + :param timeout: timeout, in seconds + :param interface_name: name of the network interface, the default is aliased to 'default' + :param kwargs: keyword arguments for the command + :param loop: EventLoop, used for testing + """ - if method == 'm_search': # if we're only m_searching don't do any device discovery - fn = lambda *_a, **_kw: cls.m_search( - lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop + + kwargs = kwargs or {} + igd_args = igd_args + timeout = int(timeout) + loop = loop or asyncio.get_event_loop() + fut: 'asyncio.Future' = asyncio.Future(loop=loop) + + async def wrapper(): # wrap the upnp setup and call of the command in a coroutine + cli_commands = [ + 'm_search', + 'get_external_ip', + 'add_port_mapping', + 'get_port_mapping_by_index', + 'get_redirects', + 'get_specific_port_mapping', + 'delete_port_mapping', + 'get_next_mapping' + ] + + if method == 'm_search': # if we're only m_searching don't do any device discovery + fn = lambda *_a, **_kw: UPnP.m_search( + lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop + ) + else: # automatically discover the gateway + try: + u = await UPnP.discover( + lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop ) - else: # automatically discover the gateway - try: - u = await cls.discover( - lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop - ) - except UPnPError as err: - fut.set_exception(err) - return - if hasattr(u, method) and hasattr(getattr(u, method), "_cli"): - fn = getattr(u, method) - else: - fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) - return - try: # call the command - result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()}) - fut.set_result(result) except UPnPError as err: fut.set_exception(err) + return + if method not in cli_commands: + fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) + return + else: + fn = getattr(u, method) - except Exception as err: - log.exception("uncaught error") - fut.set_exception(UPnPError("uncaught error: %s" % str(err))) - - if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"): - fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) - else: - loop.run_until_complete(wrapper()) - try: - result = fut.result() + try: # call the command + result = await fn(**{k: fn.__annotations__[k](v) for k, v in kwargs.items()}) + fut.set_result(result) except UPnPError as err: - print("aioupnp encountered an error: %s" % str(err)) - return + fut.set_exception(err) - if isinstance(result, (list, tuple, dict)): - print(json.dumps(result, indent=2)) - else: - print(result) + except Exception as err: + log.exception("uncaught error") + fut.set_exception(UPnPError("uncaught error: %s" % str(err))) + + if not hasattr(UPnP, method): + fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) + else: + loop.run_until_complete(wrapper()) + try: + result = fut.result() + except UPnPError as err: + print("aioupnp encountered an error: %s" % str(err)) + return + + if isinstance(result, (list, tuple, dict)): + print(json.dumps(result, indent=2)) + else: + print(result) + return diff --git a/stubs/defusedxml.py b/stubs/defusedxml.py index db9fd1c..bb94534 100644 --- a/stubs/defusedxml.py +++ b/stubs/defusedxml.py @@ -20,4 +20,4 @@ class ElementTree: @classmethod def fromstring(cls, xml_str: str) -> 'ElementTree': - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError()