test ssdp

This commit is contained in:
Jack Robison 2018-10-21 21:11:41 -04:00
parent 1cfe84dcef
commit e9757666ab
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 170 additions and 25 deletions

View file

@ -9,7 +9,7 @@ jobs:
name: "mypy"
before_install:
- pip install mypy lxml
- pip install -e .
- pip install -e .[test]
script:
- mypy . --txt-report . --scripts-are-modules; cat index.txt; rm index.txt
@ -19,7 +19,7 @@ jobs:
name: "Unit Tests w/ Python 3.7"
before_install:
- pip install pylint coverage
- pip install -e .
- pip install -e .[test]
script:
- HOME=/tmp coverage run --source=aioupnp -m unittest -v

View file

@ -271,7 +271,6 @@ class Gateway:
name, param_types, return_types, inputs, outputs, soap_socket)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)
self._registered_commands[command.method] = service.serviceType
log.debug("registered %s::%s", service.serviceType, command.method)
except AttributeError:

View file

@ -1,5 +1,4 @@
import re
import socket
import binascii
import asyncio
import logging
@ -45,8 +44,6 @@ class SSDPProtocol(MulticastProtocol):
a, s = t[0], t[1]
if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
f: Future = t[2]
# h: asyncio.Handle = t[3]
# h.cancel()
if f not in set_futures:
set_futures.append(f)
if not f.done():
@ -68,7 +65,6 @@ class SSDPProtocol(MulticastProtocol):
for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
assert packet.st is not None
# h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
self._pending_searches.append((address, packet.st, fut))
packets.append(packet)
self.send_many_m_searches(address, packets),
@ -108,18 +104,19 @@ class SSDPProtocol(MulticastProtocol):
# return
async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socket.socket = None,
async def listen_ssdp(lan_address: str, gateway_address: str, loop=None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[DatagramTransport,
SSDPProtocol, str, str]:
loop = asyncio.get_event_loop_policy().get_event_loop()
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
try:
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
sock = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
)
transport: DatagramTransport = listen_result[0]
protocol: SSDPProtocol = listen_result[1]
except Exception as err:
print(err)
raise UPnPError(err)
try:
protocol.join_group(protocol.multicast_address, protocol.bind_address)
@ -132,24 +129,25 @@ async def listen_ssdp(lan_address: str, gateway_address: str, ssdp_socket: socke
async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
ssdp_socket: socket.socket = None, ignored: typing.Set[str] = None,
loop=None, ignored: typing.Set[str] = None,
unicast: bool = False) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket, ignored, unicast
lan_address, gateway_address, loop, ignored, unicast
)
try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
return await asyncio.wait_for(
protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]), timeout
)
except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:
protocol.disconnect()
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None,
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.List[OrderedDict]:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket, ignored, unicast
lan_address, gateway_address, loop, ignored, unicast
)
packet_args = list(packet_generator())
batch_size = 2
@ -168,17 +166,15 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None,
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, loop=None,
ignored: typing.Set[str] = None, unicast: bool = False) -> typing.Tuple[OrderedDict,
SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket, ignored, unicast)
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked
for args in args_to_try:
try:
packet = await m_search(lan_address, gateway_address, args, 3, ignored=ignored, unicast=unicast)
packet = await m_search(lan_address, gateway_address, args, 3, loop=loop, ignored=ignored, unicast=unicast)
return args, packet
except UPnPError:
continue

View file

@ -0,0 +1,60 @@
import asyncio
import inspect
import contextlib
import socket
import mock
import unittest
def async_test(f):
def wrapper(*args, **kwargs):
if inspect.iscoroutinefunction(f):
future = f(*args, **kwargs)
else:
coroutine = asyncio.coroutine(f)
future = coroutine(*args, **kwargs)
asyncio.get_event_loop().run_until_complete(future)
return wrapper
class TestBase(unittest.TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop_policy().get_event_loop()
@contextlib.contextmanager
def mock_datagram_endpoint_factory(loop, expected_addr, replies=None, delay_reply=0.0, sent_packets=None):
sent_packets = sent_packets if sent_packets is not None else []
replies = replies or {}
def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr):
sent_packets.append(data)
if (data, addr) in replies:
loop.call_later(delay_reply, p.datagram_received, replies[(data, addr)], (expected_addr, 1900))
return _sendto
async def create_datagram_endpoint(proto_lam, sock=None):
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto(protocol)
transport.sendto = mock_sock.sendto
protocol.connection_made(transport)
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_DGRAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_datagram_endpoint = create_datagram_endpoint
yield

View file

@ -0,0 +1,86 @@
from collections import OrderedDict
from aioupnp.fault import UPnPError
from aioupnp.protocols.m_search_patterns import packet_generator
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from aioupnp.protocols.test_common import TestBase, async_test, mock_datagram_endpoint_factory
class TestSSDP(TestBase):
packet_args = list(packet_generator())
byte_packets = [SSDPDatagram("M-SEARCH", p).encode().encode() for p in packet_args]
successful_args = OrderedDict([
("HOST", "239.255.255.250:1900"),
("MAN", "ssdp:discover"),
("MX", 1),
("ST", "urn:schemas-upnp-org:device:WANDevice:1")
])
query_packet = SSDPDatagram("M-SEARCH", successful_args)
reply_args = OrderedDict([
("CACHE_CONTROL", "max-age=1800"),
("LOCATION", "http://10.0.0.1:49152/InternetGatewayDevice.xml"),
("SERVER", "Linux, UPnP/1.0, DIR-890L Ver 1.20"),
("ST", "urn:schemas-upnp-org:device:WANDevice:1"),
("USN", "uuid:22222222-3333-4444-5555-666666666666::urn:schemas-upnp-org:device:WANDevice:1")
])
reply_packet = SSDPDatagram("OK", reply_args)
@async_test
async def test_m_search_reply_unicast(self):
replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
}
sent = []
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent):
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertListEqual(sent, [self.query_packet.encode().encode()])
with self.assertRaises(UPnPError):
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=False)
@async_test
async def test_m_search_reply_multicast(self):
replies = {
(self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode()
}
sent = []
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent):
reply = await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertListEqual(sent, [self.query_packet.encode().encode()])
with self.assertRaises(UPnPError):
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop, unicast=True)
@async_test
async def test_packets_sent_fuzzy_m_search(self):
sent = []
with self.assertRaises(UPnPError):
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", sent_packets=sent):
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertListEqual(sent, self.byte_packets)
@async_test
async def test_packets_fuzzy_m_search(self):
replies = {
(self.query_packet.encode().encode(), (SSDP_IP_ADDRESS, 1900)): self.reply_packet.encode().encode()
}
sent = []
with mock_datagram_endpoint_factory(self.loop, "10.0.0.1", replies=replies, sent_packets=sent):
args, reply = await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertEqual(args, self.successful_args)

View file

@ -61,8 +61,7 @@ class UPnP:
@classmethod
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: OrderedDict = None, interface_name: str = 'default',
ssdp_socket: socket.socket = None) -> Dict:
igd_args: OrderedDict = None, interface_name: str = 'default') -> Dict:
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
assert gateway_address and lan_address
@ -70,10 +69,10 @@ class UPnP:
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
str(err)))
if not igd_args:
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout)
else:
igd_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, ssdp_socket)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout)
return {
'lan_address': lan_address,
'gateway_address': gateway_address,

View file

@ -27,4 +27,9 @@ setup(
install_requires=[
'netifaces',
],
extras_require={
'test': (
'mock',
)
}
)