fix scpd post dropping some responses

-improve debug_gateway and generate_test_data
This commit is contained in:
Jack Robison 2018-10-18 10:46:12 -04:00
parent efad9b0d51
commit 55e1621637
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 161 additions and 78 deletions

View file

@ -125,13 +125,23 @@ class Gateway:
return None return None
@property @property
def _soap_requests(self) -> Dict: def soap_requests(self) -> List:
return { soap_call_infos = []
name: getattr(self.commands, name)._requests for name in self._registered_commands.keys() for name in self._registered_commands.keys():
} if not hasattr(getattr(self.commands, name), "_requests"):
continue
soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
request_args, raw_response, decoded_response, soap_error, ts
) in getattr(self.commands, name)._requests
])
soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos
def debug_gateway(self) -> Dict: def debug_gateway(self) -> Dict:
return { return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip, 'gateway_address': self.base_ip,
'gateway_descriptor': self.gateway_descriptor(), 'gateway_descriptor': self.gateway_descriptor(),
'gateway_xml': self._xml_response, 'gateway_xml': self._xml_response,
@ -142,7 +152,7 @@ class Gateway:
'soap_port': self.port, 'soap_port': self.port,
'registered_soap_commands': self._registered_commands, 'registered_soap_commands': self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions, 'unsupported_soap_commands': self._unsupported_actions,
'soap_requests': self._soap_requests 'soap_requests': self.soap_requests
} }
@classmethod @classmethod
@ -163,7 +173,7 @@ class Gateway:
await gateway.discover_commands(soap_socket) await gateway.discover_commands(soap_socket)
log.debug('found gateway device %s', datagram.location) log.debug('found gateway device %s', datagram.location)
return gateway return gateway
except asyncio.TimeoutError: except (asyncio.TimeoutError, UPnPError):
log.debug("get %s timed out, looking for other devices", datagram.location) log.debug("get %s timed out, looking for other devices", datagram.location)
ignored.add(datagram.location) ignored.add(datagram.location)
continue continue
@ -189,8 +199,10 @@ class Gateway:
return result return result
async def discover_commands(self, soap_socket: socket.socket = None): async def discover_commands(self, soap_socket: socket.socket = None):
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
self._xml_response = xml_bytes self._xml_response = xml_bytes
if get_err is not None:
raise get_err
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase") self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base: if not self.url_base:
@ -207,9 +219,11 @@ class Gateway:
async def register_commands(self, service: Service, soap_socket: socket.socket = None): async def register_commands(self, service: Service, soap_socket: socket.socket = None):
if not service.SCPDURL: if not service.SCPDURL:
raise UPnPError("no scpd url") raise UPnPError("no scpd url")
service_dict, xml_bytes = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
self._service_descriptors[service.SCPDURL] = xml_bytes self._service_descriptors[service.SCPDURL] = xml_bytes
if get_err is not None:
raise get_err
if not service_dict: if not service_dict:
return return

View file

@ -1,10 +1,13 @@
import logging import logging
import socket import socket
import typing import typing
import re
from collections import OrderedDict
from xml.etree import ElementTree from xml.etree import ElementTree
import asyncio import asyncio
from asyncio.protocols import Protocol from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.serialization.scpd import deserialize_scpd_get_response from aioupnp.serialization.scpd import deserialize_scpd_get_response
from aioupnp.serialization.scpd import serialize_scpd_get from aioupnp.serialization.scpd import serialize_scpd_get
from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response
@ -13,87 +16,126 @@ from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_pos
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class SCPDHTTPClientProtocol(Protocol): HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
POST = 'POST'
GET = 'GET'
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None, close_after_send: bool = False) -> None: def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]:
self.method = method lines = response.split(b'\r\n')
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \ headers = OrderedDict([
'soap args not provided' (l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
for l in response.split(b'\r\n')
])
if len(lines) != len(headers):
raise ValueError("duplicate headers")
http_response = tuple(headers.keys())[0]
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0]
del headers[http_response]
return headers, int(response_code), message
class SCPDHTTPClientProtocol(Protocol):
"""
This class will make HTTP GET and POST requests
It differs from spec HTTP in that the version string can be invalid, all we care about is the xml body
and devices respond with an invalid HTTP version line
"""
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None) -> None:
self.message = message self.message = message
self.response_buff = b"" self.response_buff = b""
self.finished = finished self.finished = finished
self.soap_method = soap_method self.soap_method = soap_method
self.soap_service_id = soap_service_id self.soap_service_id = soap_service_id
self.close_after_send = close_after_send
self._response_code: int = 0
self._response_msg: bytes = b""
self._content_length: int = 0
self._got_headers = False
self._headers: dict = {}
self._body = b""
def connection_made(self, transport): def connection_made(self, transport):
transport.write(self.message) transport.write(self.message)
if self.close_after_send:
self.finished.set_result(None)
def data_received(self, data): def data_received(self, data):
self.response_buff += data self.response_buff += data
if self.method == self.GET: for i, line in enumerate(self.response_buff.split(b'\r\n')):
try: if not line: # we hit the blank line between the headers and the body
packet = deserialize_scpd_get_response(self.response_buff) if i == (len(self.response_buff.split(b'\r\n')) - 1):
if packet: continue # the body is still yet to be written
self.finished.set_result(packet) if not self._got_headers:
return self._headers, self._response_code, self._response_msg = parse_headers(
except ElementTree.ParseError: b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
pass )
except UPnPError as err: content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length')
self.finished.set_exception(err) if content_length is None:
elif self.method == self.POST: return
try: self._content_length = int(content_length or 0)
packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id) self._got_headers = True
if packet: body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
self.finished.set_result(packet) if self._content_length == len(body):
return self.finished.set_result((body, self._response_code, self._response_msg))
except ElementTree.ParseError: elif self._content_length > len(body):
pass pass
except UPnPError as err: else:
self.finished.set_exception(err) self.finished.set_exception(
UPnPError(
"too many bytes written to response (%i vs %i expected)" % (
len(body), self._content_length
)
)
)
return
async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes]: async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes,
typing.Optional[Exception]]:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
finished: asyncio.Future = asyncio.Future() finished: asyncio.Future = asyncio.Future()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection( transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port lambda : SCPDHTTPClientProtocol(packet, finished), address, port
) )
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
parsed: typing.Dict = {} error = None
try: try:
parsed = await asyncio.wait_for(finished, 1.0) body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
except UPnPError: except asyncio.TimeoutError:
return parsed, protocol.response_buff error = UPnPError("get request timed out")
body = b''
finally: finally:
transport.close() transport.close()
return parsed, protocol.response_buff if not error:
try:
return deserialize_scpd_get_response(body), body, None
except ElementTree.ParseError as err:
error = UPnPError(err)
return {}, body, error
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
close_after_send: bool, soap_socket: socket.socket = None, soap_socket: socket.socket = None, **kwargs) -> typing.Tuple[typing.Dict, bytes,
**kwargs) -> typing.Tuple[typing.Dict, bytes]: typing.Optional[Exception]]:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
finished: asyncio.Future = asyncio.Future() finished: asyncio.Future = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection( transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol( lambda : SCPDHTTPClientProtocol(
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(), packet, finished, soap_method=method, soap_service_id=service_id.decode(),
close_after_send=close_after_send
), address, port, sock=soap_socket ), address, port, sock=soap_socket
) )
assert isinstance(protocol, SCPDHTTPClientProtocol) assert isinstance(protocol, SCPDHTTPClientProtocol)
parsed: typing.Dict = {}
try: try:
parsed = await asyncio.wait_for(finished, 1.0) body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
except UPnPError: except asyncio.TimeoutError:
return parsed, protocol.response_buff return {}, b'', UPnPError("Timeout")
finally: finally:
transport.close() transport.close()
return parsed, protocol.response_buff try:
return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None
)
except (ElementTree.ParseError, UPnPError) as err:
return {}, body, UPnPError(err)

