mypy
This commit is contained in:
parent
0eca7478e2
commit
1098177bae
18 changed files with 335 additions and 159 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,3 +4,4 @@ _trial_temp/
|
||||||
build/
|
build/
|
||||||
dist/
|
dist/
|
||||||
.coverage
|
.coverage
|
||||||
|
.mypy_cache/
|
|
@ -4,11 +4,12 @@ language: python
|
||||||
python: "3.7"
|
python: "3.7"
|
||||||
|
|
||||||
before_install:
|
before_install:
|
||||||
- pip install pylint coverage
|
- pip install pylint coverage mypy
|
||||||
- pip install -e .
|
- pip install -e .
|
||||||
# - pylint aioupnp
|
|
||||||
|
|
||||||
script:
|
script:
|
||||||
|
# - pylint aioupnp
|
||||||
|
- mypy .
|
||||||
- HOME=/tmp coverage run --source=aioupnp -m unittest -v
|
- HOME=/tmp coverage run --source=aioupnp -m unittest -v
|
||||||
|
|
||||||
after_success:
|
after_success:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
def none_or_str(x):
|
from typing import Tuple, Union
|
||||||
return None if not x or x == 'None' else str(x)
|
|
||||||
|
none_or_str = Union[None, str]
|
||||||
|
|
||||||
|
|
||||||
class SCPDCommands:
|
class SCPDCommands:
|
||||||
|
@ -14,12 +15,13 @@ class SCPDCommands:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def GetNATRSIPStatus() -> (bool, bool):
|
async def GetNATRSIPStatus() -> Tuple[bool, bool]:
|
||||||
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> (none_or_str, int, str, int, str, bool, str, int):
|
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> Tuple[none_or_str, int, str, int, str,
|
||||||
|
bool, str, int]:
|
||||||
"""
|
"""
|
||||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||||
NewPortMappingDescription, NewLeaseDuration)
|
NewPortMappingDescription, NewLeaseDuration)
|
||||||
|
@ -27,7 +29,8 @@ class SCPDCommands:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> (int, str, bool, str, int):
|
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int,
|
||||||
|
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
|
||||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -42,12 +45,12 @@ class SCPDCommands:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def GetConnectionTypeInfo() -> (str, str):
|
async def GetConnectionTypeInfo() -> Tuple[str, str]:
|
||||||
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def GetStatusInfo() -> (str, str, int):
|
async def GetStatusInfo() -> Tuple[str, str, int]:
|
||||||
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -92,7 +95,7 @@ class SCPDCommands:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def X_GetICSStatistics() -> (int, int, int, int, str, str):
|
async def X_GetICSStatistics() -> Tuple[int, int, int, int, str, str]:
|
||||||
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -119,6 +122,6 @@ class SCPDCommands:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def GetActiveConnections() -> (str, str):
|
async def GetActiveConnections() -> Tuple[str, str]:
|
||||||
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CaseInsensitive:
|
class CaseInsensitive:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs) -> None:
|
||||||
not_evaluated = {}
|
not_evaluated = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if k.startswith("_"):
|
if k.startswith("_"):
|
||||||
|
@ -22,6 +23,7 @@ class CaseInsensitive:
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k.lower() == case_insensitive.lower():
|
if k.lower() == case_insensitive.lower():
|
||||||
return k
|
return k
|
||||||
|
raise AttributeError(case_insensitive)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item in self.__dict__:
|
if item in self.__dict__:
|
||||||
|
@ -75,7 +77,7 @@ class Device(CaseInsensitive):
|
||||||
presentationURL = None
|
presentationURL = None
|
||||||
iconList = None
|
iconList = None
|
||||||
|
|
||||||
def __init__(self, devices, services, **kwargs):
|
def __init__(self, devices: List, services: List, **kwargs) -> None:
|
||||||
super(Device, self).__init__(**kwargs)
|
super(Device, self).__init__(**kwargs)
|
||||||
if self.serviceList and "service" in self.serviceList:
|
if self.serviceList and "service" in self.serviceList:
|
||||||
new_services = self.serviceList["service"]
|
new_services = self.serviceList["service"]
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
|
import typing
|
||||||
|
from typing import Dict, List, Union, Type, Tuple
|
||||||
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
|
||||||
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
|
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
|
||||||
from aioupnp.commands import SCPDCommands
|
from aioupnp.commands import SCPDCommands
|
||||||
|
@ -7,11 +10,16 @@ from aioupnp.protocols.ssdp import m_search
|
||||||
from aioupnp.protocols.scpd import scpd_get
|
from aioupnp.protocols.scpd import scpd_get
|
||||||
from aioupnp.protocols.soap import SCPDCommand
|
from aioupnp.protocols.soap import SCPDCommand
|
||||||
from aioupnp.util import flatten_keys
|
from aioupnp.util import flatten_keys
|
||||||
|
from aioupnp.fault import UPnPError
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
return_type_lambas = {
|
||||||
|
Union[None, str]: lambda x: x if x is not None and str(x).lower() not in ['none', 'nil'] else None
|
||||||
|
}
|
||||||
|
|
||||||
def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
|
|
||||||
|
def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
|
||||||
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
|
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
|
||||||
if "actionList" in service_info:
|
if "actionList" in service_info:
|
||||||
action_list = service_info["actionList"]
|
action_list = service_info["actionList"]
|
||||||
|
@ -20,7 +28,7 @@ def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...],
|
||||||
if not len(action_list): # it could be an empty string
|
if not len(action_list): # it could be an empty string
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result = []
|
result: list = []
|
||||||
if isinstance(action_list["action"], dict):
|
if isinstance(action_list["action"], dict):
|
||||||
arg_dicts = action_list["action"]['argumentList']['argument']
|
arg_dicts = action_list["action"]['argumentList']['argument']
|
||||||
if not isinstance(arg_dicts, list): # when there is one arg
|
if not isinstance(arg_dicts, list): # when there is one arg
|
||||||
|
@ -97,21 +105,22 @@ class Gateway:
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def services(self) -> dict:
|
def services(self) -> Dict:
|
||||||
if not self._device:
|
if not self._device:
|
||||||
return {}
|
return {}
|
||||||
return {service.serviceType: service for service in self._services}
|
return {service.serviceType: service for service in self._services}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def devices(self) -> dict:
|
def devices(self) -> Dict:
|
||||||
if not self._device:
|
if not self._device:
|
||||||
return {}
|
return {}
|
||||||
return {device.udn: device for device in self._devices}
|
return {device.udn: device for device in self._devices}
|
||||||
|
|
||||||
def get_service(self, service_type: str) -> Service:
|
def get_service(self, service_type: str) -> Union[Type[Service], None]:
|
||||||
for service in self._services:
|
for service in self._services:
|
||||||
if service.serviceType.lower() == service_type.lower():
|
if service.serviceType.lower() == service_type.lower():
|
||||||
return service
|
return service
|
||||||
|
return None
|
||||||
|
|
||||||
def debug_commands(self):
|
def debug_commands(self):
|
||||||
return {
|
return {
|
||||||
|
@ -121,13 +130,14 @@ class Gateway:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
|
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
|
||||||
service: str = UPNP_ORG_IGD):
|
service: str = UPNP_ORG_IGD, ssdp_socket: socket.socket = None,
|
||||||
datagram = await m_search(lan_address, gateway_address, timeout, service)
|
soap_socket: socket.socket = None):
|
||||||
|
datagram = await m_search(lan_address, gateway_address, timeout, service, ssdp_socket)
|
||||||
gateway = cls(**datagram.as_dict())
|
gateway = cls(**datagram.as_dict())
|
||||||
await gateway.discover_commands()
|
await gateway.discover_commands(soap_socket)
|
||||||
return gateway
|
return gateway
|
||||||
|
|
||||||
async def discover_commands(self):
|
async def discover_commands(self, soap_socket: socket.socket = None):
|
||||||
response = await scpd_get("/" + self.path.decode(), self.base_ip.decode(), self.port)
|
response = await scpd_get("/" + self.path.decode(), self.base_ip.decode(), self.port)
|
||||||
|
|
||||||
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
|
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
|
||||||
|
@ -141,9 +151,11 @@ class Gateway:
|
||||||
else:
|
else:
|
||||||
self._device = Device(self._devices, self._services)
|
self._device = Device(self._devices, self._services)
|
||||||
for service_type in self.services.keys():
|
for service_type in self.services.keys():
|
||||||
await self.register_commands(self.services[service_type])
|
await self.register_commands(self.services[service_type], soap_socket)
|
||||||
|
|
||||||
async def register_commands(self, service: Service):
|
async def register_commands(self, service: Service, soap_socket: socket.socket = None):
|
||||||
|
if not service.SCPDURL:
|
||||||
|
raise UPnPError("no scpd url")
|
||||||
service_dict = await scpd_get(("" if service.SCPDURL.startswith("/") else "/") + service.SCPDURL,
|
service_dict = await scpd_get(("" if service.SCPDURL.startswith("/") else "/") + service.SCPDURL,
|
||||||
self.base_ip.decode(), self.port)
|
self.base_ip.decode(), self.port)
|
||||||
if not service_dict:
|
if not service_dict:
|
||||||
|
@ -157,8 +169,10 @@ class Gateway:
|
||||||
annotations = current.__annotations__
|
annotations = current.__annotations__
|
||||||
return_types = annotations.get('return', None)
|
return_types = annotations.get('return', None)
|
||||||
if return_types:
|
if return_types:
|
||||||
if not isinstance(return_types, tuple):
|
if isinstance(return_types, type):
|
||||||
return_types = (return_types, )
|
return_types = (return_types, )
|
||||||
|
else:
|
||||||
|
return_types = tuple([return_type_lambas.get(a, a) for a in return_types.__args__])
|
||||||
return_types = {r: t for r, t in zip(outputs, return_types)}
|
return_types = {r: t for r, t in zip(outputs, return_types)}
|
||||||
param_types = {}
|
param_types = {}
|
||||||
for param_name, param_type in annotations.items():
|
for param_name, param_type in annotations.items():
|
||||||
|
@ -167,7 +181,7 @@ class Gateway:
|
||||||
param_types[param_name] = param_type
|
param_types[param_name] = param_type
|
||||||
command = SCPDCommand(
|
command = SCPDCommand(
|
||||||
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
|
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
|
||||||
name, param_types, return_types, inputs, outputs)
|
name, param_types, return_types, inputs, outputs, soap_socket)
|
||||||
setattr(command, "__doc__", current.__doc__)
|
setattr(command, "__doc__", current.__doc__)
|
||||||
setattr(self.commands, command.method, command)
|
setattr(self.commands, command.method, command)
|
||||||
|
|
||||||
|
|
|
@ -5,42 +5,44 @@ from asyncio.transports import DatagramTransport
|
||||||
|
|
||||||
|
|
||||||
class MulticastProtocol(DatagramProtocol):
|
class MulticastProtocol(DatagramProtocol):
|
||||||
|
def __init__(self, multicast_address: str, bind_address: str) -> None:
|
||||||
|
self.multicast_address = multicast_address
|
||||||
|
self.bind_address = bind_address
|
||||||
|
self.transport: DatagramTransport
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def socket(self) -> socket.socket:
|
def sock(self) -> socket.socket:
|
||||||
return self.transport.get_extra_info('socket')
|
s: socket.socket = self.transport.get_extra_info(name='socket')
|
||||||
|
return s
|
||||||
|
|
||||||
def get_ttl(self) -> int:
|
def get_ttl(self) -> int:
|
||||||
return self.socket.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
|
return self.sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
|
||||||
|
|
||||||
def set_ttl(self, ttl: int = 1) -> None:
|
def set_ttl(self, ttl: int = 1) -> None:
|
||||||
self.socket.setsockopt(
|
self.sock.setsockopt(
|
||||||
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
|
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
|
||||||
)
|
)
|
||||||
|
|
||||||
def join_group(self, addr: str, interface: str) -> None:
|
def join_group(self, multicast_address: str, bind_address: str) -> None:
|
||||||
addr = socket.inet_aton(addr)
|
self.sock.setsockopt(
|
||||||
interface = socket.inet_aton(interface)
|
|
||||||
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, addr + interface)
|
|
||||||
|
|
||||||
def leave_group(self, addr: str, interface: str) -> None:
|
|
||||||
addr = socket.inet_aton(addr)
|
|
||||||
interface = socket.inet_aton(interface)
|
|
||||||
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, addr + interface)
|
|
||||||
|
|
||||||
def connection_made(self, transport: DatagramTransport) -> None:
|
|
||||||
self.transport = transport
|
|
||||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_socket(cls, bind_address: str, multicast_address: str):
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
sock.bind((bind_address, 0))
|
|
||||||
sock.setsockopt(
|
|
||||||
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
|
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
|
||||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def leave_group(self, multicast_address: str, bind_address: str) -> None:
|
||||||
|
self.sock.setsockopt(
|
||||||
|
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
|
||||||
|
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||||
|
)
|
||||||
|
|
||||||
|
def connection_made(self, transport) -> None:
|
||||||
|
self.transport = transport
|
||||||
|
self.join_group(self.multicast_address, self.bind_address)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_multicast_socket(cls, bind_address: str):
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock.bind((bind_address, 0))
|
||||||
sock.setblocking(False)
|
sock.setblocking(False)
|
||||||
return sock
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
from xml.etree import ElementTree
|
from xml.etree import ElementTree
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio.protocols import Protocol
|
from asyncio.protocols import Protocol
|
||||||
|
@ -16,7 +17,7 @@ class SCPDHTTPClientProtocol(Protocol):
|
||||||
GET = 'GET'
|
GET = 'GET'
|
||||||
|
|
||||||
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
|
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
|
||||||
soap_service_id: str=None, close_after_send: bool = False):
|
soap_service_id: str=None, close_after_send: bool = False) -> None:
|
||||||
self.method = method
|
self.method = method
|
||||||
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
|
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
|
||||||
'soap args not provided'
|
'soap args not provided'
|
||||||
|
@ -60,7 +61,7 @@ class SCPDHTTPClientProtocol(Protocol):
|
||||||
|
|
||||||
async def scpd_get(control_url: str, address: str, port: int) -> dict:
|
async def scpd_get(control_url: str, address: str, port: int) -> dict:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
finished = asyncio.Future()
|
finished: asyncio.Future = asyncio.Future()
|
||||||
packet = serialize_scpd_get(control_url, address)
|
packet = serialize_scpd_get(control_url, address)
|
||||||
transport, protocol = await loop.create_connection(
|
transport, protocol = await loop.create_connection(
|
||||||
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
|
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
|
||||||
|
@ -72,15 +73,15 @@ async def scpd_get(control_url: str, address: str, port: int) -> dict:
|
||||||
|
|
||||||
|
|
||||||
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
|
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
|
||||||
close_after_send: bool, **kwargs):
|
close_after_send: bool, soap_socket: socket.socket = None, **kwargs):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
finished = asyncio.Future()
|
finished: asyncio.Future = asyncio.Future()
|
||||||
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
|
||||||
transport, protocol = await loop.create_connection(
|
transport, protocol = await loop.create_connection(
|
||||||
lambda : SCPDHTTPClientProtocol(
|
lambda : SCPDHTTPClientProtocol(
|
||||||
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
|
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
|
||||||
close_after_send=close_after_send
|
close_after_send=close_after_send
|
||||||
), address, port
|
), address, port, sock=soap_socket
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(finished, 1.0)
|
return await asyncio.wait_for(finished, 1.0)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
|
import typing
|
||||||
from aioupnp.protocols.scpd import scpd_post
|
from aioupnp.protocols.scpd import scpd_post
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -6,7 +8,8 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class SCPDCommand:
|
class SCPDCommand:
|
||||||
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
|
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
|
||||||
param_types: dict, return_types: dict, param_order: list, return_order: list):
|
param_types: dict, return_types: dict, param_order: list, return_order: list,
|
||||||
|
soap_socket: socket.socket = None) -> None:
|
||||||
self.gateway_address = gateway_address
|
self.gateway_address = gateway_address
|
||||||
self.service_port = service_port
|
self.service_port = service_port
|
||||||
self.control_url = control_url
|
self.control_url = control_url
|
||||||
|
@ -16,17 +19,19 @@ class SCPDCommand:
|
||||||
self.param_order = param_order
|
self.param_order = param_order
|
||||||
self.return_types = return_types
|
self.return_types = return_types
|
||||||
self.return_order = return_order
|
self.return_order = return_order
|
||||||
|
self.soap_socket = soap_socket
|
||||||
|
|
||||||
async def __call__(self, **kwargs):
|
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
|
||||||
if set(kwargs.keys()) != set(self.param_types.keys()):
|
if set(kwargs.keys()) != set(self.param_types.keys()):
|
||||||
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
|
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
|
||||||
close_after_send = not self.return_types or self.return_types == [None]
|
close_after_send = not self.return_types or self.return_types == [None]
|
||||||
response = await scpd_post(
|
response = await scpd_post(
|
||||||
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
|
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
|
||||||
close_after_send, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
|
close_after_send, self.soap_socket, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
|
||||||
)
|
)
|
||||||
extracted_response = tuple([None if self.return_types[n] is None else self.return_types[n](response[n])
|
result = tuple([self.return_types[n](response.get(n)) for n in self.return_order])
|
||||||
for n in self.return_order]) or (None, )
|
if not result:
|
||||||
if len(extracted_response) == 1:
|
return None
|
||||||
return extracted_response[0]
|
if len(result) == 1:
|
||||||
return extracted_response
|
return result[0]
|
||||||
|
return result
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import re
|
import re
|
||||||
|
import socket
|
||||||
import binascii
|
import binascii
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import DefaultDict
|
from typing import Dict, List, Tuple
|
||||||
from asyncio.coroutines import coroutine
|
|
||||||
from asyncio.futures import Future
|
from asyncio.futures import Future
|
||||||
from asyncio.transports import DatagramTransport
|
from asyncio.transports import DatagramTransport
|
||||||
from aioupnp.fault import UPnPError
|
from aioupnp.fault import UPnPError
|
||||||
|
@ -17,17 +17,12 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SSDPProtocol(MulticastProtocol):
|
class SSDPProtocol(MulticastProtocol):
|
||||||
def __init__(self, lan_address):
|
def __init__(self, multicast_address: str, lan_address: str) -> None:
|
||||||
super().__init__()
|
super().__init__(multicast_address, lan_address)
|
||||||
self.lan_address = lan_address
|
self.lan_address = lan_address
|
||||||
self.discover_callbacks: DefaultDict[coroutine] = {}
|
self.discover_callbacks: Dict = {}
|
||||||
self.transport: DatagramTransport
|
self.notifications: List = []
|
||||||
self.notifications = []
|
self.replies: List = []
|
||||||
self.replies = []
|
|
||||||
|
|
||||||
def connection_made(self, transport: DatagramTransport):
|
|
||||||
super().connection_made(transport)
|
|
||||||
self.set_ttl(1)
|
|
||||||
|
|
||||||
async def m_search(self, address, timeout: int = 1, service=UPNP_ORG_IGD) -> SSDPDatagram:
|
async def m_search(self, address, timeout: int = 1, service=UPNP_ORG_IGD) -> SSDPDatagram:
|
||||||
if (address, service) in self.discover_callbacks:
|
if (address, service) in self.discover_callbacks:
|
||||||
|
@ -37,7 +32,7 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
mx=1
|
mx=1
|
||||||
)
|
)
|
||||||
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
|
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
|
||||||
f = Future()
|
f: Future = Future()
|
||||||
self.discover_callbacks[(address, service)] = f
|
self.discover_callbacks[(address, service)] = f
|
||||||
return await asyncio.wait_for(f, timeout)
|
return await asyncio.wait_for(f, timeout)
|
||||||
|
|
||||||
|
@ -57,8 +52,9 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
if (addr[0], packet.st) in self.discover_callbacks:
|
if (addr[0], packet.st) in self.discover_callbacks:
|
||||||
if packet.st not in map(lambda p: p['st'], self.replies):
|
if packet.st not in map(lambda p: p['st'], self.replies):
|
||||||
self.replies.append(packet)
|
self.replies.append(packet)
|
||||||
f: Future = self.discover_callbacks.pop((addr[0], packet.st))
|
ok_fut: Future = self.discover_callbacks.pop((addr[0], packet.st))
|
||||||
f.set_result(packet)
|
ok_fut.set_result(packet)
|
||||||
|
return
|
||||||
|
|
||||||
elif packet._packet_type == packet._NOTIFY:
|
elif packet._packet_type == packet._NOTIFY:
|
||||||
if packet.nt == SSDP_ROOT_DEVICE:
|
if packet.nt == SSDP_ROOT_DEVICE:
|
||||||
|
@ -70,21 +66,27 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
break
|
break
|
||||||
if key:
|
if key:
|
||||||
log.debug("got a notification with the requested m-search info")
|
log.debug("got a notification with the requested m-search info")
|
||||||
f: Future = self.discover_callbacks.pop(key)
|
notify_fut: Future = self.discover_callbacks.pop(key)
|
||||||
f.set_result(SSDPDatagram(
|
notify_fut.set_result(SSDPDatagram(
|
||||||
SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server,
|
SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server,
|
||||||
st=UPNP_ORG_IGD, usn=packet.usn
|
st=UPNP_ORG_IGD, usn=packet.usn
|
||||||
))
|
))
|
||||||
self.notifications.append(packet.as_dict())
|
self.notifications.append(packet.as_dict())
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransport, SSDPProtocol, str, str):
|
async def listen_ssdp(lan_address: str, gateway_address: str,
|
||||||
|
ssdp_socket: socket.socket = None) -> Tuple[DatagramTransport, SSDPProtocol,
|
||||||
|
str, str]:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
sock = SSDPProtocol.create_socket(lan_address, SSDP_IP_ADDRESS)
|
sock = ssdp_socket or SSDPProtocol.create_multicast_socket(lan_address)
|
||||||
transport, protocol = await loop.create_datagram_endpoint(
|
listen_result: Tuple = await loop.create_datagram_endpoint(
|
||||||
lambda: SSDPProtocol(lan_address), sock=sock
|
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
|
||||||
)
|
)
|
||||||
|
transport: DatagramTransport = listen_result[0]
|
||||||
|
protocol: SSDPProtocol = listen_result[1]
|
||||||
|
protocol.set_ttl(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("failed to create multicast socket %s:%i", lan_address, SSDP_PORT)
|
log.exception("failed to create multicast socket %s:%i", lan_address, SSDP_PORT)
|
||||||
raise
|
raise
|
||||||
|
@ -92,12 +94,12 @@ async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransp
|
||||||
|
|
||||||
|
|
||||||
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1,
|
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1,
|
||||||
service: str = UPNP_ORG_IGD) -> SSDPDatagram:
|
service: str = UPNP_ORG_IGD, ssdp_socket: socket.socket = None) -> SSDPDatagram:
|
||||||
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
transport, protocol, gateway_address, lan_address = await listen_ssdp(
|
||||||
lan_address, gateway_address
|
lan_address, gateway_address, ssdp_socket
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return await protocol.m_search(gateway_address, timeout=timeout, service=service)
|
return await protocol.m_search(address=gateway_address, timeout=timeout, service=service)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import re
|
import re
|
||||||
|
from typing import Dict
|
||||||
from xml.etree import ElementTree
|
from xml.etree import ElementTree
|
||||||
from aioupnp.constants import XML_VERSION, DEVICE, ROOT
|
from aioupnp.constants import XML_VERSION, DEVICE, ROOT
|
||||||
from aioupnp.util import etree_to_dict, flatten_keys
|
from aioupnp.util import etree_to_dict, flatten_keys
|
||||||
|
@ -33,7 +34,7 @@ def serialize_scpd_get(path: str, address: str) -> bytes:
|
||||||
).encode()
|
).encode()
|
||||||
|
|
||||||
|
|
||||||
def deserialize_scpd_get_response(content: bytes) -> dict:
|
def deserialize_scpd_get_response(content: bytes) -> Dict:
|
||||||
if XML_VERSION.encode() in content:
|
if XML_VERSION.encode() in content:
|
||||||
parsed = CONTENT_PATTERN.findall(content)
|
parsed = CONTENT_PATTERN.findall(content)
|
||||||
content = b'' if not parsed else parsed[0][0]
|
content = b'' if not parsed else parsed[0][0]
|
||||||
|
@ -47,3 +48,4 @@ def deserialize_scpd_get_response(content: bytes) -> dict:
|
||||||
root = m[2][5]
|
root = m[2][5]
|
||||||
break
|
break
|
||||||
return flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
return flatten_keys(xml_dict, "{%s}" % schema_key)[root]
|
||||||
|
return {}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import binascii
|
import binascii
|
||||||
|
from typing import Dict, List
|
||||||
from aioupnp.fault import UPnPError
|
from aioupnp.fault import UPnPError
|
||||||
from aioupnp.constants import line_separator
|
from aioupnp.constants import line_separator
|
||||||
|
|
||||||
|
@ -69,13 +70,8 @@ class SSDPDatagram(object):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
_marshallers = {
|
|
||||||
'mx': str,
|
|
||||||
'man': lambda x: ("\"%s\"" % x)
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
|
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
|
||||||
cache_control=None, server=None, date=None, ext=None, **kwargs):
|
cache_control=None, server=None, date=None, ext=None, **kwargs) -> None:
|
||||||
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
|
||||||
raise UPnPError("unknown packet type: {}".format(packet_type))
|
raise UPnPError("unknown packet type: {}".format(packet_type))
|
||||||
self._packet_type = packet_type
|
self._packet_type = packet_type
|
||||||
|
@ -95,8 +91,9 @@ class SSDPDatagram(object):
|
||||||
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
|
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
|
||||||
setattr(self, k.lower(), v)
|
setattr(self, k.lower(), v)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
|
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + \
|
||||||
|
", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
for i in self._required_fields[self._packet_type]:
|
for i in self._required_fields[self._packet_type]:
|
||||||
|
@ -104,17 +101,19 @@ class SSDPDatagram(object):
|
||||||
return getattr(self, i)
|
return getattr(self, i)
|
||||||
raise KeyError(item)
|
raise KeyError(item)
|
||||||
|
|
||||||
def get_friendly_name(self):
|
def get_friendly_name(self) -> str:
|
||||||
return self._friendly_names[self._packet_type]
|
return self._friendly_names[self._packet_type]
|
||||||
|
|
||||||
def encode(self, trailing_newlines=2):
|
def encode(self, trailing_newlines: int = 2) -> str:
|
||||||
lines = [self._start_lines[self._packet_type]]
|
lines = [self._start_lines[self._packet_type]]
|
||||||
for attr_name in self._required_fields[self._packet_type]:
|
for attr_name in self._required_fields[self._packet_type]:
|
||||||
attr = getattr(self, attr_name)
|
attr = getattr(self, attr_name)
|
||||||
if attr is None:
|
if attr is None:
|
||||||
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
|
||||||
if attr_name in self._marshallers:
|
if attr_name == 'mx':
|
||||||
value = self._marshallers[attr_name](attr)
|
value = str(attr)
|
||||||
|
elif attr_name == 'man':
|
||||||
|
value = "\"%s\"" % attr
|
||||||
else:
|
else:
|
||||||
value = attr
|
value = attr
|
||||||
lines.append("{}: {}".format(attr_name.upper(), value))
|
lines.append("{}: {}".format(attr_name.upper(), value))
|
||||||
|
@ -123,7 +122,7 @@ class SSDPDatagram(object):
|
||||||
serialized += line_separator
|
serialized += line_separator
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> Dict:
|
||||||
return self._lines_to_content_dict(self.encode().split(line_separator))
|
return self._lines_to_content_dict(self.encode().split(line_separator))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -142,8 +141,8 @@ class SSDPDatagram(object):
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _lines_to_content_dict(cls, lines: list) -> dict:
|
def _lines_to_content_dict(cls, lines: list) -> Dict:
|
||||||
result = {}
|
result: dict = {}
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
@ -175,13 +174,13 @@ class SSDPDatagram(object):
|
||||||
return cls._from_response(lines[1:])
|
return cls._from_response(lines[1:])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_response(cls, lines):
|
def _from_response(cls, lines: List):
|
||||||
return cls(cls._OK, **cls._lines_to_content_dict(lines))
|
return cls(cls._OK, **cls._lines_to_content_dict(lines))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_notify(cls, lines):
|
def _from_notify(cls, lines: List):
|
||||||
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
|
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_request(cls, lines):
|
def _from_request(cls, lines: List):
|
||||||
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
|
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
|
||||||
|
|
|
@ -3,7 +3,9 @@ from aioupnp.serialization.soap import serialize_soap_post
|
||||||
|
|
||||||
|
|
||||||
class TestSOAPSerialization(unittest.TestCase):
|
class TestSOAPSerialization(unittest.TestCase):
|
||||||
method, param_names, gateway_address, kwargs = "GetExternalIPAddress", [], b'10.0.0.1', {}
|
param_names: list = []
|
||||||
|
kwargs: dict = {}
|
||||||
|
method, gateway_address = "GetExternalIPAddress", b'10.0.0.1'
|
||||||
st, lan_address, path = b'urn:schemas-upnp-org:service:WANIPConnection:1', '10.0.0.1', b'/soap.cgi?service=WANIPConn1'
|
st, lan_address, path = b'urn:schemas-upnp-org:service:WANIPConnection:1', '10.0.0.1', b'/soap.cgi?service=WANIPConn1'
|
||||||
expected_result = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\n' \
|
expected_result = b'POST /soap.cgi?service=WANIPConn1 HTTP/1.1\r\n' \
|
||||||
b'Host: 10.0.0.1\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' \
|
b'Host: 10.0.0.1\r\nUser-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n' \
|
||||||
|
@ -15,7 +17,6 @@ class TestSOAPSerialization(unittest.TestCase):
|
||||||
b'<s:Body><u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">' \
|
b'<s:Body><u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">' \
|
||||||
b'</u:GetExternalIPAddress></s:Body></s:Envelope>\r\n'
|
b'</u:GetExternalIPAddress></s:Body></s:Envelope>\r\n'
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_get(self):
|
def test_serialize_get(self):
|
||||||
self.assertEqual(serialize_soap_post(
|
self.assertEqual(serialize_soap_post(
|
||||||
self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs
|
self.method, self.param_names, self.st, self.gateway_address, self.path, **self.kwargs
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||||
from aioupnp.fault import UPnPError
|
from aioupnp.fault import UPnPError
|
||||||
|
from aioupnp.constants import UPNP_ORG_IGD
|
||||||
|
|
||||||
|
|
||||||
class TestParseMSearchRequest(unittest.TestCase):
|
class TestParseMSearchRequest(unittest.TestCase):
|
||||||
|
@ -11,7 +12,7 @@ class TestParseMSearchRequest(unittest.TestCase):
|
||||||
b'MX: 1\r\n' \
|
b'MX: 1\r\n' \
|
||||||
b'\r\n'
|
b'\r\n'
|
||||||
|
|
||||||
def test_parse_m_search_response(self):
|
def test_parse_m_search(self):
|
||||||
packet = SSDPDatagram.decode(self.datagram)
|
packet = SSDPDatagram.decode(self.datagram)
|
||||||
self.assertTrue(packet._packet_type, packet._M_SEARCH)
|
self.assertTrue(packet._packet_type, packet._M_SEARCH)
|
||||||
self.assertEqual(packet.host, '239.255.255.250:1900')
|
self.assertEqual(packet.host, '239.255.255.250:1900')
|
||||||
|
@ -46,6 +47,30 @@ class TestParseMSearchResponse(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseMSearchResponseRedSonic(TestParseMSearchResponse):
|
||||||
|
datagram = \
|
||||||
|
b"HTTP/1.1 200 OK\r\n" \
|
||||||
|
b"CACHE-CONTROL: max-age=1800\r\n" \
|
||||||
|
b"DATE: Thu, 04 Oct 2018 22:59:40 GMT\r\n" \
|
||||||
|
b"EXT:\r\n" \
|
||||||
|
b"LOCATION: http://10.1.10.1:49152/IGDdevicedesc_brlan0.xml\r\n" \
|
||||||
|
b"OPT: \"http://schemas.upnp.org/upnp/1/0/\"; ns=01\r\n" \
|
||||||
|
b"01-NLS: 00000000-0000-0000-0000-000000000000\r\n" \
|
||||||
|
b"SERVER: Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22\r\n" \
|
||||||
|
b"X-User-Agent: redsonic\r\n" \
|
||||||
|
b"ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" \
|
||||||
|
b"USN: uuid:00000000-0000-0000-0000-000000000000::urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" \
|
||||||
|
b"\r\n"
|
||||||
|
|
||||||
|
def test_parse_m_search_response(self):
|
||||||
|
packet = SSDPDatagram.decode(self.datagram)
|
||||||
|
self.assertTrue(packet._packet_type, packet._OK)
|
||||||
|
self.assertEqual(packet.cache_control, 'max-age=1800')
|
||||||
|
self.assertEqual(packet.location, 'http://10.1.10.1:49152/IGDdevicedesc_brlan0.xml')
|
||||||
|
self.assertEqual(packet.server, 'Linux/3.14.28-Prod_17.2, UPnP/1.0, Portable SDK for UPnP devices/1.6.22')
|
||||||
|
self.assertEqual(packet.st, UPNP_ORG_IGD)
|
||||||
|
|
||||||
|
|
||||||
class TestParseMSearchResponseDashCacheControl(TestParseMSearchResponse):
|
class TestParseMSearchResponseDashCacheControl(TestParseMSearchResponse):
|
||||||
datagram = "\r\n".join([
|
datagram = "\r\n".join([
|
||||||
'HTTP/1.1 200 OK',
|
'HTTP/1.1 200 OK',
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
from typing import Tuple, Dict, List, Union
|
||||||
from aioupnp.fault import UPnPError
|
from aioupnp.fault import UPnPError
|
||||||
from aioupnp.gateway import Gateway
|
from aioupnp.gateway import Gateway
|
||||||
from aioupnp.constants import UPNP_ORG_IGD
|
from aioupnp.constants import UPNP_ORG_IGD
|
||||||
|
@ -12,19 +13,9 @@ from aioupnp.protocols.ssdp import m_search
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def cli(format_result=None):
|
def cli(fn):
|
||||||
def _cli(fn):
|
fn._cli = True
|
||||||
@functools.wraps(fn)
|
return fn
|
||||||
async def _inner(*args, **kwargs):
|
|
||||||
result = await fn(*args, **kwargs)
|
|
||||||
if not format_result or not result or not isinstance(result, (list, dict, tuple)):
|
|
||||||
return result
|
|
||||||
self = args[0]
|
|
||||||
return {k: v for k, v in zip(getattr(self.gateway.commands, format_result).return_order, result)}
|
|
||||||
f = _inner
|
|
||||||
f._cli = True
|
|
||||||
return f
|
|
||||||
return _cli
|
|
||||||
|
|
||||||
|
|
||||||
def _encode(x):
|
def _encode(x):
|
||||||
|
@ -36,14 +27,14 @@ def _encode(x):
|
||||||
|
|
||||||
|
|
||||||
class UPnP:
|
class UPnP:
|
||||||
def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway):
|
def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway) -> None:
|
||||||
self.lan_address = lan_address
|
self.lan_address = lan_address
|
||||||
self.gateway_address = gateway_address
|
self.gateway_address = gateway_address
|
||||||
self.gateway = gateway
|
self.gateway = gateway
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
||||||
interface_name: str = 'default') -> (str, str):
|
interface_name: str = 'default') -> Tuple[str, str]:
|
||||||
if not lan_address or not gateway_address:
|
if not lan_address or not gateway_address:
|
||||||
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
|
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
|
||||||
lan_address = lan_address or lan_addr
|
lan_address = lan_address or lan_addr
|
||||||
|
@ -52,18 +43,19 @@ class UPnP:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||||
service: str = UPNP_ORG_IGD, interface_name: str = 'default'):
|
service: str = UPNP_ORG_IGD, interface_name: str = 'default',
|
||||||
|
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
|
||||||
try:
|
try:
|
||||||
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
|
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
|
||||||
gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service)
|
gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service, ssdp_socket, soap_socket)
|
||||||
return cls(lan_address, gateway_address, gateway)
|
return cls(lan_address, gateway_address, gateway)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@cli()
|
@cli
|
||||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||||
service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> dict:
|
service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> Dict:
|
||||||
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
||||||
datagram = await m_search(lan_address, gateway_address, timeout, service)
|
datagram = await m_search(lan_address, gateway_address, timeout, service)
|
||||||
return {
|
return {
|
||||||
|
@ -72,32 +64,39 @@ class UPnP:
|
||||||
'discover_reply': datagram.as_dict()
|
'discover_reply': datagram.as_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
@cli()
|
@cli
|
||||||
async def get_external_ip(self) -> str:
|
async def get_external_ip(self) -> str:
|
||||||
return await self.gateway.commands.GetExternalIPAddress()
|
return await self.gateway.commands.GetExternalIPAddress()
|
||||||
|
|
||||||
@cli("AddPortMapping")
|
@cli
|
||||||
async def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
|
async def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
|
||||||
description: str) -> None:
|
description: str) -> None:
|
||||||
return await self.gateway.commands.AddPortMapping(
|
await self.gateway.commands.AddPortMapping(
|
||||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
|
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
|
||||||
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
||||||
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
|
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
@cli("GetGenericPortMappingEntry")
|
@cli
|
||||||
async def get_port_mapping_by_index(self, index: int) -> dict:
|
async def get_port_mapping_by_index(self, index: int) -> Dict:
|
||||||
return await self._get_port_mapping_by_index(index)
|
result = await self._get_port_mapping_by_index(index)
|
||||||
|
if result:
|
||||||
|
return {
|
||||||
|
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
async def _get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
|
async def _get_port_mapping_by_index(self, index: int) -> Union[Tuple[str, int, str, int, str, bool, str, int],
|
||||||
|
None]:
|
||||||
try:
|
try:
|
||||||
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||||
return redirect
|
return redirect
|
||||||
except UPnPError:
|
except UPnPError:
|
||||||
return
|
return None
|
||||||
|
|
||||||
@cli()
|
@cli
|
||||||
async def get_redirects(self) -> list:
|
async def get_redirects(self) -> List[Dict]:
|
||||||
redirects = []
|
redirects = []
|
||||||
cnt = 0
|
cnt = 0
|
||||||
redirect = await self.get_port_mapping_by_index(cnt)
|
redirect = await self.get_port_mapping_by_index(cnt)
|
||||||
|
@ -107,8 +106,8 @@ class UPnP:
|
||||||
redirect = await self.get_port_mapping_by_index(cnt)
|
redirect = await self.get_port_mapping_by_index(cnt)
|
||||||
return redirects
|
return redirects
|
||||||
|
|
||||||
@cli("GetSpecificPortMappingEntry")
|
@cli
|
||||||
async def get_specific_port_mapping(self, external_port: int, protocol: str):
|
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
:param external_port: (int) external port to listen on
|
:param external_port: (int) external port to listen on
|
||||||
:param protocol: (str) 'UDP' | 'TCP'
|
:param protocol: (str) 'UDP' | 'TCP'
|
||||||
|
@ -116,13 +115,14 @@ class UPnP:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.gateway.commands.GetSpecificPortMappingEntry(
|
result = await self.gateway.commands.GetSpecificPortMappingEntry(
|
||||||
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
|
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
|
||||||
)
|
)
|
||||||
|
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
|
||||||
except UPnPError:
|
except UPnPError:
|
||||||
return
|
return {}
|
||||||
|
|
||||||
@cli()
|
@cli
|
||||||
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
||||||
"""
|
"""
|
||||||
:param external_port: (int) external port to listen on
|
:param external_port: (int) external port to listen on
|
||||||
|
@ -133,7 +133,7 @@ class UPnP:
|
||||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
||||||
)
|
)
|
||||||
|
|
||||||
@cli("AddPortMapping")
|
@cli
|
||||||
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
|
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
|
||||||
if protocol not in ["UDP", "TCP"]:
|
if protocol not in ["UDP", "TCP"]:
|
||||||
raise UPnPError("unsupported protocol: {}".format(protocol))
|
raise UPnPError("unsupported protocol: {}".format(protocol))
|
||||||
|
@ -163,14 +163,14 @@ class UPnP:
|
||||||
)
|
)
|
||||||
return port
|
return port
|
||||||
|
|
||||||
@cli()
|
@cli
|
||||||
async def get_soap_commands(self) -> dict:
|
async def get_soap_commands(self) -> Dict:
|
||||||
return {
|
return {
|
||||||
'supported': list(self.gateway._registered_commands.keys()),
|
'supported': list(self.gateway._registered_commands.keys()),
|
||||||
'unsupported': self.gateway._unsupported_actions
|
'unsupported': self.gateway._unsupported_actions
|
||||||
}
|
}
|
||||||
|
|
||||||
@cli()
|
@cli
|
||||||
async def generate_test_data(self):
|
async def generate_test_data(self):
|
||||||
external_ip = await self.get_external_ip()
|
external_ip = await self.get_external_ip()
|
||||||
redirects = await self.get_redirects()
|
redirects = await self.get_redirects()
|
||||||
|
@ -216,13 +216,13 @@ class UPnP:
|
||||||
@classmethod
|
@classmethod
|
||||||
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
|
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
|
||||||
service: str = UPNP_ORG_IGD, interface_name: str = 'default',
|
service: str = UPNP_ORG_IGD, interface_name: str = 'default',
|
||||||
kwargs: dict = None):
|
kwargs: dict = None) -> None:
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
fut = asyncio.Future()
|
fut: asyncio.Future = asyncio.Future()
|
||||||
|
|
||||||
async def wrapper():
|
async def wrapper():
|
||||||
if method == 'm_search':
|
if method == 'm_search':
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Tuple, Dict
|
||||||
from xml.etree import ElementTree
|
from xml.etree import ElementTree
|
||||||
import netifaces
|
import netifaces
|
||||||
|
|
||||||
|
@ -9,15 +10,17 @@ BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode
|
||||||
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
|
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
|
||||||
|
|
||||||
|
|
||||||
def etree_to_dict(t: ElementTree) -> dict:
|
def etree_to_dict(t: ElementTree.Element) -> Dict:
|
||||||
d = {t.tag: {} if t.attrib else None}
|
d: dict = {}
|
||||||
|
if t.attrib:
|
||||||
|
d[t.tag] = {}
|
||||||
children = list(t)
|
children = list(t)
|
||||||
if children:
|
if children:
|
||||||
dd = defaultdict(list)
|
dd: dict = defaultdict(list)
|
||||||
for dc in map(etree_to_dict, children):
|
for dc in map(etree_to_dict, children):
|
||||||
for k, v in dc.items():
|
for k, v in dc.items():
|
||||||
dd[k].append(v)
|
dd[k].append(v)
|
||||||
d = {t.tag: {k: v[0] if len(v) == 1 else v for k, v in dd.items()}}
|
d[t.tag] = {k: v[0] if len(v) == 1 else v for k, v in dd.items()}
|
||||||
if t.attrib:
|
if t.attrib:
|
||||||
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
|
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
|
||||||
if t.text:
|
if t.text:
|
||||||
|
@ -81,8 +84,8 @@ def get_interfaces():
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def get_gateway_and_lan_addresses(interface_name: str) -> (str, str):
|
def get_gateway_and_lan_addresses(interface_name: str) -> Tuple[str, str]:
|
||||||
for iface_name, (gateway, lan) in get_interfaces().items():
|
for iface_name, (gateway, lan) in get_interfaces().items():
|
||||||
if interface_name == iface_name:
|
if interface_name == iface_name:
|
||||||
return gateway, lan
|
return gateway, lan
|
||||||
return None, None
|
return '', ''
|
||||||
|
|
4
mypy.ini
Normal file
4
mypy.ini
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
[mypy]
|
||||||
|
python_version = 3.7
|
||||||
|
mypy_path=stubs
|
||||||
|
cache_dir=/dev/null
|
2
setup.py
2
setup.py
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages # type: ignore
|
||||||
from aioupnp import __version__, __name__, __email__, __author__, __license__
|
from aioupnp import __version__, __name__, __email__, __author__, __license__
|
||||||
|
|
||||||
console_scripts = [
|
console_scripts = [
|
||||||
|
|
111
stubs/netifaces.py
Normal file
111
stubs/netifaces.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
AF_APPLETALK = 5
|
||||||
|
AF_ASH = 18
|
||||||
|
AF_ATMPVC = 8
|
||||||
|
AF_ATMSVC = 20
|
||||||
|
AF_AX25 = 3
|
||||||
|
AF_BLUETOOTH = 31
|
||||||
|
AF_BRIDGE = 7
|
||||||
|
AF_DECnet = 12
|
||||||
|
AF_ECONET = 19
|
||||||
|
AF_FILE = 1
|
||||||
|
AF_INET = 2
|
||||||
|
AF_INET6 = 10
|
||||||
|
AF_IPX = 4
|
||||||
|
AF_IRDA = 23
|
||||||
|
AF_ISDN = 34
|
||||||
|
AF_KEY = 15
|
||||||
|
AF_LINK = 17
|
||||||
|
AF_NETBEUI = 13
|
||||||
|
AF_NETLINK = 16
|
||||||
|
AF_NETROM = 6
|
||||||
|
AF_PACKET = 17
|
||||||
|
AF_PPPOX = 24
|
||||||
|
AF_ROSE = 11
|
||||||
|
AF_ROUTE = 16
|
||||||
|
AF_SECURITY = 14
|
||||||
|
AF_SNA = 22
|
||||||
|
AF_UNIX = 1
|
||||||
|
AF_UNSPEC = 0
|
||||||
|
AF_WANPIPE = 25
|
||||||
|
AF_X25 = 9
|
||||||
|
|
||||||
|
version = '0.10.7'
|
||||||
|
|
||||||
|
|
||||||
|
# functions
|
||||||
|
|
||||||
|
def gateways(*args, **kwargs) -> typing.List: # real signature unknown
|
||||||
|
"""
|
||||||
|
Obtain a list of the gateways on this machine.
|
||||||
|
|
||||||
|
Returns a dict whose keys are equal to the address family constants,
|
||||||
|
e.g. netifaces.AF_INET, and whose values are a list of tuples of the
|
||||||
|
format (<address>, <interface>, <is_default>).
|
||||||
|
|
||||||
|
There is also a special entry with the key 'default', which you can use
|
||||||
|
to quickly obtain the default gateway for a particular address family.
|
||||||
|
|
||||||
|
There may in general be multiple gateways; different address
|
||||||
|
families may have different gateway settings (e.g. AF_INET vs AF_INET6)
|
||||||
|
and on some systems it's also possible to have interface-specific
|
||||||
|
default gateways.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def ifaddresses(*args, **kwargs) -> typing.Dict: # real signature unknown
|
||||||
|
"""
|
||||||
|
Obtain information about the specified network interface.
|
||||||
|
|
||||||
|
Returns a dict whose keys are equal to the address family constants,
|
||||||
|
e.g. netifaces.AF_INET, and whose values are a list of addresses in
|
||||||
|
that family that are attached to the network interface.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def interfaces(*args, **kwargs) -> typing.List: # real signature unknown
|
||||||
|
""" Obtain a list of the interfaces available on this machine. """
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# no classes
|
||||||
|
# variables with complex values
|
||||||
|
|
||||||
|
address_families = {
|
||||||
|
0: 'AF_UNSPEC',
|
||||||
|
1: 'AF_FILE',
|
||||||
|
2: 'AF_INET',
|
||||||
|
3: 'AF_AX25',
|
||||||
|
4: 'AF_IPX',
|
||||||
|
5: 'AF_APPLETALK',
|
||||||
|
6: 'AF_NETROM',
|
||||||
|
7: 'AF_BRIDGE',
|
||||||
|
8: 'AF_ATMPVC',
|
||||||
|
9: 'AF_X25',
|
||||||
|
10: 'AF_INET6',
|
||||||
|
11: 'AF_ROSE',
|
||||||
|
12: 'AF_DECnet',
|
||||||
|
13: 'AF_NETBEUI',
|
||||||
|
14: 'AF_SECURITY',
|
||||||
|
15: 'AF_KEY',
|
||||||
|
16: 'AF_NETLINK',
|
||||||
|
17: 'AF_PACKET',
|
||||||
|
18: 'AF_ASH',
|
||||||
|
19: 'AF_ECONET',
|
||||||
|
20: 'AF_ATMSVC',
|
||||||
|
22: 'AF_SNA',
|
||||||
|
23: 'AF_IRDA',
|
||||||
|
24: 'AF_PPPOX',
|
||||||
|
25: 'AF_WANPIPE',
|
||||||
|
31: 'AF_BLUETOOTH',
|
||||||
|
34: 'AF_ISDN',
|
||||||
|
}
|
||||||
|
|
||||||
|
__loader__ = None # (!) real value is ''
|
||||||
|
|
||||||
|
__spec__ = None # (!) real value is ''
|
||||||
|
|
Loading…
Reference in a new issue