improve type hints
This commit is contained in:
parent
a181c0cb87
commit
3785318f76
6 changed files with 302 additions and 134 deletions
|
@ -3,6 +3,8 @@ import sys
|
|||
import textwrap
|
||||
from collections import OrderedDict
|
||||
from aioupnp.upnp import UPnP
|
||||
from typing import Any, Optional
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
log = logging.getLogger("aioupnp")
|
||||
handler = logging.StreamHandler()
|
||||
|
@ -16,7 +18,7 @@ base_usage = "\n".join(textwrap.wrap(
|
|||
100, subsequent_indent=' ', break_long_words=False)) + "\n"
|
||||
|
||||
|
||||
def get_help(command):
|
||||
def get_help(command: str) -> str:
|
||||
fn = getattr(UPnP, command)
|
||||
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return'])
|
||||
return base_usage + "\n".join(
|
||||
|
@ -24,7 +26,7 @@ def get_help(command):
|
|||
)
|
||||
|
||||
|
||||
def main(argv=None, loop=None):
|
||||
def main(argv: str = None, loop: Any[Optional[AbstractEventLoop], None] = None) -> None:
|
||||
argv = argv or sys.argv
|
||||
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
|
||||
help_str = "\n".join(textwrap.wrap(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import logging
|
||||
import time
|
||||
import typing
|
||||
from typing import Any, Optional
|
||||
from typing import Tuple, Union, List
|
||||
from aioupnp.protocols.scpd import scpd_post
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
none_or_str = Union[None, str]
|
||||
|
@ -11,34 +12,60 @@ return_type_lambas = {
|
|||
}
|
||||
|
||||
|
||||
def safe_type(t):
|
||||
if t is typing.Tuple:
|
||||
def safe_type(t: Any[tuple, list, dict, set]) -> Any[type, list, dict, set]:
|
||||
"""Return input if type safe.
|
||||
|
||||
:param t:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(t, Tuple):
|
||||
return tuple
|
||||
if t is typing.List:
|
||||
if isinstance(t, List):
|
||||
return list
|
||||
if t is typing.Dict:
|
||||
if isinstance(t, dict):
|
||||
return dict
|
||||
if t is typing.Set:
|
||||
if isinstance(t, set):
|
||||
return set
|
||||
return t
|
||||
|
||||
|
||||
class SOAPCommand:
|
||||
"""SOAP Command."""
|
||||
|
||||
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
|
||||
param_types: dict, return_types: dict, param_order: list, return_order: list, loop=None) -> None:
|
||||
self.gateway_address = gateway_address
|
||||
self.service_port = service_port
|
||||
self.control_url = control_url
|
||||
self.service_id = service_id
|
||||
self.method = method
|
||||
param_types: dict, return_types: dict, param_order: list, return_order: list,
|
||||
loop: Any[Optional[AbstractEventLoop], None] = None) -> None:
|
||||
"""
|
||||
|
||||
:param gateway_address:
|
||||
:param service_port:
|
||||
:param control_url:
|
||||
:param service_id:
|
||||
:param method:
|
||||
:param param_types:
|
||||
:param return_types:
|
||||
:param param_order:
|
||||
:param return_order:
|
||||
:param loop:
|
||||
"""
|
||||
self.gateway_address: str = gateway_address
|
||||
self.service_port: int = service_port
|
||||
self.control_url: str = control_url
|
||||
self.service_id: bytes = service_id
|
||||
self.method: str = method
|
||||
self.param_types = param_types
|
||||
self.param_order = param_order
|
||||
self.return_types = return_types
|
||||
self.return_order = return_order
|
||||
self.loop = loop
|
||||
self._requests: typing.List = []
|
||||
self.loop: Any[AbstractEventLoop, None] = loop
|
||||
self._requests: list = []
|
||||
|
||||
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
|
||||
async def __call__(self, **kwargs) -> Union[None, dict, list, tuple]:
|
||||
"""Supports Call.
|
||||
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if set(kwargs.keys()) != set(self.param_types.keys()):
|
||||
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
|
||||
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
|
||||
|
@ -72,7 +99,7 @@ class SOAPCommands:
|
|||
to their expected types.
|
||||
"""
|
||||
|
||||
SOAP_COMMANDS = [
|
||||
SOAP_COMMANDS: List[str] = [
|
||||
'AddPortMapping',
|
||||
'GetNATRSIPStatus',
|
||||
'GetGenericPortMappingEntry',
|
||||
|
@ -102,10 +129,24 @@ class SOAPCommands:
|
|||
]
|
||||
|
||||
def __init__(self):
|
||||
self._registered = set()
|
||||
"""SOAPCommand."""
|
||||
self._registered: set = set()
|
||||
|
||||
def register(self, base_ip: bytes, port: int, name: str, control_url: str,
|
||||
service_type: bytes, inputs: List, outputs: List, loop=None) -> None:
|
||||
service_type: bytes, inputs: List, outputs: List,
|
||||
loop: Any[Optional[AbstractEventLoop], None] = None) -> None:
|
||||
"""Register Service.
|
||||
|
||||
:param base_ip:
|
||||
:param port:
|
||||
:param name:
|
||||
:param control_url:
|
||||
:param service_type:
|
||||
:param inputs:
|
||||
:param outputs:
|
||||
:param loop:
|
||||
:return:
|
||||
"""
|
||||
if name not in self.SOAP_COMMANDS or name in self._registered:
|
||||
raise AttributeError(name)
|
||||
current = getattr(self, name)
|
||||
|
@ -131,120 +172,176 @@ class SOAPCommands:
|
|||
self._registered.add(command.method)
|
||||
|
||||
@staticmethod
|
||||
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
|
||||
NewInternalClient: str, NewEnabled: int, NewPortMappingDescription: str,
|
||||
NewLeaseDuration: str) -> None:
|
||||
async def add_port_mapping(new_remote_host: str, new_external_port: int, new_protocol: str, new_internal_port: int,
|
||||
new_internal_client: str, new_enabled: int, new_port_mapping_description: str,
|
||||
new_lease_duration: str) -> Any:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
AddPortMapping = add_port_mapping
|
||||
|
||||
@staticmethod
|
||||
async def GetNATRSIPStatus() -> Tuple[bool, bool]:
|
||||
async def get_NATRSIP_status() -> Any:
|
||||
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetNATRSIPStatus = get_NATRSIP_status
|
||||
|
||||
@staticmethod
|
||||
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
|
||||
bool, str, int]:
|
||||
async def get_generic_port_mapping_entry(new_port_mapping_index: int) -> Any:
|
||||
"""
|
||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||
NewPortMappingDescription, NewLeaseDuration)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetGenericPortMappingEntry = get_generic_port_mapping_entry
|
||||
|
||||
@staticmethod
|
||||
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
|
||||
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
|
||||
async def get_specific_port_mapping_entry(new_remote_host: str, new_external_port: int, new_protocol: str) -> Any:
|
||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetSpecificPortMappingEntry = get_specific_port_mapping_entry
|
||||
|
||||
@staticmethod
|
||||
async def SetConnectionType(NewConnectionType: str) -> None:
|
||||
async def set_connection_type(new_conn_type: str) -> Any:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
SetConnectionType = set_connection_type
|
||||
|
||||
@staticmethod
|
||||
async def GetExternalIPAddress() -> str:
|
||||
async def get_external_ip_address() -> Any:
|
||||
"""Returns (NewExternalIPAddress)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetExternalIPAddress = get_external_ip_address
|
||||
|
||||
@staticmethod
|
||||
async def GetConnectionTypeInfo() -> Tuple[str, str]:
|
||||
async def get_connection_type_info() -> Any:
|
||||
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetConnectionTypeInfo = get_connection_type_info
|
||||
|
||||
@staticmethod
|
||||
async def GetStatusInfo() -> Tuple[str, str, int]:
|
||||
async def get_status_info() -> Any:
|
||||
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetStatusInfo = get_status_info
|
||||
|
||||
@staticmethod
|
||||
async def ForceTermination() -> None:
|
||||
async def force_termination() -> Any:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
|
||||
async def delete_port_mapping(new_remote_host: str, new_external_port: int, new_protocol: str) -> Any:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
DeletePortMapping = delete_port_mapping
|
||||
|
||||
@staticmethod
|
||||
async def RequestConnection() -> None:
|
||||
async def request_connection() -> Any:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
RequestConnection = request_connection
|
||||
|
||||
@staticmethod
|
||||
async def GetCommonLinkProperties():
|
||||
async def get_common_link_properties() -> Any:
|
||||
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetCommonLinkProperties = get_common_link_properties
|
||||
|
||||
@staticmethod
|
||||
async def GetTotalBytesSent():
|
||||
async def get_total_bytes_sent() -> Any:
|
||||
"""Returns (NewTotalBytesSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetTotalBytesSent = get_total_bytes_sent
|
||||
|
||||
@staticmethod
|
||||
async def GetTotalBytesReceived():
|
||||
async def get_total_bytes_received() -> Any:
|
||||
"""Returns (NewTotalBytesReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetTotalBytesRecieved = get_total_bytes_received
|
||||
|
||||
@staticmethod
|
||||
async def GetTotalPacketsSent():
|
||||
async def get_total_packets_sent() -> Any:
|
||||
"""Returns (NewTotalPacketsSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetTotalPacketsSent = get_total_packets_sent
|
||||
|
||||
@staticmethod
|
||||
def GetTotalPacketsReceived():
|
||||
def get_total_packets_received() -> Any:
|
||||
"""Returns (NewTotalPacketsReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetTotalPacketsReceived = get_total_packets_received
|
||||
|
||||
@staticmethod
|
||||
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]:
|
||||
async def x_get_ICS_statistics() -> Any:
|
||||
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
X_GetICSStatistics = x_get_ICS_statistics
|
||||
|
||||
@staticmethod
|
||||
async def GetDefaultConnectionService():
|
||||
async def get_default_connection_service() -> Any:
|
||||
"""Returns (NewDefaultConnectionService)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetDefaultConnectionService = get_default_connection_service
|
||||
|
||||
@staticmethod
|
||||
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None:
|
||||
async def set_default_connection_service(new_default_connection_service: str) -> Any:
|
||||
"""Returns (None)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None:
|
||||
raise NotImplementedError()
|
||||
SetDefaultConnectionService = set_default_connection_service
|
||||
|
||||
@staticmethod
|
||||
async def GetEnabledForInternet() -> bool:
|
||||
async def set_enabled_for_internet(new_enabled_for_internet: bool) -> Any:
|
||||
"""
|
||||
|
||||
:param new_enabled_for_internet:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int):
|
||||
raise NotImplementedError()
|
||||
SetEnabledForInternet = set_enabled_for_internet
|
||||
|
||||
@staticmethod
|
||||
async def GetActiveConnections() -> Tuple[str, str]:
|
||||
async def get_enabled_for_internet() -> Any:
|
||||
"""
|
||||
|
||||
:return bool?:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetEnabledForInternet = get_enabled_for_internet
|
||||
|
||||
@staticmethod
|
||||
async def get_maximum_active_connections(new_active_connection_index: int) -> Any:
|
||||
"""
|
||||
|
||||
:param new_active_connection_index:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetMaximumActiveConnections = get_maximum_active_connections
|
||||
|
||||
@staticmethod
|
||||
async def get_active_connections() -> Any:
|
||||
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||
raise NotImplementedError()
|
||||
|
||||
GetActiveConnections = get_active_connections
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
POST = "POST"
|
||||
ROOT = "root"
|
||||
SPEC_VERSION = "specVersion"
|
||||
XML_VERSION = "<?xml version=\"1.0\"?>"
|
||||
FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
|
||||
ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
|
||||
BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
|
||||
POST: str = "POST"
|
||||
ROOT: str = "root"
|
||||
SPEC_VERSION: str = "specVersion"
|
||||
XML_VERSION: str = "<?xml version=\"1.0\"?>"
|
||||
FAULT: str = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
|
||||
ENVELOPE: str = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
|
||||
BODY: str = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
|
||||
|
||||
CONTROL: str = 'urn:schemas-upnp-org:control-1-0'
|
||||
SERVICE: str = 'urn:schemas-upnp-org:service-1-0'
|
||||
DEVICE: str = 'urn:schemas-upnp-org:device-1-0'
|
||||
|
||||
CONTROL = 'urn:schemas-upnp-org:control-1-0'
|
||||
SERVICE = 'urn:schemas-upnp-org:service-1-0'
|
||||
DEVICE = 'urn:schemas-upnp-org:device-1-0'
|
||||
WIFI_ALLIANCE_ORG_IGD: str = "urn:schemas-wifialliance-org:device:WFADevice:1"
|
||||
UPNP_ORG_IGD: str = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
|
||||
|
||||
WIFI_ALLIANCE_ORG_IGD = "urn:schemas-wifialliance-org:device:WFADevice:1"
|
||||
UPNP_ORG_IGD = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
|
||||
WAN_SCHEMA: str = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
|
||||
LAYER_SCHEMA: str = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
|
||||
IP_SCHEMA: str = 'urn:schemas-upnp-org:service:WANIPConnection:1'
|
||||
|
||||
WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
|
||||
LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
|
||||
IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
|
||||
|
||||
SSDP_IP_ADDRESS = '239.255.255.250'
|
||||
SSDP_PORT = 1900
|
||||
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
|
||||
SSDP_DISCOVER = "ssdp:discover"
|
||||
line_separator = "\r\n"
|
||||
SSDP_IP_ADDRESS: str = '239.255.255.250'
|
||||
SSDP_PORT: int = 1900
|
||||
SSDP_HOST: str = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
|
||||
SSDP_DISCOVER: str = "ssdp:discover"
|
||||
line_separator: str = "\r\n"
|
||||
|
|
|
@ -1,22 +1,39 @@
|
|||
import logging
|
||||
from typing import List
|
||||
from typing import List, Any, Optional, Dict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CaseInsensitive:
|
||||
"""Case Insensitive."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""CaseInsensitive
|
||||
|
||||
:param kwargs:
|
||||
"""
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith("_"):
|
||||
setattr(self, k, v)
|
||||
|
||||
def __getattr__(self, item):
|
||||
def __getattr__(self, item: str) -> Any[str, Optional[AttributeError]]:
|
||||
"""
|
||||
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
for k in self.__class__.__dict__.keys():
|
||||
if k.lower() == item.lower():
|
||||
return self.__dict__.get(k)
|
||||
raise AttributeError(item)
|
||||
|
||||
def __setattr__(self, item, value):
|
||||
def __setattr__(self, item: str, value: str) -> Any[None, Optional[AttributeError]]:
|
||||
"""
|
||||
|
||||
:param item:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
if k.lower() == item.lower():
|
||||
self.__dict__[k] = value
|
||||
|
@ -27,12 +44,19 @@ class CaseInsensitive:
|
|||
raise AttributeError(item)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
"""
|
||||
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
|
||||
}
|
||||
|
||||
|
||||
class Service(CaseInsensitive):
|
||||
"""
|
||||
|
||||
"""
|
||||
serviceType = None
|
||||
serviceId = None
|
||||
controlURL = None
|
||||
|
@ -41,13 +65,15 @@ class Service(CaseInsensitive):
|
|||
|
||||
|
||||
class Device(CaseInsensitive):
|
||||
"""Device."""
|
||||
|
||||
serviceList = None
|
||||
deviceList = None
|
||||
deviceType = None
|
||||
friendlyName = None
|
||||
manufacturer = None
|
||||
manufacturerURL = None
|
||||
modelDescription = None
|
||||
modelDescription = None
|
||||
modelName = None
|
||||
modelNumber = None
|
||||
modelURL = None
|
||||
|
@ -58,6 +84,12 @@ class Device(CaseInsensitive):
|
|||
iconList = None
|
||||
|
||||
def __init__(self, devices: List, services: List, **kwargs) -> None:
|
||||
"""Device().
|
||||
|
||||
:param devices:
|
||||
:param services:
|
||||
:param kwargs:
|
||||
"""
|
||||
super(Device, self).__init__(**kwargs)
|
||||
if self.serviceList and "service" in self.serviceList:
|
||||
new_services = self.serviceList["service"]
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
from aioupnp.util import flatten_keys
|
||||
from aioupnp.constants import FAULT, CONTROL
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
class UPnPError(Exception):
|
||||
"""UPnPError."""
|
||||
pass
|
||||
|
||||
|
||||
def handle_fault(response: dict) -> dict:
|
||||
def handle_fault(response: Dict) -> Any[dict, Optional[UPnPError]]:
|
||||
"""Handle Fault.
|
||||
|
||||
:param dict response: Response
|
||||
"""
|
||||
if FAULT in response:
|
||||
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
|
||||
error_description = fault['detail']['UPnPError']['errorDescription']
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import socket
|
||||
from typing import Dict, List, Union, Type, Any, Optional, Set, Awaitable, TYPE_CHECKING, NoReturn
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Union, Type
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import AbstractEventLoop, TimeoutError
|
||||
from collections import OrderedDict
|
||||
|
||||
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||
from aioupnp.constants import SPEC_VERSION, SERVICE
|
||||
from aioupnp.commands import SOAPCommands
|
||||
|
@ -20,7 +23,12 @@ return_type_lambas = {
|
|||
}
|
||||
|
||||
|
||||
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
|
||||
def get_action_list(element_dict: Dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
|
||||
"""Get Action List.
|
||||
|
||||
:param element_dict:
|
||||
:return:
|
||||
"""
|
||||
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
|
||||
if "actionList" in service_info:
|
||||
action_list = service_info["actionList"]
|
||||
|
@ -55,51 +63,68 @@ def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...],
|
|||
|
||||
|
||||
class Gateway:
|
||||
"""Gateway."""
|
||||
|
||||
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str,
|
||||
gateway_address: str) -> None:
|
||||
self._ok_packet = ok_packet
|
||||
self._m_search_args = m_search_args
|
||||
self._lan_address = lan_address
|
||||
self.usn = (ok_packet.usn or '').encode()
|
||||
self.ext = (ok_packet.ext or '').encode()
|
||||
self.server = (ok_packet.server or '').encode()
|
||||
self.location = (ok_packet.location or '').encode()
|
||||
self.cache_control = (ok_packet.cache_control or '').encode()
|
||||
self.date = (ok_packet.date or '').encode()
|
||||
self.urn = (ok_packet.st or '').encode()
|
||||
"""Gateway object.
|
||||
|
||||
self._xml_response = b""
|
||||
:param ok_packet:
|
||||
:param m_search_args:
|
||||
:param lan_address:
|
||||
:param gateway_address:
|
||||
"""
|
||||
self._ok_packet: SSDPDatagram = ok_packet
|
||||
self._m_search_args: OrderedDict = m_search_args
|
||||
self._lan_address: str = lan_address
|
||||
self.usn: bytes = (ok_packet.usn or '').encode()
|
||||
self.ext: bytes = (ok_packet.ext or '').encode()
|
||||
self.server: bytes = (ok_packet.server or '').encode()
|
||||
self.location: bytes = (ok_packet.location or '').encode()
|
||||
self.cache_control: bytes = (ok_packet.cache_control or '').encode()
|
||||
self.date: bytes = (ok_packet.date or '').encode()
|
||||
self.urn: bytes = (ok_packet.st or '').encode()
|
||||
|
||||
self._xml_response: bytes = b''
|
||||
self._service_descriptors: Dict = {}
|
||||
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
|
||||
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
|
||||
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
|
||||
self.base_address: bytes = BASE_ADDRESS_REGEX.findall(self.location)[0]
|
||||
self.port: int = int(BASE_PORT_REGEX.findall(self.location)[0])
|
||||
self.base_ip: bytes = self.base_address.lstrip(b'http://').split(b':')[0]
|
||||
assert self.base_ip == gateway_address.encode()
|
||||
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
|
||||
self.path: bytes = self.location.split(b'%s:%i/' % (self.base_ip, self.port))[1]
|
||||
|
||||
self.spec_version = None
|
||||
self.url_base = None
|
||||
|
||||
self._device: Union[None, Device] = None
|
||||
self._device: Any[None, Optional[Device]] = None
|
||||
self._devices: List = []
|
||||
self._services: List = []
|
||||
|
||||
self._unsupported_actions: Dict = {}
|
||||
self._registered_commands: Dict = {}
|
||||
self.commands = SOAPCommands()
|
||||
self.commands = SOAPCommands()
|
||||
|
||||
def gateway_descriptor(self) -> dict:
|
||||
def gateway_descriptor(self) -> Dict:
|
||||
"""Gateway Descriptor.
|
||||
|
||||
:return: dict
|
||||
"""
|
||||
r = {
|
||||
'server': self.server.decode(),
|
||||
'urlBase': self.url_base,
|
||||
'location': self.location.decode(),
|
||||
"server": self.server.decode(),
|
||||
"urlBase": self.url_base,
|
||||
"location": self.location.decode(),
|
||||
"specVersion": self.spec_version,
|
||||
'usn': self.usn.decode(),
|
||||
'urn': self.urn.decode(),
|
||||
"usn": self.usn.decode(),
|
||||
"urn": self.urn.decode(),
|
||||
}
|
||||
return r
|
||||
|
||||
@property
|
||||
def manufacturer_string(self) -> str:
|
||||
"""Manufacturer string.
|
||||
|
||||
:return str: Manufacturer string.
|
||||
"""
|
||||
if not self.devices:
|
||||
return "UNKNOWN GATEWAY"
|
||||
device = list(self.devices.values())[0]
|
||||
|
@ -107,17 +132,25 @@ class Gateway:
|
|||
|
||||
@property
|
||||
def services(self) -> Dict:
|
||||
"""Services.
|
||||
|
||||
:return dict: Services.
|
||||
"""
|
||||
if not self._device:
|
||||
return {}
|
||||
return {service.serviceType: service for service in self._services}
|
||||
|
||||
@property
|
||||
def devices(self) -> Dict:
|
||||
"""Devices
|
||||
|
||||
:return dict: Devices.
|
||||
"""
|
||||
if not self._device:
|
||||
return {}
|
||||
return {device.udn: device for device in self._devices}
|
||||
|
||||
def get_service(self, service_type: str) -> Union[Type[Service], None]:
|
||||
def get_service(self, service_type: str) -> Any[Type[Service], None]:
|
||||
for service in self._services:
|
||||
if service.serviceType.lower() == service_type.lower():
|
||||
return service
|
||||
|
@ -140,29 +173,27 @@ class Gateway:
|
|||
|
||||
def debug_gateway(self) -> Dict:
|
||||
return {
|
||||
'manufacturer_string': self.manufacturer_string,
|
||||
'gateway_address': self.base_ip,
|
||||
'gateway_descriptor': self.gateway_descriptor(),
|
||||
'gateway_xml': self._xml_response,
|
||||
'services_xml': self._service_descriptors,
|
||||
'services': {service.SCPDURL: service.as_dict() for service in self._services},
|
||||
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
|
||||
'reply': self._ok_packet.as_dict(),
|
||||
'soap_port': self.port,
|
||||
'registered_soap_commands': self._registered_commands,
|
||||
'unsupported_soap_commands': self._unsupported_actions,
|
||||
'soap_requests': self.soap_requests
|
||||
"manufacturer_string": self.manufacturer_string,
|
||||
"gateway_address": self.base_ip,
|
||||
"gateway_descriptor": self.gateway_descriptor(),
|
||||
"gateway_xml": self._xml_response,
|
||||
"services_xml": self._service_descriptors,
|
||||
"services": {service.SCPDURL: service.as_dict() for service in self._services},
|
||||
"m_search_args": [(k, v) for (k, v) in self._m_search_args.items()],
|
||||
"reply": self._ok_packet.as_dict(),
|
||||
"soap_port": self.port,
|
||||
"registered_soap_commands": self._registered_commands,
|
||||
"unsupported_soap_commands": self._unsupported_actions,
|
||||
"soap_requests": self.soap_requests
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: OrderedDict = None, loop=None, unicast: bool = False):
|
||||
ignored: set = set()
|
||||
required_commands = [
|
||||
'AddPortMapping',
|
||||
'DeletePortMapping',
|
||||
'GetExternalIPAddress'
|
||||
]
|
||||
igd_args: Any[Optional[OrderedDict], None] = None,
|
||||
loop: Any[Optional[AbstractEventLoop], None] = None,
|
||||
unicast: bool = False) -> Any[__class__, None]:
|
||||
ignored: Set = set()
|
||||
required_commands = ["AddPortMapping", "DeletePortMapping", "GetExternalIPAddress"]
|
||||
while True:
|
||||
if not igd_args:
|
||||
m_search_args, datagram = await fuzzy_m_search(
|
||||
|
@ -173,28 +204,30 @@ class Gateway:
|
|||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
|
||||
try:
|
||||
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
|
||||
log.debug('get gateway descriptor %s', datagram.location)
|
||||
log.debug("Get gateway descriptor %s.", datagram.location)
|
||||
await gateway.discover_commands(loop)
|
||||
requirements_met = all([required in gateway._registered_commands for required in required_commands])
|
||||
if not requirements_met:
|
||||
not_met = [
|
||||
required for required in required_commands if required not in gateway._registered_commands
|
||||
]
|
||||
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
|
||||
log.debug("Found gateway %s at %s, however it does not implement required soap commands: %s.",
|
||||
gateway.manufacturer_string, gateway.location, not_met)
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
else:
|
||||
log.debug('found gateway device %s', datagram.location)
|
||||
log.debug("Found gateway device %s.", datagram.location)
|
||||
return gateway
|
||||
except (asyncio.TimeoutError, UPnPError) as err:
|
||||
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
|
||||
except (TimeoutError, UPnPError) as err:
|
||||
log.debug("Get %s failed (%s), looking for other devices.", datagram.location, str(err))
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: OrderedDict = None, loop=None, unicast: bool = None):
|
||||
igd_args: Any[Optional[OrderedDict], None] = None,
|
||||
loop = Any[Optional[AbstractEventLoop], None],
|
||||
unicast: Any[Optional[bool], None] = None) -> __class__():
|
||||
if unicast is not None:
|
||||
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
|
||||
|
||||
|
@ -217,7 +250,7 @@ class Gateway:
|
|||
|
||||
return list(done)[0].result()
|
||||
|
||||
async def discover_commands(self, loop=None):
|
||||
async def discover_commands(self, loop: Any[Optional[AbstractEventLoop], None] = None) -> NoReturn:
|
||||
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
|
||||
self._xml_response = xml_bytes
|
||||
if get_err is not None:
|
||||
|
@ -236,9 +269,9 @@ class Gateway:
|
|||
for service_type in self.services.keys():
|
||||
await self.register_commands(self.services[service_type], loop)
|
||||
|
||||
async def register_commands(self, service: Service, loop=None):
|
||||
async def register_commands(self, service: Service, loop: Any[Optional[AbstractEventLoop], None] = None) -> Any[None, Optional[UPnPError]]:
|
||||
if not service.SCPDURL:
|
||||
raise UPnPError("no scpd url")
|
||||
raise UPnPError("No scpd url.")
|
||||
|
||||
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
|
||||
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
|
||||
|
@ -247,7 +280,7 @@ class Gateway:
|
|||
if get_err is not None:
|
||||
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
|
||||
if xml_bytes:
|
||||
log.debug("response: %s", xml_bytes.decode())
|
||||
log.debug("Response: %s.", xml_bytes)
|
||||
return
|
||||
if not service_dict:
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue