Merge branch 'testing'

This commit is contained in:
Jack Robison 2018-10-04 17:17:52 -04:00
commit d9fee45fc7
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
18 changed files with 1017 additions and 948 deletions

1
.gitignore vendored
View file

@ -3,3 +3,4 @@
_trial_temp/
build/
dist/
.coverage

View file

@ -2,7 +2,7 @@
# UPnP for Twisted
`txupnp` is a python2/3 library to interact with UPnP gateways using Twisted
`txupnp` is a python 3 library to interact with UPnP gateways using `twisted`
## Installation

View file

@ -21,7 +21,7 @@ setup(
long_description=long_description,
url="https://github.com/lbryio/txupnp",
license=__license__,
packages=find_packages(),
packages=find_packages(exclude=['tests']),
entry_points={'console_scripts': console_scripts},
install_requires=[
'twisted[tls]',

View file

@ -0,0 +1,7 @@
import logging
log = logging.getLogger("txupnp")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
log.addHandler(handler)
log.setLevel(logging.INFO)

File diff suppressed because one or more lines are too long

105
tests/test_devices.py Normal file
View file

@ -0,0 +1,105 @@
from twisted.internet import reactor, defer
from twisted.trial import unittest
from txupnp.constants import SSDP_PORT, SSDP_IP_ADDRESS
from txupnp.upnp import UPnP
from txupnp.scpd import SCPDCommand
from txupnp.gateway import Service
from txupnp.fault import UPnPError
from txupnp.mocks import MockReactor, MockSSDPServiceGatewayProtocol, get_device_test_case
from txupnp.util import verify_return_types
class TestDevice(unittest.TestCase):
manufacturer, model = "Cisco", "CGA4131COM"
device = get_device_test_case(manufacturer, model)
router_address = device.device_dict['router_address']
client_address = device.device_dict['client_address']
expected_devices = device.device_dict['expected_devices']
packets_rx = device.device_dict['ssdp']['received']
packets_tx = device.device_dict['ssdp']['sent']
expected_available_commands = device.device_dict['commands']['available']
scdp_packets = device.device_dict['scpd']
def setUp(self):
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
reactor.listenMulticast = fake_reactor.listenMulticast
self.reactor = reactor
server_protocol = MockSSDPServiceGatewayProtocol(
self.client_address, self.router_address, self.packets_rx, self.packets_tx
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
self.upnp = UPnP(
self.reactor, debug_ssdp=True, router_ip=self.router_address,
lan_ip=self.client_address, iface_name='mock'
)
def tearDown(self):
self.upnp.sspd_factory.disconnect()
self.server_port.stopListening()
class TestSSDP(TestDevice):
@defer.inlineCallbacks
def test_discover_device(self):
result = yield self.upnp.m_search(self.router_address, timeout=1)
self.assertEqual(len(self.expected_devices), len(result))
self.assertEqual(len(result), 1)
self.assertDictEqual(self.expected_devices[0], result[0])
class TestSCPD(TestDevice):
@defer.inlineCallbacks
def setUp(self):
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
reactor.listenMulticast = fake_reactor.listenMulticast
reactor.connectTCP = fake_reactor.connectTCP
self.reactor = reactor
server_protocol = MockSSDPServiceGatewayProtocol(
self.client_address, self.router_address, self.packets_rx, self.packets_tx
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
self.upnp = UPnP(
self.reactor, debug_ssdp=True, router_ip=self.router_address,
lan_ip=self.client_address, iface_name='mock'
)
yield self.upnp.discover()
def test_parse_available_commands(self):
self.assertDictEqual(self.expected_available_commands, self.upnp.gateway.debug_commands()['available'])
def test_parse_gateway(self):
self.assertDictEqual(self.device.device_dict['gateway_dict'], self.upnp.gateway.as_dict())
@defer.inlineCallbacks
def test_commands(self):
method, args, expected = self.device.device_dict['soap'][0]
command1 = getattr(self.upnp, method)
result = yield command1(*tuple(args))
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][1]
command2 = getattr(self.upnp, method)
result = yield command2(*tuple(args))
result = [[i for i in r] for r in result]
self.assertListEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][2]
command3 = getattr(self.upnp, method)
result = yield command3(*tuple(args))
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][3]
command4 = getattr(self.upnp, method)
result = yield command4(*tuple(args))
result = [r for r in result]
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][4]
command5 = getattr(self.upnp, method)
result = yield command5(*tuple(args))
self.assertEqual(result, expected)

View file

@ -1,45 +0,0 @@
from twisted.internet import reactor, defer
from twisted.trial import unittest
from txupnp.constants import SSDP_PORT
from txupnp import mocks
from txupnp.ssdp import SSDPFactory
class TestDiscoverGateway(unittest.TestCase):
router_address = '10.0.0.1'
client_address = '10.0.0.10'
service_name = 'WANCommonInterfaceConfig:1'
st = 'urn:schemas-upnp-org:service:%s' % service_name
port = 49152
location = 'InternetGatewayDevice.xml'
usn = 'uuid:00000000-0000-0000-0000-000000000000::%s' % st
version = 'Linux, UPnP/1.0, DIR-890L Ver 1.20'
expected_devices = [
{
'cache_control': 'max-age=1800',
'location': 'http://%s:%i/%s' % (router_address, port, location),
'server': version,
'st': st,
'usn': usn
}
]
def setUp(self):
fake_reactor = mocks.MockReactor()
reactor.listenMulticast = fake_reactor.listenMulticast
self.reactor = reactor
server_protocol = mocks.MockSSDPServiceGatewayProtocol(
self.router_address, self.service_name, self.st, self.port, self.location, self.usn, self.version
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol)
def tearDown(self):
self.server_port.stopListening()
@defer.inlineCallbacks
def test_discover(self):
client_factory = SSDPFactory(self.reactor, self.client_address, self.router_address)
result = yield client_factory.m_search(self.router_address)
self.assertListEqual(self.expected_devices, result)
client_factory.disconnect()

View file

