diff --git a/.travis.yml b/.travis.yml index a97073d..cfb1f56 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ jobs: name: "mypy" before_install: - pip install mypy lxml - - pip install -e . + - pip install -e .[test] script: - mypy . --txt-report . --scripts-are-modules; cat index.txt; rm index.txt @@ -19,7 +19,7 @@ jobs: name: "Unit Tests w/ Python 3.7" before_install: - pip install pylint coverage - - pip install -e . + - pip install -e .[test] script: - HOME=/tmp coverage run --source=aioupnp -m unittest -v diff --git a/aioupnp/gateway.py b/aioupnp/gateway.py index 3e9ea45..02242c8 100644 --- a/aioupnp/gateway.py +++ b/aioupnp/gateway.py @@ -271,7 +271,6 @@ class Gateway: name, param_types, return_types, inputs, outputs, soap_socket) setattr(command, "__doc__", current.__doc__) setattr(self.commands, command.method, command) - self._registered_commands[command.method] = service.serviceType log.debug("registered %s::%s", service.serviceType, command.method) except AttributeError: diff --git a/aioupnp/protocols/ssdp.py b/aioupnp/protocols/ssdp.py index 5238c2c..7a4e694 100644 --- a/aioupnp/protocols/ssdp.py +++ b/aioupnp/protocols/ssdp.py @@ -1,5 +1,4 @@ import re -import socket import binascii import asyncio import logging @@ -45,8 +44,6 @@ class SSDPProtocol(MulticastProtocol): a, s = t[0], t[1] if (address == a) and (s in [packet.st, "upnp:rootdevice"]): f: Future = t[2] - # h: asyncio.Handle = t[3] - # h.cancel() if f not in set_futures: set_futures.append(f) if not f.done(): @@ -68,7 +65,6 @@ class SSDPProtocol(MulticastProtocol): for datagram in datagrams: packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram) assert packet.st is not None - # h = asyncio.get_running_loop().call_later(timeout, fut.cancel) self._pending_searches.append((address, packet.st, fut)) packets.append(packet) self.send_many_m_searches(address, packets), @@ -108,18 +104,19 @@ class SSDPProtocol(MulticastProtocol): # return -async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None, +async def listen_ssdp(lan_address: str, gateway_address: str, loop=None, ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport, SSDPProtocol, str, str]: - loop = asyncio.get_event_loop_policy().get_event_loop() + loop = loop or asyncio.get_event_loop_policy().get_event_loop() try: - sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address) + sock = SSDPProtocol.create_multicast_socket(lan_address) listen_result: typing.Tuple = await loop.create_datagram_endpoint( lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock ) transport: DatagramTransport = listen_result[0] protocol: SSDPProtocol = listen_result[1] except Exception as err: + print(err) raise UPnPError(err) try: protocol.join_group(protocol.multicast_address, protocol.bind_address) @@ -132,24 +129,25 @@ async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socke async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1, - ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None, + loop=None, ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram: transport, protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, ssdp_socket, ignored, unicast + lan_address, gateway_address, loop, ignored, unicast ) try: - return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) + return await asyncio.wait_for( + protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout + ) except (asyncio.TimeoutError, asyncio.CancelledError): raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) finally: protocol.disconnect() -async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, - ssdp_socket: socket.socket = None, +async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]: transport, protocol, gateway_address, lan_address = await listen_ssdp( - lan_address, gateway_address, ssdp_socket, ignored, unicast + lan_address, gateway_address, loop, ignored, unicast ) packet_args = list(packet_generator()) batch_size = 2 @@ -168,17 +166,15 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) -async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, - ssdp_socket: socket.socket = None, +async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None, ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict, SSDPDatagram]: # we don't know which packet the gateway replies to, so send small batches at a time - - args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast) + args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast) # check the args in the batch that got a reply one at a time to see which one worked for args in args_to_try: try: - packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast) + packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast) return args, packet except UPnPError: continue diff --git a/aioupnp/protocols/test_common.py b/aioupnp/protocols/test_common.py new file mode 100644 index 0000000..972c8fc --- /dev/null +++ b/aioupnp/protocols/test_common.py @@ -0,0 +1,60 @@ +import asyncio +import inspect +import contextlib +import socket +import mock +import unittest + + +def async_test(f): + def wrapper(*args, **kwargs): + if inspect.iscoroutinefunction(f): + future = f(*args, **kwargs) + else: + coroutine = asyncio.coroutine(f) + future = coroutine(*args, **kwargs) + asyncio.get_event_loop().run_until_complete(future) + + return wrapper + + +class TestBase(unittest.TestCase): + def setUp(self): + self.loop = asyncio.get_event_loop_policy().get_event_loop() + + +@contextlib.contextmanager +def mock_datagram_endpoint_factory(loop, expected_addr, replies=None, delay_reply=0.0, sent_packets=None): + sent_packets = sent_packets if sent_packets is not None else [] + replies = replies or {} + + def sendto(p: asyncio.DatagramProtocol): + def _sendto(data, addr): + sent_packets.append(data) + if (data, addr) in replies: + loop.call_later(delay_reply, p.datagram_received, replies[(data, addr)], (expected_addr, 1900)) + return _sendto + + async def create_datagram_endpoint(proto_lam, sock=None): + protocol = proto_lam() + transport = asyncio.DatagramTransport(extra={'socket': mock_sock}) + transport.close = lambda: mock_sock.close() + mock_sock.sendto = sendto(protocol) + transport.sendto = mock_sock.sendto + protocol.connection_made(transport) + return transport, protocol + + with mock.patch('socket.socket') as mock_socket: + mock_sock = mock.Mock(spec=socket.socket) + mock_sock.setsockopt = lambda *_: None + mock_sock.bind = lambda *_: None + mock_sock.setblocking = lambda *_: None + mock_sock.getsockname = lambda: "0.0.0.0" + mock_sock.getpeername = lambda: "" + mock_sock.close = lambda: None + mock_sock.type = socket.SOCK_DGRAM + mock_sock.fileno = lambda: 7 + + mock_socket.return_value = mock_sock + loop.create_datagram_endpoint = create_datagram_endpoint + yield diff --git a/aioupnp/protocols/test_ssdp.py b/aioupnp/protocols/test_ssdp.py new file mode 100644 index 0000000..5a42330 --- /dev/null +++ b/aioupnp/protocols/test_ssdp.py @@ -0,0 +1,86 @@ +from collections import OrderedDict +from aioupnp.fault import UPnPError +from aioupnp.protocols.m_search_patterns import packet_generator +from aioupnp.serialization.ssdp import SSDPDatagram +from aioupnp.constants import SSDP_IP_ADDRESS +from aioupnp.protocols.ssdp import fuzzy_m_search, m_search +from aioupnp.protocols.test_common import TestBase, async_test, mock_datagram_endpoint_factory + + +class TestSSDP(TestBase): + packet_args = list(packet_generator()) + byte_packets = [SSDPDatagram("M-SEARCH", p).encode().encode() for p in packet_args] + + successful_args = OrderedDict([ + ("HOST", "239.255.255.250:1900"), + ("MAN", "ssdp:discover"), + ("MX", 1), + ("ST", "urn:schemas-upnp-org:device:WANDevice:1") + ]) + query_packet = SSDPDatagram("M-SEARCH", successful_args) + + reply_args = OrderedDict([ + ("CACHE_CONTROL", "max-age=1800"), + ("LOCATION", "http://10.0.0.1:49152/InternetGatewayDevice.xml"), + ("SERVER", "Linux, UPnP/1.0, DIR-890L Ver 1.20"), + ("ST", "urn:schemas-upnp-org:device:WANDevice:1"), + ("USN", "uuid:22222222-3333-4444-5555-666666666666::urn:schemas-upnp-org:device:WANDevice:1") + ]) + reply_packet = SSDPDatagram("OK", reply_args) + + @async_test + async def test_m_search_reply_unicast(self): + replies = { + (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode() + } + sent = [] + + with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent): + reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True) + + self.assertEqual(reply.encode(), self.reply_packet.encode()) + self.assertListEqual(sent, [self.query_packet.encode().encode()]) + + with self.assertRaises(UPnPError): + with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False) + + @async_test + async def test_m_search_reply_multicast(self): + replies = { + (self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode() + } + sent = [] + + with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent): + reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop) + + self.assertEqual(reply.encode(), self.reply_packet.encode()) + self.assertListEqual(sent, [self.query_packet.encode().encode()]) + + with self.assertRaises(UPnPError): + with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies): + await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True) + + @async_test + async def test_packets_sent_fuzzy_m_search(self): + sent = [] + + with self.assertRaises(UPnPError): + with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", sent_packets=sent): + await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) + + self.assertListEqual(sent, self.byte_packets) + + @async_test + async def test_packets_fuzzy_m_search(self): + replies = { + (self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode() + } + sent = [] + + with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent): + args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop) + + self.assertEqual(reply.encode(), self.reply_packet.encode()) + self.assertEqual(args, self.successful_args) diff --git a/aioupnp/upnp.py b/aioupnp/upnp.py index 8e29498..fe50a62 100644 --- a/aioupnp/upnp.py +++ b/aioupnp/upnp.py @@ -61,8 +61,7 @@ class UPnP: @classmethod @cli async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, - igd_args: OrderedDict = None, interface_name: str = 'default', - ssdp_socket: socket.socket = None) -> Dict: + igd_args: OrderedDict = None, interface_name: str = 'default') -> Dict: try: lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) assert gateway_address and lan_address @@ -70,10 +69,10 @@ class UPnP: raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name, str(err))) if not igd_args: - igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) + igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout) else: igd_args = OrderedDict(igd_args) - datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket) + datagram = await m_search(lan_address, gateway_address, igd_args, timeout) return { 'lan_address': lan_address, 'gateway_address': gateway_address, diff --git a/setup.py b/setup.py index cfd75f1..093d062 100644 --- a/setup.py +++ b/setup.py @@ -27,4 +27,9 @@ setup( install_requires=[ 'netifaces', ], + extras_require={ + 'test': ( + 'mock', + ) + } )