Merge pull request #18 from lbryio/fast-discover

Faster gateway discovery
This commit is contained in:
Jack Robison 2019-10-25 16:12:40 -04:00 committed by GitHub
commit b3b020e01b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 245 additions and 244 deletions

View file

@ -75,8 +75,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
'interface': 'default',
'gateway_address': '',
'lan_address': '',
'timeout': 30,
'unicast': False
'timeout': 3,
}
options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict()
@ -114,10 +113,9 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
gateway_address: str = str(options.pop('gateway_address'))
timeout: int = int(options.pop('timeout'))
interface: str = str(options.pop('interface'))
unicast: bool = bool(options.pop('unicast'))
run_cli(
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop
command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, kwargs, loop
)
return 0

View file

@ -2,13 +2,12 @@ import re
import logging
import typing
import asyncio
from collections import OrderedDict
from typing import Dict, List
from typing import Dict, List, Optional
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands, SCPDRequestDebuggingInfo
from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from aioupnp.protocols.ssdp import m_search, multi_m_search
from aioupnp.protocols.scpd import scpd_get
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.util import flatten_keys
@ -69,12 +68,10 @@ def parse_location(location: bytes) -> typing.Tuple[bytes, int]:
class Gateway:
def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]],
lan_address: str, gateway_address: str,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._ok_packet = ok_packet
self._m_search_args = m_search_args
self._lan_address = lan_address
self.usn: bytes = (ok_packet.usn or '').encode()
self.ext: bytes = (ok_packet.ext or '').encode()
@ -92,10 +89,10 @@ class Gateway:
assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version: typing.Optional[str] = None
self.url_base: typing.Optional[str] = None
self.spec_version: Optional[str] = None
self.url_base: Optional[str] = None
self._device: typing.Optional[Device] = None
self._device: Optional[Device] = None
self._devices: List[Device] = []
self._services: List[Service] = []
@ -130,7 +127,7 @@ class Gateway:
devices[device.udn] = device
return devices
# def get_service(self, service_type: str) -> typing.Optional[Service]:
# def get_service(self, service_type: str) -> Optional[Service]:
# for service in self._services:
# if service.serviceType and service.serviceType.lower() == service_type.lower():
# return service
@ -149,7 +146,6 @@ class Gateway:
'gateway_xml': self._xml_response.decode(),
'services_xml': self._service_descriptors,
'services': {service.SCPDURL: service.as_dict() for service in self._services},
'm_search_args': OrderedDict(self._m_search_args),
'reply': self._ok_packet.as_dict(),
'soap_port': self.port,
'registered_soap_commands': self._registered_commands,
@ -158,74 +154,79 @@ class Gateway:
}
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: bool = False) -> 'Gateway':
ignored: typing.Set[str] = set()
async def _try_gateway_from_ssdp(cls, datagram: SSDPDatagram, lan_address: str,
gateway_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional['Gateway']:
required_commands: typing.List[str] = [
'AddPortMapping',
'DeletePortMapping',
'GetExternalIPAddress'
]
while True:
if not igd_args:
m_search_args, datagram = await fuzzy_m_search(
lan_address, gateway_address, timeout, loop, ignored, unicast
)
else:
m_search_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast)
try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
continue
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
try:
gateway = cls(datagram, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
required for required in required_commands if not gateway.commands.is_registered(required)
]
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
ignored.add(datagram.location)
continue
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
return None
else:
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None
log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err))
return None
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
unicast: typing.Optional[bool] = None) -> 'Gateway':
async def _gateway_from_igd_args(cls, lan_address: str, gateway_address: str,
igd_args: typing.Dict[str, typing.Union[int, str]],
timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop)
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if not gateway:
raise UPnPError("no gateway found for given args")
return gateway
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
ignored: typing.Set[str] = set()
ssdp_proto = await multi_m_search(
lan_address, gateway_address, timeout, loop
)
try:
while True:
datagram = await ssdp_proto.devices.get()
if datagram.location in ignored:
continue
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if gateway:
return gateway
elif datagram.location:
ignored.add(datagram.location)
finally:
ssdp_proto.disconnect()
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3,
igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
loop = loop or asyncio.get_event_loop()
if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
done, pending = await asyncio.wait([
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=True
),
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
)
], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop)
for task in pending:
task.cancel()
for task in done:
try:
task.exception()
except asyncio.CancelledError:
pass
results: typing.List['asyncio.Future[Gateway]'] = list(done)
return results[0].result()
if igd_args:
return await cls._gateway_from_igd_args(lan_address, gateway_address, igd_args, timeout, loop)
try:
return await asyncio.wait_for(loop.create_task(
cls._discover_gateway(lan_address, gateway_address, timeout, loop)
), timeout, loop=loop)
except asyncio.TimeoutError:
raise UPnPError(f"M-SEARCH for {gateway_address}:1900 timed out")
async def discover_commands(self) -> None:
response, xml_bytes, get_err = await scpd_get(
@ -270,7 +271,7 @@ class Gateway:
return None
async def register_commands(self, service: Service,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
if not service.SCPDURL:
raise UPnPError("no scpd url")
if not service.serviceType:

View file

@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int,
typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
@ -141,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
**kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())

View file

@ -4,6 +4,7 @@ import asyncio
import logging
import typing
import socket
from typing import List, Set, Dict, Tuple, Optional
from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram
@ -16,17 +17,22 @@ ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-
log = logging.getLogger(__name__)
class PendingSearch(typing.NamedTuple):
address: str
st: str
fut: 'asyncio.Future[SSDPDatagram]'
class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None,
unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
def __init__(self, multicast_address: str, lan_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: typing.Optional[DatagramTransport] = None
self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
self.notifications: typing.List[SSDPDatagram] = []
self.transport: Optional[DatagramTransport] = None
self._pending_searches: List[PendingSearch] = []
self.notifications: List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop)
self.devices: 'asyncio.Queue[SSDPDatagram]' = asyncio.Queue(loop=self.loop)
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
super().connection_made(transport)
@ -45,46 +51,56 @@ class SSDPProtocol(MulticastProtocol):
return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location not in self._ignored:
# TODO: fix this
tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = []
while len(self._pending_searches):
t = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f = t[2]
if f not in set_futures:
set_futures.append(f)
if not f.done():
f.set_result(packet)
elif t[2] not in set_futures:
tmp.append(t)
while tmp:
self._pending_searches.append(tmp.pop())
return None
futures: Set['asyncio.Future[SSDPDatagram]'] = set()
replied: List[PendingSearch] = []
for pending in self._pending_searches:
# if pending.address == address and pending.st in (packet.st, "upnp:rootdevice"):
if pending.address == address and pending.st == packet.st:
replied.append(pending)
if pending.fut not in futures:
futures.add(pending.fut)
if replied:
self.devices.put_nowait(packet)
while replied:
self._pending_searches.remove(replied.pop())
while futures:
fut = futures.pop()
if not fut.done():
fut.set_result(packet)
def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS
if not self.transport:
if not fut.done():
fut.set_exception(UPnPError("SSDP transport not connected"))
return None
log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None
return
assert packet.st is not None
self._pending_searches.append(
PendingSearch(address, packet.st, fut)
)
self.transport.sendto(packet.encode().encode(), (SSDP_IP_ADDRESS, SSDP_PORT))
async def m_search(self, address: str, timeout: float,
datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
fut: 'asyncio.Future[SSDPDatagram]' = asyncio.Future(loop=self.loop)
# also send unicast
log.debug("send m search to %s: %s", address, packet.st)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
def send_m_searches(self, address: str,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> 'asyncio.Future[SSDPDatagram]':
fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future()
for datagram in datagrams:
packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None
self._pending_searches.append(
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut))
)
return await asyncio.wait_for(fut, timeout)
self._send_m_search(address, packet, fut)
return fut
def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore
async def m_search(self, address: str, timeout: float,
datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
fut = self.send_m_searches(address, datagrams)
return await asyncio.wait_for(fut, timeout, loop=self.loop)
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore
if addr[0] == self.bind_address:
return None
try:
@ -94,7 +110,6 @@ class SSDPProtocol(MulticastProtocol):
log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data))
return None
if packet._packet_type == packet._OK:
self._callback_m_search_ok(addr[0], packet)
return None
@ -118,14 +133,13 @@ class SSDPProtocol(MulticastProtocol):
# return
async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Optional[typing.Set[str]] = None,
unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]:
async def listen_ssdp(lan_address: str, gateway_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop()
try:
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
)
protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol)
@ -137,58 +151,28 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio
return protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]],
timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram:
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
timeout: int = 1,
loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPDatagram:
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
lan_address, gateway_address, loop
)
try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
except (asyncio.TimeoutError, asyncio.CancelledError):
except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:
protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]:
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPProtocol:
loop = loop or asyncio.get_event_loop()
protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast
lan_address, gateway_address, loop
)
await protocol.connected.wait()
packet_args = list(packet_generator())
batch_size = 2
batch_timeout = float(timeout) / float(len(packet_args))
while packet_args:
args = packet_args[:batch_size]
packet_args = packet_args[batch_size:]
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try:
await protocol.m_search(gateway_address, batch_timeout, args)
except asyncio.TimeoutError:
continue
else:
protocol.disconnect()
return args
protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None,
unicast: bool = False) -> typing.Tuple[typing.Dict[str,
typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked
for args in args_to_try:
try:
packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast)
return args, packet
except UPnPError:
continue
raise UPnPError("failed to discover gateway")
fut = asyncio.ensure_future(protocol.send_m_searches(
address=gateway_address, datagrams=list(packet_generator())
), loop=loop)
loop.call_later(timeout, lambda: None if not fut or fut.done() else fut.cancel())
return protocol

View file

@ -164,10 +164,15 @@ class SSDPDatagram:
@classmethod
def decode(cls, datagram: bytes) -> 'SSDPDatagram':
packet = cls._from_string(datagram.decode())
try:
packet = cls._from_string(datagram.decode())
except UnicodeDecodeError:
raise UPnPError(
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
)
if packet is None:
raise UPnPError(
"failed to decode datagram: {}".format(binascii.hexlify(datagram))
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
)
for attr_name in packet._required_fields[packet._packet_type]:
if getattr(packet, attr_name, None) is None:

View file

@ -4,12 +4,10 @@
import logging
import json
import asyncio
from collections import OrderedDict
from typing import Tuple, Dict, List, Union, Optional, Any
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
@ -61,7 +59,7 @@ class UPnP:
return lan_address, gateway_address
@classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 3,
igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default',
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP':
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -72,7 +70,7 @@ class UPnP:
@classmethod
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
unicast: bool = True, interface_name: str = 'default',
interface_name: str = 'default',
igd_args: Optional[Dict[str, Union[str, int]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
@ -82,7 +80,6 @@ class UPnP:
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
@ -101,16 +98,14 @@ class UPnP:
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
str(err)))
if not igd_args:
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast)
else:
igd_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast)
gateway = await Gateway.discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop
)
return {
'lan_address': lan_address,
'gateway_address': gateway_address,
'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
'discover_reply': datagram.as_dict()
# 'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(),
'discover_reply': gateway._ok_packet.as_dict()
}
async def get_external_ip(self) -> str:
@ -372,20 +367,20 @@ cli_commands = [
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None,
gateway_address: str = '', timeout: int = 3, interface_name: str = 'default',
kwargs: Optional[Dict[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
fut: 'asyncio.Future' = loop.create_future()
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop
lan_address, gateway_address, timeout, interface_name, igd_args, loop
)
else: # automatically discover the gateway
try:

View file

@ -16,7 +16,8 @@ except ImportError:
@contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False):
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
raise_oserror_on_bind=False):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {}
@ -72,7 +73,13 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
def bind(*_):
if raise_oserror_on_bind:
raise OSError()
return
mock_sock.bind = bind
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""

