aioupnp/aioupnp/gateway.py

283 lines
13 KiB
Python

import logging
import socket
import asyncio
from collections import OrderedDict
from typing import Dict, List, Union, Type
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from aioupnp.protocols.scpd import scpd_get
from aioupnp.protocols.soap import SOAPCommand
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError
log = logging.getLogger(__name__)
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "actionList" in service_info:
action_list = service_info["actionList"]
else:
return []
if not len(action_list): # it could be an empty string
return []
result: list = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts]
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else:
arg_dicts = action['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts]
result.append((
action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
class Gateway:
def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str,
gateway_address: str) -> None:
self._ok_packet = ok_packet
self._m_search_args = m_search_args
self._lan_address = lan_address
self.usn = (ok_packet.usn or '').encode()
self.ext = (ok_packet.ext or '').encode()
self.server = (ok_packet.server or '').encode()
self.location = (ok_packet.location or '').encode()
self.cache_control = (ok_packet.cache_control or '').encode()
self.date = (ok_packet.date or '').encode()
self.urn = (ok_packet.st or '').encode()
self._xml_response = b""
self._service_descriptors: Dict = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version = None
self.url_base = None
self._device: Union[None, Device] = None
self._devices: List = []
self._services: List = []
self._unsupported_actions: Dict = {}
self._registered_commands: Dict = {}
self.commands = SOAPCommands()
def gateway_descriptor(self) -> dict:
r = {
'server': self.server.decode(),
'urlBase': self.url_base,
'location': self.location.decode(),
"specVersion": self.spec_version,
'usn': self.usn.decode(),
'urn': self.urn.decode(),
}
return r
@property
def manufacturer_string(self) -> str:
if not self.devices:
return "UNKNOWN GATEWAY"
device = list(self.devices.values())[0]
return "%s %s" % (device.manufacturer, device.modelName)
@property
def services(self) -> Dict:
if not self._device:
return {}
return {service.serviceType: service for service in self._services}
@property
def devices(self) -> Dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Union[Type[Service], None]:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
return service
return None
@property
def soap_requests(self) -> List:
soap_call_infos = []
for name in self._registered_commands.keys():
if not hasattr(getattr(self.commands, name), "_requests"):
continue
soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
request_args, raw_response, decoded_response, soap_error, ts
) in getattr(self.commands, name)._requests
])
soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos
def debug_gateway(self) -> Dict:
return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip,
'gateway_descriptor': self.gateway_descriptor(),
'gateway_xml': self._xml_response,
'services_xml': self._service_descriptors,
'services': {service.SCPDURL: service.as_dict() for service in self._services},
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
'reply': self._ok_packet.as_dict(),
'soap_port': self.port,
'registered_soap_commands': self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions,
'soap_requests': self.soap_requests
}
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None, unicast: bool = False):
ignored: set = set()
required_commands = [
'AddPortMapping',
'DeletePortMapping',
'GetExternalIPAddress'
]
while True:
if not igd_args:
m_search_args, datagram = await asyncio.wait_for(fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket,
ignored, unicast), timeout)
else:
m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket, ignored,
unicast)
try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands(soap_socket)
requirements_met = all([required in gateway._registered_commands for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if required not in gateway._registered_commands
]
log.warning("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
continue
else:
log.debug('found gateway device %s', datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
ignored.add(datagram.location)
continue
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None, unicast: bool = None):
if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, ssdp_socket,
soap_socket, unicast=unicast)
done, pending = await asyncio.wait([
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=True
),
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=False
)], return_when=asyncio.tasks.FIRST_COMPLETED
)
for task in list(pending):
task.cancel()
result = list(done)[0].result()
return result
async def discover_commands(self, soap_socket: socket.socket = None):
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
self._xml_response = xml_bytes
if get_err is not None:
raise get_err
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base:
self.url_base = self.base_address.decode()
if response:
self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
)
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
await self.register_commands(self.services[service_type], soap_socket)
async def register_commands(self, service: Service, soap_socket: socket.socket = None):
if not service.SCPDURL:
raise UPnPError("no scpd url")
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
self._service_descriptors[service.SCPDURL] = xml_bytes
if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
if xml_bytes:
log.debug("response: %s", xml_bytes.decode())
return
if not service_dict:
return
action_list = get_action_list(service_dict)
for name, inputs, outputs in action_list:
try:
current = getattr(self.commands, name)
annotations = current.__annotations__
return_types = annotations.get('return', None)
if return_types:
if hasattr(return_types, '__args__'):
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
elif isinstance(return_types, type):
return_types = (return_types, )
return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {}
for param_name, param_type in annotations.items():
if param_name == "return":
continue
param_types[param_name] = param_type
command = SOAPCommand(
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
name, param_types, return_types, inputs, outputs, soap_socket)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)
self._registered_commands[command.method] = service.serviceType
log.debug("registered %s::%s", service.serviceType, command.method)
except AttributeError:
s = self._unsupported_actions.get(service.serviceType, [])
s.append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)
log.debug("registered service %s", service.serviceType)