from collections import defaultdict
import logging
import os
import unqlite
import time
from Crypto.Hash import SHA512
from Crypto.PublicKey import RSA
from lbrynet.core.client.ClientRequest import ClientRequest
from lbrynet.core.Error import RequestCanceledError
from lbrynet.interfaces import IRequestCreator, IQueryHandlerFactory, IQueryHandler, IWallet
from lbrynet.pointtraderclient import pointtraderclient
from twisted.internet import defer, threads
from zope.interface import implements
from twisted.python.failure import Failure
from lbrynet.core.Wallet import ReservedPoints


log = logging.getLogger(__name__)


class PTCWallet(object):
    """This class sends payments to peers and also ensures that expected payments are received.
       This class is only intended to be used for testing."""
    implements(IWallet)

    def __init__(self, db_dir):
        self.db_dir = db_dir
        self.db = None
        self.private_key = None
        self.encoded_public_key = None
        self.peer_pub_keys = {}
        self.queued_payments = defaultdict(int)
        self.expected_payments = defaultdict(list)
        self.received_payments = defaultdict(list)
        self.next_manage_call = None
        self.payment_check_window = 3 * 60  # 3 minutes
        self.new_payments_expected_time = time.time() - self.payment_check_window
        self.known_transactions = []
        self.total_reserved_points = 0.0
        self.wallet_balance = 0.0

    def manage(self):
        """Send payments, ensure expected payments are received"""

        from twisted.internet import reactor

        if time.time() < self.new_payments_expected_time + self.payment_check_window:
            d1 = self._get_new_payments()
        else:
            d1 = defer.succeed(None)
        d1.addCallback(lambda _: self._check_good_standing())
        d2 = self._send_queued_points()
        self.next_manage_call = reactor.callLater(60, self.manage)
        dl = defer.DeferredList([d1, d2])
        dl.addCallback(lambda _: self.get_balance())

        def set_balance(balance):
            self.wallet_balance = balance

        dl.addCallback(set_balance)
        return dl

    def stop(self):
        if self.next_manage_call is not None:
            self.next_manage_call.cancel()
            self.next_manage_call = None
        d = self.manage()
        self.next_manage_call.cancel()
        self.next_manage_call = None
        self.db = None
        return d

    def start(self):

        def save_key(success, private_key):
            if success is True:
                self._save_private_key(private_key.exportKey())
                return True
            return False

        def register_private_key(private_key):
            self.private_key = private_key
            self.encoded_public_key = self.private_key.publickey().exportKey()
            d_r = pointtraderclient.register_new_account(private_key)
            d_r.addCallback(save_key, private_key)
            return d_r

        def ensure_private_key_exists(encoded_private_key):
            if encoded_private_key is not None:
                self.private_key = RSA.importKey(encoded_private_key)
                self.encoded_public_key = self.private_key.publickey().exportKey()
                return True
            else:
                create_d = threads.deferToThread(RSA.generate, 4096)
                create_d.addCallback(register_private_key)
                return create_d

        def start_manage():
            self.manage()
            return True
        d = self._open_db()
        d.addCallback(lambda _: self._get_wallet_private_key())
        d.addCallback(ensure_private_key_exists)
        d.addCallback(lambda _: start_manage())
        return d

    def get_info_exchanger(self):
        return PointTraderKeyExchanger(self)

    def get_wallet_info_query_handler_factory(self):
        return PointTraderKeyQueryHandlerFactory(self)

    def reserve_points(self, peer, amount):
        """Ensure a certain amount of points are available to be sent as
        payment, before the service is rendered

        @param peer: The peer to which the payment will ultimately be sent

        @param amount: The amount of points to reserve

        @return: A ReservedPoints object which is given to send_points
        once the service has been rendered

        """
        if self.wallet_balance >= self.total_reserved_points + amount:
            self.total_reserved_points += amount
            return ReservedPoints(peer, amount)
        return None

    def cancel_point_reservation(self, reserved_points):
        """
        Return all of the points that were reserved previously for some ReservedPoints object

        @param reserved_points: ReservedPoints previously returned by reserve_points

        @return: None
        """
        self.total_reserved_points -= reserved_points.amount

    def send_points(self, reserved_points, amount):
        """
        Schedule a payment to be sent to a peer

        @param reserved_points: ReservedPoints object previously returned by reserve_points

        @param amount: amount of points to actually send, must be less than or equal to the
            amount reserved in reserved_points

        @return: Deferred which fires when the payment has been scheduled
        """
        self.queued_payments[reserved_points.identifier] += amount
        # make any unused points available
        self.total_reserved_points -= reserved_points.amount - amount
        reserved_points.identifier.update_stats('points_sent', amount)
        d = defer.succeed(True)
        return d

    def _send_queued_points(self):
        ds = []
        for peer, points in self.queued_payments.items():
            if peer in self.peer_pub_keys:
                d = pointtraderclient.send_points(
                    self.private_key, self.peer_pub_keys[peer], points)
                self.wallet_balance -= points
                self.total_reserved_points -= points
                ds.append(d)
                del self.queued_payments[peer]
            else:
                log.warning("Don't have a payment address for peer %s. Can't send %s points.",
                            str(peer), str(points))
        return defer.DeferredList(ds)

    def get_balance(self):
        """Return the balance of this wallet"""
        d = pointtraderclient.get_balance(self.private_key)
        return d

    def add_expected_payment(self, peer, amount):
        """Increase the number of points expected to be paid by a peer"""
        self.expected_payments[peer].append((amount, time.time()))
        self.new_payments_expected_time = time.time()
        peer.update_stats('expected_points', amount)

    def set_public_key_for_peer(self, peer, pub_key):
        self.peer_pub_keys[peer] = pub_key

    def _get_new_payments(self):

        def add_new_transactions(transactions):
            for transaction in transactions:
                if transaction[1] == self.encoded_public_key:
                    t_hash = SHA512.new()
                    t_hash.update(transaction[0])
                    t_hash.update(transaction[1])
                    t_hash.update(str(transaction[2]))
                    t_hash.update(transaction[3])
                    if t_hash.hexdigest() not in self.known_transactions:
                        self.known_transactions.append(t_hash.hexdigest())
                        self._add_received_payment(transaction[0], transaction[2])

        d = pointtraderclient.get_recent_transactions(self.private_key)
        d.addCallback(add_new_transactions)
        return d

    def _add_received_payment(self, encoded_other_public_key, amount):
        self.received_payments[encoded_other_public_key].append((amount, time.time()))

    def _check_good_standing(self):
        for peer, expected_payments in self.expected_payments.iteritems():
            expected_cutoff = time.time() - 90
            min_expected_balance = sum([a[0] for a in expected_payments if a[1] < expected_cutoff])
            received_balance = 0
            if self.peer_pub_keys[peer] in self.received_payments:
                received_balance = sum([
                    a[0] for a in self.received_payments[self.peer_pub_keys[peer]]])
            if min_expected_balance > received_balance:
                log.warning(
                    "Account in bad standing: %s (pub_key: %s), expected amount = %s, "
                    "received_amount = %s",
                    peer, self.peer_pub_keys[peer], min_expected_balance, received_balance)

    def _open_db(self):
        def open_db():
            self.db = unqlite.UnQLite(os.path.join(self.db_dir, "ptcwallet.db"))
        return threads.deferToThread(open_db)

    def _save_private_key(self, private_key):
        def save_key():
            self.db['private_key'] = private_key
        return threads.deferToThread(save_key)

    def _get_wallet_private_key(self):
        def get_key():
            if 'private_key' in self.db:
                return self.db['private_key']
            return None
        return threads.deferToThread(get_key)


