improve type hints

This commit is contained in:
binaryflesh 2019-04-24 07:11:50 -05:00
parent a181c0cb87
commit 3785318f76
6 changed files with 302 additions and 134 deletions

View file

@ -3,6 +3,8 @@ import sys
import textwrap import textwrap
from collections import OrderedDict from collections import OrderedDict
from aioupnp.upnp import UPnP from aioupnp.upnp import UPnP
from typing import Any, Optional
from asyncio import AbstractEventLoop
log = logging.getLogger("aioupnp") log = logging.getLogger("aioupnp")
handler = logging.StreamHandler() handler = logging.StreamHandler()
@ -16,7 +18,7 @@ base_usage = "\n".join(textwrap.wrap(
100, subsequent_indent=' ', break_long_words=False)) + "\n" 100, subsequent_indent=' ', break_long_words=False)) + "\n"
def get_help(command): def get_help(command: str) -> str:
fn = getattr(UPnP, command) fn = getattr(UPnP, command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return']) params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return'])
return base_usage + "\n".join( return base_usage + "\n".join(
@ -24,7 +26,7 @@ def get_help(command):
) )
def main(argv=None, loop=None): def main(argv: str = None, loop: Any[Optional[AbstractEventLoop], None] = None) -> None:
argv = argv or sys.argv argv = argv or sys.argv
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")] commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
help_str = "\n".join(textwrap.wrap( help_str = "\n".join(textwrap.wrap(

View file

@ -1,8 +1,9 @@
import logging import logging
import time import time
import typing from typing import Any, Optional
from typing import Tuple, Union, List from typing import Tuple, Union, List
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
from asyncio import AbstractEventLoop
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
none_or_str = Union[None, str] none_or_str = Union[None, str]
@ -11,34 +12,60 @@ return_type_lambas = {
} }
def safe_type(t): def safe_type(t: Any[tuple, list, dict, set]) -> Any[type, list, dict, set]:
if t is typing.Tuple: """Return input if type safe.
:param t:
:return:
"""
if isinstance(t, Tuple):
return tuple return tuple
if t is typing.List: if isinstance(t, List):
return list return list
if t is typing.Dict: if isinstance(t, dict):
return dict return dict
if t is typing.Set: if isinstance(t, set):
return set return set
return t return t
class SOAPCommand: class SOAPCommand:
"""SOAP Command."""
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None: param_types: dict, return_types: dict, param_order: list, return_order: list,
self.gateway_address = gateway_address loop: Any[Optional[AbstractEventLoop], None] = None) -> None:
self.service_port = service_port """
self.control_url = control_url
self.service_id = service_id :param gateway_address:
self.method = method :param service_port:
:param control_url:
:param service_id:
:param method:
:param param_types:
:param return_types:
:param param_order:
:param return_order:
:param loop:
"""
self.gateway_address: str = gateway_address
self.service_port: int = service_port
self.control_url: str = control_url
self.service_id: bytes = service_id
self.method: str = method
self.param_types = param_types self.param_types = param_types
self.param_order = param_order self.param_order = param_order
self.return_types = return_types self.return_types = return_types
self.return_order = return_order self.return_order = return_order
self.loop = loop self.loop: Any[AbstractEventLoop, None] = loop
self._requests: typing.List = [] self._requests: list = []
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]: async def __call__(self, **kwargs) -> Union[None, dict, list, tuple]:
"""Supports Call.
:param kwargs:
:return:
"""
if set(kwargs.keys()) != set(self.param_types.keys()): if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()} soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
@ -72,7 +99,7 @@ class SOAPCommands:
to their expected types. to their expected types.
""" """
SOAP_COMMANDS = [ SOAP_COMMANDS: List[str] = [
'AddPortMapping', 'AddPortMapping',
'GetNATRSIPStatus', 'GetNATRSIPStatus',
'GetGenericPortMappingEntry', 'GetGenericPortMappingEntry',
@ -102,10 +129,24 @@ class SOAPCommands:
] ]
def __init__(self): def __init__(self):
self._registered = set() """SOAPCommand."""
self._registered: set = set()
def register(self, base_ip: bytes, port: int, name: str, control_url: str, def register(self, base_ip: bytes, port: int, name: str, control_url: str,
service_type: bytes, inputs: List, outputs: List, loop=None) -> None: service_type: bytes, inputs: List, outputs: List,
loop: Any[Optional[AbstractEventLoop], None] = None) -> None:
"""Register Service.
:param base_ip:
:param port:
:param name:
:param control_url:
:param service_type:
:param inputs:
:param outputs:
:param loop:
:return:
"""
if name not in self.SOAP_COMMANDS or name in self._registered: if name not in self.SOAP_COMMANDS or name in self._registered:
raise AttributeError(name) raise AttributeError(name)
current = getattr(self, name) current = getattr(self, name)
@ -131,120 +172,176 @@ class SOAPCommands:
self._registered.add(command.method) self._registered.add(command.method)
@staticmethod @staticmethod
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int, async def add_port_mapping(new_remote_host: str, new_external_port: int, new_protocol: str, new_internal_port: int,
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str, new_internal_client: str, new_enabled: int, new_port_mapping_description: str,
NewLeaseDuration: str) -> None: new_lease_duration: str) -> Any:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
AddPortMapping = add_port_mapping
@staticmethod @staticmethod
async def GetNATRSIPStatus() -> Tuple[bool, bool]: async def get_NATRSIP_status() -> Any:
"""Returns (NewRSIPAvailable, NewNATEnabled)""" """Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError() raise NotImplementedError()
GetNATRSIPStatus = get_NATRSIP_status
@staticmethod @staticmethod
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str, async def get_generic_port_mapping_entry(new_port_mapping_index: int) -> Any:
bool, str, int]:
""" """
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration) NewPortMappingDescription, NewLeaseDuration)
""" """
raise NotImplementedError() raise NotImplementedError()
GetGenericPortMappingEntry = get_generic_port_mapping_entry
@staticmethod @staticmethod
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, async def get_specific_port_mapping_entry(new_remote_host: str, new_external_port: int, new_protocol: str) -> Any:
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError() raise NotImplementedError()
GetSpecificPortMappingEntry = get_specific_port_mapping_entry
@staticmethod @staticmethod
async def SetConnectionType(NewConnectionType: str) -> None: async def set_connection_type(new_conn_type: str) -> Any:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
SetConnectionType = set_connection_type
@staticmethod @staticmethod
async def GetExternalIPAddress() -> str: async def get_external_ip_address() -> Any:
"""Returns (NewExternalIPAddress)""" """Returns (NewExternalIPAddress)"""
raise NotImplementedError() raise NotImplementedError()
GetExternalIPAddress = get_external_ip_address
@staticmethod @staticmethod
async def GetConnectionTypeInfo() -> Tuple[str, str]: async def get_connection_type_info() -> Any:
"""Returns (NewConnectionType, NewPossibleConnectionTypes)""" """Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError() raise NotImplementedError()
GetConnectionTypeInfo = get_connection_type_info
@staticmethod @staticmethod
async def GetStatusInfo() -> Tuple[str, str, int]: async def get_status_info() -> Any:
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError() raise NotImplementedError()
GetStatusInfo = get_status_info
@staticmethod @staticmethod
async def ForceTermination() -> None: async def force_termination() -> Any:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None: async def delete_port_mapping(new_remote_host: str, new_external_port: int, new_protocol: str) -> Any:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
DeletePortMapping = delete_port_mapping
@staticmethod @staticmethod
async def RequestConnection() -> None: async def request_connection() -> Any:
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
RequestConnection = request_connection
@staticmethod @staticmethod
async def GetCommonLinkProperties(): async def get_common_link_properties() -> Any:
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)""" """Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError() raise NotImplementedError()
GetCommonLinkProperties = get_common_link_properties
@staticmethod @staticmethod
async def GetTotalBytesSent(): async def get_total_bytes_sent() -> Any:
"""Returns (NewTotalBytesSent)""" """Returns (NewTotalBytesSent)"""
raise NotImplementedError() raise NotImplementedError()
GetTotalBytesSent = get_total_bytes_sent
@staticmethod @staticmethod
async def GetTotalBytesReceived(): async def get_total_bytes_received() -> Any:
"""Returns (NewTotalBytesReceived)""" """Returns (NewTotalBytesReceived)"""
raise NotImplementedError() raise NotImplementedError()
GetTotalBytesRecieved = get_total_bytes_received
@staticmethod @staticmethod
async def GetTotalPacketsSent(): async def get_total_packets_sent() -> Any:
"""Returns (NewTotalPacketsSent)""" """Returns (NewTotalPacketsSent)"""
raise NotImplementedError() raise NotImplementedError()
GetTotalPacketsSent = get_total_packets_sent
@staticmethod @staticmethod
def GetTotalPacketsReceived(): def get_total_packets_received() -> Any:
"""Returns (NewTotalPacketsReceived)""" """Returns (NewTotalPacketsReceived)"""
raise NotImplementedError() raise NotImplementedError()
GetTotalPacketsReceived = get_total_packets_received
@staticmethod @staticmethod
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]: async def x_get_ICS_statistics() -> Any:
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)""" """Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError() raise NotImplementedError()
X_GetICSStatistics = x_get_ICS_statistics
@staticmethod @staticmethod
async def GetDefaultConnectionService(): async def get_default_connection_service() -> Any:
"""Returns (NewDefaultConnectionService)""" """Returns (NewDefaultConnectionService)"""
raise NotImplementedError() raise NotImplementedError()
GetDefaultConnectionService = get_default_connection_service
@staticmethod @staticmethod
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None: async def set_default_connection_service(new_default_connection_service: str) -> Any:
"""Returns (None)""" """Returns (None)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod SetDefaultConnectionService = set_default_connection_service
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None:
raise NotImplementedError()
@staticmethod @staticmethod
async def GetEnabledForInternet() -> bool: async def set_enabled_for_internet(new_enabled_for_internet: bool) -> Any:
"""
:param new_enabled_for_internet:
:return:
"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod SetEnabledForInternet = set_enabled_for_internet
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int):
raise NotImplementedError()
@staticmethod @staticmethod
async def GetActiveConnections() -> Tuple[str, str]: async def get_enabled_for_internet() -> Any:
"""
:return bool?:
"""
raise NotImplementedError()
GetEnabledForInternet = get_enabled_for_internet
@staticmethod
async def get_maximum_active_connections(new_active_connection_index: int) -> Any:
"""
:param new_active_connection_index:
:return:
"""
raise NotImplementedError()
GetMaximumActiveConnections = get_maximum_active_connections
@staticmethod
async def get_active_connections() -> Any:
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError() raise NotImplementedError()
GetActiveConnections = get_active_connections

