This commit is contained in:
Jack Robison 2020-11-02 10:52:36 -05:00
parent 33f9597939
commit cc69e88d1a
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 54 additions and 32 deletions

View file

@ -134,7 +134,7 @@ class SOAPCommands:
def is_registered(self, name: str) -> bool: def is_registered(self, name: str) -> bool:
if name not in self.SOAP_COMMANDS: if name not in self.SOAP_COMMANDS:
raise ValueError("unknown command") raise ValueError("unknown command") # pragma: no cover
for service in self._registered.values(): for service in self._registered.values():
if name in service: if name in service:
return True return True
@ -142,11 +142,11 @@ class SOAPCommands:
def get_service(self, name: str) -> Service: def get_service(self, name: str) -> Service:
if name not in self.SOAP_COMMANDS: if name not in self.SOAP_COMMANDS:
raise ValueError("unknown command") raise ValueError("unknown command") # pragma: no cover
for service, commands in self._registered.items(): for service, commands in self._registered.items():
if name in commands: if name in commands:
return service return service
raise ValueError(name) raise ValueError(name) # pragma: no cover
def _register_soap_wrapper(self, name: str) -> None: def _register_soap_wrapper(self, name: str) -> None:
annotations: typing.Dict[str, typing.Any] = typing.get_type_hints(getattr(self, name)) annotations: typing.Dict[str, typing.Any] = typing.get_type_hints(getattr(self, name))
@ -173,7 +173,7 @@ class SOAPCommands:
self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, result, None, time.time())) self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, result, None, time.time()))
except Exception as err: except Exception as err:
if isinstance(err, asyncio.CancelledError): if isinstance(err, asyncio.CancelledError):
raise raise # pragma: no cover
self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, None, err, time.time())) self._request_debug_infos.append(SCPDRequestDebuggingInfo(name, kwargs, xml_bytes, None, err, time.time()))
raise UPnPError(f"Raised {str(type(err).__name__)}({str(err)}) parsing response for {name}") raise UPnPError(f"Raised {str(type(err).__name__)}({str(err)}) parsing response for {name}")
return result return result
@ -253,8 +253,8 @@ class SOAPCommands:
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
assert name in self._wrappers_no_args assert name in self._wrappers_no_args
result: str = await self._wrappers_no_args[name]() result: str = await self._wrappers_no_args[name]()
if not result: # if not result:
raise UPnPError("Got null external ip address") # raise UPnPError("Got null external ip address")
return result return result
# async def GetNATRSIPStatus(self) -> Tuple[bool, bool]: # async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:

View file

