fix scpd post dropping some responses

-improve debug_gateway and generate_test_data
This commit is contained in:
Jack Robison 2018-10-18 10:46:12 -04:00
parent efad9b0d51
commit 55e1621637
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 161 additions and 78 deletions

View file

@ -125,13 +125,23 @@ class Gateway:
return None
@property
def _soap_requests(self) -> Dict:
return {
name: getattr(self.commands, name)._requests for name in self._registered_commands.keys()
}
def soap_requests(self) -> List:
soap_call_infos = []
for name in self._registered_commands.keys():
if not hasattr(getattr(self.commands, name), "_requests"):
continue
soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
request_args, raw_response, decoded_response, soap_error, ts
) in getattr(self.commands, name)._requests
])
soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos
def debug_gateway(self) -> Dict:
return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip,
'gateway_descriptor': self.gateway_descriptor(),
'gateway_xml': self._xml_response,
@ -142,7 +152,7 @@ class Gateway:
'soap_port': self.port,
'registered_soap_commands': self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions,
'soap_requests': self._soap_requests
'soap_requests': self.soap_requests
}
@classmethod
@ -163,7 +173,7 @@ class Gateway:
await gateway.discover_commands(soap_socket)
log.debug('found gateway device %s', datagram.location)
return gateway
except asyncio.TimeoutError:
except (asyncio.TimeoutError, UPnPError):
log.debug("get %s timed out, looking for other devices", datagram.location)
ignored.add(datagram.location)
continue
@ -189,8 +199,10 @@ class Gateway:
return result
async def discover_commands(self, soap_socket: socket.socket = None):
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
self._xml_response = xml_bytes
if get_err is not None:
raise get_err
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base:
@ -207,9 +219,11 @@ class Gateway:
async def register_commands(self, service: Service, soap_socket: socket.socket = None):
if not service.SCPDURL:
raise UPnPError("no scpd url")
service_dict, xml_bytes = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
self._service_descriptors[service.SCPDURL] = xml_bytes
if get_err is not None:
raise get_err
if not service_dict:
return

View file

