aioupnp/txupnp/ssdp.py
Jack Robison 69f35aa54b
refactor
2018-10-04 17:06:02 -04:00

168 lines
7.1 KiB
Python

import logging
import binascii
from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol
from txupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types
from txupnp.constants import SSDP_HOST
from txupnp.fault import UPnPError
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger(__name__)
class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None, debug_packets=False,
debug_sent=None, debug_received=None):
self._reactor = reactor
self._sem = defer.DeferredSemaphore(1)
self.discover_callbacks = {}
self.iface = iface
self.router = router
self.ssdp_address = ssdp_address
self.ssdp_port = ssdp_port
self.ttl = ttl
self._start = None
self.max_devices = max_devices
self.devices = []
self.debug_packets = debug_packets
self.debug_sent = debug_sent if debug_sent is not None else []
self.debug_received = debug_received if debug_sent is not None else []
def _send_m_search(self, service=UPNP_ORG_IGD):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
try:
msg_bytes = packet.encode().encode()
if self.debug_packets:
self.debug_sent.append(msg_bytes)
self.transport.write(msg_bytes, (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
raise err
@staticmethod
def _gather(finished_deferred, max_results, results: list):
def discover_cb(packet):
if not finished_deferred.called and packet.st in service_types:
results.append(packet.as_dict())
if len(results) >= max_results:
finished_deferred.callback(results)
return discover_cb
def m_search(self, address, timeout, max_devices):
# return deferred for a pending call if we have one
if address in self.discover_callbacks:
d = self.discover_callbacks[address][1]
if not d.called: # the existing deferred has already fired, make a new one
return d
def _trap_timeout_and_return_results(err):
if err.check(defer.TimeoutError):
return self.devices
raise err
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
d.addErrback(_trap_timeout_and_return_results)
found_cb = self._gather(d, max_devices, self.devices)
self.discover_callbacks[address] = found_cb, d
for st in service_types:
self._send_m_search(service=st)
return d
def startProtocol(self):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
def datagramReceived(self, datagram, address):
if address[0] == self.iface:
return
if self.debug_packets:
self.debug_received.append((address, datagram))
try:
packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err,
binascii.hexlify(datagram))
return
except Exception:
log.exception("failed to decode: %s", binascii.hexlify(datagram))
return
if packet._packet_type == packet._OK:
log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
# if address[0] in self.discover_callbacks and packet.location not in map(lambda p: p['location'], self.devices):
if packet.location not in map(lambda p: p['location'], self.devices):
if address[0] not in self.discover_callbacks:
self.devices.append(packet.as_dict())
else:
self._sem.run(self.discover_callbacks[address[0]][0], packet)
else:
log.info("ignored packet from %s:%s (%s) %s", address[0], address[1], packet._packet_type, packet.location)
elif packet._packet_type == packet._NOTIFY:
log.debug("%s:%i sent us a notification (type: %s), url: %s", address[0], address[1], packet.nts,
packet.location)
class SSDPFactory:
def __init__(self, reactor, lan_address, router_address, debug_packets=False):
self.lan_address = lan_address
self.router_address = router_address
self._reactor = reactor
self.protocol = None
self.port = None
self.debug_packets = debug_packets
self.debug_sent = []
self.debug_received = []
self.server_infos = []
def disconnect(self):
if not self.port:
return
self.protocol.transport.leaveGroup(SSDP_IP_ADDRESS, interface=self.lan_address)
self.port.stopListening()
self.port = None
self.protocol = None
def connect(self):
if not self.protocol:
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address,
debug_packets=self.debug_packets, debug_sent=self.debug_sent,
debug_received=self.debug_received)
if not self.port:
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
@defer.inlineCallbacks
def m_search(self, address, timeout, max_devices):
"""
Perform a M-SEARCH (HTTP over UDP) and gather the results
:param address: (str) address to listen for responses from
:param timeout: (int) timeout for the query
:param max_devices: (int) block until timeout or at least this many devices are found
:param service_types: (list) M-SEARCH "ST" arguments to try, if None use the defaults
:return: (list) [ (dict) {
'server: (str) gateway os and version
'location': (str) upnp gateway url,
'cache-control': (str) max age,
'date': (int) server time,
'usn': (str) usn
}, ...]
"""
self.connect()
server_infos = yield self.protocol.m_search(address, timeout, max_devices)
for server_info in server_infos:
self.server_infos.append(server_info)
defer.returnValue(server_infos)
def get_ssdp_packet_replay(self) -> dict:
return {
'lan_address': self.lan_address,
'router_address': self.router_address,
'sent': self.debug_sent,
'received': self.debug_received,
}