diff --git a/aioupnp/__main__.py b/aioupnp/__main__.py index 7ee15ea..0f50e28 100644 --- a/aioupnp/__main__.py +++ b/aioupnp/__main__.py @@ -4,8 +4,7 @@ import logging import textwrap import typing from collections import OrderedDict -from aioupnp.upnp import run_cli, UPnP -from aioupnp.commands import SOAPCommands +from aioupnp.upnp import run_cli, UPnP, cli_commands log = logging.getLogger("aioupnp") handler = logging.StreamHandler() @@ -20,36 +19,53 @@ base_usage = "\n".join(textwrap.wrap( def get_help(command: str) -> str: - annotations = UPnP.get_annotations(command) - params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return']) - return base_usage + "\n".join( - textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False) - ) + annotations, doc = UPnP.get_annotations(command) + doc = doc or "" + + arg_strs = [] + for k, v in annotations.items(): + if k not in ['return', 'igd_args', 'loop']: + t = str(v) if not hasattr(v, "__name__") else v.__name__ + if t == 'bool': + arg_strs.append(f"[--{k}]") + else: + arg_strs.append(f"[--{k}=<{t}>]") + elif k == 'igd_args': + arg_strs.append(f"[--
=
, ...]") + + params = " ".join(arg_strs) + usage = "\n".join(textwrap.wrap( + f"aioupnp [-h] [--debug_logging] {command} {params}", + 100, subsequent_indent=' ', break_long_words=False)) + "\n" + + return usage + textwrap.dedent(doc) def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int: argv = argv or list(sys.argv) - commands = list(SOAPCommands.SOAP_COMMANDS) help_str = "\n".join(textwrap.wrap( - " | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False + " | ".join(cli_commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False )) usage = \ - "\n%s\n" \ + "%s\n" \ "If m-search headers are provided as keyword arguments all of the headers to be used must be provided,\n" \ "in the order they are to be used. For example:\n" \ " aioupnp --HOST=239.255.255.250:1900 --MAN=\"ssdp:discover\" --MX=1 --ST=upnp:rootdevice m_search\n\n" \ "Commands:\n" \ "%s\n\n" \ "For help with a specific command:" \ - " aioupnp help \n" % (base_usage, help_str) + " aioupnp help " % (base_usage, help_str) args: typing.List[str] = [str(arg) for arg in argv[1:]] + if not args: + print(usage) + return 0 if args[0] in ['help', '-h', '--help']: if len(args) > 1: - if args[1] in commands: - print(get_help(args[1])) + if args[1].replace("-", "_") in cli_commands: + print(get_help(args[1].replace("-", "_"))) return 0 print(usage) return 0 diff --git a/aioupnp/commands.py b/aioupnp/commands.py index 3200a62..163ad44 100644 --- a/aioupnp/commands.py +++ b/aioupnp/commands.py @@ -2,43 +2,64 @@ import asyncio import time import typing import logging -from typing import Tuple from aioupnp.protocols.scpd import scpd_post from aioupnp.device import Service log = logging.getLogger(__name__) -def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]: - return x if x is not None and str(x).lower() not in ['none', 'nil'] else None +def soap_optional_str(x: typing.Optional[typing.Union[str, int]]) -> typing.Optional[str]: + return str(x) if x is not None and str(x).lower() not in ['none', 'nil'] else None -def soap_bool(x: typing.Optional[str]) -> bool: +def soap_bool(x: typing.Optional[typing.Union[str, int]]) -> bool: return False if not x or str(x).lower() in ['false', 'False'] else True -def recast_single_result(t: type, result: typing.Any) -> typing.Optional[typing.Union[str, int, float, bool]]: - if t is bool: - return soap_bool(result) - if t is str: - return soap_optional_str(result) - return t(result) +class GetSpecificPortMappingEntryResponse(typing.NamedTuple): + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int + + +class GetGenericPortMappingEntryResponse(typing.NamedTuple): + gateway_address: str + external_port: int + protocol: str + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]], - result_keys: typing.List[str]) -> typing.Tuple: - if return_annotation is None or len(result_keys) == 0: - return () + result_keys: typing.List[str]) -> typing.Optional[ + typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]: if len(result_keys) == 1: - assert len(result_keys) == 1 single_result = result[result_keys[0]] - return (recast_single_result(return_annotation, single_result), ) - annotated_args: typing.List[type] = list(return_annotation.__args__) - assert len(annotated_args) == len(result_keys) - recast_results: typing.List[typing.Optional[typing.Union[str, int, float, bool]]] = [] - for type_annotation, result_key in zip(annotated_args, result_keys): - recast_results.append(recast_single_result(type_annotation, result.get(result_key, None))) - return tuple(recast_results) + if return_annotation is bool: + return soap_bool(single_result) + if return_annotation is str: + return soap_optional_str(single_result) + return int(result[result_keys[0]]) if result_keys[0] in result else None + elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]: + arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types + assert len(arg_types) == len(result_keys) + recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {} + for i, (field_name, result_key) in enumerate(zip(arg_types, result_keys)): + result_field_name = result_keys[i] + field_type = arg_types[field_name] + if field_type is bool: + recast_results[field_name] = soap_bool(result.get(result_field_name, None)) + elif field_type is str: + recast_results[field_name] = soap_optional_str(result.get(result_field_name, None)) + elif field_type is int: + recast_results[field_name] = int(result[result_field_name]) if result_field_name in result else None + return return_annotation(**recast_results) + return None class SOAPCommands: @@ -88,7 +109,10 @@ class SOAPCommands: self._base_address = base_address self._port = port self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, - typing.Tuple, typing.Optional[Exception], float]] = [] + typing.Optional[typing.Union[str, int, bool, + GetSpecificPortMappingEntryResponse, + GetGenericPortMappingEntryResponse]], + typing.Optional[Exception], float]] = [] def is_registered(self, name: str) -> bool: if name not in self.SOAP_COMMANDS: @@ -112,7 +136,8 @@ class SOAPCommands: input_names: typing.List[str] = self._registered[service][name][0] output_names: typing.List[str] = self._registered[service][name][1] - async def wrapper(**kwargs: typing.Any) -> typing.Tuple: + async def wrapper(**kwargs: typing.Any) -> typing.Optional[ + typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]: assert service.controlURL is not None assert service.serviceType is not None @@ -122,11 +147,10 @@ class SOAPCommands: ) if err is not None: assert isinstance(xml_bytes, bytes) - self._requests.append((name, kwargs, xml_bytes, (), err, time.time())) + self._requests.append((name, kwargs, xml_bytes, None, err, time.time())) raise err assert 'return' in annotations result = recast_return(annotations['return'], response, output_names) - self._requests.append((name, kwargs, xml_bytes, result, None, time.time())) return result @@ -161,8 +185,7 @@ class SOAPCommands: ) return None - async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str, - bool, str, int]: + async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> GetGenericPortMappingEntryResponse: """ Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration) @@ -171,19 +194,19 @@ class SOAPCommands: if not self.is_registered(name): raise NotImplementedError() assert name in self._wrappers_kwargs - result: Tuple[str, int, str, int, str, bool, str, int] = await self._wrappers_kwargs[name]( + result: GetGenericPortMappingEntryResponse = await self._wrappers_kwargs[name]( NewPortMappingIndex=NewPortMappingIndex ) return result async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int, - NewProtocol: str) -> Tuple[int, str, bool, str, int]: + NewProtocol: str) -> GetSpecificPortMappingEntryResponse: """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" name = "GetSpecificPortMappingEntry" if not self.is_registered(name): raise NotImplementedError() assert name in self._wrappers_kwargs - result: Tuple[int, str, bool, str, int] = await self._wrappers_kwargs[name]( + result: GetSpecificPortMappingEntryResponse = await self._wrappers_kwargs[name]( NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol ) return result @@ -205,8 +228,8 @@ class SOAPCommands: if not self.is_registered(name): raise NotImplementedError() assert name in self._wrappers_no_args - result: Tuple[str] = await self._wrappers_no_args[name]() - return result[0] + result: str = await self._wrappers_no_args[name]() + return result # async def GetNATRSIPStatus(self) -> Tuple[bool, bool]: # """Returns (NewRSIPAvailable, NewNATEnabled)""" diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index b40d8cb..cd2a68c 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -3,7 +3,7 @@ import logging import typing import asyncio from collections import OrderedDict -from typing import Dict, List, Union +from typing import Dict, List from aioupnp.util import get_dict_val_case_insensitive from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.commands import SOAPCommands @@ -103,74 +103,81 @@ class Gateway: self._registered_commands: Dict[str, str] = {} self.commands = SOAPCommands(self._loop, self.base_ip, self.port) - def gateway_descriptor(self) -> dict: - r = { - 'server': self.server.decode(), - 'urlBase': self.url_base, - 'location': self.location.decode(), - "specVersion": self.spec_version, - 'usn': self.usn.decode(), - 'urn': self.urn.decode(), - } - return r + # def gateway_descriptor(self) -> dict: + # r = { + # 'server': self.server.decode(), + # 'urlBase': self.url_base, + # 'location': self.location.decode(), + # "specVersion": self.spec_version, + # 'usn': self.usn.decode(), + # 'urn': self.urn.decode(), + # } + # return r @property def manufacturer_string(self) -> str: - if not self.devices: - return "UNKNOWN GATEWAY" - devices: typing.List[Device] = list(self.devices.values()) - device = devices[0] - return f"{device.manufacturer} {device.modelName}" + manufacturer_string = "UNKNOWN GATEWAY" + if self.devices: + devices: typing.List[Device] = list(self.devices.values()) + device = devices[0] + manufacturer_string = f"{device.manufacturer} {device.modelName}" + return manufacturer_string @property def services(self) -> Dict[str, Service]: - if not self._device: - return {} - return {str(service.serviceType): service for service in self._services} + services: Dict[str, Service] = {} + if self._services: + for service in self._services: + if service.serviceType is not None: + services[service.serviceType] = service + return services @property - def devices(self) -> Dict: - if not self._device: - return {} - return {device.udn: device for device in self._devices} + def devices(self) -> Dict[str, Device]: + devices: Dict[str, Device] = {} + if self._device: + for device in self._devices: + if device.udn is not None: + devices[device.udn] = device + return devices - def get_service(self, service_type: str) -> typing.Optional[Service]: - for service in self._services: - if service.serviceType and service.serviceType.lower() == service_type.lower(): - return service - return None + # def get_service(self, service_type: str) -> typing.Optional[Service]: + # for service in self._services: + # if service.serviceType and service.serviceType.lower() == service_type.lower(): + # return service + # return None - @property - def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, - typing.Optional[typing.Tuple], - typing.Optional[Exception], float]]: - soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, - typing.Optional[typing.Tuple], - typing.Optional[Exception], float]] = [] - soap_call_infos.extend([ - (name, request_args, raw_response, decoded_response, soap_error, ts) - for ( - name, request_args, raw_response, decoded_response, soap_error, ts - ) in self.commands._requests - ]) - soap_call_infos.sort(key=lambda x: x[5]) - return soap_call_infos + # @property + # def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, + # typing.Optional[typing.Tuple], + # typing.Optional[Exception], float]]: + # soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, + # typing.Optional[typing.Tuple], + # typing.Optional[Exception], float]] = [] + # soap_call_infos.extend([ + # (name, request_args, raw_response, decoded_response, soap_error, ts) + # for ( + # name, request_args, raw_response, decoded_response, soap_error, ts + # ) in self.commands._requests + # ]) + # soap_call_infos.sort(key=lambda x: x[5]) + # return soap_call_infos - def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]: - return { - 'manufacturer_string': self.manufacturer_string, - 'gateway_address': self.base_ip, - 'gateway_descriptor': self.gateway_descriptor(), - 'gateway_xml': self._xml_response, - 'services_xml': self._service_descriptors, - 'services': {service.SCPDURL: service.as_dict() for service in self._services}, - 'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()], - 'reply': self._ok_packet.as_dict(), - 'soap_port': self.port, - 'registered_soap_commands': self._registered_commands, - 'unsupported_soap_commands': self._unsupported_actions, - 'soap_requests': self.soap_requests - } + # def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]: + # return { + # 'manufacturer_string': self.manufacturer_string, + # 'gateway_address': self.base_ip, + # 'gateway_descriptor': self.gateway_descriptor(), + # 'gateway_xml': self._xml_response, + # 'services_xml': self._service_descriptors, + # 'services': {service.SCPDURL: service.as_dict() for service in self._services}, + # 'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()], + # 'reply': self._ok_packet.as_dict(), + # 'soap_port': self.port, + # 'registered_soap_commands': self._registered_commands, + # 'unsupported_soap_commands': self._unsupported_actions, + # 'soap_requests': self.soap_requests + # } @classmethod async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, @@ -206,7 +213,7 @@ class Gateway: ignored.add(datagram.location) continue else: - log.debug('found gateway device %s', datagram.location) + log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location) return gateway except (asyncio.TimeoutError, UPnPError) as err: assert datagram.location is not None diff --git a/aioupnp/protocols/multicast.py b/aioupnp/protocols/multicast.py index 4aadb2e..76adb0c 100644 --- a/aioupnp/protocols/multicast.py +++ b/aioupnp/protocols/multicast.py @@ -2,11 +2,11 @@ import struct import socket import typing from asyncio.protocols import DatagramProtocol -from asyncio.transports import BaseTransport +from asyncio.transports import DatagramTransport from unittest import mock -def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]: +def _get_sock(transport: typing.Optional[DatagramTransport]) -> typing.Optional[socket.socket]: if transport is None or not hasattr(transport, "_extra"): return None sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None) @@ -18,7 +18,7 @@ class MulticastProtocol(DatagramProtocol): def __init__(self, multicast_address: str, bind_address: str) -> None: self.multicast_address = multicast_address self.bind_address = bind_address - self.transport: typing.Optional[BaseTransport] = None + self.transport: typing.Optional[DatagramTransport] = None @property def sock(self) -> typing.Optional[socket.socket]: @@ -26,40 +26,37 @@ class MulticastProtocol(DatagramProtocol): def get_ttl(self) -> int: sock = self.sock - if not sock: - raise ValueError("not connected") - return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) + if sock: + return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) + return 0 def set_ttl(self, ttl: int = 1) -> None: sock = self.sock - if not sock: - return None - sock.setsockopt( - socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) - ) + if sock: + sock.setsockopt( + socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) + ) return None def join_group(self, multicast_address: str, bind_address: str) -> None: sock = self.sock - if not sock: - return None - sock.setsockopt( - socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, - socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) - ) + if sock: + sock.setsockopt( + socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, + socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) + ) return None def leave_group(self, multicast_address: str, bind_address: str) -> None: sock = self.sock - if not sock: - raise ValueError("not connected") - sock.setsockopt( - socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, - socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) - ) + if sock: + sock.setsockopt( + socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, + socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) + ) return None - def connection_made(self, transport: BaseTransport) -> None: + def connection_made(self, transport: DatagramTransport) -> None: # type: ignore self.transport = transport return None diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 846f1ce..0f8782d 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -28,19 +28,14 @@ class SSDPProtocol(MulticastProtocol): self.notifications: typing.List[SSDPDatagram] = [] self.connected = asyncio.Event(loop=self.loop) - def connection_made(self, transport) -> None: - # assert isinstance(transport, asyncio.DatagramTransport), str(type(transport)) + def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore super().connection_made(transport) self.connected.set() + return None def disconnect(self) -> None: if self.transport: - try: - self.leave_group(self.multicast_address, self.bind_address) - except ValueError: - pass - except Exception: - log.exception("unexpected error leaving multicast group") + self.leave_group(self.multicast_address, self.bind_address) self.transport.close() self.connected.clear() while self._pending_searches: @@ -50,29 +45,30 @@ class SSDPProtocol(MulticastProtocol): return None def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: - if packet.location in self._ignored: - return None - # TODO: fix this - tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] - set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = [] - while len(self._pending_searches): - t: typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle] = self._pending_searches.pop() - if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): - f = t[2] - if f not in set_futures: - set_futures.append(f) - if not f.done(): - f.set_result(packet) - elif t[2] not in set_futures: - tmp.append(t) - while tmp: - self._pending_searches.append(tmp.pop()) + if packet.location not in self._ignored: + # TODO: fix this + tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] + set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = [] + while len(self._pending_searches): + t = self._pending_searches.pop() + if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): + f = t[2] + if f not in set_futures: + set_futures.append(f) + if not f.done(): + f.set_result(packet) + elif t[2] not in set_futures: + tmp.append(t) + while tmp: + self._pending_searches.append(tmp.pop()) return None - def _send_m_search(self, address: str, packet: SSDPDatagram) -> None: + def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None: dest = address if self._unicast else SSDP_IP_ADDRESS if not self.transport: - raise UPnPError("SSDP transport not connected") + if not fut.done(): + fut.set_exception(UPnPError("SSDP transport not connected")) + return None log.debug("send m search to %s: %s", dest, packet.st) self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) return None @@ -84,7 +80,7 @@ class SSDPProtocol(MulticastProtocol): packet = SSDPDatagram("M-SEARCH", datagram) assert packet.st is not None self._pending_searches.append( - (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet)) + (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut)) ) return await asyncio.wait_for(fut, timeout) @@ -95,8 +91,8 @@ class SSDPProtocol(MulticastProtocol): packet = SSDPDatagram.decode(data) log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet) except UPnPError as err: - log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, - binascii.hexlify(data)) + log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, + binascii.hexlify(data)) return None if packet._packet_type == packet._OK: @@ -131,19 +127,13 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock ) - transport = listen_result[0] protocol = listen_result[1] assert isinstance(protocol, SSDPProtocol) except Exception as err: - print(err) raise UPnPError(err) - try: + else: protocol.join_group(protocol.multicast_address, protocol.bind_address) protocol.set_ttl(1) - except Exception as err: - protocol.disconnect() - raise UPnPError(err) - return protocol, gateway_address, lan_address @@ -178,10 +168,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = log.debug("sending batch of %i M-SEARCH attempts", batch_size) try: await protocol.m_search(gateway_address, batch_timeout, args) - protocol.disconnect() - return args except asyncio.TimeoutError: continue + else: + protocol.disconnect() + return args protocol.disconnect() raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index a248c6c..9b3eefc 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -5,13 +5,14 @@ import logging import json import asyncio from collections import OrderedDict -from typing import Tuple, Dict, List, Union, Optional +from typing import Tuple, Dict, List, Union, Optional, Any from aioupnp.fault import UPnPError from aioupnp.gateway import Gateway from aioupnp.interfaces import get_gateway_and_lan_addresses from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.serialization.ssdp import SSDPDatagram -from aioupnp.commands import SOAPCommands +from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse + log = logging.getLogger(__name__) @@ -31,11 +32,27 @@ class UPnP: self.gateway = gateway @classmethod - def get_annotations(cls, command: str) -> Dict[str, type]: - return getattr(SOAPCommands, command).__annotations__ + def get_annotations(cls, command: str) -> Tuple[Dict[str, Any], Optional[str]]: + if command == "m_search": + return cls.m_search.__annotations__, cls.m_search.__doc__ + if command == "get_external_ip": + return cls.get_external_ip.__annotations__, cls.get_external_ip.__doc__ + if command == "add_port_mapping": + return cls.add_port_mapping.__annotations__, cls.add_port_mapping.__doc__ + if command == "get_port_mapping_by_index": + return cls.get_port_mapping_by_index.__annotations__, cls.get_port_mapping_by_index.__doc__ + if command == "get_redirects": + return cls.get_redirects.__annotations__, cls.get_redirects.__doc__ + if command == "get_specific_port_mapping": + return cls.get_specific_port_mapping.__annotations__, cls.get_specific_port_mapping.__doc__ + if command == "delete_port_mapping": + return cls.delete_port_mapping.__annotations__, cls.delete_port_mapping.__doc__ + if command == "get_next_mapping": + return cls.get_next_mapping.__annotations__, cls.get_next_mapping.__doc__ + raise AttributeError(command) - @classmethod - def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', + @staticmethod + def get_lan_and_gateway(lan_address: str = '', gateway_address: str = '', interface_name: str = 'default') -> Tuple[str, str]: if not lan_address or not gateway_address: gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name) @@ -55,10 +72,28 @@ class UPnP: @classmethod async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, - igd_args: Optional[Dict[str, Union[int, str]]] = None, unicast: bool = True, interface_name: str = 'default', + igd_args: Optional[Dict[str, Union[str, int]]] = None, loop: Optional[asyncio.AbstractEventLoop] = None - ) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]: + ) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]: + """ + Perform a M-SEARCH for a upnp gateway. + + :param lan_address: (str) the local interface ipv4 address + :param gateway_address: (str) the gateway ipv4 address + :param timeout: (int) m search timeout + :param unicast: (bool) use unicast + :param interface_name: (str) name of the network interface + :param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided. + + :return: { + 'lan_address': (str) lan address, + 'gateway_address': (str) gateway address, + 'm_search_kwargs': (str) equivalent igd_args , + 'discover_reply': (dict) SSDP response datagram + } + """ + if not lan_address or not gateway_address: try: lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) @@ -79,10 +114,25 @@ class UPnP: } async def get_external_ip(self) -> str: + """ + Get the external ip address from the gateway + + :return: (str) external ip + """ return await self.gateway.commands.GetExternalIPAddress() async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str, description: str) -> None: + """ + Add a new port mapping + + :param external_port: (int) external port to map + :param protocol: (str) UDP | TCP + :param internal_port: (int) internal port + :param lan_address: (str) internal lan address + :param description: (str) mapping description + :return: None + """ await self.gateway.commands.AddPortMapping( NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol, NewInternalPort=internal_port, NewInternalClient=lan_address, @@ -90,11 +140,42 @@ class UPnP: ) return None - async def get_port_mapping_by_index(self, index: int) -> Tuple[str, int, str, int, str, bool, str, int]: + async def get_port_mapping_by_index(self, index: int) -> GetGenericPortMappingEntryResponse: + """ + Get information about a port mapping by index number + + :param index: (int) mapping index number + :return: NamedTuple[ + gateway_address: str + external_port: int + protocol: str + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int + ] + """ return await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) - async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]: - redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = [] + async def get_redirects(self) -> List[GetGenericPortMappingEntryResponse]: + """ + Get information about all mapped ports + + :return: List[ + NamedTuple[ + gateway_address: str + external_port: int + protocol: str + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int + ] + ] + """ + redirects: List[GetGenericPortMappingEntryResponse] = [] cnt = 0 try: redirect = await self.get_port_mapping_by_index(cnt) @@ -109,11 +190,19 @@ class UPnP: break return redirects - async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]: + async def get_specific_port_mapping(self, external_port: int, protocol: str) -> GetSpecificPortMappingEntryResponse: """ - :param external_port: (int) external port to listen on - :param protocol: (str) 'UDP' | 'TCP' - :return: (int) , (str) , (bool) , (str) , (int) + Get information about a port mapping by port number and protocol + + :param external_port: (int) port number + :param protocol: (str) UDP | TCP + :return: NamedTuple[ + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int + ] """ return await self.gateway.commands.GetSpecificPortMappingEntry( NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol @@ -121,25 +210,31 @@ class UPnP: async def delete_port_mapping(self, external_port: int, protocol: str) -> None: """ - :param external_port: (int) external port to listen on - :param protocol: (str) 'UDP' | 'TCP' + Delete a port mapping + + :param external_port: (int) port number of mapping + :param protocol: (str) TCP | UDP :return: None """ await self.gateway.commands.DeletePortMapping( NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol ) + return None async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: Optional[int] = None) -> int: """ - :param port: (int) external port to redirect from - :param protocol: (str) 'UDP' | 'TCP' - :param description: (str) mapping description - :param internal_port: (int) internal port to redirect to + Get a new port mapping. If the requested port is not available, increment until the next free port is mapped - :return: (int) + :param port: (int) external port + :param protocol: (str) UDP | TCP + :param description: (str) mapping description + :param internal_port: (int) internal port + + :return: (int) mapped port """ + _internal_port = int(internal_port or port) requested_port = int(_internal_port) port = int(port) @@ -264,22 +359,23 @@ class UPnP: # return await self.gateway.commands.GetActiveConnections() -def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', - gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', - unicast: bool = True, kwargs: Optional[Dict] = None, - loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - """ - :param method: the command name - :param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided - :param lan_address: the ip address of the local interface - :param gateway_address: the ip address of the gateway - :param timeout: timeout, in seconds - :param interface_name: name of the network interface, the default is aliased to 'default' - :param kwargs: keyword arguments for the command - :param loop: EventLoop, used for testing - """ +cli_commands = [ + 'm_search', + 'get_external_ip', + 'add_port_mapping', + 'get_port_mapping_by_index', + 'get_redirects', + 'get_specific_port_mapping', + 'delete_port_mapping', + 'get_next_mapping' +] +def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', + gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', + unicast: bool = True, kwargs: Optional[Dict[str, str]] = None, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + kwargs = kwargs or {} igd_args = igd_args timeout = int(timeout) @@ -287,20 +383,9 @@ def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str fut: 'asyncio.Future' = asyncio.Future(loop=loop) async def wrapper(): # wrap the upnp setup and call of the command in a coroutine - cli_commands = [ - 'm_search', - 'get_external_ip', - 'add_port_mapping', - 'get_port_mapping_by_index', - 'get_redirects', - 'get_specific_port_mapping', - 'delete_port_mapping', - 'get_next_mapping' - ] - if method == 'm_search': # if we're only m_searching don't do any device discovery fn = lambda *_a, **_kw: UPnP.m_search( - lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop + lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop ) else: # automatically discover the gateway try: diff --git a/tests/__init__.py b/tests/__init__.py index 9d87eeb..acd71a5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,7 +16,7 @@ except ImportError: @contextlib.contextmanager def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, - tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None): + tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False): sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] udp_replies = udp_replies or {} @@ -28,7 +28,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r def _write(data): sent_tcp_packets.append(data) if data in tcp_replies: - loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data]) + reply = tcp_replies[data] + i = 0 + while i < len(reply): + loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+100]) + i += 100 return else: pass @@ -46,6 +50,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r def sendto(p: asyncio.DatagramProtocol): def _sendto(data, addr): sent_udp_packets.append(data) + loop.call_later(udp_delay_reply, p.datagram_received, data, + (p.bind_address, 1900)) + if add_potato_datagrams: + loop.call_soon(p.datagram_received, b'potato', ('?.?.?.?', 1900)) + if (data, addr) in udp_replies: loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)], (udp_expected_addr, 1900)) diff --git a/tests/protocols/test_multicast.py b/tests/protocols/test_multicast.py index 065f12e..37774fa 100644 --- a/tests/protocols/test_multicast.py +++ b/tests/protocols/test_multicast.py @@ -1,23 +1,30 @@ import unittest - +from unittest import mock +import socket +import struct from asyncio import DatagramTransport from aioupnp.protocols.multicast import MulticastProtocol class TestMulticast(unittest.TestCase): - def test_it(self): - class none_socket: - sock = None + def test_multicast(self): + _ttl = None + mock_socket = mock.MagicMock(spec=socket.socket) + def getsockopt(*_): + return _ttl - def get(self, name, default=None): - return default + def setsockopt(a, b, ttl: bytes): + nonlocal _ttl + _ttl, = struct.unpack('b', ttl) + + mock_socket.getsockopt = getsockopt + mock_socket.setsockopt = setsockopt protocol = MulticastProtocol('1.2.3.4', '1.2.3.4') - transport = DatagramTransport(none_socket()) - protocol.set_ttl(1) - with self.assertRaises(ValueError): - _ = protocol.get_ttl() + transport = DatagramTransport() + transport._extra = {'socket': mock_socket} + self.assertEqual(None, protocol.set_ttl(1)) + self.assertEqual(0, protocol.get_ttl()) protocol.connection_made(transport) - protocol.set_ttl(1) - with self.assertRaises(ValueError): - _ = protocol.get_ttl() + self.assertEqual(None, protocol.set_ttl(1)) + self.assertEqual(1, protocol.get_ttl()) diff --git a/tests/protocols/test_ssdp.py b/tests/protocols/test_ssdp.py index a36f1a0..243c94f 100644 --- a/tests/protocols/test_ssdp.py +++ b/tests/protocols/test_ssdp.py @@ -3,7 +3,7 @@ from aioupnp.fault import UPnPError from aioupnp.protocols.m_search_patterns import packet_generator from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.constants import SSDP_IP_ADDRESS -from aioupnp.protocols.ssdp import fuzzy_m_search, m_search +from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, SSDPProtocol from tests import AsyncioTestCase, mock_tcp_and_udp @@ -28,6 +28,13 @@ class TestSSDP(AsyncioTestCase): ]) reply_packet = SSDPDatagram("OK", reply_args) + async def test_transport_not_connected_error(self): + try: + await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()]) + self.assertTrue(False) + except UPnPError as err: + self.assertEqual(str(err), "SSDP transport not connected") + async def test_m_search_reply_unicast(self): replies = { (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode() @@ -80,3 +87,13 @@ class TestSSDP(AsyncioTestCase): self.assertEqual(reply.encode(), self.reply_packet.encode()) self.assertEqual(args, self.successful_args) + + async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self): + sent = [] + + with self.assertRaises(UPnPError): + with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent, + add_potato_datagrams=True): + await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) + + self.assertListEqual(sent, self.byte_packets) \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_cli.py index e72ea59..f7a0052 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -21,6 +21,139 @@ m_search_cli_result = """{ }\n""" +m_search_help_msg = """aioupnp [-h] [--debug_logging] m_search [--lan_address=] [--gateway_address=] + [--timeout=] [--unicast] [--interface_name=] [--
=
, ...] + +Perform a M-SEARCH for a upnp gateway. + +:param lan_address: (str) the local interface ipv4 address +:param gateway_address: (str) the gateway ipv4 address +:param timeout: (int) m search timeout +:param unicast: (bool) use unicast +:param interface_name: (str) name of the network interface +:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided. + +:return: { + 'lan_address': (str) lan address, + 'gateway_address': (str) gateway address, + 'm_search_kwargs': (str) equivalent igd_args , + 'discover_reply': (dict) SSDP response datagram +}\n +""" + +expected_usage = """aioupnp [-h] [--debug_logging] [--interface=] [--gateway_address=] + [--lan_address=] [--timeout=] [(--=)...] + +If m-search headers are provided as keyword arguments all of the headers to be used must be provided, +in the order they are to be used. For example: + aioupnp --HOST=239.255.255.250:1900 --MAN="ssdp:discover" --MX=1 --ST=upnp:rootdevice m_search + +Commands: + m_search | get_external_ip | add_port_mapping | get_port_mapping_by_index | get_redirects | + get_specific_port_mapping | delete_port_mapping | get_next_mapping + +For help with a specific command: aioupnp help +""" + +expected_get_external_ip_usage = """aioupnp [-h] [--debug_logging] get_external_ip + +Get the external ip address from the gateway + +:return: (str) external ip + +""" + +expected_add_port_mapping_usage = """aioupnp [-h] [--debug_logging] add_port_mapping [--external_port=] [--protocol=] + [--internal_port=] [--lan_address=] [--description=] + +Add a new port mapping + +:param external_port: (int) external port to map +:param protocol: (str) UDP | TCP +:param internal_port: (int) internal port +:param lan_address: (str) internal lan address +:param description: (str) mapping description +:return: None + +""" + +expected_get_next_mapping_usage = """aioupnp [-h] [--debug_logging] get_next_mapping [--port=] [--protocol=] + [--description=] [--internal_port=] + +Get a new port mapping. If the requested port is not available, increment until the next free port is mapped + +:param port: (int) external port +:param protocol: (str) UDP | TCP +:param description: (str) mapping description +:param internal_port: (int) internal port + +:return: (int) mapped port + +""" + + +expected_delete_port_mapping_usage = """aioupnp [-h] [--debug_logging] delete_port_mapping [--external_port=] [--protocol=] + +Delete a port mapping + +:param external_port: (int) port number of mapping +:param protocol: (str) TCP | UDP +:return: None + +""" + +expected_get_specific_port_mapping_usage = """aioupnp [-h] [--debug_logging] get_specific_port_mapping [--external_port=] [--protocol=] + +Get information about a port mapping by port number and protocol + +:param external_port: (int) port number +:param protocol: (str) UDP | TCP +:return: NamedTuple[ + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int +] + +""" +expected_get_redirects_usage = """aioupnp [-h] [--debug_logging] get_redirects + +Get information about all mapped ports + +:return: List[ + NamedTuple[ + gateway_address: str + external_port: int + protocol: str + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int + ] +] + +""" +expected_get_port_mapping_by_index_usage = """aioupnp [-h] [--debug_logging] get_port_mapping_by_index [--index=] + +Get information about a port mapping by index number + +:param index: (int) mapping index number +:return: NamedTuple[ + gateway_address: str + external_port: int + protocol: str + internal_port: int + lan_address: str + enabled: bool + description: str + lease_time: int +] + +""" + + class TestCLI(AsyncioTestCase): gateway_address = "10.0.0.1" soap_port = 49152 @@ -101,3 +234,117 @@ class TestCLI(AsyncioTestCase): self.loop ) self.assertEqual(m_search_cli_result, actual_output.getvalue()) + + def test_usage(self): + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help'], + self.loop + ) + self.assertEqual(expected_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'test'], + self.loop + ) + self.assertEqual(expected_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'test', 'help'], + self.loop + ) + self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'test'], + self.loop + ) + self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None], + self.loop + ) + self.assertEqual(expected_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, "--something=test"], + self.loop + ) + self.assertEqual("no command given\n" + expected_usage, actual_output.getvalue()) + + def test_commands_help(self): + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'm-search'], + self.loop + ) + self.assertEqual(m_search_help_msg, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'get-external-ip'], + self.loop + ) + + self.assertEqual(expected_get_external_ip_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'add-port-mapping'], + self.loop + ) + self.assertEqual(expected_add_port_mapping_usage, actual_output.getvalue()) + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'get-next-mapping'], + self.loop + ) + self.assertEqual(expected_get_next_mapping_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'delete_port_mapping'], + self.loop + ) + self.assertEqual(expected_delete_port_mapping_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'get_specific_port_mapping'], + self.loop + ) + self.assertEqual(expected_get_specific_port_mapping_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'get_redirects'], + self.loop + ) + self.assertEqual(expected_get_redirects_usage, actual_output.getvalue()) + + actual_output = StringIO() + with contextlib.redirect_stdout(actual_output): + main( + [None, 'help', 'get_port_mapping_by_index'], + self.loop + ) + self.assertEqual(expected_get_port_mapping_by_index_usage, actual_output.getvalue()) diff --git a/tests/test_upnp.py b/tests/test_upnp.py index 0a6d2a3..bd93570 100644 --- a/tests/test_upnp.py +++ b/tests/test_upnp.py @@ -4,17 +4,7 @@ from aioupnp.upnp import UPnP from aioupnp.fault import UPnPError from aioupnp.gateway import Gateway from aioupnp.serialization.ssdp import SSDPDatagram - - -class TestGetAnnotations(AsyncioTestCase): - def test_get_annotations(self): - expected = { - 'NewRemoteHost': str, 'NewExternalPort': int, 'NewProtocol': str, 'NewInternalPort': int, - 'NewInternalClient': str, 'NewEnabled': int, 'NewPortMappingDescription': str, - 'NewLeaseDuration': str, 'return': None - } - - self.assertDictEqual(expected, UPnP.get_annotations('AddPortMapping')) +from aioupnp.commands import GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse class UPnPCommandTestCase(AsyncioTestCase): @@ -76,7 +66,8 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase): await gateway.discover_commands(self.loop) upnp = UPnP(self.client_address, self.gateway_address, gateway) result = await upnp.get_port_mapping_by_index(0) - self.assertEqual((None, 9308, 'UDP', 9308, "11.2.3.44", True, "11.2.3.44:9308 to 9308 (UDP)", 0), result) + self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True, + "11.2.3.44:9308 to 9308 (UDP)", 0), result) class TestGetNextPortMapping(UPnPCommandTestCase): @@ -120,6 +111,9 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase): await upnp.get_specific_port_mapping(1000, 'UDP') except UPnPError: result = await upnp.get_specific_port_mapping(9308, 'UDP') - self.assertEqual((9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0), result) + self.assertEqual( + GetSpecificPortMappingEntryResponse(9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0), + result + ) else: self.assertTrue(False)