import base64
import codecs
import datetime
import random
import socket
import time
import string
import sys
import json
import typing
import asyncio
import ssl
import logging
import ipaddress
import contextlib
import functools
import collections
import hashlib
import pkg_resources

import certifi
import aiohttp
from prometheus_client import Counter
from lbry.schema.claim import Claim


log = logging.getLogger(__name__)


# defining these time functions here allows for easier overriding in testing
def now():
    return datetime.datetime.now()


def utcnow():
    return datetime.datetime.utcnow()


def isonow():
    """Return utc now in isoformat with timezone"""
    return utcnow().isoformat() + 'Z'


def today():
    return datetime.datetime.today()


def timedelta(**kwargs):
    return datetime.timedelta(**kwargs)


def datetime_obj(*args, **kwargs):
    return datetime.datetime(*args, **kwargs)


def get_lbry_hash_obj():
    return hashlib.sha384()


def generate_id(num=None):
    h = get_lbry_hash_obj()
    if num is not None:
        h.update(str(num).encode())
    else:
        h.update(str(random.getrandbits(512)).encode())
    return h.digest()


def version_is_greater_than(version_a, version_b):
    """Returns True if version a is more recent than version b"""
    return pkg_resources.parse_version(version_a) > pkg_resources.parse_version(version_b)


def rot13(some_str):
    return codecs.encode(some_str, 'rot_13')


def deobfuscate(obfustacated):
    return base64.b64decode(rot13(obfustacated)).decode()


def obfuscate(plain):
    return rot13(base64.b64encode(plain).decode())


def check_connection(server="lbry.com", port=80, timeout=5) -> bool:
    """Attempts to open a socket to server:port and returns True if successful."""
    log.debug('Checking connection to %s:%s', server, port)
    try:
        server = socket.gethostbyname(server)
        socket.create_connection((server, port), timeout).close()
        return True
    except (socket.gaierror, socket.herror):
        log.debug("Failed to connect to %s:%s. Unable to resolve domain. Trying to bypass DNS",
                  server, port)
        try:
            server = "8.8.8.8"
            port = 53
            socket.create_connection((server, port), timeout).close()
            return True
        except OSError:
            return False
    except OSError:
        return False


def random_string(length=10, chars=string.ascii_lowercase):
    return ''.join([random.choice(chars) for _ in range(length)])


def short_hash(hash_str):
    return hash_str[:6]


def get_sd_hash(stream_info):
    if not stream_info:
        return None
    if isinstance(stream_info, Claim):
        return stream_info.stream.source.sd_hash
    result = stream_info.get('claim', {}).\
        get('value', {}).\
        get('stream', {}).\
        get('source', {}).\
        get('source')
    if not result:
        log.warning("Unable to get sd_hash")
    return result


def json_dumps_pretty(obj, **kwargs):
    return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)


def cancel_task(task: typing.Optional[asyncio.Task]):
    if task and not task.done():
        task.cancel()


def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
    for task in tasks:
        cancel_task(task)


def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
    while tasks:
        cancel_task(tasks.pop())


def async_timed_cache(duration: int):
    def wrapper(func):
        cache: typing.Dict[typing.Tuple,
                           typing.Tuple[typing.Any, float]] = {}

        @functools.wraps(func)
        async def _inner(*args, **kwargs) -> typing.Any:
            loop = asyncio.get_running_loop()
            time_now = loop.time()
            key = (args, tuple(kwargs.items()))
            if key in cache and (time_now - cache[key][1] < duration):
                return cache[key][0]
            to_cache = await func(*args, **kwargs)
            cache[key] = to_cache, time_now
            return to_cache
        return _inner
    return wrapper


def cache_concurrent(async_fn):
    """
    When the decorated function has concurrent calls made to it with the same arguments, only run it once
    """
    cache: typing.Dict = {}

    @functools.wraps(async_fn)
    async def wrapper(*args, **kwargs):
        key = (args, tuple(kwargs.items()))
        cache[key] = cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
        try:
            return await cache[key]
        finally:
            cache.pop(key, None)

    return wrapper


