replace fuzzy_m_search with multi_m_search

-lower search timeout from 30s->3s
-send unicast/multicast at the same time, remove optional arg
This commit is contained in:
Jack Robison 2019-10-25 01:30:49 -04:00
parent d8f309f8fe
commit 655d2ff623
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 150 additions and 177 deletions

View file

@ -75,8 +75,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
'interface': 'default', 'interface': 'default',
'gateway_address': '', 'gateway_address': '',
'lan_address': '', 'lan_address': '',
'timeout': 30, 'timeout': 3,
'unicast': False
} }
options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict() options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
@ -114,10 +113,9 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
gateway_address: str = str(options.pop('gateway_address')) gateway_address: str = str(options.pop('gateway_address'))
timeout: int = int(options.pop('timeout')) timeout: int = int(options.pop('timeout'))
interface: str = str(options.pop('interface')) interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
run_cli( run_cli(
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, kwargs, loop
) )
return 0 return 0

View file

@ -2,13 +2,12 @@ import re
import logging import logging
import typing import typing
import asyncio import asyncio
from collections import OrderedDict from typing import Dict, List, Optional
from typing import Dict, List
from aioupnp.util import get_dict_val_case_insensitive from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands, SCPDRequestDebuggingInfo from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, multi_m_search from aioupnp.protocols.ssdp import m_search, multi_m_search
from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.scpd import scpd_get
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.util import flatten_keys from aioupnp.util import flatten_keys
@ -70,7 +69,7 @@ def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
class Gateway: class Gateway:
def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str, def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet self._ok_packet = ok_packet
self._lan_address = lan_address self._lan_address = lan_address
@ -90,10 +89,10 @@ class Gateway:
assert self.base_ip == gateway_address.encode() assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version: typing.Optional[str] = None self.spec_version: Optional[str] = None
self.url_base: typing.Optional[str] = None self.url_base: Optional[str] = None
self._device: typing.Optional[Device] = None self._device: Optional[Device] = None
self._devices: List[Device] = [] self._devices: List[Device] = []
self._services: List[Service] = [] self._services: List[Service] = []
@ -128,7 +127,7 @@ class Gateway:
devices[device.udn] = device devices[device.udn] = device
return devices return devices
# def get_service(self, service_type: str) -> typing.Optional[Service]: # def get_service(self, service_type: str) -> Optional[Service]:
# for service in self._services: # for service in self._services:
# if service.serviceType and service.serviceType.lower() == service_type.lower(): # if service.serviceType and service.serviceType.lower() == service_type.lower():
# return service # return service
@ -155,73 +154,78 @@ class Gateway:
} }
@classmethod @classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _try_gateway_from_ssdp(cls, datagram: SSDPDatagram, lan_address: str,
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional['Gateway']:
unicast: bool = False) -> 'Gateway':
ignored: typing.Set[str] = set()
required_commands: typing.List[str] = [ required_commands: typing.List[str] = [
'AddPortMapping', 'AddPortMapping',
'DeletePortMapping', 'DeletePortMapping',
'GetExternalIPAddress' 'GetExternalIPAddress'
] ]
while True: try:
if not igd_args: gateway = cls(datagram, lan_address, gateway_address, loop=loop)
datagram = await multi_m_search( log.debug('get gateway descriptor %s', datagram.location)
lan_address, gateway_address, timeout, loop, ignored, unicast await gateway.discover_commands()
) requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
else: if not requirements_met:
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast) not_met = [
try: required for required in required_commands if not gateway.commands.is_registered(required)
gateway = cls(datagram, lan_address, gateway_address, loop=loop) ]
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
continue
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
ignored.add(datagram.location) gateway.manufacturer_string, gateway.location, not_met)
continue return None
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
return None
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _gateway_from_igd_args(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None, igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
unicast: typing.Optional[bool] = None) -> 'Gateway': datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set())
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if not gateway:
raise UPnPError("no gateway found for given args")
return gateway
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
ignored: typing.Set[str] = set()
ssdp_proto = await multi_m_search(
lan_address, gateway_address, timeout, loop, ignored
)
try:
while True:
datagram = await ssdp_proto.devices.get()
if datagram.location in ignored:
continue
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if gateway:
return gateway
else:
ignored.add(datagram.location)
finally:
ssdp_proto.disconnect()
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
if unicast is not None: if igd_args:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) return await cls._gateway_from_igd_args(lan_address, gateway_address, timeout, igd_args, loop)
try:
done, pending = await asyncio.wait([ return await asyncio.wait_for(loop.create_task(
cls._discover_gateway( cls._discover_gateway(lan_address, gateway_address, timeout, loop)
lan_address, gateway_address, timeout, igd_args, loop, unicast=True ), timeout, loop=loop)
), except asyncio.TimeoutError:
cls._discover_gateway( raise UPnPError(f"M-SEARCH for {gateway_address}:1900 timed out")
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
)
], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
for task in pending:
task.cancel()
for task in done:
try:
task.exception()
except asyncio.CancelledError:
pass
results: typing.List['asyncio.Future[Gateway]'] = list(done)
return results[0].result()
async def discover_commands(self) -> None: async def discover_commands(self) -> None:
response, xml_bytes, get_err = await scpd_get( response, xml_bytes, get_err = await scpd_get(
@ -266,7 +270,7 @@ class Gateway:
return None return None
async def register_commands(self, service: Service, async def register_commands(self, service: Service,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL: if not service.SCPDURL:
raise UPnPError("no scpd url") raise UPnPError("no scpd url")
if not service.serviceType: if not service.serviceType:

View file

@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int,
typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]: typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop) finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished) proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port proto_factory, address, port
@ -141,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
**kwargs: typing.Dict[str, typing.Any] **kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: ) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop) finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\ proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode()) SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())