@ -1,4 +1,3 @@
import socket
from collections import OrderedDict from collections import OrderedDict
import typing import typing
import netifaces import netifaces
@ -25,11 +24,11 @@ def _get_gateways() -> typing.Dict[typing.Union[str, int],
def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]: def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]:
gateways = _get_gateways() gateways = _get_gateways()
infos = gateways[socket.AF_INET] infos = gateways[netifaces.AF_INET]
assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict") assert isinstance(infos, list), TypeError(f"expected list from netifaces, got a dict")
interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos interface_infos: typing.List[typing.Tuple[str, str, bool]] = infos
result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict( result: typing.Dict[str, typing.Tuple[str, str]] = OrderedDict(
(interface_name, (router_address, ifaddresses(interface_name)[socket.AF_INET][0]['addr'])) (interface_name, (router_address, ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']))
for router_address, interface_name, _ in interface_infos for router_address, interface_name, _ in interface_infos
) )
for interface_name in _get_interfaces(): for interface_name in _get_interfaces():
@ -43,7 +42,7 @@ def get_interfaces() -> typing.Dict[str, typing.Tuple[str, str]]:
_default = gateways['default'] _default = gateways['default']
assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list") assert isinstance(_default, dict), TypeError(f"expected dict from netifaces, got a list")
default: typing.Dict[int, typing.Tuple[str, str]] = _default default: typing.Dict[int, typing.Tuple[str, str]] = _default
result['default'] = result[default[socket.AF_INET][1]] result['default'] = result[default[netifaces.AF_INET][1]]
return result return result

View file

@ -13,7 +13,7 @@ CONTENT_NO_XML_VERSION_PATTERN = re.compile(
def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes, def serialize_soap_post(method: str, param_names: typing.List[str], service_id: bytes, gateway_address: bytes,
control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes: control_url: bytes, **kwargs: typing.Dict[str, str]) -> bytes:
args = "".join(f"<{n}>{kwargs.get(n)}</{n}>" for n in param_names) args = "".join(f"<{param_name}>{kwargs.get(param_name, '')}</{param_name}>" for param_name in param_names)
soap_body = (f'\r\n{XML_VERSION}\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" ' soap_body = (f'\r\n{XML_VERSION}\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
f's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>' f's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
f'<u:{method} xmlns:u="{service_id.decode()}">{args}</u:{method}></s:Body></s:Envelope>') f'<u:{method} xmlns:u="{service_id.decode()}">{args}</u:{method}></s:Body></s:Envelope>')

View file

@ -410,12 +410,12 @@ def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address
u = await UPnP.discover( u = await UPnP.discover(
lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop lan_address, gateway_address, timeout, igd_args, interface_name, loop=loop
) )
except UPnPError as err: except UPnPError as err: # pragma: no cover
fut.set_exception(err) fut.set_exception(err)
return return
if method not in cli_commands: if method not in cli_commands:
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method)) # pragma: no cover
return return # pragma: no cover
else: else:
fn = getattr(u, method) fn = getattr(u, method)
@ -424,8 +424,7 @@ def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address
fut.set_result(result) fut.set_result(result)
except UPnPError as err: except UPnPError as err:
fut.set_exception(err) fut.set_exception(err)
except Exception as err: # pragma: no cover
except Exception as err:
log.exception("uncaught error") log.exception("uncaught error")
fut.set_exception(UPnPError("uncaught error: %s" % str(err))) fut.set_exception(UPnPError("uncaught error: %s" % str(err)))

View file

@ -368,3 +368,8 @@ class TestActiontec(AsyncioTestCase):
class TestNewMediaNet(TestActiontec): class TestNewMediaNet(TestActiontec):
name = "NewMedia-NET GmbH Generic X86" name = "NewMedia-NET GmbH Generic X86"
async def replay(self, u: UPnP):
self.assertEqual('11.222.33.111', await u.get_external_ip())
await u.get_redirects()
# print(await u.get_next_mapping(4567, 'UDP', 'aioupnp test mapping'))

View file

@ -1,4 +1,6 @@
from unittest import mock from unittest import mock
from collections import OrderedDict
from aioupnp import interfaces
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.upnp import UPnP from aioupnp.upnp import UPnP
from tests import AsyncioTestCase from tests import AsyncioTestCase
@ -29,7 +31,6 @@ class mock_netifaces:
@staticmethod @staticmethod
def ifaddresses(interface): def ifaddresses(interface):
return { return {
"test0": {
17: [ 17: [
{ {
"addr": "01:02:03:04:05:06", "addr": "01:02:03:04:05:06",
@ -43,8 +44,13 @@ class mock_netifaces:
"broadcast": "192.168.1.255" "broadcast": "192.168.1.255"
} }
], ],
}, }
}[interface]
class mock_netifaces_extra_interface(mock_netifaces):
@staticmethod
def interfaces():
return ['lo', 'test0', 'test1']
class TestParseInterfaces(AsyncioTestCase): class TestParseInterfaces(AsyncioTestCase):
@ -68,3 +74,16 @@ class TestParseInterfaces(AsyncioTestCase):
else: else:
self.assertTrue(False) self.assertTrue(False)
self.assertEqual(len(checked), 1) self.assertEqual(len(checked), 1)
def test_guess_gateway(self):
# handle edge case where netifaces gives more interfaces than it does gateways
with mock.patch('aioupnp.interfaces.get_netifaces') as patch:
patch.return_value = mock_netifaces_extra_interface
self.assertDictEqual(
OrderedDict(
[
('test0', ('192.168.1.1', '192.168.1.2')),
('test1', ('192.168.1.1', '192.168.1.2')),
('default', ('192.168.1.1', '192.168.1.2'))
]), interfaces.get_interfaces()
)