@ -1,3 +1,5 @@
import os
import json
import argparse
import logging
from twisted.internet import reactor, defer
@ -6,11 +8,6 @@ from txupnp.upnp import UPnP
log = logging.getLogger("txupnp")
def debug_device(u, include_gateway_xml=False, *_):
print(u.get_debug_info(include_gateway_xml=include_gateway_xml))
return defer.succeed(None)
@defer.inlineCallbacks
def get_external_ip(u, *_):
ip = yield u.get_external_ip()
@ -30,7 +27,7 @@ def list_mappings(u, *_):
@defer.inlineCallbacks
def add_mapping(u, *_):
port = 4567
port = 51413
protocol = "UDP"
description = "txupnp test mapping"
ext_port = yield u.get_next_mapping(port, protocol, description)
@ -50,12 +47,64 @@ def delete_mapping(u, *_):
print("removed mapping")
def _encode(x):
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
@defer.inlineCallbacks
def generate_test_data(u, *_):
external_ip = yield u.get_external_ip()
redirects = yield u.get_redirects()
ext_port = yield u.get_next_mapping(4567, "UDP", "txupnp test mapping")
delete = yield u.delete_port_mapping(ext_port, "UDP")
after_delete = yield u.get_specific_port_mapping(ext_port, "UDP")
commands_test_case = (
("get_external_ip", (), "1.2.3.4"),
("get_redirects", (), redirects),
("get_next_mapping", (4567, "UDP", "txupnp test mapping"), ext_port),
("delete_port_mapping", (ext_port, "UDP"), delete),
("get_specific_port_mapping", (ext_port, "UDP"), after_delete),
)
gateway = u.gateway
device = list(gateway.devices.values())[0]
assert device.manufacturer and device.modelName
device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName))
commands = gateway.debug_commands()
with open(device_path, "w") as f:
f.write(json.dumps({
"router_address": u.router_ip,
"client_address": u.lan_address,
"port": gateway.port,
"gateway_dict": gateway.as_dict(),
'expected_devices': [
{
'cache_control': 'max-age=1800',
'location': gateway.location,
'server': gateway.server,
'st': gateway.urn,
'usn': gateway.usn
}
],
'commands': commands,
'ssdp': u.sspd_factory.get_ssdp_packet_replay(),
'scpd': gateway.requester.dump_packets(),
'soap': commands_test_case
}, default=_encode, indent=2).replace(external_ip, "1.2.3.4"))
print("Generated test data! -> %s" % device_path)
cli_commands = {
"debug_device": debug_device,
"get_external_ip": get_external_ip,
"list_mappings": list_mappings,
"add_mapping": add_mapping,
"delete_mapping": delete_mapping
"delete_mapping": delete_mapping,
"generate_test_data": generate_test_data,
}
@ -85,9 +134,9 @@ def main():
parser.add_argument("--include_igd_xml", dest="include_igd_xml", default=False, action="store_true")
args = parser.parse_args()
if args.debug_logging:
from twisted.python import log as tx_log
observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
observer.start()
# from twisted.python import log as tx_log
# observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
# observer.start()
log.setLevel(logging.DEBUG)
command = args.command
command = command.replace("-", "_")
@ -98,7 +147,7 @@ def main():
def show(err):
print("error: {}".format(err))
u = UPnP(reactor)
u = UPnP(reactor, debug_ssdp=(command == "generate_test_data"))
d = u.discover()
d.addCallback(run_command, u, command, args.include_igd_xml)
d.addErrback(show)

137
txupnp/commands.py Normal file
View file

