Merge branch 'testing'
This commit is contained in:
commit
d9fee45fc7
18 changed files with 1017 additions and 948 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -3,3 +3,4 @@
|
|||
_trial_temp/
|
||||
build/
|
||||
dist/
|
||||
.coverage
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
# UPnP for Twisted
|
||||
|
||||
`txupnp` is a python2/3 library to interact with UPnP gateways using Twisted
|
||||
`txupnp` is a python 3 library to interact with UPnP gateways using `twisted`
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ setup(
|
|||
long_description=long_description,
|
||||
url="https://github.com/lbryio/txupnp",
|
||||
license=__license__,
|
||||
packages=find_packages(),
|
||||
packages=find_packages(exclude=['tests']),
|
||||
entry_points={'console_scripts': console_scripts},
|
||||
install_requires=[
|
||||
'twisted[tls]',
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
import logging
|
||||
|
||||
log = logging.getLogger("txupnp")
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
|
||||
log.addHandler(handler)
|
||||
log.setLevel(logging.INFO)
|
208
tests/devices/Cisco CGA4131COM
Normal file
208
tests/devices/Cisco CGA4131COM
Normal file
File diff suppressed because one or more lines are too long
105
tests/test_devices.py
Normal file
105
tests/test_devices.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
from twisted.internet import reactor, defer
|
||||
from twisted.trial import unittest
|
||||
from txupnp.constants import SSDP_PORT, SSDP_IP_ADDRESS
|
||||
from txupnp.upnp import UPnP
|
||||
from txupnp.scpd import SCPDCommand
|
||||
from txupnp.gateway import Service
|
||||
from txupnp.fault import UPnPError
|
||||
from txupnp.mocks import MockReactor, MockSSDPServiceGatewayProtocol, get_device_test_case
|
||||
from txupnp.util import verify_return_types
|
||||
|
||||
|
||||
class TestDevice(unittest.TestCase):
|
||||
manufacturer, model = "Cisco", "CGA4131COM"
|
||||
|
||||
device = get_device_test_case(manufacturer, model)
|
||||
router_address = device.device_dict['router_address']
|
||||
client_address = device.device_dict['client_address']
|
||||
expected_devices = device.device_dict['expected_devices']
|
||||
packets_rx = device.device_dict['ssdp']['received']
|
||||
packets_tx = device.device_dict['ssdp']['sent']
|
||||
expected_available_commands = device.device_dict['commands']['available']
|
||||
scdp_packets = device.device_dict['scpd']
|
||||
|
||||
def setUp(self):
|
||||
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
|
||||
reactor.listenMulticast = fake_reactor.listenMulticast
|
||||
self.reactor = reactor
|
||||
server_protocol = MockSSDPServiceGatewayProtocol(
|
||||
self.client_address, self.router_address, self.packets_rx, self.packets_tx
|
||||
)
|
||||
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
|
||||
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
|
||||
|
||||
self.upnp = UPnP(
|
||||
self.reactor, debug_ssdp=True, router_ip=self.router_address,
|
||||
lan_ip=self.client_address, iface_name='mock'
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.upnp.sspd_factory.disconnect()
|
||||
self.server_port.stopListening()
|
||||
|
||||
|
||||
class TestSSDP(TestDevice):
|
||||
@defer.inlineCallbacks
|
||||
def test_discover_device(self):
|
||||
result = yield self.upnp.m_search(self.router_address, timeout=1)
|
||||
self.assertEqual(len(self.expected_devices), len(result))
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertDictEqual(self.expected_devices[0], result[0])
|
||||
|
||||
|
||||
class TestSCPD(TestDevice):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
|
||||
reactor.listenMulticast = fake_reactor.listenMulticast
|
||||
reactor.connectTCP = fake_reactor.connectTCP
|
||||
self.reactor = reactor
|
||||
server_protocol = MockSSDPServiceGatewayProtocol(
|
||||
self.client_address, self.router_address, self.packets_rx, self.packets_tx
|
||||
)
|
||||
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
|
||||
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
|
||||
|
||||
self.upnp = UPnP(
|
||||
self.reactor, debug_ssdp=True, router_ip=self.router_address,
|
||||
lan_ip=self.client_address, iface_name='mock'
|
||||
)
|
||||
yield self.upnp.discover()
|
||||
|
||||
def test_parse_available_commands(self):
|
||||
self.assertDictEqual(self.expected_available_commands, self.upnp.gateway.debug_commands()['available'])
|
||||
|
||||
def test_parse_gateway(self):
|
||||
self.assertDictEqual(self.device.device_dict['gateway_dict'], self.upnp.gateway.as_dict())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_commands(self):
|
||||
method, args, expected = self.device.device_dict['soap'][0]
|
||||
command1 = getattr(self.upnp, method)
|
||||
result = yield command1(*tuple(args))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
method, args, expected = self.device.device_dict['soap'][1]
|
||||
command2 = getattr(self.upnp, method)
|
||||
result = yield command2(*tuple(args))
|
||||
result = [[i for i in r] for r in result]
|
||||
self.assertListEqual(result, expected)
|
||||
|
||||
method, args, expected = self.device.device_dict['soap'][2]
|
||||
command3 = getattr(self.upnp, method)
|
||||
result = yield command3(*tuple(args))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
method, args, expected = self.device.device_dict['soap'][3]
|
||||
command4 = getattr(self.upnp, method)
|
||||
result = yield command4(*tuple(args))
|
||||
result = [r for r in result]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
method, args, expected = self.device.device_dict['soap'][4]
|
||||
command5 = getattr(self.upnp, method)
|
||||
result = yield command5(*tuple(args))
|
||||
self.assertEqual(result, expected)
|
|
@ -1,45 +0,0 @@
|
|||
from twisted.internet import reactor, defer
|
||||
from twisted.trial import unittest
|
||||
from txupnp.constants import SSDP_PORT
|
||||
from txupnp import mocks
|
||||
from txupnp.ssdp import SSDPFactory
|
||||
|
||||
|
||||
class TestDiscoverGateway(unittest.TestCase):
|
||||
router_address = '10.0.0.1'
|
||||
client_address = '10.0.0.10'
|
||||
service_name = 'WANCommonInterfaceConfig:1'
|
||||
st = 'urn:schemas-upnp-org:service:%s' % service_name
|
||||
port = 49152
|
||||
location = 'InternetGatewayDevice.xml'
|
||||
usn = 'uuid:00000000-0000-0000-0000-000000000000::%s' % st
|
||||
version = 'Linux, UPnP/1.0, DIR-890L Ver 1.20'
|
||||
|
||||
expected_devices = [
|
||||
{
|
||||
'cache_control': 'max-age=1800',
|
||||
'location': 'http://%s:%i/%s' % (router_address, port, location),
|
||||
'server': version,
|
||||
'st': st,
|
||||
'usn': usn
|
||||
}
|
||||
]
|
||||
|
||||
def setUp(self):
|
||||
fake_reactor = mocks.MockReactor()
|
||||
reactor.listenMulticast = fake_reactor.listenMulticast
|
||||
self.reactor = reactor
|
||||
server_protocol = mocks.MockSSDPServiceGatewayProtocol(
|
||||
self.router_address, self.service_name, self.st, self.port, self.location, self.usn, self.version
|
||||
)
|
||||
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol)
|
||||
|
||||
def tearDown(self):
|
||||
self.server_port.stopListening()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_discover(self):
|
||||
client_factory = SSDPFactory(self.reactor, self.client_address, self.router_address)
|
||||
result = yield client_factory.m_search(self.router_address)
|
||||
self.assertListEqual(self.expected_devices, result)
|
||||
client_factory.disconnect()
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from twisted.internet import reactor, defer
|
||||
|
@ -6,11 +8,6 @@ from txupnp.upnp import UPnP
|
|||
log = logging.getLogger("txupnp")
|
||||
|
||||
|
||||
def debug_device(u, include_gateway_xml=False, *_):
|
||||
print(u.get_debug_info(include_gateway_xml=include_gateway_xml))
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_external_ip(u, *_):
|
||||
ip = yield u.get_external_ip()
|
||||
|
@ -30,7 +27,7 @@ def list_mappings(u, *_):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def add_mapping(u, *_):
|
||||
port = 4567
|
||||
port = 51413
|
||||
protocol = "UDP"
|
||||
description = "txupnp test mapping"
|
||||
ext_port = yield u.get_next_mapping(port, protocol, description)
|
||||
|
@ -50,12 +47,64 @@ def delete_mapping(u, *_):
|
|||
print("removed mapping")
|
||||
|
||||
|
||||
def _encode(x):
|
||||
if isinstance(x, bytes):
|
||||
return x.decode()
|
||||
elif isinstance(x, Exception):
|
||||
return str(x)
|
||||
return x
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_test_data(u, *_):
|
||||
external_ip = yield u.get_external_ip()
|
||||
redirects = yield u.get_redirects()
|
||||
ext_port = yield u.get_next_mapping(4567, "UDP", "txupnp test mapping")
|
||||
delete = yield u.delete_port_mapping(ext_port, "UDP")
|
||||
after_delete = yield u.get_specific_port_mapping(ext_port, "UDP")
|
||||
|
||||
commands_test_case = (
|
||||
("get_external_ip", (), "1.2.3.4"),
|
||||
("get_redirects", (), redirects),
|
||||
("get_next_mapping", (4567, "UDP", "txupnp test mapping"), ext_port),
|
||||
("delete_port_mapping", (ext_port, "UDP"), delete),
|
||||
("get_specific_port_mapping", (ext_port, "UDP"), after_delete),
|
||||
)
|
||||
|
||||
gateway = u.gateway
|
||||
device = list(gateway.devices.values())[0]
|
||||
assert device.manufacturer and device.modelName
|
||||
device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName))
|
||||
commands = gateway.debug_commands()
|
||||
with open(device_path, "w") as f:
|
||||
f.write(json.dumps({
|
||||
"router_address": u.router_ip,
|
||||
"client_address": u.lan_address,
|
||||
"port": gateway.port,
|
||||
"gateway_dict": gateway.as_dict(),
|
||||
'expected_devices': [
|
||||
{
|
||||
'cache_control': 'max-age=1800',
|
||||
'location': gateway.location,
|
||||
'server': gateway.server,
|
||||
'st': gateway.urn,
|
||||
'usn': gateway.usn
|
||||
}
|
||||
],
|
||||
'commands': commands,
|
||||
'ssdp': u.sspd_factory.get_ssdp_packet_replay(),
|
||||
'scpd': gateway.requester.dump_packets(),
|
||||
'soap': commands_test_case
|
||||
}, default=_encode, indent=2).replace(external_ip, "1.2.3.4"))
|
||||
print("Generated test data! -> %s" % device_path)
|
||||
|
||||
|
||||
cli_commands = {
|
||||
"debug_device": debug_device,
|
||||
"get_external_ip": get_external_ip,
|
||||
"list_mappings": list_mappings,
|
||||
"add_mapping": add_mapping,
|
||||
"delete_mapping": delete_mapping
|
||||
"delete_mapping": delete_mapping,
|
||||
"generate_test_data": generate_test_data,
|
||||
}
|
||||
|
||||
|
||||
|
@ -85,9 +134,9 @@ def main():
|
|||
parser.add_argument("--include_igd_xml", dest="include_igd_xml", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.debug_logging:
|
||||
from twisted.python import log as tx_log
|
||||
observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
|
||||
observer.start()
|
||||
# from twisted.python import log as tx_log
|
||||
# observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
|
||||
# observer.start()
|
||||
log.setLevel(logging.DEBUG)
|
||||
command = args.command
|
||||
command = command.replace("-", "_")
|
||||
|
@ -98,7 +147,7 @@ def main():
|
|||
def show(err):
|
||||
print("error: {}".format(err))
|
||||
|
||||
u = UPnP(reactor)
|
||||
u = UPnP(reactor, debug_ssdp=(command == "generate_test_data"))
|
||||
d = u.discover()
|
||||
d.addCallback(run_command, u, command, args.include_igd_xml)
|
||||
d.addErrback(show)
|
||||
|
|
137
txupnp/commands.py
Normal file
137
txupnp/commands.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
from txupnp.util import return_types, none_or_str, none
|
||||
|
||||
|
||||
class SCPDCommands: # TODO use type annotations
|
||||
def debug_commands(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
|
||||
NewInternalClient: str, NewEnabled: bool, NewPortMappingDescription: str,
|
||||
NewLeaseDuration: str = '') -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(bool, bool)
|
||||
def GetNATRSIPStatus() -> (bool, bool):
|
||||
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none_or_str, int, str, int, str, bool, str, int)
|
||||
def GetGenericPortMappingEntry(NewPortMappingIndex) -> (none_or_str, int, str, int, str, bool, str, int):
|
||||
"""
|
||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||
NewPortMappingDescription, NewLeaseDuration)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(int, str, bool, str, int)
|
||||
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol) -> (int, str, bool, str, int):
|
||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def SetConnectionType(NewConnectionType) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str)
|
||||
def GetExternalIPAddress() -> str:
|
||||
"""Returns (NewExternalIPAddress)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str, str)
|
||||
def GetConnectionTypeInfo() -> (str, str):
|
||||
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str, str, int)
|
||||
def GetStatusInfo() -> (str, str, int):
|
||||
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def ForceTermination() -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol) -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def RequestConnection() -> None:
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetCommonLinkProperties():
|
||||
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalBytesSent():
|
||||
"""Returns (NewTotalBytesSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalBytesReceived():
|
||||
"""Returns (NewTotalBytesReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalPacketsSent():
|
||||
"""Returns (NewTotalPacketsSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalPacketsReceived():
|
||||
"""Returns (NewTotalPacketsReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def X_GetICSStatistics():
|
||||
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetDefaultConnectionService():
|
||||
"""Returns (NewDefaultConnectionService)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def SetDefaultConnectionService(NewDefaultConnectionService) -> None:
|
||||
"""Returns (None)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def SetEnabledForInternet(NewEnabledForInternet) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(bool)
|
||||
def GetEnabledForInternet() -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetMaximumActiveConnections(NewActiveConnectionIndex):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str, str)
|
||||
def GetActiveConnections() -> (str, str):
|
||||
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||
raise NotImplementedError()
|
|
@ -20,15 +20,7 @@ IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
|
|||
|
||||
service_types = [
|
||||
UPNP_ORG_IGD,
|
||||
WIFI_ALLIANCE_ORG_IGD,
|
||||
WAN_SCHEMA,
|
||||
LAYER_SCHEMA,
|
||||
IP_SCHEMA,
|
||||
|
||||
|
||||
CONTROL,
|
||||
SERVICE,
|
||||
DEVICE,
|
||||
# WIFI_ALLIANCE_ORG_IGD,
|
||||
]
|
||||
|
||||
SSDP_IP_ADDRESS = '239.255.255.250'
|
||||
|
|
|
@ -1,94 +0,0 @@
|
|||
import logging
|
||||
from twisted.web.client import HTTPConnectionPool, _HTTP11ClientFactory
|
||||
from twisted.web._newclient import HTTPClientParser, BadResponseVersion, HTTP11ClientProtocol, RequestNotSent
|
||||
from twisted.web._newclient import TransportProxyProducer, RequestGenerationFailed
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.internet.defer import Deferred, fail, maybeDeferred
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class DirtyHTTPParser(HTTPClientParser):
|
||||
def parseVersion(self, strversion):
|
||||
"""
|
||||
Parse version strings of the form Protocol '/' Major '.' Minor. E.g.
|
||||
b'HTTP/1.1'. Returns (protocol, major, minor). Will raise ValueError
|
||||
on bad syntax.
|
||||
"""
|
||||
try:
|
||||
proto, strnumber = strversion.split(b'/')
|
||||
major, minor = strnumber.split(b'.')
|
||||
major, minor = int(major), int(minor)
|
||||
except ValueError as e:
|
||||
log.exception("got a bad http version: %s", strversion)
|
||||
if b'HTTP1.1' in strversion:
|
||||
return ("HTTP", 1, 1)
|
||||
raise BadResponseVersion(str(e), strversion)
|
||||
if major < 0 or minor < 0:
|
||||
raise BadResponseVersion(u"version may not be negative",
|
||||
strversion)
|
||||
return (proto, major, minor)
|
||||
|
||||
|
||||
class DirtyHTTPClientProtocol(HTTP11ClientProtocol):
|
||||
def request(self, request):
|
||||
if self._state != 'QUIESCENT':
|
||||
return fail(RequestNotSent())
|
||||
|
||||
self._state = 'TRANSMITTING'
|
||||
_requestDeferred = maybeDeferred(request.writeTo, self.transport)
|
||||
|
||||
def cancelRequest(ign):
|
||||
# Explicitly cancel the request's deferred if it's still trying to
|
||||
# write when this request is cancelled.
|
||||
if self._state in (
|
||||
'TRANSMITTING', 'TRANSMITTING_AFTER_RECEIVING_RESPONSE'):
|
||||
_requestDeferred.cancel()
|
||||
else:
|
||||
self.transport.abortConnection()
|
||||
self._disconnectParser(Failure(CancelledError()))
|
||||
|
||||
self._finishedRequest = Deferred(cancelRequest)
|
||||
|
||||
# Keep track of the Request object in case we need to call stopWriting
|
||||
# on it.
|
||||
self._currentRequest = request
|
||||
|
||||
self._transportProxy = TransportProxyProducer(self.transport)
|
||||
self._parser = DirtyHTTPParser(request, self._finishResponse)
|
||||
self._parser.makeConnection(self._transportProxy)
|
||||
self._responseDeferred = self._parser._responseDeferred
|
||||
|
||||
def cbRequestWritten(ignored):
|
||||
if self._state == 'TRANSMITTING':
|
||||
self._state = 'WAITING'
|
||||
self._responseDeferred.chainDeferred(self._finishedRequest)
|
||||
|
||||
def ebRequestWriting(err):
|
||||
if self._state == 'TRANSMITTING':
|
||||
self._state = 'GENERATION_FAILED'
|
||||
self.transport.abortConnection()
|
||||
self._finishedRequest.errback(
|
||||
Failure(RequestGenerationFailed([err])))
|
||||
else:
|
||||
self._log.failure(
|
||||
u'Error writing request, but not in valid state '
|
||||
u'to finalize request: {state}',
|
||||
failure=err,
|
||||
state=self._state
|
||||
)
|
||||
|
||||
_requestDeferred.addCallbacks(cbRequestWritten, ebRequestWriting)
|
||||
|
||||
return self._finishedRequest
|
||||
|
||||
|
||||
class DirtyHTTP11ClientFactory(_HTTP11ClientFactory):
|
||||
def buildProtocol(self, addr):
|
||||
return DirtyHTTPClientProtocol(self._quiescentCallback)
|
||||
|
||||
|
||||
class DirtyPool(HTTPConnectionPool):
|
||||
_factory = DirtyHTTP11ClientFactory
|
|
@ -1,23 +1,12 @@
|
|||
import logging
|
||||
from twisted.internet import defer
|
||||
import treq
|
||||
import re
|
||||
from xml.etree import ElementTree
|
||||
from txupnp.util import etree_to_dict, flatten_keys, get_dict_val_case_insensitive
|
||||
from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||
from txupnp.constants import DEVICE, ROOT
|
||||
from txupnp.scpd import SCPDCommand, SCPDRequester
|
||||
from txupnp.util import get_dict_val_case_insensitive, verify_return_types, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||
from txupnp.constants import SPEC_VERSION
|
||||
from txupnp.commands import SCPDCommands
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
service_type_pattern = re.compile(
|
||||
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\})"
|
||||
)
|
||||
|
||||
xml_root_sanity_pattern = re.compile(
|
||||
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
|
||||
)
|
||||
|
||||
|
||||
class CaseInsensitive:
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -112,7 +101,7 @@ class Device(CaseInsensitive):
|
|||
|
||||
|
||||
class Gateway:
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, reactor, **kwargs):
|
||||
flattened = {
|
||||
k.lower(): v for k, v in kwargs.items()
|
||||
}
|
||||
|
@ -133,9 +122,10 @@ class Gateway:
|
|||
self.date = date.encode()
|
||||
self.urn = st.encode()
|
||||
|
||||
self._xml_response = ""
|
||||
self._service_descriptors = {}
|
||||
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
|
||||
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
|
||||
self.xml_response = None
|
||||
self.spec_version = None
|
||||
self.url_base = None
|
||||
|
||||
|
@ -143,55 +133,65 @@ class Gateway:
|
|||
self._devices = []
|
||||
self._services = []
|
||||
|
||||
def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict:
|
||||
r = {
|
||||
'server': self.server,
|
||||
'urlBase': self.url_base,
|
||||
'location': self.location,
|
||||
"specVersion": self.spec_version,
|
||||
'usn': self.usn,
|
||||
'urn': self.urn,
|
||||
'devices': [device.as_dict() for device in self._devices]
|
||||
}
|
||||
if include_xml:
|
||||
r['xml_response'] = self.xml_response
|
||||
if include_services:
|
||||
r['services'] = [service.as_dict() for service in self._services]
|
||||
self._reactor = reactor
|
||||
self._unsupported_actions = {}
|
||||
self._registered_commands = {}
|
||||
self.commands = SCPDCommands()
|
||||
self.requester = SCPDRequester(self._reactor)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
r = {
|
||||
'server': self.server.decode(),
|
||||
'urlBase': self.url_base,
|
||||
'location': self.location.decode(),
|
||||
"specVersion": self.spec_version,
|
||||
'usn': self.usn.decode(),
|
||||
'urn': self.urn.decode(),
|
||||
}
|
||||
return r
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover_services(self):
|
||||
log.debug("querying %s", self.location)
|
||||
response = yield treq.get(self.location)
|
||||
self.xml_response = yield response.content()
|
||||
if not self.xml_response:
|
||||
log.warning("service sent an empty reply\n%s", self.debug_device())
|
||||
xml_dict = etree_to_dict(ElementTree.fromstring(self.xml_response))
|
||||
schema_key = DEVICE
|
||||
root = ROOT
|
||||
if len(xml_dict) > 1:
|
||||
log.warning(xml_dict.keys())
|
||||
for k in xml_dict.keys():
|
||||
m = xml_root_sanity_pattern.findall(k)
|
||||
if len(m) == 3 and m[1][0] and m[2][5]:
|
||||
schema_key = m[1][0]
|
||||
root = m[2][5]
|
||||
break
|
||||
|
||||
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
||||
self.spec_version = get_dict_val_case_insensitive(flattened_xml, SPEC_VERSION)
|
||||
self.url_base = get_dict_val_case_insensitive(flattened_xml, "urlbase")
|
||||
|
||||
if flattened_xml:
|
||||
def discover_commands(self):
|
||||
response = yield self.requester.scpd_get(self.location.decode().split(self.base_address.decode())[1], self.base_address.decode(), self.port)
|
||||
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
|
||||
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
|
||||
if not self.url_base:
|
||||
self.url_base = self.base_address.decode()
|
||||
if response:
|
||||
self._device = Device(
|
||||
self._devices, self._services, **get_dict_val_case_insensitive(flattened_xml, "device")
|
||||
self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
|
||||
)
|
||||
log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices),
|
||||
len(self.services))
|
||||
else:
|
||||
self._device = Device(self._devices, self._services)
|
||||
log.debug("finished setting up gateway:\n%s", self.debug_device())
|
||||
for service_type in self.services.keys():
|
||||
service = self.services[service_type]
|
||||
yield self.register_commands(service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register_commands(self, service: Service):
|
||||
try:
|
||||
action_list = yield self.requester.scpd_get_supported_actions(service, self.base_address.decode(), self.port)
|
||||
except Exception as err:
|
||||
log.exception("failed to register service %s: %s", service.serviceType, str(err))
|
||||
return
|
||||
for name, inputs, outputs in action_list:
|
||||
try:
|
||||
command = SCPDCommand(self.requester, self.base_address, self.port,
|
||||
service.controlURL.encode(),
|
||||
service.serviceType.encode(), name, inputs, outputs)
|
||||
current = getattr(self.commands, command.method)
|
||||
if hasattr(current, "_return_types"):
|
||||
command._process_result = verify_return_types(*current._return_types)(command._process_result)
|
||||
setattr(command, "__doc__", current.__doc__)
|
||||
setattr(self.commands, command.method, command)
|
||||
self._registered_commands[command.method] = service.serviceType
|
||||
log.debug("registered %s::%s", service.serviceType, command.method)
|
||||
except AttributeError:
|
||||
s = self._unsupported_actions.get(service.serviceType, [])
|
||||
s.append(name)
|
||||
self._unsupported_actions[service.serviceType] = s
|
||||
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
|
||||
service.serviceType, name, inputs, outputs)
|
||||
|
||||
@property
|
||||
def services(self) -> dict:
|
||||
|
@ -205,7 +205,13 @@ class Gateway:
|
|||
return {}
|
||||
return {device.udn: device for device in self._devices}
|
||||
|
||||
def get_service(self, service_type) -> Service:
|
||||
def get_service(self, service_type: str) -> Service:
|
||||
for service in self._services:
|
||||
if service.serviceType.lower() == service_type.lower():
|
||||
return service
|
||||
|
||||
def debug_commands(self):
|
||||
return {
|
||||
'available': self._registered_commands,
|
||||
'failed': self._unsupported_actions
|
||||
}
|
||||
|
|
170
txupnp/mocks.py
170
txupnp/mocks.py
|
@ -1,44 +1,118 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
import binascii
|
||||
from twisted.internet import task
|
||||
from txupnp.constants import SSDP_IP_ADDRESS
|
||||
from txupnp.fault import UPnPError
|
||||
from txupnp.ssdp_datagram import SSDPDatagram
|
||||
from twisted.internet import task, defer
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.internet.protocol import DatagramProtocol
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import _FakePort
|
||||
from txupnp.ssdp_datagram import SSDPDatagram
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, content):
|
||||
self._content = content
|
||||
self.headers = {}
|
||||
|
||||
def content(self):
|
||||
return defer.succeed(self._content)
|
||||
|
||||
|
||||
class MockDevice:
|
||||
def __init__(self, manufacturer, model):
|
||||
self.manufacturer = manufacturer
|
||||
self.model = model
|
||||
device_path = os.path.join(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"), "{} {}".format(manufacturer, model)
|
||||
)
|
||||
assert os.path.isfile(device_path)
|
||||
with open(device_path, "r") as f:
|
||||
self.device_dict = json.loads(f.read())
|
||||
|
||||
def __repr__(self):
|
||||
return "MockDevice(manufacturer={}, model={})".format(self.manufacturer, self.model)
|
||||
|
||||
|
||||
def get_mock_devices():
|
||||
return [
|
||||
MockDevice(path.split(" ")[0], path.split(" ")[1])
|
||||
for path in os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "devices"))
|
||||
if ".py" not in path and "pycache" not in path
|
||||
]
|
||||
|
||||
|
||||
def get_device_test_case(manufacturer: str, model: str) -> MockDevice:
|
||||
r = [
|
||||
MockDevice(path.split(" ")[0], path.split(" ")[1])
|
||||
for path in os.listdir(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"))
|
||||
if ".py" not in path and "pycache" not in path and path.split(" ") == [manufacturer, model]
|
||||
]
|
||||
return r[0]
|
||||
|
||||
|
||||
class MockMulticastTransport:
|
||||
def __init__(self, address, port, max_packet_size, network):
|
||||
def __init__(self, address, port, max_packet_size, network, protocol):
|
||||
self.address = address
|
||||
self.port = port
|
||||
self.max_packet_size = max_packet_size
|
||||
self._network = network
|
||||
self._protocol = protocol
|
||||
|
||||
def write(self, data, address):
|
||||
if address in self._network.peers:
|
||||
for dest in self._network.peers[address]:
|
||||
dest.datagramReceived(data, (self.address, self.port))
|
||||
else: # the node is sending to an address that doesnt currently exist, act like it never arrived
|
||||
pass
|
||||
if address[0] in self._network.group:
|
||||
destinations = self._network.group[address[0]]
|
||||
else:
|
||||
destinations = address[0]
|
||||
for address, dest in self._network.peers.items():
|
||||
if address[0] in destinations and dest.address != self.address:
|
||||
dest._protocol.datagramReceived(data, (self.address, self.port))
|
||||
|
||||
def setTTL(self, ttl):
|
||||
pass
|
||||
|
||||
def joinGroup(self, address, interface=None):
|
||||
pass
|
||||
group = self._network.group.get(address, [])
|
||||
group.append(interface)
|
||||
self._network.group[address] = group
|
||||
|
||||
def leaveGroup(self, address, interface=None):
|
||||
pass
|
||||
group = self._network.group.get(address, [])
|
||||
if interface in group:
|
||||
group.remove(interface)
|
||||
self._network.group[address] = group
|
||||
|
||||
|
||||
class MockMulticastPort(object):
|
||||
def __init__(self, protocol, remover):
|
||||
class MockTCPTransport(_FakePort):
|
||||
def __init__(self, address, port, callback, mock_requests):
|
||||
super().__init__((address, port))
|
||||
self._callback = callback
|
||||
self._mock_requests = mock_requests
|
||||
|
||||
def write(self, data):
|
||||
if data.startswith(b"POST"):
|
||||
for url, packets in self._mock_requests['POST'].items():
|
||||
for request_response in packets:
|
||||
if data.decode() == request_response['request']:
|
||||
self._callback(request_response['response'].encode())
|
||||
return
|
||||
elif data.startswith(b"GET"):
|
||||
for url, packets in self._mock_requests['GET'].items():
|
||||
if data.decode() == packets['request']:
|
||||
self._callback(packets['response'].encode())
|
||||
return
|
||||
|
||||
|
||||
class MockMulticastPort(_FakePort):
|
||||
def __init__(self, protocol, remover, address, transport):
|
||||
super().__init__((address, 1900))
|
||||
self.protocol = protocol
|
||||
self._remover = remover
|
||||
self.transport = transport
|
||||
|
||||
def startListening(self, reason=None):
|
||||
self.protocol.transport = self.transport
|
||||
return self.protocol.startProtocol()
|
||||
|
||||
def stopListening(self, reason=None):
|
||||
|
@ -49,57 +123,55 @@ class MockMulticastPort(object):
|
|||
|
||||
class MockNetwork:
|
||||
def __init__(self):
|
||||
self.peers = {} # (interface, port): (protocol, max_packet_size)
|
||||
self.peers = {}
|
||||
self.group = {}
|
||||
|
||||
def add_peer(self, port, protocol, interface, maxPacketSize):
|
||||
protocol.transport = MockMulticastTransport(interface, port, maxPacketSize, self)
|
||||
peers = self.peers.get((interface, port), [])
|
||||
peers.append(protocol)
|
||||
self.peers[(interface, port)] = peers
|
||||
transport = MockMulticastTransport(interface, port, maxPacketSize, self, protocol)
|
||||
self.peers[(interface, port)] = transport
|
||||
|
||||
def remove_peer():
|
||||
if self.peers.get((interface, port)):
|
||||
self.peers[(interface, port)].remove(protocol)
|
||||
if not self.peers.get((interface, port)):
|
||||
del self.peers[(interface, port)]
|
||||
del protocol.transport
|
||||
return remove_peer
|
||||
return transport, remove_peer
|
||||
|
||||
|
||||
class MockReactor(task.Clock):
|
||||
def __init__(self):
|
||||
def __init__(self, client_addr, mock_scpd_requests):
|
||||
super().__init__()
|
||||
self.client_addr = client_addr
|
||||
self._mock_scpd_requests = mock_scpd_requests
|
||||
self.network = MockNetwork()
|
||||
|
||||
def listenMulticast(self, port, protocol, interface=SSDP_IP_ADDRESS, maxPacketSize=8192, listenMultiple=True):
|
||||
remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
|
||||
port = MockMulticastPort(protocol, remover)
|
||||
def listenMulticast(self, port, protocol, interface=None, maxPacketSize=8192, listenMultiple=True):
|
||||
interface = interface or self.client_addr
|
||||
transport, remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
|
||||
port = MockMulticastPort(protocol, remover, interface, transport)
|
||||
port.startListening()
|
||||
return port
|
||||
|
||||
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
|
||||
protocol = factory.buildProtocol(host)
|
||||
|
||||
def _write_and_close(data):
|
||||
protocol.dataReceived(data)
|
||||
protocol.connectionLost(Failure(ConnectionDone()))
|
||||
|
||||
protocol.transport = MockTCPTransport(host, port, _write_and_close, self._mock_scpd_requests)
|
||||
protocol.connectionMade()
|
||||
|
||||
|
||||
class MockSSDPServiceGatewayProtocol(DatagramProtocol):
|
||||
def __init__(self, iface, service_name, st, port, location, usn, version):
|
||||
def __init__(self, client_addr: int, iface: str, packets_rx: list, packets_tx: list):
|
||||
self.client_addr = client_addr
|
||||
self.iface = iface
|
||||
self.service_name = service_name
|
||||
self.gateway_st = st
|
||||
self.gateway_location = location
|
||||
self.gateway_usn = usn
|
||||
self.gateway_version = version
|
||||
self.gateway_port = port
|
||||
self.packets_tx = [SSDPDatagram.decode(packet.encode()) for packet in packets_tx] # sent by client
|
||||
self.packets_rx = [((addr, port), SSDPDatagram.decode(packet.encode())) for (addr, port), packet in packets_rx] # rx by client
|
||||
|
||||
def datagramReceived(self, datagram, address):
|
||||
try:
|
||||
packet = SSDPDatagram.decode(datagram)
|
||||
except UPnPError as err:
|
||||
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err,
|
||||
binascii.hexlify(datagram))
|
||||
return
|
||||
|
||||
if packet._packet_type == packet._M_SEARCH:
|
||||
if packet.man == "ssdp:discover" and packet.st == self.gateway_st:
|
||||
location = 'http://{}:{}/{}'.format(self.iface, self.gateway_port, self.gateway_location)
|
||||
response = SSDPDatagram(SSDPDatagram._OK, st=self.gateway_st, cache_control='max-age=1800',
|
||||
location=location, server=self.gateway_version, usn=self.gateway_usn)
|
||||
self.transport.write(response.encode().encode(), address)
|
||||
|
||||
packet = SSDPDatagram.decode(datagram)
|
||||
if packet.st in map(lambda p: p[1].st, self.packets_rx): # this contains one of the service types the server replied to
|
||||
reply = list(filter(lambda p: p[1].st == packet.st, self.packets_rx))[0][1]
|
||||
self.transport.write(reply.encode().encode(), (self.client_addr, 1900))
|
||||
else:
|
||||
pass
|
||||
|
|
652
txupnp/scpd.py
652
txupnp/scpd.py
|
@ -1,48 +1,224 @@
|
|||
import re
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.web.client import Agent
|
||||
import treq
|
||||
from treq.client import HTTPClient
|
||||
from xml.etree import ElementTree
|
||||
from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none
|
||||
from twisted.internet.protocol import Protocol, ClientFactory
|
||||
from twisted.internet import defer, error
|
||||
from txupnp.constants import XML_VERSION, DEVICE, ROOT, SERVICE, ENVELOPE, BODY
|
||||
from txupnp.util import etree_to_dict, flatten_keys
|
||||
from txupnp.fault import handle_fault, UPnPError
|
||||
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
|
||||
from txupnp.constants import BODY, POST
|
||||
from txupnp.dirty_pool import DirtyPool
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
CONTENT_PATTERN = re.compile(
|
||||
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
|
||||
)
|
||||
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
|
||||
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
|
||||
)
|
||||
|
||||
class StringProducer:
|
||||
def __init__(self, body):
|
||||
self.body = body
|
||||
self.length = len(body)
|
||||
|
||||
def startProducing(self, consumer):
|
||||
consumer.write(self.body)
|
||||
return defer.succeed(None)
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
||||
XML_ROOT_SANITY_PATTERN = re.compile(
|
||||
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
|
||||
)
|
||||
|
||||
|
||||
def xml_arg(name, arg):
|
||||
return "<%s>%s</%s>" % (name, arg, name)
|
||||
def parse_service_description(content: bytes):
|
||||
element_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
|
||||
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
|
||||
if "scpd" not in service_info:
|
||||
return []
|
||||
action_list = service_info["scpd"]["actionList"]
|
||||
if not len(action_list): # it could be an empty string
|
||||
return []
|
||||
result = []
|
||||
if isinstance(action_list["action"], dict):
|
||||
arg_dicts = action_list["action"]['argumentList']['argument']
|
||||
if not isinstance(arg_dicts, list): # when there is one arg, ew
|
||||
arg_dicts = [arg_dicts]
|
||||
return [[
|
||||
action_list["action"]['name'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'out']
|
||||
]]
|
||||
for action in action_list["action"]:
|
||||
if not action.get('argumentList'):
|
||||
result.append((action['name'], [], []))
|
||||
else:
|
||||
arg_dicts = action['argumentList']['argument']
|
||||
if not isinstance(arg_dicts, list): # when there is one arg, ew
|
||||
arg_dicts = [arg_dicts]
|
||||
result.append((
|
||||
action['name'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'out']
|
||||
))
|
||||
return result
|
||||
|
||||
|
||||
def get_soap_body(service_name, method, param_names, **kwargs):
|
||||
args = "".join(xml_arg(n, kwargs.get(n)) for n in param_names)
|
||||
return '\n%s\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body><u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (XML_VERSION, method, service_name, args, method)
|
||||
class SCPDHTTPClientProtocol(Protocol):
|
||||
def connectionMade(self):
|
||||
self.response_buff = b""
|
||||
self.factory.reactor.callLater(0, self.transport.write, self.factory.packet)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.response_buff += data
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if reason.trap(error.ConnectionDone):
|
||||
if XML_VERSION.encode() in self.response_buff:
|
||||
parsed = CONTENT_PATTERN.findall(self.response_buff)
|
||||
result = b'' if not parsed else parsed[0][0]
|
||||
self.factory.finished_deferred.callback(result)
|
||||
else:
|
||||
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(self.response_buff)
|
||||
result = b'' if not parsed else XML_VERSION.encode() + b'\r\n' + parsed[0][0]
|
||||
self.factory.finished_deferred.callback(result)
|
||||
|
||||
|
||||
class _SCPDCommand(object):
|
||||
def __init__(self, http_client, gateway_address, service_port, control_url, service_id, method, param_names,
|
||||
class SCPDHTTPClientFactory(ClientFactory):
|
||||
protocol = SCPDHTTPClientProtocol
|
||||
|
||||
def __init__(self, reactor, packet):
|
||||
self.reactor = reactor
|
||||
self.finished_deferred = defer.Deferred()
|
||||
self.packet = packet
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self.protocol()
|
||||
p.factory = self
|
||||
return p
|
||||
|
||||
@classmethod
|
||||
def post(cls, reactor, command, **kwargs):
|
||||
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in command.param_names)
|
||||
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
|
||||
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
|
||||
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
|
||||
XML_VERSION, command.method, command.service_id.decode(),
|
||||
args, command.method))
|
||||
data = (
|
||||
(
|
||||
'POST %s HTTP/1.1\r\n'
|
||||
'Host: %s\r\n'
|
||||
'User-Agent: Debian/buster/sid, UPnP/1.0, MiniUPnPc/1.9\r\n'
|
||||
'Content-Length: %i\r\n'
|
||||
'Content-Type: text/xml\r\n'
|
||||
'SOAPAction: \"%s#%s\"\r\n'
|
||||
'Connection: Close\r\n'
|
||||
'Cache-Control: no-cache\r\n'
|
||||
'Pragma: no-cache\r\n'
|
||||
'%s'
|
||||
'\r\n'
|
||||
) % (
|
||||
command.control_url.decode(), # could be just / even if it shouldn't be
|
||||
command.gateway_address.decode().split("http://")[1],
|
||||
len(soap_body),
|
||||
command.service_id.decode(), # maybe no quotes
|
||||
command.method,
|
||||
soap_body
|
||||
)
|
||||
).encode()
|
||||
return cls(reactor, data)
|
||||
|
||||
@classmethod
|
||||
def get(cls, reactor, control_url: str, address: str):
|
||||
data = (
|
||||
(
|
||||
'GET %s HTTP/1.1\r\n'
|
||||
'Accept-Encoding: gzip\r\n'
|
||||
'Host: %s\r\n'
|
||||
'\r\n'
|
||||
) % (control_url, address)
|
||||
).encode()
|
||||
return cls(reactor, data)
|
||||
|
||||
|
||||
class SCPDRequester:
|
||||
client_factory = SCPDHTTPClientFactory
|
||||
|
||||
def __init__(self, reactor):
|
||||
self._reactor = reactor
|
||||
self._get_requests = {}
|
||||
self._post_requests = {}
|
||||
|
||||
def _save_get(self, request: bytes, response: bytes, destination: str) -> None:
|
||||
self._get_requests[destination.lstrip("/")] = {
|
||||
'request': request,
|
||||
'response': response
|
||||
}
|
||||
|
||||
def _save_post(self, request: bytes, response: bytes, destination: str) -> None:
|
||||
p = self._post_requests.get(destination.lstrip("/"), [])
|
||||
p.append({
|
||||
'request': request,
|
||||
'response': response,
|
||||
})
|
||||
self._post_requests[destination.lstrip("/")] = p
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _scpd_get_soap_xml(self, control_url: str, address: str, service_port: int) -> bytes:
|
||||
factory = self.client_factory.get(self._reactor, control_url, address)
|
||||
url = address.split("http://")[1].split(":")[0]
|
||||
self._reactor.connectTCP(url, service_port, factory)
|
||||
xml_response_bytes = yield factory.finished_deferred
|
||||
self._save_get(factory.packet, xml_response_bytes, control_url)
|
||||
return xml_response_bytes
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def scpd_post_soap(self, command, **kwargs) -> tuple:
|
||||
factory = self.client_factory.post(self._reactor, command, **kwargs)
|
||||
url = command.gateway_address.split(b"http://")[1].split(b":")[0]
|
||||
self._reactor.connectTCP(url.decode(), command.service_port, factory)
|
||||
xml_response_bytes = yield factory.finished_deferred
|
||||
self._save_post(
|
||||
factory.packet, xml_response_bytes, command.gateway_address.decode() + command.control_url.decode()
|
||||
)
|
||||
content_dict = etree_to_dict(ElementTree.fromstring(xml_response_bytes.decode()))
|
||||
envelope = content_dict[ENVELOPE]
|
||||
response_body = flatten_keys(envelope[BODY], "{%s}" % command.service_id)
|
||||
body = handle_fault(response_body) # raises UPnPError if there is a fault
|
||||
response_key = None
|
||||
for key in body:
|
||||
if command.method in key:
|
||||
response_key = key
|
||||
break
|
||||
if not response_key:
|
||||
raise UPnPError("unknown response fields for %s")
|
||||
response = body[response_key]
|
||||
extracted_response = tuple([response[n] for n in command.returns])
|
||||
return extracted_response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def scpd_get_supported_actions(self, service, address: str, port: int) -> list:
|
||||
xml_bytes = yield self._scpd_get_soap_xml(service.SCPDURL, address, port)
|
||||
return parse_service_description(xml_bytes)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def scpd_get(self, control_url: str, service_address: str, service_port: int) -> dict:
|
||||
xml_bytes = yield self._scpd_get_soap_xml(control_url, service_address, service_port)
|
||||
xml_dict = etree_to_dict(ElementTree.fromstring(xml_bytes.decode()))
|
||||
schema_key = DEVICE
|
||||
root = ROOT
|
||||
for k in xml_dict.keys():
|
||||
m = XML_ROOT_SANITY_PATTERN.findall(k)
|
||||
if len(m) == 3 and m[1][0] and m[2][5]:
|
||||
schema_key = m[1][0]
|
||||
root = m[2][5]
|
||||
break
|
||||
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
||||
return flattened_xml
|
||||
|
||||
def dump_packets(self) -> dict:
|
||||
return {
|
||||
'GET': self._get_requests,
|
||||
'POST': self._post_requests
|
||||
}
|
||||
|
||||
|
||||
class SCPDCommand:
|
||||
def __init__(self, scpd_requester: SCPDRequester, gateway_address, service_port, control_url, service_id, method,
|
||||
param_names,
|
||||
returns):
|
||||
self._http_client = http_client
|
||||
self.scpd_requester = scpd_requester
|
||||
self.gateway_address = gateway_address
|
||||
self.service_port = service_port
|
||||
self.control_url = control_url
|
||||
|
@ -51,59 +227,8 @@ class _SCPDCommand(object):
|
|||
self.param_names = param_names
|
||||
self.returns = returns
|
||||
|
||||
def extract_body(self, xml_response):
|
||||
content_dict = etree_to_dict(ElementTree.fromstring(xml_response))
|
||||
envelope = content_dict[ENVELOPE]
|
||||
return flatten_keys(envelope[BODY], "{%s}" % self.service_id)
|
||||
|
||||
def extract_response(self, body):
|
||||
body = handle_fault(body) # raises UPnPError if there is a fault
|
||||
response_key = None
|
||||
for key in body:
|
||||
if self.method in key:
|
||||
response_key = key
|
||||
break
|
||||
if not response_key:
|
||||
raise UPnPError("unknown response fields for %s")
|
||||
response = body[response_key]
|
||||
extracted_response = tuple([response[n] for n in self.returns])
|
||||
if len(extracted_response) == 1:
|
||||
return extracted_response[0]
|
||||
return extracted_response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_upnp_soap(self, **kwargs):
|
||||
soap_body = get_soap_body(self.service_id, self.method, self.param_names, **kwargs).encode()
|
||||
headers = OrderedDict((
|
||||
('SOAPAction', '%s#%s' % (self.service_id, self.method)),
|
||||
('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))),
|
||||
('Content-Type', 'text/xml'),
|
||||
('Content-Length', len(soap_body))
|
||||
))
|
||||
log.debug("send POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body)
|
||||
try:
|
||||
response = yield self._http_client.request(
|
||||
POST, url=self.control_url, data=soap_body, headers=headers
|
||||
)
|
||||
except Exception as err:
|
||||
log.error("error (%s) sending POST to %s\nheaders: %s\nbody:%s\n", err, self.control_url, headers,
|
||||
soap_body)
|
||||
raise UPnPError(err)
|
||||
|
||||
xml_response = yield response.content()
|
||||
try:
|
||||
response = self.extract_response(self.extract_body(xml_response))
|
||||
except UPnPError:
|
||||
raise
|
||||
except Exception as err:
|
||||
log.debug("error extracting response (%s) to %s:\n%s", err, self.method, xml_response)
|
||||
raise err
|
||||
if not response:
|
||||
log.debug("empty response to %s\n%s", self.method, xml_response)
|
||||
defer.returnValue(response)
|
||||
|
||||
@staticmethod
|
||||
def _process_result(results):
|
||||
def _process_result(*results):
|
||||
"""
|
||||
this method gets decorated automatically with a function that maps result types to the types
|
||||
defined in the @return_types decorator
|
||||
|
@ -114,367 +239,10 @@ class _SCPDCommand(object):
|
|||
def __call__(self, **kwargs):
|
||||
if set(kwargs.keys()) != set(self.param_names):
|
||||
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_names))
|
||||
response = yield self.send_upnp_soap(**kwargs)
|
||||
response = yield self.scpd_requester.scpd_post_soap(self, **kwargs)
|
||||
try:
|
||||
result = self._process_result(response)
|
||||
result = self._process_result(*response)
|
||||
except Exception as err:
|
||||
log.error("error formatting response (%s):\n%s", err, response)
|
||||
raise err
|
||||
defer.returnValue(result)
|
||||
|
||||
|
||||
class SCPDResponse(object):
|
||||
def __init__(self, url, headers, content):
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
self.content = content
|
||||
|
||||
def get_element_tree(self):
|
||||
return ElementTree.fromstring(self.content)
|
||||
|
||||
def get_element_dict(self, service_key):
|
||||
return flatten_keys(etree_to_dict(self.get_element_tree()), "{%s}" % service_key)
|
||||
|
||||
def get_action_list(self):
|
||||
return self.get_element_dict(SERVICE)["scpd"]["actionList"]["action"]
|
||||
|
||||
def get_device_info(self):
|
||||
return self.get_element_dict(DEVICE)[ROOT]
|
||||
|
||||
|
||||
class SCPDCommandRunner(object):
|
||||
def __init__(self, gateway, reactor, treq_get=None):
|
||||
self._gateway = gateway
|
||||
self._unsupported_actions = {}
|
||||
self._registered_commands = {}
|
||||
self._reactor = reactor
|
||||
self._connection_pool = DirtyPool(reactor)
|
||||
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
|
||||
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
|
||||
self._treq_get = treq_get or treq.get
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _discover_commands(self, service):
|
||||
scpd_url = self._gateway.base_address + service.SCPDURL.encode()
|
||||
response = yield self._treq_get(scpd_url)
|
||||
content = yield response.content()
|
||||
try:
|
||||
scpd_response = SCPDResponse(scpd_url, response.headers, content)
|
||||
for action_dict in scpd_response.get_action_list():
|
||||
self._register_command(action_dict, service.serviceType)
|
||||
except Exception as err:
|
||||
log.exception("failed to parse scpd response (%s) from %s\nheaders:\n%s\ncontent\n%s",
|
||||
err, scpd_url, response.headers, content)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover_commands(self):
|
||||
for service_type in service_types:
|
||||
service = self._gateway.get_service(service_type)
|
||||
if not service:
|
||||
continue
|
||||
yield self._discover_commands(service)
|
||||
log.debug(self.debug_commands())
|
||||
|
||||
@staticmethod
|
||||
def _soap_function_info(action_dict):
|
||||
if not action_dict.get('argumentList'):
|
||||
return (
|
||||
action_dict['name'],
|
||||
[],
|
||||
[]
|
||||
)
|
||||
arg_dicts = action_dict['argumentList']['argument']
|
||||
if not isinstance(arg_dicts, list): # when there is one arg, ew
|
||||
arg_dicts = [arg_dicts]
|
||||
return (
|
||||
action_dict['name'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
|
||||
[i['name'] for i in arg_dicts if i['direction'] == 'out']
|
||||
)
|
||||
|
||||
def _patch_command(self, action_info, service_type):
|
||||
name, inputs, outputs = self._soap_function_info(action_info)
|
||||
command = _SCPDCommand(self._http_client, self._gateway.base_address, self._gateway.port,
|
||||
self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(),
|
||||
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
|
||||
current = getattr(self, command.method)
|
||||
if hasattr(current, "_return_types"):
|
||||
command._process_result = verify_return_types(*current._return_types)(command._process_result)
|
||||
setattr(command, "__doc__", current.__doc__)
|
||||
setattr(self, command.method, command)
|
||||
self._registered_commands[command.method] = service_type
|
||||
log.debug("registered %s %s", service_type, action_info['name'])
|
||||
return True
|
||||
|
||||
def _register_command(self, action_info, service_type):
|
||||
try:
|
||||
return self._patch_command(action_info, service_type)
|
||||
except Exception as err:
|
||||
s = self._unsupported_actions.get(service_type, [])
|
||||
s.append((action_info, err))
|
||||
self._unsupported_actions[service_type] = s
|
||||
log.error("available command for %s does not have a wrapper implemented: %s", service_type, action_info)
|
||||
|
||||
def debug_commands(self):
|
||||
return {
|
||||
'available': self._registered_commands,
|
||||
'failed': self._unsupported_actions
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
|
||||
NewEnabled, NewPortMappingDescription, NewLeaseDuration=''):
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(bool, bool)
|
||||
def GetNATRSIPStatus():
|
||||
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none_or_str, int, str, int, str, bool, str, int)
|
||||
def GetGenericPortMappingEntry(NewPortMappingIndex):
|
||||
"""
|
||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||
NewPortMappingDescription, NewLeaseDuration)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(int, str, bool, str, int)
|
||||
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol):
|
||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def SetConnectionType(NewConnectionType):
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str)
|
||||
def GetExternalIPAddress():
|
||||
"""Returns (NewExternalIPAddress)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str, str)
|
||||
def GetConnectionTypeInfo():
|
||||
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str, str, int)
|
||||
def GetStatusInfo():
|
||||
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def ForceTermination():
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol):
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def RequestConnection():
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetCommonLinkProperties():
|
||||
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalBytesSent():
|
||||
"""Returns (NewTotalBytesSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalBytesReceived():
|
||||
"""Returns (NewTotalBytesReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalPacketsSent():
|
||||
"""Returns (NewTotalPacketsSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetTotalPacketsReceived():
|
||||
"""Returns (NewTotalPacketsReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def X_GetICSStatistics():
|
||||
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetDefaultConnectionService():
|
||||
"""Returns (NewDefaultConnectionService)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def SetDefaultConnectionService(NewDefaultConnectionService):
|
||||
"""Returns (None)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(none)
|
||||
def SetEnabledForInternet(NewEnabledForInternet):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(bool)
|
||||
def GetEnabledForInternet():
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def GetMaximumActiveConnections(NewActiveConnectionIndex):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@return_types(str, str)
|
||||
def GetActiveConnections():
|
||||
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class UPnPFallback(object):
|
||||
def __init__(self):
|
||||
try:
|
||||
import miniupnpc
|
||||
self._upnp = miniupnpc.UPnP()
|
||||
self.available = True
|
||||
except ImportError:
|
||||
self._upnp = None
|
||||
self.available = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover(self):
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
devices = yield threads.deferToThread(self._upnp.discover)
|
||||
if devices:
|
||||
self.device_url = yield threads.deferToThread(self._upnp.selectigd)
|
||||
else:
|
||||
self.device_url = None
|
||||
|
||||
defer.returnValue(devices > 0)
|
||||
|
||||
@return_types(none)
|
||||
def AddPortMapping(self, NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
|
||||
NewEnabled, NewPortMappingDescription, NewLeaseDuration=''):
|
||||
"""Returns None"""
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
return threads.deferToThread(self._upnp.addportmapping, NewExternalPort, NewProtocol, NewInternalClient,
|
||||
NewInternalPort, NewPortMappingDescription, NewLeaseDuration)
|
||||
|
||||
def GetNATRSIPStatus(self):
|
||||
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@return_types(none_or_str, int, str, int, str, bool, str, int)
|
||||
@defer.inlineCallbacks
|
||||
def GetGenericPortMappingEntry(self, NewPortMappingIndex):
|
||||
"""
|
||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||
NewPortMappingDescription, NewLeaseDuration)
|
||||
"""
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
result = yield threads.deferToThread(self._upnp.getgenericportmapping, NewPortMappingIndex)
|
||||
if not result:
|
||||
raise UPnPError()
|
||||
ext_port, protocol, (int_host, int_port), desc, enabled, remote_host, lease = result
|
||||
defer.returnValue((remote_host, ext_port, protocol, int_port, int_host, enabled, desc, lease))
|
||||
|
||||
@return_types(int, str, bool, str, int)
|
||||
def GetSpecificPortMappingEntry(self, NewRemoteHost, NewExternalPort, NewProtocol):
|
||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
return threads.deferToThread(self._upnp.getspecificportmapping, NewExternalPort, NewProtocol)
|
||||
|
||||
def SetConnectionType(self, NewConnectionType):
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@return_types(str)
|
||||
def GetExternalIPAddress(self):
|
||||
"""Returns (NewExternalIPAddress)"""
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
return threads.deferToThread(self._upnp.externalipaddress)
|
||||
|
||||
def GetConnectionTypeInfo(self):
|
||||
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@return_types(str, str, int)
|
||||
def GetStatusInfo(self):
|
||||
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
return threads.deferToThread(self._upnp.statusinfo)
|
||||
|
||||
def ForceTermination(self):
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@return_types(none)
|
||||
def DeletePortMapping(self, NewRemoteHost, NewExternalPort, NewProtocol):
|
||||
"""Returns None"""
|
||||
if not self.available:
|
||||
raise NotImplementedError()
|
||||
return threads.deferToThread(self._upnp.deleteportmapping, NewExternalPort, NewProtocol)
|
||||
|
||||
def RequestConnection(self):
|
||||
"""Returns None"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetCommonLinkProperties(self):
|
||||
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetTotalBytesSent(self):
|
||||
"""Returns (NewTotalBytesSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetTotalBytesReceived(self):
|
||||
"""Returns (NewTotalBytesReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetTotalPacketsSent(self):
|
||||
"""Returns (NewTotalPacketsSent)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetTotalPacketsReceived(self):
|
||||
"""Returns (NewTotalPacketsReceived)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def X_GetICSStatistics(self):
|
||||
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetDefaultConnectionService(self):
|
||||
"""Returns (NewDefaultConnectionService)"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def SetDefaultConnectionService(self, NewDefaultConnectionService):
|
||||
"""Returns (None)"""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
import logging
|
||||
import netifaces
|
||||
from twisted.internet import defer
|
||||
from txupnp.ssdp import SSDPFactory
|
||||
from txupnp.scpd import SCPDCommandRunner
|
||||
from txupnp.gateway import Gateway
|
||||
from txupnp.fault import UPnPError
|
||||
from txupnp.constants import UPNP_ORG_IGD
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SOAPServiceManager(object):
|
||||
def __init__(self, reactor, treq_get=None):
|
||||
self._reactor = reactor
|
||||
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
|
||||
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
|
||||
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
|
||||
self._command_runners = {}
|
||||
self._selected_runner = UPNP_ORG_IGD
|
||||
self._treq_get = treq_get
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover_services(self, address=None, timeout=30, max_devices=1):
|
||||
server_infos = yield self.sspd_factory.m_search(
|
||||
address or self.router_ip, timeout=timeout, max_devices=max_devices
|
||||
)
|
||||
locations = []
|
||||
for server_info in server_infos:
|
||||
if 'st' in server_info and server_info['st'] not in self._command_runners:
|
||||
locations.append(server_info['location'])
|
||||
gateway = Gateway(**server_info)
|
||||
yield gateway.discover_services()
|
||||
command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get)
|
||||
yield command_runner.discover_commands()
|
||||
self._command_runners[gateway.urn.decode()] = command_runner
|
||||
elif 'st' not in server_info:
|
||||
log.error("don't know how to handle gateway: %s", server_info)
|
||||
continue
|
||||
defer.returnValue(len(self._command_runners) > 0)
|
||||
|
||||
def set_runner(self, urn):
|
||||
if urn not in self._command_runners:
|
||||
raise IndexError(urn)
|
||||
self._command_runners = urn
|
||||
|
||||
def get_runner(self):
|
||||
if self._command_runners and not self._selected_runner in self._command_runners:
|
||||
self._selected_runner = list(self._command_runners.keys())[0]
|
||||
if not self._command_runners:
|
||||
raise UPnPError("no devices found")
|
||||
return self._command_runners[self._selected_runner]
|
||||
|
||||
def get_available_runners(self):
|
||||
return self._command_runners.keys()
|
||||
|
||||
def debug(self, include_gateway_xml=False):
|
||||
results = []
|
||||
for runner in self._command_runners.values():
|
||||
gateway = runner._gateway
|
||||
info = gateway.debug_device(include_xml=include_gateway_xml)
|
||||
commands = runner.debug_commands()
|
||||
service_result = []
|
||||
for service in info['services']:
|
||||
service_commands = []
|
||||
unavailable = []
|
||||
for command, service_type in commands['available'].items():
|
||||
if service['serviceType'] == service_type:
|
||||
service_commands.append(command)
|
||||
for command, service_type in commands['failed'].items():
|
||||
if service['serviceType'] == service_type:
|
||||
unavailable.append(command)
|
||||
services_with_commands = dict(service)
|
||||
services_with_commands['available_commands'] = service_commands
|
||||
services_with_commands['unavailable_commands'] = unavailable
|
||||
service_result.append(services_with_commands)
|
||||
info['services'] = service_result
|
||||
results.append(info)
|
||||
return results
|
|
@ -12,7 +12,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class SSDPProtocol(DatagramProtocol):
|
||||
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
|
||||
ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
|
||||
ssdp_port=SSDP_PORT, ttl=1, max_devices=None, debug_packets=False,
|
||||
debug_sent=None, debug_received=None):
|
||||
self._reactor = reactor
|
||||
self._sem = defer.DeferredSemaphore(1)
|
||||
self.discover_callbacks = {}
|
||||
|
@ -24,34 +25,36 @@ class SSDPProtocol(DatagramProtocol):
|
|||
self._start = None
|
||||
self.max_devices = max_devices
|
||||
self.devices = []
|
||||
self.debug_packets = debug_packets
|
||||
self.debug_sent = debug_sent if debug_sent is not None else []
|
||||
self.debug_received = debug_received if debug_sent is not None else []
|
||||
|
||||
def _send_m_search(self, service=UPNP_ORG_IGD):
|
||||
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
|
||||
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
|
||||
try:
|
||||
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
|
||||
msg_bytes = packet.encode().encode()
|
||||
if self.debug_packets:
|
||||
self.debug_sent.append(msg_bytes)
|
||||
self.transport.write(msg_bytes, (self.ssdp_address, self.ssdp_port))
|
||||
except Exception as err:
|
||||
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
|
||||
raise err
|
||||
|
||||
@staticmethod
|
||||
def _gather(finished_deferred, max_results):
|
||||
results = []
|
||||
|
||||
def _gather(finished_deferred, max_results, results: list):
|
||||
def discover_cb(packet):
|
||||
if not finished_deferred.called:
|
||||
if not finished_deferred.called and packet.st in service_types:
|
||||
results.append(packet.as_dict())
|
||||
if len(results) >= max_results:
|
||||
finished_deferred.callback(results)
|
||||
|
||||
return discover_cb
|
||||
|
||||
def m_search(self, address=None, timeout=1, max_devices=1):
|
||||
address = address or self.iface
|
||||
|
||||
def m_search(self, address, timeout, max_devices):
|
||||
# return deferred for a pending call if we have one
|
||||
if address in self.discover_callbacks:
|
||||
d = self.protocol.discover_callbacks[address][1]
|
||||
d = self.discover_callbacks[address][1]
|
||||
if not d.called: # the existing deferred has already fired, make a new one
|
||||
return d
|
||||
|
||||
|
@ -63,7 +66,7 @@ class SSDPProtocol(DatagramProtocol):
|
|||
d = defer.Deferred()
|
||||
d.addTimeout(timeout, self._reactor)
|
||||
d.addErrback(_trap_timeout_and_return_results)
|
||||
found_cb = self._gather(d, max_devices)
|
||||
found_cb = self._gather(d, max_devices, self.devices)
|
||||
self.discover_callbacks[address] = found_cb, d
|
||||
for st in service_types:
|
||||
self._send_m_search(service=st)
|
||||
|
@ -73,11 +76,12 @@ class SSDPProtocol(DatagramProtocol):
|
|||
self._start = self._reactor.seconds()
|
||||
self.transport.setTTL(self.ttl)
|
||||
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
|
||||
self.m_search()
|
||||
|
||||
def datagramReceived(self, datagram, address):
|
||||
if address[0] == self.iface:
|
||||
return
|
||||
if self.debug_packets:
|
||||
self.debug_received.append((address, datagram))
|
||||
try:
|
||||
packet = SSDPDatagram.decode(datagram)
|
||||
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
|
||||
|
@ -90,23 +94,30 @@ class SSDPProtocol(DatagramProtocol):
|
|||
return
|
||||
if packet._packet_type == packet._OK:
|
||||
log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
|
||||
if packet.st not in map(lambda p: p['st'], self.devices):
|
||||
self.devices.append(packet.as_dict())
|
||||
log.debug("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
|
||||
if address[0] in self.discover_callbacks:
|
||||
# if address[0] in self.discover_callbacks and packet.location not in map(lambda p: p['location'], self.devices):
|
||||
if packet.location not in map(lambda p: p['location'], self.devices):
|
||||
if address[0] not in self.discover_callbacks:
|
||||
self.devices.append(packet.as_dict())
|
||||
else:
|
||||
self._sem.run(self.discover_callbacks[address[0]][0], packet)
|
||||
else:
|
||||
log.info("ignored packet from %s:%s (%s) %s", address[0], address[1], packet._packet_type, packet.location)
|
||||
elif packet._packet_type == packet._NOTIFY:
|
||||
log.debug("%s:%i sent us a notification (type: %s), url: %s", address[0], address[1], packet.nts,
|
||||
packet.location)
|
||||
|
||||
|
||||
class SSDPFactory(object):
|
||||
def __init__(self, reactor, lan_address, router_address):
|
||||
class SSDPFactory:
|
||||
def __init__(self, reactor, lan_address, router_address, debug_packets=False):
|
||||
self.lan_address = lan_address
|
||||
self.router_address = router_address
|
||||
self._reactor = reactor
|
||||
self.protocol = None
|
||||
self.port = None
|
||||
self.debug_packets = debug_packets
|
||||
self.debug_sent = []
|
||||
self.debug_received = []
|
||||
self.server_infos = []
|
||||
|
||||
def disconnect(self):
|
||||
if not self.port:
|
||||
|
@ -118,13 +129,15 @@ class SSDPFactory(object):
|
|||
|
||||
def connect(self):
|
||||
if not self.protocol:
|
||||
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
|
||||
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address,
|
||||
debug_packets=self.debug_packets, debug_sent=self.debug_sent,
|
||||
debug_received=self.debug_received)
|
||||
if not self.port:
|
||||
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
|
||||
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def m_search(self, address, timeout=1, max_devices=1):
|
||||
def m_search(self, address, timeout, max_devices):
|
||||
"""
|
||||
Perform a M-SEARCH (HTTP over UDP) and gather the results
|
||||
|
||||
|
@ -140,7 +153,16 @@ class SSDPFactory(object):
|
|||
'usn': (str) usn
|
||||
}, ...]
|
||||
"""
|
||||
|
||||
self.connect()
|
||||
server_infos = yield self.protocol.m_search(address, timeout, max_devices)
|
||||
for server_info in server_infos:
|
||||
self.server_infos.append(server_info)
|
||||
defer.returnValue(server_infos)
|
||||
|
||||
def get_ssdp_packet_replay(self) -> dict:
|
||||
return {
|
||||
'lan_address': self.lan_address,
|
||||
'router_address': self.router_address,
|
||||
'sent': self.debug_sent,
|
||||
'received': self.debug_received,
|
||||
}
|
||||
|
|
166
txupnp/upnp.py
166
txupnp/upnp.py
|
@ -1,45 +1,29 @@
|
|||
import netifaces
|
||||
import logging
|
||||
import json
|
||||
from twisted.internet import defer
|
||||
from txupnp.fault import UPnPError
|
||||
from txupnp.soap import SOAPServiceManager
|
||||
from txupnp.scpd import UPnPFallback
|
||||
from txupnp.ssdp import SSDPFactory
|
||||
from txupnp.gateway import Gateway
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UPnP(object):
|
||||
def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None):
|
||||
class UPnP:
|
||||
def __init__(self, reactor, try_miniupnpc_fallback=False, debug_ssdp=False, router_ip=None,
|
||||
lan_ip=None, iface_name=None):
|
||||
self._reactor = reactor
|
||||
if router_ip and lan_ip and iface_name:
|
||||
self.router_ip, self.lan_address, self.iface_name = router_ip, lan_ip, iface_name
|
||||
else:
|
||||
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
|
||||
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
|
||||
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip, debug_packets=debug_ssdp)
|
||||
self.try_miniupnpc_fallback = try_miniupnpc_fallback
|
||||
self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get)
|
||||
self.miniupnpc_runner = None
|
||||
self.miniupnpc_igd_url = None
|
||||
self.gateway = None
|
||||
|
||||
@property
|
||||
def lan_address(self):
|
||||
return self.soap_manager.lan_address
|
||||
|
||||
@property
|
||||
def commands(self):
|
||||
try:
|
||||
runner = self.soap_manager.get_runner()
|
||||
required_commands = [
|
||||
"GetExternalIPAddress",
|
||||
"AddPortMapping",
|
||||
"GetSpecificPortMappingEntry",
|
||||
"GetGenericPortMappingEntry",
|
||||
"DeletePortMapping"
|
||||
]
|
||||
if all((command in runner._registered_commands for command in required_commands)):
|
||||
return runner
|
||||
raise UPnPError("required commands not found")
|
||||
except UPnPError as err:
|
||||
if self.try_miniupnpc_fallback and self.miniupnpc_runner:
|
||||
return self.miniupnpc_runner
|
||||
log.warning("upnp is not available: %s", err)
|
||||
|
||||
def m_search(self, address, timeout=30, max_devices=2):
|
||||
def m_search(self, address, timeout=1, max_devices=1):
|
||||
"""
|
||||
Perform a HTTP over UDP M-SEARCH query
|
||||
|
||||
|
@ -51,68 +35,52 @@ class UPnP(object):
|
|||
'usn': <usn>
|
||||
}, ...]
|
||||
"""
|
||||
return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
|
||||
return self.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover(self, timeout=1, max_devices=1, keep_listening=False, try_txupnp=True):
|
||||
found = False
|
||||
if not try_txupnp and not self.try_miniupnpc_fallback:
|
||||
log.warning("nothing left to try")
|
||||
if try_txupnp:
|
||||
try:
|
||||
found = yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices)
|
||||
except defer.TimeoutError:
|
||||
found = False
|
||||
finally:
|
||||
if not keep_listening:
|
||||
self.soap_manager.sspd_factory.disconnect()
|
||||
def _discover(self, timeout=1, max_devices=1):
|
||||
server_infos = yield self.sspd_factory.m_search(
|
||||
self.router_ip, timeout=timeout, max_devices=max_devices
|
||||
)
|
||||
server_info = server_infos[0]
|
||||
if 'st' in server_info:
|
||||
gateway = Gateway(reactor=self._reactor, **server_info)
|
||||
yield gateway.discover_commands()
|
||||
self.gateway = gateway
|
||||
defer.returnValue(True)
|
||||
elif 'st' not in server_info:
|
||||
log.error("don't know how to handle gateway: %s", server_info)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def discover(self, timeout=1, max_devices=1):
|
||||
try:
|
||||
found = yield self._discover(timeout=timeout, max_devices=max_devices)
|
||||
except defer.TimeoutError:
|
||||
found = False
|
||||
finally:
|
||||
self.sspd_factory.disconnect()
|
||||
if found:
|
||||
try:
|
||||
runner = self.soap_manager.get_runner()
|
||||
required_commands = [
|
||||
"GetExternalIPAddress",
|
||||
"AddPortMapping",
|
||||
"GetSpecificPortMappingEntry",
|
||||
"GetGenericPortMappingEntry",
|
||||
"DeletePortMapping"
|
||||
]
|
||||
found = all((command in runner._registered_commands for command in required_commands))
|
||||
except UPnPError:
|
||||
found = False
|
||||
if not found and self.try_miniupnpc_fallback:
|
||||
found = yield self.start_miniupnpc_fallback()
|
||||
log.debug("found upnp device")
|
||||
else:
|
||||
log.debug("failed to find upnp device")
|
||||
defer.returnValue(found)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_miniupnpc_fallback(self):
|
||||
found = False
|
||||
if not self.miniupnpc_runner:
|
||||
log.debug("trying miniupnpc fallback")
|
||||
fallback = UPnPFallback()
|
||||
success = yield fallback.discover()
|
||||
self.miniupnpc_igd_url = fallback.device_url
|
||||
if success:
|
||||
log.info("successfully started miniupnpc fallback")
|
||||
self.miniupnpc_runner = fallback
|
||||
found = True
|
||||
if not found:
|
||||
log.warning("failed to find upnp gateway using miniupnpc fallback")
|
||||
defer.returnValue(found)
|
||||
def get_external_ip(self) -> str:
|
||||
return self.gateway.commands.GetExternalIPAddress()
|
||||
|
||||
def get_external_ip(self):
|
||||
return self.commands.GetExternalIPAddress()
|
||||
|
||||
def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description):
|
||||
return self.commands.AddPortMapping(
|
||||
def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
|
||||
description: str) -> None:
|
||||
return self.gateway.commands.AddPortMapping(
|
||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
|
||||
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
||||
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_port_mapping_by_index(self, index):
|
||||
def get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
|
||||
try:
|
||||
redirect = yield self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||
redirect = yield self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||
defer.returnValue(redirect)
|
||||
except UPnPError:
|
||||
defer.returnValue(None)
|
||||
|
@ -129,7 +97,7 @@ class UPnP(object):
|
|||
defer.returnValue(redirects)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_specific_port_mapping(self, external_port, protocol):
|
||||
def get_specific_port_mapping(self, external_port: int, protocol: str) -> (int, str, bool, str, int):
|
||||
"""
|
||||
:param external_port: (int) external port to listen on
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
|
@ -137,41 +105,23 @@ class UPnP(object):
|
|||
"""
|
||||
|
||||
try:
|
||||
result = yield self.commands.GetSpecificPortMappingEntry(
|
||||
result = yield self.gateway.commands.GetSpecificPortMappingEntry(
|
||||
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except UPnPError:
|
||||
defer.returnValue(None)
|
||||
|
||||
def delete_port_mapping(self, external_port, protocol, new_remote_host=""):
|
||||
def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
||||
"""
|
||||
:param external_port: (int) external port to listen on
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
:return: None
|
||||
"""
|
||||
return self.commands.DeletePortMapping(
|
||||
NewRemoteHost=new_remote_host, NewExternalPort=external_port, NewProtocol=protocol
|
||||
return self.gateway.commands.DeletePortMapping(
|
||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
|
||||
def get_rsip_nat_status(self):
|
||||
"""
|
||||
:return: (bool) NewRSIPAvailable, (bool) NewNATEnabled
|
||||
"""
|
||||
return self.commands.GetNATRSIPStatus()
|
||||
|
||||
def get_status_info(self):
|
||||
"""
|
||||
:return: (str) NewConnectionStatus, (str) NewLastConnectionError, (int) NewUptime
|
||||
"""
|
||||
return self.commands.GetStatusInfo()
|
||||
|
||||
def get_connection_type_info(self):
|
||||
"""
|
||||
:return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str)
|
||||
"""
|
||||
return self.commands.GetConnectionTypeInfo()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_next_mapping(self, port, protocol, description, internal_port=None):
|
||||
if protocol not in ["UDP", "TCP"]:
|
||||
|
@ -192,15 +142,3 @@ class UPnP(object):
|
|||
port, protocol, internal_port, self.lan_address, description
|
||||
)
|
||||
defer.returnValue(port)
|
||||
|
||||
def get_debug_info(self, include_gateway_xml=False):
|
||||
def default_byte(x):
|
||||
if isinstance(x, bytes):
|
||||
return x.decode()
|
||||
return x
|
||||
return json.dumps({
|
||||
'txupnp': self.soap_manager.debug(include_gateway_xml=include_gateway_xml),
|
||||
'miniupnpc_igd_url': self.miniupnpc_igd_url
|
||||
},
|
||||
indent=2, default=default_byte
|
||||
)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import re
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from twisted.internet import defer
|
||||
from xml.etree import ElementTree
|
||||
|
||||
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
|
||||
|
@ -59,13 +58,11 @@ def verify_return_types(*types):
|
|||
|
||||
def _verify_return_types(fn):
|
||||
@functools.wraps(fn)
|
||||
def _inner(response):
|
||||
if isinstance(response, (list, tuple)):
|
||||
r = tuple(t(r) for t, r in zip(types, response))
|
||||
if len(r) == 1:
|
||||
return fn(r[0])
|
||||
return fn(r)
|
||||
return fn(types[0](response))
|
||||
def _inner(*result):
|
||||
r = fn(*tuple(t(r) for t, r in zip(types, result)))
|
||||
if isinstance(r, tuple) and len(r) == 1:
|
||||
return r[0]
|
||||
return r
|
||||
return _inner
|
||||
return _verify_return_types
|
||||
|
||||
|
@ -85,18 +82,3 @@ def return_types(*types):
|
|||
none_or_str = lambda x: None if not x or x == 'None' else str(x)
|
||||
|
||||
none = lambda _: None
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def DeferredDict(d, consumeErrors=False):
|
||||
keys = []
|
||||
dl = []
|
||||
response = {}
|
||||
for k, v in d.items():
|
||||
keys.append(k)
|
||||
dl.append(v)
|
||||
results = yield defer.DeferredList(dl, consumeErrors=consumeErrors)
|
||||
for k, (success, result) in zip(keys, results):
|
||||
if success:
|
||||
response[k] = result
|
||||
defer.returnValue(response)
|
||||
|
|
Loading…
Reference in a new issue