aioupnp/aioupnp/protocols/scpd.py
2019-04-24 01:10:53 -05:00

162 lines
5.9 KiB
Python

import logging
import re
from collections import OrderedDict
from xml.etree import ElementTree
import asyncio
from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.serialization.scpd import deserialize_scpd_get_response
from aioupnp.serialization.scpd import serialize_scpd_get
from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response
log = logging.getLogger(__name__)
HTTP_CODE_REGEX = re.compile(b'^HTTP[\/]{0,1}1\.[1|0] (\d\d\d)(.*)$') #TODO: refactor
def parse_headers(response):
lines = response.split(b'\r\n')
headers = OrderedDict([
(l.split(b':')[0], b':'.join(l.split(b':')[1:]).lstrip(b' ').rstrip(b' '))
for l in response.split(b'\r\n')
])
if len(lines) != len(headers):
raise ValueError("duplicate headers")
http_response = tuple(headers.keys())[0]
response_code, message = HTTP_CODE_REGEX.findall(http_response)[0]
del headers[http_response]
return headers, int(response_code), message
class SCPDHTTPClientProtocol(Protocol):
"""
This class will make HTTP GET and POST requests
It differs from spec HTTP in that the version string can be invalid, all we care about is the xml body
and devices respond with an invalid HTTP version line
"""
def __init__(self, message, finished, soap_method, soap_service_id):
"""SCPDHTTPClientProtocol.
:param message:
:param finished:
:param soap_method:
:param soap_service_id:
"""
self.message = message
self.response_buff = b''
self.finished = finished
self.soap_method = soap_method
self.soap_service_id = soap_service_id
self._response_code = 0
self._response_msg = b''
self._content_length = 0
self._got_headers = False
self._headers = {}
self._body = b''
def connection_made(self, transport):
"""Called when connection is established.
:param DatagramTransport transport: Transport object.
"""
addr = transport.get_extra_info("peername")
transport.sendto(self.message, addr)
def data_received(self, data):
"""Called when data has been received.
:param bytes or str data: Data
"""
self.response_buff += data # TODO: make response_buff bytearray?
for i, line in enumerate(self.response_buff.split(b'\r\n')):
if not line: # we hit the blank line between the headers and the body
if i == (len(self.response_buff.split(b'\r\n')) - 1):
return # the body is still yet to be written
if not self._got_headers:
self._headers, self._response_code, self._response_msg = parse_headers(
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
)
content_length = get_dict_val_case_insensitive(self._headers, b'Content-Length')
if content_length is None:
return
self._content_length = int(content_length or 0)
self._got_headers = True
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
if self._content_length == len(body):
self.finished.set_result((body, self._response_code, self._response_msg))
elif self._content_length > len(body):
pass
else:
self.finished.set_exception(
UPnPError(
"too many bytes written to response (%i vs %i expected)" % (
len(body), self._content_length
)
)
)
return
async def scpd_get(control_url, address, port, loop):
"""SCPD GET request.
:param control_url:
:param address:
:param port:
:param loop:
"""
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
finished = asyncio.Future()
packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection(
lambda: SCPDHTTPClientProtocol(packet, finished), address, port
)
assert isinstance(protocol, SCPDHTTPClientProtocol)
error = None
try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
except asyncio.TimeoutError:
error = UPnPError("get request timed out")
body = b''
except UPnPError as err:
error = err
body = protocol.response_buff
finally:
transport.close()
if not error:
try:
return deserialize_scpd_get_response(body), body, None
except ElementTree.ParseError as err:
error = UPnPError(err)
return {}, body, error
async def scpd_post(control_url, address, port, method, param_names, service_id, loop=None, **kwargs):
loop |= asyncio.get_event_loop_policy().get_event_loop()
finished: asyncio.Future = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol(
packet, finished, soap_method=method, soap_service_id=service_id.decode(),
), address, port
)
assert isinstance(protocol, SCPDHTTPClientProtocol)
try:
body, response_code, response_msg = await asyncio.wait_for(finished, 1.0)
except asyncio.TimeoutError:
return {}, b'', UPnPError("Timeout")
except UPnPError as err:
return {}, protocol.response_buff, err
finally:
transport.close()
try:
return (
deserialize_soap_post_response(body, method, service_id.decode()), body, None
)
except (ElementTree.ParseError, UPnPError) as err:
return {}, body, UPnPError(err)