Fix uncaught ConnectionError in scpd_post and scpd_get #22

Merged
jackrobison merged 3 commits from catch-connection-error into master 2020-01-15 22:37:03 +01:00
3 changed files with 28 additions and 7 deletions
Showing only changes of commit 99fabb2a65 - Show all commits

View file

@ -107,9 +107,12 @@ async def scpd_get(control_url: str, address: str, port: int,
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future() finished: 'asyncio.Future[typing.Tuple[bytes, bytes, int, bytes]]' = loop.create_future()
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished) proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda: SCPDHTTPClientProtocol(packet, finished)
try:
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port proto_factory, address, port
) )
except ConnectionError as err:
return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})")
protocol = connect_tup[1] protocol = connect_tup[1]
transport = connect_tup[0] transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
@ -145,9 +148,12 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\ proto_factory: typing.Callable[[], SCPDHTTPClientProtocol] = lambda:\
SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode()) SCPDHTTPClientProtocol(packet, finished, soap_method=method, soap_service_id=service_id.decode())
try:
connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection( connect_tup: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_connection(
proto_factory, address, port proto_factory, address, port
) )
except ConnectionError as err:
return {}, b'', UPnPError(f"{err.__class__.__name__}({str(err)})")
protocol = connect_tup[1] protocol = connect_tup[1]
transport = connect_tup[0] transport = connect_tup[0]
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)

View file

@ -18,7 +18,7 @@ except ImportError:
@contextlib.contextmanager @contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False, tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
raise_oserror_on_bind=False): raise_oserror_on_bind=False, raise_connectionerror=False):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {} udp_replies = udp_replies or {}
@ -26,6 +26,9 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
tcp_replies = tcp_replies or {} tcp_replies = tcp_replies or {}
async def create_connection(protocol_factory, host=None, port=None): async def create_connection(protocol_factory, host=None, port=None):
if raise_connectionerror:
raise ConnectionRefusedError()
def get_write(p: asyncio.Protocol): def get_write(p: asyncio.Protocol):
def _write(data): def _write(data):
sent_tcp_packets.append(data) sent_tcp_packets.append(data)

View file

@ -202,6 +202,18 @@ class TestSCPDPost(AsyncioTestCase):
self.assertEqual(b'', raw) self.assertEqual(b'', raw)
self.assertDictEqual({}, result) self.assertDictEqual({}, result)
async def test_scpd_post_connection_error(self):
sent = []
replies = {}
with mock_tcp_and_udp(self.loop, tcp_replies=replies, sent_tcp_packets=sent, raise_connectionerror=True):
result, raw, err = await scpd_post(
self.path, self.gateway_address, self.port, self.method, self.param_names, self.st, self.loop
)
self.assertIsInstance(err, UPnPError)
self.assertEqual('ConnectionRefusedError()', str(err))
self.assertEqual(b'', raw)
self.assertDictEqual({}, result)
async def test_scpd_post_bad_xml_response(self): async def test_scpd_post_bad_xml_response(self):
sent = [] sent = []
replies = {self.post_bytes: self.bad_envelope_response} replies = {self.post_bytes: self.bad_envelope_response}