diff --git a/lbrynet/dht/protocol/async_generator_junction.py b/lbrynet/dht/protocol/async_generator_junction.py index 79db6a55d..4812ca522 100644 --- a/lbrynet/dht/protocol/async_generator_junction.py +++ b/lbrynet/dht/protocol/async_generator_junction.py @@ -26,12 +26,11 @@ class AsyncGeneratorJunction: def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None): self.loop = loop + self.__iterator_queue = asyncio.Queue(loop=loop) self.result_queue = queue or asyncio.Queue(loop=loop) self.tasks: typing.List[asyncio.Task] = [] self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {} self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop) - self.can_iterate = asyncio.Event(loop=self.loop) - self.finished = asyncio.Event(loop=self.loop) @property def running(self): @@ -42,15 +41,16 @@ class AsyncGeneratorJunction: try: async for item in iterator: self.result_queue.put_nowait(item) + self.__iterator_queue.put_nowait(item) finally: self.running_iterators[iterator] = False + if not self.running: + self.__iterator_queue.put_nowait(StopAsyncIteration) while True: async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get() self.running_iterators[async_gen] = True self.tasks.append(self.loop.create_task(iterate(async_gen))) - if not self.can_iterate.is_set(): - self.can_iterate.set() def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']): """ @@ -62,14 +62,10 @@ class AsyncGeneratorJunction: return self async def __anext__(self): - if not self.can_iterate.is_set(): - await self.can_iterate.wait() - if not self.running: - raise StopAsyncIteration() - try: - return await self.result_queue.get() - finally: - self.awaiting = None + result = await self.__iterator_queue.get() + if result is StopAsyncIteration: + raise result + return result def aclose(self): async def _aclose(): @@ -80,8 +76,6 @@ class AsyncGeneratorJunction: self.running_iterators[iterator] = False drain_tasks(self.tasks) raise StopAsyncIteration() - if not self.finished.is_set(): - self.finished.set() return self.loop.create_task(_aclose()) async def __aenter__(self):