This commit is contained in:
Jack Robison 2018-10-08 14:47:37 -04:00
parent 0eca7478e2
commit 1098177bae
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
18 changed files with 335 additions and 159 deletions

1
.gitignore vendored
View file

@ -4,3 +4,4 @@ _trial_temp/
build/ build/
dist/ dist/
.coverage .coverage
.mypy_cache/

View file

@ -4,11 +4,12 @@ language: python
python: "3.7" python: "3.7"
before_install: before_install:
- pip install pylint coverage - pip install pylint coverage mypy
- pip install -e . - pip install -e .
# - pylint aioupnp
script: script:
# - pylint aioupnp
- mypy .
- HOME=/tmp coverage run --source=aioupnp -m unittest -v - HOME=/tmp coverage run --source=aioupnp -m unittest -v
after_success: after_success:

View file

@ -1,5 +1,6 @@
def none_or_str(x): from typing import Tuple, Union
return None if not x or x == 'None' else str(x)
none_or_str = Union[None, str]
class SCPDCommands: class SCPDCommands:
@ -14,12 +15,13 @@ class SCPDCommands:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def GetNATRSIPStatus() -> (bool, bool): async def GetNATRSIPStatus() -> Tuple[bool, bool]:
"""Returns (NewRSIPAvailable, NewNATEnabled)""" """Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> (none_or_str, int, str, int, str, bool, str, int): async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[none_or_str, int, str, int, str,
bool, str, int]:
""" """
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration) NewPortMappingDescription, NewLeaseDuration)
@ -27,7 +29,8 @@ class SCPDCommands:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> (int, str, bool, str, int): async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError() raise NotImplementedError()
@ -42,12 +45,12 @@ class SCPDCommands:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def GetConnectionTypeInfo() -> (str, str): async def GetConnectionTypeInfo() -> Tuple[str, str]:
"""Returns (NewConnectionType, NewPossibleConnectionTypes)""" """Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def GetStatusInfo() -> (str, str, int): async def GetStatusInfo() -> Tuple[str, str, int]:
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError() raise NotImplementedError()
@ -92,7 +95,7 @@ class SCPDCommands:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def X_GetICSStatistics() -> (int, int, int, int, str, str): async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]:
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)""" """Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError() raise NotImplementedError()
@ -119,6 +122,6 @@ class SCPDCommands:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
async def GetActiveConnections() -> (str, str): async def GetActiveConnections() -> Tuple[str, str]:
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID""" """Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError() raise NotImplementedError()

View file

@ -1,10 +1,11 @@
import logging import logging
from typing import List
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class CaseInsensitive: class CaseInsensitive:
def __init__(self, **kwargs): def __init__(self, **kwargs) -> None:
not_evaluated = {} not_evaluated = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
if k.startswith("_"): if k.startswith("_"):
@ -22,6 +23,7 @@ class CaseInsensitive:
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k.lower() == case_insensitive.lower(): if k.lower() == case_insensitive.lower():
return k return k
raise AttributeError(case_insensitive)
def __getattr__(self, item): def __getattr__(self, item):
if item in self.__dict__: if item in self.__dict__:
@ -75,7 +77,7 @@ class Device(CaseInsensitive):
presentationURL = None presentationURL = None
iconList = None iconList = None
def __init__(self, devices, services, **kwargs): def __init__(self, devices: List, services: List, **kwargs) -> None:
super(Device, self).__init__(**kwargs) super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList: if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"] new_services = self.serviceList["service"]

View file

