handle multiple gateways replying on the same physical device

-try sending to both the router address and the multicast address
This commit is contained in:
Jack Robison 2018-10-17 18:57:02 -04:00
parent a415943ddf
commit 3ab4e1d887
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 85 additions and 35 deletions

View file

@ -1,7 +1,8 @@
import logging import logging
import socket import socket
import asyncio
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Union, Type from typing import Dict, List, Union, Type, Set
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
@ -144,18 +145,48 @@ class Gateway:
'soap_requests': self._soap_requests 'soap_requests': self._soap_requests
} }
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None, unicast: bool = False):
ignored: set = set()
while True:
if not igd_args:
m_search_args, datagram = await asyncio.wait_for(fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket,
ignored, unicast), timeout)
else:
m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket, ignored,
unicast)
try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
await gateway.discover_commands(soap_socket)
log.debug('found gateway device %s', datagram.location)
return gateway
except asyncio.TimeoutError:
log.debug("get %s timed out, looking for other devices", datagram.location)
ignored.add(datagram.location)
continue
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None, igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None): soap_socket: socket.socket = None, unicast: bool = None):
if not igd_args: if unicast is not None:
m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, ssdp_socket,
else: soap_socket, unicast=unicast)
m_search_args = OrderedDict(igd_args) done, pending = await asyncio.wait([
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket) cls._discover_gateway(
gateway = cls(datagram, m_search_args, lan_address, gateway_address) lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=True
await gateway.discover_commands(soap_socket) ),
return gateway cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=False
)], return_when=asyncio.tasks.FIRST_COMPLETED
)
for task in list(pending):
task.cancel()
result = list(done)[0].result()
return result
async def discover_commands(self, soap_socket: socket.socket = None): async def discover_commands(self, soap_socket: socket.socket = None):
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)

View file

@ -19,14 +19,25 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str) -> None: def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None,
unicast: bool = False) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.lan_address = lan_address self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = [] self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
self.notifications: typing.List = [] self.notifications: typing.List = []
def disconnect(self):
if self.transport:
self.transport.close()
while self._pending_searches:
pending = self._pending_searches.pop()[2]
if not pending.cancelled() and not pending.done():
pending.cancel()
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored:
return
tmp: typing.List = [] tmp: typing.List = []
set_futures: typing.List = [] set_futures: typing.List = []
while self._pending_searches: while self._pending_searches:
@ -34,8 +45,8 @@ class SSDPProtocol(MulticastProtocol):
a, s = t[0], t[1] a, s = t[0], t[1]
if (address == a) and (s in [packet.st, "upnp:rootdevice"]): if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
f: Future = t[2] f: Future = t[2]
h: asyncio.Handle = t[3] # h: asyncio.Handle = t[3]
h.cancel() # h.cancel()
if f not in set_futures: if f not in set_futures:
set_futures.append(f) set_futures.append(f)
if not f.done(): if not f.done():
@ -46,9 +57,10 @@ class SSDPProtocol(MulticastProtocol):
self._pending_searches.append(tmp.pop()) self._pending_searches.append(tmp.pop())
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]): def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
dest = address if self._unicast else SSDP_IP_ADDRESS
for packet in packets: for packet in packets:
log.debug("send m search to %s: %s", address, packet.st) log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram: async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
fut: Future = Future() fut: Future = Future()
@ -56,14 +68,14 @@ class SSDPProtocol(MulticastProtocol):
for datagram in datagrams: for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
assert packet.st is not None assert packet.st is not None
h = asyncio.get_running_loop().call_later(timeout, fut.cancel) # h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
self._pending_searches.append((address, packet.st, fut, h)) self._pending_searches.append((address, packet.st, fut))
packets.append(packet) packets.append(packet)
self.send_many_m_searches(address, packets), self.send_many_m_searches(address, packets),
return await fut return await fut
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
if addr[0] == self.lan_address: if addr[0] == self.bind_address:
return return
try: try:
packet = SSDPDatagram.decode(data) packet = SSDPDatagram.decode(data)
@ -96,14 +108,14 @@ class SSDPProtocol(MulticastProtocol):
# return # return
async def listen_ssdp(lan_address: str, gateway_address: str, async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None,
ssdp_socket: socket.socket = None) -> typing.Tuple[DatagramTransport, SSDPProtocol, ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
str, str]: SSDPProtocol, str, str]:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address) sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple = await loop.create_datagram_endpoint( listen_result: typing.Tuple = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
) )
transport: DatagramTransport = listen_result[0] transport: DatagramTransport = listen_result[0]
protocol: SSDPProtocol = listen_result[1] protocol: SSDPProtocol = listen_result[1]
@ -113,29 +125,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str,
protocol.join_group(protocol.multicast_address, protocol.bind_address) protocol.join_group(protocol.multicast_address, protocol.bind_address)
protocol.set_ttl(1) protocol.set_ttl(1)
except Exception as err: except Exception as err:
transport.close() protocol.disconnect()
raise UPnPError(err) raise UPnPError(err)
return transport, protocol, gateway_address, lan_address return transport, protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
ssdp_socket: socket.socket = None) -> SSDPDatagram: ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None,
unicast: bool = False) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp( transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket lan_address, gateway_address, ssdp_socket, ignored, unicast
) )
try: try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
except (asyncio.TimeoutError, asyncio.CancelledError): except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:
transport.close() protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None) -> typing.List[OrderedDict]: ssdp_socket: socket.socket = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
transport, protocol, gateway_address, lan_address = await listen_ssdp( transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket lan_address, gateway_address, ssdp_socket, ignored, unicast
) )
packet_args = list(packet_generator()) packet_args = list(packet_generator())
batch_size = 2 batch_size = 2
@ -145,21 +159,26 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
packet_args = packet_args[batch_size:] packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size) log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try: try:
await protocol.m_search(gateway_address, batch_timeout, args) await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), timeout)
protocol.disconnect()
return args return args
except (asyncio.TimeoutError, asyncio.CancelledError): except asyncio.TimeoutError:
continue continue
protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None) -> typing.Tuple[OrderedDict, SSDPDatagram]: ssdp_socket: socket.socket = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time # we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked # check the args in the batch that got a reply one at a time to see which one worked
for args in args_to_try: for args in args_to_try:
try: try:
packet = await m_search(lan_address, gateway_address, args, 3) packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast)
return args, packet return args, packet
except UPnPError: except UPnPError:
continue continue