forked from LBRYCommunity/lbry-sdk
fix generator junction
This commit is contained in:
parent
77c69f661d
commit
efbf2f49a9
1 changed files with 8 additions and 14 deletions
|
@ -26,12 +26,11 @@ class AsyncGeneratorJunction:
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
|
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
self.__iterator_queue = asyncio.Queue(loop=loop)
|
||||||
self.result_queue = queue or asyncio.Queue(loop=loop)
|
self.result_queue = queue or asyncio.Queue(loop=loop)
|
||||||
self.tasks: typing.List[asyncio.Task] = []
|
self.tasks: typing.List[asyncio.Task] = []
|
||||||
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
|
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
|
||||||
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
|
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
|
||||||
self.can_iterate = asyncio.Event(loop=self.loop)
|
|
||||||
self.finished = asyncio.Event(loop=self.loop)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self):
|
def running(self):
|
||||||
|
@ -42,15 +41,16 @@ class AsyncGeneratorJunction:
|
||||||
try:
|
try:
|
||||||
async for item in iterator:
|
async for item in iterator:
|
||||||
self.result_queue.put_nowait(item)
|
self.result_queue.put_nowait(item)
|
||||||
|
self.__iterator_queue.put_nowait(item)
|
||||||
finally:
|
finally:
|
||||||
self.running_iterators[iterator] = False
|
self.running_iterators[iterator] = False
|
||||||
|
if not self.running:
|
||||||
|
self.__iterator_queue.put_nowait(StopAsyncIteration)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
|
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
|
||||||
self.running_iterators[async_gen] = True
|
self.running_iterators[async_gen] = True
|
||||||
self.tasks.append(self.loop.create_task(iterate(async_gen)))
|
self.tasks.append(self.loop.create_task(iterate(async_gen)))
|
||||||
if not self.can_iterate.is_set():
|
|
||||||
self.can_iterate.set()
|
|
||||||
|
|
||||||
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
|
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
|
||||||
"""
|
"""
|
||||||
|
@ -62,14 +62,10 @@ class AsyncGeneratorJunction:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
if not self.can_iterate.is_set():
|
result = await self.__iterator_queue.get()
|
||||||
await self.can_iterate.wait()
|
if result is StopAsyncIteration:
|
||||||
if not self.running:
|
raise result
|
||||||
raise StopAsyncIteration()
|
return result
|
||||||
try:
|
|
||||||
return await self.result_queue.get()
|
|
||||||
finally:
|
|
||||||
self.awaiting = None
|
|
||||||
|
|
||||||
def aclose(self):
|
def aclose(self):
|
||||||
async def _aclose():
|
async def _aclose():
|
||||||
|
@ -80,8 +76,6 @@ class AsyncGeneratorJunction:
|
||||||
self.running_iterators[iterator] = False
|
self.running_iterators[iterator] = False
|
||||||
drain_tasks(self.tasks)
|
drain_tasks(self.tasks)
|
||||||
raise StopAsyncIteration()
|
raise StopAsyncIteration()
|
||||||
if not self.finished.is_set():
|
|
||||||
self.finished.set()
|
|
||||||
return self.loop.create_task(_aclose())
|
return self.loop.create_task(_aclose())
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
|
Loading…
Add table
Reference in a new issue