@ -1,4 +1,7 @@
import logging import logging
import socket
import typing
from typing import Dict, List, Union, Type, Tuple
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
from aioupnp.commands import SCPDCommands from aioupnp.commands import SCPDCommands
@ -7,11 +10,16 @@ from aioupnp.protocols.ssdp import m_search
from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.scpd import scpd_get
from aioupnp.protocols.soap import SCPDCommand from aioupnp.protocols.soap import SCPDCommand
from aioupnp.util import flatten_keys from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
return_type_lambas = {
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
}
def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
service_info = flatten_keys(element_dict, "{%s}" % SERVICE) service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "actionList" in service_info: if "actionList" in service_info:
action_list = service_info["actionList"] action_list = service_info["actionList"]
@ -20,7 +28,7 @@ def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...],
if not len(action_list): # it could be an empty string if not len(action_list): # it could be an empty string
return [] return []
result = [] result: list = []
if isinstance(action_list["action"], dict): if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument'] arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg if not isinstance(arg_dicts, list): # when there is one arg
@ -97,21 +105,22 @@ class Gateway:
return r return r
@property @property
def services(self) -> dict: def services(self) -> Dict:
if not self._device: if not self._device:
return {} return {}
return {service.serviceType: service for service in self._services} return {service.serviceType: service for service in self._services}
@property @property
def devices(self) -> dict: def devices(self) -> Dict:
if not self._device: if not self._device:
return {} return {}
return {device.udn: device for device in self._devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Service: def get_service(self, service_type: str) -> Union[Type[Service], None]:
for service in self._services: for service in self._services:
if service.serviceType.lower() == service_type.lower(): if service.serviceType.lower() == service_type.lower():
return service return service
return None
def debug_commands(self): def debug_commands(self):
return { return {
@ -121,13 +130,14 @@ class Gateway:
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD): service: str = UPNP_ORG_IGD, ssdp_socket: socket.socket = None,
datagram = await m_search(lan_address, gateway_address, timeout, service) soap_socket: socket.socket = None):
datagram = await m_search(lan_address, gateway_address, timeout, service, ssdp_socket)
gateway = cls(**datagram.as_dict()) gateway = cls(**datagram.as_dict())
await gateway.discover_commands() await gateway.discover_commands(soap_socket)
return gateway return gateway
async def discover_commands(self): async def discover_commands(self, soap_socket: socket.socket = None):
response = await scpd_get("/" + self.path.decode(), self.base_ip.decode(), self.port) response = await scpd_get("/" + self.path.decode(), self.base_ip.decode(), self.port)
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
@ -141,9 +151,11 @@ class Gateway:
else: else:
self._device = Device(self._devices, self._services) self._device = Device(self._devices, self._services)
for service_type in self.services.keys(): for service_type in self.services.keys():
await self.register_commands(self.services[service_type]) await self.register_commands(self.services[service_type], soap_socket)
async def register_commands(self, service: Service): async def register_commands(self, service: Service, soap_socket: socket.socket = None):
if not service.SCPDURL:
raise UPnPError("no scpd url")
service_dict = await scpd_get(("" if service.SCPDURL.startswith("/") else "/") + service.SCPDURL, service_dict = await scpd_get(("" if service.SCPDURL.startswith("/") else "/") + service.SCPDURL,
self.base_ip.decode(), self.port) self.base_ip.decode(), self.port)
if not service_dict: if not service_dict:
@ -157,8 +169,10 @@ class Gateway:
annotations = current.__annotations__ annotations = current.__annotations__
return_types = annotations.get('return', None) return_types = annotations.get('return', None)
if return_types: if return_types:
if not isinstance(return_types, tuple): if isinstance(return_types, type):
return_types = (return_types, ) return_types = (return_types, )
else:
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
return_types = {r: t for r, t in zip(outputs, return_types)} return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {} param_types = {}
for param_name, param_type in annotations.items(): for param_name, param_type in annotations.items():
@ -167,7 +181,7 @@ class Gateway:
param_types[param_name] = param_type param_types[param_name] = param_type
command = SCPDCommand( command = SCPDCommand(
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(), self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
name, param_types, return_types, inputs, outputs) name, param_types, return_types, inputs, outputs, soap_socket)
setattr(command, "__doc__", current.__doc__) setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command) setattr(self.commands, command.method, command)

View file

@ -5,42 +5,44 @@ from asyncio.transports import DatagramTransport
class MulticastProtocol(DatagramProtocol): class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address
self.bind_address = bind_address
self.transport: DatagramTransport
@property @property
def socket(self) -> socket.socket: def sock(self) -> socket.socket:
return self.transport.get_extra_info('socket') s: socket.socket = self.transport.get_extra_info(name='socket')
return s
def get_ttl(self) -> int: def get_ttl(self) -> int:
return self.socket.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
def set_ttl(self, ttl: int = 1) -> None: def set_ttl(self, ttl: int = 1) -> None:
self.socket.setsockopt( self.sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
) )
def join_group(self, addr: str, interface: str) -> None: def join_group(self, multicast_address: str, bind_address: str) -> None:
addr = socket.inet_aton(addr) self.sock.setsockopt(
interface = socket.inet_aton(interface)
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, addr + interface)
def leave_group(self, addr: str, interface: str) -> None:
addr = socket.inet_aton(addr)
interface = socket.inet_aton(interface)
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, addr + interface)
def connection_made(self, transport: DatagramTransport) -> None:
self.transport = transport
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@classmethod
def create_socket(cls, bind_address: str, multicast_address: str):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0))
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
) )
def leave_group(self, multicast_address: str, bind_address: str) -> None:
self.sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
def connection_made(self, transport) -> None:
self.transport = transport
self.join_group(self.multicast_address, self.bind_address)
@classmethod
def create_multicast_socket(cls, bind_address: str):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0))
sock.setblocking(False) sock.setblocking(False)
return sock return sock

