import asyncio
import logging
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor

from lbry.testcase import AsyncioTestCase
from lbry.event import EventController, EventQueuePublisher
from lbry.tasks import TaskGroup


class StreamControllerTestCase(AsyncioTestCase):

    async def test_non_unique_events(self):
        events = []
        controller = EventController()
        controller.stream.listen(events.append)
        await controller.add("yo")
        await controller.add("yo")
        self.assertListEqual(events, ["yo", "yo"])

    async def test_unique_events(self):
        events = []
        controller = EventController(merge_repeated_events=True)
        controller.stream.listen(events.append)
        await controller.add("yo")
        await controller.add("yo")
        self.assertListEqual(events, ["yo"])

    async def test_sync_listener_errors(self):
        def bad_listener(_):
            raise ValueError('bad')
        controller = EventController()
        controller.stream.listen(bad_listener)
        with self.assertRaises(ValueError):
            await controller.add("yo")

    async def test_async_listener_errors(self):
        async def bad_listener(_):
            raise ValueError('bad')
        controller = EventController()
        controller.stream.listen(bad_listener)
        with self.assertRaises(ValueError):
            await controller.add("yo")

    async def test_first_event(self):
        controller = EventController()
        first = controller.stream.first
        await controller.add("one")
        second = controller.stream.first
        await controller.add("two")
        self.assertEqual("one", await first)
        self.assertEqual("two", await second)

    async def test_last_event(self):
        controller = EventController()
        last = controller.stream.last
        await controller.add("one")
        await controller.add("two")
        await controller.close()
        self.assertEqual("two", await last)


class TestEventQueuePublisher(AsyncioTestCase):

    async def test_event_buffering_avoids_overloading_asyncio(self):
        threads = 3
        generate_events = 3000
        expected_event_count = (threads * generate_events)-1

        queue = mp.Queue()
        executor = ThreadPoolExecutor(max_workers=threads)
        controller = EventController()
        events = []

        async def event_logger(e):
            await asyncio.sleep(0)
            events.append(e)

        controller.stream.listen(event_logger)
        until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count)

        def event_producer(q, j):
            for i in range(generate_events):
                q.put(f'foo-{i}-{j}')

        with EventQueuePublisher(queue, controller), self.assertLogs() as logs:
            # assertLogs() requires that at least one message is logged
            # this is that one message:
            logging.getLogger().info("placeholder")
            await asyncio.wait([
                self.loop.run_in_executor(executor, event_producer, queue, j)
                for j in range(threads)
            ])
            await until_all_consumed
            # assert that there were no WARNINGs from asyncio about slow tasks
            # (should have exactly 1 log which is the placeholder above)
            self.assertEqual(['INFO:root:placeholder'], logs.output)


class TaskGroupTestCase(AsyncioTestCase):

    async def test_cancel_sets_it_done(self):
        group = TaskGroup()
        group.cancel()
        self.assertTrue(group.done.is_set())