test cli
This commit is contained in:
parent
2f8b10f3c7
commit
ab0b9707f2
11 changed files with 306 additions and 136 deletions
|
@ -24,7 +24,8 @@ def get_help(command):
|
|||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main(argv=None, loop=None):
|
||||
argv = argv or sys.argv
|
||||
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
|
||||
help_str = "\n".join(textwrap.wrap(
|
||||
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
|
||||
|
@ -40,7 +41,7 @@ def main():
|
|||
"For help with a specific command:" \
|
||||
" aioupnp help <command>\n" % (base_usage, help_str)
|
||||
|
||||
args = sys.argv[1:]
|
||||
args = argv[1:]
|
||||
if args[0] in ['help', '-h', '--help']:
|
||||
if len(args) > 1:
|
||||
if args[1] in commands:
|
||||
|
@ -53,6 +54,7 @@ def main():
|
|||
'gateway_address': '',
|
||||
'lan_address': '',
|
||||
'timeout': 30,
|
||||
'unicast': False
|
||||
}
|
||||
|
||||
options = OrderedDict()
|
||||
|
@ -88,7 +90,7 @@ def main():
|
|||
|
||||
UPnP.run_cli(
|
||||
command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'),
|
||||
options.pop('timeout'), options.pop('interface'), kwargs
|
||||
options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ class Gateway:
|
|||
not_met = [
|
||||
required for required in required_commands if required not in gateway._registered_commands
|
||||
]
|
||||
log.warning("found gateway %s at %s, but it does not implement required soap commands: %s",
|
||||
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
|
||||
gateway.manufacturer_string, gateway.location, not_met)
|
||||
ignored.add(datagram.location)
|
||||
continue
|
||||
|
|
|
@ -60,7 +60,9 @@ class UPnP:
|
|||
@classmethod
|
||||
@cli
|
||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||
igd_args: OrderedDict = None, interface_name: str = 'default') -> Dict:
|
||||
igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default',
|
||||
loop=None) -> Dict:
|
||||
if not lan_address or not gateway_address:
|
||||
try:
|
||||
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
||||
assert gateway_address and lan_address
|
||||
|
@ -68,10 +70,10 @@ class UPnP:
|
|||
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
|
||||
str(err)))
|
||||
if not igd_args:
|
||||
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout)
|
||||
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast)
|
||||
else:
|
||||
igd_args = OrderedDict(igd_args)
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout)
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast)
|
||||
return {
|
||||
'lan_address': lan_address,
|
||||
'gateway_address': gateway_address,
|
||||
|
@ -340,7 +342,7 @@ class UPnP:
|
|||
|
||||
@classmethod
|
||||
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
|
||||
interface_name: str = 'default', kwargs: dict = None) -> None:
|
||||
interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None:
|
||||
"""
|
||||
:param method: the command name
|
||||
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
|
||||
|
@ -349,18 +351,19 @@ class UPnP:
|
|||
:param timeout: timeout, in seconds
|
||||
:param interface_name: name of the network interface, the default is aliased to 'default'
|
||||
:param kwargs: keyword arguments for the command
|
||||
:param loop: EventLoop, used for testing
|
||||
"""
|
||||
kwargs = kwargs or {}
|
||||
igd_args = igd_args
|
||||
timeout = int(timeout)
|
||||
loop = asyncio.get_event_loop_policy().get_event_loop()
|
||||
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
|
||||
fut: asyncio.Future = asyncio.Future()
|
||||
|
||||
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
||||
|
||||
if method == 'm_search': # if we're only m_searching don't do any device discovery
|
||||
fn = lambda *_a, **_kw: cls.m_search(
|
||||
lan_address, gateway_address, timeout, igd_args, interface_name
|
||||
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
|
||||
)
|
||||
else: # automatically discover the gateway
|
||||
try:
|
||||
|
|
2
setup.py
2
setup.py
|
@ -22,7 +22,7 @@ setup(
|
|||
long_description_content_type='text/markdown',
|
||||
url="https://github.com/lbryio/aioupnp",
|
||||
license=__license__,
|
||||
packages=find_packages(),
|
||||
packages=find_packages(exclude=('tests',)),
|
||||
entry_points={'console_scripts': console_scripts},
|
||||
install_requires=[
|
||||
'netifaces',
|
||||
|
|
113
tests/__init__.py
Normal file
113
tests/__init__.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
from unittest.case import _Outcome
|
||||
|
||||
try:
|
||||
from asyncio.runners import _cancel_all_tasks
|
||||
except ImportError:
|
||||
# this is only available in py3.7
|
||||
def _cancel_all_tasks(loop):
|
||||
pass
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
# Implementation inspired by discussion:
|
||||
# https://bugs.python.org/issue32972
|
||||
|
||||
async def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
async def asyncTearDown(self):
|
||||
pass
|
||||
|
||||
async def doAsyncCleanups(self):
|
||||
pass
|
||||
|
||||
def run(self, result=None):
|
||||
orig_result = result
|
||||
if result is None:
|
||||
result = self.defaultTestResult()
|
||||
startTestRun = getattr(result, 'startTestRun', None)
|
||||
if startTestRun is not None:
|
||||
startTestRun()
|
||||
|
||||
result.startTest(self)
|
||||
|
||||
testMethod = getattr(self, self._testMethodName)
|
||||
if (getattr(self.__class__, "__unittest_skip__", False) or
|
||||
getattr(testMethod, "__unittest_skip__", False)):
|
||||
# If the class or method was skipped.
|
||||
try:
|
||||
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
|
||||
or getattr(testMethod, '__unittest_skip_why__', ''))
|
||||
self._addSkip(result, self, skip_why)
|
||||
finally:
|
||||
result.stopTest(self)
|
||||
return
|
||||
expecting_failure_method = getattr(testMethod,
|
||||
"__unittest_expecting_failure__", False)
|
||||
expecting_failure_class = getattr(self,
|
||||
"__unittest_expecting_failure__", False)
|
||||
expecting_failure = expecting_failure_class or expecting_failure_method
|
||||
outcome = _Outcome(result)
|
||||
try:
|
||||
self._outcome = outcome
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(True)
|
||||
|
||||
with outcome.testPartExecutor(self):
|
||||
self.setUp()
|
||||
loop.run_until_complete(self.asyncSetUp())
|
||||
if outcome.success:
|
||||
outcome.expecting_failure = expecting_failure
|
||||
with outcome.testPartExecutor(self, isTest=True):
|
||||
possible_coroutine = testMethod()
|
||||
if asyncio.iscoroutine(possible_coroutine):
|
||||
loop.run_until_complete(possible_coroutine)
|
||||
outcome.expecting_failure = False
|
||||
with outcome.testPartExecutor(self):
|
||||
loop.run_until_complete(self.asyncTearDown())
|
||||
self.tearDown()
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
self.doCleanups()
|
||||
|
||||
for test, reason in outcome.skipped:
|
||||
self._addSkip(result, test, reason)
|
||||
self._feedErrorsToResult(result, outcome.errors)
|
||||
if outcome.success:
|
||||
if expecting_failure:
|
||||
if outcome.expectedFailure:
|
||||
self._addExpectedFailure(result, outcome.expectedFailure)
|
||||
else:
|
||||
self._addUnexpectedSuccess(result)
|
||||
else:
|
||||
result.addSuccess(self)
|
||||
return result
|
||||
finally:
|
||||
result.stopTest(self)
|
||||
if orig_result is None:
|
||||
stopTestRun = getattr(result, 'stopTestRun', None)
|
||||
if stopTestRun is not None:
|
||||
stopTestRun()
|
||||
|
||||
# explicitly break reference cycles:
|
||||
# outcome.errors -> frame -> outcome -> outcome.errors
|
||||
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
|
||||
outcome.errors.clear()
|
||||
outcome.expectedFailure = None
|
||||
|
||||
# clear the outcome, no more needed
|
||||
self._outcome = None
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.get_event_loop_policy().get_event_loop()
|
|
@ -75,3 +75,63 @@ def mock_tcp_endpoint_factory(loop, replies=None, delay_reply=0.0, sent_packets=
|
|||
mock_socket.return_value = mock_sock
|
||||
loop.create_connection = create_connection
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_tcp_and_udp(loop, udp_expected_addr, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||
tcp_replies=None, tcp_delay_reply=0.0, tcp_sent_packets=None):
|
||||
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
||||
udp_replies = udp_replies or {}
|
||||
|
||||
tcp_sent_packets = tcp_sent_packets if tcp_sent_packets is not None else []
|
||||
tcp_replies = tcp_replies or {}
|
||||
|
||||
async def create_connection(protocol_factory, host=None, port=None):
|
||||
def write(p: asyncio.Protocol):
|
||||
def _write(data):
|
||||
tcp_sent_packets.append(data)
|
||||
if data in tcp_replies:
|
||||
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
|
||||
|
||||
return _write
|
||||
|
||||
protocol = protocol_factory()
|
||||
transport = asyncio.Transport(extra={'socket': mock.Mock(spec=socket.socket)})
|
||||
transport.close = lambda: None
|
||||
transport.write = write(protocol)
|
||||
protocol.connection_made(transport)
|
||||
return transport, protocol
|
||||
|
||||
async def create_datagram_endpoint(proto_lam, sock=None):
|
||||
def sendto(p: asyncio.DatagramProtocol):
|
||||
def _sendto(data, addr):
|
||||
sent_udp_packets.append(data)
|
||||
if (data, addr) in udp_replies:
|
||||
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
|
||||
(udp_expected_addr, 1900))
|
||||
|
||||
return _sendto
|
||||
|
||||
protocol = proto_lam()
|
||||
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
|
||||
transport.close = lambda: mock_sock.close()
|
||||
mock_sock.sendto = sendto(protocol)
|
||||
transport.sendto = mock_sock.sendto
|
||||
protocol.connection_made(transport)
|
||||
return transport, protocol
|
||||
|
||||
with mock.patch('socket.socket') as mock_socket:
|
||||
mock_sock = mock.Mock(spec=socket.socket)
|
||||
mock_sock.setsockopt = lambda *_: None
|
||||
mock_sock.bind = lambda *_: None
|
||||
mock_sock.setblocking = lambda *_: None
|
||||
mock_sock.getsockname = lambda: "0.0.0.0"
|
||||
mock_sock.getpeername = lambda: ""
|
||||
mock_sock.close = lambda: None
|
||||
mock_sock.type = socket.SOCK_DGRAM
|
||||
mock_sock.fileno = lambda: 7
|
||||
|
||||
mock_socket.return_value = mock_sock
|
||||
loop.create_datagram_endpoint = create_datagram_endpoint
|
||||
loop.create_connection = create_connection
|
||||
yield
|
|
@ -1,113 +0,0 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
from unittest.case import _Outcome
|
||||
|
||||
try:
|
||||
from asyncio.runners import _cancel_all_tasks
|
||||
except ImportError:
|
||||
# this is only available in py3.7
|
||||
def _cancel_all_tasks(loop):
|
||||
pass
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
# Implementation inspired by discussion:
|
||||
# https://bugs.python.org/issue32972
|
||||
|
||||
async def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
async def asyncTearDown(self):
|
||||
pass
|
||||
|
||||
async def doAsyncCleanups(self):
|
||||
pass
|
||||
|
||||
def run(self, result=None):
|
||||
orig_result = result
|
||||
if result is None:
|
||||
result = self.defaultTestResult()
|
||||
startTestRun = getattr(result, 'startTestRun', None)
|
||||
if startTestRun is not None:
|
||||
startTestRun()
|
||||
|
||||
result.startTest(self)
|
||||
|
||||
testMethod = getattr(self, self._testMethodName)
|
||||
if (getattr(self.__class__, "__unittest_skip__", False) or
|
||||
getattr(testMethod, "__unittest_skip__", False)):
|
||||
# If the class or method was skipped.
|
||||
try:
|
||||
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
|
||||
or getattr(testMethod, '__unittest_skip_why__', ''))
|
||||
self._addSkip(result, self, skip_why)
|
||||
finally:
|
||||
result.stopTest(self)
|
||||
return
|
||||
expecting_failure_method = getattr(testMethod,
|
||||
"__unittest_expecting_failure__", False)
|
||||
expecting_failure_class = getattr(self,
|
||||
"__unittest_expecting_failure__", False)
|
||||
expecting_failure = expecting_failure_class or expecting_failure_method
|
||||
outcome = _Outcome(result)
|
||||
try:
|
||||
self._outcome = outcome
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(True)
|
||||
|
||||
with outcome.testPartExecutor(self):
|
||||
self.setUp()
|
||||
loop.run_until_complete(self.asyncSetUp())
|
||||
if outcome.success:
|
||||
outcome.expecting_failure = expecting_failure
|
||||
with outcome.testPartExecutor(self, isTest=True):
|
||||
possible_coroutine = testMethod()
|
||||
if asyncio.iscoroutine(possible_coroutine):
|
||||
loop.run_until_complete(possible_coroutine)
|
||||
outcome.expecting_failure = False
|
||||
with outcome.testPartExecutor(self):
|
||||
loop.run_until_complete(self.asyncTearDown())
|
||||
self.tearDown()
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
self.doCleanups()
|
||||
|
||||
for test, reason in outcome.skipped:
|
||||
self._addSkip(result, test, reason)
|
||||
self._feedErrorsToResult(result, outcome.errors)
|
||||
if outcome.success:
|
||||
if expecting_failure:
|
||||
if outcome.expectedFailure:
|
||||
self._addExpectedFailure(result, outcome.expectedFailure)
|
||||
else:
|
||||
self._addUnexpectedSuccess(result)
|
||||
else:
|
||||
result.addSuccess(self)
|
||||
return result
|
||||
finally:
|
||||
result.stopTest(self)
|
||||
if orig_result is None:
|
||||
stopTestRun = getattr(result, 'stopTestRun', None)
|
||||
if stopTestRun is not None:
|
||||
stopTestRun()
|
||||
|
||||
# explicitly break reference cycles:
|
||||
# outcome.errors -> frame -> outcome -> outcome.errors
|
||||
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
|
||||
outcome.errors.clear()
|
||||
outcome.expectedFailure = None
|
||||
|
||||
# clear the outcome, no more needed
|
||||
self._outcome = None
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.get_event_loop_policy().get_event_loop()
|
|
@ -1,7 +1,7 @@
|
|||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.protocols.scpd import scpd_post, scpd_get
|
||||
from . import TestBase
|
||||
from .mocks import mock_tcp_endpoint_factory
|
||||
from tests import TestBase
|
||||
from tests.mocks import mock_tcp_endpoint_factory
|
||||
|
||||
|
||||
class TestSCPDGet(TestBase):
|
||||
|
|
|
@ -4,8 +4,8 @@ from aioupnp.protocols.m_search_patterns import packet_generator
|
|||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.constants import SSDP_IP_ADDRESS
|
||||
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
|
||||
from . import TestBase
|
||||
from .mocks import mock_datagram_endpoint_factory
|
||||
from tests import TestBase
|
||||
from tests.mocks import mock_datagram_endpoint_factory
|
||||
|
||||
|
||||
class TestSSDP(TestBase):
|
||||
|
|
104
tests/test_cli.py
Normal file
104
tests/test_cli.py
Normal file
File diff suppressed because one or more lines are too long
|
@ -1,11 +1,12 @@
|
|||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.protocols.scpd import scpd_post, scpd_get
|
||||
from . import TestBase
|
||||
from .mocks import mock_tcp_endpoint_factory
|
||||
from tests import TestBase
|
||||
from tests.mocks import mock_tcp_endpoint_factory
|
||||
from collections import OrderedDict
|
||||
from aioupnp.gateway import Gateway
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
|
||||
|
||||
class TestDiscoverCommands(TestBase):
|
||||
gateway_address = "10.0.0.1"
|
||||
soap_port = 49152
|
Loading…
Reference in a new issue