Call .exception() if both discoveries fail together, else return result

This commit is contained in:
hackrush 2019-02-02 12:22:20 +05:30
parent daaa41ba3d
commit ee8ff02d52
3 changed files with 22 additions and 15 deletions

View file

@ -197,19 +197,26 @@ class Gateway:
igd_args: OrderedDict = None, loop=None, unicast: bool = None): igd_args: OrderedDict = None, loop=None, unicast: bool = None):
if unicast is not None: if unicast is not None:
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast) return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop, unicast)
return await cls._discover_gateway(lan_address, gateway_address, timeout, igd_args, loop) loop = loop or asyncio.get_event_loop()
done, pending = await asyncio.wait([ with_unicast = loop.create_task(cls._discover_gateway(
cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=True lan_address, gateway_address, timeout, igd_args, loop, unicast=True
), ))
cls._discover_gateway( without_unicast = loop.create_task(cls._discover_gateway(
lan_address, gateway_address, timeout, igd_args, loop, unicast=False lan_address, gateway_address, timeout, igd_args, loop, unicast=False
)], return_when=asyncio.tasks.FIRST_COMPLETED ))
) await asyncio.wait([with_unicast, without_unicast], return_when=asyncio.tasks.FIRST_COMPLETED)
for task in list(pending): if with_unicast and not with_unicast.done():
task.cancel() with_unicast.cancel()
result = list(done)[0].result() if without_unicast.done():
return result return without_unicast.result()
elif without_unicast and not without_unicast.done():
without_unicast.cancel()
if with_unicast.done() and not with_unicast.cancelled():
return with_unicast.result()
else:
with_unicast.exception()
without_unicast.exception()
return with_unicast.result()
async def discover_commands(self, loop=None): async def discover_commands(self, loop=None):
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop) response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)

View file

@ -394,7 +394,7 @@ class UPnP:
try: try:
result = fut.result() result = fut.result()
except UPnPError as err: except UPnPError as err:
print("aioupnp encountered an error:\n%s" % str(err)) print("aioupnp encountered an error: %s" % str(err))
return return
if isinstance(result, (list, tuple, dict)): if isinstance(result, (list, tuple, dict)):

View file

@ -85,7 +85,7 @@ class TestCLI(TestBase):
def test_m_search(self): def test_m_search(self):
actual_output = StringIO() actual_output = StringIO()
timeout_msg = "aioupnp encountered an error:\nM-SEARCH for 10.0.0.1:1900 timed out\n" timeout_msg = "aioupnp encountered an error: M-SEARCH for 10.0.0.1:1900 timed out\n"
with contextlib.redirect_stdout(actual_output): with contextlib.redirect_stdout(actual_output):
with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies): with mock_tcp_and_udp(self.loop, '10.0.0.1', tcp_replies=self.scpd_replies, udp_replies=self.udp_replies):
main( main(