fix scpd breaking if content-length is not provided

-applies to at least actiontec
-add actiontec and dd-wrt discover + get_external_ip replay tests
This commit is contained in:
Jack Robison 2020-11-01 15:25:51 -05:00
parent 454ad65450
commit beaa7bc3cb
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 366 additions and 21 deletions

View file

@ -56,6 +56,7 @@ class SCPDHTTPClientProtocol(Protocol):
self._response_msg = b"" self._response_msg = b""
self._content_length = 0 self._content_length = 0
self._got_headers = False self._got_headers = False
self._has_content_length = True
self._headers: typing.Dict[bytes, bytes] = {} self._headers: typing.Dict[bytes, bytes] = {}
self._body = b"" self._body = b""
self.transport: typing.Optional[asyncio.WriteTransport] = None self.transport: typing.Optional[asyncio.WriteTransport] = None
@ -67,22 +68,30 @@ class SCPDHTTPClientProtocol(Protocol):
return None return None
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None:
if self.finished.done(): # possible to hit during tests
return
self.response_buff += data self.response_buff += data
for i, line in enumerate(self.response_buff.split(b'\r\n')): for i, line in enumerate(self.response_buff.split(b'\r\n')):
if not line: # we hit the blank line between the headers and the body if not line: # we hit the blank line between the headers and the body
if i == (len(self.response_buff.split(b'\r\n')) - 1): if i == (len(self.response_buff.split(b'\r\n')) - 1):
return None # the body is still yet to be written return None # the body is still yet to be written
if not self._got_headers: if not self._got_headers:
try:
self._headers, self._response_code, self._response_msg = parse_headers( self._headers, self._response_code, self._response_msg = parse_headers(
b'\r\n'.join(self.response_buff.split(b'\r\n')[:i]) b'\r\n'.join(self.response_buff.split(b'\r\n')[:i])
) )
except ValueError as err:
self.finished.set_exception(UPnPError(str(err)))
return
content_length = get_dict_val_case_insensitive( content_length = get_dict_val_case_insensitive(
self._headers, b'Content-Length' self._headers, b'Content-Length'
) )
if content_length is None: if content_length is not None:
return None self._content_length = int(content_length)
self._content_length = int(content_length or 0) else:
self._has_content_length = False
self._got_headers = True self._got_headers = True
if self._got_headers and self._has_content_length:
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:]) body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
if self._content_length == len(body): if self._content_length == len(body):
self.finished.set_result((self.response_buff, body, self._response_code, self._response_msg)) self.finished.set_result((self.response_buff, body, self._response_code, self._response_msg))
@ -96,6 +105,20 @@ class SCPDHTTPClientProtocol(Protocol):
) )
) )
) )
elif any(map(self.response_buff.endswith, (b"</root>\r\n", b"</scpd>\r\n"))):
# Actiontec has a router that doesn't give a Content-Length for the gateway xml
body = b'\r\n'.join(self.response_buff.split(b'\r\n')[i+1:])
self.finished.set_result((self.response_buff, body, self._response_code, self._response_msg))
elif len(self.response_buff) >= 65535:
self.finished.set_exception(
UPnPError(
"too many bytes written to response (%i) with unspecified content length" % len(self.response_buff)
)
)
return
else:
# needed for the actiontec case
pass
return None return None
return None return None

View file

