simplify tests

-make it easy to add new device test cases
This commit is contained in:
Jack Robison 2019-08-15 16:09:16 -04:00
parent 4741842582
commit 21a379eb8c
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 196 additions and 151 deletions

View file

@ -85,7 +85,7 @@ class Gateway:
self.urn: bytes = (ok_packet.st or '').encode()
self._xml_response: bytes = b""
self._service_descriptors: Dict[str, bytes] = {}
self._service_descriptors: Dict[str, str] = {}
self.base_address, self.port = parse_location(self.location)
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
@ -139,17 +139,17 @@ class Gateway:
def debug_gateway(self) -> Dict[str, typing.Union[str, bytes, int, Dict, List]]:
return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip,
'gateway_address': self.base_ip.decode(),
'server': self.server.decode(),
'urlBase': self.url_base or '',
'location': self.location.decode(),
"specVersion": self.spec_version or '',
'usn': self.usn.decode(),
'urn': self.urn.decode(),
'gateway_xml': self._xml_response,
'gateway_xml': self._xml_response.decode(),
'services_xml': self._service_descriptors,
'services': {service.SCPDURL: service.as_dict() for service in self._services},
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
'm_search_args': OrderedDict(self._m_search_args),
'reply': self._ok_packet.as_dict(),
'soap_port': self.port,
'registered_soap_commands': self._registered_commands,
@ -179,7 +179,7 @@ class Gateway:
try:
gateway = cls(datagram, m_search_args, lan_address, gateway_address, loop=loop)
log.debug('get gateway descriptor %s', datagram.location)
await gateway.discover_commands(loop)
await gateway.discover_commands()
requirements_met = all([gateway.commands.is_registered(required) for required in required_commands])
if not requirements_met:
not_met = [
@ -227,8 +227,10 @@ class Gateway:
results: typing.List['asyncio.Future[Gateway]'] = list(done)
return results[0].result()
async def discover_commands(self, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None:
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port, loop=loop)
async def discover_commands(self) -> None:
response, xml_bytes, get_err = await scpd_get(
self.path.decode(), self.base_ip.decode(), self.port, loop=self._loop
)
self._xml_response = xml_bytes
if get_err is not None:
raise get_err
@ -264,7 +266,7 @@ class Gateway:
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
await self.register_commands(self.services[service_type], loop)
await self.register_commands(self.services[service_type], self._loop)
return None
async def register_commands(self, service: Service,
@ -276,7 +278,7 @@ class Gateway:
log.debug("get descriptor for %s from %s", service.serviceType, service.SCPDURL)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port, loop=loop)
self._service_descriptors[service.SCPDURL] = xml_bytes
self._service_descriptors[service.SCPDURL] = xml_bytes.decode()
if get_err is not None:
log.debug("failed to get descriptor for %s from %s", service.serviceType, service.SCPDURL)

View file

@ -45,7 +45,7 @@ class SCPDHTTPClientProtocol(Protocol):
and devices respond with an invalid HTTP version line
"""
def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]',
def __init__(self, message: bytes, finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]',
soap_method: typing.Optional[str] = None, soap_service_id: typing.Optional[str] = None) -> None:
self.message = message
self.response_buff = b""
@ -85,7 +85,7 @@ class SCPDHTTPClientProtocol(Protocol):
self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
if self._content_length == len(body):
self.finished.set_result((body, self._response_code, self._response_msg))
self.finished.set_result((self.response_buff, body, self._response_code, self._response_msg))
elif self._content_length > len(body):
pass
else:
@ -105,7 +105,7 @@ async def scpd_get(control_url: str, address: str, port: int,
typing.Dict[str, typing.Any], bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
packet = serialize_scpd_get(control_url, address)
finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]' = asyncio.Future(loop=loop)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port
@ -115,24 +115,25 @@ async def scpd_get(control_url: str, address: str, port: int,
assert isinstance(protocol, SCPDHTTPClientProtocol)
error = None
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
wait_task: typing.Awaitable[typing.Tuple[bytes, bytes, int, bytes]] = asyncio.wait_for(protocol.finished, 1.0, loop=loop)
body = b''
raw_response = b''
try:
body, response_code, response_msg = await wait_task
raw_response, body, response_code, response_msg = await wait_task
except asyncio.TimeoutError:
error = UPnPError("get request timed out")
body = b''
except UPnPError as err:
error = err
body = protocol.response_buff
raw_response = protocol.response_buff
finally:
transport.close()
if not error:
try:
return deserialize_scpd_get_response(body), body, None
return deserialize_scpd_get_response(body), raw_response, None
except Exception as err:
error = UPnPError(err)
return {}, body, error
return {}, raw_response, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
@ -140,7 +141,7 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
**kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, bytes, typing.Optional[Exception]]:
loop = loop or asyncio.get_event_loop()
finished: 'asyncio.Future[typing.Tuple[bytes, int, bytes]]' = asyncio.Future(loop=loop)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = asyncio.Future(loop=loop)
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
@ -152,18 +153,17 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
assert isinstance(protocol, SCPDHTTPClientProtocol)
try:
wait_task: typing.Awaitable[typing.Tuple[bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
body, response_code, response_msg = await wait_task
wait_task: typing.Awaitable[typing.Tuple[bytes, bytes, int, bytes]] = asyncio.wait_for(finished, 1.0, loop=loop)
raw_response, body, response_code, response_msg = await wait_task
except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout")
except UPnPError as err:
return {}, protocol.response_buff, err
finally:
# raw_response = protocol.response_buff
transport.close()
try:
return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None
deserialize_soap_post_response(body, method, service_id.decode()), raw_response, None
)
except Exception as err:
return {}, body, UPnPError(err)
return {}, raw_response, UPnPError(err)

View file

@ -126,7 +126,7 @@ class TestSCPDGet(AsyncioTestCase):
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent):
result, raw, err = await scpd_get(self.path, self.lan_address, self.port, self.loop)
self.assertDictEqual({}, result)
self.assertEqual(self.bad_xml, raw)
self.assertEqual(self.bad_response, raw)
self.assertTrue(isinstance(err, UPnPError))
self.assertTrue(str(err).startswith('no element found'))
@ -187,7 +187,7 @@ class TestSCPDPost(AsyncioTestCase):
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
)
self.assertEqual(None, err)
self.assertEqual(self.envelope, raw)
self.assertEqual(self.post_response, raw)
self.assertDictEqual({'NewExternalIPAddress': '11.22.33.44'}, result)
async def test_scpd_post_timeout(self):
@ -211,7 +211,7 @@ class TestSCPDPost(AsyncioTestCase):
)
self.assertTrue(isinstance(err, UPnPError))
self.assertTrue(str(err).startswith('no element found'))
self.assertEqual(self.bad_envelope, raw)
self.assertEqual(self.bad_envelope_response, raw)
self.assertDictEqual({}, result)
async def test_scpd_post_overrun_response(self):

File diff suppressed because one or more lines are too long

View file

@ -44,8 +44,8 @@ class TestGetExternalIPAddress(UPnPCommandTestCase):
async def test_get_external_ip(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
self.assertEqual("11.222.3.44", external_ip)
@ -60,8 +60,8 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
b"<derp>11.222.3.44</derp>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
with self.assertRaises(UPnPError):
await upnp.get_external_ip()
@ -71,8 +71,8 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
b"<newexternalipaddress>11.222.3.44</newexternalipaddress>\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
self.assertEqual("11.222.3.44", external_ip)
@ -82,12 +82,13 @@ class TestMalformedGetExternalIPAddressResponse(UPnPCommandTestCase):
b"11.222.3.44\n</u:GetExternalIPAddressResponse>\n\t</s:Body>\n</s:Envelope>\n"})
self.addCleanup(self.replies.pop, self.get_ip_request)
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
external_ip = await upnp.get_external_ip()
self.assertEqual("11.222.3.44", external_ip)
class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
def setUp(self) -> None:
query = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\nHost: 11.2.3.4\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\nContent-Length: 341\r\nContent-Type: text/xml\r\nSOAPAction: "urn:schemas-upnp-org:service:WANIPConnection:1#GetGenericPortMappingEntry"\r\nConnection: Close\r\nCache-Control: no-cache\r\nPragma: no-cache\r\n\r\n<?xml version="1.0"?>\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body><u:GetGenericPortMappingEntry xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1"><NewPortMappingIndex>0</NewPortMappingIndex></u:GetGenericPortMappingEntry></s:Body></s:Envelope>\r\n'
@ -99,8 +100,8 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
async def test_get_port_mapping_by_index(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
result = await upnp.get_port_mapping_by_index(0)
self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True,
@ -121,8 +122,8 @@ class TestGetNextPortMapping(UPnPCommandTestCase):
async def test_get_next_mapping(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
ext_port = await upnp.get_next_mapping(4567, "UDP", "aioupnp test mapping")
self.assertEqual(4567, ext_port)
@ -141,8 +142,8 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase):
async def test_get_specific_port_mapping(self):
with mock_tcp_and_udp(self.loop, tcp_replies=self.replies):
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address)
await gateway.discover_commands(self.loop)
gateway = Gateway(self.reply, self.m_search_args, self.client_address, self.gateway_address, loop=self.loop)
await gateway.discover_commands()
upnp = UPnP(self.client_address, self.gateway_address, gateway)
try:
await upnp.get_specific_port_mapping(1000, 'UDP')