@ -1,10 +1,13 @@
import logging
import socket
import typing
import re
from collections import OrderedDict
from xml.etree import ElementTree
import asyncio
from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.serialization.scpd import deserialize_scpd_get_response
from aioupnp.serialization.scpd import serialize_scpd_get
from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response
@ -13,87 +16,126 @@ from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_pos
log = logging.getLogger(__name__)
class SCPDHTTPClientProtocol(Protocol):
POST = 'POST'
GET = 'GET'
HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None, close_after_send: bool = False) -> None:
self.method = method
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
'soap args not provided'
def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]:
lines = response.split(b'\r\n')
headers = OrderedDict([
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
for l in response.split(b'\r\n')
])
if len(lines) != len(headers):
raise ValueError("duplicate headers")
http_response = tuple(headers.keys())[0]
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0]
del headers[http_response]
return headers, int(response_code), message
class SCPDHTTPClientProtocol(Protocol):
"""
This class will make HTTP GET and POST requests
It differs from spec HTTP in that the version string can be invalid, all we care about is the xml body
and devices respond with an invalid HTTP version line
"""
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None) -> None:
self.message = message
self.response_buff = b""
self.finished = finished
self.soap_method = soap_method
self.soap_service_id = soap_service_id
self.close_after_send = close_after_send
self._response_code: int = 0
self._response_msg: bytes = b""
self._content_length: int = 0
self._got_headers = False
self._headers: dict = {}
self._body = b""
def connection_made(self, transport):
transport.write(self.message)
if self.close_after_send:
self.finished.set_result(None)
def data_received(self, data):
self.response_buff += data
if self.method == self.GET:
try:
packet = deserialize_scpd_get_response(self.response_buff)
if packet:
self.finished.set_result(packet)
return
except ElementTree.ParseError:
pass
except UPnPError as err:
self.finished.set_exception(err)
elif self.method == self.POST:
try:
packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id)
if packet:
self.finished.set_result(packet)
return
except ElementTree.ParseError:
pass
except UPnPError as err:
self.finished.set_exception(err)
for i, line in enumerate(self.response_buff.split(b'\r\n')):
if not line: # we hit the blank line between the headers and the body
if i == (len(self.response_buff.split(b'\r\n')) - 1):
continue # the body is still yet to be written
if not self._got_headers:
self._headers, self._response_code, self._response_msg = parse_headers(
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
)
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length')
if content_length is None:
return
self._content_length = int(content_length or 0)
self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
if self._content_length == len(body):
self.finished.set_result((body, self._response_code, self._response_msg))
elif self._content_length > len(body):
pass
else:
self.finished.set_exception(
UPnPError(
"too many bytes written to response (%i vs %i expected)" % (
len(body), self._content_length
)
)
)
return
async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes]:
async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes,
typing.Optional[Exception]]:
loop = asyncio.get_running_loop()
finished: asyncio.Future = asyncio.Future()
packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
lambda : SCPDHTTPClientProtocol(packet, finished), address, port
)
assert isinstance(protocol, SCPDHTTPClientProtocol)
parsed: typing.Dict = {}
error = None
try:
parsed = await asyncio.wait_for(finished, 1.0)
except UPnPError:
return parsed, protocol.response_buff
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
except asyncio.TimeoutError:
error = UPnPError("get request timed out")
body = b''
finally:
transport.close()
return parsed, protocol.response_buff
if not error:
try:
return deserialize_scpd_get_response(body), body, None
except ElementTree.ParseError as err:
error = UPnPError(err)
return {}, body, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
close_after_send: bool, soap_socket: socket.socket = None,
**kwargs) -> typing.Tuple[typing.Dict, bytes]:
soap_socket: socket.socket = None, **kwargs) -> typing.Tuple[typing.Dict, bytes,
typing.Optional[Exception]]:
loop = asyncio.get_running_loop()
finished: asyncio.Future = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol(
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
close_after_send=close_after_send
packet, finished, soap_method=method, soap_service_id=service_id.decode(),
), address, port, sock=soap_socket
)
assert isinstance(protocol, SCPDHTTPClientProtocol)
parsed: typing.Dict = {}
try:
parsed = await asyncio.wait_for(finished, 1.0)
except UPnPError:
return parsed, protocol.response_buff
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout")
finally:
transport.close()
return parsed, protocol.response_buff
try:
return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None
)
except (ElementTree.ParseError, UPnPError) as err:
return {}, body, UPnPError(err)

View file

@ -2,6 +2,7 @@ import logging
import socket
import asyncio
import typing
import time
from aioupnp.protocols.scpd import scpd_post
from aioupnp.fault import UPnPError
@ -39,20 +40,21 @@ class SOAPCommand:
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
close_after_send = not self.return_types or self.return_types == [None]
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
try:
response, xml_bytes = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
close_after_send, self.soap_socket, **soap_kwargs
)
except asyncio.TimeoutError as err:
raise UPnPError(err)
self._requests.append((soap_kwargs, xml_bytes))
response, xml_bytes, err = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order,
self.service_id, self.soap_socket, **soap_kwargs
)
if err is not None:
self._requests.append((soap_kwargs, xml_bytes, None, err, time.time()))
raise err
if not response:
return None
result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
if len(result) == 1:
return result[0]
result = None
else:
recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
if len(recast_result) == 1:
result = recast_result[0]
else:
result = recast_result
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time()))
return result

View file

@ -53,11 +53,12 @@ def deserialize_soap_post_response(response: bytes, method: str, service_id: str
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault
response_key = None
if not body:
return {}
for key in body:
if method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
raise UPnPError("unknown response fields for %s: %s" % (method, body))
return body[response_key]

View file

@ -133,11 +133,11 @@ class UPnP:
result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
)
if isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
return {}
except UPnPError:
return {}
pass
return {}
@cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
@ -196,18 +196,42 @@ class UPnP:
except (UPnPError, NotImplementedError):
print("failed to get the external ip")
try:
redirects = await self.get_redirects()
print("got redirects:\n%s" % redirects)
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
try:
ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping")
print("set up external mapping to port %i" % ext_port)
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
await self.delete_port_mapping(ext_port, "UDP")
print("deleted mapping")
except (UPnPError, NotImplementedError):
print("failed to add and remove a mapping")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
if self.gateway.devices:
device = list(self.gateway.devices.values())[0]
assert device.manufacturer and device.modelName