aioupnp/aioupnp/gateway.py
2020-10-13 15:22:49 -04:00

304 lines
14 KiB
Python

import re
import logging
import typing
import asyncio
from typing import Dict, List, Optional
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import m_search, multi_m_search
from aioupnp.protocols.scpd import scpd_get
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError
log = logging.getLogger(__name__)
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def get_action_list(element_dict: typing.Dict[str, typing.Union[str, typing.Dict[str, str],
typing.List[typing.Dict[str, typing.Dict[str, str]]]]]
) -> typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]]:
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
result: typing.List[typing.Tuple[str, typing.List[str], typing.List[str]]] = []
if "actionList" in service_info:
action_list = service_info["actionList"]
else:
return result
if not len(action_list): # it could be an empty string
return result
action = action_list["action"]
if isinstance(action, dict):
arg_dicts: typing.List[typing.Dict[str, str]] = []
if not isinstance(action['argumentList']['argument'], list): # when there is one arg
arg_dicts.extend([action['argumentList']['argument']])
else:
arg_dicts.extend(action['argumentList']['argument'])
result.append((action_list["action"]['name'], [i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']))
return result
assert isinstance(action, list)
for _action in action:
if not _action.get('argumentList'):
result.append((_action['name'], [], []))
else:
if not isinstance(_action['argumentList']['argument'], list): # when there is one arg
arg_dicts = [_action['argumentList']['argument']]
else:
arg_dicts = _action['argumentList']['argument']
result.append((
_action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
base_address_result: typing.List[bytes] = BASE_ADDRESS_REGEX.findall(location)
base_address = base_address_result[0]
port_result: typing.List[bytes] = BASE_PORT_REGEX.findall(location)
port = int(port_result[0])
return base_address, port
class Gateway:
def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet
self._lan_address = lan_address
self.usn: bytes = (ok_packet.usn or '').encode()
self.ext: bytes = (ok_packet.ext or '').encode()
self.server: bytes = (ok_packet.server or '').encode()
self.location: bytes = (ok_packet.location or '').encode()
self.cache_control: bytes = (ok_packet.cache_control or '').encode()
self.date: bytes = (ok_packet.date or '').encode()
self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response: bytes = b""
self._service_descriptors: Dict[str, str] = {}
self.base_address, self.port = parse_location(self.location)
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version: Optional[str] = None
self.url_base: Optional[str] = None
self._device: Optional[Device] = None
self._devices: List[Device] = []
self._services: List[Service] = []
self._unsupported_actions: Dict[str, typing.List[str]] = {}
self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
@property
def manufacturer_string(self) -> str:
manufacturer_string = "UNKNOWN GATEWAY"
if self.devices:
devices: typing.List[Device] = list(self.devices.values())
device = devices[0]
manufacturer_string = f"{device.manufacturer} {device.modelName}"
return manufacturer_string
@property
def services(self) -> Dict[str, Service]:
services: Dict[str, Service] = {}
if self._services:
for service in self._services:
if service.serviceType is not None:
services[service.serviceType] = service
return services
@property
def devices(self) -> Dict[str, Device]:
devices: Dict[str, Device] = {}
if self._device:
for device in self._devices:
if device.udn is not None:
devices[device.udn] = device
return devices
# def get_service(self, service_type: str) -> Optional[Service]:
# for service in self._services:
# if service.serviceType and service.serviceType.lower() == service_type.lower():
# return service
# return None
def debug_gateway(self) -> Dict[str, typing.Union[str, bytes, int, Dict, List]]:
return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip.decode(),
'server': self.server.decode(),
'urlBase': self.url_base or '',
'location': self.location.decode(),
"specVersion": self.spec_version or '',
'usn': self.usn.decode(),
'urn': self.urn.decode(),
'gateway_xml': self._xml_response.decode(),
'services_xml': self._service_descriptors,
'services': {service.SCPDURL: service.as_dict() for service in self._services},
'reply': self._ok_packet.as_dict(),
'soap_port': self.port,
'registered_soap_commands': self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions,
'soap_requests': list(self.commands._request_debug_infos)
}
@classmethod
async def _try_gateway_from_ssdp(cls, datagram: SSDPDatagram, lan_address: str,
gateway_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional['Gateway']:
required_commands: typing.List[str] = [
'AddPortMapping',
'DeletePortMapping',
'GetExternalIPAddress'
]
try:
gateway = cls(datagram, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
return None
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
return None
@classmethod
async def _gateway_from_igd_args(cls, lan_address: str, gateway_address: str,
igd_args: typing.Dict[str, typing.Union[int, str]],
timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop)
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if not gateway:
raise UPnPError("no gateway found for given args")
return gateway
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
ignored: typing.Set[str] = set()
ssdp_proto = await multi_m_search(
lan_address, gateway_address, timeout, loop
)
try:
while True:
datagram = await ssdp_proto.devices.get()
if datagram.location in ignored:
continue
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if gateway:
return gateway
elif datagram.location:
ignored.add(datagram.location)
finally:
ssdp_proto.disconnect()
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop()
if igd_args:
return await cls._gateway_from_igd_args(lan_address, gateway_address, igd_args, timeout, loop)
try:
return await asyncio.wait_for(loop.create_task(
cls._discover_gateway(lan_address, gateway_address, timeout, loop)
), timeout, loop=loop)
except asyncio.TimeoutError:
raise UPnPError(f"M-SEARCH for {gateway_address}:1900 timed out")
async def discover_commands(self) -> None:
response, xml_bytes, get_err = await scpd_get(
self.path.decode(), self.base_ip.decode(), self.port, loop=self._loop
)
self._xml_response = xml_bytes
if get_err is not None:
raise get_err
spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
if isinstance(spec_version, bytes):
self.spec_version = spec_version.decode()
else:
self.spec_version = spec_version
url_base = get_dict_val_case_insensitive(response, "urlbase")
if isinstance(url_base, bytes):
self.url_base = url_base.decode()
else:
self.url_base = url_base
if not self.url_base:
self.url_base = self.base_address.decode()
if response:
source_keys: typing.List[str] = list(response.keys())
matches: typing.List[str] = list(filter(lambda x: x.lower() == "device", source_keys))
match_key = matches[0]
match: dict = response[match_key]
# if not len(match):
# return None
# if len(match) > 1:
# raise KeyError("overlapping keys")
# if len(match) == 1:
# matched_key: typing.AnyStr = match[0]
# return source[matched_key]
# raise KeyError("overlapping keys")
self._device = Device(
self._devices, self._services, **match
)
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
await self.register_commands(self.services[service_type], self._loop)
return None
async def register_commands(self, service: Service,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL:
raise UPnPError("no scpd url")
if not service.serviceType:
raise UPnPError("no service type")
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
self._service_descriptors[service.SCPDURL] = xml_bytes.decode()
if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)
if xml_bytes:
log.debug("response: %s", xml_bytes.decode())
return None
if not service_dict:
return None
action_list = get_action_list(service_dict)
for name, inputs, outputs in action_list:
try:
self.commands.register(name, service, inputs, outputs)
self._registered_commands[name] = service.serviceType
log.debug("registered %s::%s", service.serviceType, name)
except AttributeError:
self._unsupported_actions.setdefault(service.serviceType, [])
self._unsupported_actions[service.serviceType].append(name)
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)
log.debug("registered service %s", service.serviceType)
return None