handle multiple gateways replying on the same physical device

-try sending to both the router address and the multicast address
This commit is contained in:
Jack Robison 2018-10-17 18:57:02 -04:00
parent a415943ddf
commit 3ab4e1d887
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 85 additions and 35 deletions

View file

@ -1,7 +1,8 @@
import logging
import socket
import asyncio
from collections import OrderedDict
from typing import Dict, List, Union, Type
from typing import Dict, List, Union, Type, Set
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands
@ -144,18 +145,48 @@ class Gateway:
'soap_requests': self._soap_requests
}
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None, unicast: bool = False):
ignored: set = set()
while True:
if not igd_args:
m_search_args, datagram = await asyncio.wait_for(fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket,
ignored, unicast), timeout)
else:
m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket, ignored,
unicast)
try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
await gateway.discover_commands(soap_socket)
log.debug('found gateway device %s', datagram.location)
return gateway
except asyncio.TimeoutError:
log.debug("get %s timed out, looking for other devices", datagram.location)
ignored.add(datagram.location)
continue
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None):
if not igd_args:
m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
else:
m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket)
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
await gateway.discover_commands(soap_socket)
return gateway
soap_socket: socket.socket = None, unicast: bool = None):
if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, ssdp_socket,
soap_socket, unicast=unicast)
done, pending = await asyncio.wait([
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=True
),
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=False
)], return_when=asyncio.tasks.FIRST_COMPLETED
)
for task in list(pending):
task.cancel()
result = list(done)[0].result()
return result
async def discover_commands(self, soap_socket: socket.socket = None):
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)

View file

@ -19,14 +19,25 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str) -> None:
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None,
unicast: bool = False) -> None:
super().__init__(multicast_address, lan_address)
self.lan_address = lan_address
self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
self.notifications: typing.List = []
def disconnect(self):
if self.transport:
self.transport.close()
while self._pending_searches:
pending = self._pending_searches.pop()[2]
if not pending.cancelled() and not pending.done():
pending.cancel()
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored:
return
tmp: typing.List = []
set_futures: typing.List = []
while self._pending_searches:
@ -34,8 +45,8 @@ class SSDPProtocol(MulticastProtocol):
a, s = t[0], t[1]
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
f: Future = t[2]
h: asyncio.Handle = t[3]
h.cancel()
# h: asyncio.Handle = t[3]
# h.cancel()
if f not in set_futures:
set_futures.append(f)
if not f.done():
@ -46,9 +57,10 @@ class SSDPProtocol(MulticastProtocol):
self._pending_searches.append(tmp.pop())
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
dest = address if self._unicast else SSDP_IP_ADDRESS
for packet in packets:
log.debug("send m search to %s: %s", address, packet.st)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
fut: Future = Future()
@ -56,14 +68,14 @@ class SSDPProtocol(MulticastProtocol):
for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
assert packet.st is not None
h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
self._pending_searches.append((address, packet.st, fut, h))
# h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
self._pending_searches.append((address, packet.st, fut))
packets.append(packet)
self.send_many_m_searches(address, packets),
return await fut
def datagram_received(self, data, addr) -> None:
if addr[0] == self.lan_address:
if addr[0] == self.bind_address:
return
try:
packet = SSDPDatagram.decode(data)
@ -96,14 +108,14 @@ class SSDPProtocol(MulticastProtocol):
# return
async def listen_ssdp(lan_address: str, gateway_address: str,
ssdp_socket: socket.socket = None) -> typing.Tuple[DatagramTransport, SSDPProtocol,
str, str]:
async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
SSDPProtocol, str, str]:
loop = asyncio.get_running_loop()
try:
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
)
transport: DatagramTransport = listen_result[0]
protocol: SSDPProtocol = listen_result[1]
@ -113,29 +125,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str,
protocol.join_group(protocol.multicast_address, protocol.bind_address)
protocol.set_ttl(1)
except Exception as err:
transport.close()
protocol.disconnect()
raise UPnPError(err)
return transport, protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
ssdp_socket: socket.socket = None) -> SSDPDatagram:
ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None,
unicast: bool = False) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket
lan_address, gateway_address, ssdp_socket, ignored, unicast
)
try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:
transport.close()
protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None) -> typing.List[OrderedDict]:
ssdp_socket: socket.socket = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket
lan_address, gateway_address, ssdp_socket, ignored, unicast
)
packet_args = list(packet_generator())
batch_size = 2
@ -145,21 +159,26 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try:
await protocol.m_search(gateway_address, batch_timeout, args)
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), timeout)
protocol.disconnect()
return args
except (asyncio.TimeoutError, asyncio.CancelledError):
except asyncio.TimeoutError:
continue
protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None) -> typing.Tuple[OrderedDict, SSDPDatagram]:
ssdp_socket: socket.socket = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked
for args in args_to_try:
try:
packet = await m_search(lan_address, gateway_address, args, 3)
packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast)
return args, packet
except UPnPError:
continue