View file

@ -1,25 +1,24 @@
POST = "POST" POST: str = "POST"
ROOT = "root" ROOT: str = "root"
SPEC_VERSION = "specVersion" SPEC_VERSION: str = "specVersion"
XML_VERSION = "<?xml version=\"1.0\"?>" XML_VERSION: str = "<?xml version=\"1.0\"?>"
FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault" FAULT: str = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope" ENVELOPE: str = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body" BODY: str = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
CONTROL: str = 'urn:schemas-upnp-org:control-1-0'
SERVICE: str = 'urn:schemas-upnp-org:service-1-0'
DEVICE: str = 'urn:schemas-upnp-org:device-1-0'
CONTROL = 'urn:schemas-upnp-org:control-1-0' WIFI_ALLIANCE_ORG_IGD: str = "urn:schemas-wifialliance-org:device:WFADevice:1"
SERVICE = 'urn:schemas-upnp-org:service-1-0' UPNP_ORG_IGD: str = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
DEVICE = 'urn:schemas-upnp-org:device-1-0'
WIFI_ALLIANCE_ORG_IGD = "urn:schemas-wifialliance-org:device:WFADevice:1" WAN_SCHEMA: str = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
UPNP_ORG_IGD = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1' LAYER_SCHEMA: str = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
IP_SCHEMA: str = 'urn:schemas-upnp-org:service:WANIPConnection:1'
WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1' SSDP_IP_ADDRESS: str = '239.255.255.250'
LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1' SSDP_PORT: int = 1900
IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1' SSDP_HOST: str = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_DISCOVER: str = "ssdp:discover"
SSDP_IP_ADDRESS = '239.255.255.250' line_separator: str = "\r\n"
SSDP_PORT = 1900
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_DISCOVER = "ssdp:discover"
line_separator = "\r\n"