View file

@ -1,4 +1,5 @@
import logging import logging
import socket
from xml.etree import ElementTree from xml.etree import ElementTree
import asyncio import asyncio
from asyncio.protocols import Protocol from asyncio.protocols import Protocol
@ -16,7 +17,7 @@ class SCPDHTTPClientProtocol(Protocol):
GET = 'GET' GET = 'GET'
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None, def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None, close_after_send: bool = False): soap_service_id: str=None, close_after_send: bool = False) -> None:
self.method = method self.method = method
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \ assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
'soap args not provided' 'soap args not provided'
@ -60,7 +61,7 @@ class SCPDHTTPClientProtocol(Protocol):
async def scpd_get(control_url: str, address: str, port: int) -> dict: async def scpd_get(control_url: str, address: str, port: int) -> dict:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
finished = asyncio.Future() finished: asyncio.Future = asyncio.Future()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection( transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
@ -72,15 +73,15 @@ async def scpd_get(control_url: str, address: str, port: int) -> dict:
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
close_after_send: bool, **kwargs): close_after_send: bool, soap_socket: socket.socket = None, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
finished = asyncio.Future() finished: asyncio.Future = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection( transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol( lambda : SCPDHTTPClientProtocol(
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(), 'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
close_after_send=close_after_send close_after_send=close_after_send
), address, port ), address, port, sock=soap_socket
) )
try: try:
return await asyncio.wait_for(finished, 1.0) return await asyncio.wait_for(finished, 1.0)

View file

@ -1,4 +1,6 @@
import logging import logging
import socket
import typing
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -6,7 +8,8 @@ log = logging.getLogger(__name__)
class SCPDCommand: class SCPDCommand:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
param_types: dict, return_types: dict, param_order: list, return_order: list): param_types: dict, return_types: dict, param_order: list, return_order: list,
soap_socket: socket.socket = None) -> None:
self.gateway_address = gateway_address self.gateway_address = gateway_address
self.service_port = service_port self.service_port = service_port
self.control_url = control_url self.control_url = control_url
@ -16,17 +19,19 @@ class SCPDCommand:
self.param_order = param_order self.param_order = param_order
self.return_types = return_types self.return_types = return_types
self.return_order = return_order self.return_order = return_order
self.soap_socket = soap_socket
async def __call__(self, **kwargs): async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()): if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
close_after_send = not self.return_types or self.return_types == [None] close_after_send = not self.return_types or self.return_types == [None]
response = await scpd_post( response = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id, self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
close_after_send, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()} close_after_send, self.soap_socket, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
) )
extracted_response = tuple([None if self.return_types[n] is None else self.return_types[n](response[n]) result = tuple([self.return_types[n](response.get(n)) for n in self.return_order])
for n in self.return_order]) or (None, ) if not result:
if len(extracted_response) == 1: return None
return extracted_response[0] if len(result) == 1:
return extracted_response return result[0]
return result

View file

