fix scpd post dropping some responses
-improve debug_gateway and generate_test_data
This commit is contained in:
parent
efad9b0d51
commit
55e1621637
5 changed files with 161 additions and 78 deletions
|
@ -125,13 +125,23 @@ class Gateway:
|
|||
return None
|
||||
|
||||
@property
|
||||
def _soap_requests(self) -> Dict:
|
||||
return {
|
||||
name: getattr(self.commands, name)._requests for name in self._registered_commands.keys()
|
||||
}
|
||||
def soap_requests(self) -> List:
|
||||
soap_call_infos = []
|
||||
for name in self._registered_commands.keys():
|
||||
if not hasattr(getattr(self.commands, name), "_requests"):
|
||||
continue
|
||||
soap_call_infos.extend([
|
||||
(name, request_args, raw_response, decoded_response, soap_error, ts)
|
||||
for (
|
||||
request_args, raw_response, decoded_response, soap_error, ts
|
||||
) in getattr(self.commands, name)._requests
|
||||
])
|
||||
soap_call_infos.sort(key=lambda x: x[5])
|
||||
return soap_call_infos
|
||||
|
||||
def debug_gateway(self) -> Dict:
|
||||
return {
|
||||
'manufacturer_string': self.manufacturer_string,
|
||||
'gateway_address': self.base_ip,
|
||||
'gateway_descriptor': self.gateway_descriptor(),
|
||||
'gateway_xml': self._xml_response,
|
||||
|
@ -142,7 +152,7 @@ class Gateway:
|
|||
'soap_port': self.port,
|
||||
'registered_soap_commands': self._registered_commands,
|
||||
'unsupported_soap_commands': self._unsupported_actions,
|
||||
'soap_requests': self._soap_requests
|
||||
'soap_requests': self.soap_requests
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -163,7 +173,7 @@ class Gateway:
|
|||
await gateway.discover_commands(soap_socket)
|
||||
log.debug('found gateway device %s', datagram.location)
|
||||
return gateway
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, UPnPError):
|
||||
log.debug("get %s timed out, looking for other devices", datagram.location)
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
|
@ -189,8 +199,10 @@ class Gateway:
|
|||
return result
|
||||
|
||||
async def discover_commands(self, soap_socket: socket.socket = None):
|
||||
response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
|
||||
response, xml_bytes, get_err = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
|
||||
self._xml_response = xml_bytes
|
||||
if get_err is not None:
|
||||
raise get_err
|
||||
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
|
||||
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
|
||||
if not self.url_base:
|
||||
|
@ -207,9 +219,11 @@ class Gateway:
|
|||
async def register_commands(self, service: Service, soap_socket: socket.socket = None):
|
||||
if not service.SCPDURL:
|
||||
raise UPnPError("no scpd url")
|
||||
service_dict, xml_bytes = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
|
||||
service_dict, xml_bytes, get_err = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
|
||||
self._service_descriptors[service.SCPDURL] = xml_bytes
|
||||
|
||||
if get_err is not None:
|
||||
raise get_err
|
||||
if not service_dict:
|
||||
return
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import logging
|
||||
import socket
|
||||
import typing
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from xml.etree import ElementTree
|
||||
import asyncio
|
||||
from asyncio.protocols import Protocol
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.util import get_dict_val_case_insensitive
|
||||
from aioupnp.serialization.scpd import deserialize_scpd_get_response
|
||||
from aioupnp.serialization.scpd import serialize_scpd_get
|
||||
from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response
|
||||
|
@ -13,87 +16,126 @@ from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_pos
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SCPDHTTPClientProtocol(Protocol):
|
||||
POST = 'POST'
|
||||
GET = 'GET'
|
||||
HTTP_CODE_REGEX = re.compile(b"^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$")
|
||||
|
||||
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
|
||||
soap_service_id: str=None, close_after_send: bool = False) -> None:
|
||||
self.method = method
|
||||
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
|
||||
'soap args not provided'
|
||||
|
||||
def parse_headers(response: bytes) -> typing.Tuple[OrderedDict, int, bytes]:
|
||||
lines = response.split(b'\r\n')
|
||||
headers = OrderedDict([
|
||||
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
|
||||
for l in response.split(b'\r\n')
|
||||
])
|
||||
if len(lines) != len(headers):
|
||||
raise ValueError("duplicate headers")
|
||||
http_response = tuple(headers.keys())[0]
|
||||
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0]
|
||||
del headers[http_response]
|
||||
return headers, int(response_code), message
|
||||
|
||||
|
||||
class SCPDHTTPClientProtocol(Protocol):
|
||||
"""
|
||||
This class will make HTTP GET and POST requests
|
||||
|
||||
It differs from spec HTTP in that the version string can be invalid, all we care about is the xml body
|
||||
and devices respond with an invalid HTTP version line
|
||||
"""
|
||||
|
||||
def __init__(self, message: bytes, finished: asyncio.Future, soap_method: str=None,
|
||||
soap_service_id: str=None) -> None:
|
||||
self.message = message
|
||||
self.response_buff = b""
|
||||
self.finished = finished
|
||||
self.soap_method = soap_method
|
||||
self.soap_service_id = soap_service_id
|
||||
self.close_after_send = close_after_send
|
||||
|
||||
self._response_code: int = 0
|
||||
self._response_msg: bytes = b""
|
||||
self._content_length: int = 0
|
||||
self._got_headers = False
|
||||
self._headers: dict = {}
|
||||
self._body = b""
|
||||
|
||||
def connection_made(self, transport):
|
||||
transport.write(self.message)
|
||||
if self.close_after_send:
|
||||
self.finished.set_result(None)
|
||||
|
||||
def data_received(self, data):
|
||||
self.response_buff += data
|
||||
if self.method == self.GET:
|
||||
try:
|
||||
packet = deserialize_scpd_get_response(self.response_buff)
|
||||
if packet:
|
||||
self.finished.set_result(packet)
|
||||
return
|
||||
except ElementTree.ParseError:
|
||||
pass
|
||||
except UPnPError as err:
|
||||
self.finished.set_exception(err)
|
||||
elif self.method == self.POST:
|
||||
try:
|
||||
packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id)
|
||||
if packet:
|
||||
self.finished.set_result(packet)
|
||||
return
|
||||
except ElementTree.ParseError:
|
||||
pass
|
||||
except UPnPError as err:
|
||||
self.finished.set_exception(err)
|
||||
for i, line in enumerate(self.response_buff.split(b'\r\n')):
|
||||
if not line: # we hit the blank line between the headers and the body
|
||||
if i == (len(self.response_buff.split(b'\r\n')) - 1):
|
||||
continue # the body is still yet to be written
|
||||
if not self._got_headers:
|
||||
self._headers, self._response_code, self._response_msg = parse_headers(
|
||||
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
|
||||
)
|
||||
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length')
|
||||
if content_length is None:
|
||||
return
|
||||
self._content_length = int(content_length or 0)
|
||||
self._got_headers = True
|
||||
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
|
||||
if self._content_length == len(body):
|
||||
self.finished.set_result((body, self._response_code, self._response_msg))
|
||||
elif self._content_length > len(body):
|
||||
pass
|
||||
else:
|
||||
self.finished.set_exception(
|
||||
UPnPError(
|
||||
"too many bytes written to response (%i vs %i expected)" % (
|
||||
len(body), self._content_length
|
||||
)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes]:
|
||||
async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes,
|
||||
typing.Optional[Exception]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
finished: asyncio.Future = asyncio.Future()
|
||||
packet = serialize_scpd_get(control_url, address)
|
||||
transport, protocol = await loop.create_connection(
|
||||
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
|
||||
lambda : SCPDHTTPClientProtocol(packet, finished), address, port
|
||||
)
|
||||
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
||||
parsed: typing.Dict = {}
|
||||
error = None
|
||||
try:
|
||||
parsed = await asyncio.wait_for(finished, 1.0)
|
||||
except UPnPError:
|
||||
return parsed, protocol.response_buff
|
||||
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
|
||||
except asyncio.TimeoutError:
|
||||
error = UPnPError("get request timed out")
|
||||
body = b''
|
||||
finally:
|
||||
transport.close()
|
||||
return parsed, protocol.response_buff
|
||||
if not error:
|
||||
try:
|
||||
return deserialize_scpd_get_response(body), body, None
|
||||
except ElementTree.ParseError as err:
|
||||
error = UPnPError(err)
|
||||
return {}, body, error
|
||||
|
||||
|
||||
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
|
||||
close_after_send: bool, soap_socket: socket.socket = None,
|
||||
**kwargs) -> typing.Tuple[typing.Dict, bytes]:
|
||||
soap_socket: socket.socket = None, **kwargs) -> typing.Tuple[typing.Dict, bytes,
|
||||
typing.Optional[Exception]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
finished: asyncio.Future = asyncio.Future()
|
||||
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
||||
transport, protocol = await loop.create_connection(
|
||||
lambda : SCPDHTTPClientProtocol(
|
||||
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
|
||||
close_after_send=close_after_send
|
||||
packet, finished, soap_method=method, soap_service_id=service_id.decode(),
|
||||
), address, port, sock=soap_socket
|
||||
)
|
||||
assert isinstance(protocol, SCPDHTTPClientProtocol)
|
||||
parsed: typing.Dict = {}
|
||||
try:
|
||||
parsed = await asyncio.wait_for(finished, 1.0)
|
||||
except UPnPError:
|
||||
return parsed, protocol.response_buff
|
||||
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
|
||||
except asyncio.TimeoutError:
|
||||
return {}, b'', UPnPError("Timeout")
|
||||
finally:
|
||||
transport.close()
|
||||
return parsed, protocol.response_buff
|
||||
try:
|
||||
return (
|
||||
deserialize_soap_post_response(body, method, service_id.decode()), body, None
|
||||
)
|
||||
except (ElementTree.ParseError, UPnPError) as err:
|
||||
return {}, body, UPnPError(err)
|
||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import socket
|
||||
import asyncio
|
||||
import typing
|
||||
import time
|
||||
from aioupnp.protocols.scpd import scpd_post
|
||||
from aioupnp.fault import UPnPError
|
||||
|
||||
|
@ -39,20 +40,21 @@ class SOAPCommand:
|
|||
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
|
||||
if set(kwargs.keys()) != set(self.param_types.keys()):
|
||||
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
|
||||
close_after_send = not self.return_types or self.return_types == [None]
|
||||
soap_kwargs = {n: safe_type(self.param_types[n])(kwargs[n]) for n in self.param_types.keys()}
|
||||
try:
|
||||
response, xml_bytes = await scpd_post(
|
||||
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
|
||||
close_after_send, self.soap_socket, **soap_kwargs
|
||||
)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise UPnPError(err)
|
||||
|
||||
self._requests.append((soap_kwargs, xml_bytes))
|
||||
response, xml_bytes, err = await scpd_post(
|
||||
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order,
|
||||
self.service_id, self.soap_socket, **soap_kwargs
|
||||
)
|
||||
if err is not None:
|
||||
self._requests.append((soap_kwargs, xml_bytes, None, err, time.time()))
|
||||
raise err
|
||||
if not response:
|
||||
return None
|
||||
result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
result = None
|
||||
else:
|
||||
recast_result = tuple([safe_type(self.return_types[n])(response.get(n)) for n in self.return_order])
|
||||
if len(recast_result) == 1:
|
||||
result = recast_result[0]
|
||||
else:
|
||||
result = recast_result
|
||||
self._requests.append((soap_kwargs, xml_bytes, result, None, time.time()))
|
||||
return result
|
||||
|
|
|
@ -53,11 +53,12 @@ def deserialize_soap_post_response(response: bytes, method: str, service_id: str
|
|||
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id)
|
||||
body = handle_fault(response_body) # raises UPnPError if there is a fault
|
||||
response_key = None
|
||||
|
||||
if not body:
|
||||
return {}
|
||||
for key in body:
|
||||
if method in key:
|
||||
response_key = key
|
||||
break
|
||||
if not response_key:
|
||||
raise UPnPError("unknown response fields for %s")
|
||||
raise UPnPError("unknown response fields for %s: %s" % (method, body))
|
||||
return body[response_key]
|
|
@ -133,11 +133,11 @@ class UPnP:
|
|||
result = await self.gateway.commands.GetSpecificPortMappingEntry(
|
||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
if isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
|
||||
if result and isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
|
||||
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
|
||||
return {}
|
||||
except UPnPError:
|
||||
return {}
|
||||
pass
|
||||
return {}
|
||||
|
||||
@cli
|
||||
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
||||
|
@ -196,18 +196,42 @@ class UPnP:
|
|||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get the external ip")
|
||||
try:
|
||||
redirects = await self.get_redirects()
|
||||
print("got redirects:\n%s" % redirects)
|
||||
await self.get_redirects()
|
||||
print("got redirects")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get redirects")
|
||||
try:
|
||||
await self.get_specific_port_mapping(4567, "UDP")
|
||||
print("got specific mapping")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get specific mapping")
|
||||
try:
|
||||
ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping")
|
||||
print("set up external mapping to port %i" % ext_port)
|
||||
try:
|
||||
await self.get_specific_port_mapping(4567, "UDP")
|
||||
print("got specific mapping")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get specific mapping")
|
||||
try:
|
||||
await self.get_redirects()
|
||||
print("got redirects")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get redirects")
|
||||
await self.delete_port_mapping(ext_port, "UDP")
|
||||
print("deleted mapping")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to add and remove a mapping")
|
||||
|
||||
try:
|
||||
await self.get_redirects()
|
||||
print("got redirects")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get redirects")
|
||||
try:
|
||||
await self.get_specific_port_mapping(4567, "UDP")
|
||||
print("got specific mapping")
|
||||
except (UPnPError, NotImplementedError):
|
||||
print("failed to get specific mapping")
|
||||
if self.gateway.devices:
|
||||
device = list(self.gateway.devices.values())[0]
|
||||
assert device.manufacturer and device.modelName
|
||||
|
|
Loading…
Reference in a new issue