fuzzy m search

This commit is contained in:
Jack Robison 2018-10-10 19:39:45 -04:00
parent ba8be4746a
commit 0956d6a71e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 177 additions and 78 deletions

View file

@ -1,6 +1,8 @@
import logging import logging
import sys import sys
from collections import OrderedDict
from aioupnp.upnp import UPnP from aioupnp.upnp import UPnP
from aioupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER, SSDP_HOST
log = logging.getLogger("aioupnp") log = logging.getLogger("aioupnp")
handler = logging.StreamHandler() handler = logging.StreamHandler()
@ -45,13 +47,14 @@ def main():
'gateway_address': '', 'gateway_address': '',
'lan_address': '', 'lan_address': '',
'timeout': 1, 'timeout': 1,
'service': '', # if not provided try all of them
'man': '', 'HOST': SSDP_HOST,
'mx': 1, 'ST': UPNP_ORG_IGD,
'return_as_json': True 'MAN': SSDP_DISCOVER,
'MX': 1,
} }
options = {} options = OrderedDict()
command = None command = None
for arg in args: for arg in args:
if arg.startswith("--"): if arg.startswith("--"):
@ -80,9 +83,8 @@ def main():
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
UPnP.run_cli( UPnP.run_cli(
command.replace('-', '_'), options.pop('lan_address'), options.pop('gateway_address'), command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'),
options.pop('timeout'), options.pop('service'), options.pop('man'), options.pop('mx'), options.pop('timeout'), options.pop('interface'), kwargs
options.pop('interface'), kwargs
) )

View file

@ -5,7 +5,7 @@ from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_AD
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import m_search from aioupnp.protocols.ssdp import fuzzy_m_search
from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.scpd import scpd_get
from aioupnp.protocols.soap import SCPDCommand from aioupnp.protocols.soap import SCPDCommand
from aioupnp.util import flatten_keys from aioupnp.util import flatten_keys
@ -129,9 +129,8 @@ class Gateway:
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD, man: str = '', mx: int = 1,
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
datagram = await m_search(lan_address, gateway_address, timeout, service, man, mx, ssdp_socket) datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
gateway = cls(**datagram.as_dict()) gateway = cls(**datagram.as_dict())
await gateway.discover_commands(soap_socket) await gateway.discover_commands(soap_socket)
return gateway return gateway

View file

@ -0,0 +1,81 @@
M_SEARCH_ARG_PATTERNS = [
#
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('MAN', lambda s: '"%s"' % s),
('MX', lambda n: int(n)),
],
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('Man', lambda s: '"%s"' % s),
('MX', lambda n: int(n)),
],
[
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('Man', lambda s: '"%s"' % s),
('MX', lambda n: int(n)),
],
# swap st and man
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('MAN', lambda s: '"%s"' % s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: '"%s"' % s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
[
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: '"%s"' % s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
# repeat above but with no quotes on man
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('MAN', lambda s: s),
('MX', lambda n: int(n)),
],
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('Man', lambda s: s),
('MX', lambda n: int(n)),
],
[
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('Man', lambda s: s),
('MX', lambda n: int(n)),
],
# swap st and man
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('MAN', lambda s: s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
[
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: s),
('ST', lambda s: str(s)),
('MX', lambda n: int(n)),
],
]

View file

