220 lines
8.1 KiB
Python
220 lines
8.1 KiB
Python
import asyncio
|
|
import unittest
|
|
import contextlib
|
|
import socket
|
|
from typing import Tuple, Optional, Any
|
|
from unittest import mock
|
|
from unittest.case import _Outcome
|
|
|
|
|
|
try:
|
|
from asyncio.runners import _cancel_all_tasks
|
|
except ImportError:
|
|
# this is only available in py3.7
|
|
def _cancel_all_tasks(loop):
|
|
pass
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
|
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
|
|
raise_oserror_on_bind=False):
|
|
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
|
udp_replies = udp_replies or {}
|
|
|
|
sent_tcp_packets = sent_tcp_packets if sent_tcp_packets is not None else []
|
|
tcp_replies = tcp_replies or {}
|
|
|
|
async def create_connection(protocol_factory, host=None, port=None):
|
|
def get_write(p: asyncio.Protocol):
|
|
def _write(data):
|
|
sent_tcp_packets.append(data)
|
|
if data in tcp_replies:
|
|
reply = tcp_replies[data]
|
|
i = 0
|
|
while i < len(reply):
|
|
loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+100])
|
|
i += 100
|
|
return
|
|
else:
|
|
pass
|
|
|
|
return _write
|
|
|
|
protocol = protocol_factory()
|
|
write = get_write(protocol)
|
|
|
|
class MockTransport(asyncio.Transport):
|
|
def close(self):
|
|
return
|
|
|
|
def write(self, data):
|
|
write(data)
|
|
|
|
transport = MockTransport(extra={'socket': mock.Mock(spec=socket.socket)})
|
|
protocol.connection_made(transport)
|
|
return transport, protocol
|
|
|
|
async def create_datagram_endpoint(proto_lam, sock=None):
|
|
def get_sendto(p: asyncio.DatagramProtocol):
|
|
def _sendto(data, addr):
|
|
sent_udp_packets.append(data)
|
|
loop.call_later(udp_delay_reply, p.datagram_received, data,
|
|
(p.bind_address, 1900))
|
|
if add_potato_datagrams:
|
|
loop.call_soon(p.datagram_received, b'potato', ('?.?.?.?', 1900))
|
|
|
|
if (data, addr) in udp_replies:
|
|
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
|
|
(udp_expected_addr, 1900))
|
|
|
|
return _sendto
|
|
|
|
protocol = proto_lam()
|
|
sendto = get_sendto(protocol)
|
|
|
|
class MockDatagramTransport(asyncio.DatagramTransport):
|
|
def __init__(self, sock):
|
|
super().__init__(extra={'socket': sock})
|
|
self._sock = sock
|
|
|
|
def close(self):
|
|
self._sock.close()
|
|
return
|
|
|
|
def sendto(self, data: Any, addr: Optional[Tuple[str, int]] = ...) -> None:
|
|
sendto(data, addr)
|
|
|
|
transport = MockDatagramTransport(mock_sock)
|
|
protocol.connection_made(transport)
|
|
return transport, protocol
|
|
|
|
with mock.patch('socket.socket') as mock_socket:
|
|
mock_sock = mock.Mock(spec=socket.socket)
|
|
mock_sock.setsockopt = lambda *_: None
|
|
|
|
def bind(*_):
|
|
if raise_oserror_on_bind:
|
|
raise OSError()
|
|
return
|
|
|
|
mock_sock.bind = bind
|
|
mock_sock.setblocking = lambda *_: None
|
|
mock_sock.getsockname = lambda: "0.0.0.0"
|
|
mock_sock.getpeername = lambda: ""
|
|
mock_sock.close = lambda: None
|
|
mock_sock.type = socket.SOCK_DGRAM
|
|
mock_sock.fileno = lambda: 7
|
|
|
|
mock_socket.return_value = mock_sock
|
|
loop.create_datagram_endpoint = create_datagram_endpoint
|
|
loop.create_connection = create_connection
|
|
yield
|
|
|
|
|
|
class AsyncioTestCase(unittest.TestCase):
|
|
# Implementation inspired by discussion:
|
|
# https://bugs.python.org/issue32972
|
|
|
|
maxDiff = None
|
|
|
|
async def asyncSetUp(self): # pylint: disable=C0103
|
|
pass
|
|
|
|
async def asyncTearDown(self): # pylint: disable=C0103
|
|
pass
|
|
|
|
def run(self, result=None): # pylint: disable=R0915
|
|
orig_result = result
|
|
if result is None:
|
|
result = self.defaultTestResult()
|
|
startTestRun = getattr(result, 'startTestRun', None) # pylint: disable=C0103
|
|
if startTestRun is not None:
|
|
startTestRun()
|
|
|
|
result.startTest(self)
|
|
|
|
testMethod = getattr(self, self._testMethodName) # pylint: disable=C0103
|
|
if (getattr(self.__class__, "__unittest_skip__", False) or
|
|
getattr(testMethod, "__unittest_skip__", False)):
|
|
# If the class or method was skipped.
|
|
try:
|
|
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
|
|
or getattr(testMethod, '__unittest_skip_why__', ''))
|
|
self._addSkip(result, self, skip_why)
|
|
finally:
|
|
result.stopTest(self)
|
|
return
|
|
expecting_failure_method = getattr(testMethod,
|
|
"__unittest_expecting_failure__", False)
|
|
expecting_failure_class = getattr(self,
|
|
"__unittest_expecting_failure__", False)
|
|
expecting_failure = expecting_failure_class or expecting_failure_method
|
|
outcome = _Outcome(result)
|
|
|
|
self.loop = asyncio.new_event_loop() # pylint: disable=W0201
|
|
asyncio.set_event_loop(self.loop)
|
|
self.loop.set_debug(True)
|
|
|
|
try:
|
|
self._outcome = outcome
|
|
|
|
with outcome.testPartExecutor(self):
|
|
self.setUp()
|
|
self.loop.run_until_complete(self.asyncSetUp())
|
|
if outcome.success:
|
|
outcome.expecting_failure = expecting_failure
|
|
with outcome.testPartExecutor(self, isTest=True):
|
|
maybe_coroutine = testMethod()
|
|
if asyncio.iscoroutine(maybe_coroutine):
|
|
self.loop.run_until_complete(maybe_coroutine)
|
|
outcome.expecting_failure = False
|
|
with outcome.testPartExecutor(self):
|
|
self.loop.run_until_complete(self.asyncTearDown())
|
|
self.tearDown()
|
|
|
|
self.doAsyncCleanups()
|
|
|
|
try:
|
|
_cancel_all_tasks(self.loop)
|
|
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
|
finally:
|
|
asyncio.set_event_loop(None)
|
|
self.loop.close()
|
|
|
|
for test, reason in outcome.skipped:
|
|
self._addSkip(result, test, reason)
|
|
self._feedErrorsToResult(result, outcome.errors)
|
|
if outcome.success:
|
|
if expecting_failure:
|
|
if outcome.expectedFailure:
|
|
self._addExpectedFailure(result, outcome.expectedFailure)
|
|
else:
|
|
self._addUnexpectedSuccess(result)
|
|
else:
|
|
result.addSuccess(self)
|
|
return result
|
|
finally:
|
|
result.stopTest(self)
|
|
if orig_result is None:
|
|
stopTestRun = getattr(result, 'stopTestRun', None) # pylint: disable=C0103
|
|
if stopTestRun is not None:
|
|
stopTestRun() # pylint: disable=E1102
|
|
|
|
# explicitly break reference cycles:
|
|
# outcome.errors -> frame -> outcome -> outcome.errors
|
|
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
|
|
outcome.errors.clear()
|
|
outcome.expectedFailure = None
|
|
|
|
# clear the outcome, no more needed
|
|
self._outcome = None
|
|
|
|
def doAsyncCleanups(self): # pylint: disable=C0103
|
|
outcome = self._outcome or _Outcome()
|
|
while self._cleanups:
|
|
function, args, kwargs = self._cleanups.pop()
|
|
with outcome.testPartExecutor(self):
|
|
maybe_coroutine = function(*args, **kwargs)
|
|
if asyncio.iscoroutine(maybe_coroutine):
|
|
self.loop.run_until_complete(maybe_coroutine)
|