View file

@ -3,7 +3,7 @@ from aioupnp.fault import UPnPError
from aioupnp.protocols.m_search_patterns import packet_generator
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, SSDPProtocol
from aioupnp.protocols.ssdp import m_search, SSDPProtocol
from tests import AsyncioTestCase, mock_tcp_and_udp
@ -28,6 +28,11 @@ class TestSSDP(AsyncioTestCase):
])
reply_packet = SSDPDatagram("OK", reply_args)
async def test_socket_setup_error(self):
with mock_tcp_and_udp(self.loop, raise_oserror_on_bind=True):
with self.assertRaises(UPnPError):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
async def test_transport_not_connected_error(self):
try:
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
@ -35,6 +40,16 @@ class TestSSDP(AsyncioTestCase):
except UPnPError as err:
self.assertEqual(str(err), "SSDP transport not connected")
async def test_deadbeef_response(self):
replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): b'\xde\xad\xbe\xef'
}
sent = []
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
with self.assertRaises(UPnPError):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
async def test_m_search_reply_unicast(self):
replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
@ -42,14 +57,14 @@ class TestSSDP(AsyncioTestCase):
sent = []
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertListEqual(sent, [self.query_packet.encode().encode()])
self.assertIn(self.query_packet.encode().encode(), sent)
with self.assertRaises(UPnPError):
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False)
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.10", udp_replies=replies):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
async def test_m_search_reply_multicast(self):
replies = {
@ -61,39 +76,40 @@ class TestSSDP(AsyncioTestCase):
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertListEqual(sent, [self.query_packet.encode().encode()])
self.assertIn(self.query_packet.encode().encode(), sent)
with self.assertRaises(UPnPError):
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1"):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.10"):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
async def test_packets_sent_fuzzy_m_search(self):
sent = []
with self.assertRaises(UPnPError):
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertListEqual(sent, self.byte_packets)
async def test_packets_fuzzy_m_search(self):
replies = {
(self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode()
}
sent = []
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies, sent_udp_packets=sent):
args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertEqual(args, self.successful_args)
async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self):
sent = []
with self.assertRaises(UPnPError):
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent,
add_potato_datagrams=True):
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertListEqual(sent, self.byte_packets)
# async def test_packets_sent_fuzzy_m_search(self):
# sent = []
#
# with self.assertRaises(UPnPError):
# with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
# await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
# for packet in self.byte_packets:
# self.assertIn(packet, sent)
#
# async def test_packets_fuzzy_m_search(self):
# replies = {
# (self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode()
# }
# sent = []
#
# with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies, sent_udp_packets=sent):
# args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
#
# self.assertEqual(reply.encode(), self.reply_packet.encode())
# self.assertEqual(args, self.successful_args)
#
# async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self):
# sent = []
#
# with self.assertRaises(UPnPError):
# with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent,
# add_potato_datagrams=True):
# await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
#
# for packet in self.byte_packets:
# self.assertIn(packet, sent)