@ -3,14 +3,16 @@ import socket
import binascii import binascii
import asyncio import asyncio
import logging import logging
from collections import OrderedDict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from asyncio.futures import Future from asyncio.futures import Future
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD from aioupnp.constants import UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD
from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE, SSDP_ALL
from aioupnp.protocols.multicast import MulticastProtocol from aioupnp.protocols.multicast import MulticastProtocol
from aioupnp.protocols.m_search_patterns import M_SEARCH_ARG_PATTERNS
ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$") ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$")
@ -25,41 +27,17 @@ class SSDPProtocol(MulticastProtocol):
self.notifications: List = [] self.notifications: List = []
self.replies: List = [] self.replies: List = []
def send_m_search_packet(self, service, address, man, mx): def m_search(self, address: str, timeout: int, datagram_args: OrderedDict) -> Future:
packet = SSDPDatagram( packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram_args)
SSDPDatagram._M_SEARCH, host="{}:{}".format(SSDP_IP_ADDRESS, SSDP_PORT), st=service, f: Future = Future()
man=man, mx=mx futs = self.discover_callbacks.get((address, packet.st), [])
) futs.append(f)
log.debug("sending packet to %s:%i: %s", address, SSDP_PORT, packet) self.discover_callbacks[(address, packet.st)] = futs
log.debug("send m search to %s: %s", address, packet)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
async def m_search(self, address: str, timeout: int = 1, service: str = '', man: str = '', mx: int = 1) -> SSDPDatagram: r: Future = asyncio.ensure_future(asyncio.wait_for(f, timeout))
if (address, service) in self.discover_callbacks: return r
return self.discover_callbacks[(address, service)]
man = man or SSDP_DISCOVER
if not service:
services = [UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD]
else:
services = [service]
search_futs: List[Future] = []
outer_fut: Future = Future()
for service in services:
# D-Link works with both
# Cisco only works with quotes
self.send_m_search_packet(service, address, '\"%s\"' % man, mx)
# DD-WRT only works without quotes
self.send_m_search_packet(service, address, man, mx)
f: Future = Future()
f.add_done_callback(lambda _f: outer_fut.set_result(_f.result()))
self.discover_callbacks[(address, service)] = f
search_futs.append(f)
return await asyncio.wait_for(outer_fut, timeout)
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
if addr[0] == self.lan_address: if addr[0] == self.lan_address:
@ -77,28 +55,29 @@ class SSDPProtocol(MulticastProtocol):
log.debug("%s:%i replied to our m-search", addr[0], addr[1]) log.debug("%s:%i replied to our m-search", addr[0], addr[1])
if packet.st not in map(lambda p: p['st'], self.replies): if packet.st not in map(lambda p: p['st'], self.replies):
self.replies.append(packet) self.replies.append(packet)
ok_fut: Future = self.discover_callbacks.pop((addr[0], packet.st)) for ok_fut in self.discover_callbacks[(addr[0], packet.st)]:
ok_fut.set_result(packet) ok_fut.set_result(packet)
del self.discover_callbacks[(addr[0], packet.st)]
return return
elif packet._packet_type == packet._NOTIFY: # elif packet._packet_type == packet._NOTIFY:
log.debug("%s:%i sent us a notification: %s", packet) # log.debug("%s:%i sent us a notification: %s", packet)
if packet.nt == SSDP_ROOT_DEVICE: # if packet.nt == SSDP_ROOT_DEVICE:
address, port, path = ADDRESS_REGEX.findall(packet.location)[0] # address, port, path = ADDRESS_REGEX.findall(packet.location)[0]
key = None # key = None
for (addr, service) in self.discover_callbacks: # for (addr, service) in self.discover_callbacks:
if addr == address: # if addr == address:
key = (addr, service) # key = (addr, service)
break # break
if key: # if key:
log.debug("got a notification with the requested m-search info") # log.debug("got a notification with the requested m-search info")
notify_fut: Future = self.discover_callbacks.pop(key) # notify_fut: Future = self.discover_callbacks.pop(key)
notify_fut.set_result(SSDPDatagram( # notify_fut.set_result(SSDPDatagram(
SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server, # SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server,
st=UPNP_ORG_IGD, usn=packet.usn # st=UPNP_ORG_IGD, usn=packet.usn
)) # ))
self.notifications.append(packet.as_dict()) # self.notifications.append(packet.as_dict())
return # return
async def listen_ssdp(lan_address: str, gateway_address: str, async def listen_ssdp(lan_address: str, gateway_address: str,
@ -124,14 +103,50 @@ async def listen_ssdp(lan_address: str, gateway_address: str,
return transport, protocol, gateway_address, lan_address return transport, protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1, async def m_search(lan_address: str, gateway_address: str, datagram_args: OrderedDict, timeout: int = 1,
service: str = '', man: str = '', mx: int = 1, ssdp_socket: socket.socket = None) -> SSDPDatagram: ssdp_socket: socket.socket = None) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp( transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket lan_address, gateway_address, ssdp_socket
) )
try: try:
return await protocol.m_search(address=gateway_address, timeout=timeout, service=service, man=man, mx=mx) return await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=datagram_args)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:
transport.close() transport.close()
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 1,
ssdp_socket: socket.socket = None) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket
)
datagram_kwargs: list = []
services = [UPNP_ORG_IGD, SSDP_ALL, WIFI_ALLIANCE_ORG_IGD]
mans = [SSDP_DISCOVER, SSDP_ROOT_DEVICE]
mx = 1
for service in services:
for man in mans:
for arg_pattern in M_SEARCH_ARG_PATTERNS:
dgram_kwargs: OrderedDict = OrderedDict()
for k, l in arg_pattern:
if k.lower() == 'host':
dgram_kwargs[k] = l(SSDP_IP_ADDRESS)
elif k.lower() == 'st':
dgram_kwargs[k] = l(service)
elif k.lower() == 'man':
dgram_kwargs[k] = l(man)
elif k.lower() == 'mx':
dgram_kwargs[k] = l(mx)
datagram_kwargs.append(dgram_kwargs)
for i, args in enumerate(datagram_kwargs):
try:
result = await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=args)
transport.close()
return result
except TimeoutError:
pass
transport.close()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))

