This commit is contained in:
Jack Robison 2018-10-08 14:47:37 -04:00
parent 0eca7478e2
commit 1098177bae
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
18 changed files with 335 additions and 159 deletions

1
.gitignore vendored
View file

@ -4,3 +4,4 @@ _trial_temp/
build/
dist/
.coverage
.mypy_cache/

View file

@ -4,11 +4,12 @@ language: python
python: "3.7"
before_install:
- pip install pylint coverage
- pip install pylint coverage mypy
- pip install -e .
# - pylint aioupnp
script:
# - pylint aioupnp
- mypy .
- HOME=/tmp coverage run --source=aioupnp -m unittest -v
after_success:

View file

@ -1,5 +1,6 @@
def none_or_str(x):
return None if not x or x == 'None' else str(x)
from typing import Tuple, Union
none_or_str = Union[None, str]
class SCPDCommands:
@ -14,12 +15,13 @@ class SCPDCommands:
raise NotImplementedError()
@staticmethod
async def GetNATRSIPStatus() -> (bool, bool):
async def GetNATRSIPStatus() -> Tuple[bool, bool]:
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> (none_or_str, int, str, int, str, bool, str, int):
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[none_or_str, int, str, int, str,
bool, str, int]:
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
@ -27,7 +29,8 @@ class SCPDCommands:
raise NotImplementedError()
@staticmethod
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> (int, str, bool, str, int):
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@ -42,12 +45,12 @@ class SCPDCommands:
raise NotImplementedError()
@staticmethod
async def GetConnectionTypeInfo() -> (str, str):
async def GetConnectionTypeInfo() -> Tuple[str, str]:
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
async def GetStatusInfo() -> (str, str, int):
async def GetStatusInfo() -> Tuple[str, str, int]:
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@ -92,7 +95,7 @@ class SCPDCommands:
raise NotImplementedError()
@staticmethod
async def X_GetICSStatistics() -> (int, int, int, int, str, str):
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]:
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@ -119,6 +122,6 @@ class SCPDCommands:
raise NotImplementedError()
@staticmethod
async def GetActiveConnections() -> (str, str):
async def GetActiveConnections() -> Tuple[str, str]:
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()

View file

@ -1,10 +1,11 @@
import logging
from typing import List
log = logging.getLogger(__name__)
class CaseInsensitive:
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
not_evaluated = {}
for k, v in kwargs.items():
if k.startswith("_"):
@ -22,6 +23,7 @@ class CaseInsensitive:
for k, v in self.__dict__.items():
if k.lower() == case_insensitive.lower():
return k
raise AttributeError(case_insensitive)
def __getattr__(self, item):
if item in self.__dict__:
@ -75,7 +77,7 @@ class Device(CaseInsensitive):
presentationURL = None
iconList = None
def __init__(self, devices, services, **kwargs):
def __init__(self, devices: List, services: List, **kwargs) -> None:
super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"]

View file

@ -1,4 +1,7 @@
import logging
import socket
import typing
from typing import Dict, List, Union, Type, Tuple
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
from aioupnp.commands import SCPDCommands
@ -7,11 +10,16 @@ from aioupnp.protocols.ssdp import m_search
from aioupnp.protocols.scpd import scpd_get
from aioupnp.protocols.soap import SCPDCommand
from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError
log = logging.getLogger(__name__)
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "actionList" in service_info:
action_list = service_info["actionList"]
@ -20,7 +28,7 @@ def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...],
if not len(action_list): # it could be an empty string
return []
result = []
result: list = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
@ -97,21 +105,22 @@ class Gateway:
return r
@property
def services(self) -> dict:
def services(self) -> Dict:
if not self._device:
return {}
return {service.serviceType: service for service in self._services}
@property
def devices(self) -> dict:
def devices(self) -> Dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Service:
def get_service(self, service_type: str) -> Union[Type[Service], None]:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
return service
return None
def debug_commands(self):
return {
@ -121,13 +130,14 @@ class Gateway:
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD):
datagram = await m_search(lan_address, gateway_address, timeout, service)
service: str = UPNP_ORG_IGD, ssdp_socket: socket.socket = None,
soap_socket: socket.socket = None):
datagram = await m_search(lan_address, gateway_address, timeout, service, ssdp_socket)
gateway = cls(**datagram.as_dict())
await gateway.discover_commands()
await gateway.discover_commands(soap_socket)
return gateway
async def discover_commands(self):
async def discover_commands(self, soap_socket: socket.socket = None):
response = await scpd_get("/" + self.path.decode(), self.base_ip.decode(), self.port)
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
@ -141,9 +151,11 @@ class Gateway:
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
await self.register_commands(self.services[service_type])
await self.register_commands(self.services[service_type], soap_socket)
async def register_commands(self, service: Service):
async def register_commands(self, service: Service, soap_socket: socket.socket = None):
if not service.SCPDURL:
raise UPnPError("no scpd url")
service_dict = await scpd_get(("" if service.SCPDURL.startswith("/") else "/") + service.SCPDURL,
self.base_ip.decode(), self.port)
if not service_dict:
@ -157,8 +169,10 @@ class Gateway:
annotations = current.__annotations__
return_types = annotations.get('return', None)
if return_types:
if not isinstance(return_types, tuple):
if isinstance(return_types, type):
return_types = (return_types, )
else:
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {}
for param_name, param_type in annotations.items():
@ -167,7 +181,7 @@ class Gateway:
param_types[param_name] = param_type
command = SCPDCommand(
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
name, param_types, return_types, inputs, outputs)
name, param_types, return_types, inputs, outputs, soap_socket)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)

