177 lines
6.5 KiB
177 lines
6.5 KiB
import os
import json
import logging
from twisted.internet import task, defer
from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import DatagramProtocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import _FakePort
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger()
class MockResponse:
def __init__(self, content):
self._content = content
self.headers = {}
def content(self):
return defer.succeed(self._content)
class MockDevice:
def __init__(self, manufacturer, model):
self.manufacturer = manufacturer
self.model = model
device_path = os.path.join(
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"), "{} {}".format(manufacturer, model)
assert os.path.isfile(device_path)
with open(device_path, "r") as f:
self.device_dict = json.loads(f.read())
def __repr__(self):
return "MockDevice(manufacturer={}, model={})".format(self.manufacturer, self.model)
def get_mock_devices():
return [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "devices"))
if ".py" not in path and "pycache" not in path
def get_device_test_case(manufacturer: str, model: str) -> MockDevice:
r = [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"))
if ".py" not in path and "pycache" not in path and path.split(" ") == [manufacturer, model]
return r[0]
class MockMulticastTransport:
def __init__(self, address, port, max_packet_size, network, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._network = network
self._protocol = protocol
def write(self, data, address):
if address[0] in self._network.group:
destinations = self._network.group[address[0]]
destinations = address[0]
for address, dest in self._network.peers.items():
if address[0] in destinations and dest.address != self.address:
dest._protocol.datagramReceived(data, (self.address, self.port))
def setTTL(self, ttl):
def joinGroup(self, address, interface=None):
group = self._network.group.get(address, [])
self._network.group[address] = group
def leaveGroup(self, address, interface=None):
group = self._network.group.get(address, [])
if interface in group:
self._network.group[address] = group
class MockTCPTransport(_FakePort):
def __init__(self, address, port, callback, mock_requests):
super().__init__((address, port))
self._callback = callback
self._mock_requests = mock_requests
def write(self, data):
if data.startswith(b"POST"):
for url, packets in self._mock_requests['POST'].items():
for request_response in packets:
if data.decode() == request_response['request']:
elif data.startswith(b"GET"):
for url, packets in self._mock_requests['GET'].items():
if data.decode() == packets['request']:
class MockMulticastPort(_FakePort):
def __init__(self, protocol, remover, address, transport):
super().__init__((address, 1900))
self.protocol = protocol
self._remover = remover
self.transport = transport
def startListening(self, reason=None):
self.protocol.transport = self.transport
return self.protocol.startProtocol()
def stopListening(self, reason=None):
result = self.protocol.stopProtocol()
return result
class MockNetwork:
def __init__(self):
self.peers = {}
self.group = {}
def add_peer(self, port, protocol, interface, maxPacketSize):
transport = MockMulticastTransport(interface, port, maxPacketSize, self, protocol)
self.peers[(interface, port)] = transport
def remove_peer():
if self.peers.get((interface, port)):
del self.peers[(interface, port)]
return transport, remove_peer
class MockReactor(task.Clock):
def __init__(self, client_addr, mock_scpd_requests):
self.client_addr = client_addr
self._mock_scpd_requests = mock_scpd_requests
self.network = MockNetwork()
def listenMulticast(self, port, protocol, interface=None, maxPacketSize=8192, listenMultiple=True):
interface = interface or self.client_addr
transport, remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
port = MockMulticastPort(protocol, remover, interface, transport)
return port
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
protocol = factory.buildProtocol(host)
def _write_and_close(data):
protocol.transport = MockTCPTransport(host, port, _write_and_close, self._mock_scpd_requests)
class MockSSDPServiceGatewayProtocol(DatagramProtocol):
def __init__(self, client_addr: int, iface: str, packets_rx: list, packets_tx: list):
self.client_addr = client_addr
self.iface = iface
self.packets_tx = [SSDPDatagram.decode(packet.encode()) for packet in packets_tx] # sent by client
self.packets_rx = [((addr, port), SSDPDatagram.decode(packet.encode())) for (addr, port), packet in packets_rx] # rx by client
def datagramReceived(self, datagram, address):
packet = SSDPDatagram.decode(datagram)
if packet.st in map(lambda p: p[1].st, self.packets_rx): # this contains one of the service types the server replied to
reply = list(filter(lambda p: p[1].st == packet.st, self.packets_rx))[0][1]
self.transport.write(reply.encode().encode(), (self.client_addr, 1900))