View file

@ -1,22 +1,39 @@
import logging import logging
from typing import List from typing import List, Any, Optional, Dict
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class CaseInsensitive: class CaseInsensitive:
"""Case Insensitive."""
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
"""CaseInsensitive
:param kwargs:
"""
for k, v in kwargs.items(): for k, v in kwargs.items():
if not k.startswith("_"): if not k.startswith("_"):
setattr(self, k, v) setattr(self, k, v)
def __getattr__(self, item): def __getattr__(self, item: str) -> Any[str, Optional[AttributeError]]:
"""
:param item:
:return:
"""
for k in self.__class__.__dict__.keys(): for k in self.__class__.__dict__.keys():
if k.lower() == item.lower(): if k.lower() == item.lower():
return self.__dict__.get(k) return self.__dict__.get(k)
raise AttributeError(item) raise AttributeError(item)
def __setattr__(self, item, value): def __setattr__(self, item: str, value: str) -> Any[None, Optional[AttributeError]]:
"""
:param item:
:param value:
:return:
"""
for k, v in self.__class__.__dict__.items(): for k, v in self.__class__.__dict__.items():
if k.lower() == item.lower(): if k.lower() == item.lower():
self.__dict__[k] = value self.__dict__[k] = value
@ -27,12 +44,19 @@ class CaseInsensitive:
raise AttributeError(item) raise AttributeError(item)
def as_dict(self) -> dict: def as_dict(self) -> dict:
"""
:return:
"""
return { return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v) k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
} }
class Service(CaseInsensitive): class Service(CaseInsensitive):
"""
"""
serviceType = None serviceType = None
serviceId = None serviceId = None
controlURL = None controlURL = None
@ -41,13 +65,15 @@ class Service(CaseInsensitive):
class Device(CaseInsensitive): class Device(CaseInsensitive):
"""Device."""
serviceList = None serviceList = None
deviceList = None deviceList = None
deviceType = None deviceType = None
friendlyName = None friendlyName = None
manufacturer = None manufacturer = None
manufacturerURL = None manufacturerURL = None
modelDescription = None modelDescription = None
modelName = None modelName = None
modelNumber = None modelNumber = None
modelURL = None modelURL = None
@ -58,6 +84,12 @@ class Device(CaseInsensitive):
iconList = None iconList = None
def __init__(self, devices: List, services: List, **kwargs) -> None: def __init__(self, devices: List, services: List, **kwargs) -> None:
"""Device().
:param devices:
:param services:
:param kwargs:
"""
super(Device, self).__init__(**kwargs) super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList: if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"] new_services = self.serviceList["service"]

