From 57ebbbcb7875b072b69d4fbe56163d01a1a79047 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Fri, 18 Feb 2022 18:52:17 -0300 Subject: [PATCH] simplify dht mock and restore clock after accelerating --- tests/dht_mocks.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/dht_mocks.py b/tests/dht_mocks.py index 2e01986c0..4bebcfaf1 100644 --- a/tests/dht_mocks.py +++ b/tests/dht_mocks.py @@ -9,7 +9,7 @@ if typing.TYPE_CHECKING: def get_time_accelerator(loop: asyncio.AbstractEventLoop, - now: typing.Optional[float] = None) -> typing.Callable[[float], typing.Awaitable[None]]: + instant_step: bool = False) -> typing.Callable[[float], typing.Awaitable[None]]: """ Returns an async advance() function @@ -17,32 +17,22 @@ def get_time_accelerator(loop: asyncio.AbstractEventLoop, made by call_later, call_at, and call_soon. """ - _time = now or loop.time() - loop.time = functools.wraps(loop.time)(lambda: _time) + original = loop.time + _drift = 0 + loop.time = functools.wraps(loop.time)(lambda: original() + _drift) async def accelerate_time(seconds: float) -> None: - nonlocal _time + nonlocal _drift if seconds < 0: raise ValueError(f'Cannot go back in time ({seconds} seconds)') - _time += seconds - await past_events() + _drift += seconds await asyncio.sleep(0) - async def past_events() -> None: - while loop._scheduled: - timer: asyncio.TimerHandle = loop._scheduled[0] - if timer not in loop._ready and timer._when <= _time: - loop._scheduled.remove(timer) - loop._ready.append(timer) - if timer._when > _time: - break - await asyncio.sleep(0) - async def accelerator(seconds: float): - steps = seconds * 10.0 + steps = seconds * 10.0 if not instant_step else 1 for _ in range(max(int(steps), 1)): - await accelerate_time(0.1) + await accelerate_time(seconds/steps) return accelerator