From 99fabb2a65041ecbb8f1a1d29d70bc2bc231f0ca Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Wed, 15 Jan 2020 16:07:21 -0500 Subject: [PATCH 1/3] handle ConnectionError in scpd_get and scpd_post --- aioupnp/protocols/scpd.py | 18 ++++++++++++------ tests/__init__.py | 5 ++++- tests/protocols/test_scpd.py | 12 ++++++++++++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/aioupnp/protocols/scpd.py b/aioupnp/protocols/scpd.py index cc57bb9..a1bd4dc 100644 --- a/aioupnp/protocols/scpd.py +++ b/aioupnp/protocols/scpd.py @@ -107,9 +107,12 @@ async def scpd_get(control_url: str, address: str, port: int, packet = serialize_scpd_get(control_url, address) finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future() proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished) - connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( - proto_factory, address, port - ) + try: + connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( + proto_factory, address, port + ) + except ConnectionError as err: + return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})") protocol = connect_tup[1] transport = connect_tup[0] assert isinstance(protocol, SCPDHTTPClientProtocol) @@ -145,9 +148,12 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\ SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode()) - connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( - proto_factory, address, port - ) + try: + connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( + proto_factory, address, port + ) + except ConnectionError as err: + return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})") protocol = connect_tup[1] transport = connect_tup[0] assert isinstance(protocol, SCPDHTTPClientProtocol) diff --git a/tests/__init__.py b/tests/__init__.py index 27ad4f5..45c9d29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,7 +18,7 @@ except ImportError: @contextlib.contextmanager def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False, - raise_oserror_on_bind=False): + raise_oserror_on_bind=False, raise_connectionerror=False): sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] udp_replies = udp_replies or {} @@ -26,6 +26,9 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r tcp_replies = tcp_replies or {} async def create_connection(protocol_factory, host=None, port=None): + if raise_connectionerror: + raise ConnectionRefusedError() + def get_write(p: asyncio.Protocol): def _write(data): sent_tcp_packets.append(data) diff --git a/tests/protocols/test_scpd.py b/tests/protocols/test_scpd.py index 876008e..357b880 100644 --- a/tests/protocols/test_scpd.py +++ b/tests/protocols/test_scpd.py @@ -202,6 +202,18 @@ class TestSCPDPost(AsyncioTestCase): self.assertEqual(b'', raw) self.assertDictEqual({}, result) + async def test_scpd_post_connection_error(self): + sent = [] + replies = {} + with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent, raise_connectionerror=True): + result, raw, err = await scpd_post( + self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop + ) + self.assertIsInstance(err, UPnPError) + self.assertEqual('ConnectionRefusedError()', str(err)) + self.assertEqual(b'', raw) + self.assertDictEqual({}, result) + async def test_scpd_post_bad_xml_response(self): sent = [] replies = {self.post_bytes: self.bad_envelope_response} -- 2.45.2 From 1c6fd317968d09be073a3f4680601dca897ba36b Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Wed, 15 Jan 2020 16:08:38 -0500 Subject: [PATCH 2/3] remove mock testing requirement --- .travis.yml | 4 ++-- setup.py | 7 +------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 13bb672..d69d447 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ jobs: name: "mypy" before_install: - pip install mypy lxml - - pip install -e .[test] + - pip install -e . script: - mypy aioupnp --txt-report . --scripts-are-modules; cat index.txt; rm index.txt @@ -20,7 +20,7 @@ jobs: python: "3.8" before_install: - pip install pylint coverage - - pip install -e .[test] + - pip install -e . script: - HOME=/tmp coverage run -m unittest discover -v tests diff --git a/setup.py b/setup.py index 145fcb4..a0cad75 100644 --- a/setup.py +++ b/setup.py @@ -38,10 +38,5 @@ setup( install_requires=[ 'netifaces', 'defusedxml' - ], - extras_require={ - 'test': ( - 'mock', - ) - } + ] ) -- 2.45.2 From 6d52d76af6b2bcc50b852ed2c6dae4560fe4bea0 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Wed, 15 Jan 2020 16:28:53 -0500 Subject: [PATCH 3/3] bytes/str --- aioupnp/serialization/scpd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aioupnp/serialization/scpd.py b/aioupnp/serialization/scpd.py index 7694455..4026302 100644 --- a/aioupnp/serialization/scpd.py +++ b/aioupnp/serialization/scpd.py @@ -7,7 +7,7 @@ from aioupnp.util import flatten_keys CONTENT_PATTERN = re.compile( - "(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode() + "(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)" ) XML_ROOT_SANITY_PATTERN = re.compile( @@ -39,8 +39,8 @@ def serialize_scpd_get(path: str, address: str) -> bytes: def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]: if XML_VERSION.encode() in content: - parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content) - xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode()) + parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content.decode()) + xml_dict = xml_to_dict('' if not parsed else parsed[0][0]) return parse_device_dict(xml_dict) return {} -- 2.45.2