@ -1,9 +1,9 @@
import re import re
import socket
import binascii import binascii
import asyncio import asyncio
import logging import logging
from typing import DefaultDict from typing import Dict, List, Tuple
from asyncio.coroutines import coroutine
from asyncio.futures import Future from asyncio.futures import Future
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
@ -17,17 +17,12 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, lan_address): def __init__(self, multicast_address: str, lan_address: str) -> None:
super().__init__() super().__init__(multicast_address, lan_address)
self.lan_address = lan_address self.lan_address = lan_address
self.discover_callbacks: DefaultDict[coroutine] = {} self.discover_callbacks: Dict = {}
self.transport: DatagramTransport self.notifications: List = []
self.notifications = [] self.replies: List = []
self.replies = []
def connection_made(self, transport: DatagramTransport):
super().connection_made(transport)
self.set_ttl(1)
async def m_search(self, address, timeout: int = 1, service=UPNP_ORG_IGD) -> SSDPDatagram: async def m_search(self, address, timeout: int = 1, service=UPNP_ORG_IGD) -> SSDPDatagram:
if (address, service) in self.discover_callbacks: if (address, service) in self.discover_callbacks:
@ -37,7 +32,7 @@ class SSDPProtocol(MulticastProtocol):
mx=1 mx=1
) )
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
f = Future() f: Future = Future()
self.discover_callbacks[(address, service)] = f self.discover_callbacks[(address, service)] = f
return await asyncio.wait_for(f, timeout) return await asyncio.wait_for(f, timeout)
@ -57,8 +52,9 @@ class SSDPProtocol(MulticastProtocol):
if (addr[0], packet.st) in self.discover_callbacks: if (addr[0], packet.st) in self.discover_callbacks:
if packet.st not in map(lambda p: p['st'], self.replies): if packet.st not in map(lambda p: p['st'], self.replies):
self.replies.append(packet) self.replies.append(packet)
f: Future = self.discover_callbacks.pop((addr[0], packet.st)) ok_fut: Future = self.discover_callbacks.pop((addr[0], packet.st))
f.set_result(packet) ok_fut.set_result(packet)
return
elif packet._packet_type == packet._NOTIFY: elif packet._packet_type == packet._NOTIFY:
if packet.nt == SSDP_ROOT_DEVICE: if packet.nt == SSDP_ROOT_DEVICE:
@ -70,21 +66,27 @@ class SSDPProtocol(MulticastProtocol):
break break
if key: if key:
log.debug("got a notification with the requested m-search info") log.debug("got a notification with the requested m-search info")
f: Future = self.discover_callbacks.pop(key) notify_fut: Future = self.discover_callbacks.pop(key)
f.set_result(SSDPDatagram( notify_fut.set_result(SSDPDatagram(
SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server, SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server,
st=UPNP_ORG_IGD, usn=packet.usn st=UPNP_ORG_IGD, usn=packet.usn
)) ))
self.notifications.append(packet.as_dict()) self.notifications.append(packet.as_dict())
return
async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransport, SSDPProtocol, str, str): async def listen_ssdp(lan_address: str, gateway_address: str,
ssdp_socket: socket.socket = None) -> Tuple[DatagramTransport, SSDPProtocol,
str, str]:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
sock = SSDPProtocol.create_socket(lan_address, SSDP_IP_ADDRESS) sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
transport, protocol = await loop.create_datagram_endpoint( listen_result: Tuple = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(lan_address), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
) )
transport: DatagramTransport = listen_result[0]
protocol: SSDPProtocol = listen_result[1]
protocol.set_ttl(1)
except Exception: except Exception:
log.exception("failed to create multicast socket %s:%i", lan_address, SSDP_PORT) log.exception("failed to create multicast socket %s:%i", lan_address, SSDP_PORT)
raise raise
@ -92,12 +94,12 @@ async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransp
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1, async def m_search(lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD) -> SSDPDatagram: service: str = UPNP_ORG_IGD, ssdp_socket: socket.socket = None) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp( transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address lan_address, gateway_address, ssdp_socket
) )
try: try:
return await protocol.m_search(gateway_address, timeout=timeout, service=service) return await protocol.m_search(address=gateway_address, timeout=timeout, service=service)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:

View file

@ -1,4 +1,5 @@
import re import re
from typing import Dict
from xml.etree import ElementTree from xml.etree import ElementTree
from aioupnp.constants import XML_VERSION, DEVICE, ROOT from aioupnp.constants import XML_VERSION, DEVICE, ROOT
from aioupnp.util import etree_to_dict, flatten_keys from aioupnp.util import etree_to_dict, flatten_keys
@ -33,7 +34,7 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
).encode() ).encode()
def deserialize_scpd_get_response(content: bytes) -> dict: def deserialize_scpd_get_response(content: bytes) -> Dict:
if XML_VERSION.encode() in content: if XML_VERSION.encode() in content:
parsed = CONTENT_PATTERN.findall(content) parsed = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0] content = b'' if not parsed else parsed[0][0]
@ -47,3 +48,4 @@ def deserialize_scpd_get_response(content: bytes) -> dict:
root = m[2][5] root = m[2][5]
break break
return flatten_keys(xml_dict, "{%s}" % schema_key)[root] return flatten_keys(xml_dict, "{%s}" % schema_key)[root]
return {}

