Fix uncaught ConnectionError in scpd_post and scpd_get #22
6 changed files with 34 additions and 18 deletions
|
@ -9,7 +9,7 @@ jobs:
|
||||||
name: "mypy"
|
name: "mypy"
|
||||||
before_install:
|
before_install:
|
||||||
- pip install mypy lxml
|
- pip install mypy lxml
|
||||||
- pip install -e .[test]
|
- pip install -e .
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- mypy aioupnp --txt-report . --scripts-are-modules; cat index.txt; rm index.txt
|
- mypy aioupnp --txt-report . --scripts-are-modules; cat index.txt; rm index.txt
|
||||||
|
@ -20,7 +20,7 @@ jobs:
|
||||||
python: "3.8"
|
python: "3.8"
|
||||||
before_install:
|
before_install:
|
||||||
- pip install pylint coverage
|
- pip install pylint coverage
|
||||||
- pip install -e .[test]
|
- pip install -e .
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- HOME=/tmp coverage run -m unittest discover -v tests
|
- HOME=/tmp coverage run -m unittest discover -v tests
|
||||||
|
|
|
@ -107,9 +107,12 @@ async def scpd_get(control_url: str, address: str, port: int,
|
||||||
packet = serialize_scpd_get(control_url, address)
|
packet = serialize_scpd_get(control_url, address)
|
||||||
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
|
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
|
||||||
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
|
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
|
||||||
|
try:
|
||||||
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
||||||
proto_factory, address, port
|
proto_factory, address, port
|
||||||
)
|
)
|
||||||
|
except ConnectionError as err:
|
||||||
|
return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})")
|
||||||
protocol = connect_tup[1]
|
protocol = connect_tup[1]
|
||||||
transport = connect_tup[0]
|
transport = connect_tup[0]
|
||||||
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
||||||
|
@ -145,9 +148,12 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
|
||||||
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
||||||
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
|
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
|
||||||
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
|
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
|
||||||
|
try:
|
||||||
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
|
||||||
proto_factory, address, port
|
proto_factory, address, port
|
||||||
)
|
)
|
||||||
|
except ConnectionError as err:
|
||||||
|
return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})")
|
||||||
protocol = connect_tup[1]
|
protocol = connect_tup[1]
|
||||||
transport = connect_tup[0]
|
transport = connect_tup[0]
|
||||||
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from aioupnp.util import flatten_keys
|
||||||
|
|
||||||
|
|
||||||
CONTENT_PATTERN = re.compile(
|
CONTENT_PATTERN = re.compile(
|
||||||
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
|
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)"
|
||||||
)
|
)
|
||||||
|
|
||||||
XML_ROOT_SANITY_PATTERN = re.compile(
|
XML_ROOT_SANITY_PATTERN = re.compile(
|
||||||
|
@ -39,8 +39,8 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
|
||||||
|
|
||||||
def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
|
def deserialize_scpd_get_response(content: bytes) -> Dict[str, Any]:
|
||||||
if XML_VERSION.encode() in content:
|
if XML_VERSION.encode() in content:
|
||||||
parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content)
|
parsed: List[Tuple[bytes, bytes]] = CONTENT_PATTERN.findall(content.decode())
|
||||||
xml_dict = xml_to_dict((b'' if not parsed else parsed[0][0]).decode())
|
xml_dict = xml_to_dict('' if not parsed else parsed[0][0])
|
||||||
return parse_device_dict(xml_dict)
|
return parse_device_dict(xml_dict)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
7
setup.py
7
setup.py
|
@ -38,10 +38,5 @@ setup(
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'netifaces',
|
'netifaces',
|
||||||
'defusedxml'
|
'defusedxml'
|
||||||
],
|
]
|
||||||
extras_require={
|
|
||||||
'test': (
|
|
||||||
'mock',
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ except ImportError:
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
|
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
|
||||||
raise_oserror_on_bind=False):
|
raise_oserror_on_bind=False, raise_connectionerror=False):
|
||||||
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
||||||
udp_replies = udp_replies or {}
|
udp_replies = udp_replies or {}
|
||||||
|
|
||||||
|
@ -26,6 +26,9 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
|
||||||
tcp_replies = tcp_replies or {}
|
tcp_replies = tcp_replies or {}
|
||||||
|
|
||||||
async def create_connection(protocol_factory, host=None, port=None):
|
async def create_connection(protocol_factory, host=None, port=None):
|
||||||
|
if raise_connectionerror:
|
||||||
|
raise ConnectionRefusedError()
|
||||||
|
|
||||||
def get_write(p: asyncio.Protocol):
|
def get_write(p: asyncio.Protocol):
|
||||||
def _write(data):
|
def _write(data):
|
||||||
sent_tcp_packets.append(data)
|
sent_tcp_packets.append(data)
|
||||||
|
|
|
@ -202,6 +202,18 @@ class TestSCPDPost(AsyncioTestCase):
|
||||||
self.assertEqual(b'', raw)
|
self.assertEqual(b'', raw)
|
||||||
self.assertDictEqual({}, result)
|
self.assertDictEqual({}, result)
|
||||||
|
|
||||||
|
async def test_scpd_post_connection_error(self):
|
||||||
|
sent = []
|
||||||
|
replies = {}
|
||||||
|
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent, raise_connectionerror=True):
|
||||||
|
result, raw, err = await scpd_post(
|
||||||
|
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
|
||||||
|
)
|
||||||
|
self.assertIsInstance(err, UPnPError)
|
||||||
|
self.assertEqual('ConnectionRefusedError()', str(err))
|
||||||
|
self.assertEqual(b'', raw)
|
||||||
|
self.assertDictEqual({}, result)
|
||||||
|
|
||||||
async def test_scpd_post_bad_xml_response(self):
|
async def test_scpd_post_bad_xml_response(self):
|
||||||
sent = []
|
sent = []
|
||||||
replies = {self.post_bytes: self.bad_envelope_response}
|
replies = {self.post_bytes: self.bad_envelope_response}
|
||||||
|
|
Loading…
Reference in a new issue