handle multiple gateways replying on the same physical device
-try sending to both the router address and the multicast address
This commit is contained in:
parent
a415943ddf
commit
3ab4e1d887
2 changed files with 85 additions and 35 deletions
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
import socket
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Union, Type
|
||||
from typing import Dict, List, Union, Type, Set
|
||||
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||
from aioupnp.constants import SPEC_VERSION, SERVICE
|
||||
from aioupnp.commands import SOAPCommands
|
||||
|
@ -145,17 +146,47 @@ class Gateway:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
|
||||
soap_socket: socket.socket = None):
|
||||
soap_socket: socket.socket = None, unicast: bool = False):
|
||||
ignored: set = set()
|
||||
while True:
|
||||
if not igd_args:
|
||||
m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
|
||||
m_search_args, datagram = await asyncio.wait_for(fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket,
|
||||
ignored, unicast), timeout)
|
||||
else:
|
||||
m_search_args = OrderedDict(igd_args)
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket)
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket, ignored,
|
||||
unicast)
|
||||
try:
|
||||
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
|
||||
await gateway.discover_commands(soap_socket)
|
||||
log.debug('found gateway device %s', datagram.location)
|
||||
return gateway
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("get %s timed out, looking for other devices", datagram.location)
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
|
||||
soap_socket: socket.socket = None, unicast: bool = None):
|
||||
if unicast is not None:
|
||||
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, ssdp_socket,
|
||||
soap_socket, unicast=unicast)
|
||||
done, pending = await asyncio.wait([
|
||||
cls._discover_gateway(
|
||||
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=True
|
||||
),
|
||||
cls._discover_gateway(
|
||||
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=False
|
||||
)], return_when=asyncio.tasks.FIRST_COMPLETED
|
||||
)
|
||||
for task in list(pending):
|
||||
task.cancel()
|
||||
result = list(done)[0].result()
|
||||
return result
|
||||
|
||||
async def discover_commands(self, soap_socket: socket.socket = None):
|
||||
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
|
||||
|
|
|
@ -19,14 +19,25 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SSDPProtocol(MulticastProtocol):
|
||||
def __init__(self, multicast_address: str, lan_address: str) -> None:
|
||||
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None,
|
||||
unicast: bool = False) -> None:
|
||||
super().__init__(multicast_address, lan_address)
|
||||
self.lan_address = lan_address
|
||||
self._unicast = unicast
|
||||
self._ignored: typing.Set[str] = ignored or set() # ignored locations
|
||||
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
|
||||
|
||||
self.notifications: typing.List = []
|
||||
|
||||
def disconnect(self):
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
while self._pending_searches:
|
||||
pending = self._pending_searches.pop()[2]
|
||||
if not pending.cancelled() and not pending.done():
|
||||
pending.cancel()
|
||||
|
||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||
if packet.location in self._ignored:
|
||||
return
|
||||
tmp: typing.List = []
|
||||
set_futures: typing.List = []
|
||||
while self._pending_searches:
|
||||
|
@ -34,8 +45,8 @@ class SSDPProtocol(MulticastProtocol):
|
|||
a, s = t[0], t[1]
|
||||
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
|
||||
f: Future = t[2]
|
||||
h: asyncio.Handle = t[3]
|
||||
h.cancel()
|
||||
# h: asyncio.Handle = t[3]
|
||||
# h.cancel()
|
||||
if f not in set_futures:
|
||||
set_futures.append(f)
|
||||
if not f.done():
|
||||
|
@ -46,9 +57,10 @@ class SSDPProtocol(MulticastProtocol):
|
|||
self._pending_searches.append(tmp.pop())
|
||||
|
||||
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
|
||||
dest = address if self._unicast else SSDP_IP_ADDRESS
|
||||
for packet in packets:
|
||||
log.debug("send m search to %s: %s", address, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
|
||||
log.debug("send m search to %s: %s", dest, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
|
||||
|
||||
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
|
||||
fut: Future = Future()
|
||||
|
@ -56,14 +68,14 @@ class SSDPProtocol(MulticastProtocol):
|
|||
for datagram in datagrams:
|
||||
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
|
||||
assert packet.st is not None
|
||||
h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
|
||||
self._pending_searches.append((address, packet.st, fut, h))
|
||||
# h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
|
||||
self._pending_searches.append((address, packet.st, fut))
|
||||
packets.append(packet)
|
||||
self.send_many_m_searches(address, packets),
|
||||
return await fut
|
||||
|
||||
def datagram_received(self, data, addr) -> None:
|
||||
if addr[0] == self.lan_address:
|
||||
if addr[0] == self.bind_address:
|
||||
return
|
||||
try:
|
||||
packet = SSDPDatagram.decode(data)
|
||||
|
@ -96,14 +108,14 @@ class SSDPProtocol(MulticastProtocol):
|
|||
# return
|
||||
|
||||
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str,
|
||||
ssdp_socket: socket.socket = None) -> typing.Tuple[DatagramTransport, SSDPProtocol,
|
||||
str, str]:
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
|
||||
SSDPProtocol, str, str]:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
|
||||
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
|
||||
)
|
||||
transport: DatagramTransport = listen_result[0]
|
||||
protocol: SSDPProtocol = listen_result[1]
|
||||
|
@ -113,29 +125,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str,
|
|||
protocol.join_group(protocol.multicast_address, protocol.bind_address)
|
||||
protocol.set_ttl(1)
|
||||
except Exception as err:
|
||||
transport.close()
|
||||
protocol.disconnect()
|
||||
raise UPnPError(err)
|
||||
|
||||
return transport, protocol, gateway_address, lan_address
|
||||
|
||||
|
||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
|
||||
ssdp_socket: socket.socket = None) -> SSDPDatagram:
|
||||
ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None,
|
||||
unicast: bool = False) -> SSDPDatagram:
|
||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, ssdp_socket
|
||||
lan_address, gateway_address, ssdp_socket, ignored, unicast
|
||||
)
|
||||
try:
|
||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
finally:
|
||||
transport.close()
|
||||
protocol.disconnect()
|
||||
|
||||
|
||||
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
ssdp_socket: socket.socket = None) -> typing.List[OrderedDict]:
|
||||
ssdp_socket: socket.socket = None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
|
||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, ssdp_socket
|
||||
lan_address, gateway_address, ssdp_socket, ignored, unicast
|
||||
)
|
||||
packet_args = list(packet_generator())
|
||||
batch_size = 2
|
||||
|
@ -145,21 +159,26 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
|
|||
packet_args = packet_args[batch_size:]
|
||||
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
|
||||
try:
|
||||
await protocol.m_search(gateway_address, batch_timeout, args)
|
||||
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), timeout)
|
||||
protocol.disconnect()
|
||||
return args
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
protocol.disconnect()
|
||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
|
||||
|
||||
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
ssdp_socket: socket.socket = None) -> typing.Tuple[OrderedDict, SSDPDatagram]:
|
||||
ssdp_socket: socket.socket = None,
|
||||
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
|
||||
SSDPDatagram]:
|
||||
# we don't know which packet the gateway replies to, so send small batches at a time
|
||||
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
|
||||
|
||||
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast)
|
||||
# check the args in the batch that got a reply one at a time to see which one worked
|
||||
for args in args_to_try:
|
||||
try:
|
||||
packet = await m_search(lan_address, gateway_address, args, 3)
|
||||
packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast)
|
||||
return args, packet
|
||||
except UPnPError:
|
||||
continue
|
||||
|
|
Loading…
Reference in a new issue