Faster gateway discovery #18
5 changed files with 150 additions and 177 deletions
|
@ -75,8 +75,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
|
|||
'interface': 'default',
|
||||
'gateway_address': '',
|
||||
'lan_address': '',
|
||||
'timeout': 30,
|
||||
'unicast': False
|
||||
'timeout': 3,
|
||||
}
|
||||
|
||||
options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
|
||||
|
@ -114,10 +113,9 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
|
|||
gateway_address: str = str(options.pop('gateway_address'))
|
||||
timeout: int = int(options.pop('timeout'))
|
||||
interface: str = str(options.pop('interface'))
|
||||
unicast: bool = bool(options.pop('unicast'))
|
||||
|
||||
run_cli(
|
||||
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
|
||||
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, kwargs, loop
|
||||
)
|
||||
return 0
|
||||
|
||||
|
|
|
@ -2,13 +2,12 @@ import re
|
|||
import logging
|
||||
import typing
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
from aioupnp.util import get_dict_val_case_insensitive
|
||||
from aioupnp.constants import SPEC_VERSION, SERVICE
|
||||
from aioupnp.commands import SOAPCommands, SCPDRequestDebuggingInfo
|
||||
from aioupnp.commands import SOAPCommands
|
||||
from aioupnp.device import Device, Service
|
||||
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, multi_m_search
|
||||
from aioupnp.protocols.ssdp import m_search, multi_m_search
|
||||
from aioupnp.protocols.scpd import scpd_get
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.util import flatten_keys
|
||||
|
@ -70,7 +69,7 @@ def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
|
|||
|
||||
class Gateway:
|
||||
def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._ok_packet = ok_packet
|
||||
self._lan_address = lan_address
|
||||
|
@ -90,10 +89,10 @@ class Gateway:
|
|||
assert self.base_ip == gateway_address.encode()
|
||||
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
|
||||
|
||||
self.spec_version: typing.Optional[str] = None
|
||||
self.url_base: typing.Optional[str] = None
|
||||
self.spec_version: Optional[str] = None
|
||||
self.url_base: Optional[str] = None
|
||||
|
||||
self._device: typing.Optional[Device] = None
|
||||
self._device: Optional[Device] = None
|
||||
self._devices: List[Device] = []
|
||||
self._services: List[Service] = []
|
||||
|
||||
|
@ -128,7 +127,7 @@ class Gateway:
|
|||
devices[device.udn] = device
|
||||
return devices
|
||||
|
||||
# def get_service(self, service_type: str) -> typing.Optional[Service]:
|
||||
# def get_service(self, service_type: str) -> Optional[Service]:
|
||||
# for service in self._services:
|
||||
# if service.serviceType and service.serviceType.lower() == service_type.lower():
|
||||
# return service
|
||||
|
@ -155,73 +154,78 @@ class Gateway:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
unicast: bool = False) -> 'Gateway':
|
||||
ignored: typing.Set[str] = set()
|
||||
async def _try_gateway_from_ssdp(cls, datagram: SSDPDatagram, lan_address: str,
|
||||
gateway_address: str,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional['Gateway']:
|
||||
required_commands: typing.List[str] = [
|
||||
'AddPortMapping',
|
||||
'DeletePortMapping',
|
||||
'GetExternalIPAddress'
|
||||
]
|
||||
while True:
|
||||
if not igd_args:
|
||||
datagram = await multi_m_search(
|
||||
lan_address, gateway_address, timeout, loop, ignored, unicast
|
||||
)
|
||||
else:
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
|
||||
try:
|
||||
gateway = cls(datagram, lan_address, gateway_address, loop=loop)
|
||||
log.debug('get gateway descriptor %s', datagram.location)
|
||||
await gateway.discover_commands()
|
||||
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
|
||||
if not requirements_met:
|
||||
not_met = [
|
||||
required for required in required_commands if not gateway.commands.is_registered(required)
|
||||
]
|
||||
assert datagram.location is not None
|
||||
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
|
||||
gateway.manufacturer_string, gateway.location, not_met)
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
else:
|
||||
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
|
||||
return gateway
|
||||
except (asyncio.TimeoutError, UPnPError) as err:
|
||||
try:
|
||||
gateway = cls(datagram, lan_address, gateway_address, loop=loop)
|
||||
log.debug('get gateway descriptor %s', datagram.location)
|
||||
await gateway.discover_commands()
|
||||
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
|
||||
if not requirements_met:
|
||||
not_met = [
|
||||
required for required in required_commands if not gateway.commands.is_registered(required)
|
||||
]
|
||||
assert datagram.location is not None
|
||||
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
|
||||
gateway.manufacturer_string, gateway.location, not_met)
|
||||
return None
|
||||
else:
|
||||
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
|
||||
return gateway
|
||||
except (asyncio.TimeoutError, UPnPError) as err:
|
||||
assert datagram.location is not None
|
||||
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
unicast: typing.Optional[bool] = None) -> 'Gateway':
|
||||
async def _gateway_from_igd_args(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set())
|
||||
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
|
||||
if not gateway:
|
||||
raise UPnPError("no gateway found for given args")
|
||||
return gateway
|
||||
|
||||
@classmethod
|
||||
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||
ignored: typing.Set[str] = set()
|
||||
ssdp_proto = await multi_m_search(
|
||||
lan_address, gateway_address, timeout, loop, ignored
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
datagram = await ssdp_proto.devices.get()
|
||||
if datagram.location in ignored:
|
||||
continue
|
||||
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
|
||||
if gateway:
|
||||
return gateway
|
||||
else:
|
||||
ignored.add(datagram.location)
|
||||
finally:
|
||||
ssdp_proto.disconnect()
|
||||
|
||||
@classmethod
|
||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
|
||||
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
if unicast is not None:
|
||||
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
|
||||
|
||||
done, pending = await asyncio.wait([
|
||||
cls._discover_gateway(
|
||||
lan_address, gateway_address, timeout, igd_args, loop, unicast=True
|
||||
),
|
||||
cls._discover_gateway(
|
||||
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
|
||||
)
|
||||
], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
try:
|
||||
task.exception()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
results: typing.List['asyncio.Future[Gateway]'] = list(done)
|
||||
return results[0].result()
|
||||
if igd_args:
|
||||
return await cls._gateway_from_igd_args(lan_address, gateway_address, timeout, igd_args, loop)
|
||||
try:
|
||||
return await asyncio.wait_for(loop.create_task(
|
||||
cls._discover_gateway(lan_address, gateway_address, timeout, loop)
|
||||
), timeout, loop=loop)
|
||||
except asyncio.TimeoutError:
|
||||
raise UPnPError(f"M-SEARCH for {gateway_address}:1900 timed out")
|
||||
|
||||
async def discover_commands(self) -> None:
|
||||
response, xml_bytes, get_err = await scpd_get(
|
||||
|
@ -266,7 +270,7 @@ class Gateway:
|
|||
return None
|
||||
|
||||
async def register_commands(self, service: Service,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
if not service.SCPDURL:
|
||||
raise UPnPError("no scpd url")
|
||||
if not service.serviceType:
|
||||
|
|
|
@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int,
|
|||
typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
packet = serialize_scpd_get(control_url, address)
|
||||
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
|
||||
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
|
||||
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
|
||||
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
||||
proto_factory, address, port
|
||||
|
@ -141,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
|
|||
**kwargs: typing.Dict[str, typing.Any]
|
||||
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
|
||||
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
|
||||
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
||||
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
|
||||
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
|
||||
|
|
|
@ -17,17 +17,23 @@ ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PendingSearch(typing.NamedTuple):
|
||||
address: str
|
||||
st: str
|
||||
fut: 'asyncio.Future[SSDPDatagram]'
|
||||
|
||||
|
||||
class SSDPProtocol(MulticastProtocol):
|
||||
def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None,
|
||||
unicast: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(multicast_address, lan_address)
|
||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self.transport: Optional[DatagramTransport] = None
|
||||
self._unicast = unicast
|
||||
self._ignored: Set[str] = ignored or set() # ignored locations
|
||||
self._pending_searches: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
|
||||
self._pending_searches: List[PendingSearch] = []
|
||||
self.notifications: List[SSDPDatagram] = []
|
||||
self.connected = asyncio.Event(loop=self.loop)
|
||||
self.devices: 'asyncio.Queue[SSDPDatagram]' = asyncio.Queue(loop=self.loop)
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
|
||||
super().connection_made(transport)
|
||||
|
@ -46,43 +52,58 @@ class SSDPProtocol(MulticastProtocol):
|
|||
return None
|
||||
|
||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||
if packet.location not in self._ignored:
|
||||
# TODO: fix this
|
||||
tmp: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
|
||||
set_futures: List['asyncio.Future[SSDPDatagram]'] = []
|
||||
while len(self._pending_searches):
|
||||
t = self._pending_searches.pop()
|
||||
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
|
||||
f = t[2]
|
||||
if f not in set_futures:
|
||||
set_futures.append(f)
|
||||
if not f.done():
|
||||
f.set_result(packet)
|
||||
elif t[2] not in set_futures:
|
||||
tmp.append(t)
|
||||
while tmp:
|
||||
self._pending_searches.append(tmp.pop())
|
||||
return None
|
||||
if packet.location in self._ignored:
|
||||
return
|
||||
|
||||
futures: Set['asyncio.Future[SSDPDatagram]'] = set()
|
||||
replied: List[PendingSearch] = []
|
||||
|
||||
for pending in self._pending_searches:
|
||||
# if pending.address == address and pending.st in (packet.st, "upnp:rootdevice"):
|
||||
if pending.address == address and pending.st == packet.st:
|
||||
replied.append(pending)
|
||||
if pending.fut not in futures:
|
||||
futures.add(pending.fut)
|
||||
if replied:
|
||||
self.devices.put_nowait(packet)
|
||||
|
||||
while replied:
|
||||
self._pending_searches.remove(replied.pop())
|
||||
|
||||
while futures:
|
||||
fut = futures.pop()
|
||||
if not fut.done():
|
||||
fut.set_result(packet)
|
||||
|
||||
def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
|
||||
dest = address if self._unicast else SSDP_IP_ADDRESS
|
||||
if not self.transport:
|
||||
if not fut.done():
|
||||
fut.set_exception(UPnPError("SSDP transport not connected"))
|
||||
return None
|
||||
log.debug("send m search to %s: %s", dest, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
|
||||
return None
|
||||
return
|
||||
if fut.done():
|
||||
return
|
||||
self._pending_searches.append(
|
||||
PendingSearch(address, packet.st, fut)
|
||||
)
|
||||
|
||||
async def m_search(self, address: str, timeout: float,
|
||||
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
|
||||
self.transport.sendto(packet.encode().encode(), (SSDP_IP_ADDRESS, SSDP_PORT))
|
||||
|
||||
# also send unicast
|
||||
log.debug("send m search to %s: %s", address, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
|
||||
|
||||
def send_m_searches(self, address: str,
|
||||
datagrams: List[Dict[str, typing.Union[str, int]]]) -> 'asyncio.Future[SSDPDatagram]':
|
||||
fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future()
|
||||
for datagram in datagrams:
|
||||
packet = SSDPDatagram("M-SEARCH", datagram)
|
||||
assert packet.st is not None
|
||||
self._pending_searches.append(
|
||||
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut))
|
||||
)
|
||||
self._send_m_search(address, packet, fut)
|
||||
return fut
|
||||
|
||||
async def m_search(self, address: str, timeout: float,
|
||||
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
|
||||
fut = self.send_m_searches(address, datagrams)
|
||||
return await asyncio.wait_for(fut, timeout, loop=self.loop)
|
||||
|
||||
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore
|
||||
|
@ -95,7 +116,6 @@ class SSDPProtocol(MulticastProtocol):
|
|||
log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
|
||||
binascii.hexlify(data))
|
||||
return None
|
||||
|
||||
if packet._packet_type == packet._OK:
|
||||
self._callback_m_search_ok(addr[0], packet)
|
||||
return None
|
||||
|
@ -120,13 +140,12 @@ class SSDPProtocol(MulticastProtocol):
|
|||
|
||||
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Optional[Set[str]] = None,
|
||||
unicast: bool = False) -> Tuple[SSDPProtocol, str, str]:
|
||||
ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
try:
|
||||
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
|
||||
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock
|
||||
)
|
||||
protocol = listen_result[1]
|
||||
assert isinstance(protocol, SSDPProtocol)
|
||||
|
@ -140,9 +159,9 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy
|
|||
|
||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
|
||||
timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram:
|
||||
ignored: Set[str] = None) -> SSDPDatagram:
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored, unicast
|
||||
lan_address, gateway_address, loop, ignored
|
||||
)
|
||||
try:
|
||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
||||
|
@ -154,56 +173,13 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s
|
|||
|
||||
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram:
|
||||
ignored: Set[str] = None) -> SSDPProtocol:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored, unicast
|
||||
lan_address, gateway_address, loop, ignored
|
||||
)
|
||||
datagram_args = list(packet_generator())
|
||||
try:
|
||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=datagram_args)
|
||||
except asyncio.TimeoutError:
|
||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
finally:
|
||||
protocol.disconnect()
|
||||
|
||||
|
||||
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Set[str] = None,
|
||||
unicast: bool = False) -> List[Dict[str, typing.Union[int, str]]]:
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored, unicast
|
||||
)
|
||||
await protocol.connected.wait()
|
||||
packet_args = list(packet_generator())
|
||||
batch_size = 2
|
||||
batch_timeout = float(timeout) / float(len(packet_args))
|
||||
while packet_args:
|
||||
args = packet_args[:batch_size]
|
||||
packet_args = packet_args[batch_size:]
|
||||
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
|
||||
try:
|
||||
await protocol.m_search(gateway_address, batch_timeout, args)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
else:
|
||||
protocol.disconnect()
|
||||
return args
|
||||
protocol.disconnect()
|
||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
|
||||
|
||||
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Set[str] = None,
|
||||
unicast: bool = False) -> Tuple[Dict[str, typing.Union[int, str]], SSDPDatagram]:
|
||||
# we don't know which packet the gateway replies to, so send small batches at a time
|
||||
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
|
||||
# check the args in the batch that got a reply one at a time to see which one worked
|
||||
for args in args_to_try:
|
||||
try:
|
||||
packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast)
|
||||
return args, packet
|
||||
except UPnPError:
|
||||
continue
|
||||
raise UPnPError("failed to discover gateway")
|
||||
fut = asyncio.ensure_future(protocol.send_m_searches(
|
||||
address=gateway_address, datagrams=list(packet_generator())
|
||||
), loop=loop)
|
||||
loop.call_later(timeout, lambda: None if not fut or fut.done() else fut.cancel())
|
||||
return protocol
|
||||
|
|
|
@ -4,12 +4,10 @@
|
|||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Dict, List, Union, Optional, Any
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.gateway import Gateway
|
||||
from aioupnp.interfaces import get_gateway_and_lan_addresses
|
||||
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
|
||||
|
||||
|
@ -61,7 +59,7 @@ class UPnP:
|
|||
return lan_address, gateway_address
|
||||
|
||||
@classmethod
|
||||
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
|
||||
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 3,
|
||||
igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default',
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP':
|
||||
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
||||
|
@ -72,7 +70,7 @@ class UPnP:
|
|||
|
||||
@classmethod
|
||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||
unicast: bool = True, interface_name: str = 'default',
|
||||
interface_name: str = 'default',
|
||||
igd_args: Optional[Dict[str, Union[str, int]]] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
|
||||
|
@ -82,7 +80,6 @@ class UPnP:
|
|||
:param lan_address: (str) the local interface ipv4 address
|
||||
:param gateway_address: (str) the gateway ipv4 address
|
||||
:param timeout: (int) m search timeout
|
||||
:param unicast: (bool) use unicast
|
||||
:param interface_name: (str) name of the network interface
|
||||
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
|
||||
|
||||
|
@ -101,16 +98,14 @@ class UPnP:
|
|||
except Exception as err:
|
||||
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
|
||||
str(err)))
|
||||
if not igd_args:
|
||||
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast)
|
||||
else:
|
||||
igd_args = OrderedDict(igd_args)
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast)
|
||||
gateway = await Gateway.discover_gateway(
|
||||
lan_address, gateway_address, timeout, igd_args, loop
|
||||
)
|
||||
return {
|
||||
'lan_address': lan_address,
|
||||
'gateway_address': gateway_address,
|
||||
'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
|
||||
'discover_reply': datagram.as_dict()
|
||||
# 'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
|
||||
'discover_reply': gateway._ok_packet.as_dict()
|
||||
}
|
||||
|
||||
async def get_external_ip(self) -> str:
|
||||
|
@ -372,20 +367,20 @@ cli_commands = [
|
|||
|
||||
|
||||
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
||||
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None,
|
||||
gateway_address: str = '', timeout: int = 3, interface_name: str = 'default',
|
||||
kwargs: Optional[Dict[str, str]] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
|
||||
kwargs = kwargs or {}
|
||||
igd_args = igd_args
|
||||
timeout = int(timeout)
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
|
||||
fut: 'asyncio.Future' = loop.create_future()
|
||||
|
||||
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
||||
if method == 'm_search': # if we're only m_searching don't do any device discovery
|
||||
fn = lambda *_a, **_kw: UPnP.m_search(
|
||||
lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop
|
||||
lan_address, gateway_address, timeout, interface_name, igd_args, loop
|
||||
)
|
||||
else: # automatically discover the gateway
|
||||
try:
|
||||
|
|
Loading…
Reference in a new issue