View file

@ -3,11 +3,12 @@ import socket
import logging import logging
import json import json
import asyncio import asyncio
from collections import OrderedDict
from typing import Tuple, Dict, List, Union from typing import Tuple, Dict, List, Union
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.util import get_gateway_and_lan_addresses from aioupnp.util import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -42,23 +43,24 @@ class UPnP:
@classmethod @classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default', interface_name: str = 'default',
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
try: try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
except Exception as err: except Exception as err:
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err)) raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
gateway = await Gateway.discover_gateway( gateway = await Gateway.discover_gateway(
lan_address, gateway_address, timeout, service, man, mx, ssdp_socket, soap_socket lan_address, gateway_address, timeout, ssdp_socket, soap_socket
) )
return cls(lan_address, gateway_address, gateway) return cls(lan_address, gateway_address, gateway)
@classmethod @classmethod
@cli @cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default') -> Dict: args: OrderedDict = None, interface_name: str = 'default') -> Dict:
args = args or OrderedDict()
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
datagram = await m_search(lan_address, gateway_address, timeout, service, man, mx) datagram = await m_search(lan_address, gateway_address, args, timeout)
return { return {
'lan_address': lan_address, 'lan_address': lan_address,
'gateway_address': gateway_address, 'gateway_address': gateway_address,
@ -215,10 +217,10 @@ class UPnP:
return "Generated test data! -> %s" % device_path return "Generated test data! -> %s" % device_path
@classmethod @classmethod
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60, def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
service: str = '', man: str = '', mx: int = 1, interface_name: str = 'default', interface_name: str = 'default', kwargs: dict = None) -> None:
kwargs: dict = None) -> None:
kwargs = kwargs or {} kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout) timeout = int(timeout)
try: try:
asyncio.get_running_loop() asyncio.get_running_loop()
@ -231,12 +233,12 @@ class UPnP:
async def wrapper(): async def wrapper():
if method == 'm_search': if method == 'm_search':
fn = lambda *_a, **_kw: cls.m_search( fn = lambda *_a, **_kw: cls.m_search(
lan_address, gateway_address, timeout, service, man, mx, interface_name lan_address, gateway_address, timeout, igd_args, interface_name
) )
else: else:
try: try:
u = await cls.discover( u = await cls.discover(
lan_address, gateway_address, timeout, service, man, mx, interface_name lan_address, gateway_address, timeout, interface_name
) )
except UPnPError as err: except UPnPError as err:
fut.set_exception(err) fut.set_exception(err)