better ssdp datagram parsing
This commit is contained in:
parent
a0625153dc
commit
d6695c7925
6 changed files with 310 additions and 139 deletions
|
@ -1,4 +1,4 @@
|
|||
import binascii
|
||||
import json
|
||||
import logging
|
||||
from twisted.internet import defer
|
||||
import treq
|
||||
|
@ -19,6 +19,15 @@ class Service(object):
|
|||
self.subscribe_path = eventSubURL
|
||||
self.scpd_path = SCPDURL
|
||||
|
||||
def get_info(self):
|
||||
return {
|
||||
"service_type": self.service_type,
|
||||
"service_id": self.service_id,
|
||||
"control_path": self.control_path,
|
||||
"subscribe_path": self.subscribe_path,
|
||||
"scpd_path": self.scpd_path
|
||||
}
|
||||
|
||||
|
||||
class Device(object):
|
||||
def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None,
|
||||
|
@ -45,6 +54,17 @@ class Device(object):
|
|||
devices = [Device(self._root_device, **deviceList[k]) for k in deviceList]
|
||||
self._root_device.devices.extend(devices)
|
||||
|
||||
def get_info(self):
|
||||
return {
|
||||
'device_type': self.device_type,
|
||||
'friendly_name': self.friendly_name,
|
||||
'manufacturers': self.manufacturer,
|
||||
'model_name': self.model_name,
|
||||
'model_number': self.model_number,
|
||||
'serial_number': self.serial_number,
|
||||
'udn': self.udn
|
||||
}
|
||||
|
||||
|
||||
class RootDevice(object):
|
||||
def __init__(self, xml_string):
|
||||
|
@ -61,7 +81,7 @@ class RootDevice(object):
|
|||
if root:
|
||||
root_device = Device(self, **(root["device"]))
|
||||
self.devices.append(root_device)
|
||||
log.info("finished setting up root device. %i devices and %i services", len(self.devices), len(self.services))
|
||||
log.info("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services))
|
||||
|
||||
|
||||
class Gateway(object):
|
||||
|
@ -77,14 +97,41 @@ class Gateway(object):
|
|||
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
|
||||
self._device = None
|
||||
|
||||
def debug_device(self):
|
||||
def default_byte(x):
|
||||
if isinstance(x, bytes):
|
||||
return x.decode()
|
||||
return x
|
||||
|
||||
devices = []
|
||||
for device in self._device.devices:
|
||||
info = device.get_info()
|
||||
devices.append(info)
|
||||
services = []
|
||||
for service in self._device.services:
|
||||
info = service.get_info()
|
||||
services.append(info)
|
||||
return json.dumps({
|
||||
'root_url': self.base_address,
|
||||
'gateway_xml_url': self.location,
|
||||
'usn': self.usn,
|
||||
'devices': devices,
|
||||
'services': services
|
||||
}, indent=2, default=default_byte)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover_services(self):
|
||||
log.info("querying %s", self.location)
|
||||
response = yield treq.get(self.location)
|
||||
response_xml = yield response.content()
|
||||
if not response_xml:
|
||||
log.error("service sent an empty reply\n%s", self.debug_device())
|
||||
try:
|
||||
self._device = RootDevice(response_xml)
|
||||
if not self._device.devices or not self._device.services:
|
||||
log.error("failed to parse device: \n%s", response_xml)
|
||||
except Exception as err:
|
||||
log.error("error parsing gateway: %s\n%s\n\n%s", err, self.debug_device(), response_xml)
|
||||
self._device = RootDevice("")
|
||||
log.debug("finished setting up gateway:\n%s", self.debug_device())
|
||||
|
||||
@property
|
||||
def services(self):
|
||||
|
|
|
@ -87,7 +87,7 @@ class _SCPDCommand(object):
|
|||
xml_response = yield response.content()
|
||||
response = self.extract_response(self.extract_body(xml_response))
|
||||
if not response:
|
||||
log.error("empty response to %s\n%s", self.method, xml_response)
|
||||
log.debug("empty response to %s\n%s", self.method, xml_response)
|
||||
defer.returnValue(response)
|
||||
|
||||
@staticmethod
|
||||
|
@ -155,7 +155,7 @@ class SCPDCommandRunner(object):
|
|||
@staticmethod
|
||||
def _soap_function_info(action_dict):
|
||||
if not action_dict.get('argumentList'):
|
||||
log.warning("don't know how to handle argument list: %s", action_dict)
|
||||
log.debug("don't know how to handle argument list: %s", action_dict)
|
||||
return (
|
||||
action_dict['name'],
|
||||
[],
|
||||
|
@ -185,7 +185,7 @@ class SCPDCommandRunner(object):
|
|||
command._process_result = _return_types(*current._return_types)(command._process_result)
|
||||
setattr(command, "__doc__", current.__doc__)
|
||||
setattr(self, command.method, command)
|
||||
log.info("registered %s %s", service_type, action_info['name'])
|
||||
log.debug("registered %s %s", service_type, action_info['name'])
|
||||
|
||||
def _register_command(self, action_info, service_type):
|
||||
try:
|
||||
|
|
|
@ -14,14 +14,14 @@ class SOAPServiceManager(object):
|
|||
def __init__(self, reactor):
|
||||
self._reactor = reactor
|
||||
self.iface_name, self.router_ip, self.lan_address = get_lan_info()
|
||||
self.sspd_factory = SSDPFactory(self.lan_address, self._reactor)
|
||||
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
|
||||
self._command_runners = {}
|
||||
self._selected_runner = GATEWAY_SCHEMA
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover_services(self, address=None, ttl=30, max_devices=2):
|
||||
def discover_services(self, address=None, timeout=30, max_devices=1):
|
||||
server_infos = yield self.sspd_factory.m_search(
|
||||
address or self.router_ip, ttl=ttl, max_devices=max_devices
|
||||
address or self.router_ip, timeout=timeout, max_devices=max_devices
|
||||
)
|
||||
locations = []
|
||||
for server_info in server_infos:
|
||||
|
|
336
txupnp/ssdp.py
336
txupnp/ssdp.py
|
@ -1,67 +1,188 @@
|
|||
import logging
|
||||
import binascii
|
||||
import re
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import DatagramProtocol
|
||||
from txupnp.fault import UPnPError
|
||||
from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL
|
||||
from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_http_fields(content_lines):
|
||||
def flatten(s, lower=True):
|
||||
r = s.rstrip(":").rstrip(" ").lstrip(" ").replace("-", "_")
|
||||
if lower:
|
||||
return r.lower()
|
||||
return r
|
||||
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
|
||||
SSDP_BYEBYE = "ssdp:byebye"
|
||||
SSDP_UPDATE = "ssdp:update"
|
||||
SSDP_ROOT_DEVICE = "upnp:rootdevice"
|
||||
line_separator = "\r\n"
|
||||
|
||||
|
||||
class SSDPDatagram(object):
|
||||
_M_SEARCH = "M-SEARCH"
|
||||
_NOTIFY = "NOTIFY"
|
||||
_OK = "OK"
|
||||
|
||||
_start_lines = {
|
||||
_M_SEARCH: "M-SEARCH * HTTP/1.1",
|
||||
_NOTIFY: "NOTIFY * HTTP/1.1",
|
||||
_OK: "HTTP/1.1 200 OK"
|
||||
}
|
||||
|
||||
_friendly_names = {
|
||||
_M_SEARCH: "m-search",
|
||||
_NOTIFY: "notify",
|
||||
_OK: "m-search response"
|
||||
}
|
||||
|
||||
_vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
|
||||
|
||||
_patterns = {
|
||||
'host': (re.compile("^(?i)(host):(.*)$"), str),
|
||||
'st': (re.compile("^(?i)(st):(.*)$"), str),
|
||||
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
|
||||
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
|
||||
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
|
||||
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
|
||||
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
|
||||
'location': (re.compile("^(?i)(location):(.*)$"), str),
|
||||
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
|
||||
'server': (re.compile("^(?i)(server):(.*)$"), str),
|
||||
}
|
||||
|
||||
_required_fields = {
|
||||
_M_SEARCH: [
|
||||
'host',
|
||||
'st',
|
||||
'man',
|
||||
'mx',
|
||||
],
|
||||
_OK: [
|
||||
'cache_control',
|
||||
# 'date',
|
||||
# 'ext',
|
||||
'location',
|
||||
'server',
|
||||
'st',
|
||||
'usn'
|
||||
]
|
||||
}
|
||||
|
||||
_marshallers = {
|
||||
'mx': str,
|
||||
'man': lambda x: ("\"%s\"" % x)
|
||||
}
|
||||
|
||||
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
|
||||
cache_control=None, server=None, date=None, ext=None, **kwargs):
|
||||
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
||||
raise UPnPError("unknown packet type: {}".format(packet_type))
|
||||
self._packet_type = packet_type
|
||||
self.host = host
|
||||
self.st = st
|
||||
self.man = man
|
||||
self.mx = mx
|
||||
self.nt = nt
|
||||
self.nts = nts
|
||||
self.usn = usn
|
||||
self.location = location
|
||||
self.cache_control = cache_control
|
||||
self.server = server
|
||||
self.date = date
|
||||
self.ext = ext
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
|
||||
setattr(self, k.lower(), v)
|
||||
|
||||
def __getitem__(self, item):
|
||||
for i in self._required_fields[self._packet_type]:
|
||||
if i.lower() == item.lower():
|
||||
return getattr(self, i)
|
||||
raise KeyError(item)
|
||||
|
||||
def get_friendly_name(self):
|
||||
return self._friendly_names[self._packet_type]
|
||||
|
||||
def encode(self, trailing_newlines=2):
|
||||
lines = [self._start_lines[self._packet_type]]
|
||||
for attr_name in self._required_fields[self._packet_type]:
|
||||
attr = getattr(self, attr_name)
|
||||
if attr is None:
|
||||
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
||||
if attr_name in self._marshallers:
|
||||
value = self._marshallers[attr_name](attr)
|
||||
else:
|
||||
value = attr
|
||||
lines.append("{}: {}".format(attr_name.upper(), value))
|
||||
serialized = line_separator.join(lines)
|
||||
for _ in range(trailing_newlines):
|
||||
serialized += line_separator
|
||||
return serialized
|
||||
|
||||
def as_dict(self):
|
||||
return self._lines_to_content_dict(self.encode().split(line_separator))
|
||||
|
||||
@classmethod
|
||||
def decode(cls, datagram):
|
||||
packet = cls._from_string(datagram.decode())
|
||||
for attr_name in packet._required_fields[packet._packet_type]:
|
||||
attr = getattr(packet, attr_name)
|
||||
if attr is None:
|
||||
raise UPnPError(
|
||||
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
|
||||
)
|
||||
return packet
|
||||
|
||||
@classmethod
|
||||
def _lines_to_content_dict(cls, lines):
|
||||
result = {}
|
||||
for l in content_lines:
|
||||
split = l.decode().split(":")
|
||||
if split and split[0]:
|
||||
k = split[0]
|
||||
v = ":".join(split[1:])
|
||||
result[flatten(k)] = flatten(v, lower=False)
|
||||
for line in lines:
|
||||
if not line:
|
||||
continue
|
||||
matched = False
|
||||
for name, (pattern, field_type) in cls._patterns.items():
|
||||
if name not in result and pattern.findall(line):
|
||||
match = pattern.findall(line)[-1][-1]
|
||||
result[name] = field_type(match.lstrip(" ").rstrip(" "))
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
if cls._vendor_field_pattern.findall(line):
|
||||
match = cls._vendor_field_pattern.findall(line)[-1]
|
||||
vendor_key = match[0].lstrip(" ").rstrip(" ")
|
||||
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
|
||||
value = match[2].lstrip(" ").rstrip(" ")
|
||||
if vendor_key not in result:
|
||||
result[vendor_key] = value
|
||||
return result
|
||||
|
||||
#
|
||||
# return {
|
||||
# (k.lower().rstrip(":".encode()).replace("-".encode(), "_".encode())).decode(): v.decode()
|
||||
# for k, v in {
|
||||
# l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:])
|
||||
#
|
||||
# }.items() if k
|
||||
# }
|
||||
@classmethod
|
||||
def _from_string(cls, datagram):
|
||||
lines = [l for l in datagram.split(line_separator) if l]
|
||||
if lines[0] == cls._start_lines[cls._M_SEARCH]:
|
||||
return cls._from_request(lines[1:])
|
||||
if lines[0] == cls._start_lines[cls._NOTIFY]:
|
||||
return cls._from_notify(lines[1:])
|
||||
if lines[0] == cls._start_lines[cls._OK]:
|
||||
return cls._from_response(lines[1:])
|
||||
|
||||
@classmethod
|
||||
def _from_response(cls, lines):
|
||||
return cls(cls._OK, **cls._lines_to_content_dict(lines))
|
||||
|
||||
def parse_ssdp_request(operation, port, protocol, content_lines):
|
||||
if operation != "NOTIFY".encode():
|
||||
log.warning("unsupported operation: %s", operation)
|
||||
raise UPnPError("unsupported operation: %s" % operation)
|
||||
if port != "*".encode():
|
||||
log.warning("unexpected port: %s", port)
|
||||
raise UPnPError("unexpected port: %s" % port)
|
||||
return parse_http_fields(content_lines)
|
||||
@classmethod
|
||||
def _from_notify(cls, lines):
|
||||
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
|
||||
|
||||
|
||||
def parse_ssdp_response(code, response, content_lines):
|
||||
try:
|
||||
if int(code) != 200:
|
||||
raise UPnPError("unexpected http response code: %i" % int(code))
|
||||
except ValueError:
|
||||
log.error(response)
|
||||
raise UPnPError("unexpected http response code: %s" % code)
|
||||
if response != "OK".encode():
|
||||
raise UPnPError("unexpected response: %s" % response)
|
||||
return parse_http_fields(content_lines)
|
||||
@classmethod
|
||||
def _from_request(cls, lines):
|
||||
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
|
||||
|
||||
|
||||
class SSDPProtocol(DatagramProtocol):
|
||||
def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS,
|
||||
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
|
||||
ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
|
||||
self._reactor = reactor
|
||||
self._sem = defer.DeferredSemaphore(1)
|
||||
self.finished_deferred = finished_deferred
|
||||
self.discover_callbacks = {}
|
||||
self.iface = iface
|
||||
self.router = router
|
||||
self.ssdp_address = ssdp_address
|
||||
|
@ -72,93 +193,81 @@ class SSDPProtocol(DatagramProtocol):
|
|||
self.devices = []
|
||||
|
||||
def startProtocol(self):
|
||||
return self._sem.run(self.do_start)
|
||||
|
||||
def send_m_search(self):
|
||||
data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_ALL, self.ttl)
|
||||
try:
|
||||
log.info("sending m-search (%i bytes) to %s:%i", len(data), self.ssdp_address, self.ssdp_port)
|
||||
self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port))
|
||||
except Exception as err:
|
||||
log.exception("failed to write %s to %s:%i", binascii.hexlify(data), self.ssdp_address, self.ssdp_port)
|
||||
raise err
|
||||
|
||||
def parse_ssdp_datagram(self, datagram):
|
||||
lines = datagram.split("\r\n".encode())
|
||||
header_pieces = lines[0].split(" ".encode())
|
||||
protocols = {
|
||||
"HTTP/1.1".encode()
|
||||
}
|
||||
operations = {
|
||||
"M-SEARCH".encode(),
|
||||
"NOTIFY".encode()
|
||||
}
|
||||
if header_pieces[0] in operations:
|
||||
if header_pieces[2] not in protocols:
|
||||
raise UPnPError("unknown protocol: %s" % header_pieces[2])
|
||||
return parse_ssdp_request(header_pieces[0], header_pieces[1], header_pieces[2], lines[1:])
|
||||
if header_pieces[0] in protocols:
|
||||
parsed = parse_ssdp_response(header_pieces[1], header_pieces[2], lines[1:])
|
||||
log.info("received reply (%i bytes) to SSDP request (%f) (%s) %s", len(datagram),
|
||||
self._reactor.seconds() - self._start, parsed['location'], parsed['server'])
|
||||
return parsed
|
||||
raise UPnPError("don't know how to decode datagram: %s" % binascii.hexlify(datagram))
|
||||
|
||||
def do_start(self):
|
||||
self._start = self._reactor.seconds()
|
||||
self.finished_deferred.addTimeout(self.ttl, self._reactor)
|
||||
self.transport.setTTL(self.ttl)
|
||||
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
|
||||
self.send_m_search()
|
||||
|
||||
for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]:
|
||||
self.send_m_search(service=st)
|
||||
|
||||
def send_m_search(self, service=GATEWAY_SCHEMA):
|
||||
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
|
||||
log.debug("writing packet:\n%s", packet.encode())
|
||||
log.info("sending m-search (%i bytes) to %s:%i", len(packet.encode()), self.ssdp_address, self.ssdp_port)
|
||||
try:
|
||||
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
|
||||
except Exception as err:
|
||||
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port)
|
||||
raise err
|
||||
|
||||
def leave_group(self):
|
||||
self.transport.leaveGroup(self.ssdp_address, interface=self.iface)
|
||||
|
||||
def datagramReceived(self, datagram, addr):
|
||||
self._sem.run(self.handle_datagram, datagram, addr)
|
||||
|
||||
def handle_datagram(self, datagram, address):
|
||||
if address[0] == self.router:
|
||||
def datagramReceived(self, datagram, address):
|
||||
if address[0] == self.iface:
|
||||
return
|
||||
try:
|
||||
parsed = self.parse_ssdp_datagram(datagram)
|
||||
self.devices.append(parsed)
|
||||
log.info("found %i/%s so far", len(self.devices), self.max_devices)
|
||||
if not self.finished_deferred.called:
|
||||
if not self.max_devices or (self.max_devices and len(self.devices) >= self.max_devices):
|
||||
self._sem.run(self.finished_deferred.callback, self.devices)
|
||||
except UPnPError as err:
|
||||
log.error("error decoding SSDP response from %s:%s (error: %s)\n%s", address[0], address[1], str(err), binascii.hexlify(datagram))
|
||||
raise err
|
||||
elif address[0] != self.iface:
|
||||
log.info("received %i bytes from %s:%i\n%s", len(datagram), address[0], address[1], binascii.hexlify(datagram))
|
||||
packet = SSDPDatagram.decode(datagram)
|
||||
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
|
||||
except Exception:
|
||||
log.exception("failed to decode: %s", binascii.hexlify(datagram))
|
||||
return
|
||||
if packet._packet_type == packet._OK:
|
||||
log.info("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
|
||||
else:
|
||||
pass # loopback
|
||||
log.info("%s:%i notified us of a service type: %s", address[0], address[1], packet.st)
|
||||
if packet.st not in map(lambda p: p['st'], self.devices):
|
||||
self.devices.append(packet.as_dict())
|
||||
log.info("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
|
||||
if address[0] in self.discover_callbacks:
|
||||
self._sem.run(self.discover_callbacks[address[0]][0], packet)
|
||||
|
||||
|
||||
def gather(finished_deferred, max_results):
|
||||
results = []
|
||||
|
||||
def discover_cb(packet):
|
||||
if not finished_deferred.called:
|
||||
results.append(packet.as_dict())
|
||||
if len(results) >= max_results:
|
||||
finished_deferred.callback(results)
|
||||
|
||||
return discover_cb
|
||||
|
||||
|
||||
class SSDPFactory(object):
|
||||
def __init__(self, lan_address, reactor):
|
||||
def __init__(self, reactor, lan_address, router_address):
|
||||
self.lan_address = lan_address
|
||||
self.router_address = router_address
|
||||
self._reactor = reactor
|
||||
self.protocol = None
|
||||
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
|
||||
self.port = None
|
||||
self.finished_deferred = defer.Deferred()
|
||||
|
||||
def stop(self):
|
||||
try:
|
||||
def disconnect(self):
|
||||
if self.protocol:
|
||||
self.protocol.leave_group()
|
||||
self.protocol = None
|
||||
if not self.port:
|
||||
return
|
||||
self.port.stopListening()
|
||||
except:
|
||||
pass
|
||||
self.port = None
|
||||
|
||||
def connect(self, address, ttl, max_devices=1):
|
||||
self.protocol = SSDPProtocol(self._reactor, self.finished_deferred, self.lan_address, address, ttl=ttl,
|
||||
max_devices=max_devices)
|
||||
def connect(self):
|
||||
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
|
||||
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
|
||||
self._reactor.addSystemEventTrigger("before", "shutdown", self.stop)
|
||||
return self.finished_deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def m_search(self, address, ttl=30, max_devices=2):
|
||||
def m_search(self, address, timeout=30, max_devices=1):
|
||||
"""
|
||||
Perform a HTTP over UDP M-SEARCH query
|
||||
|
||||
|
@ -170,11 +279,18 @@ class SSDPFactory(object):
|
|||
'usn': <usn>
|
||||
}, ...]
|
||||
"""
|
||||
d = self.connect(address, ttl, max_devices=max_devices)
|
||||
|
||||
self.connect()
|
||||
|
||||
if address in self.protocol.discover_callbacks:
|
||||
d = self.protocol.discover_callbacks[address][1]
|
||||
else:
|
||||
d = defer.Deferred()
|
||||
d.addTimeout(timeout, self._reactor)
|
||||
found_cb = gather(d, max_devices)
|
||||
self.protocol.discover_callbacks[address] = found_cb, d
|
||||
try:
|
||||
server_infos = yield d
|
||||
except defer.TimeoutError:
|
||||
server_infos = self.protocol.devices
|
||||
log.info("found %i devices", len(server_infos))
|
||||
self.stop()
|
||||
defer.returnValue(server_infos)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
import logging
|
||||
from twisted.internet import reactor, defer
|
||||
from txupnp.upnp import UPnP
|
||||
|
@ -7,10 +8,12 @@ log = logging.getLogger("txupnp")
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test(ext_port=4446, int_port=4446, proto='UDP'):
|
||||
def test(ext_port=4446, int_port=4446, proto='UDP', timeout=1):
|
||||
u = UPnP(reactor)
|
||||
found = yield u.discover()
|
||||
assert found, "M-SEARCH failed to find gateway"
|
||||
found = yield u.discover(timeout=timeout)
|
||||
if not found:
|
||||
print("failed to find gateway")
|
||||
defer.returnValue(None)
|
||||
external_ip = yield u.get_external_ip()
|
||||
assert external_ip, "Failed to get the external IP"
|
||||
log.info(external_ip)
|
||||
|
@ -45,17 +48,22 @@ def test(ext_port=4446, int_port=4446, proto='UDP'):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def run_tests():
|
||||
for p in ['UDP', 'TCP']:
|
||||
yield test(proto=p)
|
||||
def run_tests(timeout=1):
|
||||
for p in ['UDP']:
|
||||
yield test(proto=p, timeout=timeout)
|
||||
|
||||
|
||||
def main():
|
||||
d = run_tests()
|
||||
def main(timeout):
|
||||
d = run_tests(timeout)
|
||||
d.addErrback(log.exception)
|
||||
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
|
||||
reactor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
if len(sys.argv) > 1:
|
||||
log.setLevel(logging.DEBUG)
|
||||
timeout = int(sys.argv[1])
|
||||
else:
|
||||
timeout = 1
|
||||
main(timeout)
|
||||
|
|
|
@ -19,7 +19,7 @@ class UPnP(object):
|
|||
def commands(self):
|
||||
return self.soap_manager.get_runner()
|
||||
|
||||
def m_search(self, address, ttl=30, max_devices=2):
|
||||
def m_search(self, address, timeout=30, max_devices=2):
|
||||
"""
|
||||
Perform a HTTP over UDP M-SEARCH query
|
||||
|
||||
|
@ -31,12 +31,12 @@ class UPnP(object):
|
|||
'usn': <usn>
|
||||
}, ...]
|
||||
"""
|
||||
return self.soap_manager.sspd_factory.m_search(address, ttl=ttl, max_devices=max_devices)
|
||||
return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover(self, ttl=30, max_devices=2):
|
||||
def discover(self, timeout=1, max_devices=1):
|
||||
try:
|
||||
yield self.soap_manager.discover_services(ttl=ttl, max_devices=max_devices)
|
||||
yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices)
|
||||
except defer.TimeoutError:
|
||||
log.warning("failed to find upnp gateway")
|
||||
defer.returnValue(False)
|
||||
|
|
Loading…
Reference in a new issue