@ -0,0 +1,137 @@
from txupnp.util import return_types, none_or_str, none
class SCPDCommands: # TODO use type annotations
def debug_commands(self) -> dict:
raise NotImplementedError()
@staticmethod
@return_types(none)
def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: bool, NewPortMappingDescription: str,
NewLeaseDuration: str = '') -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(bool, bool)
def GetNATRSIPStatus() -> (bool, bool):
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
@return_types(none_or_str, int, str, int, str, bool, str, int)
def GetGenericPortMappingEntry(NewPortMappingIndex) -> (none_or_str, int, str, int, str, bool, str, int):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
raise NotImplementedError()
@staticmethod
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol) -> (int, str, bool, str, int):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetConnectionType(NewConnectionType) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(str)
def GetExternalIPAddress() -> str:
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetConnectionTypeInfo() -> (str, str):
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str, int)
def GetStatusInfo() -> (str, str, int):
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def ForceTermination() -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def RequestConnection() -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetCommonLinkProperties():
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesSent():
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesReceived():
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsSent():
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsReceived():
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
@staticmethod
def X_GetICSStatistics():
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
def GetDefaultConnectionService():
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
def SetDefaultConnectionService(NewDefaultConnectionService) -> None:
"""Returns (None)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetEnabledForInternet(NewEnabledForInternet) -> None:
raise NotImplementedError()
@staticmethod
@return_types(bool)
def GetEnabledForInternet() -> bool:
raise NotImplementedError()
@staticmethod
def GetMaximumActiveConnections(NewActiveConnectionIndex):
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetActiveConnections() -> (str, str):
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()

View file

@ -20,15 +20,7 @@ IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
service_types = [
UPNP_ORG_IGD,
WIFI_ALLIANCE_ORG_IGD,
WAN_SCHEMA,
LAYER_SCHEMA,
IP_SCHEMA,
CONTROL,
SERVICE,
DEVICE,
# WIFI_ALLIANCE_ORG_IGD,
]
SSDP_IP_ADDRESS = '239.255.255.250'

View file

@ -1,94 +0,0 @@
import logging
from twisted.web.client import HTTPConnectionPool, _HTTP11ClientFactory
from twisted.web._newclient import HTTPClientParser, BadResponseVersion, HTTP11ClientProtocol, RequestNotSent
from twisted.web._newclient import TransportProxyProducer, RequestGenerationFailed
from twisted.python.failure import Failure
from twisted.internet.defer import Deferred, fail, maybeDeferred
from twisted.internet.defer import CancelledError
log = logging.getLogger()
class DirtyHTTPParser(HTTPClientParser):
def parseVersion(self, strversion):
"""
Parse version strings of the form Protocol '/' Major '.' Minor. E.g.
b'HTTP/1.1'. Returns (protocol, major, minor). Will raise ValueError
on bad syntax.
"""
try:
proto, strnumber = strversion.split(b'/')
major, minor = strnumber.split(b'.')
major, minor = int(major), int(minor)
except ValueError as e:
log.exception("got a bad http version: %s", strversion)
if b'HTTP1.1' in strversion:
return ("HTTP", 1, 1)
raise BadResponseVersion(str(e), strversion)
if major < 0 or minor < 0:
raise BadResponseVersion(u"version may not be negative",
strversion)
return (proto, major, minor)
class DirtyHTTPClientProtocol(HTTP11ClientProtocol):
def request(self, request):
if self._state != 'QUIESCENT':
return fail(RequestNotSent())
self._state = 'TRANSMITTING'
_requestDeferred = maybeDeferred(request.writeTo, self.transport)
def cancelRequest(ign):
# Explicitly cancel the request's deferred if it's still trying to
# write when this request is cancelled.
if self._state in (
'TRANSMITTING', 'TRANSMITTING_AFTER_RECEIVING_RESPONSE'):
_requestDeferred.cancel()
else:
self.transport.abortConnection()
self._disconnectParser(Failure(CancelledError()))
self._finishedRequest = Deferred(cancelRequest)
# Keep track of the Request object in case we need to call stopWriting
# on it.
self._currentRequest = request
self._transportProxy = TransportProxyProducer(self.transport)
self._parser = DirtyHTTPParser(request, self._finishResponse)
self._parser.makeConnection(self._transportProxy)
self._responseDeferred = self._parser._responseDeferred
def cbRequestWritten(ignored):
if self._state == 'TRANSMITTING':
self._state = 'WAITING'
self._responseDeferred.chainDeferred(self._finishedRequest)
def ebRequestWriting(err):
if self._state == 'TRANSMITTING':
self._state = 'GENERATION_FAILED'
self.transport.abortConnection()
self._finishedRequest.errback(
Failure(RequestGenerationFailed([err])))
else:
self._log.failure(
u'Error writing request, but not in valid state '
u'to finalize request: {state}',
failure=err,
state=self._state
)
_requestDeferred.addCallbacks(cbRequestWritten, ebRequestWriting)
return self._finishedRequest
class DirtyHTTP11ClientFactory(_HTTP11ClientFactory):
def buildProtocol(self, addr):
return DirtyHTTPClientProtocol(self._quiescentCallback)
class DirtyPool(HTTPConnectionPool):
_factory = DirtyHTTP11ClientFactory

View file

@ -1,23 +1,12 @@
import logging
from twisted.internet import defer
import treq
import re
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, get_dict_val_case_insensitive
from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from txupnp.constants import DEVICE, ROOT
from txupnp.scpd import SCPDCommand, SCPDRequester
from txupnp.util import get_dict_val_case_insensitive, verify_return_types, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from txupnp.constants import SPEC_VERSION
from txupnp.commands import SCPDCommands
log = logging.getLogger(__name__)
service_type_pattern = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\})"
)
xml_root_sanity_pattern = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
)
class CaseInsensitive:
def __init__(self, **kwargs):
@ -112,7 +101,7 @@ class Device(CaseInsensitive):
class Gateway:
def __init__(self, **kwargs):
def __init__(self, reactor, **kwargs):
flattened = {
k.lower(): v for k, v in kwargs.items()
}
@ -133,9 +122,10 @@ class Gateway:
self.date = date.encode()
self.urn = st.encode()
self._xml_response = ""
self._service_descriptors = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self.xml_response = None
self.spec_version = None
self.url_base = None
@ -143,55 +133,65 @@ class Gateway:
self._devices = []
self._services = []
def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict:
r = {
'server': self.server,
'urlBase': self.url_base,
'location': self.location,
"specVersion": self.spec_version,
'usn': self.usn,
'urn': self.urn,
'devices': [device.as_dict() for device in self._devices]
}
if include_xml:
r['xml_response'] = self.xml_response
if include_services:
r['services'] = [service.as_dict() for service in self._services]
self._reactor = reactor
self._unsupported_actions = {}
self._registered_commands = {}
self.commands = SCPDCommands()
self.requester = SCPDRequester(self._reactor)
def as_dict(self) -> dict:
r = {
'server': self.server.decode(),
'urlBase': self.url_base,
'location': self.location.decode(),
"specVersion": self.spec_version,
'usn': self.usn.decode(),
'urn': self.urn.decode(),
}
return r
@defer.inlineCallbacks
def discover_services(self):
log.debug("querying %s", self.location)
response = yield treq.get(self.location)
self.xml_response = yield response.content()
if not self.xml_response:
log.warning("service sent an empty reply\n%s", self.debug_device())
xml_dict = etree_to_dict(ElementTree.fromstring(self.xml_response))
schema_key = DEVICE
root = ROOT
if len(xml_dict) > 1:
log.warning(xml_dict.keys())
for k in xml_dict.keys():
m = xml_root_sanity_pattern.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
self.spec_version = get_dict_val_case_insensitive(flattened_xml, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(flattened_xml, "urlbase")
if flattened_xml:
def discover_commands(self):
response = yield self.requester.scpd_get(self.location.decode().split(self.base_address.decode())[1], self.base_address.decode(), self.port)
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base:
self.url_base = self.base_address.decode()
if response:
self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(flattened_xml, "device")
self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
)
log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices),
len(self.services))
else:
self._device = Device(self._devices, self._services)
log.debug("finished setting up gateway:\n%s", self.debug_device())
for service_type in self.services.keys():
service = self.services[service_type]
yield self.register_commands(service)
@defer.inlineCallbacks
def register_commands(self, service: Service):
try:
action_list = yield self.requester.scpd_get_supported_actions(service, self.base_address.decode(), self.port)
except Exception as err:
log.exception("failed to register service %s: %s", service.serviceType, str(err))
return
for name, inputs, outputs in action_list:
try:
command = SCPDCommand(self.requester, self.base_address, self.port,
service.controlURL.encode(),
service.serviceType.encode(), name, inputs, outputs)
current = getattr(self.commands, command.method)
if hasattr(current, "_return_types"):
command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)
self._registered_commands[command.method] = service.serviceType
log.debug("registered %s::%s", service.serviceType, command.method)
except AttributeError:
s = self._unsupported_actions.get(service.serviceType, [])
s.append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)
@property
def services(self) -> dict:
@ -205,7 +205,13 @@ class Gateway:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type) -> Service:
def get_service(self, service_type: str) -> Service:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
return service
def debug_commands(self):
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}

View file

@ -1,44 +1,118 @@
import os
import json
import logging
import binascii
from twisted.internet import task
from txupnp.constants import SSDP_IP_ADDRESS
from txupnp.fault import UPnPError
from txupnp.ssdp_datagram import SSDPDatagram
from twisted.internet import task, defer
from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import DatagramProtocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import _FakePort
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger()
class MockResponse:
def __init__(self, content):
self._content = content
self.headers = {}
def content(self):
return defer.succeed(self._content)
class MockDevice:
def __init__(self, manufacturer, model):
self.manufacturer = manufacturer
self.model = model
device_path = os.path.join(
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"), "{} {}".format(manufacturer, model)
)
assert os.path.isfile(device_path)
with open(device_path, "r") as f:
self.device_dict = json.loads(f.read())
def __repr__(self):
return "MockDevice(manufacturer={}, model={})".format(self.manufacturer, self.model)
def get_mock_devices():
return [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "devices"))
if ".py" not in path and "pycache" not in path
]
def get_device_test_case(manufacturer: str, model: str) -> MockDevice:
r = [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"))
if ".py" not in path and "pycache" not in path and path.split(" ") == [manufacturer, model]
]
return r[0]
class MockMulticastTransport:
def __init__(self, address, port, max_packet_size, network):
def __init__(self, address, port, max_packet_size, network, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._network = network
self._protocol = protocol
def write(self, data, address):
if address in self._network.peers:
for dest in self._network.peers[address]:
dest.datagramReceived(data, (self.address, self.port))
else: # the node is sending to an address that doesnt currently exist, act like it never arrived
pass
if address[0] in self._network.group:
destinations = self._network.group[address[0]]
else:
destinations = address[0]
for address, dest in self._network.peers.items():
if address[0] in destinations and dest.address != self.address:
dest._protocol.datagramReceived(data, (self.address, self.port))
def setTTL(self, ttl):
pass
def joinGroup(self, address, interface=None):
pass
group = self._network.group.get(address, [])
group.append(interface)
self._network.group[address] = group
def leaveGroup(self, address, interface=None):
pass
group = self._network.group.get(address, [])
if interface in group:
group.remove(interface)
self._network.group[address] = group
class MockMulticastPort(object):
def __init__(self, protocol, remover):
class MockTCPTransport(_FakePort):
def __init__(self, address, port, callback, mock_requests):
super().__init__((address, port))
self._callback = callback
self._mock_requests = mock_requests
def write(self, data):
if data.startswith(b"POST"):
for url, packets in self._mock_requests['POST'].items():
for request_response in packets:
if data.decode() == request_response['request']:
self._callback(request_response['response'].encode())
return
elif data.startswith(b"GET"):
for url, packets in self._mock_requests['GET'].items():
if data.decode() == packets['request']:
self._callback(packets['response'].encode())
return
class MockMulticastPort(_FakePort):
def __init__(self, protocol, remover, address, transport):
super().__init__((address, 1900))
self.protocol = protocol
self._remover = remover
self.transport = transport
def startListening(self, reason=None):
self.protocol.transport = self.transport
return self.protocol.startProtocol()
def stopListening(self, reason=None):
@ -49,57 +123,55 @@ class MockMulticastPort(object):
class MockNetwork:
def __init__(self):
self.peers = {} # (interface, port): (protocol, max_packet_size)
self.peers = {}
self.group = {}
def add_peer(self, port, protocol, interface, maxPacketSize):
protocol.transport = MockMulticastTransport(interface, port, maxPacketSize, self)
peers = self.peers.get((interface, port), [])
peers.append(protocol)
self.peers[(interface, port)] = peers
transport = MockMulticastTransport(interface, port, maxPacketSize, self, protocol)
self.peers[(interface, port)] = transport
def remove_peer():
if self.peers.get((interface, port)):
self.peers[(interface, port)].remove(protocol)
if not self.peers.get((interface, port)):
del self.peers[(interface, port)]
del protocol.transport
return remove_peer
return transport, remove_peer
class MockReactor(task.Clock):
def __init__(self):
def __init__(self, client_addr, mock_scpd_requests):
super().__init__()
self.client_addr = client_addr
self._mock_scpd_requests = mock_scpd_requests
self.network = MockNetwork()
def listenMulticast(self, port, protocol, interface=SSDP_IP_ADDRESS, maxPacketSize=8192, listenMultiple=True):
remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
port = MockMulticastPort(protocol, remover)
def listenMulticast(self, port, protocol, interface=None, maxPacketSize=8192, listenMultiple=True):
interface = interface or self.client_addr
transport, remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
port = MockMulticastPort(protocol, remover, interface, transport)
port.startListening()
return port
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
protocol = factory.buildProtocol(host)
def _write_and_close(data):
protocol.dataReceived(data)
protocol.connectionLost(Failure(ConnectionDone()))
protocol.transport = MockTCPTransport(host, port, _write_and_close, self._mock_scpd_requests)
protocol.connectionMade()
class MockSSDPServiceGatewayProtocol(DatagramProtocol):
def __init__(self, iface, service_name, st, port, location, usn, version):
def __init__(self, client_addr: int, iface: str, packets_rx: list, packets_tx: list):
self.client_addr = client_addr
self.iface = iface
self.service_name = service_name
self.gateway_st = st
self.gateway_location = location
self.gateway_usn = usn
self.gateway_version = version
self.gateway_port = port
self.packets_tx = [SSDPDatagram.decode(packet.encode()) for packet in packets_tx] # sent by client
self.packets_rx = [((addr, port), SSDPDatagram.decode(packet.encode())) for (addr, port), packet in packets_rx] # rx by client
def datagramReceived(self, datagram, address):
try:
packet = SSDPDatagram.decode(datagram)
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err,
binascii.hexlify(datagram))
return
if packet._packet_type == packet._M_SEARCH:
if packet.man == "ssdp:discover" and packet.st == self.gateway_st:
location = 'http://{}:{}/{}'.format(self.iface, self.gateway_port, self.gateway_location)
response = SSDPDatagram(SSDPDatagram._OK, st=self.gateway_st, cache_control='max-age=1800',
location=location, server=self.gateway_version, usn=self.gateway_usn)
self.transport.write(response.encode().encode(), address)
packet = SSDPDatagram.decode(datagram)
if packet.st in map(lambda p: p[1].st, self.packets_rx): # this contains one of the service types the server replied to
reply = list(filter(lambda p: p[1].st == packet.st, self.packets_rx))[0][1]
self.transport.write(reply.encode().encode(), (self.client_addr, 1900))
else:
pass

View file

@ -1,48 +1,224 @@
import re
import logging
from collections import OrderedDict
from twisted.internet import defer, threads
from twisted.web.client import Agent
import treq
from treq.client import HTTPClient
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none
from twisted.internet.protocol import Protocol, ClientFactory
from twisted.internet import defer, error
from txupnp.constants import XML_VERSION, DEVICE, ROOT, SERVICE, ENVELOPE, BODY
from txupnp.util import etree_to_dict, flatten_keys
from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST
from txupnp.dirty_pool import DirtyPool
log = logging.getLogger(__name__)
CONTENT_PATTERN = re.compile(
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
)
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
)
class StringProducer:
def __init__(self, body):
self.body = body
self.length = len(body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
XML_ROOT_SANITY_PATTERN = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
)
def xml_arg(name, arg):
return "<%s>%s</%s>" % (name, arg, name)
def parse_service_description(content: bytes):
element_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "scpd" not in service_info:
return []
action_list = service_info["scpd"]["actionList"]
if not len(action_list): # it could be an empty string
return []
result = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else:
arg_dicts = action['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
result.append((
action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
def get_soap_body(service_name, method, param_names, **kwargs):
args = "".join(xml_arg(n, kwargs.get(n)) for n in param_names)
return '\n%s\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body><u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (XML_VERSION, method, service_name, args, method)
class SCPDHTTPClientProtocol(Protocol):
def connectionMade(self):
self.response_buff = b""
self.factory.reactor.callLater(0, self.transport.write, self.factory.packet)
def dataReceived(self, data):
self.response_buff += data
def connectionLost(self, reason):
if reason.trap(error.ConnectionDone):
if XML_VERSION.encode() in self.response_buff:
parsed = CONTENT_PATTERN.findall(self.response_buff)
result = b'' if not parsed else parsed[0][0]
self.factory.finished_deferred.callback(result)
else:
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(self.response_buff)
result = b'' if not parsed else XML_VERSION.encode() + b'\r\n' + parsed[0][0]
self.factory.finished_deferred.callback(result)
class _SCPDCommand(object):
def __init__(self, http_client, gateway_address, service_port, control_url, service_id, method, param_names,
class SCPDHTTPClientFactory(ClientFactory):
protocol = SCPDHTTPClientProtocol
def __init__(self, reactor, packet):
self.reactor = reactor
self.finished_deferred = defer.Deferred()
self.packet = packet
def buildProtocol(self, addr):
p = self.protocol()
p.factory = self
return p
@classmethod
def post(cls, reactor, command, **kwargs):
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in command.param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
XML_VERSION, command.method, command.service_id.decode(),
args, command.method))
data = (
(
'POST %s HTTP/1.1\r\n'
'Host: %s\r\n'
'User-Agent: Debian/buster/sid, UPnP/1.0, MiniUPnPc/1.9\r\n'
'Content-Length: %i\r\n'
'Content-Type: text/xml\r\n'
'SOAPAction: \"%s#%s\"\r\n'
'Connection: Close\r\n'
'Cache-Control: no-cache\r\n'
'Pragma: no-cache\r\n'
'%s'
'\r\n'
) % (
command.control_url.decode(), # could be just / even if it shouldn't be
command.gateway_address.decode().split("http://")[1],
len(soap_body),
command.service_id.decode(), # maybe no quotes
command.method,
soap_body
)
).encode()
return cls(reactor, data)
@classmethod
def get(cls, reactor, control_url: str, address: str):
data = (
(
'GET %s HTTP/1.1\r\n'
'Accept-Encoding: gzip\r\n'
'Host: %s\r\n'
'\r\n'
) % (control_url, address)
).encode()
return cls(reactor, data)
class SCPDRequester:
client_factory = SCPDHTTPClientFactory
def __init__(self, reactor):
self._reactor = reactor
self._get_requests = {}
self._post_requests = {}
def _save_get(self, request: bytes, response: bytes, destination: str) -> None:
self._get_requests[destination.lstrip("/")] = {
'request': request,
'response': response
}
def _save_post(self, request: bytes, response: bytes, destination: str) -> None:
p = self._post_requests.get(destination.lstrip("/"), [])
p.append({
'request': request,
'response': response,
})
self._post_requests[destination.lstrip("/")] = p
@defer.inlineCallbacks
def _scpd_get_soap_xml(self, control_url: str, address: str, service_port: int) -> bytes:
factory = self.client_factory.get(self._reactor, control_url, address)
url = address.split("http://")[1].split(":")[0]
self._reactor.connectTCP(url, service_port, factory)
xml_response_bytes = yield factory.finished_deferred
self._save_get(factory.packet, xml_response_bytes, control_url)
return xml_response_bytes
@defer.inlineCallbacks
def scpd_post_soap(self, command, **kwargs) -> tuple:
factory = self.client_factory.post(self._reactor, command, **kwargs)
url = command.gateway_address.split(b"http://")[1].split(b":")[0]
self._reactor.connectTCP(url.decode(), command.service_port, factory)
xml_response_bytes = yield factory.finished_deferred
self._save_post(
factory.packet, xml_response_bytes, command.gateway_address.decode() + command.control_url.decode()
)
content_dict = etree_to_dict(ElementTree.fromstring(xml_response_bytes.decode()))
envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % command.service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault
response_key = None
for key in body:
if command.method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
response = body[response_key]
extracted_response = tuple([response[n] for n in command.returns])
return extracted_response
@defer.inlineCallbacks
def scpd_get_supported_actions(self, service, address: str, port: int) -> list:
xml_bytes = yield self._scpd_get_soap_xml(service.SCPDURL, address, port)
return parse_service_description(xml_bytes)
@defer.inlineCallbacks
def scpd_get(self, control_url: str, service_address: str, service_port: int) -> dict:
xml_bytes = yield self._scpd_get_soap_xml(control_url, service_address, service_port)
xml_dict = etree_to_dict(ElementTree.fromstring(xml_bytes.decode()))
schema_key = DEVICE
root = ROOT
for k in xml_dict.keys():
m = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
return flattened_xml
def dump_packets(self) -> dict:
return {
'GET': self._get_requests,
'POST': self._post_requests
}
class SCPDCommand:
def __init__(self, scpd_requester: SCPDRequester, gateway_address, service_port, control_url, service_id, method,
param_names,
returns):
self._http_client = http_client
self.scpd_requester = scpd_requester
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
@ -51,59 +227,8 @@ class _SCPDCommand(object):
self.param_names = param_names
self.returns = returns
def extract_body(self, xml_response):
content_dict = etree_to_dict(ElementTree.fromstring(xml_response))
envelope = content_dict[ENVELOPE]
return flatten_keys(envelope[BODY], "{%s}" % self.service_id)
def extract_response(self, body):
body = handle_fault(body) # raises UPnPError if there is a fault
response_key = None
for key in body:
if self.method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
response = body[response_key]
extracted_response = tuple([response[n] for n in self.returns])
if len(extracted_response) == 1:
return extracted_response[0]
return extracted_response
@defer.inlineCallbacks
def send_upnp_soap(self, **kwargs):
soap_body = get_soap_body(self.service_id, self.method, self.param_names, **kwargs).encode()
headers = OrderedDict((
('SOAPAction', '%s#%s' % (self.service_id, self.method)),
('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))),
('Content-Type', 'text/xml'),
('Content-Length', len(soap_body))
))
log.debug("send POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body)
try:
response = yield self._http_client.request(
POST, url=self.control_url, data=soap_body, headers=headers
)
except Exception as err:
log.error("error (%s) sending POST to %s\nheaders: %s\nbody:%s\n", err, self.control_url, headers,
soap_body)
raise UPnPError(err)
xml_response = yield response.content()
try:
response = self.extract_response(self.extract_body(xml_response))
except UPnPError:
raise
except Exception as err:
log.debug("error extracting response (%s) to %s:\n%s", err, self.method, xml_response)
raise err
if not response:
log.debug("empty response to %s\n%s", self.method, xml_response)
defer.returnValue(response)
@staticmethod
def _process_result(results):
def _process_result(*results):
"""
this method gets decorated automatically with a function that maps result types to the types
defined in the @return_types decorator
@ -114,367 +239,10 @@ class _SCPDCommand(object):
def __call__(self, **kwargs):
if set(kwargs.keys()) != set(self.param_names):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_names))
response = yield self.send_upnp_soap(**kwargs)
response = yield self.scpd_requester.scpd_post_soap(self, **kwargs)
try:
result = self._process_result(response)
result = self._process_result(*response)
except Exception as err:
log.error("error formatting response (%s):\n%s", err, response)
raise err
defer.returnValue(result)
class SCPDResponse(object):
def __init__(self, url, headers, content):
self.url = url
self.headers = headers
self.content = content
def get_element_tree(self):
return ElementTree.fromstring(self.content)
def get_element_dict(self, service_key):
return flatten_keys(etree_to_dict(self.get_element_tree()), "{%s}" % service_key)
def get_action_list(self):
return self.get_element_dict(SERVICE)["scpd"]["actionList"]["action"]
def get_device_info(self):
return self.get_element_dict(DEVICE)[ROOT]
class SCPDCommandRunner(object):
def __init__(self, gateway, reactor, treq_get=None):
self._gateway = gateway
self._unsupported_actions = {}
self._registered_commands = {}
self._reactor = reactor
self._connection_pool = DirtyPool(reactor)
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._treq_get = treq_get or treq.get
@defer.inlineCallbacks
def _discover_commands(self, service):
scpd_url = self._gateway.base_address + service.SCPDURL.encode()
response = yield self._treq_get(scpd_url)
content = yield response.content()
try:
scpd_response = SCPDResponse(scpd_url, response.headers, content)
for action_dict in scpd_response.get_action_list():
self._register_command(action_dict, service.serviceType)
except Exception as err:
log.exception("failed to parse scpd response (%s) from %s\nheaders:\n%s\ncontent\n%s",
err, scpd_url, response.headers, content)
defer.returnValue(None)
@defer.inlineCallbacks
def discover_commands(self):
for service_type in service_types:
service = self._gateway.get_service(service_type)
if not service:
continue
yield self._discover_commands(service)
log.debug(self.debug_commands())
@staticmethod
def _soap_function_info(action_dict):
if not action_dict.get('argumentList'):
return (
action_dict['name'],
[],
[]
)
arg_dicts = action_dict['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
return (
action_dict['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
)
def _patch_command(self, action_info, service_type):
name, inputs, outputs = self._soap_function_info(action_info)
command = _SCPDCommand(self._http_client, self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(),
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
current = getattr(self, command.method)
if hasattr(current, "_return_types"):
command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command)
self._registered_commands[command.method] = service_type
log.debug("registered %s %s", service_type, action_info['name'])
return True
def _register_command(self, action_info, service_type):
try:
return self._patch_command(action_info, service_type)
except Exception as err:
s = self._unsupported_actions.get(service_type, [])
s.append((action_info, err))
self._unsupported_actions[service_type] = s
log.error("available command for %s does not have a wrapper implemented: %s", service_type, action_info)
def debug_commands(self):
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}
@staticmethod
@return_types(none)
def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
NewEnabled, NewPortMappingDescription, NewLeaseDuration=''):
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(bool, bool)
def GetNATRSIPStatus():
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
@return_types(none_or_str, int, str, int, str, bool, str, int)
def GetGenericPortMappingEntry(NewPortMappingIndex):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
raise NotImplementedError()
@staticmethod
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetConnectionType(NewConnectionType):
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(str)
def GetExternalIPAddress():
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetConnectionTypeInfo():
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str, int)
def GetStatusInfo():
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def ForceTermination():
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def RequestConnection():
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetCommonLinkProperties():
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesSent():
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesReceived():
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsSent():
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsReceived():
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
@staticmethod
def X_GetICSStatistics():
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
def GetDefaultConnectionService():
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
def SetDefaultConnectionService(NewDefaultConnectionService):
"""Returns (None)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetEnabledForInternet(NewEnabledForInternet):
raise NotImplementedError()
@staticmethod
@return_types(bool)
def GetEnabledForInternet():
raise NotImplementedError()
@staticmethod
def GetMaximumActiveConnections(NewActiveConnectionIndex):
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetActiveConnections():
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()
class UPnPFallback(object):
def __init__(self):
try:
import miniupnpc
self._upnp = miniupnpc.UPnP()
self.available = True
except ImportError:
self._upnp = None
self.available = False
@defer.inlineCallbacks
def discover(self):
if not self.available:
raise NotImplementedError()
devices = yield threads.deferToThread(self._upnp.discover)
if devices:
self.device_url = yield threads.deferToThread(self._upnp.selectigd)
else:
self.device_url = None
defer.returnValue(devices > 0)
@return_types(none)
def AddPortMapping(self, NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
NewEnabled, NewPortMappingDescription, NewLeaseDuration=''):
"""Returns None"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.addportmapping, NewExternalPort, NewProtocol, NewInternalClient,
NewInternalPort, NewPortMappingDescription, NewLeaseDuration)
def GetNATRSIPStatus(self):
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@return_types(none_or_str, int, str, int, str, bool, str, int)
@defer.inlineCallbacks
def GetGenericPortMappingEntry(self, NewPortMappingIndex):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
if not self.available:
raise NotImplementedError()
result = yield threads.deferToThread(self._upnp.getgenericportmapping, NewPortMappingIndex)
if not result:
raise UPnPError()
ext_port, protocol, (int_host, int_port), desc, enabled, remote_host, lease = result
defer.returnValue((remote_host, ext_port, protocol, int_port, int_host, enabled, desc, lease))
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(self, NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.getspecificportmapping, NewExternalPort, NewProtocol)
def SetConnectionType(self, NewConnectionType):
"""Returns None"""
raise NotImplementedError()
@return_types(str)
def GetExternalIPAddress(self):
"""Returns (NewExternalIPAddress)"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.externalipaddress)
def GetConnectionTypeInfo(self):
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@return_types(str, str, int)
def GetStatusInfo(self):
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.statusinfo)
def ForceTermination(self):
"""Returns None"""
raise NotImplementedError()
@return_types(none)
def DeletePortMapping(self, NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns None"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.deleteportmapping, NewExternalPort, NewProtocol)
def RequestConnection(self):
"""Returns None"""
raise NotImplementedError()
def GetCommonLinkProperties(self):
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
def GetTotalBytesSent(self):
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
def GetTotalBytesReceived(self):
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
def GetTotalPacketsSent(self):
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
def GetTotalPacketsReceived(self):
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
def X_GetICSStatistics(self):
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
def GetDefaultConnectionService(self):
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
def SetDefaultConnectionService(self, NewDefaultConnectionService):
"""Returns (None)"""
raise NotImplementedError()

View file

@ -1,79 +0,0 @@
import logging
import netifaces
from twisted.internet import defer
from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandRunner
from txupnp.gateway import Gateway
from txupnp.fault import UPnPError
from txupnp.constants import UPNP_ORG_IGD
log = logging.getLogger(__name__)
class SOAPServiceManager(object):
def __init__(self, reactor, treq_get=None):
self._reactor = reactor
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
self._command_runners = {}
self._selected_runner = UPNP_ORG_IGD
self._treq_get = treq_get
@defer.inlineCallbacks
def discover_services(self, address=None, timeout=30, max_devices=1):
server_infos = yield self.sspd_factory.m_search(
address or self.router_ip, timeout=timeout, max_devices=max_devices
)
locations = []
for server_info in server_infos:
if 'st' in server_info and server_info['st'] not in self._command_runners:
locations.append(server_info['location'])
gateway = Gateway(**server_info)
yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get)
yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner
elif 'st' not in server_info:
log.error("don't know how to handle gateway: %s", server_info)
continue
defer.returnValue(len(self._command_runners) > 0)
def set_runner(self, urn):
if urn not in self._command_runners:
raise IndexError(urn)
self._command_runners = urn
def get_runner(self):
if self._command_runners and not self._selected_runner in self._command_runners:
self._selected_runner = list(self._command_runners.keys())[0]
if not self._command_runners:
raise UPnPError("no devices found")
return self._command_runners[self._selected_runner]
def get_available_runners(self):
return self._command_runners.keys()
def debug(self, include_gateway_xml=False):
results = []
for runner in self._command_runners.values():
gateway = runner._gateway
info = gateway.debug_device(include_xml=include_gateway_xml)
commands = runner.debug_commands()
service_result = []
for service in info['services']:
service_commands = []
unavailable = []
for command, service_type in commands['available'].items():
if service['serviceType'] == service_type:
service_commands.append(command)
for command, service_type in commands['failed'].items():
if service['serviceType'] == service_type:
unavailable.append(command)
services_with_commands = dict(service)
services_with_commands['available_commands'] = service_commands
services_with_commands['unavailable_commands'] = unavailable
service_result.append(services_with_commands)
info['services'] = service_result
results.append(info)
return results

View file

@ -12,7 +12,8 @@ log = logging.getLogger(__name__)
class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
ssdp_port=SSDP_PORT, ttl=1, max_devices=None, debug_packets=False,
debug_sent=None, debug_received=None):
self._reactor = reactor
self._sem = defer.DeferredSemaphore(1)
self.discover_callbacks = {}
@ -24,34 +25,36 @@ class SSDPProtocol(DatagramProtocol):
self._start = None
self.max_devices = max_devices
self.devices = []
self.debug_packets = debug_packets
self.debug_sent = debug_sent if debug_sent is not None else []
self.debug_received = debug_received if debug_sent is not None else []
def _send_m_search(self, service=UPNP_ORG_IGD):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
msg_bytes = packet.encode().encode()
if self.debug_packets:
self.debug_sent.append(msg_bytes)
self.transport.write(msg_bytes, (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
raise err
@staticmethod
def _gather(finished_deferred, max_results):
results = []
def _gather(finished_deferred, max_results, results: list):
def discover_cb(packet):
if not finished_deferred.called:
if not finished_deferred.called and packet.st in service_types:
results.append(packet.as_dict())
if len(results) >= max_results:
finished_deferred.callback(results)
return discover_cb
def m_search(self, address=None, timeout=1, max_devices=1):
address = address or self.iface
def m_search(self, address, timeout, max_devices):
# return deferred for a pending call if we have one
if address in self.discover_callbacks:
d = self.protocol.discover_callbacks[address][1]
d = self.discover_callbacks[address][1]
if not d.called: # the existing deferred has already fired, make a new one
return d
@ -63,7 +66,7 @@ class SSDPProtocol(DatagramProtocol):
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
d.addErrback(_trap_timeout_and_return_results)
found_cb = self._gather(d, max_devices)
found_cb = self._gather(d, max_devices, self.devices)
self.discover_callbacks[address] = found_cb, d
for st in service_types:
self._send_m_search(service=st)
@ -73,11 +76,12 @@ class SSDPProtocol(DatagramProtocol):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
self.m_search()
def datagramReceived(self, datagram, address):
if address[0] == self.iface:
return
if self.debug_packets:
self.debug_received.append((address, datagram))
try:
packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
@ -90,23 +94,30 @@ class SSDPProtocol(DatagramProtocol):
return
if packet._packet_type == packet._OK:
log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
if packet.st not in map(lambda p: p['st'], self.devices):
self.devices.append(packet.as_dict())
log.debug("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
if address[0] in self.discover_callbacks:
# if address[0] in self.discover_callbacks and packet.location not in map(lambda p: p['location'], self.devices):
if packet.location not in map(lambda p: p['location'], self.devices):
if address[0] not in self.discover_callbacks:
self.devices.append(packet.as_dict())
else:
self._sem.run(self.discover_callbacks[address[0]][0], packet)
else:
log.info("ignored packet from %s:%s (%s) %s", address[0], address[1], packet._packet_type, packet.location)
elif packet._packet_type == packet._NOTIFY:
log.debug("%s:%i sent us a notification (type: %s), url: %s", address[0], address[1], packet.nts,
packet.location)
class SSDPFactory(object):
def __init__(self, reactor, lan_address, router_address):
class SSDPFactory:
def __init__(self, reactor, lan_address, router_address, debug_packets=False):
self.lan_address = lan_address
self.router_address = router_address
self._reactor = reactor
self.protocol = None
self.port = None
self.debug_packets = debug_packets
self.debug_sent = []
self.debug_received = []
self.server_infos = []
def disconnect(self):
if not self.port:
@ -118,13 +129,15 @@ class SSDPFactory(object):
def connect(self):
if not self.protocol:
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address,
debug_packets=self.debug_packets, debug_sent=self.debug_sent,
debug_received=self.debug_received)
if not self.port:
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
@defer.inlineCallbacks
def m_search(self, address, timeout=1, max_devices=1):
def m_search(self, address, timeout, max_devices):
"""
Perform a M-SEARCH (HTTP over UDP) and gather the results
@ -140,7 +153,16 @@ class SSDPFactory(object):
'usn': (str) usn
}, ...]
"""
self.connect()
server_infos = yield self.protocol.m_search(address, timeout, max_devices)
for server_info in server_infos:
self.server_infos.append(server_info)
defer.returnValue(server_infos)
def get_ssdp_packet_replay(self) -> dict:
return {
'lan_address': self.lan_address,
'router_address': self.router_address,
'sent': self.debug_sent,
'received': self.debug_received,
}

View file

@ -1,45 +1,29 @@
import netifaces
import logging
import json
from twisted.internet import defer
from txupnp.fault import UPnPError
from txupnp.soap import SOAPServiceManager
from txupnp.scpd import UPnPFallback
from txupnp.ssdp import SSDPFactory
from txupnp.gateway import Gateway
log = logging.getLogger(__name__)
class UPnP(object):
def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None):
class UPnP:
def __init__(self, reactor, try_miniupnpc_fallback=False, debug_ssdp=False, router_ip=None,
lan_ip=None, iface_name=None):
self._reactor = reactor
if router_ip and lan_ip and iface_name:
self.router_ip, self.lan_address, self.iface_name = router_ip, lan_ip, iface_name
else:
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip, debug_packets=debug_ssdp)
self.try_miniupnpc_fallback = try_miniupnpc_fallback
self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get)
self.miniupnpc_runner = None
self.miniupnpc_igd_url = None
self.gateway = None
@property
def lan_address(self):
return self.soap_manager.lan_address
@property
def commands(self):
try:
runner = self.soap_manager.get_runner()
required_commands = [
"GetExternalIPAddress",
"AddPortMapping",
"GetSpecificPortMappingEntry",
"GetGenericPortMappingEntry",
"DeletePortMapping"
]
if all((command in runner._registered_commands for command in required_commands)):
return runner
raise UPnPError("required commands not found")
except UPnPError as err:
if self.try_miniupnpc_fallback and self.miniupnpc_runner:
return self.miniupnpc_runner
log.warning("upnp is not available: %s", err)
def m_search(self, address, timeout=30, max_devices=2):
def m_search(self, address, timeout=1, max_devices=1):
"""
Perform a HTTP over UDP M-SEARCH query
@ -51,68 +35,52 @@ class UPnP(object):
'usn': <usn>
}, ...]
"""
return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
return self.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
@defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1, keep_listening=False, try_txupnp=True):
found = False
if not try_txupnp and not self.try_miniupnpc_fallback:
log.warning("nothing left to try")
if try_txupnp:
try:
found = yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices)
except defer.TimeoutError:
found = False
finally:
if not keep_listening:
self.soap_manager.sspd_factory.disconnect()
def _discover(self, timeout=1, max_devices=1):
server_infos = yield self.sspd_factory.m_search(
self.router_ip, timeout=timeout, max_devices=max_devices
)
server_info = server_infos[0]
if 'st' in server_info:
gateway = Gateway(reactor=self._reactor, **server_info)
yield gateway.discover_commands()
self.gateway = gateway
defer.returnValue(True)
elif 'st' not in server_info:
log.error("don't know how to handle gateway: %s", server_info)
defer.returnValue(False)
@defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1):
try:
found = yield self._discover(timeout=timeout, max_devices=max_devices)
except defer.TimeoutError:
found = False
finally:
self.sspd_factory.disconnect()
if found:
try:
runner = self.soap_manager.get_runner()
required_commands = [
"GetExternalIPAddress",
"AddPortMapping",
"GetSpecificPortMappingEntry",
"GetGenericPortMappingEntry",
"DeletePortMapping"
]
found = all((command in runner._registered_commands for command in required_commands))
except UPnPError:
found = False
if not found and self.try_miniupnpc_fallback:
found = yield self.start_miniupnpc_fallback()
log.debug("found upnp device")
else:
log.debug("failed to find upnp device")
defer.returnValue(found)
@defer.inlineCallbacks
def start_miniupnpc_fallback(self):
found = False
if not self.miniupnpc_runner:
log.debug("trying miniupnpc fallback")
fallback = UPnPFallback()
success = yield fallback.discover()
self.miniupnpc_igd_url = fallback.device_url
if success:
log.info("successfully started miniupnpc fallback")
self.miniupnpc_runner = fallback
found = True
if not found:
log.warning("failed to find upnp gateway using miniupnpc fallback")
defer.returnValue(found)
def get_external_ip(self) -> str:
return self.gateway.commands.GetExternalIPAddress()
def get_external_ip(self):
return self.commands.GetExternalIPAddress()
def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description):
return self.commands.AddPortMapping(
def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
description: str) -> None:
return self.gateway.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
)
@defer.inlineCallbacks
def get_port_mapping_by_index(self, index):
def get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
try:
redirect = yield self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
redirect = yield self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
defer.returnValue(redirect)
except UPnPError:
defer.returnValue(None)
@ -129,7 +97,7 @@ class UPnP(object):
defer.returnValue(redirects)
@defer.inlineCallbacks
def get_specific_port_mapping(self, external_port, protocol):
def get_specific_port_mapping(self, external_port: int, protocol: str) -> (int, str, bool, str, int):
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
@ -137,41 +105,23 @@ class UPnP(object):
"""
try:
result = yield self.commands.GetSpecificPortMappingEntry(
result = yield self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
defer.returnValue(result)
except UPnPError:
defer.returnValue(None)
def delete_port_mapping(self, external_port, protocol, new_remote_host=""):
def delete_port_mapping(self, external_port: int, protocol: str) -> None:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: None
"""
return self.commands.DeletePortMapping(
NewRemoteHost=new_remote_host, NewExternalPort=external_port, NewProtocol=protocol
return self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
)
def get_rsip_nat_status(self):
"""
:return: (bool) NewRSIPAvailable, (bool) NewNATEnabled
"""
return self.commands.GetNATRSIPStatus()
def get_status_info(self):
"""
:return: (str) NewConnectionStatus, (str) NewLastConnectionError, (int) NewUptime
"""
return self.commands.GetStatusInfo()
def get_connection_type_info(self):
"""
:return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str)
"""
return self.commands.GetConnectionTypeInfo()
@defer.inlineCallbacks
def get_next_mapping(self, port, protocol, description, internal_port=None):
if protocol not in ["UDP", "TCP"]:
@ -192,15 +142,3 @@ class UPnP(object):
port, protocol, internal_port, self.lan_address, description
)
defer.returnValue(port)
def get_debug_info(self, include_gateway_xml=False):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
return json.dumps({
'txupnp': self.soap_manager.debug(include_gateway_xml=include_gateway_xml),
'miniupnpc_igd_url': self.miniupnpc_igd_url
},
indent=2, default=default_byte
)

View file

@ -1,7 +1,6 @@
import re
import functools
from collections import defaultdict
from twisted.internet import defer
from xml.etree import ElementTree
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
@ -59,13 +58,11 @@ def verify_return_types(*types):
def _verify_return_types(fn):
@functools.wraps(fn)
def _inner(response):
if isinstance(response, (list, tuple)):
r = tuple(t(r) for t, r in zip(types, response))
if len(r) == 1:
return fn(r[0])
return fn(r)
return fn(types[0](response))
def _inner(*result):
r = fn(*tuple(t(r) for t, r in zip(types, result)))
if isinstance(r, tuple) and len(r) == 1:
return r[0]
return r
return _inner
return _verify_return_types
@ -85,18 +82,3 @@ def return_types(*types):
none_or_str = lambda x: None if not x or x == 'None' else str(x)
none = lambda _: None
@defer.inlineCallbacks
def DeferredDict(d, consumeErrors=False):
keys = []
dl = []
response = {}
for k, v in d.items():
keys.append(k)
dl.append(v)
results = yield defer.DeferredList(dl, consumeErrors=consumeErrors)
for k, (success, result) in zip(keys, results):
if success:
response[k] = result
defer.returnValue(response)