lbry-sdk/lbrynet/dht/protocol/async_generator_junction.py

101 lines
3.4 KiB
Python
Raw Normal View History

2019-01-22 12:49:43 -05:00
import asyncio
import typing
import logging
import traceback
if typing.TYPE_CHECKING:
from types import AsyncGeneratorType
log = logging.getLogger(__name__)
def cancel_task(task: typing.Optional[asyncio.Task]):
if task and not (task.done() or task.cancelled()):
task.cancel()
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks:
cancel_task(tasks.pop())
class AsyncGeneratorJunction:
"""
A helper to interleave the results from multiple async generators into one
async generator.
"""
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
self.loop = loop
self.result_queue = queue or asyncio.Queue(loop=loop)
self.tasks: typing.List[asyncio.Task] = []
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
self.can_iterate = asyncio.Event(loop=self.loop)
self.finished = asyncio.Event(loop=self.loop)
@property
def running(self):
return any(self.running_iterators.values())
async def wait_for_generators(self):
async def iterate(iterator: typing.AsyncGenerator):
try:
async for item in iterator:
self.result_queue.put_nowait(item)
finally:
self.running_iterators[iterator] = False
while True:
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
self.running_iterators[async_gen] = True
self.tasks.append(self.loop.create_task(iterate(async_gen)))
if not self.can_iterate.is_set():
self.can_iterate.set()
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
"""
Add an async generator. This can be called during an iteration of the generator junction.
"""
self.generator_queue.put_nowait(async_gen)
def __aiter__(self):
return self
async def __anext__(self):
if not self.can_iterate.is_set():
await self.can_iterate.wait()
if not self.running:
raise StopAsyncIteration()
try:
return await self.result_queue.get()
finally:
self.awaiting = None
def aclose(self):
async def _aclose():
for iterator in list(self.running_iterators.keys()):
result = iterator.aclose()
if asyncio.iscoroutine(result):
await result
self.running_iterators[iterator] = False
drain_tasks(self.tasks)
raise StopAsyncIteration()
if not self.finished.is_set():
self.finished.set()
return self.loop.create_task(_aclose())
async def __aenter__(self):
self.tasks.append(self.loop.create_task(self.wait_for_generators()))
return self
async def __aexit__(self, exc_type, exc, tb):
try:
await self.aclose()
except StopAsyncIteration:
pass
finally:
if exc_type:
2019-01-28 15:06:33 -05:00
if exc_type not in (asyncio.CancelledError, asyncio.TimeoutError, StopAsyncIteration, GeneratorExit):
2019-01-22 12:49:43 -05:00
err = traceback.format_exception(exc_type, exc, tb)
log.error(err)