View file

@ -5,42 +5,44 @@ from asyncio.transports import DatagramTransport
class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address
self.bind_address = bind_address
self.transport: DatagramTransport
@property
def socket(self) -> socket.socket:
return self.transport.get_extra_info('socket')
def sock(self) -> socket.socket:
s: socket.socket = self.transport.get_extra_info(name='socket')
return s
def get_ttl(self) -> int:
return self.socket.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
def set_ttl(self, ttl: int = 1) -> None:
self.socket.setsockopt(
self.sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
)
def join_group(self, addr: str, interface: str) -> None:
addr = socket.inet_aton(addr)
interface = socket.inet_aton(interface)
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, addr + interface)
def leave_group(self, addr: str, interface: str) -> None:
addr = socket.inet_aton(addr)
interface = socket.inet_aton(interface)
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, addr + interface)
def connection_made(self, transport: DatagramTransport) -> None:
self.transport = transport
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@classmethod
def create_socket(cls, bind_address: str, multicast_address: str):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0))
sock.setsockopt(
def join_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
def leave_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
def connection_made(self, transport) -> None:
self.transport = transport
self.join_group(self.multicast_address, self.bind_address)
@classmethod
def create_multicast_socket(cls, bind_address: str):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0))
sock.setblocking(False)
return sock

View file

@ -1,4 +1,5 @@
import logging
import socket
from xml.etree import ElementTree
import asyncio
from asyncio.protocols import Protocol
@ -16,7 +17,7 @@ class SCPDHTTPClientProtocol(Protocol):
GET = 'GET'
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None, close_after_send: bool = False):
soap_service_id: str=None, close_after_send: bool = False) -> None:
self.method = method
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
'soap args not provided'
@ -60,7 +61,7 @@ class SCPDHTTPClientProtocol(Protocol):
async def scpd_get(control_url: str, address: str, port: int) -> dict:
loop = asyncio.get_running_loop()
finished = asyncio.Future()
finished: asyncio.Future = asyncio.Future()
packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
@ -72,15 +73,15 @@ async def scpd_get(control_url: str, address: str, port: int) -> dict:
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
close_after_send: bool, **kwargs):
close_after_send: bool, soap_socket: socket.socket = None, **kwargs):
loop = asyncio.get_running_loop()
finished = asyncio.Future()
finished: asyncio.Future = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol(
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
close_after_send=close_after_send
), address, port
), address, port, sock=soap_socket
)
try:
return await asyncio.wait_for(finished, 1.0)

View file

@ -1,4 +1,6 @@
import logging
import socket
import typing
from aioupnp.protocols.scpd import scpd_post
log = logging.getLogger(__name__)
@ -6,7 +8,8 @@ log = logging.getLogger(__name__)
class SCPDCommand:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
param_types: dict, return_types: dict, param_order: list, return_order: list):
param_types: dict, return_types: dict, param_order: list, return_order: list,
soap_socket: socket.socket = None) -> None:
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
@ -16,17 +19,19 @@ class SCPDCommand:
self.param_order = param_order
self.return_types = return_types
self.return_order = return_order
self.soap_socket = soap_socket
async def __call__(self, **kwargs):
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
close_after_send = not self.return_types or self.return_types == [None]
response = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
close_after_send, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
close_after_send, self.soap_socket, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
)
extracted_response = tuple([None if self.return_types[n] is None else self.return_types[n](response[n])
for n in self.return_order]) or (None, )
if len(extracted_response) == 1:
return extracted_response[0]
return extracted_response
result = tuple([self.return_types[n](response.get(n)) for n in self.return_order])
if not result:
return None
if len(result) == 1:
return result[0]
return result

