aioupnp/txupnp/mocks.py

178 lines
6.5 KiB
Python
Raw Normal View History

2018-10-01 09:51:31 -04:00
import os
import json
2018-09-25 14:54:39 -04:00
import logging
2018-10-01 09:51:31 -04:00
from twisted.internet import task, defer
from twisted.internet.error import ConnectionDone
2018-09-25 14:54:39 -04:00
from twisted.internet.protocol import DatagramProtocol
2018-10-01 09:51:31 -04:00
from twisted.python.failure import Failure
from twisted.test.proto_helpers import _FakePort
from txupnp.ssdp_datagram import SSDPDatagram
2018-09-25 14:54:39 -04:00
log = logging.getLogger()
2018-10-01 09:51:31 -04:00
class MockResponse:
def __init__(self, content):
self._content = content
self.headers = {}
def content(self):
return defer.succeed(self._content)
class MockDevice:
def __init__(self, manufacturer, model):
self.manufacturer = manufacturer
self.model = model
device_path = os.path.join(
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"), "{} {}".format(manufacturer, model)
)
assert os.path.isfile(device_path)
with open(device_path, "r") as f:
self.device_dict = json.loads(f.read())
def __repr__(self):
return "MockDevice(manufacturer={}, model={})".format(self.manufacturer, self.model)
def get_mock_devices():
return [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "devices"))
if ".py" not in path and "pycache" not in path
]
def get_device_test_case(manufacturer: str, model: str) -> MockDevice:
r = [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"))
if ".py" not in path and "pycache" not in path and path.split(" ") == [manufacturer, model]
]
return r[0]
2018-09-25 14:54:39 -04:00
class MockMulticastTransport:
2018-10-01 09:51:31 -04:00
def __init__(self, address, port, max_packet_size, network, protocol):
2018-09-25 14:54:39 -04:00
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._network = network
2018-10-01 09:51:31 -04:00
self._protocol = protocol
2018-09-25 14:54:39 -04:00
def write(self, data, address):
2018-10-01 09:51:31 -04:00
if address[0] in self._network.group:
destinations = self._network.group[address[0]]
else:
destinations = address[0]
for address, dest in self._network.peers.items():
if address[0] in destinations and dest.address != self.address:
dest._protocol.datagramReceived(data, (self.address, self.port))
2018-09-25 14:54:39 -04:00
def setTTL(self, ttl):
pass
def joinGroup(self, address, interface=None):
2018-10-01 09:51:31 -04:00
group = self._network.group.get(address, [])
group.append(interface)
self._network.group[address] = group
2018-09-25 14:54:39 -04:00
def leaveGroup(self, address, interface=None):
2018-10-01 09:51:31 -04:00
group = self._network.group.get(address, [])
if interface in group:
group.remove(interface)
self._network.group[address] = group
class MockTCPTransport(_FakePort):
def __init__(self, address, port, callback, mock_requests):
super().__init__((address, port))
self._callback = callback
self._mock_requests = mock_requests
def write(self, data):
if data.startswith(b"POST"):
for url, packets in self._mock_requests['POST'].items():
for request_response in packets:
if data.decode() == request_response['request']:
self._callback(request_response['response'].encode())
return
elif data.startswith(b"GET"):
for url, packets in self._mock_requests['GET'].items():
if data.decode() == packets['request']:
self._callback(packets['response'].encode())
return
class MockMulticastPort(_FakePort):
def __init__(self, protocol, remover, address, transport):
super().__init__((address, 1900))
2018-09-25 14:54:39 -04:00
self.protocol = protocol
self._remover = remover
2018-10-01 09:51:31 -04:00
self.transport = transport
2018-09-25 14:54:39 -04:00
def startListening(self, reason=None):
2018-10-01 09:51:31 -04:00
self.protocol.transport = self.transport
2018-09-25 14:54:39 -04:00
return self.protocol.startProtocol()
def stopListening(self, reason=None):
result = self.protocol.stopProtocol()
self._remover()
return result
class MockNetwork:
def __init__(self):
2018-10-01 09:51:31 -04:00
self.peers = {}
self.group = {}
2018-09-25 14:54:39 -04:00
def add_peer(self, port, protocol, interface, maxPacketSize):
2018-10-01 09:51:31 -04:00
transport = MockMulticastTransport(interface, port, maxPacketSize, self, protocol)
self.peers[(interface, port)] = transport
2018-09-25 14:54:39 -04:00
def remove_peer():
if self.peers.get((interface, port)):
del self.peers[(interface, port)]
2018-10-01 09:51:31 -04:00
return transport, remove_peer
2018-09-25 14:54:39 -04:00
class MockReactor(task.Clock):
2018-10-01 09:51:31 -04:00
def __init__(self, client_addr, mock_scpd_requests):
2018-09-25 14:54:39 -04:00
super().__init__()
2018-10-01 09:51:31 -04:00
self.client_addr = client_addr
self._mock_scpd_requests = mock_scpd_requests
2018-09-25 14:54:39 -04:00
self.network = MockNetwork()
2018-10-01 09:51:31 -04:00
def listenMulticast(self, port, protocol, interface=None, maxPacketSize=8192, listenMultiple=True):
interface = interface or self.client_addr
transport, remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
port = MockMulticastPort(protocol, remover, interface, transport)
2018-09-25 14:54:39 -04:00
port.startListening()
return port
2018-10-01 09:51:31 -04:00
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
protocol = factory.buildProtocol(host)
def _write_and_close(data):
protocol.dataReceived(data)
protocol.connectionLost(Failure(ConnectionDone()))
protocol.transport = MockTCPTransport(host, port, _write_and_close, self._mock_scpd_requests)
protocol.connectionMade()
2018-09-25 14:54:39 -04:00
class MockSSDPServiceGatewayProtocol(DatagramProtocol):
2018-10-01 09:51:31 -04:00
def __init__(self, client_addr: int, iface: str, packets_rx: list, packets_tx: list):
self.client_addr = client_addr
2018-09-25 14:54:39 -04:00
self.iface = iface
2018-10-01 09:51:31 -04:00
self.packets_tx = [SSDPDatagram.decode(packet.encode()) for packet in packets_tx] # sent by client
self.packets_rx = [((addr, port), SSDPDatagram.decode(packet.encode())) for (addr, port), packet in packets_rx] # rx by client
2018-09-25 14:54:39 -04:00
def datagramReceived(self, datagram, address):
2018-10-01 09:51:31 -04:00
packet = SSDPDatagram.decode(datagram)
if packet.st in map(lambda p: p[1].st, self.packets_rx): # this contains one of the service types the server replied to
reply = list(filter(lambda p: p[1].st == packet.st, self.packets_rx))[0][1]
self.transport.write(reply.encode().encode(), (self.client_addr, 1900))
else:
pass