Use done and pending, add unit test

This commit is contained in:
hackrush 2019-02-06 00:46:49 +05:30
parent ee8ff02d52
commit ca5ff5f225
2 changed files with 35 additions and 23 deletions

View file

@ -155,8 +155,8 @@ class Gateway:
}
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int=30,
igd_args: OrderedDict=None, loop=None, unicast: bool=False):
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
igd_args: OrderedDict = None, loop=None, unicast: bool = False):
ignored: set = set()
required_commands = [
'AddPortMapping',
@ -181,7 +181,7 @@ class Gateway:
required for required in required_commands if required not in gateway._registered_commands
]
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
continue
else:
@ -197,26 +197,25 @@ class Gateway:
igd_args: OrderedDict = None, loop=None, unicast: bool = None):
if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
loop = loop or asyncio.get_event_loop()
with_unicast = loop.create_task(cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=True
))
without_unicast = loop.create_task(cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
))
await asyncio.wait([with_unicast, without_unicast], return_when=asyncio.tasks.FIRST_COMPLETED)
if with_unicast and not with_unicast.done():
with_unicast.cancel()
if without_unicast.done():
return without_unicast.result()
elif without_unicast and not without_unicast.done():
without_unicast.cancel()
if with_unicast.done() and not with_unicast.cancelled():
return with_unicast.result()
else:
with_unicast.exception()
without_unicast.exception()
return with_unicast.result()
done, pending = await asyncio.wait([
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=True
),
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False
)
], return_when=asyncio.tasks.FIRST_COMPLETED)
for task in pending:
task.cancel()
for task in done:
try:
task.exception()
except asyncio.CancelledError:
pass
return list(done)[0].result()
async def discover_commands(self, loop=None):
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)

View file

@ -1,3 +1,6 @@
import asyncio
from aioupnp.fault import UPnPError
from tests import TestBase
from tests.mocks import mock_tcp_and_udp
from collections import OrderedDict
@ -62,6 +65,16 @@ class TestDiscoverDLinkDIR890L(TestBase):
'GetExternalIPAddress': 'urn:schemas-upnp-org:service:WANIPConnection:1'
}
async def test_discover_gateway(self):
with self.assertRaises(UPnPError) as e1:
with mock_tcp_and_udp(self.loop):
await Gateway.discover_gateway("10.0.0.2", "10.0.0.1", 2)
with self.assertRaises(UPnPError) as e2:
with mock_tcp_and_udp(self.loop):
await Gateway.discover_gateway("10.0.0.2", "10.0.0.1", 2, unicast=False)
self.assertEqual(str(e1.exception), "M-SEARCH for 10.0.0.1:1900 timed out")
self.assertEqual(str(e2.exception), "M-SEARCH for 10.0.0.1:1900 timed out")
async def test_discover_commands(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)