handle multiple gateways replying on the same physical device
-try sending to both the router address and the multicast address
This commit is contained in:
parent
a415943ddf
commit
3ab4e1d887
2 changed files with 85 additions and 35 deletions
|
@ -1,7 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict, List, Union, Type
|
from typing import Dict, List, Union, Type, Set
|
||||||
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||||
from aioupnp.constants import SPEC_VERSION, SERVICE
|
from aioupnp.constants import SPEC_VERSION, SERVICE
|
||||||
from aioupnp.commands import SOAPCommands
|
from aioupnp.commands import SOAPCommands
|
||||||
|
@ -145,17 +146,47 @@ class Gateway:
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||||
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
|
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
|
||||||
soap_socket: socket.socket = None):
|
soap_socket: socket.socket = None, unicast: bool = False):
|
||||||
|
ignored: set = set()
|
||||||
|
while True:
|
||||||
if not igd_args:
|
if not igd_args:
|
||||||
m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
|
m_search_args, datagram = await asyncio.wait_for(fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket,
|
||||||
|
ignored, unicast), timeout)
|
||||||
else:
|
else:
|
||||||
m_search_args = OrderedDict(igd_args)
|
m_search_args = OrderedDict(igd_args)
|
||||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket)
|
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket, ignored,
|
||||||
|
unicast)
|
||||||
|
try:
|
||||||
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
|
gateway = cls(datagram, m_search_args, lan_address, gateway_address)
|
||||||
await gateway.discover_commands(soap_socket)
|
await gateway.discover_commands(soap_socket)
|
||||||
|
log.debug('found gateway device %s', datagram.location)
|
||||||
return gateway
|
return gateway
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.debug("get %s timed out, looking for other devices", datagram.location)
|
||||||
|
ignored.add(datagram.location)
|
||||||
|
continue
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||||
|
igd_args: OrderedDict = None, ssdp_socket: socket.socket = None,
|
||||||
|
soap_socket: socket.socket = None, unicast: bool = None):
|
||||||
|
if unicast is not None:
|
||||||
|
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, ssdp_socket,
|
||||||
|
soap_socket, unicast=unicast)
|
||||||
|
done, pending = await asyncio.wait([
|
||||||
|
cls._discover_gateway(
|
||||||
|
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=True
|
||||||
|
),
|
||||||
|
cls._discover_gateway(
|
||||||
|
lan_address, gateway_address, timeout, igd_args, ssdp_socket, soap_socket, unicast=False
|
||||||
|
)], return_when=asyncio.tasks.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
for task in list(pending):
|
||||||
|
task.cancel()
|
||||||
|
result = list(done)[0].result()
|
||||||
|
return result
|
||||||
|
|
||||||
async def discover_commands(self, soap_socket: socket.socket = None):
|
async def discover_commands(self, soap_socket: socket.socket = None):
|
||||||
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
|
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
|
||||||
|
|
|
@ -19,14 +19,25 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SSDPProtocol(MulticastProtocol):
|
class SSDPProtocol(MulticastProtocol):
|
||||||
def __init__(self, multicast_address: str, lan_address: str) -> None:
|
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Set[str] = None,
|
||||||
|
unicast: bool = False) -> None:
|
||||||
super().__init__(multicast_address, lan_address)
|
super().__init__(multicast_address, lan_address)
|
||||||
self.lan_address = lan_address
|
self._unicast = unicast
|
||||||
|
self._ignored: typing.Set[str] = ignored or set() # ignored locations
|
||||||
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
|
self._pending_searches: typing.List[typing.Tuple[str, str, Future, asyncio.Handle]] = []
|
||||||
|
|
||||||
self.notifications: typing.List = []
|
self.notifications: typing.List = []
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
if self.transport:
|
||||||
|
self.transport.close()
|
||||||
|
while self._pending_searches:
|
||||||
|
pending = self._pending_searches.pop()[2]
|
||||||
|
if not pending.cancelled() and not pending.done():
|
||||||
|
pending.cancel()
|
||||||
|
|
||||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||||
|
if packet.location in self._ignored:
|
||||||
|
return
|
||||||
tmp: typing.List = []
|
tmp: typing.List = []
|
||||||
set_futures: typing.List = []
|
set_futures: typing.List = []
|
||||||
while self._pending_searches:
|
while self._pending_searches:
|
||||||
|
@ -34,8 +45,8 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
a, s = t[0], t[1]
|
a, s = t[0], t[1]
|
||||||
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
|
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
|
||||||
f: Future = t[2]
|
f: Future = t[2]
|
||||||
h: asyncio.Handle = t[3]
|
# h: asyncio.Handle = t[3]
|
||||||
h.cancel()
|
# h.cancel()
|
||||||
if f not in set_futures:
|
if f not in set_futures:
|
||||||
set_futures.append(f)
|
set_futures.append(f)
|
||||||
if not f.done():
|
if not f.done():
|
||||||
|
@ -46,9 +57,10 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
self._pending_searches.append(tmp.pop())
|
self._pending_searches.append(tmp.pop())
|
||||||
|
|
||||||
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
|
def send_many_m_searches(self, address: str, packets: typing.List[SSDPDatagram]):
|
||||||
|
dest = address if self._unicast else SSDP_IP_ADDRESS
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
log.debug("send m search to %s: %s", address, packet.st)
|
log.debug("send m search to %s: %s", dest, packet.st)
|
||||||
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
|
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
|
||||||
|
|
||||||
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
|
async def m_search(self, address: str, timeout: float, datagrams: typing.List[OrderedDict]) -> SSDPDatagram:
|
||||||
fut: Future = Future()
|
fut: Future = Future()
|
||||||
|
@ -56,14 +68,14 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
for datagram in datagrams:
|
for datagram in datagrams:
|
||||||
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
|
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
|
||||||
assert packet.st is not None
|
assert packet.st is not None
|
||||||
h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
|
# h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
|
||||||
self._pending_searches.append((address, packet.st, fut, h))
|
self._pending_searches.append((address, packet.st, fut))
|
||||||
packets.append(packet)
|
packets.append(packet)
|
||||||
self.send_many_m_searches(address, packets),
|
self.send_many_m_searches(address, packets),
|
||||||
return await fut
|
return await fut
|
||||||
|
|
||||||
def datagram_received(self, data, addr) -> None:
|
def datagram_received(self, data, addr) -> None:
|
||||||
if addr[0] == self.lan_address:
|
if addr[0] == self.bind_address:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
packet = SSDPDatagram.decode(data)
|
packet = SSDPDatagram.decode(data)
|
||||||
|
@ -96,14 +108,14 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
# return
|
# return
|
||||||
|
|
||||||
|
|
||||||
async def listen_ssdp(lan_address: str, gateway_address: str,
|
async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None,
|
||||||
ssdp_socket: socket.socket = None) -> typing.Tuple[DatagramTransport, SSDPProtocol,
|
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
|
||||||
str, str]:
|
SSDPProtocol, str, str]:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
|
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
|
||||||
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
|
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
|
||||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
|
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
|
||||||
)
|
)
|
||||||
transport: DatagramTransport = listen_result[0]
|
transport: DatagramTransport = listen_result[0]
|
||||||
protocol: SSDPProtocol = listen_result[1]
|
protocol: SSDPProtocol = listen_result[1]
|
||||||
|
@ -113,29 +125,31 @@ async def listen_ssdp(lan_address: str, gateway_address: str,
|
||||||
protocol.join_group(protocol.multicast_address, protocol.bind_address)
|
protocol.join_group(protocol.multicast_address, protocol.bind_address)
|
||||||
protocol.set_ttl(1)
|
protocol.set_ttl(1)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
transport.close()
|
protocol.disconnect()
|
||||||
raise UPnPError(err)
|
raise UPnPError(err)
|
||||||
|
|
||||||
return transport, protocol, gateway_address, lan_address
|
return transport, protocol, gateway_address, lan_address
|
||||||
|
|
||||||
|
|
||||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
|
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
|
||||||
ssdp_socket: socket.socket = None) -> SSDPDatagram:
|
ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None,
|
||||||
|
unicast: bool = False) -> SSDPDatagram:
|
||||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||||
lan_address, gateway_address, ssdp_socket
|
lan_address, gateway_address, ssdp_socket, ignored, unicast
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||||
finally:
|
finally:
|
||||||
transport.close()
|
protocol.disconnect()
|
||||||
|
|
||||||
|
|
||||||
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||||
ssdp_socket: socket.socket = None) -> typing.List[OrderedDict]:
|
ssdp_socket: socket.socket = None,
|
||||||
|
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
|
||||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||||
lan_address, gateway_address, ssdp_socket
|
lan_address, gateway_address, ssdp_socket, ignored, unicast
|
||||||
)
|
)
|
||||||
packet_args = list(packet_generator())
|
packet_args = list(packet_generator())
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
@ -145,21 +159,26 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
|
||||||
packet_args = packet_args[batch_size:]
|
packet_args = packet_args[batch_size:]
|
||||||
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
|
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
|
||||||
try:
|
try:
|
||||||
await protocol.m_search(gateway_address, batch_timeout, args)
|
await asyncio.wait_for(protocol.m_search(gateway_address, batch_timeout, args), timeout)
|
||||||
|
protocol.disconnect()
|
||||||
return args
|
return args
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
protocol.disconnect()
|
||||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||||
|
|
||||||
|
|
||||||
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||||
ssdp_socket: socket.socket = None) -> typing.Tuple[OrderedDict, SSDPDatagram]:
|
ssdp_socket: socket.socket = None,
|
||||||
|
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
|
||||||
|
SSDPDatagram]:
|
||||||
# we don't know which packet the gateway replies to, so send small batches at a time
|
# we don't know which packet the gateway replies to, so send small batches at a time
|
||||||
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
|
|
||||||
|
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast)
|
||||||
# check the args in the batch that got a reply one at a time to see which one worked
|
# check the args in the batch that got a reply one at a time to see which one worked
|
||||||
for args in args_to_try:
|
for args in args_to_try:
|
||||||
try:
|
try:
|
||||||
packet = await m_search(lan_address, gateway_address, args, 3)
|
packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast)
|
||||||
return args, packet
|
return args, packet
|
||||||
except UPnPError:
|
except UPnPError:
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in a new issue