From 99fabb2a65041ecbb8f1a1d29d70bc2bc231f0ca Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Wed, 15 Jan 2020 16:07:21 -0500
Subject: [PATCH 1/3] handle ConnectionError in scpd_get and scpd_post

---
 aioupnp/protocols/scpd.py    | 18 ++++++++++++------
 tests/__init__.py            |  5 ++++-
 tests/protocols/test_scpd.py | 12 ++++++++++++
 3 files changed, 28 insertions(+), 7 deletions(-)

diff --git a/aioupnp/protocols/scpd.py b/aioupnp/protocols/scpd.py
index cc57bb9..a1bd4dc 100644
--- a/aioupnp/protocols/scpd.py
+++ b/aioupnp/protocols/scpd.py
@@ -107,9 +107,12 @@ async def scpd_get(control_url: str, address: str, port: int,
     packet = serialize_scpd_get(control_url, address)
     finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
     proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
-    connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
-        proto_factory, address, port
-    )
+    try:
+        connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
+            proto_factory, address, port
+        )
+    except ConnectionError as err:
+        return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})")
     protocol = connect_tup[1]
     transport = connect_tup[0]
     assert isinstance(protocol, SCPDHTTPClientProtocol)
@@ -145,9 +148,12 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
     packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
     proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
         SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
-    connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
-        proto_factory, address, port
-    )
+    try:
+        connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
+            proto_factory, address, port
+        )
+    except ConnectionError as err:
+        return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})")
     protocol = connect_tup[1]
     transport = connect_tup[0]
     assert isinstance(protocol, SCPDHTTPClientProtocol)
diff --git a/tests/__init__.py b/tests/__init__.py
index 27ad4f5..45c9d29 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -18,7 +18,7 @@ except ImportError:
 @contextlib.contextmanager
 def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
                      tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
-                     raise_oserror_on_bind=False):
+                     raise_oserror_on_bind=False, raise_connectionerror=False):
     sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
     udp_replies = udp_replies or {}
 
@@ -26,6 +26,9 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
     tcp_replies = tcp_replies or {}
 
     async def create_connection(protocol_factory, host=None, port=None):
+        if raise_connectionerror:
+            raise ConnectionRefusedError()
+
         def get_write(p: asyncio.Protocol):
             def _write(data):
                 sent_tcp_packets.append(data)
diff --git a/tests/protocols/test_scpd.py b/tests/protocols/test_scpd.py
index 876008e..357b880 100644
--- a/tests/protocols/test_scpd.py
+++ b/tests/protocols/test_scpd.py
@@ -202,6 +202,18 @@ class TestSCPDPost(AsyncioTestCase):
             self.assertEqual(b'', raw)
             self.assertDictEqual({}, result)
 
+    async def test_scpd_post_connection_error(self):
+        sent = []
+        replies = {}
+        with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent, raise_connectionerror=True):
+            result, raw, err = await scpd_post(
+                self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
+            )
+            self.assertIsInstance(err, UPnPError)
+            self.assertEqual('ConnectionRefusedError()', str(err))
+            self.assertEqual(b'', raw)
+            self.assertDictEqual({}, result)
+
     async def test_scpd_post_bad_xml_response(self):
         sent = []
         replies = {self.post_bytes: self.bad_envelope_response}
-- 
2.49.1


From 1c6fd317968d09be073a3f4680601dca897ba36b Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Wed, 15 Jan 2020 16:08:38 -0500
Subject: [PATCH 2/3] remove mock testing requirement

---
 .travis.yml | 4 ++--
 setup.py    | 7 +------
 2 files changed, 3 insertions(+), 8 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 13bb672..d69d447 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -9,7 +9,7 @@ jobs:
     name: "mypy"
     before_install:
     - pip install mypy lxml
-    - pip install -e .[test]
+    - pip install -e .
 
     script:
     - mypy aioupnp --txt-report . --scripts-are-modules; cat index.txt; rm index.txt
@@ -20,7 +20,7 @@ jobs:
       python: "3.8"
       before_install:
       - pip install pylint coverage
-      - pip install -e .[test]
+      - pip install -e .
 
       script:
       - HOME=/tmp coverage run -m unittest discover -v tests
diff --git a/setup.py b/setup.py
index 145fcb4..a0cad75 100644
--- a/setup.py
+++ b/setup.py
@@ -38,10 +38,5 @@ setup(
     install_requires=[
         'netifaces',
         'defusedxml'
-    ],
-    extras_require={
-        'test': (
-            'mock',
-        )
-    }
+    ]
 )
-- 
2.49.1


From 6d52d76af6b2bcc50b852ed2c6dae4560fe4bea0 Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Wed, 15 Jan 2020 16:28:53 -0500
Subject: [PATCH 3/3] bytes/str

---
 aioupnp/serialization/scpd.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/aioupnp/serialization/scpd.py b/aioupnp/serialization/scpd.py
index 7694455..4026302 100644
--- a/aioupnp/serialization/scpd.py
+++ b/aioupnp/serialization/scpd.py
@@ -7,7 +7,7 @@ from aioupnp.util import flatten_keys
 
 
 CONTENT_PATTERN = re.compile(
-    "(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
+    "(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)"
 )
 
 XML_ROOT_SANITY_PATTERN = re.compile(
@@ -39,8 +39,8 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
 
 def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
     if XML_VERSION.encode() in content:
-        parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content)
-        xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode())
+        parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content.decode())
+        xml_dict = xml_to_dict('' if not parsed else parsed[0][0])
         return parse_device_dict(xml_dict)
     return {}
 
-- 
2.49.1