View file

@ -2,6 +2,7 @@ import logging
import socket import socket
import asyncio import asyncio
import typing import typing
import time
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
@ -39,20 +40,21 @@ class SOAPCommand:
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]: async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()): if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
close_after_send = not self.return_types or self.return_types == [None]
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()} soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
try: response, xml_bytes, err = await scpd_post(
response, xml_bytes = await scpd_post( self.control_url, self.gateway_address, self.service_port, self.method, self.param_order,
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id, self.service_id, self.soap_socket, **soap_kwargs
close_after_send, self.soap_socket, **soap_kwargs )
) if err is not None:
except asyncio.TimeoutError as err: self._requests.append((soap_kwargs, xml_bytes, None, err, time.time()))
raise UPnPError(err) raise err
self._requests.append((soap_kwargs, xml_bytes))
if not response: if not response:
return None result = None
result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order]) else:
if len(result) == 1: recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
return result[0] if len(recast_result) == 1:
result = recast_result[0]
else:
result = recast_result
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time()))
return result return result

View file

@ -53,11 +53,12 @@ def deserialize_soap_post_response(response: bytes, method: str, service_id: str
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id) response_body = flatten_keys(envelope[BODY], "{%s}" % service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault body = handle_fault(response_body) # raises UPnPError if there is a fault
response_key = None response_key = None
if not body:
return {}
for key in body: for key in body:
if method in key: if method in key:
response_key = key response_key = key
break break
if not response_key: if not response_key:
raise UPnPError("unknown response fields for %s") raise UPnPError("unknown response fields for %s: %s" % (method, body))
return body[response_key] return body[response_key]

View file

@ -133,11 +133,11 @@ class UPnP:
result = await self.gateway.commands.GetSpecificPortMappingEntry( result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
) )
if isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand): if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
return {}
except UPnPError: except UPnPError:
return {} pass
return {}
@cli @cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None: async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
@ -196,18 +196,42 @@ class UPnP:
except (UPnPError, NotImplementedError): except (UPnPError, NotImplementedError):
print("failed to get the external ip") print("failed to get the external ip")
try: try:
redirects = await self.get_redirects() await self.get_redirects()
print("got redirects:\n%s" % redirects) print("got redirects")
except (UPnPError, NotImplementedError): except (UPnPError, NotImplementedError):
print("failed to get redirects") print("failed to get redirects")
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
try: try:
ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping") ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping")
print("set up external mapping to port %i" % ext_port) print("set up external mapping to port %i" % ext_port)
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
await self.delete_port_mapping(ext_port, "UDP") await self.delete_port_mapping(ext_port, "UDP")
print("deleted mapping") print("deleted mapping")
except (UPnPError, NotImplementedError): except (UPnPError, NotImplementedError):
print("failed to add and remove a mapping") print("failed to add and remove a mapping")
try:
await self.get_redirects()
print("got redirects")
except (UPnPError, NotImplementedError):
print("failed to get redirects")
try:
await self.get_specific_port_mapping(4567, "UDP")
print("got specific mapping")
except (UPnPError, NotImplementedError):
print("failed to get specific mapping")
if self.gateway.devices: if self.gateway.devices:
device = list(self.gateway.devices.values())[0] device = list(self.gateway.devices.values())[0]
assert device.manufacturer and device.modelName assert device.manufacturer and device.modelName