From 3ab4e1d887c28caa0e756c0e2ce0da8eb3846fa5 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Wed, 17 Oct 2018 18:57:02 -0400 Subject: [PATCH] handle multiple gateways replying on the same physical device -try sending to both the router address and the multicast address --- aioupnp/gateway.py | 51 +++++++++++++++++++++++------ aioupnp/protocols/ssdp.py | 69 +++++++++++++++++++++++++-------------- 2 files changed, 85 insertions(+), 35 deletions(-) diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index 6bbf48d..5186e56 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -1,7 +1,8 @@ import logging import socket +import asyncio from collections import OrderedDict -from typing import Dict, List, Union, Type +from typing import Dict, List, Union, Type, Set from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.commands import SOAPCommands @@ -144,18 +145,48 @@ class Gateway: 'soap_requests': self._soap_requests } + @classmethod + async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, + igd_args: OrderedDict = None, ssdp_socket: socket.socket = None, + soap_socket: socket.socket = None, unicast: bool = False): + ignored: set = set() + while True: + if not igd_args: + m_search_args, datagram = await asyncio.wait_for(fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, + ignored, unicast), timeout) + else: + m_search_args = OrderedDict(igd_args) + datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket, ignored, + unicast) + try: + gateway = cls(datagram, m_search_args, lan_address, gateway_address) + await gateway.discover_commands(soap_socket) + log.debug('found gateway device %s', datagram.location) + return gateway + except asyncio.TimeoutError: + log.debug("get %s timed out, looking for other devices", datagram.location) + ignored.add(datagram.location) + continue + @classmethod async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, igd_args: OrderedDict = None, ssdp_socket: socket.socket = None, - soap_socket: socket.socket = None): - if not igd_args: - m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) - else: - m_search_args = OrderedDict(igd_args) - datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket) - gateway = cls(datagram, m_search_args, lan_address, gateway_address) - await gateway.discover_commands(soap_socket) - return gateway + soap_socket: socket.socket = None, unicast: bool = None): + if unicast is not None: + return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, ssdp_socket, + soap_socket, unicast=unicast) + done, pending = await asyncio.wait([ + cls._discover_gateway( + lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=True + ), + cls._discover_gateway( + lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=False + )], return_when=asyncio.tasks.FIRST_COMPLETED + ) + for task in list(pending): + task.cancel() + result = list(done)[0].result() + return result async def discover_commands(self, soap_socket: socket.socket = None): response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 6c89d3a..6fd4437 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -19,14 +19,25 @@ log = logging.getLogger(__name__) class SSDPProtocol(MulticastProtocol): - def __init__(self, multicast_address: str, lan_address: str) -> None: + def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None, + unicast: bool = False) -> None: super().__init__(multicast_address, lan_address) - self.lan_address = lan_address + self._unicast = unicast + self._ignored: typing.Set[str] = ignored or set() # ignored locations self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = [] - self.notifications: typing.List = [] + def disconnect(self): + if self.transport: + self.transport.close() + while self._pending_searches: + pending = self._pending_searches.pop()[2] + if not pending.cancelled() and not pending.done(): + pending.cancel() + def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: + if packet.location in self._ignored: + return tmp: typing.List = [] set_futures: typing.List = [] while self._pending_searches: @@ -34,8 +45,8 @@ class SSDPProtocol(MulticastProtocol): a, s = t[0], t[1] if (address == a) and (s in [packet.st, "upnp:rootdevice"]): f: Future = t[2] - h: asyncio.Handle = t[3] - h.cancel() + # h: asyncio.Handle = t[3] + # h.cancel() if f not in set_futures: set_futures.append(f) if not f.done(): @@ -46,9 +57,10 @@ class SSDPProtocol(MulticastProtocol): self._pending_searches.append(tmp.pop()) def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]): + dest = address if self._unicast else SSDP_IP_ADDRESS for packet in packets: - log.debug("send m search to %s: %s", address, packet.st) - self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) + log.debug("send m search to %s: %s", dest, packet.st) + self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram: fut: Future = Future() @@ -56,14 +68,14 @@ class SSDPProtocol(MulticastProtocol): for datagram in datagrams: packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) assert packet.st is not None - h = asyncio.get_running_loop().call_later(timeout, fut.cancel) - self._pending_searches.append((address, packet.st, fut, h)) + # h = asyncio.get_running_loop().call_later(timeout, fut.cancel) + self._pending_searches.append((address, packet.st, fut)) packets.append(packet) self.send_many_m_searches(address, packets), return await fut def datagram_received(self, data, addr) -> None: - if addr[0] == self.lan_address: + if addr[0] == self.bind_address: return try: packet = SSDPDatagram.decode(data) @@ -96,14 +108,14 @@ class SSDPProtocol(MulticastProtocol): # return -async def listen_ssdp(lan_address: str, gateway_address: str, - ssdp_socket: socket.socket = None) -> typing.Tuple[DatagramTransport, SSDPProtocol, - str, str]: +async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None, + ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport, + SSDPProtocol, str, str]: loop = asyncio.get_running_loop() try: sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address) listen_result: typing.Tuple = await loop.create_datagram_endpoint( - lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock + lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock ) transport: DatagramTransport = listen_result[0] protocol: SSDPProtocol = listen_result[1] @@ -113,29 +125,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str, protocol.join_group(protocol.multicast_address, protocol.bind_address) protocol.set_ttl(1) except Exception as err: - transport.close() + protocol.disconnect() raise UPnPError(err) return transport, protocol, gateway_address, lan_address async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, - ssdp_socket: socket.socket = None) -> SSDPDatagram: + ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None, + unicast: bool = False) -> SSDPDatagram: transport, protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, ssdp_socket + lan_address, gateway_address, ssdp_socket, ignored, unicast ) try: return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) except (asyncio.TimeoutError, asyncio.CancelledError): raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) finally: - transport.close() + protocol.disconnect() async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, - ssdp_socket: socket.socket = None) -> typing.List[OrderedDict]: + ssdp_socket: socket.socket = None, + ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]: transport, protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, ssdp_socket + lan_address, gateway_address, ssdp_socket, ignored, unicast ) packet_args = list(packet_generator()) batch_size = 2 @@ -145,21 +159,26 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = packet_args = packet_args[batch_size:] log.debug("sending batch of %i M-SEARCH attempts", batch_size) try: - await protocol.m_search(gateway_address, batch_timeout, args) + await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), timeout) + protocol.disconnect() return args - except (asyncio.TimeoutError, asyncio.CancelledError): + except asyncio.TimeoutError: continue + protocol.disconnect() raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, - ssdp_socket: socket.socket = None) -> typing.Tuple[OrderedDict, SSDPDatagram]: + ssdp_socket: socket.socket = None, + ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict, + SSDPDatagram]: # we don't know which packet the gateway replies to, so send small batches at a time - args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) + + args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast) # check the args in the batch that got a reply one at a time to see which one worked for args in args_to_try: try: - packet = await m_search(lan_address, gateway_address, args, 3) + packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast) return args, packet except UPnPError: continue