View file

@ -1,6 +1,7 @@
import re import re
import logging import logging
import binascii import binascii
from typing import Dict, List
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator from aioupnp.constants import line_separator
@ -69,13 +70,8 @@ class SSDPDatagram(object):
] ]
} }
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None, def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs): cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]: if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type)) raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type self._packet_type = packet_type
@ -95,8 +91,9 @@ class SSDPDatagram(object):
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None: if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v) setattr(self, k.lower(), v)
def __repr__(self): def __repr__(self) -> str:
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")" return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
def __getitem__(self, item): def __getitem__(self, item):
for i in self._required_fields[self._packet_type]: for i in self._required_fields[self._packet_type]:
@ -104,17 +101,19 @@ class SSDPDatagram(object):
return getattr(self, i) return getattr(self, i)
raise KeyError(item) raise KeyError(item)
def get_friendly_name(self): def get_friendly_name(self) -> str:
return self._friendly_names[self._packet_type] return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2): def encode(self, trailing_newlines: int = 2) -> str:
lines = [self._start_lines[self._packet_type]] lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]: for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name) attr = getattr(self, attr_name)
if attr is None: if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name)) raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers: if attr_name == 'mx':
value = self._marshallers[attr_name](attr) value = str(attr)
elif attr_name == 'man':
value = "\"%s\"" % attr
else: else:
value = attr value = attr
lines.append("{}: {}".format(attr_name.upper(), value)) lines.append("{}: {}".format(attr_name.upper(), value))
@ -123,7 +122,7 @@ class SSDPDatagram(object):
serialized += line_separator serialized += line_separator
return serialized return serialized
def as_dict(self): def as_dict(self) -> Dict:
return self._lines_to_content_dict(self.encode().split(line_separator)) return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod @classmethod
@ -142,8 +141,8 @@ class SSDPDatagram(object):
return packet return packet
@classmethod @classmethod
def _lines_to_content_dict(cls, lines: list) -> dict: def _lines_to_content_dict(cls, lines: list) -> Dict:
result = {} result: dict = {}
for line in lines: for line in lines:
if not line: if not line:
continue continue
@ -175,13 +174,13 @@ class SSDPDatagram(object):
return cls._from_response(lines[1:]) return cls._from_response(lines[1:])
@classmethod @classmethod
def _from_response(cls, lines): def _from_response(cls, lines: List):
return cls(cls._OK, **cls._lines_to_content_dict(lines)) return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_notify(cls, lines): def _from_notify(cls, lines: List):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines)) return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod @classmethod
def _from_request(cls, lines): def _from_request(cls, lines: List):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines)) return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))

View file

