diff --git a/lbrynet/dht/protocol/async_generator_junction.py b/lbrynet/dht/protocol/async_generator_junction.py index 79db6a55d..4812ca522 100644 --- a/lbrynet/dht/protocol/async_generator_junction.py +++ b/lbrynet/dht/protocol/async_generator_junction.py @@ -26,12 +26,11 @@ class AsyncGeneratorJunction: def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None): self.loop = loop + self.__iterator_queue = asyncio.Queue(loop=loop) self.result_queue = queue or asyncio.Queue(loop=loop) self.tasks: typing.List[asyncio.Task] = [] self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {} self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop) - self.can_iterate = asyncio.Event(loop=self.loop) - self.finished = asyncio.Event(loop=self.loop) @property def running(self): @@ -42,15 +41,16 @@ class AsyncGeneratorJunction: try: async for item in iterator: self.result_queue.put_nowait(item) + self.__iterator_queue.put_nowait(item) finally: self.running_iterators[iterator] = False + if not self.running: + self.__iterator_queue.put_nowait(StopAsyncIteration) while True: async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get() self.running_iterators[async_gen] = True self.tasks.append(self.loop.create_task(iterate(async_gen))) - if not self.can_iterate.is_set(): - self.can_iterate.set() def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']): """ @@ -62,14 +62,10 @@ class AsyncGeneratorJunction: return self async def __anext__(self): - if not self.can_iterate.is_set(): - await self.can_iterate.wait() - if not self.running: - raise StopAsyncIteration() - try: - return await self.result_queue.get() - finally: - self.awaiting = None + result = await self.__iterator_queue.get() + if result is StopAsyncIteration: + raise result + return result def aclose(self): async def _aclose(): @@ -80,8 +76,6 @@ class AsyncGeneratorJunction: self.running_iterators[iterator] = False drain_tasks(self.tasks) raise StopAsyncIteration() - if not self.finished.is_set(): - self.finished.set() return self.loop.create_task(_aclose()) async def __aenter__(self): diff --git a/lbrynet/dht/protocol/iterative_find.py b/lbrynet/dht/protocol/iterative_find.py index adbe202c2..87472d87c 100644 --- a/lbrynet/dht/protocol/iterative_find.py +++ b/lbrynet/dht/protocol/iterative_find.py @@ -169,7 +169,9 @@ class IterativeFinder: log.warning(str(err)) self.active.discard(peer) return - except (RemoteException, TransportNotConnected): + except TransportNotConnected: + return self.aclose() + except RemoteException: return return await self._handle_probe_result(peer, response) diff --git a/tests/unit/dht/protocol/test_async_gen_junction.py b/tests/unit/dht/protocol/test_async_gen_junction.py index 1d2b97718..5dd9c8f29 100644 --- a/tests/unit/dht/protocol/test_async_gen_junction.py +++ b/tests/unit/dht/protocol/test_async_gen_junction.py @@ -17,10 +17,10 @@ class MockAsyncGen: return self async def __anext__(self): + await asyncio.sleep(self.delay, loop=self.loop) if self.count > self.stop_cnt - 1: raise StopAsyncIteration() self.count += 1 - await asyncio.sleep(self.delay, loop=self.loop) return self.result async def aclose(self): @@ -48,6 +48,18 @@ class TestAsyncGeneratorJunction(AsyncioTestCase): self.assertEqual(fast_gen.called_close, True) self.assertEqual(slow_gen.called_close, True) + async def test_nothing_to_yield(self): + async def __nothing(): + for _ in []: + yield self.fail("nada") + await self._test_junction([], __nothing()) + + async def test_fast_iteratiors(self): + async def __gotta_go_fast(): + for _ in range(10): + yield 0 + await self._test_junction([0]*40, __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast()) + @unittest.SkipTest async def test_one_stopped_first(self): expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]