import asyncio
import random
import logging
import typing
import itertools

from lbrynet.dht import constants
from lbrynet.dht.protocol.distance import Distance
if typing.TYPE_CHECKING:
    from lbrynet.dht.peer import KademliaPeer, PeerManager

log = logging.getLogger(__name__)


class KBucket:
    """ Description - later
    """

    def __init__(self, peer_manager: 'PeerManager', range_min: int, range_max: int, node_id: bytes):
        """
        @param range_min: The lower boundary for the range in the n-bit ID
                         space covered by this k-bucket
        @param range_max: The upper boundary for the range in the ID space
                         covered by this k-bucket
        """
        self._peer_manager = peer_manager
        self.last_accessed = 0
        self.range_min = range_min
        self.range_max = range_max
        self.peers: typing.List['KademliaPeer'] = []
        self._node_id = node_id

    def add_peer(self, peer: 'KademliaPeer') -> bool:
        """ Add contact to _contact list in the right order. This will move the
        contact to the end of the k-bucket if it is already present.

        @raise kademlia.kbucket.BucketFull: Raised when the bucket is full and
                                            the contact isn't in the bucket
                                            already

        @param peer: The contact to add
        @type peer: dht.contact._Contact
        """
        if peer in self.peers:
            # Move the existing contact to the end of the list
            # - using the new contact to allow add-on data
            #   (e.g. optimization-specific stuff) to pe updated as well
            self.peers.remove(peer)
            self.peers.append(peer)
            return True
        elif len(self.peers) < constants.k:
            self.peers.append(peer)
            return True
        else:
            return False
            # raise BucketFull("No space in bucket to insert contact")

    def get_peer(self, node_id: bytes) -> 'KademliaPeer':
        for peer in self.peers:
            if peer.node_id == node_id:
                return peer
        raise IndexError(node_id)

    def get_peers(self, count=-1, exclude_contact=None, sort_distance_to=None) -> typing.List['KademliaPeer']:
        """ Returns a list containing up to the first count number of contacts

        @param count: The amount of contacts to return (if 0 or less, return
                      all contacts)
        @type count: int
        @param exclude_contact: A node node_id to exclude; if this contact is in
                               the list of returned values, it will be
                               discarded before returning. If a C{str} is
                               passed as this argument, it must be the
                               contact's ID.
        @type exclude_contact: str

        @param sort_distance_to: Sort distance to the node_id, defaulting to the parent node node_id. If False don't
                                 sort the contacts

        @raise IndexError: If the number of requested contacts is too large

        @return: Return up to the first count number of contacts in a list
                If no contacts are present an empty is returned
        @rtype: list
        """
        peers = [peer for peer in self.peers if peer.node_id != exclude_contact]

        # Return all contacts in bucket
        if count <= 0:
            count = len(peers)

        # Get current contact number
        current_len = len(peers)

        # If count greater than k - return only k contacts
        if count > constants.k:
            count = constants.k

        if not current_len:
            return peers

        if sort_distance_to is False:
            pass
        else:
            sort_distance_to = sort_distance_to or self._node_id
            peers.sort(key=lambda c: Distance(sort_distance_to)(c.node_id))

        return peers[:min(current_len, count)]

    def get_bad_or_unknown_peers(self) -> typing.List['KademliaPeer']:
        peer = self.get_peers(sort_distance_to=False)
        return [
            peer for peer in peer
            if self._peer_manager.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port) is not True
        ]

    def remove_peer(self, peer: 'KademliaPeer') -> None:
        self.peers.remove(peer)

    def key_in_range(self, key: bytes) -> bool:
        """ Tests whether the specified key (i.e. node ID) is in the range
        of the n-bit ID space covered by this k-bucket (in otherwords, it
        returns whether or not the specified key should be placed in this
        k-bucket)

        @param key: The key to test
        @type key: str or int

        @return: C{True} if the key is in this k-bucket's range, or C{False}
                 if not.
        @rtype: bool
        """
        return self.range_min <= int.from_bytes(key, 'big') < self.range_max

    def __len__(self) -> int:
        return len(self.peers)

    def __contains__(self, item) -> bool:
        return item in self.peers


