diff --git a/aioupnp/__main__.py b/aioupnp/__main__.py index 0f50e28..33ef8b6 100644 --- a/aioupnp/__main__.py +++ b/aioupnp/__main__.py @@ -75,8 +75,7 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None, 'interface': 'default', 'gateway_address': '', 'lan_address': '', - 'timeout': 30, - 'unicast': False + 'timeout': 3, } options: typing.Dict[str, typing.Union[bool, str, int]] = OrderedDict() @@ -114,10 +113,9 @@ def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None, gateway_address: str = str(options.pop('gateway_address')) timeout: int = int(options.pop('timeout')) interface: str = str(options.pop('interface')) - unicast: bool = bool(options.pop('unicast')) run_cli( - command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, unicast, kwargs, loop + command.replace('-', '_'), options, lan_address, gateway_address, timeout, interface, kwargs, loop ) return 0 diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index 95dd81c..d59b889 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -2,13 +2,12 @@ import re import logging import typing import asyncio -from collections import OrderedDict -from typing import Dict, List +from typing import Dict, List, Optional from aioupnp.util import get_dict_val_case_insensitive from aioupnp.constants import SPEC_VERSION, SERVICE -from aioupnp.commands import SOAPCommands, SCPDRequestDebuggingInfo +from aioupnp.commands import SOAPCommands from aioupnp.device import Device, Service -from aioupnp.protocols.ssdp import fuzzy_m_search, m_search +from aioupnp.protocols.ssdp import m_search, multi_m_search from aioupnp.protocols.scpd import scpd_get from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.util import flatten_keys @@ -69,12 +68,10 @@ def parse_location(location: bytes) -> typing.Tuple[bytes, int]: class Gateway: - def __init__(self, ok_packet: SSDPDatagram, m_search_args: typing.Dict[str, typing.Union[int, str]], - lan_address: str, gateway_address: str, - loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: + def __init__(self, ok_packet: SSDPDatagram, lan_address: str, gateway_address: str, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: self._loop = loop or asyncio.get_event_loop() self._ok_packet = ok_packet - self._m_search_args = m_search_args self._lan_address = lan_address self.usn: bytes = (ok_packet.usn or '').encode() self.ext: bytes = (ok_packet.ext or '').encode() @@ -92,10 +89,10 @@ class Gateway: assert self.base_ip == gateway_address.encode() self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] - self.spec_version: typing.Optional[str] = None - self.url_base: typing.Optional[str] = None + self.spec_version: Optional[str] = None + self.url_base: Optional[str] = None - self._device: typing.Optional[Device] = None + self._device: Optional[Device] = None self._devices: List[Device] = [] self._services: List[Service] = [] @@ -130,7 +127,7 @@ class Gateway: devices[device.udn] = device return devices - # def get_service(self, service_type: str) -> typing.Optional[Service]: + # def get_service(self, service_type: str) -> Optional[Service]: # for service in self._services: # if service.serviceType and service.serviceType.lower() == service_type.lower(): # return service @@ -149,7 +146,6 @@ class Gateway: 'gateway_xml': self._xml_response.decode(), 'services_xml': self._service_descriptors, 'services': {service.SCPDURL: service.as_dict() for service in self._services}, - 'm_search_args': OrderedDict(self._m_search_args), 'reply': self._ok_packet.as_dict(), 'soap_port': self.port, 'registered_soap_commands': self._registered_commands, @@ -158,74 +154,79 @@ class Gateway: } @classmethod - async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, - igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - unicast: bool = False) -> 'Gateway': - ignored: typing.Set[str] = set() + async def _try_gateway_from_ssdp(cls, datagram: SSDPDatagram, lan_address: str, + gateway_address: str, + loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional['Gateway']: required_commands: typing.List[str] = [ 'AddPortMapping', 'DeletePortMapping', 'GetExternalIPAddress' ] - while True: - if not igd_args: - m_search_args, datagram = await fuzzy_m_search( - lan_address, gateway_address, timeout, loop, ignored, unicast - ) - else: - m_search_args = OrderedDict(igd_args) - datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, ignored, unicast) - try: - gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop) - log.debug('get gateway descriptor %s', datagram.location) - await gateway.discover_commands() - requirements_met = all([gateway.commands.is_registered(required) for required in required_commands]) - if not requirements_met: - not_met = [ - required for required in required_commands if not gateway.commands.is_registered(required) - ] - assert datagram.location is not None - log.debug("found gateway %s at %s, but it does not implement required soap commands: %s", - gateway.manufacturer_string, gateway.location, not_met) - ignored.add(datagram.location) - continue - else: - log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location) - return gateway - except (asyncio.TimeoutError, UPnPError) as err: + try: + gateway = cls(datagram, lan_address, gateway_address, loop=loop) + log.debug('get gateway descriptor %s', datagram.location) + await gateway.discover_commands() + requirements_met = all([gateway.commands.is_registered(required) for required in required_commands]) + if not requirements_met: + not_met = [ + required for required in required_commands if not gateway.commands.is_registered(required) + ] assert datagram.location is not None - log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) - ignored.add(datagram.location) - continue + log.debug("found gateway %s at %s, but it does not implement required soap commands: %s", + gateway.manufacturer_string, gateway.location, not_met) + return None + else: + log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location) + return gateway + except (asyncio.TimeoutError, UPnPError) as err: + assert datagram.location is not None + log.debug("get %s failed (%s), looking for other devices", datagram.location, str(err)) + return None @classmethod - async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, - igd_args: typing.Optional[typing.Dict[str, typing.Union[int, str]]] = None, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - unicast: typing.Optional[bool] = None) -> 'Gateway': + async def _gateway_from_igd_args(cls, lan_address: str, gateway_address: str, + igd_args: typing.Dict[str, typing.Union[int, str]], + timeout: int = 30, + loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': + datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop) + gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop) + if not gateway: + raise UPnPError("no gateway found for given args") + return gateway + + @classmethod + async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3, + loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': + ignored: typing.Set[str] = set() + ssdp_proto = await multi_m_search( + lan_address, gateway_address, timeout, loop + ) + try: + while True: + datagram = await ssdp_proto.devices.get() + if datagram.location in ignored: + continue + gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop) + if gateway: + return gateway + elif datagram.location: + ignored.add(datagram.location) + finally: + ssdp_proto.disconnect() + + @classmethod + async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 3, + igd_args: Optional[typing.Dict[str, typing.Union[int, str]]] = None, + loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': loop = loop or asyncio.get_event_loop() - if unicast is not None: - return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) - - done, pending = await asyncio.wait([ - cls._discover_gateway( - lan_address, gateway_address, timeout, igd_args, loop, unicast=True - ), - cls._discover_gateway( - lan_address, gateway_address, timeout, igd_args, loop, unicast=False - ) - ], return_when=asyncio.tasks.FIRST_COMPLETED, loop=loop) - - for task in pending: - task.cancel() - for task in done: - try: - task.exception() - except asyncio.CancelledError: - pass - results: typing.List['asyncio.Future[Gateway]'] = list(done) - return results[0].result() + if igd_args: + return await cls._gateway_from_igd_args(lan_address, gateway_address, igd_args, timeout, loop) + try: + return await asyncio.wait_for(loop.create_task( + cls._discover_gateway(lan_address, gateway_address, timeout, loop) + ), timeout, loop=loop) + except asyncio.TimeoutError: + raise UPnPError(f"M-SEARCH for {gateway_address}:1900 timed out") async def discover_commands(self) -> None: response, xml_bytes, get_err = await scpd_get( @@ -270,7 +271,7 @@ class Gateway: return None async def register_commands(self, service: Service, - loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: if not service.SCPDURL: raise UPnPError("no scpd url") if not service.serviceType: diff --git a/aioupnp/protocols/scpd.py b/aioupnp/protocols/scpd.py index a85baab..cc57bb9 100644 --- a/aioupnp/protocols/scpd.py +++ b/aioupnp/protocols/scpd.py @@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int, typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]: loop = loop or asyncio.get_event_loop() packet = serialize_scpd_get(control_url, address) - finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop) + finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future() proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished) connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( proto_factory, address, port @@ -141,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para **kwargs: typing.Dict[str, typing.Any] ) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]: loop = loop or asyncio.get_event_loop() - finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop) + finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future() packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\ SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode()) diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 0f8782d..c7babb8 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -4,6 +4,7 @@ import asyncio import logging import typing import socket +from typing import List, Set, Dict, Tuple, Optional from asyncio.transports import DatagramTransport from aioupnp.fault import UPnPError from aioupnp.serialization.ssdp import SSDPDatagram @@ -16,17 +17,22 @@ ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\- log = logging.getLogger(__name__) +class PendingSearch(typing.NamedTuple): + address: str + st: str + fut: 'asyncio.Future[SSDPDatagram]' + + class SSDPProtocol(MulticastProtocol): - def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None, - unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: + def __init__(self, multicast_address: str, lan_address: str, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: super().__init__(multicast_address, lan_address) self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() - self.transport: typing.Optional[DatagramTransport] = None - self._unicast = unicast - self._ignored: typing.Set[str] = ignored or set() # ignored locations - self._pending_searches: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] - self.notifications: typing.List[SSDPDatagram] = [] + self.transport: Optional[DatagramTransport] = None + self._pending_searches: List[PendingSearch] = [] + self.notifications: List[SSDPDatagram] = [] self.connected = asyncio.Event(loop=self.loop) + self.devices: 'asyncio.Queue[SSDPDatagram]' = asyncio.Queue(loop=self.loop) def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore super().connection_made(transport) @@ -45,46 +51,56 @@ class SSDPProtocol(MulticastProtocol): return None def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: - if packet.location not in self._ignored: - # TODO: fix this - tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] - set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = [] - while len(self._pending_searches): - t = self._pending_searches.pop() - if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): - f = t[2] - if f not in set_futures: - set_futures.append(f) - if not f.done(): - f.set_result(packet) - elif t[2] not in set_futures: - tmp.append(t) - while tmp: - self._pending_searches.append(tmp.pop()) - return None + futures: Set['asyncio.Future[SSDPDatagram]'] = set() + replied: List[PendingSearch] = [] + + for pending in self._pending_searches: + # if pending.address == address and pending.st in (packet.st, "upnp:rootdevice"): + if pending.address == address and pending.st == packet.st: + replied.append(pending) + if pending.fut not in futures: + futures.add(pending.fut) + if replied: + self.devices.put_nowait(packet) + + while replied: + self._pending_searches.remove(replied.pop()) + + while futures: + fut = futures.pop() + if not fut.done(): + fut.set_result(packet) def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None: - dest = address if self._unicast else SSDP_IP_ADDRESS if not self.transport: if not fut.done(): fut.set_exception(UPnPError("SSDP transport not connected")) - return None - log.debug("send m search to %s: %s", dest, packet.st) - self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) - return None + return + assert packet.st is not None + self._pending_searches.append( + PendingSearch(address, packet.st, fut) + ) + self.transport.sendto(packet.encode().encode(), (SSDP_IP_ADDRESS, SSDP_PORT)) - async def m_search(self, address: str, timeout: float, - datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram: - fut: 'asyncio.Future[SSDPDatagram]' = asyncio.Future(loop=self.loop) + # also send unicast + log.debug("send m search to %s: %s", address, packet.st) + self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) + + def send_m_searches(self, address: str, + datagrams: List[Dict[str, typing.Union[str, int]]]) -> 'asyncio.Future[SSDPDatagram]': + fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future() for datagram in datagrams: packet = SSDPDatagram("M-SEARCH", datagram) assert packet.st is not None - self._pending_searches.append( - (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut)) - ) - return await asyncio.wait_for(fut, timeout) + self._send_m_search(address, packet, fut) + return fut - def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore + async def m_search(self, address: str, timeout: float, + datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram: + fut = self.send_m_searches(address, datagrams) + return await asyncio.wait_for(fut, timeout, loop=self.loop) + + def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore if addr[0] == self.bind_address: return None try: @@ -94,7 +110,6 @@ class SSDPProtocol(MulticastProtocol): log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, binascii.hexlify(data)) return None - if packet._packet_type == packet._OK: self._callback_m_search_ok(addr[0], packet) return None @@ -118,14 +133,13 @@ class SSDPProtocol(MulticastProtocol): # return -async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None, - ignored: typing.Optional[typing.Set[str]] = None, - unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]: +async def listen_ssdp(lan_address: str, gateway_address: str, + loop: Optional[asyncio.AbstractEventLoop] = None) -> Tuple[SSDPProtocol, str, str]: loop = loop or asyncio.get_event_loop() try: sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address) - listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( - lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock + listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( + lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock ) protocol = listen_result[1] assert isinstance(protocol, SSDPProtocol) @@ -137,58 +151,28 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio return protocol, gateway_address, lan_address -async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]], - timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None, - ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram: +async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]], + timeout: int = 1, + loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPDatagram: protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, loop, ignored, unicast + lan_address, gateway_address, loop ) try: return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) - except (asyncio.TimeoutError, asyncio.CancelledError): + except asyncio.TimeoutError: raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) finally: protocol.disconnect() -async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - ignored: typing.Set[str] = None, - unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]: +async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3, + loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPProtocol: + loop = loop or asyncio.get_event_loop() protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, loop, ignored, unicast + lan_address, gateway_address, loop ) - await protocol.connected.wait() - packet_args = list(packet_generator()) - batch_size = 2 - batch_timeout = float(timeout) / float(len(packet_args)) - while packet_args: - args = packet_args[:batch_size] - packet_args = packet_args[batch_size:] - log.debug("sending batch of %i M-SEARCH attempts", batch_size) - try: - await protocol.m_search(gateway_address, batch_timeout, args) - except asyncio.TimeoutError: - continue - else: - protocol.disconnect() - return args - protocol.disconnect() - raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) - - -async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - ignored: typing.Set[str] = None, - unicast: bool = False) -> typing.Tuple[typing.Dict[str, - typing.Union[int, str]], SSDPDatagram]: - # we don't know which packet the gateway replies to, so send small batches at a time - args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast) - # check the args in the batch that got a reply one at a time to see which one worked - for args in args_to_try: - try: - packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast) - return args, packet - except UPnPError: - continue - raise UPnPError("failed to discover gateway") + fut = asyncio.ensure_future(protocol.send_m_searches( + address=gateway_address, datagrams=list(packet_generator()) + ), loop=loop) + loop.call_later(timeout, lambda: None if not fut or fut.done() else fut.cancel()) + return protocol diff --git a/aioupnp/serialization/ssdp.py b/aioupnp/serialization/ssdp.py index 99a3808..9a252f5 100644 --- a/aioupnp/serialization/ssdp.py +++ b/aioupnp/serialization/ssdp.py @@ -164,10 +164,15 @@ class SSDPDatagram: @classmethod def decode(cls, datagram: bytes) -> 'SSDPDatagram': - packet = cls._from_string(datagram.decode()) + try: + packet = cls._from_string(datagram.decode()) + except UnicodeDecodeError: + raise UPnPError( + f"failed to decode datagram: {binascii.hexlify(datagram).decode()}" + ) if packet is None: raise UPnPError( - "failed to decode datagram: {}".format(binascii.hexlify(datagram)) + f"failed to decode datagram: {binascii.hexlify(datagram).decode()}" ) for attr_name in packet._required_fields[packet._packet_type]: if getattr(packet, attr_name, None) is None: diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index 9b3eefc..430f3de 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -4,12 +4,10 @@ import logging import json import asyncio -from collections import OrderedDict from typing import Tuple, Dict, List, Union, Optional, Any from aioupnp.fault import UPnPError from aioupnp.gateway import Gateway from aioupnp.interfaces import get_gateway_and_lan_addresses -from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse @@ -61,7 +59,7 @@ class UPnP: return lan_address, gateway_address @classmethod - async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30, + async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 3, igd_args: Optional[Dict[str, Union[str, int]]] = None, interface_name: str = 'default', loop: Optional[asyncio.AbstractEventLoop] = None) -> 'UPnP': lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) @@ -72,7 +70,7 @@ class UPnP: @classmethod async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, - unicast: bool = True, interface_name: str = 'default', + interface_name: str = 'default', igd_args: Optional[Dict[str, Union[str, int]]] = None, loop: Optional[asyncio.AbstractEventLoop] = None ) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]: @@ -82,7 +80,6 @@ class UPnP: :param lan_address: (str) the local interface ipv4 address :param gateway_address: (str) the gateway ipv4 address :param timeout: (int) m search timeout - :param unicast: (bool) use unicast :param interface_name: (str) name of the network interface :param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided. @@ -101,16 +98,14 @@ class UPnP: except Exception as err: raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name, str(err))) - if not igd_args: - igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast) - else: - igd_args = OrderedDict(igd_args) - datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast) + gateway = await Gateway.discover_gateway( + lan_address, gateway_address, timeout, igd_args, loop + ) return { 'lan_address': lan_address, 'gateway_address': gateway_address, - 'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(), - 'discover_reply': datagram.as_dict() + # 'm_search_kwargs': SSDPDatagram("M-SEARCH", igd_args).get_cli_igd_kwargs(), + 'discover_reply': gateway._ok_packet.as_dict() } async def get_external_ip(self) -> str: @@ -372,20 +367,20 @@ cli_commands = [ def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', - gateway_address: str = '', timeout: int = 30, interface_name: str = 'default', - unicast: bool = True, kwargs: Optional[Dict[str, str]] = None, + gateway_address: str = '', timeout: int = 3, interface_name: str = 'default', + kwargs: Optional[Dict[str, str]] = None, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: kwargs = kwargs or {} igd_args = igd_args timeout = int(timeout) loop = loop or asyncio.get_event_loop() - fut: 'asyncio.Future' = asyncio.Future(loop=loop) + fut: 'asyncio.Future' = loop.create_future() async def wrapper(): # wrap the upnp setup and call of the command in a coroutine if method == 'm_search': # if we're only m_searching don't do any device discovery fn = lambda *_a, **_kw: UPnP.m_search( - lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop + lan_address, gateway_address, timeout, interface_name, igd_args, loop ) else: # automatically discover the gateway try: diff --git a/tests/__init__.py b/tests/__init__.py index acd71a5..62f39b9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,7 +16,8 @@ except ImportError: @contextlib.contextmanager def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, - tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False): + tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False, + raise_oserror_on_bind=False): sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] udp_replies = udp_replies or {} @@ -72,7 +73,13 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r with mock.patch('socket.socket') as mock_socket: mock_sock = mock.Mock(spec=socket.socket) mock_sock.setsockopt = lambda *_: None - mock_sock.bind = lambda *_: None + + def bind(*_): + if raise_oserror_on_bind: + raise OSError() + return + + mock_sock.bind = bind mock_sock.setblocking = lambda *_: None mock_sock.getsockname = lambda: "0.0.0.0" mock_sock.getpeername = lambda: "" diff --git a/tests/protocols/test_ssdp.py b/tests/protocols/test_ssdp.py index 243c94f..fec6621 100644 --- a/tests/protocols/test_ssdp.py +++ b/tests/protocols/test_ssdp.py @@ -3,7 +3,7 @@ from aioupnp.fault import UPnPError from aioupnp.protocols.m_search_patterns import packet_generator from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.constants import SSDP_IP_ADDRESS -from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, SSDPProtocol +from aioupnp.protocols.ssdp import m_search, SSDPProtocol from tests import AsyncioTestCase, mock_tcp_and_udp @@ -28,6 +28,11 @@ class TestSSDP(AsyncioTestCase): ]) reply_packet = SSDPDatagram("OK", reply_args) + async def test_socket_setup_error(self): + with mock_tcp_and_udp(self.loop, raise_oserror_on_bind=True): + with self.assertRaises(UPnPError): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) + async def test_transport_not_connected_error(self): try: await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()]) @@ -35,6 +40,16 @@ class TestSSDP(AsyncioTestCase): except UPnPError as err: self.assertEqual(str(err), "SSDP transport not connected") + async def test_deadbeef_response(self): + replies = { + (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): b'\xde\xad\xbe\xef' + } + sent = [] + + with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent): + with self.assertRaises(UPnPError): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) + async def test_m_search_reply_unicast(self): replies = { (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode() @@ -42,14 +57,14 @@ class TestSSDP(AsyncioTestCase): sent = [] with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent): - reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True) + reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) self.assertEqual(reply.encode(), self.reply_packet.encode()) - self.assertListEqual(sent, [self.query_packet.encode().encode()]) + self.assertIn(self.query_packet.encode().encode(), sent) with self.assertRaises(UPnPError): - with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies): - await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False) + with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.10", udp_replies=replies): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) async def test_m_search_reply_multicast(self): replies = { @@ -61,39 +76,40 @@ class TestSSDP(AsyncioTestCase): reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) self.assertEqual(reply.encode(), self.reply_packet.encode()) - self.assertListEqual(sent, [self.query_packet.encode().encode()]) + self.assertIn(self.query_packet.encode().encode(), sent) with self.assertRaises(UPnPError): - with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1"): - await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True) + with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.10"): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) - async def test_packets_sent_fuzzy_m_search(self): - sent = [] - - with self.assertRaises(UPnPError): - with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent): - await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) - - self.assertListEqual(sent, self.byte_packets) - - async def test_packets_fuzzy_m_search(self): - replies = { - (self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode() - } - sent = [] - - with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies, sent_udp_packets=sent): - args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) - - self.assertEqual(reply.encode(), self.reply_packet.encode()) - self.assertEqual(args, self.successful_args) - - async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self): - sent = [] - - with self.assertRaises(UPnPError): - with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent, - add_potato_datagrams=True): - await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) - - self.assertListEqual(sent, self.byte_packets) \ No newline at end of file + # async def test_packets_sent_fuzzy_m_search(self): + # sent = [] + # + # with self.assertRaises(UPnPError): + # with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent): + # await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) + # for packet in self.byte_packets: + # self.assertIn(packet, sent) + # + # async def test_packets_fuzzy_m_search(self): + # replies = { + # (self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode() + # } + # sent = [] + # + # with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", udp_replies=replies, sent_udp_packets=sent): + # args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) + # + # self.assertEqual(reply.encode(), self.reply_packet.encode()) + # self.assertEqual(args, self.successful_args) + # + # async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self): + # sent = [] + # + # with self.assertRaises(UPnPError): + # with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent, + # add_potato_datagrams=True): + # await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) + # + # for packet in self.byte_packets: + # self.assertIn(packet, sent) \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_cli.py index f7a0052..722161e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -10,7 +10,6 @@ from aioupnp.__main__ import main m_search_cli_result = """{ "lan_address": "10.0.0.2", "gateway_address": "10.0.0.1", - "m_search_kwargs": "--HOST=239.255.255.250:1900 --MAN=ssdp:discover --MX=1 --ST=urn:schemas-upnp-org:device:WANDevice:1", "discover_reply": { "CACHE_CONTROL": "max-age=1800", "LOCATION": "http://10.0.0.1:49152/InternetGatewayDevice.xml", @@ -22,14 +21,13 @@ m_search_cli_result = """{ m_search_help_msg = """aioupnp [-h] [--debug_logging] m_search [--lan_address=] [--gateway_address=] - [--timeout=] [--unicast] [--interface_name=] [--
=
, ...] + [--timeout=] [--interface_name=] [--
=
, ...] Perform a M-SEARCH for a upnp gateway. :param lan_address: (str) the local interface ipv4 address :param gateway_address: (str) the gateway ipv4 address :param timeout: (int) m search timeout -:param unicast: (bool) use unicast :param interface_name: (str) name of the network interface :param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided. @@ -219,7 +217,7 @@ class TestCLI(AsyncioTestCase): actual_output = StringIO() timeout_msg = "aioupnp encountered an error: M-SEARCH for 10.0.0.1:1900 timed out\n" with contextlib.redirect_stdout(actual_output): - with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies): + with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies={}, udp_replies={}): main( [None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'm-search'], self.loop @@ -230,7 +228,7 @@ class TestCLI(AsyncioTestCase): with contextlib.redirect_stdout(actual_output): with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies): main( - [None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', '--unicast', 'm-search'], + [None, '--timeout=1', '--gateway_address=10.0.0.1', '--lan_address=10.0.0.2', 'm-search'], self.loop ) self.assertEqual(m_search_cli_result, actual_output.getvalue()) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 1512947..6b0a50b 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -176,8 +176,7 @@ class TestDiscoverDLinkDIR890L(AsyncioTestCase): [('serviceType', 'urn:schemas-upnp-org:service:WANIPConnection:1'), ('serviceId', 'urn:upnp-org:serviceId:WANIPConn1'), ('controlURL', '/soap.cgi?service=WANIPConn1'), ('eventSubURL', '/gena.cgi?service=WANIPConn1'), ('SCPDURL', '/WANIPConnection.xml')])}, - 'm_search_args': OrderedDict([('HOST', '239.255.255.250:1900'), ('MAN', 'ssdp:discover'), ('MX', 1), - ('ST', 'urn:schemas-upnp-org:device:WANDevice:1')]), 'reply': OrderedDict( + 'reply': OrderedDict( [('CACHE_CONTROL', 'max-age=1800'), ('LOCATION', 'http://10.0.0.1:49152/InternetGatewayDevice.xml'), ('SERVER', 'Linux, UPnP/1.0, DIR-890L Ver 1.20'), ('ST', 'urn:schemas-upnp-org:device:WANDevice:1'), ('USN', 'uuid:11111111-2222-3333-4444-555555555555::urn:schemas-upnp-org:device:WANDevice:1')]), @@ -232,14 +231,14 @@ class TestDiscoverDLinkDIR890L(AsyncioTestCase): with self.assertRaises(UPnPError) as e2: with mock_tcp_and_udp(self.loop): await Gateway.discover_gateway(self.client_address, self.gateway_info['gateway_address'], 2, - unicast=False, loop=self.loop) + loop=self.loop) self.assertEqual(str(e1.exception), f"M-SEARCH for {self.gateway_info['gateway_address']}:1900 timed out") self.assertEqual(str(e2.exception), f"M-SEARCH for {self.gateway_info['gateway_address']}:1900 timed out") async def test_discover_commands(self): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): gateway = Gateway( - SSDPDatagram("OK", self.gateway_info['reply']), self.gateway_info['m_search_args'], + SSDPDatagram("OK", self.gateway_info['reply']), self.client_address, self.gateway_info['gateway_address'], loop=self.loop ) await gateway.discover_commands() @@ -274,9 +273,7 @@ class TestDiscoverNetgearNighthawkAC2350(TestDiscoverDLinkDIR890L): [('serviceType', 'urn:schemas-upnp-org:service:WANIPConnection:1'), ('serviceId', 'urn:upnp-org:serviceId:WANIPConn1'), ('controlURL', '/ctl/IPConn'), ('eventSubURL', '/evt/IPConn'), ('SCPDURL', '/WANIPCn.xml')])}, - 'm_search_args': OrderedDict( - [('HOST', '239.255.255.250:1900'), ('MAN', '"ssdp:discover"'), ('MX', 1), - ('ST', 'upnp:rootdevice')]), 'reply': OrderedDict( + 'reply': OrderedDict( [('CACHE_CONTROL', 'max-age=1800'), ('ST', 'upnp:rootdevice'), ('USN', 'uuid:11111111-2222-3333-4444-555555555555::upnp:rootdevice'), ('Server', 'R7500v2 UPnP/1.0 miniupnpd/1.0'), ('Location', 'http://192.168.0.1:5555/rootDesc.xml')]), diff --git a/tests/test_upnp.py b/tests/test_upnp.py index fcc9388..096a0d4 100644 --- a/tests/test_upnp.py +++ b/tests/test_upnp.py @@ -45,7 +45,7 @@ class TestGetExternalIPAddress(UPnPCommandTestCase): self.replies.update({request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 365 \r\nEXT:\r\n\r\n\n\n\t\n\t\t\n11.222.3.44\n\n\t\n\n"}) self.addCleanup(self.replies.pop, request) with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) external_ip = await upnp.get_external_ip() @@ -57,7 +57,7 @@ class TestGetExternalIPAddress(UPnPCommandTestCase): request: b"HTTP/1.1 200 OK\r\nServer: WebServer\r\nDate: Wed, 22 May 2019 03:25:57 GMT\r\nConnection: close\r\nCONTENT-TYPE: text/xml; charset=\"utf-8\"\r\nCONTENT-LENGTH: 354 \r\nEXT:\r\n\r\n\n\n\t\n\t\t\n\n\n\t\n\n"}) self.addCleanup(self.replies.pop, request) with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) with self.assertRaises(UPnPError): @@ -73,7 +73,7 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase): b"11.222.3.44\n\n\t\n\n"}) self.addCleanup(self.replies.pop, self.get_ip_request) with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) with self.assertRaises(UPnPError): @@ -84,7 +84,7 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase): b"11.222.3.44\n\n\t\n\n"}) self.addCleanup(self.replies.pop, self.get_ip_request) with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) external_ip = await upnp.get_external_ip() @@ -95,7 +95,7 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase): b"11.222.3.44\n\n\t\n\n"}) self.addCleanup(self.replies.pop, self.get_ip_request) with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) external_ip = await upnp.get_external_ip() @@ -113,7 +113,7 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase): async def test_get_port_mapping_by_index(self): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) result = await upnp.get_port_mapping_by_index(0) @@ -135,7 +135,7 @@ class TestGetNextPortMapping(UPnPCommandTestCase): async def test_get_next_mapping(self): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) ext_port = await upnp.get_next_mapping(4567, "UDP", "aioupnp test mapping") @@ -155,7 +155,7 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase): async def test_get_specific_port_mapping(self): with mock_tcp_and_udp(self.loop, tcp_replies=self.replies): - gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop) + gateway = Gateway(self.reply, self.client_address, self.gateway_address, loop=self.loop) await gateway.discover_commands() upnp = UPnP(self.client_address, self.gateway_address, gateway) try: