better ssdp datagram parsing

This commit is contained in:
Jack Robison 2018-07-29 17:32:14 -04:00
parent a0625153dc
commit d6695c7925
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 310 additions and 139 deletions

View file

@ -1,4 +1,4 @@
import binascii import json
import logging import logging
from twisted.internet import defer from twisted.internet import defer
import treq import treq
@ -19,6 +19,15 @@ class Service(object):
self.subscribe_path = eventSubURL self.subscribe_path = eventSubURL
self.scpd_path = SCPDURL self.scpd_path = SCPDURL
def get_info(self):
return {
"service_type": self.service_type,
"service_id": self.service_id,
"control_path": self.control_path,
"subscribe_path": self.subscribe_path,
"scpd_path": self.scpd_path
}
class Device(object): class Device(object):
def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None, def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None,
@ -45,6 +54,17 @@ class Device(object):
devices = [Device(self._root_device, **deviceList[k]) for k in deviceList] devices = [Device(self._root_device, **deviceList[k]) for k in deviceList]
self._root_device.devices.extend(devices) self._root_device.devices.extend(devices)
def get_info(self):
return {
'device_type': self.device_type,
'friendly_name': self.friendly_name,
'manufacturers': self.manufacturer,
'model_name': self.model_name,
'model_number': self.model_number,
'serial_number': self.serial_number,
'udn': self.udn
}
class RootDevice(object): class RootDevice(object):
def __init__(self, xml_string): def __init__(self, xml_string):
@ -61,7 +81,7 @@ class RootDevice(object):
if root: if root:
root_device = Device(self, **(root["device"])) root_device = Device(self, **(root["device"]))
self.devices.append(root_device) self.devices.append(root_device)
log.info("finished setting up root device. %i devices and %i services", len(self.devices), len(self.services)) log.info("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services))
class Gateway(object): class Gateway(object):
@ -77,14 +97,41 @@ class Gateway(object):
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self._device = None self._device = None
def debug_device(self):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
devices = []
for device in self._device.devices:
info = device.get_info()
devices.append(info)
services = []
for service in self._device.services:
info = service.get_info()
services.append(info)
return json.dumps({
'root_url': self.base_address,
'gateway_xml_url': self.location,
'usn': self.usn,
'devices': devices,
'services': services
}, indent=2, default=default_byte)
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self): def discover_services(self):
log.info("querying %s", self.location) log.info("querying %s", self.location)
response = yield treq.get(self.location) response = yield treq.get(self.location)
response_xml = yield response.content() response_xml = yield response.content()
self._device = RootDevice(response_xml) if not response_xml:
if not self._device.devices or not self._device.services: log.error("service sent an empty reply\n%s", self.debug_device())
log.error("failed to parse device: \n%s", response_xml) try:
self._device = RootDevice(response_xml)
except Exception as err:
log.error("error parsing gateway: %s\n%s\n\n%s", err, self.debug_device(), response_xml)
self._device = RootDevice("")
log.debug("finished setting up gateway:\n%s", self.debug_device())
@property @property
def services(self): def services(self):

View file

@ -87,7 +87,7 @@ class _SCPDCommand(object):
xml_response = yield response.content() xml_response = yield response.content()
response = self.extract_response(self.extract_body(xml_response)) response = self.extract_response(self.extract_body(xml_response))
if not response: if not response:
log.error("empty response to %s\n%s", self.method, xml_response) log.debug("empty response to %s\n%s", self.method, xml_response)
defer.returnValue(response) defer.returnValue(response)
@staticmethod @staticmethod
@ -155,7 +155,7 @@ class SCPDCommandRunner(object):
@staticmethod @staticmethod
def _soap_function_info(action_dict): def _soap_function_info(action_dict):
if not action_dict.get('argumentList'): if not action_dict.get('argumentList'):
log.warning("don't know how to handle argument list: %s", action_dict) log.debug("don't know how to handle argument list: %s", action_dict)
return ( return (
action_dict['name'], action_dict['name'],
[], [],
@ -185,7 +185,7 @@ class SCPDCommandRunner(object):
command._process_result = _return_types(*current._return_types)(command._process_result) command._process_result = _return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__) setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command) setattr(self, command.method, command)
log.info("registered %s %s", service_type, action_info['name']) log.debug("registered %s %s", service_type, action_info['name'])
def _register_command(self, action_info, service_type): def _register_command(self, action_info, service_type):
try: try:

View file

@ -14,14 +14,14 @@ class SOAPServiceManager(object):
def __init__(self, reactor): def __init__(self, reactor):
self._reactor = reactor self._reactor = reactor
self.iface_name, self.router_ip, self.lan_address = get_lan_info() self.iface_name, self.router_ip, self.lan_address = get_lan_info()
self.sspd_factory = SSDPFactory(self.lan_address, self._reactor) self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
self._command_runners = {} self._command_runners = {}
self._selected_runner = GATEWAY_SCHEMA self._selected_runner = GATEWAY_SCHEMA
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self, address=None, ttl=30, max_devices=2): def discover_services(self, address=None, timeout=30, max_devices=1):
server_infos = yield self.sspd_factory.m_search( server_infos = yield self.sspd_factory.m_search(
address or self.router_ip, ttl=ttl, max_devices=max_devices address or self.router_ip, timeout=timeout, max_devices=max_devices
) )
locations = [] locations = []
for server_info in server_infos: for server_info in server_infos:

View file

@ -1,67 +1,188 @@
import logging import logging
import binascii import binascii
import re
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol from twisted.internet.protocol import DatagramProtocol
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def parse_http_fields(content_lines): SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
def flatten(s, lower=True): SSDP_BYEBYE = "ssdp:byebye"
r = s.rstrip(":").rstrip(" ").lstrip(" ").replace("-", "_") SSDP_UPDATE = "ssdp:update"
if lower: SSDP_ROOT_DEVICE = "upnp:rootdevice"
return r.lower() line_separator = "\r\n"
return r
result = {}
for l in content_lines:
split = l.decode().split(":")
if split and split[0]:
k = split[0]
v = ":".join(split[1:])
result[flatten(k)] = flatten(v, lower=False)
return result
#
# return {
# (k.lower().rstrip(":".encode()).replace("-".encode(), "_".encode())).decode(): v.decode()
# for k, v in {
# l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:])
#
# }.items() if k
# }
def parse_ssdp_request(operation, port, protocol, content_lines): class SSDPDatagram(object):
if operation != "NOTIFY".encode(): _M_SEARCH = "M-SEARCH"
log.warning("unsupported operation: %s", operation) _NOTIFY = "NOTIFY"
raise UPnPError("unsupported operation: %s" % operation) _OK = "OK"
if port != "*".encode():
log.warning("unexpected port: %s", port)
raise UPnPError("unexpected port: %s" % port)
return parse_http_fields(content_lines)
_start_lines = {
_M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK"
}
def parse_ssdp_response(code, response, content_lines): _friendly_names = {
try: _M_SEARCH: "m-search",
if int(code) != 200: _NOTIFY: "notify",
raise UPnPError("unexpected http response code: %i" % int(code)) _OK: "m-search response"
except ValueError: }
log.error(response)
raise UPnPError("unexpected http response code: %s" % code) _vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
if response != "OK".encode():
raise UPnPError("unexpected response: %s" % response) _patterns = {
return parse_http_fields(content_lines) 'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_required_fields = {
_M_SEARCH: [
'host',
'st',
'man',
'mx',
],
_OK: [
'cache_control',
# 'date',
# 'ext',
'location',
'server',
'st',
'usn'
]
}
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs):
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
self.host = host
self.st = st
self.man = man
self.mx = mx
self.nt = nt
self.nts = nts
self.usn = usn
self.location = location
self.cache_control = cache_control
self.server = server
self.date = date
self.ext = ext
for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self):
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2):
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers:
value = self._marshallers[attr_name](attr)
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self):
return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod
def decode(cls, datagram):
packet = cls._from_string(datagram.decode())
for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name)
if attr is None:
raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
)
return packet
@classmethod
def _lines_to_content_dict(cls, lines):
result = {}
for line in lines:
if not line:
continue
matched = False
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
vendor_key = match[0].lstrip(" ").rstrip(" ")
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result
@classmethod
def _from_string(cls, datagram):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] == cls._start_lines[cls._NOTIFY]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])
@classmethod
def _from_response(cls, lines):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
class SSDPProtocol(DatagramProtocol): class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS, def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None): ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
self._reactor = reactor self._reactor = reactor
self._sem = defer.DeferredSemaphore(1) self._sem = defer.DeferredSemaphore(1)
self.finished_deferred = finished_deferred self.discover_callbacks = {}
self.iface = iface self.iface = iface
self.router = router self.router = router
self.ssdp_address = ssdp_address self.ssdp_address = ssdp_address
@ -72,93 +193,81 @@ class SSDPProtocol(DatagramProtocol):
self.devices = [] self.devices = []
def startProtocol(self): def startProtocol(self):
return self._sem.run(self.do_start)
def send_m_search(self):
data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_ALL, self.ttl)
try:
log.info("sending m-search (%i bytes) to %s:%i", len(data), self.ssdp_address, self.ssdp_port)
self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(data), self.ssdp_address, self.ssdp_port)
raise err
def parse_ssdp_datagram(self, datagram):
lines = datagram.split("\r\n".encode())
header_pieces = lines[0].split(" ".encode())
protocols = {
"HTTP/1.1".encode()
}
operations = {
"M-SEARCH".encode(),
"NOTIFY".encode()
}
if header_pieces[0] in operations:
if header_pieces[2] not in protocols:
raise UPnPError("unknown protocol: %s" % header_pieces[2])
return parse_ssdp_request(header_pieces[0], header_pieces[1], header_pieces[2], lines[1:])
if header_pieces[0] in protocols:
parsed = parse_ssdp_response(header_pieces[1], header_pieces[2], lines[1:])
log.info("received reply (%i bytes) to SSDP request (%f) (%s) %s", len(datagram),
self._reactor.seconds() - self._start, parsed['location'], parsed['server'])
return parsed
raise UPnPError("don't know how to decode datagram: %s" % binascii.hexlify(datagram))
def do_start(self):
self._start = self._reactor.seconds() self._start = self._reactor.seconds()
self.finished_deferred.addTimeout(self.ttl, self._reactor)
self.transport.setTTL(self.ttl) self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface) self.transport.joinGroup(self.ssdp_address, interface=self.iface)
self.send_m_search()
for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]:
self.send_m_search(service=st)
def send_m_search(self, service=GATEWAY_SCHEMA):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("writing packet:\n%s", packet.encode())
log.info("sending m-search (%i bytes) to %s:%i", len(packet.encode()), self.ssdp_address, self.ssdp_port)
try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port)
raise err
def leave_group(self): def leave_group(self):
self.transport.leaveGroup(self.ssdp_address, interface=self.iface) self.transport.leaveGroup(self.ssdp_address, interface=self.iface)
def datagramReceived(self, datagram, addr): def datagramReceived(self, datagram, address):
self._sem.run(self.handle_datagram, datagram, addr) if address[0] == self.iface:
return
def handle_datagram(self, datagram, address): try:
if address[0] == self.router: packet = SSDPDatagram.decode(datagram)
try: log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
parsed = self.parse_ssdp_datagram(datagram) except Exception:
self.devices.append(parsed) log.exception("failed to decode: %s", binascii.hexlify(datagram))
log.info("found %i/%s so far", len(self.devices), self.max_devices) return
if not self.finished_deferred.called: if packet._packet_type == packet._OK:
if not self.max_devices or (self.max_devices and len(self.devices) >= self.max_devices): log.info("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
self._sem.run(self.finished_deferred.callback, self.devices)
except UPnPError as err:
log.error("error decoding SSDP response from %s:%s (error: %s)\n%s", address[0], address[1], str(err), binascii.hexlify(datagram))
raise err
elif address[0] != self.iface:
log.info("received %i bytes from %s:%i\n%s", len(datagram), address[0], address[1], binascii.hexlify(datagram))
else: else:
pass # loopback log.info("%s:%i notified us of a service type: %s", address[0], address[1], packet.st)
if packet.st not in map(lambda p: p['st'], self.devices):
self.devices.append(packet.as_dict())
log.info("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
if address[0] in self.discover_callbacks:
self._sem.run(self.discover_callbacks[address[0]][0], packet)
def gather(finished_deferred, max_results):
results = []
def discover_cb(packet):
if not finished_deferred.called:
results.append(packet.as_dict())
if len(results) >= max_results:
finished_deferred.callback(results)
return discover_cb
class SSDPFactory(object): class SSDPFactory(object):
def __init__(self, lan_address, reactor): def __init__(self, reactor, lan_address, router_address):
self.lan_address = lan_address self.lan_address = lan_address
self.router_address = router_address
self._reactor = reactor self._reactor = reactor
self.protocol = None self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
self.port = None self.port = None
self.finished_deferred = defer.Deferred()
def stop(self): def disconnect(self):
try: if self.protocol:
self.protocol.leave_group() self.protocol.leave_group()
self.port.stopListening() self.protocol = None
except: if not self.port:
pass return
self.port.stopListening()
self.port = None
def connect(self, address, ttl, max_devices=1): def connect(self):
self.protocol = SSDPProtocol(self._reactor, self.finished_deferred, self.lan_address, address, ttl=ttl, self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
max_devices=max_devices)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True) self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
self._reactor.addSystemEventTrigger("before", "shutdown", self.stop)
return self.finished_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def m_search(self, address, ttl=30, max_devices=2): def m_search(self, address, timeout=30, max_devices=1):
""" """
Perform a HTTP over UDP M-SEARCH query Perform a HTTP over UDP M-SEARCH query
@ -170,11 +279,18 @@ class SSDPFactory(object):
'usn': <usn> 'usn': <usn>
}, ...] }, ...]
""" """
d = self.connect(address, ttl, max_devices=max_devices)
self.connect()
if address in self.protocol.discover_callbacks:
d = self.protocol.discover_callbacks[address][1]
else:
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
found_cb = gather(d, max_devices)
self.protocol.discover_callbacks[address] = found_cb, d
try: try:
server_infos = yield d server_infos = yield d
except defer.TimeoutError: except defer.TimeoutError:
server_infos = self.protocol.devices server_infos = self.protocol.devices
log.info("found %i devices", len(server_infos))
self.stop()
defer.returnValue(server_infos) defer.returnValue(server_infos)

