From 8e07d7f39007c16d112c929cd4b2e1935aec8a04 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Fri, 25 Oct 2019 14:47:43 -0400 Subject: [PATCH] add tests + coverage -test_socket_setup_error -ssdp test_deadbeef_response --- aioupnp/gateway.py | 4 ++-- aioupnp/protocols/ssdp.py | 23 +++++++++-------------- aioupnp/serialization/ssdp.py | 7 ++++++- tests/__init__.py | 11 +++++++++-- tests/protocols/test_ssdp.py | 15 +++++++++++++++ 5 files changed, 41 insertions(+), 19 deletions(-) diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index 282ff63..d59b889 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -188,7 +188,7 @@ class Gateway: igd_args: typing.Dict[str, typing.Union[int, str]], timeout: int = 30, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': - datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set()) + datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop) gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop) if not gateway: raise UPnPError("no gateway found for given args") @@ -199,7 +199,7 @@ class Gateway: loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': ignored: typing.Set[str] = set() ssdp_proto = await multi_m_search( - lan_address, gateway_address, timeout, loop, ignored + lan_address, gateway_address, timeout, loop ) try: while True: diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 17a3ba9..c7babb8 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -24,12 +24,11 @@ class PendingSearch(typing.NamedTuple): class SSDPProtocol(MulticastProtocol): - def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None, + def __init__(self, multicast_address: str, lan_address: str, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: super().__init__(multicast_address, lan_address) self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self.transport: Optional[DatagramTransport] = None - self._ignored: Set[str] = ignored or set() # ignored locations self._pending_searches: List[PendingSearch] = [] self.notifications: List[SSDPDatagram] = [] self.connected = asyncio.Event(loop=self.loop) @@ -52,9 +51,6 @@ class SSDPProtocol(MulticastProtocol): return None def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: - if packet.location in self._ignored: - return - futures: Set['asyncio.Future[SSDPDatagram]'] = set() replied: List[PendingSearch] = [] @@ -137,13 +133,13 @@ class SSDPProtocol(MulticastProtocol): # return -async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None, - ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]: +async def listen_ssdp(lan_address: str, gateway_address: str, + loop: Optional[asyncio.AbstractEventLoop] = None) -> Tuple[SSDPProtocol, str, str]: loop = loop or asyncio.get_event_loop() try: sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address) listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( - lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock + lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock ) protocol = listen_result[1] assert isinstance(protocol, SSDPProtocol) @@ -156,10 +152,10 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]], - timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None, - ignored: Set[str] = None) -> SSDPDatagram: + timeout: int = 1, + loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPDatagram: protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, loop, ignored + lan_address, gateway_address, loop ) try: return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) @@ -170,11 +166,10 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3, - loop: Optional[asyncio.AbstractEventLoop] = None, - ignored: Set[str] = None) -> SSDPProtocol: + loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPProtocol: loop = loop or asyncio.get_event_loop() protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, loop, ignored + lan_address, gateway_address, loop ) fut = asyncio.ensure_future(protocol.send_m_searches( address=gateway_address, datagrams=list(packet_generator()) diff --git a/aioupnp/serialization/ssdp.py b/aioupnp/serialization/ssdp.py index 6cd9d50..9a252f5 100644 --- a/aioupnp/serialization/ssdp.py +++ b/aioupnp/serialization/ssdp.py @@ -164,7 +164,12 @@ class SSDPDatagram: @classmethod def decode(cls, datagram: bytes) -> 'SSDPDatagram': - packet = cls._from_string(datagram.decode()) + try: + packet = cls._from_string(datagram.decode()) + except UnicodeDecodeError: + raise UPnPError( + f"failed to decode datagram: {binascii.hexlify(datagram).decode()}" + ) if packet is None: raise UPnPError( f"failed to decode datagram: {binascii.hexlify(datagram).decode()}" diff --git a/tests/__init__.py b/tests/__init__.py index acd71a5..62f39b9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,7 +16,8 @@ except ImportError: @contextlib.contextmanager def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, - tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False): + tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False, + raise_oserror_on_bind=False): sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] udp_replies = udp_replies or {} @@ -72,7 +73,13 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r with mock.patch('socket.socket') as mock_socket: mock_sock = mock.Mock(spec=socket.socket) mock_sock.setsockopt = lambda *_: None - mock_sock.bind = lambda *_: None + + def bind(*_): + if raise_oserror_on_bind: + raise OSError() + return + + mock_sock.bind = bind mock_sock.setblocking = lambda *_: None mock_sock.getsockname = lambda: "0.0.0.0" mock_sock.getpeername = lambda: "" diff --git a/tests/protocols/test_ssdp.py b/tests/protocols/test_ssdp.py index 3beac3f..fec6621 100644 --- a/tests/protocols/test_ssdp.py +++ b/tests/protocols/test_ssdp.py @@ -28,6 +28,11 @@ class TestSSDP(AsyncioTestCase): ]) reply_packet = SSDPDatagram("OK", reply_args) + async def test_socket_setup_error(self): + with mock_tcp_and_udp(self.loop, raise_oserror_on_bind=True): + with self.assertRaises(UPnPError): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) + async def test_transport_not_connected_error(self): try: await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()]) @@ -35,6 +40,16 @@ class TestSSDP(AsyncioTestCase): except UPnPError as err: self.assertEqual(str(err), "SSDP transport not connected") + async def test_deadbeef_response(self): + replies = { + (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): b'\xde\xad\xbe\xef' + } + sent = [] + + with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent): + with self.assertRaises(UPnPError): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) + async def test_m_search_reply_unicast(self): replies = { (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()