type annotations

This commit is contained in:
Jack Robison 2019-10-21 18:24:41 -04:00
parent b1c39cef7b
commit c3211215df
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2

View file

@ -4,6 +4,7 @@ import asyncio
import logging import logging
import typing import typing
import socket import socket
from typing import List, Set, Dict, Tuple, Optional
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
@ -17,15 +18,15 @@ log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol): class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str, ignored: typing.Optional[typing.Set[str]] = None, def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None,
unicast: bool = False, loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> None: unicast: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self.transport: typing.Optional[DatagramTransport] = None self.transport: Optional[DatagramTransport] = None
self._unicast = unicast self._unicast = unicast
self._ignored: typing.Set[str] = ignored or set() # ignored locations self._ignored: Set[str] = ignored or set() # ignored locations
self._pending_searches: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] self._pending_searches: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
self.notifications: typing.List[SSDPDatagram] = [] self.notifications: List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop) self.connected = asyncio.Event(loop=self.loop)
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
@ -47,8 +48,8 @@ class SSDPProtocol(MulticastProtocol):
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location not in self._ignored: if packet.location not in self._ignored:
# TODO: fix this # TODO: fix this
tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] tmp: List[Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = [] set_futures: List['asyncio.Future[SSDPDatagram]'] = []
while len(self._pending_searches): while len(self._pending_searches):
t = self._pending_searches.pop() t = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
@ -74,8 +75,8 @@ class SSDPProtocol(MulticastProtocol):
return None return None
async def m_search(self, address: str, timeout: float, async def m_search(self, address: str, timeout: float,
datagrams: typing.List[typing.Dict[str, typing.Union[str, int]]]) -> SSDPDatagram: datagrams: List[Dict[str, typing.Union[str, int]]]) -> SSDPDatagram:
fut: 'asyncio.Future[SSDPDatagram]' = asyncio.Future(loop=self.loop) fut: 'asyncio.Future[SSDPDatagram]' = self.loop.create_future()
for datagram in datagrams: for datagram in datagrams:
packet = SSDPDatagram("M-SEARCH", datagram) packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None assert packet.st is not None
@ -84,7 +85,7 @@ class SSDPProtocol(MulticastProtocol):
) )
return await asyncio.wait_for(fut, timeout) return await asyncio.wait_for(fut, timeout)
def datagram_received(self, data: bytes, addr: typing.Tuple[str, int]) -> None: # type: ignore def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore
if addr[0] == self.bind_address: if addr[0] == self.bind_address:
return None return None
try: try:
@ -118,13 +119,13 @@ class SSDPProtocol(MulticastProtocol):
# return # return
async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optional[asyncio.AbstractEventLoop] = None, async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Optional[typing.Set[str]] = None, ignored: Optional[Set[str]] = None,
unicast: bool = False) -> typing.Tuple[SSDPProtocol, str, str]: unicast: bool = False) -> Tuple[SSDPProtocol, str, str]:
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
try: try:
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address) sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
) )
protocol = listen_result[1] protocol = listen_result[1]
@ -137,9 +138,9 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio
return protocol, gateway_address, lan_address return protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, datagram_args: typing.Dict[str, typing.Union[int, str]], async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
timeout: int = 1, loop: typing.Optional[asyncio.AbstractEventLoop] = None, timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, unicast: bool = False) -> SSDPDatagram: ignored: Set[str] = None, unicast: bool = False) -> SSDPDatagram:
protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored, unicast
) )
@ -152,9 +153,9 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: typing
async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: typing.Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, ignored: Set[str] = None,
unicast: bool = False) -> typing.List[typing.Dict[str, typing.Union[int, str]]]: unicast: bool = False) -> List[Dict[str, typing.Union[int, str]]]:
protocol, gateway_address, lan_address = await listen_ssdp( protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, loop, ignored, unicast lan_address, gateway_address, loop, ignored, unicast
) )
@ -178,10 +179,9 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30, async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
loop: typing.Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
ignored: typing.Set[str] = None, ignored: Set[str] = None,
unicast: bool = False) -> typing.Tuple[typing.Dict[str, unicast: bool = False) -> Tuple[Dict[str, typing.Union[int, str]], SSDPDatagram]:
typing.Union[int, str]], SSDPDatagram]:
# we don't know which packet the gateway replies to, so send small batches at a time # we don't know which packet the gateway replies to, so send small batches at a time
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast) args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, loop, ignored, unicast)
# check the args in the batch that got a reply one at a time to see which one worked # check the args in the batch that got a reply one at a time to see which one worked