added AdvanceTimeTestCase

This commit is contained in:
Lex Berezhny 2019-01-07 03:02:15 -05:00
parent b9b411ec30
commit 442138ef36

View file

@ -1,5 +1,6 @@
import sys import sys
import logging import logging
import functools
import asyncio import asyncio
from asyncio.runners import _cancel_all_tasks # type: ignore from asyncio.runners import _cancel_all_tasks # type: ignore
import unittest import unittest
@ -93,35 +94,36 @@ class AsyncioTestCase(unittest.TestCase):
"__unittest_expecting_failure__", False) "__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result) outcome = _Outcome(result)
loop = asyncio.new_event_loop()
self.loop = asyncio.new_event_loop() # pylint: disable=W0201
asyncio.set_event_loop(self.loop)
self.loop.set_debug(True)
try: try:
self._outcome = outcome self._outcome = outcome
asyncio.set_event_loop(loop)
loop.set_debug(True)
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
self.setUp() self.setUp()
loop.run_until_complete(self.asyncSetUp()) self.loop.run_until_complete(self.asyncSetUp())
if outcome.success: if outcome.success:
outcome.expecting_failure = expecting_failure outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True): with outcome.testPartExecutor(self, isTest=True):
maybe_coroutine = testMethod() maybe_coroutine = testMethod()
if asyncio.iscoroutine(maybe_coroutine): if asyncio.iscoroutine(maybe_coroutine):
loop.run_until_complete(maybe_coroutine) self.loop.run_until_complete(maybe_coroutine)
outcome.expecting_failure = False outcome.expecting_failure = False
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown()) self.loop.run_until_complete(self.asyncTearDown())
self.tearDown() self.tearDown()
self.doAsyncCleanups(loop) self.doAsyncCleanups()
try: try:
_cancel_all_tasks(loop) _cancel_all_tasks(self.loop)
loop.run_until_complete(loop.shutdown_asyncgens()) self.loop.run_until_complete(self.loop.shutdown_asyncgens())
finally: finally:
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
loop.close() self.loop.close()
for test, reason in outcome.skipped: for test, reason in outcome.skipped:
self._addSkip(result, test, reason) self._addSkip(result, test, reason)
@ -151,14 +153,30 @@ class AsyncioTestCase(unittest.TestCase):
# clear the outcome, no more needed # clear the outcome, no more needed
self._outcome = None self._outcome = None
def doAsyncCleanups(self, loop): # pylint: disable=C0103 def doAsyncCleanups(self): # pylint: disable=C0103
outcome = self._outcome or _Outcome() outcome = self._outcome or _Outcome()
while self._cleanups: while self._cleanups:
function, args, kwargs = self._cleanups.pop() function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self): with outcome.testPartExecutor(self):
maybe_coroutine = function(*args, **kwargs) maybe_coroutine = function(*args, **kwargs)
if asyncio.iscoroutine(maybe_coroutine): if asyncio.iscoroutine(maybe_coroutine):
loop.run_until_complete(maybe_coroutine) self.loop.run_until_complete(maybe_coroutine)
class AdvanceTimeTestCase(AsyncioTestCase):
async def asyncSetUp(self):
self._time = 0 # pylint: disable=W0201
self.loop.time = functools.wraps(self.loop.time)(lambda: self._time)
await super().asyncSetUp()
async def advance(self, seconds):
while self.loop._ready:
await asyncio.sleep(0)
self._time += seconds
await asyncio.sleep(0)
while self.loop._ready:
await asyncio.sleep(0)
class IntegrationTestCase(AsyncioTestCase): class IntegrationTestCase(AsyncioTestCase):