View file

@ -17,17 +17,23 @@ ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class PendingSearch(typing.NamedTuple):
address: str
st: str
fut: 'asyncio.Future[SSDPDatagram]'
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None, def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None,
unicast: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: Optional[DatagramTransport] = None self.transport: Optional[DatagramTransport] = None
self._unicast = unicast
self._ignored: Set[str] = ignored or set() # ignored locations self._ignored: Set[str] = ignored or set() # ignored locations
self._pending_searches: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] self._pending_searches: List[PendingSearch] = []
self.notifications: List[SSDPDatagram] = [] self.notifications: List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop) self.connected = asyncio.Event(loop=self.loop)
self.devices: 'asyncio.Queue[SSDPDatagram]' = asyncio.Queue(loop=self.loop)
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
super().connection_made(transport) super().connection_made(transport)
@ -46,43 +52,58 @@ class SSDPProtocol(MulticastProtocol):
return None return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location not in self._ignored: if packet.location in self._ignored:
# TODO: fix this return
tmp: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
set_futures: List['asyncio.Future[SSDPDatagram]'] = [] futures: Set['asyncio.Future[SSDPDatagram]'] = set()
while len(self._pending_searches): replied: List[PendingSearch] = []
t = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): for pending in self._pending_searches:
f = t[2] # if pending.address == address and pending.st in (packet.st, "upnp:rootdevice"):
if f not in set_futures: if pending.address == address and pending.st == packet.st:
set_futures.append(f) replied.append(pending)
if not f.done(): if pending.fut not in futures:
f.set_result(packet) futures.add(pending.fut)
elif t[2] not in set_futures: if replied:
tmp.append(t) self.devices.put_nowait(packet)
while tmp:
self._pending_searches.append(tmp.pop()) while replied:
return None self._pending_searches.remove(replied.pop())
while futures:
fut = futures.pop()
if not fut.done():
fut.set_result(packet)
def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None: def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS
if not self.transport: if not self.transport:
if not fut.done(): if not fut.done():
fut.set_exception(UPnPError("SSDP transport not connected")) fut.set_exception(UPnPError("SSDP transport not connected"))
return None return
log.debug("send m search to %s: %s", dest, packet.st) if fut.done():
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) return
return None self._pending_searches.append(
PendingSearch(address, packet.st, fut)
)
async def m_search(self, address: str, timeout: float, self.transport.sendto(packet.encode().encode(), (SSDP_IP_ADDRESS, SSDP_PORT))
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
# also send unicast
log.debug("send m search to %s: %s", address, packet.st)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
def send_m_searches(self, address: str,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> 'asyncio.Future[SSDPDatagram]':
fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future() fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future()
for datagram in datagrams: for datagram in datagrams:
packet = SSDPDatagram("M-SEARCH", datagram) packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None assert packet.st is not None
self._pending_searches.append( self._send_m_search(address, packet, fut)
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut)) return fut
)
async def m_search(self, address: str, timeout: float,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
fut = self.send_m_searches(address, datagrams)
return await asyncio.wait_for(fut, timeout, loop=self.loop) return await asyncio.wait_for(fut, timeout, loop=self.loop)
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore
@ -95,7 +116,6 @@ class SSDPProtocol(MulticastProtocol):
log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data)) binascii.hexlify(data))
return None return None
if packet._packet_type == packet._OK: if packet._packet_type == packet._OK:
self._callback_m_search_ok(addr[0], packet) self._callback_m_search_ok(addr[0], packet)
return None return None
@ -120,13 +140,12 @@ class SSDPProtocol(MulticastProtocol):
async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None, async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Optional[Set[str]] = None, ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]:
unicast: bool = False) -> Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
try: try:
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address) sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock
) )
protocol = listen_result[1] protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol) assert isinstance(protocol, SSDPProtocol)
@ -140,9 +159,9 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]], async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None, timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram: ignored: Set[str] = None) -> SSDPDatagram:
protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored
) )
try: try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
@ -154,56 +173,13 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3, async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram: ignored: Set[str] = None) -> SSDPProtocol:
loop = loop or asyncio.get_event_loop()
protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored
) )
datagram_args = list(packet_generator()) fut = asyncio.ensure_future(protocol.send_m_searches(
try: address=gateway_address, datagrams=list(packet_generator())
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=datagram_args) ), loop=loop)
except asyncio.TimeoutError: loop.call_later(timeout, lambda: None if not fut or fut.done() else fut.cancel())
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) return protocol
finally:
protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None,
unicast: bool = False) -> List[Dict[str, typing.Union[int, str]]]:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
)
await protocol.connected.wait()
packet_args = list(packet_generator())
batch_size = 2
batch_timeout = float(timeout) / float(len(packet_args))
while packet_args:
args = packet_args[:batch_size]
packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try:
await protocol.m_search(gateway_address, batch_timeout, args)
except asyncio.TimeoutError:
continue
else:
protocol.disconnect()
return args
protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: Set[str] = None,
unicast: bool = False) -> Tuple[Dict[str, typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked
for args in args_to_try:
try:
packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast)
return args, packet
except UPnPError:
continue
raise UPnPError("failed to discover gateway")

View file

@ -4,12 +4,10 @@
import logging import logging
import json import json
import asyncio import asyncio
from collections import OrderedDict
from typing import Tuple, Dict, List, Union, Optional, Any from typing import Tuple, Dict, List, Union, Optional, Any
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
@ -61,7 +59,7 @@ class UPnP:
return lan_address, gateway_address return lan_address, gateway_address
@classmethod @classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30, async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 3,
igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default', igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default',
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP': loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP':
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -72,7 +70,7 @@ class UPnP:
@classmethod @classmethod
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
unicast: bool = True, interface_name: str = 'default', interface_name: str = 'default',
igd_args: Optional[Dict[str, Union[str, int]]] = None, igd_args: Optional[Dict[str, Union[str, int]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None loop: Optional[asyncio.AbstractEventLoop] = None
) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]: ) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
@ -82,7 +80,6 @@ class UPnP:
:param lan_address: (str) the local interface ipv4 address :param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address :param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout :param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface :param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided. :param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
@ -101,16 +98,14 @@ class UPnP:
except Exception as err: except Exception as err:
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name, raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
str(err))) str(err)))
if not igd_args: gateway = await Gateway.discover_gateway(
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast) lan_address, gateway_address, timeout, igd_args, loop
else: )
igd_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast)
return { return {
'lan_address': lan_address, 'lan_address': lan_address,
'gateway_address': gateway_address, 'gateway_address': gateway_address,
'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(), # 'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
'discover_reply': datagram.as_dict() 'discover_reply': gateway._ok_packet.as_dict()
} }
async def get_external_ip(self) -> str: async def get_external_ip(self) -> str:
@ -372,20 +367,20 @@ cli_commands = [
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', gateway_address: str = '', timeout: int = 3, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None, kwargs: Optional[Dict[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
kwargs = kwargs or {} kwargs = kwargs or {}
igd_args = igd_args igd_args = igd_args
timeout = int(timeout) timeout = int(timeout)
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop) fut: 'asyncio.Future' = loop.create_future()
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
if method == 'm_search': # if we're only m_searching don't do any device discovery if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: UPnP.m_search( fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop lan_address, gateway_address, timeout, interface_name, igd_args, loop
) )
else: # automatically discover the gateway else: # automatically discover the gateway
try: try: