import os
import logging
import asyncio
import sqlite3
import platform
from binascii import hexlify
from dataclasses import dataclass
from contextvars import ContextVar
from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional


log = logging.getLogger(__name__)
sqlite3.enable_callback_tracebacks(True)


@dataclass
class ReaderProcessState:
    cursor: sqlite3.Cursor


reader_context: Optional[ContextVar[ReaderProcessState]] = ContextVar('reader_context')


def initializer(path):
    db = sqlite3.connect(path)
    db.row_factory = dict_row_factory
    db.executescript("pragma journal_mode=WAL;")
    reader = ReaderProcessState(db.cursor())
    reader_context.set(reader)


def run_read_only_fetchall(sql, params):
    cursor = reader_context.get().cursor
    try:
        return cursor.execute(sql, params).fetchall()
    except (Exception, OSError) as e:
        log.exception('Error running transaction:', exc_info=e)
        raise


def run_read_only_fetchone(sql, params):
    cursor = reader_context.get().cursor
    try:
        return cursor.execute(sql, params).fetchone()
    except (Exception, OSError) as e:
        log.exception('Error running transaction:', exc_info=e)
        raise


if platform.system() == 'Windows' or 'ANDROID_ARGUMENT' in os.environ:
    ReaderExecutorClass = ThreadPoolExecutor
else:
    ReaderExecutorClass = ProcessPoolExecutor


class AIOSQLite:
    reader_executor: ReaderExecutorClass

    def __init__(self):
        # has to be single threaded as there is no mapping of thread:connection
        self.writer_executor = ThreadPoolExecutor(max_workers=1)
        self.writer_connection: Optional[sqlite3.Connection] = None
        self._closing = False
        self.query_count = 0
        self.write_lock = asyncio.Lock()
        self.writers = 0
        self.read_ready = asyncio.Event()

    @classmethod
    async def connect(cls, path: Union[bytes, str], *args, **kwargs):
        sqlite3.enable_callback_tracebacks(True)
        db = cls()

        def _connect_writer():
            db.writer_connection = sqlite3.connect(path, *args, **kwargs)

        readers = max(os.cpu_count() - 2, 2)
        db.reader_executor = ReaderExecutorClass(
            max_workers=readers, initializer=initializer, initargs=(path, )
        )
        await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
        db.read_ready.set()
        return db

    async def close(self):
        if self._closing:
            return
        self._closing = True
        await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
        self.writer_executor.shutdown(wait=True)
        self.reader_executor.shutdown(wait=True)
        self.read_ready.clear()
        self.writer_connection = None

    def executemany(self, sql: str, params: Iterable):
        params = params if params is not None else []
        # this fetchall is needed to prevent SQLITE_MISUSE
        return self.run(lambda conn: conn.executemany(sql, params).fetchall())

    def executescript(self, script: str) -> Awaitable:
        return self.run(lambda conn: conn.executescript(script))

    async def _execute_fetch(self, sql: str, parameters: Iterable = None,
                             read_only=False, fetch_all: bool = False) -> List[dict]:
        read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
        parameters = parameters if parameters is not None else []
        if read_only:
            while self.writers:
                await self.read_ready.wait()
            return await asyncio.get_event_loop().run_in_executor(
                self.reader_executor, read_only_fn, sql, parameters
            )
        if fetch_all:
            return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
        return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())

    async def execute_fetchall(self, sql: str, parameters: Iterable = None,
                               read_only=False) -> List[dict]:
        return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)

    async def execute_fetchone(self, sql: str, parameters: Iterable = None,
                               read_only=False) -> List[dict]:
        return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)

    def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
        parameters = parameters if parameters is not None else []
        return self.run(lambda conn: conn.execute(sql, parameters))

    async def run(self, fun, *args, **kwargs):
        self.writers += 1
        self.read_ready.clear()
        async with self.write_lock:
            try:
                return await asyncio.get_event_loop().run_in_executor(
                    self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
                )
            finally:
                self.writers -= 1
                if not self.writers:
                    self.read_ready.set()

    def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
        self.writer_connection.execute('begin')
        try:
            self.query_count += 1
            result = fun(self.writer_connection, *args, **kwargs)  # type: ignore
            self.writer_connection.commit()
            return result
        except (Exception, OSError) as e:
            log.exception('Error running transaction:', exc_info=e)
            self.writer_connection.rollback()
            log.warning("rolled back")
            raise

    def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable:
        return asyncio.get_event_loop().run_in_executor(
            self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs
        )

    def __run_transaction_with_foreign_keys_disabled(self,
                                                     fun: Callable[[sqlite3.Connection, Any, Any], Any],
                                                     args, kwargs):
        foreign_keys_enabled, = self.writer_connection.execute("pragma foreign_keys").fetchone()
        if not foreign_keys_enabled:
            raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead")
        try:
            self.writer_connection.execute('pragma foreign_keys=off').fetchone()
            return self.__run_transaction(fun, *args, **kwargs)
        finally:
            self.writer_connection.execute('pragma foreign_keys=on').fetchone()


def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
    sql, values = [], {}
    for key, constraint in constraints.items():
        tag = '0'
        if '#' in key:
            key, tag = key[:key.index('#')], key[key.index('#')+1:]
        col, op, key = key, '=', key.replace('.', '_')
        if not key:
            sql.append(constraint)
            continue
        if key.startswith('$$'):
            col, key = col[2:], key[1:]
        elif key.startswith('$'):
            values[key] = constraint
            continue
        if key.endswith('__not'):
            col, op = col[:-len('__not')], '!='
        elif key.endswith('__is_null'):
            col = col[:-len('__is_null')]
            sql.append(f'{col} IS NULL')
            continue
        if key.endswith('__is_not_null'):
            col = col[:-len('__is_not_null')]
            sql.append(f'{col} IS NOT NULL')
            continue
        if key.endswith('__lt'):
            col, op = col[:-len('__lt')], '<'
        elif key.endswith('__lte'):
            col, op = col[:-len('__lte')], '<='
        elif key.endswith('__gt'):
            col, op = col[:-len('__gt')], '>'
        elif key.endswith('__gte'):
            col, op = col[:-len('__gte')], '>='
        elif key.endswith('__like'):
            col, op = col[:-len('__like')], 'LIKE'
        elif key.endswith('__not_like'):
            col, op = col[:-len('__not_like')], 'NOT LIKE'
        elif key.endswith('__in') or key.endswith('__not_in'):
            if key.endswith('__in'):
                col, op, one_val_op = col[:-len('__in')], 'IN', '='
            else:
                col, op, one_val_op = col[:-len('__not_in')], 'NOT IN', '!='
            if constraint:
                if isinstance(constraint, (list, set, tuple)):
                    if len(constraint) == 1:
                        values[f'{key}{tag}'] = next(iter(constraint))
                        sql.append(f'{col} {one_val_op} :{key}{tag}')
                    else:
                        keys = []
                        for i, val in enumerate(constraint):
                            keys.append(f':{key}{tag}_{i}')
                            values[f'{key}{tag}_{i}'] = val
                        sql.append(f'{col} {op} ({", ".join(keys)})')
                elif isinstance(constraint, str):
                    sql.append(f'{col} {op} ({constraint})')
                else:
                    raise ValueError(f"{col} requires a list, set or string as constraint value.")
            continue
        elif key.endswith('__any') or key.endswith('__or'):
            where, subvalues = constraints_to_sql(constraint, ' OR ', key+tag+'_')
            sql.append(f'({where})')
            values.update(subvalues)
            continue
        if key.endswith('__and'):
            where, subvalues = constraints_to_sql(constraint, ' AND ', key+tag+'_')
            sql.append(f'({where})')
            values.update(subvalues)
            continue
        sql.append(f'{col} {op} :{prepend_key}{key}{tag}')
        values[prepend_key+key+tag] = constraint
    return joiner.join(sql) if sql else '', values


def query(select, **constraints) -> Tuple[str, Dict[str, Any]]:
    sql = [select]
    limit = constraints.pop('limit', None)
    offset = constraints.pop('offset', None)
    order_by = constraints.pop('order_by', None)
    group_by = constraints.pop('group_by', None)

    accounts = constraints.pop('accounts', [])
    if accounts:
        constraints['account__in'] = [a.public_key.address for a in accounts]

    where, values = constraints_to_sql(constraints)
    if where:
        sql.append('WHERE')
        sql.append(where)

    if group_by is not None:
        sql.append(f'GROUP BY {group_by}')

    if order_by:
        sql.append('ORDER BY')
        if isinstance(order_by, str):
            sql.append(order_by)
        elif isinstance(order_by, list):
            sql.append(', '.join(order_by))
        else:
            raise ValueError("order_by must be string or list")

    if limit is not None:
        sql.append(f'LIMIT {limit}')

    if offset is not None:
        sql.append(f'OFFSET {offset}')

    return ' '.join(sql), values


def interpolate(sql, values):
    for k in sorted(values.keys(), reverse=True):
        value = values[k]
        if isinstance(value, bytes):
            value = f"X'{hexlify(value).decode()}'"
        elif isinstance(value, str):
            value = f"'{value}'"
        else:
            value = str(value)
        sql = sql.replace(f":{k}", value)
    return sql


def constrain_single_or_list(constraints, column, value, convert=lambda x: x):
    if value is not None:
        if isinstance(value, list):
            value = [convert(v) for v in value]
            if len(value) == 1:
                constraints[column] = value[0]
            elif len(value) > 1:
                constraints[f"{column}__in"] = value
        else:
            constraints[column] = convert(value)
    return constraints


class SQLiteMixin:

    SCHEMA_VERSION: Optional[str] = None
    CREATE_TABLES_QUERY: str
    MAX_QUERY_VARIABLES = 900

    CREATE_VERSION_TABLE = """
        create table if not exists version (
            version text
        );
    """

    def __init__(self, path):
        self._db_path = path
        self.db: AIOSQLite = None
        self.ledger = None

    async def open(self):
        log.info("connecting to database: %s", self._db_path)
        self.db = await AIOSQLite.connect(self._db_path, isolation_level=None)
        if self.SCHEMA_VERSION:
            tables = [t[0] for t in await self.db.execute_fetchall(
                "SELECT name FROM sqlite_master WHERE type='table';"
            )]
            if tables:
                if 'version' in tables:
                    version = await self.db.execute_fetchone("SELECT version FROM version LIMIT 1;")
                    if version == (self.SCHEMA_VERSION,):
                        return
                await self.db.executescript('\n'.join(
                    f"DROP TABLE {table};" for table in tables
                ))
            await self.db.execute(self.CREATE_VERSION_TABLE)
            await self.db.execute("INSERT INTO version VALUES (?)", (self.SCHEMA_VERSION,))
        await self.db.executescript(self.CREATE_TABLES_QUERY)

    async def close(self):
        await self.db.close()

    @staticmethod
    def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False,
                    replace: bool = False) -> Tuple[str, List]:
        columns, values = [], []
        for column, value in data.items():
            columns.append(column)
            values.append(value)
        policy = ""
        if ignore_duplicate:
            policy = " OR IGNORE"
        if replace:
            policy = " OR REPLACE"
        sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
            policy, table, ', '.join(columns), ', '.join(['?'] * len(values))
        )
        return sql, values

    @staticmethod
    def _update_sql(table: str, data: dict, where: str,
                    constraints: Union[list, tuple]) -> Tuple[str, list]:
        columns, values = [], []
        for column, value in data.items():
            columns.append(f"{column} = ?")
            values.append(value)
        values.extend(constraints)
        sql = "UPDATE {} SET {} WHERE {}".format(
            table, ', '.join(columns), where
        )
        return sql, values


def dict_row_factory(cursor, row):
    d = {}
    for idx, col in enumerate(cursor.description):
        d[col[0]] = row[idx]
    return d