add mx argument

This commit is contained in:
Jack Robison 2018-10-10 15:18:11 -04:00
parent 28c01b5b31
commit b0d1f7a193
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 22 additions and 21 deletions

View file

@ -47,6 +47,7 @@ def main():
'timeout': 1, 'timeout': 1,
'service': '', # if not provided try all of them 'service': '', # if not provided try all of them
'man': '', 'man': '',
'mx': 1,
'return_as_json': True 'return_as_json': True
} }
@ -80,8 +81,8 @@ def main():
UPnP.run_cli( UPnP.run_cli(
command.replace('-', '_'), options.pop('lan_address'), options.pop('gateway_address'), command.replace('-', '_'), options.pop('lan_address'), options.pop('gateway_address'),
options.pop('timeout'), options.pop('service'), options.pop('man'), options.pop('interface'), options.pop('timeout'), options.pop('service'), options.pop('man'), options.pop('mx'),
kwargs options.pop('interface'), kwargs
) )

View file

@ -3,7 +3,7 @@ from typing import Tuple, Union
none_or_str = Union[None, str] none_or_str = Union[None, str]
class SCPDCommands: class SOAPCommands:
def debug_commands(self) -> dict: def debug_commands(self) -> dict:
raise NotImplementedError() raise NotImplementedError()

View file

@ -3,7 +3,7 @@ import socket
from typing import Dict, List, Union, Type from typing import Dict, List, Union, Type
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
from aioupnp.commands import SCPDCommands from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import m_search from aioupnp.protocols.ssdp import m_search
from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.scpd import scpd_get
@ -90,7 +90,7 @@ class Gateway:
self._unsupported_actions = {} self._unsupported_actions = {}
self._registered_commands = {} self._registered_commands = {}
self.commands = SCPDCommands() self.commands = SOAPCommands()
def gateway_descriptor(self) -> dict: def gateway_descriptor(self) -> dict:
r = { r = {
@ -129,9 +129,9 @@ class Gateway:
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD, man: str = '', service: str = UPNP_ORG_IGD, man: str = '', mx: int = 1,
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
datagram = await m_search(lan_address, gateway_address, timeout, service, man, ssdp_socket) datagram = await m_search(lan_address, gateway_address, timeout, service, man, mx, ssdp_socket)
gateway = cls(**datagram.as_dict()) gateway = cls(**datagram.as_dict())
await gateway.discover_commands(soap_socket) await gateway.discover_commands(soap_socket)
return gateway return gateway

View file

@ -25,15 +25,15 @@ class SSDPProtocol(MulticastProtocol):
self.notifications: List = [] self.notifications: List = []
self.replies: List = [] self.replies: List = []
def send_m_search_packet(self, service, address, man): def send_m_search_packet(self, service, address, man, mx):
packet = SSDPDatagram( packet = SSDPDatagram(
SSDPDatagram._M_SEARCH, host="{}:{}".format(SSDP_IP_ADDRESS, SSDP_PORT), st=service, SSDPDatagram._M_SEARCH, host="{}:{}".format(SSDP_IP_ADDRESS, SSDP_PORT), st=service,
man=man, mx=1 man=man, mx=mx
) )
log.debug("sending packet to %s:%i: %s", address, SSDP_PORT, packet) log.debug("sending packet to %s:%i: %s", address, SSDP_PORT, packet)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
async def m_search(self, address, timeout: int = 1, service='', man='') -> SSDPDatagram: async def m_search(self, address: str, timeout: int = 1, service: str = '', man: str = '', mx: int = 1) -> SSDPDatagram:
if (address, service) in self.discover_callbacks: if (address, service) in self.discover_callbacks:
return self.discover_callbacks[(address, service)] return self.discover_callbacks[(address, service)]
man = man or SSDP_DISCOVER man = man or SSDP_DISCOVER
@ -49,10 +49,10 @@ class SSDPProtocol(MulticastProtocol):
# D-Link works with both # D-Link works with both
# Cisco only works with quotes # Cisco only works with quotes
self.send_m_search_packet(service, address, '\"%s\"' % man) self.send_m_search_packet(service, address, '\"%s\"' % man, mx)
# DD-WRT only works without quotes # DD-WRT only works without quotes
self.send_m_search_packet(service, address, man) self.send_m_search_packet(service, address, man, mx)
f: Future = Future() f: Future = Future()
f.add_done_callback(lambda _f: outer_fut.set_result(_f.result())) f.add_done_callback(lambda _f: outer_fut.set_result(_f.result()))
@ -125,12 +125,12 @@ async def listen_ssdp(lan_address: str, gateway_address: str,
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1, async def m_search(lan_address: str, gateway_address: str, timeout: int = 1,
service: str = '', man: str = '', ssdp_socket: socket.socket = None) -> SSDPDatagram: service: str = '', man: str = '', mx: int = 1, ssdp_socket: socket.socket = None) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp( transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket lan_address, gateway_address, ssdp_socket
) )
try: try:
return await protocol.m_search(address=gateway_address, timeout=timeout, service=service, man=man) return await protocol.m_search(address=gateway_address, timeout=timeout, service=service, man=man, mx=mx)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:

View file

@ -42,23 +42,23 @@ class UPnP:
@classmethod @classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = '', man: str = '', interface_name: str = 'default', service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default',
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
try: try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
except Exception as err: except Exception as err:
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err)) raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
gateway = await Gateway.discover_gateway( gateway = await Gateway.discover_gateway(
lan_address, gateway_address, timeout, service, man, ssdp_socket, soap_socket lan_address, gateway_address, timeout, service, man, mx, ssdp_socket, soap_socket
) )
return cls(lan_address, gateway_address, gateway) return cls(lan_address, gateway_address, gateway)
@classmethod @classmethod
@cli @cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = '', man: str = '', interface_name: str = 'default') -> Dict: service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default') -> Dict:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
datagram = await m_search(lan_address, gateway_address, timeout, service, man) datagram = await m_search(lan_address, gateway_address, timeout, service, man, mx)
return { return {
'lan_address': lan_address, 'lan_address': lan_address,
'gateway_address': gateway_address, 'gateway_address': gateway_address,
@ -216,7 +216,7 @@ class UPnP:
@classmethod @classmethod
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60, def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
service: str = '', man: str = '', interface_name: str = 'default', service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default',
kwargs: dict = None) -> None: kwargs: dict = None) -> None:
kwargs = kwargs or {} kwargs = kwargs or {}
timeout = int(timeout) timeout = int(timeout)
@ -231,12 +231,12 @@ class UPnP:
async def wrapper(): async def wrapper():
if method == 'm_search': if method == 'm_search':
fn = lambda *_a, **_kw: cls.m_search( fn = lambda *_a, **_kw: cls.m_search(
lan_address, gateway_address, timeout, service, man, interface_name lan_address, gateway_address, timeout, service, man, mx, interface_name
) )
else: else:
try: try:
u = await cls.discover( u = await cls.discover(
lan_address, gateway_address, timeout, service, man, interface_name lan_address, gateway_address, timeout, service, man, mx, interface_name
) )
except UPnPError as err: except UPnPError as err:
fut.set_exception(err) fut.set_exception(err)