View file

@ -1,12 +1,17 @@
from aioupnp.util import flatten_keys from aioupnp.util import flatten_keys
from aioupnp.constants import FAULT, CONTROL from aioupnp.constants import FAULT, CONTROL
from typing import Dict, Any, Optional
class UPnPError(Exception): class UPnPError(Exception):
"""UPnPError."""
pass pass
def handle_fault(response: dict) -> dict: def handle_fault(response: Dict) -> Any[dict, Optional[UPnPError]]:
"""Handle Fault.
:param dict response: Response
"""
if FAULT in response: if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL) fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
error_description = fault['detail']['UPnPError']['errorDescription'] error_description = fault['detail']['UPnPError']['errorDescription']

View file

@ -1,8 +1,11 @@
import logging import logging
import socket import socket
from typing import Dict, List, Union, Type, Any, Optional, Set, Awaitable, TYPE_CHECKING, NoReturn
import asyncio import asyncio
from collections import OrderedDict if TYPE_CHECKING:
from typing import Dict, List, Union, Type from asyncio import AbstractEventLoop, TimeoutError
from collections import OrderedDict
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
@ -20,7 +23,12 @@ return_type_lambas = {
} }
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...] def get_action_list(element_dict: Dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
"""Get Action List.
:param element_dict:
:return:
"""
service_info = flatten_keys(element_dict, "{%s}" % SERVICE) service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "actionList" in service_info: if "actionList" in service_info:
action_list = service_info["actionList"] action_list = service_info["actionList"]
@ -55,51 +63,68 @@ def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...],
class Gateway: class Gateway:
"""Gateway."""
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str, def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str,
gateway_address: str) -> None: gateway_address: str) -> None:
self._ok_packet = ok_packet """Gateway object.
self._m_search_args = m_search_args
self._lan_address = lan_address
self.usn = (ok_packet.usn or '').encode()
self.ext = (ok_packet.ext or '').encode()
self.server = (ok_packet.server or '').encode()
self.location = (ok_packet.location or '').encode()
self.cache_control = (ok_packet.cache_control or '').encode()
self.date = (ok_packet.date or '').encode()
self.urn = (ok_packet.st or '').encode()
self._xml_response = b"" :param ok_packet:
:param m_search_args:
:param lan_address:
:param gateway_address:
"""
self._ok_packet: SSDPDatagram = ok_packet
self._m_search_args: OrderedDict = m_search_args
self._lan_address: str = lan_address
self.usn: bytes = (ok_packet.usn or '').encode()
self.ext: bytes = (ok_packet.ext or '').encode()
self.server: bytes = (ok_packet.server or '').encode()
self.location: bytes = (ok_packet.location or '').encode()
self.cache_control: bytes = (ok_packet.cache_control or '').encode()
self.date: bytes = (ok_packet.date or '').encode()
self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response: bytes = b''
self._service_descriptors: Dict = {} self._service_descriptors: Dict = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] self.base_address: bytes = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.port: int = int(BASE_PORT_REGEX.findall(self.location)[0])
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] self.base_ip: bytes = self.base_address.lstrip(b'http://').split(b':')[0]
assert self.base_ip == gateway_address.encode() assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] self.path: bytes = self.location.split(b'%s:%i/' % (self.base_ip, self.port))[1]
self.spec_version = None self.spec_version = None
self.url_base = None self.url_base = None
self._device: Union[None, Device] = None self._device: Any[None, Optional[Device]] = None
self._devices: List = [] self._devices: List = []
self._services: List = [] self._services: List = []
self._unsupported_actions: Dict = {} self._unsupported_actions: Dict = {}
self._registered_commands: Dict = {} self._registered_commands: Dict = {}
self.commands = SOAPCommands() self.commands = SOAPCommands()
def gateway_descriptor(self) -> dict: def gateway_descriptor(self) -> Dict:
"""Gateway Descriptor.
:return: dict
"""
r = { r = {
'server': self.server.decode(), "server": self.server.decode(),
'urlBase': self.url_base, "urlBase": self.url_base,
'location': self.location.decode(), "location": self.location.decode(),
"specVersion": self.spec_version, "specVersion": self.spec_version,
'usn': self.usn.decode(), "usn": self.usn.decode(),
'urn': self.urn.decode(), "urn": self.urn.decode(),
} }
return r return r
@property @property
def manufacturer_string(self) -> str: def manufacturer_string(self) -> str:
"""Manufacturer string.
:return str: Manufacturer string.
"""
if not self.devices: if not self.devices:
return "UNKNOWN GATEWAY" return "UNKNOWN GATEWAY"
device = list(self.devices.values())[0] device = list(self.devices.values())[0]
@ -107,17 +132,25 @@ class Gateway:
@property @property
def services(self) -> Dict: def services(self) -> Dict:
"""Services.
:return dict: Services.
"""
if not self._device: if not self._device:
return {} return {}
return {service.serviceType: service for service in self._services} return {service.serviceType: service for service in self._services}
@property @property
def devices(self) -> Dict: def devices(self) -> Dict:
"""Devices
:return dict: Devices.
"""
if not self._device: if not self._device:
return {} return {}
return {device.udn: device for device in self._devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Union[Type[Service], None]: def get_service(self, service_type: str) -> Any[Type[Service], None]:
for service in self._services: for service in self._services:
if service.serviceType.lower() == service_type.lower(): if service.serviceType.lower() == service_type.lower():
return service return service
@ -140,29 +173,27 @@ class Gateway:
def debug_gateway(self) -> Dict: def debug_gateway(self) -> Dict:
return { return {
'manufacturer_string': self.manufacturer_string, "manufacturer_string": self.manufacturer_string,
'gateway_address': self.base_ip, "gateway_address": self.base_ip,
'gateway_descriptor': self.gateway_descriptor(), "gateway_descriptor": self.gateway_descriptor(),
'gateway_xml': self._xml_response, "gateway_xml": self._xml_response,
'services_xml': self._service_descriptors, "services_xml": self._service_descriptors,
'services': {service.SCPDURL: service.as_dict() for service in self._services}, "services": {service.SCPDURL: service.as_dict() for service in self._services},
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()], "m_search_args": [(k, v) for (k, v) in self._m_search_args.items()],
'reply': self._ok_packet.as_dict(), "reply": self._ok_packet.as_dict(),
'soap_port': self.port, "soap_port": self.port,
'registered_soap_commands': self._registered_commands, "registered_soap_commands": self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions, "unsupported_soap_commands": self._unsupported_actions,
'soap_requests': self.soap_requests "soap_requests": self.soap_requests
} }
@classmethod @classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = False): igd_args: Any[Optional[OrderedDict], None] = None,
ignored: set = set() loop: Any[Optional[AbstractEventLoop], None] = None,
required_commands = [ unicast: bool = False) -> Any[__class__, None]:
'AddPortMapping', ignored: Set = set()
'DeletePortMapping', required_commands = ["AddPortMapping", "DeletePortMapping", "GetExternalIPAddress"]
'GetExternalIPAddress'
]
while True: while True:
if not igd_args: if not igd_args:
m_search_args, datagram = await fuzzy_m_search( m_search_args, datagram = await fuzzy_m_search(
@ -173,28 +204,30 @@ class Gateway:
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast) datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
try: try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address) gateway = cls(datagram, m_search_args, lan_address, gateway_address)
log.debug('get gateway descriptor %s', datagram.location) log.debug("Get gateway descriptor %s.", datagram.location)
await gateway.discover_commands(loop) await gateway.discover_commands(loop)
requirements_met = all([required in gateway._registered_commands for required in required_commands]) requirements_met = all([required in gateway._registered_commands for required in required_commands])
if not requirements_met: if not requirements_met:
not_met = [ not_met = [
required for required in required_commands if required not in gateway._registered_commands required for required in required_commands if required not in gateway._registered_commands
] ]
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s", log.debug("Found gateway %s at %s, however it does not implement required soap commands: %s.",
gateway.manufacturer_string, gateway.location, not_met) gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location) ignored.add(datagram.location)
continue continue
else: else:
log.debug('found gateway device %s', datagram.location) log.debug("Found gateway device %s.", datagram.location)
return gateway return gateway
except (asyncio.TimeoutError, UPnPError) as err: except (TimeoutError, UPnPError) as err:
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) log.debug("Get %s failed (%s), looking for other devices.", datagram.location, str(err))
ignored.add(datagram.location) ignored.add(datagram.location)
continue continue
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = None): igd_args: Any[Optional[OrderedDict], None] = None,
loop = Any[Optional[AbstractEventLoop], None],
unicast: Any[Optional[bool], None] = None) -> __class__():
if unicast is not None: if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
@ -217,7 +250,7 @@ class Gateway:
return list(done)[0].result() return list(done)[0].result()
async def discover_commands(self, loop=None): async def discover_commands(self, loop: Any[Optional[AbstractEventLoop], None] = None) -> NoReturn:
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop) response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
self._xml_response = xml_bytes self._xml_response = xml_bytes
if get_err is not None: if get_err is not None:
@ -236,9 +269,9 @@ class Gateway:
for service_type in self.services.keys(): for service_type in self.services.keys():
await self.register_commands(self.services[service_type], loop) await self.register_commands(self.services[service_type], loop)
async def register_commands(self, service: Service, loop=None): async def register_commands(self, service: Service, loop: Any[Optional[AbstractEventLoop], None] = None) -> Any[None, Optional[UPnPError]]:
if not service.SCPDURL: if not service.SCPDURL:
raise UPnPError("no scpd url") raise UPnPError("No scpd url.")
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
@ -247,7 +280,7 @@ class Gateway:
if get_err is not None: if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL) log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
if xml_bytes: if xml_bytes:
log.debug("response: %s", xml_bytes.decode()) log.debug("Response: %s.", xml_bytes)
return return
if not service_dict: if not service_dict:
return return