@async_timed_cache(300)
async def resolve_host(url: str, port: int, proto: str) -> str:
    if proto not in ['udp', 'tcp']:
        raise Exception("invalid protocol")
    if url.lower() == 'localhost':
        return '127.0.0.1'
    try:
        if ipaddress.ip_address(url):
            return url
    except ValueError:
        pass
    loop = asyncio.get_running_loop()
    return (await loop.getaddrinfo(
        url, port,
        proto=socket.IPPROTO_TCP if proto == 'tcp' else socket.IPPROTO_UDP,
        type=socket.SOCK_STREAM if proto == 'tcp' else socket.SOCK_DGRAM,
        family=socket.AF_INET
    ))[0][4][0]


class LRUCacheWithMetrics:
    __slots__ = [
        'capacity',
        'cache',
        '_track_metrics',
        'hits',
        'misses'
    ]

    def __init__(self, capacity: int, metric_name: typing.Optional[str] = None, namespace: str = "daemon_cache"):
        self.capacity = capacity
        self.cache = collections.OrderedDict()
        if metric_name is None:
            self._track_metrics = False
            self.hits = self.misses = None
        else:
            self._track_metrics = True
            try:
                self.hits = Counter(
                    f"{metric_name}_cache_hit_count", "Number of cache hits", namespace=namespace
                )
                self.misses = Counter(
                    f"{metric_name}_cache_miss_count", "Number of cache misses", namespace=namespace
                )
            except ValueError as err:
                log.debug("failed to set up prometheus %s_cache_miss_count metric: %s", metric_name, err)
                self._track_metrics = False
                self.hits = self.misses = None

    def get(self, key, default=None):
        try:
            value = self.cache.pop(key)
            if self._track_metrics:
                self.hits.inc()
        except KeyError:
            if self._track_metrics:
                self.misses.inc()
            return default
        self.cache[key] = value
        return value

    def set(self, key, value):
        try:
            self.cache.pop(key)
        except KeyError:
            if len(self.cache) >= self.capacity:
                self.cache.popitem(last=False)
        self.cache[key] = value

    def clear(self):
        self.cache.clear()

    def pop(self, key):
        return self.cache.pop(key)

    def __setitem__(self, key, value):
        return self.set(key, value)

    def __getitem__(self, item):
        return self.get(item)

    def __contains__(self, item) -> bool:
        return item in self.cache

    def __len__(self):
        return len(self.cache)

    def __delitem__(self, key):
        self.cache.pop(key)

    def __del__(self):
        self.clear()


class LRUCache:
    __slots__ = [
        'capacity',
        'cache'
    ]

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = collections.OrderedDict()

    def get(self, key, default=None):
        try:
            value = self.cache.pop(key)
        except KeyError:
            return default
        self.cache[key] = value
        return value

    def set(self, key, value):
        try:
            self.cache.pop(key)
        except KeyError:
            if len(self.cache) >= self.capacity:
                self.cache.popitem(last=False)
        self.cache[key] = value

    def items(self):
        return self.cache.items()

    def clear(self):
        self.cache.clear()

    def pop(self, key, default=None):
        return self.cache.pop(key, default)

    def __setitem__(self, key, value):
        return self.set(key, value)

    def __getitem__(self, item):
        return self.get(item)

    def __contains__(self, item) -> bool:
        return item in self.cache

    def __len__(self):
        return len(self.cache)

    def __delitem__(self, key):
        self.cache.pop(key)

    def __del__(self):
        self.clear()


def lru_cache_concurrent(cache_size: typing.Optional[int] = None,
                         override_lru_cache: typing.Optional[LRUCacheWithMetrics] = None):
    if not cache_size and override_lru_cache is None:
        raise ValueError("invalid cache size")
    concurrent_cache = {}
    lru_cache = override_lru_cache if override_lru_cache is not None else LRUCacheWithMetrics(cache_size)

    def wrapper(async_fn):

        @functools.wraps(async_fn)
        async def _inner(*args, **kwargs):
            key = (args, tuple(kwargs.items()))
            if key in lru_cache:
                return lru_cache.get(key)

            concurrent_cache[key] = concurrent_cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))

            try:
                result = await concurrent_cache[key]
                lru_cache.set(key, result)
                return result
            finally:
                concurrent_cache.pop(key, None)
        return _inner
    return wrapper


def get_ssl_context() -> ssl.SSLContext:
    return ssl.create_default_context(
        purpose=ssl.Purpose.CLIENT_AUTH, capath=certifi.where()
    )


@contextlib.asynccontextmanager
async def aiohttp_request(method, url, **kwargs) -> typing.AsyncContextManager[aiohttp.ClientResponse]:
    async with aiohttp.ClientSession() as session:
        async with session.request(method, url, **kwargs) as response:
            yield response


# the ipaddress module does not show these subnets as reserved
CARRIER_GRADE_NAT_SUBNET = ipaddress.ip_network('100.64.0.0/10')
IPV4_TO_6_RELAY_SUBNET = ipaddress.ip_network('192.88.99.0/24')


def is_valid_public_ipv4(address, allow_localhost: bool = False, allow_lan: bool = False):
    try:
        parsed_ip = ipaddress.ip_address(address)
        if parsed_ip.is_loopback and allow_localhost:
            return True
        if allow_lan and parsed_ip.is_private:
            return True
        if any((parsed_ip.version != 4, parsed_ip.is_unspecified, parsed_ip.is_link_local, parsed_ip.is_loopback,
                parsed_ip.is_multicast, parsed_ip.is_reserved, parsed_ip.is_private)):
            return False
        else:
            return not any((CARRIER_GRADE_NAT_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32")),
                            IPV4_TO_6_RELAY_SUBNET.supernet_of(ipaddress.ip_network(f"{address}/32"))))
    except (ipaddress.AddressValueError, ValueError):
        return False


async def fallback_get_external_ip():  # used if spv servers can't be used for ip detection
    try:
        async with aiohttp_request("get", "https://api.lbry.com/ip") as resp:
            response = await resp.json()
            if response['success']:
                return response['data']['ip'], None
    except Exception:
        return None, None


async def _get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
    # used if upnp is disabled or non-functioning
    from lbry.wallet.server.udp import SPVStatusClientProtocol  # pylint: disable=C0415

    hostname_to_ip = {}
    ip_to_hostnames = collections.defaultdict(list)

    async def resolve_spv(server, port):
        try:
            server_addr = await resolve_host(server, port, 'udp')
            hostname_to_ip[server] = (server_addr, port)
            ip_to_hostnames[(server_addr, port)].append(server)
        except Exception:
            log.exception("error looking up dns for spv servers")

    # accumulate the dns results
    await asyncio.gather(*(resolve_spv(server, port) for (server, port) in default_servers))

    loop = asyncio.get_event_loop()
    pong_responses = asyncio.Queue()
    connection = SPVStatusClientProtocol(pong_responses)
    try:
        await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
        # could raise OSError if it cant bind
        randomized_servers = list(ip_to_hostnames.keys())
        random.shuffle(randomized_servers)
        for server in randomized_servers:
            connection.ping(server)
            try:
                _, pong = await asyncio.wait_for(pong_responses.get(), 1)
                if is_valid_public_ipv4(pong.ip_address):
                    return pong.ip_address, ip_to_hostnames[server][0]
            except asyncio.TimeoutError:
                pass
        return None, None
    finally:
        connection.close()


async def get_external_ip(default_servers) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
    ip_from_spv_servers = await _get_external_ip(default_servers)
    if not ip_from_spv_servers[1]:
        return await fallback_get_external_ip()
    return ip_from_spv_servers


def is_running_from_bundle():
    # see https://pyinstaller.readthedocs.io/en/stable/runtime-information.html
    return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')


class LockWithMetrics(asyncio.Lock):
    def __init__(self, acquire_metric, held_time_metric, loop=None):
        super().__init__(loop=loop)
        self._acquire_metric = acquire_metric
        self._lock_held_time_metric = held_time_metric
        self._lock_acquired_time = None

    async def acquire(self):
        start = time.perf_counter()
        try:
            return await super().acquire()
        finally:
            self._lock_acquired_time = time.perf_counter()
            self._acquire_metric.observe(self._lock_acquired_time - start)

    def release(self):
        try:
            return super().release()
        finally:
            self._lock_held_time_metric.observe(time.perf_counter() - self._lock_acquired_time)