@ -3,7 +3,9 @@ from aioupnp.serialization.soap import serialize_soap_post
class TestSOAPSerialization(unittest.TestCase): class TestSOAPSerialization(unittest.TestCase):
method, param_names, gateway_address, kwargs = "GetExternalIPAddress", [], b'10.0.0.1', {} param_names: list = []
kwargs: dict = {}
method, gateway_address = "GetExternalIPAddress", b'10.0.0.1'
st, lan_address, path = b'urn:schemas-upnp-org:service:WANIPConnection:1', '10.0.0.1', b'/soap.cgi?service=WANIPConn1' st, lan_address, path = b'urn:schemas-upnp-org:service:WANIPConnection:1', '10.0.0.1', b'/soap.cgi?service=WANIPConn1'
expected_result = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\n' \ expected_result = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\n' \
b'Host: 10.0.0.1\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' \ b'Host: 10.0.0.1\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' \
@ -15,7 +17,6 @@ class TestSOAPSerialization(unittest.TestCase):
b'<s:Body><u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">' \ b'<s:Body><u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">' \
b'</u:GetExternalIPAddress></s:Body></s:Envelope>\r\n' b'</u:GetExternalIPAddress></s:Body></s:Envelope>\r\n'
def test_serialize_get(self): def test_serialize_get(self):
self.assertEqual(serialize_soap_post( self.assertEqual(serialize_soap_post(
self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs

View file

@ -1,6 +1,7 @@
import unittest import unittest
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.constants import UPNP_ORG_IGD
class TestParseMSearchRequest(unittest.TestCase): class TestParseMSearchRequest(unittest.TestCase):
@ -11,7 +12,7 @@ class TestParseMSearchRequest(unittest.TestCase):
b'MX: 1\r\n' \ b'MX: 1\r\n' \
b'\r\n' b'\r\n'
def test_parse_m_search_response(self): def test_parse_m_search(self):
packet = SSDPDatagram.decode(self.datagram) packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._M_SEARCH) self.assertTrue(packet._packet_type, packet._M_SEARCH)
self.assertEqual(packet.host, '239.255.255.250:1900') self.assertEqual(packet.host, '239.255.255.250:1900')
@ -46,6 +47,30 @@ class TestParseMSearchResponse(unittest.TestCase):
) )
class TestParseMSearchResponseRedSonic(TestParseMSearchResponse):
datagram = \
b"HTTP/1.1 200 OK\r\n" \
b"CACHE-CONTROL: max-age=1800\r\n" \
b"DATE: Thu, 04 Oct 2018 22:59:40 GMT\r\n" \
b"EXT:\r\n" \
b"LOCATION: http://10.1.10.1:49152/IGDdevicedesc_brlan0.xml\r\n" \
b"OPT: \"http://schemas.upnp.org/upnp/1/0/\"; ns=01\r\n" \
b"01-NLS: 00000000-0000-0000-0000-000000000000\r\n" \
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
b"X-User-Agent: redsonic\r\n" \
b"ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" \
b"USN: uuid:00000000-0000-0000-0000-000000000000::urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" \
b"\r\n"
def test_parse_m_search_response(self):
packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._OK)
self.assertEqual(packet.cache_control, 'max-age=1800')
self.assertEqual(packet.location, 'http://10.1.10.1:49152/IGDdevicedesc_brlan0.xml')
self.assertEqual(packet.server, 'Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22')
self.assertEqual(packet.st, UPNP_ORG_IGD)
class TestParseMSearchResponseDashCacheControl(TestParseMSearchResponse): class TestParseMSearchResponseDashCacheControl(TestParseMSearchResponse):
datagram = "\r\n".join([ datagram = "\r\n".join([
'HTTP/1.1 200 OK', 'HTTP/1.1 200 OK',

View file

@ -1,8 +1,9 @@
import os import os
import socket
import logging import logging
import json import json
import asyncio import asyncio
import functools from typing import Tuple, Dict, List, Union
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.constants import UPNP_ORG_IGD from aioupnp.constants import UPNP_ORG_IGD
@ -12,19 +13,9 @@ from aioupnp.protocols.ssdp import m_search
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def cli(format_result=None): def cli(fn):
def _cli(fn): fn._cli = True
@functools.wraps(fn) return fn
async def _inner(*args, **kwargs):
result = await fn(*args, **kwargs)
if not format_result or not result or not isinstance(result, (list, dict, tuple)):
return result
self = args[0]
return {k: v for k, v in zip(getattr(self.gateway.commands, format_result).return_order, result)}
f = _inner
f._cli = True
return f
return _cli
def _encode(x): def _encode(x):
@ -36,14 +27,14 @@ def _encode(x):
class UPnP: class UPnP:
def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway): def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway) -> None:
self.lan_address = lan_address self.lan_address = lan_address
self.gateway_address = gateway_address self.gateway_address = gateway_address
self.gateway = gateway self.gateway = gateway
@classmethod @classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> (str, str): interface_name: str = 'default') -> Tuple[str, str]:
if not lan_address or not gateway_address: if not lan_address or not gateway_address:
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name) gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
lan_address = lan_address or lan_addr lan_address = lan_address or lan_addr
@ -52,18 +43,19 @@ class UPnP:
@classmethod @classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = UPNP_ORG_IGD, interface_name: str = 'default'): service: str = UPNP_ORG_IGD, interface_name: str = 'default',
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
try: try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
except Exception as err: except Exception as err:
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err)) raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service) gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service, ssdp_socket, soap_socket)
return cls(lan_address, gateway_address, gateway) return cls(lan_address, gateway_address, gateway)
@classmethod @classmethod
@cli() @cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> dict: service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> Dict:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
datagram = await m_search(lan_address, gateway_address, timeout, service) datagram = await m_search(lan_address, gateway_address, timeout, service)
return { return {
@ -72,32 +64,39 @@ class UPnP:
'discover_reply': datagram.as_dict() 'discover_reply': datagram.as_dict()
} }
@cli() @cli
async def get_external_ip(self) -> str: async def get_external_ip(self) -> str:
return await self.gateway.commands.GetExternalIPAddress() return await self.gateway.commands.GetExternalIPAddress()
@cli("AddPortMapping") @cli
async def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str, async def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
description: str) -> None: description: str) -> None:
return await self.gateway.commands.AddPortMapping( await self.gateway.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol, NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address, NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration="" NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
) )
return
@cli("GetGenericPortMappingEntry") @cli
async def get_port_mapping_by_index(self, index: int) -> dict: async def get_port_mapping_by_index(self, index: int) -> Dict:
return await self._get_port_mapping_by_index(index) result = await self._get_port_mapping_by_index(index)
if result:
return {
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
}
return {}
async def _get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int): async def _get_port_mapping_by_index(self, index: int) -> Union[Tuple[str, int, str, int, str, bool, str, int],
None]:
try: try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect return redirect
except UPnPError: except UPnPError:
return return None
@cli() @cli
async def get_redirects(self) -> list: async def get_redirects(self) -> List[Dict]:
redirects = [] redirects = []
cnt = 0 cnt = 0
redirect = await self.get_port_mapping_by_index(cnt) redirect = await self.get_port_mapping_by_index(cnt)
@ -107,8 +106,8 @@ class UPnP:
redirect = await self.get_port_mapping_by_index(cnt) redirect = await self.get_port_mapping_by_index(cnt)
return redirects return redirects
@cli("GetSpecificPortMappingEntry") @cli
async def get_specific_port_mapping(self, external_port: int, protocol: str): async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Dict:
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
@ -116,13 +115,14 @@ class UPnP:
""" """
try: try:
return await self.gateway.commands.GetSpecificPortMappingEntry( result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
) )
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
except UPnPError: except UPnPError:
return return {}
@cli() @cli
async def delete_port_mapping(self, external_port: int, protocol: str) -> None: async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
@ -133,7 +133,7 @@ class UPnP:
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
) )
@cli("AddPortMapping") @cli
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int: async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
if protocol not in ["UDP", "TCP"]: if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol)) raise UPnPError("unsupported protocol: {}".format(protocol))
@ -163,14 +163,14 @@ class UPnP:
) )
return port return port
@cli() @cli
async def get_soap_commands(self) -> dict: async def get_soap_commands(self) -> Dict:
return { return {
'supported': list(self.gateway._registered_commands.keys()), 'supported': list(self.gateway._registered_commands.keys()),
'unsupported': self.gateway._unsupported_actions 'unsupported': self.gateway._unsupported_actions
} }
@cli() @cli
async def generate_test_data(self): async def generate_test_data(self):
external_ip = await self.get_external_ip() external_ip = await self.get_external_ip()
redirects = await self.get_redirects() redirects = await self.get_redirects()
@ -216,13 +216,13 @@ class UPnP:
@classmethod @classmethod
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60, def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
service: str = UPNP_ORG_IGD, interface_name: str = 'default', service: str = UPNP_ORG_IGD, interface_name: str = 'default',
kwargs: dict = None): kwargs: dict = None) -> None:
kwargs = kwargs or {} kwargs = kwargs or {}
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
fut = asyncio.Future() fut: asyncio.Future = asyncio.Future()
async def wrapper(): async def wrapper():
if method == 'm_search': if method == 'm_search':

