import asyncio
import binascii
import logging
import typing
from json.decoder import JSONDecodeError
from lbry.blob_exchange.serialization import BlobResponse, BlobRequest, blob_response_types
from lbry.blob_exchange.serialization import BlobAvailabilityResponse, BlobPriceResponse, BlobDownloadResponse, \
    BlobPaymentAddressResponse

if typing.TYPE_CHECKING:
    from lbry.blob.blob_manager import BlobManager

log = logging.getLogger(__name__)

# a standard request will be 295 bytes
MAX_REQUEST_SIZE = 1200


class BlobServerProtocol(asyncio.Protocol):
    def __init__(self, loop: asyncio.AbstractEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str,
                 idle_timeout: float = 30.0, transfer_timeout: float = 60.0):
        self.loop = loop
        self.blob_manager = blob_manager
        self.idle_timeout = idle_timeout
        self.transfer_timeout = transfer_timeout
        self.server_task: typing.Optional[asyncio.Task] = None
        self.started_listening = asyncio.Event(loop=self.loop)
        self.buf = b''
        self.transport: typing.Optional[asyncio.Transport] = None
        self.lbrycrd_address = lbrycrd_address
        self.peer_address_and_port: typing.Optional[str] = None
        self.started_transfer = asyncio.Event(loop=self.loop)
        self.transfer_finished = asyncio.Event(loop=self.loop)
        self.close_on_idle_task: typing.Optional[asyncio.Task] = None

    async def close_on_idle(self):
        while self.transport:
            try:
                await asyncio.wait_for(self.started_transfer.wait(), self.idle_timeout, loop=self.loop)
            except asyncio.TimeoutError:
                log.debug("closing idle connection from %s", self.peer_address_and_port)
                return self.close()
            self.started_transfer.clear()
            await self.transfer_finished.wait()
            self.transfer_finished.clear()

    def close(self):
        if self.transport:
            self.transport.close()

    def connection_made(self, transport):
        self.transport = transport
        self.close_on_idle_task = self.loop.create_task(self.close_on_idle())
        self.peer_address_and_port = "%s:%i" % self.transport.get_extra_info('peername')
        self.blob_manager.connection_manager.connection_received(self.peer_address_and_port)
        log.debug("received connection from %s", self.peer_address_and_port)

    def connection_lost(self, exc: typing.Optional[Exception]) -> None:
        log.debug("lost connection from %s", self.peer_address_and_port)
        self.blob_manager.connection_manager.incoming_connection_lost(self.peer_address_and_port)
        self.transport = None
        if self.close_on_idle_task and not self.close_on_idle_task.done():
            self.close_on_idle_task.cancel()
        self.close_on_idle_task = None

    def send_response(self, responses: typing.List[blob_response_types]):
        to_send = []
        while responses:
            to_send.append(responses.pop())
        serialized = BlobResponse(to_send).serialize()
        self.transport.write(serialized)
        self.blob_manager.connection_manager.sent_data(self.peer_address_and_port, len(serialized))

    async def handle_request(self, request: BlobRequest):
        addr = self.transport.get_extra_info('peername')
        peer_address, peer_port = addr

        responses = []
        address_request = request.get_address_request()
        if address_request:
            responses.append(BlobPaymentAddressResponse(lbrycrd_address=self.lbrycrd_address))
        availability_request = request.get_availability_request()
        if availability_request:
            responses.append(BlobAvailabilityResponse(available_blobs=list(set(
                filter(lambda blob_hash: blob_hash in self.blob_manager.completed_blob_hashes,
                       availability_request.requested_blobs)
            ))))
        price_request = request.get_price_request()
        if price_request:
            responses.append(BlobPriceResponse(blob_data_payment_rate='RATE_ACCEPTED'))
        download_request = request.get_blob_request()

        if download_request:
            blob = self.blob_manager.get_blob(download_request.requested_blob)
            if blob.get_is_verified():
                incoming_blob = {'blob_hash': blob.blob_hash, 'length': blob.length}
                responses.append(BlobDownloadResponse(incoming_blob=incoming_blob))
                self.send_response(responses)
                blob_hash = blob.blob_hash[:8]
                log.debug("send %s to %s:%i", blob_hash, peer_address, peer_port)
                self.started_transfer.set()
                try:
                    sent = await asyncio.wait_for(blob.sendfile(self), self.transfer_timeout, loop=self.loop)
                    if sent and sent > 0:
                        self.blob_manager.connection_manager.sent_data(self.peer_address_and_port, sent)
                        log.info("sent %s (%i bytes) to %s:%i", blob_hash, sent, peer_address, peer_port)
                    else:
                        self.close()
                        log.debug("stopped sending %s to %s:%i", blob_hash, peer_address, peer_port)
                        return
                except (OSError, ValueError, asyncio.TimeoutError) as err:
                    if isinstance(err, asyncio.TimeoutError):
                        log.debug("timed out sending blob %s to %s", blob_hash, peer_address)
                    else:
                        log.warning("could not read blob %s to send %s:%i", blob_hash, peer_address, peer_port)
                    self.close()
                    return
                finally:
                    self.transfer_finished.set()
            else:
                log.info("don't have %s to send %s:%i", blob.blob_hash[:8], peer_address, peer_port)
        if responses and not self.transport.is_closing():
            self.send_response(responses)

    def data_received(self, data):
        request = None
        if len(self.buf) + len(data or b'') >= MAX_REQUEST_SIZE:
            log.warning("request from %s is too large", self.peer_address_and_port)
            self.close()
            return
        if data:
            self.blob_manager.connection_manager.received_data(self.peer_address_and_port, len(data))
            _, separator, remainder = data.rpartition(b'}')
            if not separator:
                self.buf += data
                return
            try:
                request = BlobRequest.deserialize(self.buf + data)
                self.buf = remainder
            except JSONDecodeError:
                log.error("request from %s is not valid json (%i bytes): %s", self.peer_address_and_port,
                          len(self.buf + data), '' if not data else binascii.hexlify(self.buf + data).decode())
                self.close()
                return
        if not request.requests:
            log.error("failed to decode request from %s (%i bytes): %s", self.peer_address_and_port,
                      len(self.buf + data), '' if not data else binascii.hexlify(self.buf + data).decode())
            self.close()
            return
        self.loop.create_task(self.handle_request(request))


class BlobServer:
    def __init__(self, loop: asyncio.AbstractEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str,
                 idle_timeout: float = 30.0, transfer_timeout: float = 60.0):
        self.loop = loop
        self.blob_manager = blob_manager
        self.server_task: typing.Optional[asyncio.Task] = None
        self.started_listening = asyncio.Event(loop=self.loop)
        self.lbrycrd_address = lbrycrd_address
        self.idle_timeout = idle_timeout
        self.transfer_timeout = transfer_timeout
        self.server_protocol_class = BlobServerProtocol

    def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
        if self.server_task is not None:
            raise Exception("already running")

        async def _start_server():
            server = await self.loop.create_server(
                lambda: self.server_protocol_class(self.loop, self.blob_manager, self.lbrycrd_address,
                                                   self.idle_timeout, self.transfer_timeout),
                interface, port
            )
            self.started_listening.set()
            log.info("Blob server listening on TCP %s:%i", interface, port)
            async with server:
                await server.serve_forever()

        self.server_task = self.loop.create_task(_start_server())

    def stop_server(self):
        if self.server_task:
            self.server_task.cancel()
            self.server_task = None
            log.info("Stopped blob server")