Faster gateway discovery #18

Merged
jackrobison merged 7 commits from fast-discover into master 2019-10-25 22:12:41 +02:00
5 changed files with 150 additions and 177 deletions
Showing only changes of commit 655d2ff623 - Show all commits

View file

@ -75,8 +75,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
'interface': 'default',
'gateway_address': '',
'lan_address': '',
'timeout': 30,
'unicast': False
'timeout': 3,
}
options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
@ -114,10 +113,9 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
gateway_address: str = str(options.pop('gateway_address'))
timeout: int = int(options.pop('timeout'))
interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
run_cli(
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, kwargs, loop
)
return 0

View file

@ -2,13 +2,12 @@ import re
import logging
import typing
import asyncio
from collections import OrderedDict
from typing import Dict, List
from typing import Dict, List, Optional
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands, SCPDRequestDebuggingInfo
from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, multi_m_search
from aioupnp.protocols.ssdp import m_search, multi_m_search
from aioupnp.protocols.scpd import scpd_get
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.util import flatten_keys
@ -70,7 +69,7 @@ def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
class Gateway:
def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet
self._lan_address = lan_address
@ -90,10 +89,10 @@ class Gateway:
assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version: typing.Optional[str] = None
self.url_base: typing.Optional[str] = None
self.spec_version: Optional[str] = None
self.url_base: Optional[str] = None
self._device: typing.Optional[Device] = None
self._device: Optional[Device] = None
self._devices: List[Device] = []
self._services: List[Service] = []
@ -128,7 +127,7 @@ class Gateway:
devices[device.udn] = device
return devices
# def get_service(self, service_type: str) -> typing.Optional[Service]:
# def get_service(self, service_type: str) -> Optional[Service]:
# for service in self._services:
# if service.serviceType and service.serviceType.lower() == service_type.lower():
# return service
@ -155,73 +154,78 @@ class Gateway:
}
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: bool = False) -> 'Gateway':
ignored: typing.Set[str] = set()
async def _try_gateway_from_ssdp(cls, datagram: SSDPDatagram, lan_address: str,
gateway_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional['Gateway']:
required_commands: typing.List[str] = [
'AddPortMapping',
'DeletePortMapping',
'GetExternalIPAddress'
]
while True:
if not igd_args:
datagram = await multi_m_search(
lan_address, gateway_address, timeout, loop, ignored, unicast
)
else:
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
try:
gateway = cls(datagram, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
continue
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
try:
gateway = cls(datagram, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
ignored.add(datagram.location)
continue
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
return None
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
return None
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: typing.Optional[bool] = None) -> 'Gateway':
async def _gateway_from_igd_args(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set())
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if not gateway:
raise UPnPError("no gateway found for given args")
return gateway
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
ignored: typing.Set[str] = set()
ssdp_proto = await multi_m_search(
lan_address, gateway_address, timeout, loop, ignored
)
try:
while True:
datagram = await ssdp_proto.devices.get()
if datagram.location in ignored:
continue
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if gateway:
return gateway
else:
ignored.add(datagram.location)
finally:
ssdp_proto.disconnect()
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop()
if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
done, pending = await asyncio.wait([
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=True
),
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
)
], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
for task in pending:
task.cancel()
for task in done:
try:
task.exception()
except asyncio.CancelledError:
pass
results: typing.List['asyncio.Future[Gateway]'] = list(done)
return results[0].result()
if igd_args:
return await cls._gateway_from_igd_args(lan_address, gateway_address, timeout, igd_args, loop)
try:
return await asyncio.wait_for(loop.create_task(
cls._discover_gateway(lan_address, gateway_address, timeout, loop)
), timeout, loop=loop)
except asyncio.TimeoutError:
raise UPnPError(f"M-SEARCH for {gateway_address}:1900 timed out")
async def discover_commands(self) -> None:
response, xml_bytes, get_err = await scpd_get(
@ -266,7 +270,7 @@ class Gateway:
return None
async def register_commands(self, service: Service,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL:
raise UPnPError("no scpd url")
if not service.serviceType:

View file

@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int,
typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
@ -141,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
**kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())

View file

@ -17,17 +17,23 @@ ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-
log = logging.getLogger(__name__)
class PendingSearch(typing.NamedTuple):
address: str
st: str
fut: 'asyncio.Future[SSDPDatagram]'
class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None,
unicast: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: Optional[DatagramTransport] = None
self._unicast = unicast
self._ignored: Set[str] = ignored or set() # ignored locations
self._pending_searches: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
self._pending_searches: List[PendingSearch] = []
self.notifications: List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop)
self.devices: 'asyncio.Queue[SSDPDatagram]' = asyncio.Queue(loop=self.loop)
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
super().connection_made(transport)
@ -46,43 +52,58 @@ class SSDPProtocol(MulticastProtocol):
return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location not in self._ignored:
# TODO: fix this
tmp: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
set_futures: List['asyncio.Future[SSDPDatagram]'] = []
while len(self._pending_searches):
t = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f = t[2]
if f not in set_futures:
set_futures.append(f)
if not f.done():
f.set_result(packet)
elif t[2] not in set_futures:
tmp.append(t)
while tmp:
self._pending_searches.append(tmp.pop())
return None
if packet.location in self._ignored:
return
futures: Set['asyncio.Future[SSDPDatagram]'] = set()
replied: List[PendingSearch] = []
for pending in self._pending_searches:
# if pending.address == address and pending.st in (packet.st, "upnp:rootdevice"):
if pending.address == address and pending.st == packet.st:
replied.append(pending)
if pending.fut not in futures:
futures.add(pending.fut)
if replied:
self.devices.put_nowait(packet)
while replied:
self._pending_searches.remove(replied.pop())
while futures:
fut = futures.pop()
if not fut.done():
fut.set_result(packet)
def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS
if not self.transport:
if not fut.done():
fut.set_exception(UPnPError("SSDP transport not connected"))
return None
log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None
return
if fut.done():
return
self._pending_searches.append(
PendingSearch(address, packet.st, fut)
)
async def m_search(self, address: str, timeout: float,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
self.transport.sendto(packet.encode().encode(), (SSDP_IP_ADDRESS, SSDP_PORT))
# also send unicast
log.debug("send m search to %s: %s", address, packet.st)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
def send_m_searches(self, address: str,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> 'asyncio.Future[SSDPDatagram]':
fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future()
for datagram in datagrams:
packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None
self._pending_searches.append(
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut))
)
self._send_m_search(address, packet, fut)
return fut
async def m_search(self, address: str, timeout: float,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
fut = self.send_m_searches(address, datagrams)
return await asyncio.wait_for(fut, timeout, loop=self.loop)
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore
@ -95,7 +116,6 @@ class SSDPProtocol(MulticastProtocol):
log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data))
return None
if packet._packet_type == packet._OK:
self._callback_m_search_ok(addr[0], packet)
return None
@ -120,13 +140,12 @@ class SSDPProtocol(MulticastProtocol):
async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Optional[Set[str]] = None,
unicast: bool = False) -> Tuple[SSDPProtocol, str, str]:
ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop()
try:
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock
)
protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol)
@ -140,9 +159,9 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram:
ignored: Set[str] = None) -> SSDPDatagram:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
lan_address, gateway_address, loop, ignored
)
try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
@ -154,56 +173,13 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram:
ignored: Set[str] = None) -> SSDPProtocol:
loop = loop or asyncio.get_event_loop()
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
lan_address, gateway_address, loop, ignored
)
datagram_args = list(packet_generator())
try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=datagram_args)
except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:
protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None,
unicast: bool = False) -> List[Dict[str, typing.Union[int, str]]]:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
)
await protocol.connected.wait()
packet_args = list(packet_generator())
batch_size = 2
batch_timeout = float(timeout) / float(len(packet_args))
while packet_args:
args = packet_args[:batch_size]
packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try:
await protocol.m_search(gateway_address, batch_timeout, args)
except asyncio.TimeoutError:
continue
else:
protocol.disconnect()
return args
protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None,
unicast: bool = False) -> Tuple[Dict[str, typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked
for args in args_to_try:
try:
packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast)
return args, packet
except UPnPError:
continue
raise UPnPError("failed to discover gateway")
fut = asyncio.ensure_future(protocol.send_m_searches(
address=gateway_address, datagrams=list(packet_generator())
), loop=loop)
loop.call_later(timeout, lambda: None if not fut or fut.done() else fut.cancel())
return protocol

View file

@ -4,12 +4,10 @@
import logging
import json
import asyncio
from collections import OrderedDict
from typing import Tuple, Dict, List, Union, Optional, Any
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
@ -61,7 +59,7 @@ class UPnP:
return lan_address, gateway_address
@classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 3,
igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default',
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP':
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -72,7 +70,7 @@ class UPnP:
@classmethod
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
unicast: bool = True, interface_name: str = 'default',
interface_name: str = 'default',
igd_args: Optional[Dict[str, Union[str, int]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
@ -82,7 +80,6 @@ class UPnP:
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
@ -101,16 +98,14 @@ class UPnP:
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
str(err)))
if not igd_args:
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast)
else:
igd_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast)
gateway = await Gateway.discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop
)
return {
'lan_address': lan_address,
'gateway_address': gateway_address,
'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
'discover_reply': datagram.as_dict()
# 'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
'discover_reply': gateway._ok_packet.as_dict()
}
async def get_external_ip(self) -> str:
@ -372,20 +367,20 @@ cli_commands = [
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None,
gateway_address: str = '', timeout: int = 3, interface_name: str = 'default',
kwargs: Optional[Dict[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
fut: 'asyncio.Future' = loop.create_future()
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop
lan_address, gateway_address, timeout, interface_name, igd_args, loop
)
else: # automatically discover the gateway
try: