import unittest import asyncio from torba.testcase import AsyncioTestCase from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction class MockAsyncGen: def __init__(self, loop, result, delay, stop_cnt=10): self.loop = loop self.result = result self.delay = delay self.count = 0 self.stop_cnt = stop_cnt self.called_close = False def __aiter__(self): return self async def __anext__(self): if self.count > self.stop_cnt - 1: raise StopAsyncIteration() self.count += 1 await asyncio.sleep(self.delay, loop=self.loop) return self.result async def aclose(self): self.called_close = True class TestAsyncGeneratorJunction(AsyncioTestCase): def setUp(self): self.loop = asyncio.get_event_loop() async def _test_junction(self, expected, *generators): order = [] async with AsyncGeneratorJunction(self.loop) as junction: for generator in generators: junction.add_generator(generator) async for item in junction: order.append(item) self.assertListEqual(order, expected) async def test_yield_order(self): expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2] fast_gen = MockAsyncGen(self.loop, 1, 0.2) slow_gen = MockAsyncGen(self.loop, 2, 0.4) await self._test_junction(expected_order, fast_gen, slow_gen) self.assertEqual(fast_gen.called_close, True) self.assertEqual(slow_gen.called_close, True) @unittest.SkipTest async def test_one_stopped_first(self): expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] fast_gen = MockAsyncGen(self.loop, 1, 0.2, 5) slow_gen = MockAsyncGen(self.loop, 2, 0.4) await self._test_junction(expected_order, fast_gen, slow_gen) self.assertEqual(fast_gen.called_close, True) self.assertEqual(slow_gen.called_close, True) async def test_with_non_async_gen_class(self): expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] async def fast_gen(): for i in range(10): if i == 5: return await asyncio.sleep(0.2) yield 1 slow_gen = MockAsyncGen(self.loop, 2, 0.4) await self._test_junction(expected_order, fast_gen(), slow_gen) self.assertEqual(slow_gen.called_close, True) async def test_stop_when_encapsulating_task_cancelled(self): fast_gen = MockAsyncGen(self.loop, 1, 0.2) slow_gen = MockAsyncGen(self.loop, 2, 0.4) async def _task(): async with AsyncGeneratorJunction(self.loop) as junction: junction.add_generator(fast_gen) junction.add_generator(slow_gen) async for _ in junction: pass task = self.loop.create_task(_task()) self.loop.call_later(1.0, task.cancel) with self.assertRaises(asyncio.CancelledError): await task self.assertEqual(fast_gen.called_close, True) self.assertEqual(slow_gen.called_close, True)