View file

@ -1,6 +1,7 @@
import re import re
import socket import socket
from collections import defaultdict from collections import defaultdict
from typing import Tuple, Dict
from xml.etree import ElementTree from xml.etree import ElementTree
import netifaces import netifaces
@ -9,15 +10,17 @@ BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t: ElementTree) -> dict: def etree_to_dict(t: ElementTree.Element) -> Dict:
d = {t.tag: {} if t.attrib else None} d: dict = {}
if t.attrib:
d[t.tag] = {}
children = list(t) children = list(t)
if children: if children:
dd = defaultdict(list) dd: dict = defaultdict(list)
for dc in map(etree_to_dict, children): for dc in map(etree_to_dict, children):
for k, v in dc.items(): for k, v in dc.items():
dd[k].append(v) dd[k].append(v)
d = {t.tag: {k: v[0] if len(v) == 1 else v for k, v in dd.items()}} d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()}
if t.attrib: if t.attrib:
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items()) d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
if t.text: if t.text:
@ -81,8 +84,8 @@ def get_interfaces():
return r return r
def get_gateway_and_lan_addresses(interface_name: str) -> (str, str): def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
for iface_name, (gateway, lan) in get_interfaces().items(): for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name: if interface_name == iface_name:
return gateway, lan return gateway, lan
return None, None return '', ''