class TreeRoutingTable:
    """ This class implements a routing table used by a Node class.

    The Kademlia routing table is a binary tree whose leaves are k-buckets,
    where each k-bucket contains nodes with some common prefix of their IDs.
    This prefix is the k-bucket's position in the binary tree; it therefore
    covers some range of ID values, and together all of the k-buckets cover
    the entire n-bit ID (or key) space (with no overlap).

    @note: In this implementation, nodes in the tree (the k-buckets) are
    added dynamically, as needed; this technique is described in the 13-page
    version of the Kademlia paper, in section 2.4. It does, however, use the
    ping RPC-based k-bucket eviction algorithm described in section 2.2 of
    that paper.
    """

    def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', parent_node_id: bytes):
        self._loop = loop
        self._peer_manager = peer_manager
        self._parent_node_id = parent_node_id
        self.buckets: typing.List[KBucket] = [
            KBucket(
                self._peer_manager, range_min=0, range_max=2 ** constants.hash_bits, node_id=self._parent_node_id
            )
        ]

    def get_peers(self) -> typing.List['KademliaPeer']:
        return list(itertools.chain.from_iterable(map(lambda bucket: bucket.peers, self.buckets)))

    def should_split(self, bucket_index: int, to_add: bytes) -> bool:
        #  https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
        if self.buckets[bucket_index].key_in_range(self._parent_node_id):
            return True
        contacts = self.get_peers()
        distance = Distance(self._parent_node_id)
        contacts.sort(key=lambda c: distance(c.node_id))
        kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k - 1]
        return distance(to_add) < distance(kth_contact.node_id)

    def find_close_peers(self, key: bytes, count: typing.Optional[int] = None,
                         sender_node_id: typing.Optional[bytes] = None) -> typing.List['KademliaPeer']:
        exclude = [self._parent_node_id]
        if sender_node_id:
            exclude.append(sender_node_id)
        if key in exclude:
            exclude.remove(key)
        count = count or constants.k
        distance = Distance(key)
        contacts = self.get_peers()
        contacts = [c for c in contacts if c.node_id not in exclude]
        if contacts:
            contacts.sort(key=lambda c: distance(c.node_id))
            return contacts[:min(count, len(contacts))]
        return []

    def get_peer(self, contact_id: bytes) -> 'KademliaPeer':
        """
        @raise IndexError: No contact with the specified contact ID is known
                           by this node
        """
        return self.buckets[self.kbucket_index(contact_id)].get_peer(contact_id)

    def get_refresh_list(self, start_index: int = 0, force: bool = False) -> typing.List[bytes]:
        bucket_index = start_index
        refresh_ids = []
        now = int(self._loop.time())
        for bucket in self.buckets[start_index:]:
            if force or now - bucket.last_accessed >= constants.refresh_interval:
                to_search = self.midpoint_id_in_bucket_range(bucket_index)
                refresh_ids.append(to_search)
            bucket_index += 1
        return refresh_ids

    def remove_peer(self, peer: 'KademliaPeer') -> None:
        if not peer.node_id:
            return
        bucket_index = self.kbucket_index(peer.node_id)
        try:
            self.buckets[bucket_index].remove_peer(peer)
        except ValueError:
            return

    def touch_kbucket(self, key: bytes) -> None:
        self.touch_kbucket_by_index(self.kbucket_index(key))

    def touch_kbucket_by_index(self, bucket_index: int):
        self.buckets[bucket_index].last_accessed = int(self._loop.time())

    def kbucket_index(self, key: bytes) -> int:
        i = 0
        for bucket in self.buckets:
            if bucket.key_in_range(key):
                return i
            else:
                i += 1
        return i

    def random_id_in_bucket_range(self, bucket_index: int) -> bytes:
        random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max))
        return random_id.to_bytes(constants.hash_length, 'big')

    def midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes:
        half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2)
        return int(self.buckets[bucket_index].range_min + half).to_bytes(constants.hash_length, 'big')

    def split_bucket(self, old_bucket_index: int) -> None:
        """ Splits the specified k-bucket into two new buckets which together
        cover the same range in the key/ID space

        @param old_bucket_index: The index of k-bucket to split (in this table's
                                 list of k-buckets)
        @type old_bucket_index: int
        """
        # Resize the range of the current (old) k-bucket
        old_bucket = self.buckets[old_bucket_index]
        split_point = old_bucket.range_max - (old_bucket.range_max - old_bucket.range_min) // 2
        # Create a new k-bucket to cover the range split off from the old bucket
        new_bucket = KBucket(self._peer_manager, split_point, old_bucket.range_max, self._parent_node_id)
        old_bucket.range_max = split_point
        # Now, add the new bucket into the routing table tree
        self.buckets.insert(old_bucket_index + 1, new_bucket)
        # Finally, copy all nodes that belong to the new k-bucket into it...
        for contact in old_bucket.peers:
            if new_bucket.key_in_range(contact.node_id):
                new_bucket.add_peer(contact)
        # ...and remove them from the old bucket
        for contact in new_bucket.peers:
            old_bucket.remove_peer(contact)

    def join_buckets(self):
        to_pop = [i for i, bucket in enumerate(self.buckets) if not len(bucket)]
        if not to_pop:
            return
        log.info("join buckets %i", len(to_pop))
        bucket_index_to_pop = to_pop[0]
        assert len(self.buckets[bucket_index_to_pop]) == 0
        can_go_lower = bucket_index_to_pop - 1 >= 0
        can_go_higher = bucket_index_to_pop + 1 < len(self.buckets)
        assert can_go_higher or can_go_lower
        bucket = self.buckets[bucket_index_to_pop]
        if can_go_lower and can_go_higher:
            midpoint = ((bucket.range_max - bucket.range_min) // 2) + bucket.range_min
            self.buckets[bucket_index_to_pop - 1].range_max = midpoint - 1
            self.buckets[bucket_index_to_pop + 1].range_min = midpoint
        elif can_go_lower:
            self.buckets[bucket_index_to_pop - 1].range_max = bucket.range_max
        elif can_go_higher:
            self.buckets[bucket_index_to_pop + 1].range_min = bucket.range_min
        self.buckets.remove(bucket)
        return self.join_buckets()

    def contact_in_routing_table(self, address_tuple: typing.Tuple[str, int]) -> bool:
        for bucket in self.buckets:
            for contact in bucket.get_peers(sort_distance_to=False):
                if address_tuple[0] == contact.address and address_tuple[1] == contact.udp_port:
                    return True
        return False

    def buckets_with_contacts(self) -> int:
        count = 0
        for bucket in self.buckets:
            if len(bucket):
                count += 1
        return count