View file

@ -1,9 +1,9 @@
import re
import socket
import binascii
import asyncio
import logging
from typing import DefaultDict
from asyncio.coroutines import coroutine
from typing import Dict, List, Tuple
from asyncio.futures import Future
from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError
@ -17,17 +17,12 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol):
def __init__(self, lan_address):
super().__init__()
def __init__(self, multicast_address: str, lan_address: str) -> None:
super().__init__(multicast_address, lan_address)
self.lan_address = lan_address
self.discover_callbacks: DefaultDict[coroutine] = {}
self.transport: DatagramTransport
self.notifications = []
self.replies = []
def connection_made(self, transport: DatagramTransport):
super().connection_made(transport)
self.set_ttl(1)
self.discover_callbacks: Dict = {}
self.notifications: List = []
self.replies: List = []
async def m_search(self, address, timeout: int = 1, service=UPNP_ORG_IGD) -> SSDPDatagram:
if (address, service) in self.discover_callbacks:
@ -37,7 +32,7 @@ class SSDPProtocol(MulticastProtocol):
mx=1
)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
f = Future()
f: Future = Future()
self.discover_callbacks[(address, service)] = f
return await asyncio.wait_for(f, timeout)
@ -57,8 +52,9 @@ class SSDPProtocol(MulticastProtocol):
if (addr[0], packet.st) in self.discover_callbacks:
if packet.st not in map(lambda p: p['st'], self.replies):
self.replies.append(packet)
f: Future = self.discover_callbacks.pop((addr[0], packet.st))
f.set_result(packet)
ok_fut: Future = self.discover_callbacks.pop((addr[0], packet.st))
ok_fut.set_result(packet)
return
elif packet._packet_type == packet._NOTIFY:
if packet.nt == SSDP_ROOT_DEVICE:
@ -70,21 +66,27 @@ class SSDPProtocol(MulticastProtocol):
break
if key:
log.debug("got a notification with the requested m-search info")
f: Future = self.discover_callbacks.pop(key)
f.set_result(SSDPDatagram(
notify_fut: Future = self.discover_callbacks.pop(key)
notify_fut.set_result(SSDPDatagram(
SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server,
st=UPNP_ORG_IGD, usn=packet.usn
))
self.notifications.append(packet.as_dict())
return
async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransport, SSDPProtocol, str, str):
async def listen_ssdp(lan_address: str, gateway_address: str,
ssdp_socket: socket.socket = None) -> Tuple[DatagramTransport, SSDPProtocol,
str, str]:
loop = asyncio.get_running_loop()
try:
sock = SSDPProtocol.create_socket(lan_address, SSDP_IP_ADDRESS)
transport, protocol = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(lan_address), sock=sock
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
listen_result: Tuple = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
)
transport: DatagramTransport = listen_result[0]
protocol: SSDPProtocol = listen_result[1]
protocol.set_ttl(1)
except Exception:
log.exception("failed to create multicast socket %s:%i", lan_address, SSDP_PORT)
raise
@ -92,12 +94,12 @@ async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransp
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD) -> SSDPDatagram:
service: str = UPNP_ORG_IGD, ssdp_socket: socket.socket = None) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address
lan_address, gateway_address, ssdp_socket
)
try:
return await protocol.m_search(gateway_address, timeout=timeout, service=service)
return await protocol.m_search(address=gateway_address, timeout=timeout, service=service)
except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:

View file

@ -1,4 +1,5 @@
import re
from typing import Dict
from xml.etree import ElementTree
from aioupnp.constants import XML_VERSION, DEVICE, ROOT
from aioupnp.util import etree_to_dict, flatten_keys
@ -33,7 +34,7 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
).encode()
def deserialize_scpd_get_response(content: bytes) -> dict:
def deserialize_scpd_get_response(content: bytes) -> Dict:
if XML_VERSION.encode() in content:
parsed = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0]
@ -47,3 +48,4 @@ def deserialize_scpd_get_response(content: bytes) -> dict:
root = m[2][5]
break
return flatten_keys(xml_dict, "{%s}" % schema_key)[root]
return {}