4
mypy.ini Normal file
View file

@ -0,0 +1,4 @@
[mypy]
python_version = 3.7
mypy_path=stubs
cache_dir=/dev/null

View file

@ -1,5 +1,5 @@
import os import os
from setuptools import setup, find_packages from setuptools import setup, find_packages # type: ignore
from aioupnp import __version__, __name__, __email__, __author__, __license__ from aioupnp import __version__, __name__, __email__, __author__, __license__
console_scripts = [ console_scripts = [

111
stubs/netifaces.py Normal file
View file

@ -0,0 +1,111 @@
import typing
AF_APPLETALK = 5
AF_ASH = 18
AF_ATMPVC = 8
AF_ATMSVC = 20
AF_AX25 = 3
AF_BLUETOOTH = 31
AF_BRIDGE = 7
AF_DECnet = 12
AF_ECONET = 19
AF_FILE = 1
AF_INET = 2
AF_INET6 = 10
AF_IPX = 4
AF_IRDA = 23
AF_ISDN = 34
AF_KEY = 15
AF_LINK = 17
AF_NETBEUI = 13
AF_NETLINK = 16
AF_NETROM = 6
AF_PACKET = 17
AF_PPPOX = 24
AF_ROSE = 11
AF_ROUTE = 16
AF_SECURITY = 14
AF_SNA = 22
AF_UNIX = 1
AF_UNSPEC = 0
AF_WANPIPE = 25
AF_X25 = 9
version = '0.10.7'
# functions
def gateways(*args, **kwargs) -> typing.List: # real signature unknown
"""
Obtain a list of the gateways on this machine.
Returns a dict whose keys are equal to the address family constants,
e.g. netifaces.AF_INET, and whose values are a list of tuples of the
format (<address>, <interface>, <is_default>).
There is also a special entry with the key 'default', which you can use
to quickly obtain the default gateway for a particular address family.
There may in general be multiple gateways; different address
families may have different gateway settings (e.g. AF_INET vs AF_INET6)
and on some systems it's also possible to have interface-specific
default gateways.
"""
pass
def ifaddresses(*args, **kwargs) -> typing.Dict: # real signature unknown
"""
Obtain information about the specified network interface.
Returns a dict whose keys are equal to the address family constants,
e.g. netifaces.AF_INET, and whose values are a list of addresses in
that family that are attached to the network interface.
"""
pass
def interfaces(*args, **kwargs) -> typing.List: # real signature unknown
""" Obtain a list of the interfaces available on this machine. """
pass
# no classes
# variables with complex values
address_families = {
0: 'AF_UNSPEC',
1: 'AF_FILE',
2: 'AF_INET',
3: 'AF_AX25',
4: 'AF_IPX',
5: 'AF_APPLETALK',
6: 'AF_NETROM',
7: 'AF_BRIDGE',
8: 'AF_ATMPVC',
9: 'AF_X25',
10: 'AF_INET6',
11: 'AF_ROSE',
12: 'AF_DECnet',
13: 'AF_NETBEUI',
14: 'AF_SECURITY',
15: 'AF_KEY',
16: 'AF_NETLINK',
17: 'AF_PACKET',
18: 'AF_ASH',
19: 'AF_ECONET',
20: 'AF_ATMSVC',
22: 'AF_SNA',
23: 'AF_IRDA',
24: 'AF_PPPOX',
25: 'AF_WANPIPE',
31: 'AF_BLUETOOTH',
34: 'AF_ISDN',
}
__loader__ = None # (!) real value is ''
__spec__ = None # (!) real value is ''