@ -18,7 +18,7 @@ except ImportError:
@contextlib.contextmanager @contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False, tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
raise_oserror_on_bind=False, raise_connectionerror=False): raise_oserror_on_bind=False, raise_connectionerror=False, tcp_chunk_size=100):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {} udp_replies = udp_replies or {}
@ -36,8 +36,8 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
reply = tcp_replies[data] reply = tcp_replies[data]
i = 0 i = 0
while i < len(reply): while i < len(reply):
loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+100]) loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+tcp_chunk_size])
i += 100 i += tcp_chunk_size
return return
else: else:
pass pass
@ -71,6 +71,7 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)], loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
(udp_expected_addr, 1900)) (udp_expected_addr, 1900))
return _sendto return _sendto
protocol = proto_lam() protocol = proto_lam()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,8 +1,12 @@
import os
import json
from collections import OrderedDict
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from tests import AsyncioTestCase, mock_tcp_and_udp from tests import AsyncioTestCase, mock_tcp_and_udp
from collections import OrderedDict
from aioupnp.gateway import Gateway, get_action_list from aioupnp.gateway import Gateway, get_action_list
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.serialization.soap import serialize_soap_post
from aioupnp.upnp import UPnP
def gen_get_bytes(location: str, host: str) -> bytes: def gen_get_bytes(location: str, host: str) -> bytes:
@ -296,3 +300,71 @@ class TestDiscoverNetgearNighthawkAC2350(TestDiscoverDLinkDIR890L):
'RequestConnection', 'ForceTermination', 'RequestConnection', 'ForceTermination',
'GetStatusInfo', 'GetNATRSIPStatus']}, 'GetStatusInfo', 'GetNATRSIPStatus']},
'soap_requests': []} 'soap_requests': []}
class TestActiontec(AsyncioTestCase):
name = "Actiontec GT784WN"
_location_key = 'Location'
@property
def data_path(self):
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "replays", self.name)
def _get_location(self):
# return self.gateway_info['reply']['Location'].split(self.gateway_address)[-1]
return self.gateway_info['reply'][self._location_key].split(f"{self.gateway_address}:{self.gateway_info['soap_port']}")[-1]
def setUp(self) -> None:
with open(self.data_path, 'r') as f:
data = json.loads(f.read())
self.gateway_info = data['gateway']
self.client_address = data['client_address']
self.gateway_address = self.gateway_info['gateway_address']
self.udp_replies = {
(SSDPDatagram('M-SEARCH', self.gateway_info['m_search_args']).encode().encode(), ("239.255.255.250", 1900)): SSDPDatagram("OK", self.gateway_info['reply']).encode().encode()
}
self.tcp_replies = {
(
f"GET {path} HTTP/1.1\r\n"
f"Accept-Encoding: gzip\r\n"
f"Host: {self.gateway_info['gateway_address']}\r\n"
f"Connection: Close\r\n"
f"\r\n"
).encode(): xml_bytes.encode()
for path, xml_bytes in self.gateway_info['service_descriptors'].items()
}
self.tcp_replies.update({
(
f"GET {self._get_location()} HTTP/1.1\r\n"
f"Accept-Encoding: gzip\r\n"
f"Host: {self.gateway_info['gateway_address']}\r\n"
f"Connection: Close\r\n"
f"\r\n"
).encode(): self.gateway_info['gateway_xml'].encode()
})
self.registered_soap_commands = self.gateway_info['registered_soap_commands']
super().setUp()
async def setup_request_replay(self, u: UPnP):
for method, reqs in self.gateway_info['soap_requests'].items():
if not reqs:
continue
self.tcp_replies.update({
serialize_soap_post(
method, list(args.keys()), self.registered_soap_commands[method].encode(),
self.gateway_address.encode(), u.gateway.services[self.registered_soap_commands[method]].controlURL.encode()
): response.encode() for args, response in reqs
})
async def replay(self, u: UPnP):
self.assertEqual('11.222.33.111', await u.get_external_ip())
async def test_replay(self):
with mock_tcp_and_udp(self.loop, udp_replies=self.udp_replies, tcp_replies=self.tcp_replies, udp_expected_addr=self.gateway_address, tcp_chunk_size=1450):
u = await UPnP.discover(lan_address=self.client_address, gateway_address=self.gateway_address, loop=self.loop)
await self.setup_request_replay(u)
await self.replay(u)
class TestNewMediaNet(TestActiontec):
name = "NewMedia-NET GmbH Generic X86"