Merge pull request #2100 from lbryio/junction_fixes

DHT junction fixes for cases that can lock up
This commit is contained in:
Jack Robison 2019-05-08 23:52:54 -04:00 committed by GitHub
commit c5ae59bff7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 16 deletions

View file

@ -26,12 +26,11 @@ class AsyncGeneratorJunction:
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None): def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
self.loop = loop self.loop = loop
self.__iterator_queue = asyncio.Queue(loop=loop)
self.result_queue = queue or asyncio.Queue(loop=loop) self.result_queue = queue or asyncio.Queue(loop=loop)
self.tasks: typing.List[asyncio.Task] = [] self.tasks: typing.List[asyncio.Task] = []
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {} self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop) self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
self.can_iterate = asyncio.Event(loop=self.loop)
self.finished = asyncio.Event(loop=self.loop)
@property @property
def running(self): def running(self):
@ -42,15 +41,16 @@ class AsyncGeneratorJunction:
try: try:
async for item in iterator: async for item in iterator:
self.result_queue.put_nowait(item) self.result_queue.put_nowait(item)
self.__iterator_queue.put_nowait(item)
finally: finally:
self.running_iterators[iterator] = False self.running_iterators[iterator] = False
if not self.running:
self.__iterator_queue.put_nowait(StopAsyncIteration)
while True: while True:
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get() async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
self.running_iterators[async_gen] = True self.running_iterators[async_gen] = True
self.tasks.append(self.loop.create_task(iterate(async_gen))) self.tasks.append(self.loop.create_task(iterate(async_gen)))
if not self.can_iterate.is_set():
self.can_iterate.set()
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']): def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
""" """
@ -62,14 +62,10 @@ class AsyncGeneratorJunction:
return self return self
async def __anext__(self): async def __anext__(self):
if not self.can_iterate.is_set(): result = await self.__iterator_queue.get()
await self.can_iterate.wait() if result is StopAsyncIteration:
if not self.running: raise result
raise StopAsyncIteration() return result
try:
return await self.result_queue.get()
finally:
self.awaiting = None
def aclose(self): def aclose(self):
async def _aclose(): async def _aclose():
@ -80,8 +76,6 @@ class AsyncGeneratorJunction:
self.running_iterators[iterator] = False self.running_iterators[iterator] = False
drain_tasks(self.tasks) drain_tasks(self.tasks)
raise StopAsyncIteration() raise StopAsyncIteration()
if not self.finished.is_set():
self.finished.set()
return self.loop.create_task(_aclose()) return self.loop.create_task(_aclose())
async def __aenter__(self): async def __aenter__(self):

View file

@ -169,7 +169,9 @@ class IterativeFinder:
log.warning(str(err)) log.warning(str(err))
self.active.discard(peer) self.active.discard(peer)
return return
except (RemoteException, TransportNotConnected): except TransportNotConnected:
return self.aclose()
except RemoteException:
return return
return await self._handle_probe_result(peer, response) return await self._handle_probe_result(peer, response)

View file

@ -17,10 +17,10 @@ class MockAsyncGen:
return self return self
async def __anext__(self): async def __anext__(self):
await asyncio.sleep(self.delay, loop=self.loop)
if self.count > self.stop_cnt - 1: if self.count > self.stop_cnt - 1:
raise StopAsyncIteration() raise StopAsyncIteration()
self.count += 1 self.count += 1
await asyncio.sleep(self.delay, loop=self.loop)
return self.result return self.result
async def aclose(self): async def aclose(self):
@ -48,6 +48,18 @@ class TestAsyncGeneratorJunction(AsyncioTestCase):
self.assertEqual(fast_gen.called_close, True) self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True) self.assertEqual(slow_gen.called_close, True)
async def test_nothing_to_yield(self):
async def __nothing():
for _ in []:
yield self.fail("nada")
await self._test_junction([], __nothing())
async def test_fast_iteratiors(self):
async def __gotta_go_fast():
for _ in range(10):
yield 0
await self._test_junction([0]*40, __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast())
@unittest.SkipTest @unittest.SkipTest
async def test_one_stopped_first(self): async def test_one_stopped_first(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]