View file

@ -1,3 +1,4 @@
import sys
import logging import logging
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from txupnp.upnp import UPnP from txupnp.upnp import UPnP
@ -7,10 +8,12 @@ log = logging.getLogger("txupnp")
@defer.inlineCallbacks @defer.inlineCallbacks
def test(ext_port=4446, int_port=4446, proto='UDP'): def test(ext_port=4446, int_port=4446, proto='UDP', timeout=1):
u = UPnP(reactor) u = UPnP(reactor)
found = yield u.discover() found = yield u.discover(timeout=timeout)
assert found, "M-SEARCH failed to find gateway" if not found:
print("failed to find gateway")
defer.returnValue(None)
external_ip = yield u.get_external_ip() external_ip = yield u.get_external_ip()
assert external_ip, "Failed to get the external IP" assert external_ip, "Failed to get the external IP"
log.info(external_ip) log.info(external_ip)
@ -45,17 +48,22 @@ def test(ext_port=4446, int_port=4446, proto='UDP'):
@defer.inlineCallbacks @defer.inlineCallbacks
def run_tests(): def run_tests(timeout=1):
for p in ['UDP', 'TCP']: for p in ['UDP']:
yield test(proto=p) yield test(proto=p, timeout=timeout)
def main(): def main(timeout):
d = run_tests() d = run_tests(timeout)
d.addErrback(log.exception) d.addErrback(log.exception)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop)) d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run() reactor.run()
if __name__ == "__main__": if __name__ == "__main__":
main() if len(sys.argv) > 1:
log.setLevel(logging.DEBUG)
timeout = int(sys.argv[1])
else:
timeout = 1
main(timeout)

View file

@ -19,7 +19,7 @@ class UPnP(object):
def commands(self): def commands(self):
return self.soap_manager.get_runner() return self.soap_manager.get_runner()
def m_search(self, address, ttl=30, max_devices=2): def m_search(self, address, timeout=30, max_devices=2):
""" """
Perform a HTTP over UDP M-SEARCH query Perform a HTTP over UDP M-SEARCH query
@ -31,12 +31,12 @@ class UPnP(object):
'usn': <usn> 'usn': <usn>
}, ...] }, ...]
""" """
return self.soap_manager.sspd_factory.m_search(address, ttl=ttl, max_devices=max_devices) return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
@defer.inlineCallbacks @defer.inlineCallbacks
def discover(self, ttl=30, max_devices=2): def discover(self, timeout=1, max_devices=1):
try: try:
yield self.soap_manager.discover_services(ttl=ttl, max_devices=max_devices) yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices)
except defer.TimeoutError: except defer.TimeoutError:
log.warning("failed to find upnp gateway") log.warning("failed to find upnp gateway")
defer.returnValue(False) defer.returnValue(False)