View file

@ -10,7 +10,6 @@ from aioupnp.__main__ import main
m_search_cli_result = """{
"lan_address": "10.0.0.2",
"gateway_address": "10.0.0.1",
"m_search_kwargs": "--HOST=239.255.255.250:1900 --MAN=ssdp:discover --MX=1 --ST=urn:schemas-upnp-org:device:WANDevice:1",
"discover_reply": {
"CACHE_CONTROL": "max-age=1800",
"LOCATION": "http://10.0.0.1:49152/InternetGatewayDevice.xml",
@ -22,14 +21,13 @@ m_search_cli_result = """{
m_search_help_msg = """aioupnp [-h] [--debug_logging] m_search [--lan_address=<str>] [--gateway_address=<str>]
[--timeout=<int>] [--unicast] [--interface_name=<str>] [--<header key>=<header value>, ...]
[--timeout=<int>] [--interface_name=<str>] [--<header key>=<header value>, ...]
Perform a M-SEARCH for a upnp gateway.
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
@ -219,7 +217,7 @@ class TestCLI(AsyncioTestCase):
actual_output = StringIO()
timeout_msg = "aioupnp encountered an error: M-SEARCH for 10.0.0.1:1900 timed out\n"
with contextlib.redirect_stdout(actual_output):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies={}, udp_replies={}):
main(
[None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'm-search'],
self.loop
@ -230,7 +228,7 @@ class TestCLI(AsyncioTestCase):
with contextlib.redirect_stdout(actual_output):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies):
main(
[None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', '--unicast', 'm-search'],
[None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'm-search'],
self.loop
)
self.assertEqual(m_search_cli_result, actual_output.getvalue())

View file

@ -176,8 +176,7 @@ class TestDiscoverDLinkDIR890L(AsyncioTestCase):
[('serviceType', 'urn:schemas-upnp-org:service:WANIPConnection:1'),
('serviceId', 'urn:upnp-org:serviceId:WANIPConn1'), ('controlURL', '/soap.cgi?service=WANIPConn1'),
('eventSubURL', '/gena.cgi?service=WANIPConn1'), ('SCPDURL', '/WANIPConnection.xml')])},
'm_search_args': OrderedDict([('HOST', '239.255.255.250:1900'), ('MAN', 'ssdp:discover'), ('MX', 1),
('ST', 'urn:schemas-upnp-org:device:WANDevice:1')]), 'reply': OrderedDict(
'reply': OrderedDict(
[('CACHE_CONTROL', 'max-age=1800'), ('LOCATION', 'http://10.0.0.1:49152/InternetGatewayDevice.xml'),
('SERVER', 'Linux, UPnP/1.0, DIR-890L Ver 1.20'), ('ST', 'urn:schemas-upnp-org:device:WANDevice:1'),
('USN', 'uuid:11111111-2222-3333-4444-555555555555::urn:schemas-upnp-org:device:WANDevice:1')]),
@ -232,14 +231,14 @@ class TestDiscoverDLinkDIR890L(AsyncioTestCase):
with self.assertRaises(UPnPError) as e2:
with mock_tcp_and_udp(self.loop):
await Gateway.discover_gateway(self.client_address, self.gateway_info['gateway_address'], 2,
unicast=False, loop=self.loop)
loop=self.loop)
self.assertEqual(str(e1.exception), f"M-SEARCH for {self.gateway_info['gateway_address']}:1900 timed out")
self.assertEqual(str(e2.exception), f"M-SEARCH for {self.gateway_info['gateway_address']}:1900 timed out")
async def test_discover_commands(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(
SSDPDatagram("OK", self.gateway_info['reply']), self.gateway_info['m_search_args'],
SSDPDatagram("OK", self.gateway_info['reply']),
self.client_address, self.gateway_info['gateway_address'], loop=self.loop
)
await gateway.discover_commands()
@ -274,9 +273,7 @@ class TestDiscoverNetgearNighthawkAC2350(TestDiscoverDLinkDIR890L):
[('serviceType', 'urn:schemas-upnp-org:service:WANIPConnection:1'),
('serviceId', 'urn:upnp-org:serviceId:WANIPConn1'), ('controlURL', '/ctl/IPConn'),
('eventSubURL', '/evt/IPConn'), ('SCPDURL', '/WANIPCn.xml')])},
'm_search_args': OrderedDict(
[('HOST', '239.255.255.250:1900'), ('MAN', '"ssdp:discover"'), ('MX', 1),
('ST', 'upnp:rootdevice')]), 'reply': OrderedDict(
'reply': OrderedDict(
[('CACHE_CONTROL', 'max-age=1800'), ('ST', 'upnp:rootdevice'),
('USN', 'uuid:11111111-2222-3333-4444-555555555555::upnp:rootdevice'),
('Server', 'R7500v2 UPnP/1.0 miniupnpd/1.0'), ('Location', 'http://192.168.0.1:5555/rootDesc.xml')]),

View file

@ -45,7 +45,7 @@ class TestGetExternalIPAddress(UPnPCommandTestCase):
self.replies.update({request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 365 \r\nEXT:\r\n\r\n<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\n<NewExternalIPAddress>11.222.3.44</NewExternalIPAddress>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
@ -57,7 +57,7 @@ class TestGetExternalIPAddress(UPnPCommandTestCase):
request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 354 \r\nEXT:\r\n\r\n<?xml version=\"1.0\"?>\n<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\n\t<s:Body>\n\t\t<u:GetExternalIPAddressResponse xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\n<NewExternalIPAddress></NewExternalIPAddress>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
with self.assertRaises(UPnPError):
@ -73,7 +73,7 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
b"<derp>11.222.3.44</derp>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
with self.assertRaises(UPnPError):
@ -84,7 +84,7 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
b"<newexternalipaddress>11.222.3.44</newexternalipaddress>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
@ -95,7 +95,7 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
b"11.222.3.44\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
@ -113,7 +113,7 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
async def test_get_port_mapping_by_index(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
result = await upnp.get_port_mapping_by_index(0)
@ -135,7 +135,7 @@ class TestGetNextPortMapping(UPnPCommandTestCase):
async def test_get_next_mapping(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
ext_port = await upnp.get_next_mapping(4567, "UDP", "aioupnp test mapping")
@ -155,7 +155,7 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase):
async def test_get_specific_port_mapping(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
try: