fix generator junction

This commit is contained in:
Victor Shyba 2019-05-08 23:02:21 -03:00
parent 77c69f661d
commit efbf2f49a9

View file

@ -26,12 +26,11 @@ class AsyncGeneratorJunction:
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
self.loop = loop
self.__iterator_queue = asyncio.Queue(loop=loop)
self.result_queue = queue or asyncio.Queue(loop=loop)
self.tasks: typing.List[asyncio.Task] = []
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
self.can_iterate = asyncio.Event(loop=self.loop)
self.finished = asyncio.Event(loop=self.loop)
@property
def running(self):
@ -42,15 +41,16 @@ class AsyncGeneratorJunction:
try:
async for item in iterator:
self.result_queue.put_nowait(item)
self.__iterator_queue.put_nowait(item)
finally:
self.running_iterators[iterator] = False
if not self.running:
self.__iterator_queue.put_nowait(StopAsyncIteration)
while True:
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
self.running_iterators[async_gen] = True
self.tasks.append(self.loop.create_task(iterate(async_gen)))
if not self.can_iterate.is_set():
self.can_iterate.set()
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
"""
@ -62,14 +62,10 @@ class AsyncGeneratorJunction:
return self
async def __anext__(self):
if not self.can_iterate.is_set():
await self.can_iterate.wait()
if not self.running:
raise StopAsyncIteration()
try:
return await self.result_queue.get()
finally:
self.awaiting = None
result = await self.__iterator_queue.get()
if result is StopAsyncIteration:
raise result
return result
def aclose(self):
async def _aclose():
@ -80,8 +76,6 @@ class AsyncGeneratorJunction:
self.running_iterators[iterator] = False
drain_tasks(self.tasks)
raise StopAsyncIteration()
if not self.finished.is_set():
self.finished.set()
return self.loop.create_task(_aclose())
async def __aenter__(self):