View file

@ -1,6 +1,7 @@
import re
import logging
import binascii
from typing import Dict, List
from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator
@ -69,13 +70,8 @@ class SSDPDatagram(object):
]
}
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs):
cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
@ -95,8 +91,9 @@ class SSDPDatagram(object):
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __repr__(self):
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
def __repr__(self) -> str:
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
@ -104,17 +101,19 @@ class SSDPDatagram(object):
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self):
def get_friendly_name(self) -> str:
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2):
def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers:
value = self._marshallers[attr_name](attr)
if attr_name == 'mx':
value = str(attr)
elif attr_name == 'man':
value = "\"%s\"" % attr
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
@ -123,7 +122,7 @@ class SSDPDatagram(object):
serialized += line_separator
return serialized
def as_dict(self):
def as_dict(self) -> Dict:
return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod
@ -142,8 +141,8 @@ class SSDPDatagram(object):
return packet
@classmethod
def _lines_to_content_dict(cls, lines: list) -> dict:
result = {}
def _lines_to_content_dict(cls, lines: list) -> Dict:
result: dict = {}
for line in lines:
if not line:
continue
@ -175,13 +174,13 @@ class SSDPDatagram(object):
return cls._from_response(lines[1:])
@classmethod
def _from_response(cls, lines):
def _from_response(cls, lines: List):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines):
def _from_notify(cls, lines: List):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines):
def _from_request(cls, lines: List):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))

View file

