add tests + coverage

-test_socket_setup_error
-ssdp test_deadbeef_response
This commit is contained in:
Jack Robison 2019-10-25 14:47:43 -04:00
parent 1b25c009ca
commit 8e07d7f390
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 41 additions and 19 deletions

View file

@ -188,7 +188,7 @@ class Gateway:
igd_args: typing.Dict[str, typing.Union[int, str]], igd_args: typing.Dict[str, typing.Union[int, str]],
timeout: int = 30, timeout: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set()) datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop)
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop) gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
if not gateway: if not gateway:
raise UPnPError("no gateway found for given args") raise UPnPError("no gateway found for given args")
@ -199,7 +199,7 @@ class Gateway:
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway': loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
ignored: typing.Set[str] = set() ignored: typing.Set[str] = set()
ssdp_proto = await multi_m_search( ssdp_proto = await multi_m_search(
lan_address, gateway_address, timeout, loop, ignored lan_address, gateway_address, timeout, loop
) )
try: try:
while True: while True:

View file

@ -24,12 +24,11 @@ class PendingSearch(typing.NamedTuple):
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None, def __init__(self, multicast_address: str, lan_address: str,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: Optional[DatagramTransport] = None self.transport: Optional[DatagramTransport] = None
self._ignored: Set[str] = ignored or set() # ignored locations
self._pending_searches: List[PendingSearch] = [] self._pending_searches: List[PendingSearch] = []
self.notifications: List[SSDPDatagram] = [] self.notifications: List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop) self.connected = asyncio.Event(loop=self.loop)
@ -52,9 +51,6 @@ class SSDPProtocol(MulticastProtocol):
return None return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored:
return
futures: Set['asyncio.Future[SSDPDatagram]'] = set() futures: Set['asyncio.Future[SSDPDatagram]'] = set()
replied: List[PendingSearch] = [] replied: List[PendingSearch] = []
@ -137,13 +133,13 @@ class SSDPProtocol(MulticastProtocol):
# return # return
async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None, async def listen_ssdp(lan_address: str, gateway_address: str,
ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]: loop: Optional[asyncio.AbstractEventLoop] = None) -> Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
try: try:
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address) sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
) )
protocol = listen_result[1] protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol) assert isinstance(protocol, SSDPProtocol)
@ -156,10 +152,10 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]], async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None, timeout: int = 1,
ignored: Set[str] = None) -> SSDPDatagram: loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPDatagram:
protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored lan_address, gateway_address, loop
) )
try: try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args]) return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
@ -170,11 +166,10 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3, async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPProtocol:
ignored: Set[str] = None) -> SSDPProtocol:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored lan_address, gateway_address, loop
) )
fut = asyncio.ensure_future(protocol.send_m_searches( fut = asyncio.ensure_future(protocol.send_m_searches(
address=gateway_address, datagrams=list(packet_generator()) address=gateway_address, datagrams=list(packet_generator())

View file

@ -164,7 +164,12 @@ class SSDPDatagram:
@classmethod @classmethod
def decode(cls, datagram: bytes) -> 'SSDPDatagram': def decode(cls, datagram: bytes) -> 'SSDPDatagram':
packet = cls._from_string(datagram.decode()) try:
packet = cls._from_string(datagram.decode())
except UnicodeDecodeError:
raise UPnPError(
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
)
if packet is None: if packet is None:
raise UPnPError( raise UPnPError(
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}" f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"

View file

@ -16,7 +16,8 @@ except ImportError:
@contextlib.contextmanager @contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False): tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
raise_oserror_on_bind=False):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {} udp_replies = udp_replies or {}
@ -72,7 +73,13 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
with mock.patch('socket.socket') as mock_socket: with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket) mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
def bind(*_):
if raise_oserror_on_bind:
raise OSError()
return
mock_sock.bind = bind
mock_sock.setblocking = lambda *_: None mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0" mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: "" mock_sock.getpeername = lambda: ""

View file

@ -28,6 +28,11 @@ class TestSSDP(AsyncioTestCase):
]) ])
reply_packet = SSDPDatagram("OK", reply_args) reply_packet = SSDPDatagram("OK", reply_args)
async def test_socket_setup_error(self):
with mock_tcp_and_udp(self.loop, raise_oserror_on_bind=True):
with self.assertRaises(UPnPError):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
async def test_transport_not_connected_error(self): async def test_transport_not_connected_error(self):
try: try:
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()]) await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
@ -35,6 +40,16 @@ class TestSSDP(AsyncioTestCase):
except UPnPError as err: except UPnPError as err:
self.assertEqual(str(err), "SSDP transport not connected") self.assertEqual(str(err), "SSDP transport not connected")
async def test_deadbeef_response(self):
replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): b'\xde\xad\xbe\xef'
}
sent = []
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
with self.assertRaises(UPnPError):
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
async def test_m_search_reply_unicast(self): async def test_m_search_reply_unicast(self):
replies = { replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode() (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()