class PointTraderKeyExchanger(object):
    implements([IRequestCreator])

    def __init__(self, wallet):
        self.wallet = wallet
        self._protocols = []

    ######### IRequestCreator #########

    def send_next_request(self, peer, protocol):
        if not protocol in self._protocols:
            r = ClientRequest({'public_key': self.wallet.encoded_public_key},
                              'public_key')
            d = protocol.add_request(r)
            d.addCallback(self._handle_exchange_response, peer, r, protocol)
            d.addErrback(self._request_failed, peer)
            self._protocols.append(protocol)
            return defer.succeed(True)
        else:
            return defer.succeed(False)

    ######### internal calls #########

    def _handle_exchange_response(self, response_dict, peer, request, protocol):
        assert request.response_identifier in response_dict, \
            "Expected %s in dict but did not get it" % request.response_identifier
        assert protocol in self._protocols, "Responding protocol is not in our list of protocols"
        peer_pub_key = response_dict[request.response_identifier]
        self.wallet.set_public_key_for_peer(peer, peer_pub_key)
        return True

    def _request_failed(self, err, peer):
        if not err.check(RequestCanceledError):
            log.warning("A peer failed to send a valid public key response. Error: %s, peer: %s",
                        err.getErrorMessage(), str(peer))
            return err


class PointTraderKeyQueryHandlerFactory(object):
    implements(IQueryHandlerFactory)

    def __init__(self, wallet):
        self.wallet = wallet

    ######### IQueryHandlerFactory #########

    def build_query_handler(self):
        q_h = PointTraderKeyQueryHandler(self.wallet)
        return q_h

    def get_primary_query_identifier(self):
        return 'public_key'

    def get_description(self):
        return ("Point Trader Address - an address for receiving payments on the "
                "point trader testing network")


class PointTraderKeyQueryHandler(object):
    implements(IQueryHandler)

    def __init__(self, wallet):
        self.wallet = wallet
        self.query_identifiers = ['public_key']
        self.public_key = None
        self.peer = None

    ######### IQueryHandler #########

    def register_with_request_handler(self, request_handler, peer):
        self.peer = peer
        request_handler.register_query_handler(self, self.query_identifiers)

    def handle_queries(self, queries):
        if self.query_identifiers[0] in queries:
            new_encoded_pub_key = queries[self.query_identifiers[0]]
            try:
                RSA.importKey(new_encoded_pub_key)
            except (ValueError, TypeError, IndexError):
                log.warning("Client sent an invalid public key.")
                return defer.fail(Failure(ValueError("Client sent an invalid public key")))
            self.public_key = new_encoded_pub_key
            self.wallet.set_public_key_for_peer(self.peer, self.public_key)
            log.debug("Received the client's public key: %s", str(self.public_key))
            fields = {'public_key': self.wallet.encoded_public_key}
            return defer.succeed(fields)
        if self.public_key is None:
            log.warning("Expected a public key, but did not receive one")
            return defer.fail(Failure(ValueError("Expected but did not receive a public key")))
        else:
            return defer.succeed({})