@ -3,7 +3,9 @@ from aioupnp.serialization.soap import serialize_soap_post
class TestSOAPSerialization(unittest.TestCase):
method, param_names, gateway_address, kwargs = "GetExternalIPAddress", [], b'10.0.0.1', {}
param_names: list = []
kwargs: dict = {}
method, gateway_address = "GetExternalIPAddress", b'10.0.0.1'
st, lan_address, path = b'urn:schemas-upnp-org:service:WANIPConnection:1', '10.0.0.1', b'/soap.cgi?service=WANIPConn1'
expected_result = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\n' \
b'Host: 10.0.0.1\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' \
@ -15,7 +17,6 @@ class TestSOAPSerialization(unittest.TestCase):
b'<s:Body><u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">' \
b'</u:GetExternalIPAddress></s:Body></s:Envelope>\r\n'
def test_serialize_get(self):
self.assertEqual(serialize_soap_post(
self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs

View file

@ -1,6 +1,7 @@
import unittest
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.fault import UPnPError
from aioupnp.constants import UPNP_ORG_IGD
class TestParseMSearchRequest(unittest.TestCase):
@ -11,7 +12,7 @@ class TestParseMSearchRequest(unittest.TestCase):
b'MX: 1\r\n' \
b'\r\n'
def test_parse_m_search_response(self):
def test_parse_m_search(self):
packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._M_SEARCH)
self.assertEqual(packet.host, '239.255.255.250:1900')
@ -46,6 +47,30 @@ class TestParseMSearchResponse(unittest.TestCase):
)
class TestParseMSearchResponseRedSonic(TestParseMSearchResponse):
datagram = \
b"HTTP/1.1 200 OK\r\n" \
b"CACHE-CONTROL: max-age=1800\r\n" \
b"DATE: Thu, 04 Oct 2018 22:59:40 GMT\r\n" \
b"EXT:\r\n" \
b"LOCATION: http://10.1.10.1:49152/IGDdevicedesc_brlan0.xml\r\n" \
b"OPT: \"http://schemas.upnp.org/upnp/1/0/\"; ns=01\r\n" \
b"01-NLS: 00000000-0000-0000-0000-000000000000\r\n" \
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
b"X-User-Agent: redsonic\r\n" \
b"ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" \
b"USN: uuid:00000000-0000-0000-0000-000000000000::urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" \
b"\r\n"
def test_parse_m_search_response(self):
packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._OK)
self.assertEqual(packet.cache_control, 'max-age=1800')
self.assertEqual(packet.location, 'http://10.1.10.1:49152/IGDdevicedesc_brlan0.xml')
self.assertEqual(packet.server, 'Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22')
self.assertEqual(packet.st, UPNP_ORG_IGD)
class TestParseMSearchResponseDashCacheControl(TestParseMSearchResponse):
datagram = "\r\n".join([
'HTTP/1.1 200 OK',

View file

@ -1,8 +1,9 @@
import os
import socket
import logging
import json
import asyncio
import functools
from typing import Tuple, Dict, List, Union
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.constants import UPNP_ORG_IGD
@ -12,19 +13,9 @@ from aioupnp.protocols.ssdp import m_search
log = logging.getLogger(__name__)
def cli(format_result=None):
def _cli(fn):
@functools.wraps(fn)
async def _inner(*args, **kwargs):
result = await fn(*args, **kwargs)
if not format_result or not result or not isinstance(result, (list, dict, tuple)):
return result
self = args[0]
return {k: v for k, v in zip(getattr(self.gateway.commands, format_result).return_order, result)}
f = _inner
f._cli = True
return f
return _cli
def cli(fn):
fn._cli = True
return fn
def _encode(x):
@ -36,14 +27,14 @@ def _encode(x):
class UPnP:
def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway):
def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway) -> None:
self.lan_address = lan_address
self.gateway_address = gateway_address
self.gateway = gateway
@classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> (str, str):
interface_name: str = 'default') -> Tuple[str, str]:
if not lan_address or not gateway_address:
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
lan_address = lan_address or lan_addr
@ -52,18 +43,19 @@ class UPnP:
@classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = UPNP_ORG_IGD, interface_name: str = 'default'):
service: str = UPNP_ORG_IGD, interface_name: str = 'default',
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service)
gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service, ssdp_socket, soap_socket)
return cls(lan_address, gateway_address, gateway)
@classmethod
@cli()
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> dict:
service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> Dict:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
datagram = await m_search(lan_address, gateway_address, timeout, service)
return {
@ -72,32 +64,39 @@ class UPnP:
'discover_reply': datagram.as_dict()
}
@cli()
@cli
async def get_external_ip(self) -> str:
return await self.gateway.commands.GetExternalIPAddress()
@cli("AddPortMapping")
@cli
async def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
description: str) -> None:
return await self.gateway.commands.AddPortMapping(
await self.gateway.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
)
return
@cli("GetGenericPortMappingEntry")
async def get_port_mapping_by_index(self, index: int) -> dict:
return await self._get_port_mapping_by_index(index)
@cli
async def get_port_mapping_by_index(self, index: int) -> Dict:
result = await self._get_port_mapping_by_index(index)
if result:
return {
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
}
return {}
async def _get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
async def _get_port_mapping_by_index(self, index: int) -> Union[Tuple[str, int, str, int, str, bool, str, int],
None]:
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect
except UPnPError:
return
return None
@cli()
async def get_redirects(self) -> list:
@cli
async def get_redirects(self) -> List[Dict]:
redirects = []
cnt = 0
redirect = await self.get_port_mapping_by_index(cnt)
@ -107,8 +106,8 @@ class UPnP:
redirect = await self.get_port_mapping_by_index(cnt)
return redirects
@cli("GetSpecificPortMappingEntry")
async def get_specific_port_mapping(self, external_port: int, protocol: str):
@cli
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Dict:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
@ -116,13 +115,14 @@ class UPnP:
"""
try:
return await self.gateway.commands.GetSpecificPortMappingEntry(
result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
except UPnPError:
return
return {}
@cli()
@cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
"""
:param external_port: (int) external port to listen on
@ -133,7 +133,7 @@ class UPnP:
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
)
@cli("AddPortMapping")
@cli
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
@ -163,14 +163,14 @@ class UPnP:
)
return port
@cli()
async def get_soap_commands(self) -> dict:
@cli
async def get_soap_commands(self) -> Dict:
return {
'supported': list(self.gateway._registered_commands.keys()),
'unsupported': self.gateway._unsupported_actions
}
@cli()
@cli
async def generate_test_data(self):
external_ip = await self.get_external_ip()
redirects = await self.get_redirects()
@ -216,13 +216,13 @@ class UPnP:
@classmethod
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
service: str = UPNP_ORG_IGD, interface_name: str = 'default',
kwargs: dict = None):
kwargs: dict = None) -> None:
kwargs = kwargs or {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
fut = asyncio.Future()
fut: asyncio.Future = asyncio.Future()
async def wrapper():
if method == 'm_search':

View file

@ -1,6 +1,7 @@
import re
import socket
from collections import defaultdict
from typing import Tuple, Dict
from xml.etree import ElementTree
import netifaces
@ -9,15 +10,17 @@ BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t: ElementTree) -> dict:
d = {t.tag: {} if t.attrib else None}
def etree_to_dict(t: ElementTree.Element) -> Dict:
d: dict = {}
if t.attrib:
d[t.tag] = {}
children = list(t)
if children:
dd = defaultdict(list)
dd: dict = defaultdict(list)
for dc in map(etree_to_dict, children):
for k, v in dc.items():
dd[k].append(v)
d = {t.tag: {k: v[0] if len(v) == 1 else v for k, v in dd.items()}}
d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()}
if t.attrib:
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
if t.text:
@ -81,8 +84,8 @@ def get_interfaces():
return r
def get_gateway_and_lan_addresses(interface_name: str) -> (str, str):
def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return None, None
return '', ''

4
mypy.ini Normal file
View file

@ -0,0 +1,4 @@
[mypy]
python_version = 3.7
mypy_path=stubs
cache_dir=/dev/null

View file

@ -1,5 +1,5 @@
import os
from setuptools import setup, find_packages
from setuptools import setup, find_packages # type: ignore
from aioupnp import __version__, __name__, __email__, __author__, __license__
console_scripts = [

111
stubs/netifaces.py Normal file
View file

@ -0,0 +1,111 @@
import typing
AF_APPLETALK = 5
AF_ASH = 18
AF_ATMPVC = 8
AF_ATMSVC = 20
AF_AX25 = 3
AF_BLUETOOTH = 31
AF_BRIDGE = 7
AF_DECnet = 12
AF_ECONET = 19
AF_FILE = 1
AF_INET = 2
AF_INET6 = 10
AF_IPX = 4
AF_IRDA = 23
AF_ISDN = 34
AF_KEY = 15
AF_LINK = 17
AF_NETBEUI = 13
AF_NETLINK = 16
AF_NETROM = 6
AF_PACKET = 17
AF_PPPOX = 24
AF_ROSE = 11
AF_ROUTE = 16
AF_SECURITY = 14
AF_SNA = 22
AF_UNIX = 1
AF_UNSPEC = 0
AF_WANPIPE = 25
AF_X25 = 9
version = '0.10.7'
# functions
def gateways(*args, **kwargs) -> typing.List: # real signature unknown
"""
Obtain a list of the gateways on this machine.
Returns a dict whose keys are equal to the address family constants,
e.g. netifaces.AF_INET, and whose values are a list of tuples of the
format (<address>, <interface>, <is_default>).
There is also a special entry with the key 'default', which you can use
to quickly obtain the default gateway for a particular address family.
There may in general be multiple gateways; different address
families may have different gateway settings (e.g. AF_INET vs AF_INET6)
and on some systems it's also possible to have interface-specific
default gateways.
"""
pass
def ifaddresses(*args, **kwargs) -> typing.Dict: # real signature unknown
"""
Obtain information about the specified network interface.
Returns a dict whose keys are equal to the address family constants,
e.g. netifaces.AF_INET, and whose values are a list of addresses in
that family that are attached to the network interface.
"""
pass
def interfaces(*args, **kwargs) -> typing.List: # real signature unknown
""" Obtain a list of the interfaces available on this machine. """
pass
# no classes
# variables with complex values
address_families = {
0: 'AF_UNSPEC',
1: 'AF_FILE',
2: 'AF_INET',
3: 'AF_AX25',
4: 'AF_IPX',
5: 'AF_APPLETALK',
6: 'AF_NETROM',
7: 'AF_BRIDGE',
8: 'AF_ATMPVC',
9: 'AF_X25',
10: 'AF_INET6',
11: 'AF_ROSE',
12: 'AF_DECnet',
13: 'AF_NETBEUI',
14: 'AF_SECURITY',
15: 'AF_KEY',
16: 'AF_NETLINK',
17: 'AF_PACKET',
18: 'AF_ASH',
19: 'AF_ECONET',
20: 'AF_ATMSVC',
22: 'AF_SNA',
23: 'AF_IRDA',
24: 'AF_PPPOX',
25: 'AF_WANPIPE',
31: 'AF_BLUETOOTH',
34: 'AF_ISDN',
}
__loader__ = None # (!) real value is ''
__spec__ = None # (!) real value is ''