import os
import ssl
import math
import time
import codecs
import typing
import asyncio
import logging
import itertools
import collections
from bisect import bisect_right
from asyncio import Event, sleep
from collections import defaultdict
from functools import partial

from elasticsearch import ConnectionTimeout
from prometheus_client import Counter, Info, Histogram, Gauge

import lbry
from lbry.error import ResolveCensoredError, TooManyClaimSearchParametersError
from lbry.build_info import BUILD, COMMIT_HASH, DOCKER_TAG
from lbry.schema.result import Outputs
from lbry.wallet.server.db.db import HubDB
from lbry.wallet.server.websocket import AdminWebSocket
from lbry.wallet.rpc.framing import NewlineFramer

import lbry.wallet.server.version as VERSION

from lbry.wallet.rpc import (
    RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
    handler_invocation, RPCError, Request, JSONRPC, Notification, Batch
)
from lbry.wallet.server import util
from lbry.wallet.server.hash import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, Base58Error
from lbry.wallet.server.daemon import DaemonError
from lbry.wallet.server.db.elasticsearch import SearchIndex

if typing.TYPE_CHECKING:
    from lbry.wallet.server.env import Env
    from lbry.wallet.server.daemon import Daemon
    from lbry.wallet.server.mempool import MemPool

BAD_REQUEST = 1
DAEMON_ERROR = 2

log = logging.getLogger(__name__)


def scripthash_to_hashX(scripthash: str) -> bytes:
    try:
        bin_hash = hex_str_to_hash(scripthash)
        if len(bin_hash) == 32:
            return bin_hash[:HASHX_LEN]
    except Exception:
        pass
    raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash')


def non_negative_integer(value) -> int:
    """Return param value it is or can be converted to a non-negative
    integer, otherwise raise an RPCError."""
    try:
        value = int(value)
        if value >= 0:
            return value
    except ValueError:
        pass
    raise RPCError(BAD_REQUEST,
                   f'{value} should be a non-negative integer')


def assert_boolean(value) -> bool:
    """Return param value it is boolean otherwise raise an RPCError."""
    if value in (False, True):
        return value
    raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')


def assert_tx_hash(value: str) -> None:
    """Raise an RPCError if the value is not a valid transaction
    hash."""
    try:
        if len(util.hex_to_bytes(value)) == 32:
            return
    except Exception:
        pass
    raise RPCError(BAD_REQUEST, f'{value} should be a transaction hash')


class Semaphores:
    """For aiorpcX's semaphore handling."""

    def __init__(self, semaphores):
        self.semaphores = semaphores
        self.acquired = []

    async def __aenter__(self):
        for semaphore in self.semaphores:
            await semaphore.acquire()
            self.acquired.append(semaphore)

    async def __aexit__(self, exc_type, exc_value, traceback):
        for semaphore in self.acquired:
            semaphore.release()


class SessionGroup:

    def __init__(self, gid: int):
        self.gid = gid
        # Concurrency per group
        self.semaphore = asyncio.Semaphore(20)


NAMESPACE = "wallet_server"
HISTOGRAM_BUCKETS = (
    .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf')
)


class SessionManager:
    """Holds global state about all sessions."""

    version_info_metric = Info(
        'build', 'Wallet server build info (e.g. version, commit hash)', namespace=NAMESPACE
    )
    version_info_metric.info({
        'build': BUILD,
        "commit": COMMIT_HASH,
        "docker_tag": DOCKER_TAG,
        'version': lbry.__version__,
        "min_version": util.version_string(VERSION.PROTOCOL_MIN),
        "cpu_count": str(os.cpu_count())
    })
    session_count_metric = Gauge("session_count", "Number of connected client sessions", namespace=NAMESPACE,
                                      labelnames=("version",))
    request_count_metric = Counter("requests_count", "Number of requests received", namespace=NAMESPACE,
                                   labelnames=("method", "version"))
    tx_request_count_metric = Counter("requested_transaction", "Number of transactions requested", namespace=NAMESPACE)
    tx_replied_count_metric = Counter("replied_transaction", "Number of transactions responded", namespace=NAMESPACE)
    urls_to_resolve_count_metric = Counter("urls_to_resolve", "Number of urls to resolve", namespace=NAMESPACE)
    resolved_url_count_metric = Counter("resolved_url", "Number of resolved urls", namespace=NAMESPACE)

    interrupt_count_metric = Counter("interrupt", "Number of interrupted queries", namespace=NAMESPACE)
    db_operational_error_metric = Counter(
        "operational_error", "Number of queries that raised operational errors", namespace=NAMESPACE
    )
    db_error_metric = Counter(
        "internal_error", "Number of queries raising unexpected errors", namespace=NAMESPACE
    )
    executor_time_metric = Histogram(
        "executor_time", "SQLite executor times", namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS
    )
    pending_query_metric = Gauge(
        "pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE
    )

    client_version_metric = Counter(
        "clients", "Number of connections received per client version",
        namespace=NAMESPACE, labelnames=("version",)
    )
    address_history_metric = Histogram(
        "address_history", "Time to fetch an address history",
        namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS
    )
    notifications_in_flight_metric = Gauge(
        "notifications_in_flight", "Count of notifications in flight",
        namespace=NAMESPACE
    )
    notifications_sent_metric = Histogram(
        "notifications_sent", "Time to send an address notification",
        namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS
    )

    def __init__(self, env: 'Env', db: HubDB, mempool: 'MemPool', history_cache, resolve_cache, resolve_outputs_cache,
                 daemon: 'Daemon', shutdown_event: asyncio.Event,
                 on_available_callback: typing.Callable[[], None], on_unavailable_callback: typing.Callable[[], None]):
        env.max_send = max(350000, env.max_send)
        self.env = env
        self.db = db
        self.on_available_callback = on_available_callback
        self.on_unavailable_callback = on_unavailable_callback
        self.daemon = daemon
        self.mempool = mempool
        self.shutdown_event = shutdown_event
        self.logger = util.class_logger(__name__, self.__class__.__name__)
        self.servers: typing.Dict[str, asyncio.AbstractServer] = {}
        self.sessions: typing.Dict[int, 'LBRYElectrumX'] = {}
        self.hashx_subscriptions_by_session: typing.DefaultDict[str, typing.Set[int]] = defaultdict(set)
        self.mempool_statuses = {}
        self.cur_group = SessionGroup(0)
        self.txs_sent = 0
        self.start_time = time.time()
        self.history_cache = history_cache
        self.resolve_cache = resolve_cache
        self.resolve_outputs_cache = resolve_outputs_cache
        self.notified_height: typing.Optional[int] = None
        # Cache some idea of room to avoid recounting on each subscription
        self.subs_room = 0

        self.session_event = Event()

        # Search index
        self.search_index = SearchIndex(
            self.env.es_index_prefix, self.env.database_query_timeout,
            elastic_host=env.elastic_host, elastic_port=env.elastic_port
        )

    async def _start_server(self, kind, *args, **kw_args):
        loop = asyncio.get_event_loop()

        if kind == 'RPC':
            protocol_class = LocalRPC
        else:
            protocol_class = self.env.coin.SESSIONCLS
        protocol_factory = partial(protocol_class, self, self.db,
                                   self.mempool, kind)

        host, port = args[:2]
        try:
            self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args)
        except OSError as e:    # don't suppress CancelledError
            self.logger.error(f'{kind} server failed to listen on {host}:'
                              f'{port:d} :{e!r}')
        else:
            self.logger.info(f'{kind} server listening on {host}:{port:d}')

    async def _start_external_servers(self):
        """Start listening on TCP and SSL ports, but only if the respective
        port was given in the environment.
        """
        env = self.env
        host = env.cs_host(for_rpc=False)
        if env.tcp_port is not None:
            await self._start_server('TCP', host, env.tcp_port)
        if env.ssl_port is not None:
            sslc = ssl.SSLContext(ssl.PROTOCOL_TLS)
            sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile)
            await self._start_server('SSL', host, env.ssl_port, ssl=sslc)

    async def _close_servers(self, kinds):
        """Close the servers of the given kinds (TCP etc.)."""
        if kinds:
            self.logger.info('closing down {} listening servers'
                             .format(', '.join(kinds)))
        for kind in kinds:
            server = self.servers.pop(kind, None)
            if server:
                server.close()
                await server.wait_closed()

    async def _manage_servers(self):
        paused = False
        max_sessions = self.env.max_sessions
        low_watermark = int(max_sessions * 0.95)
        while True:
            await self.session_event.wait()
            self.session_event.clear()
            if not paused and len(self.sessions) >= max_sessions:
                self.on_unavailable_callback()
                self.logger.info(f'maximum sessions {max_sessions:,d} '
                                 f'reached, stopping new connections until '
                                 f'count drops to {low_watermark:,d}')
                await self._close_servers(['TCP', 'SSL'])
                paused = True
            # Start listening for incoming connections if paused and
            # session count has fallen
            if paused and len(self.sessions) <= low_watermark:
                self.on_available_callback()
                self.logger.info('resuming listening for incoming connections')
                await self._start_external_servers()
                paused = False

    def _group_map(self):
        group_map = defaultdict(list)
        for session in self.sessions.values():
            group_map[session.group].append(session)
        return group_map

    def _sub_count(self) -> int:
        return sum(s.sub_count() for s in self.sessions.values())

    def _lookup_session(self, session_id):
        try:
            session_id = int(session_id)
        except Exception:
            pass
        else:
            for session in self.sessions.values():
                if session.session_id == session_id:
                    return session
        return None

    async def _for_each_session(self, session_ids, operation):
        if not isinstance(session_ids, list):
            raise RPCError(BAD_REQUEST, 'expected a list of session IDs')

        result = []
        for session_id in session_ids:
            session = self._lookup_session(session_id)
            if session:
                result.append(await operation(session))
            else:
                result.append(f'unknown session: {session_id}')
        return result

    async def _clear_stale_sessions(self):
        """Cut off sessions that haven't done anything for 10 minutes."""
        session_timeout = self.env.session_timeout
        while True:
            await sleep(session_timeout // 10)
            stale_cutoff = time.perf_counter() - session_timeout
            stale_sessions = [session for session in self.sessions.values()
                              if session.last_recv < stale_cutoff]
            if stale_sessions:
                text = ', '.join(str(session.session_id)
                                 for session in stale_sessions)
                self.logger.info(f'closing stale connections {text}')
                # Give the sockets some time to close gracefully
                if stale_sessions:
                    await asyncio.wait([
                        session.close(force_after=session_timeout // 10) for session in stale_sessions
                    ])

            # Consolidate small groups
            group_map = self._group_map()
            groups = [group for group, sessions in group_map.items()
                      if len(sessions) <= 5]  # fixme: apply session cost here
            if len(groups) > 1:
                new_group = groups[-1]
                for group in groups:
                    for session in group_map[group]:
                        session.group = new_group

    def _get_info(self):
        """A summary of server state."""
        group_map = self._group_map()
        method_counts = collections.defaultdict(int)
        error_count = 0
        logged = 0
        paused = 0
        pending_requests = 0
        closing = 0

        for s in self.sessions.values():
            error_count += s.errors
            if s.log_me:
                logged += 1
            if not s._can_send.is_set():
                paused += 1
            pending_requests += s.count_pending_items()
            if s.is_closing():
                closing += 1
            for request, _ in s.connection._requests.values():
                method_counts[request.method] += 1
        return {
            'closing': closing,
            'daemon': self.daemon.logged_url(),
            'daemon_height': self.daemon.cached_height(),
            'db_height': self.db.db_height,
            'errors': error_count,
            'groups': len(group_map),
            'logged': logged,
            'paused': paused,
            'pid': os.getpid(),
            'peers': [],
            'requests': pending_requests,
            'method_counts': method_counts,
            'sessions': self.session_count(),
            'subs': self._sub_count(),
            'txs_sent': self.txs_sent,
            'uptime': util.formatted_time(time.time() - self.start_time),
            'version': lbry.__version__,
        }

    def _group_data(self):
        """Returned to the RPC 'groups' call."""
        result = []
        group_map = self._group_map()
        for group, sessions in group_map.items():
            result.append([group.gid,
                           len(sessions),
                           sum(s.bw_charge for s in sessions),
                           sum(s.count_pending_items() for s in sessions),
                           sum(s.txs_sent for s in sessions),
                           sum(s.sub_count() for s in sessions),
                           sum(s.recv_count for s in sessions),
                           sum(s.recv_size for s in sessions),
                           sum(s.send_count for s in sessions),
                           sum(s.send_size for s in sessions),
                           ])
        return result

    async def _electrum_and_raw_headers(self, height):
        raw_header = await self.raw_header(height)
        electrum_header = self.env.coin.electrum_header(raw_header, height)
        return electrum_header, raw_header

    async def _refresh_hsub_results(self, height):
        """Refresh the cached header subscription responses to be for height,
        and record that as notified_height.
        """
        # Paranoia: a reorg could race and leave db_height lower
        height = min(height, self.db.db_height)
        electrum, raw = await self._electrum_and_raw_headers(height)
        self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height})
        self.notified_height = height

    # --- LocalRPC command handlers

    async def rpc_add_peer(self, real_name):
        """Add a peer.

        real_name: "bch.electrumx.cash t50001 s50002" for example
        """
        await self._notify_peer(real_name)
        return f"peer '{real_name}' added"

    async def rpc_disconnect(self, session_ids):
        """Disconnect sessions.

        session_ids: array of session IDs
        """
        async def close(session):
            """Close the session's transport."""
            await session.close(force_after=2)
            return f'disconnected {session.session_id}'

        return await self._for_each_session(session_ids, close)

    async def rpc_log(self, session_ids):
        """Toggle logging of sessions.

        session_ids: array of session IDs
        """
        async def toggle_logging(session):
            """Toggle logging of the session."""
            session.toggle_logging()
            return f'log {session.session_id}: {session.log_me}'

        return await self._for_each_session(session_ids, toggle_logging)

    async def rpc_daemon_url(self, daemon_url):
        """Replace the daemon URL."""
        daemon_url = daemon_url or self.env.daemon_url
        try:
            self.daemon.set_url(daemon_url)
        except Exception as e:
            raise RPCError(BAD_REQUEST, f'an error occurred: {e!r}')
        return f'now using daemon at {self.daemon.logged_url()}'

    async def rpc_stop(self):
        """Shut down the server cleanly."""
        self.shutdown_event.set()
        return 'stopping'

    async def rpc_getinfo(self):
        """Return summary information about the server process."""
        return self._get_info()

    async def rpc_groups(self):
        """Return statistics about the session groups."""
        return self._group_data()

    async def rpc_peers(self):
        """Return a list of data about server peers."""
        return self.env.peer_hubs

    async def rpc_query(self, items, limit):
        """Return a list of data about server peers."""
        coin = self.env.coin
        db = self.db
        lines = []

        def arg_to_hashX(arg):
            try:
                script = bytes.fromhex(arg)
                lines.append(f'Script: {arg}')
                return coin.hashX_from_script(script)
            except ValueError:
                pass

            try:
                hashX = coin.address_to_hashX(arg)
            except Base58Error as e:
                lines.append(e.args[0])
                return None
            lines.append(f'Address: {arg}')
            return hashX

        for arg in items:
            hashX = arg_to_hashX(arg)
            if not hashX:
                continue
            n = None
            history = await db.limited_history(hashX, limit=limit)
            for n, (tx_hash, height) in enumerate(history):
                lines.append(f'History #{n:,d}: height {height:,d} '
                             f'tx_hash {hash_to_hex_str(tx_hash)}')
            if n is None:
                lines.append('No history found')
            n = None
            utxos = await db.all_utxos(hashX)
            for n, utxo in enumerate(utxos, start=1):
                lines.append(f'UTXO #{n:,d}: tx_hash '
                             f'{hash_to_hex_str(utxo.tx_hash)} '
                             f'tx_pos {utxo.tx_pos:,d} height '
                             f'{utxo.height:,d} value {utxo.value:,d}')
                if n == limit:
                    break
            if n is None:
                lines.append('No UTXOs found')

            balance = sum(utxo.value for utxo in utxos)
            lines.append(f'Balance: {coin.decimal_value(balance):,f} '
                         f'{coin.SHORTNAME}')

        return lines

    # async def rpc_reorg(self, count):
    #     """Force a reorg of the given number of blocks.
    #
    #     count: number of blocks to reorg
    #     """
    #     count = non_negative_integer(count)
    #     if not self.bp.force_chain_reorg(count):
    #         raise RPCError(BAD_REQUEST, 'still catching up with daemon')
    #     return f'scheduled a reorg of {count:,d} blocks'

    # --- External Interface

    async def serve(self, mempool, server_listening_event):
        """Start the RPC server if enabled.  When the event is triggered,
        start TCP and SSL servers."""
        try:
            if self.env.rpc_port is not None:
                await self._start_server('RPC', self.env.cs_host(for_rpc=True),
                                         self.env.rpc_port)
            self.logger.info(f'max session count: {self.env.max_sessions:,d}')
            self.logger.info(f'session timeout: '
                             f'{self.env.session_timeout:,d} seconds')
            self.logger.info(f'max response size {self.env.max_send:,d} bytes')
            if self.env.drop_client is not None:
                self.logger.info(f'drop clients matching: {self.env.drop_client.pattern}')
            # Start notifications; initialize hsub_results
            await mempool.start(self.db.db_height, self)
            await self.start_other()
            await self._start_external_servers()
            server_listening_event.set()
            self.on_available_callback()
            # Peer discovery should start after the external servers
            # because we connect to ourself
            await asyncio.wait([
                self._clear_stale_sessions(),
                self._manage_servers()
            ])
        except Exception as err:
            if not isinstance(err, asyncio.CancelledError):
                log.exception("hub server died")
            raise err
        finally:
            await self._close_servers(list(self.servers.keys()))
            log.info("disconnect %i sessions", len(self.sessions))
            if self.sessions:
                await asyncio.wait([
                    session.close(force_after=1) for session in self.sessions.values()
                ])
            await self.stop_other()

    async def start_other(self):
        pass

    async def stop_other(self):
        pass

    def session_count(self) -> int:
        """The number of connections that we've sent something to."""
        return len(self.sessions)

    async def daemon_request(self, method, *args):
        """Catch a DaemonError and convert it to an RPCError."""
        try:
            return await getattr(self.daemon, method)(*args)
        except DaemonError as e:
            raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None

    async def raw_header(self, height):
        """Return the binary header at the given height."""
        try:
            return await self.db.raw_header(height)
        except IndexError:
            raise RPCError(BAD_REQUEST, f'height {height:,d} '
                                        'out of range') from None

    async def electrum_header(self, height):
        """Return the deserialized header at the given height."""
        electrum_header, _ = await self._electrum_and_raw_headers(height)
        return electrum_header

    async def broadcast_transaction(self, raw_tx):
        hex_hash = await self.daemon.broadcast_transaction(raw_tx)
        self.txs_sent += 1
        return hex_hash

    async def limited_history(self, hashX):
        """A caching layer."""
        if hashX not in self.history_cache:
            # History DoS limit.  Each element of history is about 99
            # bytes when encoded as JSON.  This limits resource usage
            # on bloated history requests, and uses a smaller divisor
            # so large requests are logged before refusing them.
            limit = self.env.max_send // 97
            self.history_cache[hashX] = await self.db.limited_history(hashX, limit=limit)
        return self.history_cache[hashX]

    def _notify_peer(self, peer):
        notify_tasks = [
            session.send_notification('blockchain.peers.subscribe', [peer])
            for session in self.sessions.values() if session.subscribe_peers
        ]
        if notify_tasks:
            self.logger.info(f'notify {len(notify_tasks)} sessions of new peers')
            asyncio.create_task(asyncio.wait(notify_tasks))

    def add_session(self, session):
        self.sessions[id(session)] = session
        self.session_event.set()
        gid = int(session.start_time - self.start_time) // 900
        if self.cur_group.gid != gid:
            self.cur_group = SessionGroup(gid)
        return self.cur_group

    def remove_session(self, session):
        """Remove a session from our sessions list if there."""
        session_id = id(session)
        for hashX in session.hashX_subs:
            sessions = self.hashx_subscriptions_by_session[hashX]
            sessions.remove(session_id)
            if not sessions:
                self.hashx_subscriptions_by_session.pop(hashX)
        self.sessions.pop(session_id)
        self.session_event.set()


class SessionBase(RPCSession):
    """Base class of ElectrumX JSON sessions.

    Each session runs its tasks in asynchronous parallelism with other
    sessions.
    """

    MAX_CHUNK_SIZE = 40960
    session_counter = itertools.count()
    request_handlers: typing.Dict[str, typing.Callable] = {}
    version = '0.5.7'

    def __init__(self, session_manager: 'LBRYSessionManager', db: 'HubDB', mempool: 'MemPool', kind: str):
        connection = JSONRPCConnection(JSONRPCAutoDetect)
        self.env = session_manager.env
        super().__init__(connection=connection)
        self.logger = util.class_logger(__name__, self.__class__.__name__)
        self.session_manager = session_manager
        self.db = db
        self.mempool = mempool
        self.kind = kind  # 'RPC', 'TCP' etc.
        self.coin = self.env.coin
        self.anon_logs = self.env.anon_logs
        self.txs_sent = 0
        self.log_me = False
        self.daemon_request = self.session_manager.daemon_request
        # Hijack the connection so we can log messages
        self._receive_message_orig = self.connection.receive_message
        self.connection.receive_message = self.receive_message


    def default_framer(self):
        return NewlineFramer(self.env.max_receive)

    def peer_address_str(self, *, for_log=True):
        """Returns the peer's IP address and port as a human-readable
        string, respecting anon logs if the output is for a log."""
        if for_log and self.anon_logs:
            return 'xx.xx.xx.xx:xx'
        return super().peer_address_str()

    def receive_message(self, message):
        if self.log_me:
            self.logger.info(f'processing {message}')
        return self._receive_message_orig(message)

    def toggle_logging(self):
        self.log_me = not self.log_me

    def connection_made(self, transport):
        """Handle an incoming client connection."""
        super().connection_made(transport)
        self.session_id = next(self.session_counter)
        context = {'conn_id': f'{self.session_id}'}
        self.logger = util.ConnectionLogger(self.logger, context)
        self.group = self.session_manager.add_session(self)
        self.session_manager.session_count_metric.labels(version=self.client_version).inc()
        peer_addr_str = self.peer_address_str()
        self.logger.info(f'{self.kind} {peer_addr_str}, '
                         f'{self.session_manager.session_count():,d} total')

    def connection_lost(self, exc):
        """Handle client disconnection."""
        super().connection_lost(exc)
        self.session_manager.remove_session(self)
        self.session_manager.session_count_metric.labels(version=self.client_version).dec()
        msg = ''
        if not self._can_send.is_set():
            msg += ' whilst paused'
        if self.send_size >= 1024*1024:
            msg += ('.  Sent {:,d} bytes in {:,d} messages'
                    .format(self.send_size, self.send_count))
        if msg:
            msg = 'disconnected' + msg
            self.logger.info(msg)

    def count_pending_items(self):
        return len(self.connection.pending_requests())

    def semaphore(self):
        return Semaphores([self.group.semaphore])

    def sub_count(self):
        return 0

    async def handle_request(self, request):
        """Handle an incoming request.  ElectrumX doesn't receive
        notifications from client sessions.
        """
        self.session_manager.request_count_metric.labels(method=request.method, version=self.client_version).inc()
        if isinstance(request, Request):
            handler = self.request_handlers.get(request.method)
            handler = partial(handler, self)
        else:
            handler = None
        coro = handler_invocation(handler, request)()
        return await coro


class LBRYSessionManager(SessionManager):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.websocket = None
        # self.metrics = ServerLoadData()
        # self.metrics_loop = None
        self.running = False
        if self.env.websocket_host is not None and self.env.websocket_port is not None:
            self.websocket = AdminWebSocket(self)

    # async def process_metrics(self):
    #     while self.running:
    #         data = self.metrics.to_json_and_reset({
    #             'sessions': self.session_count(),
    #             'height': self.db.db_height,
    #         })
    #         if self.websocket is not None:
    #             self.websocket.send_message(data)
    #         await asyncio.sleep(1)

    async def start_other(self):
        self.running = True
        if self.websocket is not None:
            await self.websocket.start()

    async def stop_other(self):
        self.running = False
        if self.websocket is not None:
            await self.websocket.stop()


class LBRYElectrumX(SessionBase):
    """A TCP server that handles incoming Electrum connections."""

    PROTOCOL_MIN = VERSION.PROTOCOL_MIN
    PROTOCOL_MAX = VERSION.PROTOCOL_MAX
    max_errors = math.inf  # don't disconnect people for errors! let them happen...
    session_manager: LBRYSessionManager
    version = lbry.__version__
    cached_server_features = {}

    @classmethod
    def initialize_request_handlers(cls):
        cls.request_handlers.update({
            'blockchain.block.get_chunk': cls.block_get_chunk,
            'blockchain.block.get_header': cls.block_get_header,
            'blockchain.estimatefee': cls.estimatefee,
            'blockchain.relayfee': cls.relayfee,
            # 'blockchain.scripthash.get_balance': cls.scripthash_get_balance,
            'blockchain.scripthash.get_history': cls.scripthash_get_history,
            'blockchain.scripthash.get_mempool': cls.scripthash_get_mempool,
            # 'blockchain.scripthash.listunspent': cls.scripthash_listunspent,
            'blockchain.scripthash.subscribe': cls.scripthash_subscribe,
            'blockchain.transaction.broadcast': cls.transaction_broadcast,
            'blockchain.transaction.get': cls.transaction_get,
            'blockchain.transaction.get_batch': cls.transaction_get_batch,
            'blockchain.transaction.info': cls.transaction_info,
            'blockchain.transaction.get_merkle': cls.transaction_merkle,
            # 'server.add_peer': cls.add_peer,
            'server.banner': cls.banner,
            'server.payment_address': cls.payment_address,
            'server.donation_address': cls.donation_address,
            'server.features': cls.server_features_async,
            'server.peers.subscribe': cls.peers_subscribe,
            'server.version': cls.server_version,
            'blockchain.transaction.get_height': cls.transaction_get_height,
            'blockchain.claimtrie.search': cls.claimtrie_search,
            'blockchain.claimtrie.resolve': cls.claimtrie_resolve,
            'blockchain.claimtrie.getclaimbyid': cls.claimtrie_getclaimbyid,
            # 'blockchain.claimtrie.getclaimsbyids': cls.claimtrie_getclaimsbyids,
            'blockchain.block.get_server_height': cls.get_server_height,
            'mempool.get_fee_histogram': cls.mempool_compact_histogram,
            'blockchain.block.headers': cls.block_headers,
            'server.ping': cls.ping,
            'blockchain.headers.subscribe': cls.headers_subscribe_False,
            # 'blockchain.address.get_balance': cls.address_get_balance,
            'blockchain.address.get_history': cls.address_get_history,
            'blockchain.address.get_mempool': cls.address_get_mempool,
            # 'blockchain.address.listunspent': cls.address_listunspent,
            'blockchain.address.subscribe': cls.address_subscribe,
            'blockchain.address.unsubscribe': cls.address_unsubscribe,
        })

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not LBRYElectrumX.request_handlers:
            LBRYElectrumX.initialize_request_handlers()
        if not LBRYElectrumX.cached_server_features:
            LBRYElectrumX.set_server_features(self.env)
        self.subscribe_headers = False
        self.subscribe_headers_raw = False
        self.subscribe_peers = False
        self.connection.max_response_size = self.env.max_send
        self.hashX_subs = {}
        self.sv_seen = False
        self.protocol_tuple = self.PROTOCOL_MIN
        self.protocol_string = None
        self.daemon = self.session_manager.daemon
        self.db: HubDB = self.session_manager.db

    @classmethod
    def protocol_min_max_strings(cls):
        return [util.version_string(ver)
                for ver in (cls.PROTOCOL_MIN, cls.PROTOCOL_MAX)]

    @classmethod
    def set_server_features(cls, env):
        """Return the server features dictionary."""
        min_str, max_str = cls.protocol_min_max_strings()
        cls.cached_server_features.update({
            'hosts': env.hosts_dict(),
            'pruning': None,
            'server_version': cls.version,
            'protocol_min': min_str,
            'protocol_max': max_str,
            'genesis_hash': env.coin.GENESIS_HASH,
            'description': env.description,
            'payment_address': env.payment_address,
            'donation_address': env.donation_address,
            'daily_fee': env.daily_fee,
            'hash_function': 'sha256',
            'trending_algorithm': 'fast_ar'
        })

    async def server_features_async(self):
        return self.cached_server_features

    @classmethod
    def server_version_args(cls):
        """The arguments to a server.version RPC call to a peer."""
        return [cls.version, cls.protocol_min_max_strings()]

    def protocol_version_string(self):
        return util.version_string(self.protocol_tuple)

    def sub_count(self):
        return len(self.hashX_subs)

    async def get_hashX_status(self, hashX: bytes):
        mempool_history = self.mempool.transaction_summaries(hashX)
        history = ''.join(f'{hash_to_hex_str(tx_hash)}:'
                          f'{height:d}:'
                          for tx_hash, height in await self.session_manager.limited_history(hashX))
        history += ''.join(f'{hash_to_hex_str(tx.hash)}:'
                           f'{-tx.has_unconfirmed_inputs:d}:'
                           for tx in mempool_history)
        if history:
            status = sha256(history.encode()).hex()
        else:
            status = None
        return history, status, len(mempool_history) > 0

    async def send_history_notifications(self, *hashXes: typing.Iterable[bytes]):
        notifications = []
        for hashX in hashXes:
            alias = self.hashX_subs[hashX]
            if len(alias) == 64:
                method = 'blockchain.scripthash.subscribe'
            else:
                method = 'blockchain.address.subscribe'
            start = time.perf_counter()
            history, status, mempool_status = await self.get_hashX_status(hashX)
            if mempool_status:
                self.session_manager.mempool_statuses[hashX] = status
            else:
                self.session_manager.mempool_statuses.pop(hashX, None)

            self.session_manager.address_history_metric.observe(time.perf_counter() - start)
            notifications.append((method, (alias, status)))

        start = time.perf_counter()
        self.session_manager.notifications_in_flight_metric.inc()
        for method, args in notifications:
            self.NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc()
        try:
            await self.send_notifications(
                Batch([Notification(method, (alias, status)) for (method, (alias, status)) in notifications])
            )
            self.session_manager.notifications_sent_metric.observe(time.perf_counter() - start)
        finally:
            self.session_manager.notifications_in_flight_metric.dec()

    # def get_metrics_or_placeholder_for_api(self, query_name):
    #     """ Do not hold on to a reference to the metrics
    #         returned by this method past an `await` or
    #         you may be working with a stale metrics object.
    #     """
    #     if self.env.track_metrics:
    #         # return self.session_manager.metrics.for_api(query_name)
    #     else:
    #         return APICallMetrics(query_name)


    # async def run_and_cache_query(self, query_name, kwargs):
    #     start = time.perf_counter()
    #     if isinstance(kwargs, dict):
    #         kwargs['release_time'] = format_release_time(kwargs.get('release_time'))
    #     try:
    #         self.session_manager.pending_query_metric.inc()
    #         return await self.db.search_index.session_query(query_name, kwargs)
    #     except ConnectionTimeout:
    #         self.session_manager.interrupt_count_metric.inc()
    #         raise RPCError(JSONRPC.QUERY_TIMEOUT, 'query timed out')
    #     finally:
    #         self.session_manager.pending_query_metric.dec()
    #         self.session_manager.executor_time_metric.observe(time.perf_counter() - start)

    async def mempool_compact_histogram(self):
        return [] #self.mempool.compact_fee_histogram()

    async def claimtrie_search(self, **kwargs):
        start = time.perf_counter()
        if 'release_time' in kwargs:
            release_time = kwargs.pop('release_time')
            release_times = release_time if isinstance(release_time, list) else [release_time]
            try:
                kwargs['release_time'] = [format_release_time(release_time) for release_time in release_times]
            except ValueError:
                pass
        try:
            self.session_manager.pending_query_metric.inc()
            if 'channel' in kwargs:
                channel_url = kwargs.pop('channel')
                _, channel_claim, _, _ = await self.db.resolve(channel_url)
                if not channel_claim or isinstance(channel_claim, (ResolveCensoredError, LookupError, ValueError)):
                    return Outputs.to_base64([], [], 0, None, None)
                kwargs['channel_id'] = channel_claim.claim_hash.hex()
            return await self.session_manager.search_index.cached_search(kwargs)
        except ConnectionTimeout:
            self.session_manager.interrupt_count_metric.inc()
            raise RPCError(JSONRPC.QUERY_TIMEOUT, 'query timed out')
        except TooManyClaimSearchParametersError as err:
            await asyncio.sleep(2)
            self.logger.warning("Got an invalid query from %s, for %s with more than %d elements.",
                                self.peer_address()[0], err.key, err.limit)
            return RPCError(1, str(err))
        finally:
            self.session_manager.pending_query_metric.dec()
            self.session_manager.executor_time_metric.observe(time.perf_counter() - start)

    async def _cached_resolve_url(self, url):
        if url not in self.session_manager.resolve_cache:
            self.session_manager.resolve_cache[url] = await self.loop.run_in_executor(self.db._executor, self.db._resolve, url)
        return self.session_manager.resolve_cache[url]

    async def claimtrie_resolve(self, *urls) -> str:
        sorted_urls = tuple(sorted(urls))
        self.session_manager.urls_to_resolve_count_metric.inc(len(sorted_urls))
        try:
            if sorted_urls in self.session_manager.resolve_outputs_cache:
                return self.session_manager.resolve_outputs_cache[sorted_urls]
            rows, extra = [], []
            for url in urls:
                if url not in self.session_manager.resolve_cache:
                    self.session_manager.resolve_cache[url] = await self._cached_resolve_url(url)
                stream, channel, repost, reposted_channel = self.session_manager.resolve_cache[url]
                if isinstance(channel, ResolveCensoredError):
                    rows.append(channel)
                    extra.append(channel.censor_row)
                elif isinstance(stream, ResolveCensoredError):
                    rows.append(stream)
                    extra.append(stream.censor_row)
                elif channel and not stream:
                    rows.append(channel)
                    # print("resolved channel", channel.name.decode())
                    if repost:
                        extra.append(repost)
                    if reposted_channel:
                        extra.append(reposted_channel)
                elif stream:
                    # print("resolved stream", stream.name.decode())
                    rows.append(stream)
                    if channel:
                        # print("and channel", channel.name.decode())
                        extra.append(channel)
                    if repost:
                        extra.append(repost)
                    if reposted_channel:
                        extra.append(reposted_channel)
                await asyncio.sleep(0)
            self.session_manager.resolve_outputs_cache[sorted_urls] = result = await self.loop.run_in_executor(
                None, Outputs.to_base64, rows, extra, 0, None, None
            )
            return result
        finally:
            self.session_manager.resolved_url_count_metric.inc(len(sorted_urls))

    async def get_server_height(self):
        return self.db.db_height

    async def transaction_get_height(self, tx_hash):
        self.assert_tx_hash(tx_hash)

        def get_height():
            v = self.db.prefix_db.tx_num.get(tx_hash)
            if v:
                return bisect_right(self.db.tx_counts, v.tx_num)
            return self.mempool.get_mempool_height(tx_hash)

        return await asyncio.get_event_loop().run_in_executor(self.db._executor, get_height)

    async def claimtrie_getclaimbyid(self, claim_id):
        rows = []
        extra = []
        stream = await self.db.fs_getclaimbyid(claim_id)
        if not stream:
            stream = LookupError(f"Could not find claim at {claim_id}")
        rows.append(stream)
        return Outputs.to_base64(rows, extra, 0, None, None)

    def assert_tx_hash(self, value):
        '''Raise an RPCError if the value is not a valid transaction
        hash.'''
        try:
            if len(util.hex_to_bytes(value)) == 32:
                return
        except Exception:
            pass
        raise RPCError(1, f'{value} should be a transaction hash')

    async def subscribe_headers_result(self):
        """The result of a header subscription or notification."""
        return self.session_manager.hsub_results[self.subscribe_headers_raw]

    async def _headers_subscribe(self, raw):
        """Subscribe to get headers of new blocks."""
        self.subscribe_headers_raw = assert_boolean(raw)
        self.subscribe_headers = True
        return await self.subscribe_headers_result()

    async def headers_subscribe(self):
        """Subscribe to get raw headers of new blocks."""
        return await self._headers_subscribe(True)

    async def headers_subscribe_True(self, raw=True):
        """Subscribe to get headers of new blocks."""
        return await self._headers_subscribe(raw)

    async def headers_subscribe_False(self, raw=False):
        """Subscribe to get headers of new blocks."""
        return await self._headers_subscribe(raw)

    async def add_peer(self, features):
        """Add a peer (but only if the peer resolves to the source)."""
        return await self.peer_mgr.on_add_peer(features, self.peer_address())

    async def peers_subscribe(self):
        """Return the server peers as a list of (ip, host, details) tuples."""
        self.subscribe_peers = True
        return self.env.peer_hubs

    async def address_status(self, hashX):
        """Returns an address status.

        Status is a hex string, but must be None if there is no history.
        """
        # Note history is ordered and mempool unordered in electrum-server
        # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
        _, status, has_mempool_history = await self.get_hashX_status(hashX)
        if has_mempool_history:
            self.session_manager.mempool_statuses[hashX] = status
        else:
            self.session_manager.mempool_statuses.pop(hashX, None)
        return status

    # async def hashX_listunspent(self, hashX):
    #     """Return the list of UTXOs of a script hash, including mempool
    #     effects."""
    #     utxos = await self.db.all_utxos(hashX)
    #     utxos = sorted(utxos)
    #     utxos.extend(await self.mempool.unordered_UTXOs(hashX))
    #     spends = await self.mempool.potential_spends(hashX)
    #
    #     return [{'tx_hash': hash_to_hex_str(utxo.tx_hash),
    #              'tx_pos': utxo.tx_pos,
    #              'height': utxo.height, 'value': utxo.value}
    #             for utxo in utxos
    #             if (utxo.tx_hash, utxo.tx_pos) not in spends]

    async def hashX_subscribe(self, hashX, alias):
        self.hashX_subs[hashX] = alias
        self.session_manager.hashx_subscriptions_by_session[hashX].add(id(self))
        return await self.address_status(hashX)

    async def hashX_unsubscribe(self, hashX, alias):
        sessions = self.session_manager.hashx_subscriptions_by_session[hashX]
        sessions.remove(id(self))
        if not sessions:
            self.hashX_subs.pop(hashX, None)

    def address_to_hashX(self, address):
        try:
            return self.coin.address_to_hashX(address)
        except Exception:
            pass
        raise RPCError(BAD_REQUEST, f'{address} is not a valid address')

    # async def address_get_balance(self, address):
    #     """Return the confirmed and unconfirmed balance of an address."""
    #     hashX = self.address_to_hashX(address)
    #     return await self.get_balance(hashX)

    async def address_get_history(self, address):
        """Return the confirmed and unconfirmed history of an address."""
        hashX = self.address_to_hashX(address)
        return await self.confirmed_and_unconfirmed_history(hashX)

    async def address_get_mempool(self, address):
        """Return the mempool transactions touching an address."""
        hashX = self.address_to_hashX(address)
        return self.unconfirmed_history(hashX)

    # async def address_listunspent(self, address):
    #     """Return the list of UTXOs of an address."""
    #     hashX = self.address_to_hashX(address)
    #     return await self.hashX_listunspent(hashX)

    async def address_subscribe(self, *addresses):
        """Subscribe to an address.

        address: the address to subscribe to"""
        if len(addresses) > 1000:
            raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}')
        results = []
        for address in addresses:
            results.append(await self.hashX_subscribe(self.address_to_hashX(address), address))
            await asyncio.sleep(0)
        return results

    async def address_unsubscribe(self, address):
        """Unsubscribe an address.

        address: the address to unsubscribe"""
        hashX = self.address_to_hashX(address)
        return await self.hashX_unsubscribe(hashX, address)

    # async def get_balance(self, hashX):
    #     utxos = await self.db.all_utxos(hashX)
    #     confirmed = sum(utxo.value for utxo in utxos)
    #     unconfirmed = await self.mempool.balance_delta(hashX)
    #     return {'confirmed': confirmed, 'unconfirmed': unconfirmed}

    # async def scripthash_get_balance(self, scripthash):
    #     """Return the confirmed and unconfirmed balance of a scripthash."""
    #     hashX = scripthash_to_hashX(scripthash)
    #     return await self.get_balance(hashX)

    def unconfirmed_history(self, hashX):
        # Note unconfirmed history is unordered in electrum-server
        # height is -1 if it has unconfirmed inputs, otherwise 0
        return [{'tx_hash': hash_to_hex_str(tx.hash),
                 'height': -tx.has_unconfirmed_inputs,
                 'fee': tx.fee}
                for tx in self.mempool.transaction_summaries(hashX)]

    async def confirmed_and_unconfirmed_history(self, hashX):
        # Note history is ordered but unconfirmed is unordered in e-s
        history = await self.session_manager.limited_history(hashX)
        conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height}
                for tx_hash, height in history]
        return conf + self.unconfirmed_history(hashX)

    async def scripthash_get_history(self, scripthash):
        """Return the confirmed and unconfirmed history of a scripthash."""
        hashX = scripthash_to_hashX(scripthash)
        return await self.confirmed_and_unconfirmed_history(hashX)

    async def scripthash_get_mempool(self, scripthash):
        """Return the mempool transactions touching a scripthash."""
        hashX = scripthash_to_hashX(scripthash)
        return self.unconfirmed_history(hashX)

    # async def scripthash_listunspent(self, scripthash):
    #     """Return the list of UTXOs of a scripthash."""
    #     hashX = scripthash_to_hashX(scripthash)
    #     return await self.hashX_listunspent(hashX)

    async def scripthash_subscribe(self, scripthash):
        """Subscribe to a script hash.

        scripthash: the SHA256 hash of the script to subscribe to"""
        hashX = scripthash_to_hashX(scripthash)
        return await self.hashX_subscribe(hashX, scripthash)

    async def _merkle_proof(self, cp_height, height):
        max_height = self.db.db_height
        if not height <= cp_height <= max_height:
            raise RPCError(BAD_REQUEST,
                           f'require header height {height:,d} <= '
                           f'cp_height {cp_height:,d} <= '
                           f'chain height {max_height:,d}')
        branch, root = await self.db.header_branch_and_root(cp_height + 1, height)
        return {
            'branch': [hash_to_hex_str(elt) for elt in branch],
            'root': hash_to_hex_str(root),
        }

    async def block_headers(self, start_height, count, cp_height=0, b64=False):
        """Return count concatenated block headers as hex for the main chain;
        starting at start_height.

        start_height and count must be non-negative integers.  At most
        MAX_CHUNK_SIZE headers will be returned.
        """
        start_height = non_negative_integer(start_height)
        count = non_negative_integer(count)
        cp_height = non_negative_integer(cp_height)

        max_size = self.MAX_CHUNK_SIZE
        count = min(count, max_size)
        headers, count = await self.db.read_headers(start_height, count)

        if b64:
            headers = self.db.encode_headers(start_height, count, headers)
        else:
            headers = headers.hex()
        result = {
            'base64' if b64 else 'hex': headers,
            'count': count,
            'max': max_size
        }
        if count and cp_height:
            last_height = start_height + count - 1
            result.update(await self._merkle_proof(cp_height, last_height))
        return result

    async def block_get_chunk(self, index):
        """Return a chunk of block headers as a hexadecimal string.

        index: the chunk index"""
        index = non_negative_integer(index)
        size = self.coin.CHUNK_SIZE
        start_height = index * size
        headers, _ = await self.db.read_headers(start_height, size)
        return headers.hex()

    async def block_get_header(self, height):
        """The deserialized header at a given height.

        height: the header's height"""
        height = non_negative_integer(height)
        return await self.session_manager.electrum_header(height)

    def is_tor(self):
        """Try to detect if the connection is to a tor hidden service we are
        running."""
        peername = self.peer_mgr.proxy_peername()
        if not peername:
            return False
        peer_address = self.peer_address()
        return peer_address and peer_address[0] == peername[0]

    async def replaced_banner(self, banner):
        network_info = await self.daemon_request('getnetworkinfo')
        ni_version = network_info['version']
        major, minor = divmod(ni_version, 1000000)
        minor, revision = divmod(minor, 10000)
        revision //= 100
        daemon_version = f'{major:d}.{minor:d}.{revision:d}'
        for pair in [
            ('$SERVER_VERSION', self.version),
            ('$DAEMON_VERSION', daemon_version),
            ('$DAEMON_SUBVERSION', network_info['subversion']),
            ('$PAYMENT_ADDRESS', self.env.payment_address),
            ('$DONATION_ADDRESS', self.env.donation_address),
        ]:
            banner = banner.replace(*pair)
        return banner

    async def payment_address(self):
        """Return the payment address as a string, empty if there is none."""
        return self.env.payment_address

    async def donation_address(self):
        """Return the donation address as a string, empty if there is none."""
        return self.env.donation_address

    async def banner(self):
        """Return the server banner text."""
        banner = f'You are connected to an {self.version} server.'
        banner_file = self.env.banner_file
        if banner_file:
            try:
                with codecs.open(banner_file, 'r', 'utf-8') as f:
                    banner = f.read()
            except Exception as e:
                self.logger.error(f'reading banner file {banner_file}: {e!r}')
            else:
                banner = await self.replaced_banner(banner)

        return banner

    async def relayfee(self):
        """The minimum fee a low-priority tx must pay in order to be accepted
        to the daemon's memory pool."""
        return await self.daemon_request('relayfee')

    async def estimatefee(self, number):
        """The estimated transaction fee per kilobyte to be paid for a
        transaction to be included within a certain number of blocks.

        number: the number of blocks
        """
        number = non_negative_integer(number)
        return await self.daemon_request('estimatefee', number)

    async def ping(self):
        """Serves as a connection keep-alive mechanism and for the client to
        confirm the server is still responding.
        """
        return None

    async def server_version(self, client_name='', protocol_version=None):
        """Returns the server version as a string.

        client_name: a string identifying the client
        protocol_version: the protocol version spoken by the client
        """
        if self.protocol_string is not None:
            return self.version, self.protocol_string
        if self.sv_seen and self.protocol_tuple >= (1, 4):
            raise RPCError(BAD_REQUEST, f'server.version already sent')
        self.sv_seen = True

        if client_name:
            client_name = str(client_name)
            if self.env.drop_client is not None and \
                    self.env.drop_client.match(client_name):
                self.close_after_send = True
                raise RPCError(BAD_REQUEST, f'unsupported client: {client_name}')
            if self.client_version != client_name[:17]:
                self.session_manager.session_count_metric.labels(version=self.client_version).dec()
                self.client_version = client_name[:17]
                self.session_manager.session_count_metric.labels(version=self.client_version).inc()
        self.session_manager.client_version_metric.labels(version=self.client_version).inc()

        # Find the highest common protocol version.  Disconnect if
        # that protocol version in unsupported.
        ptuple, client_min = util.protocol_version(protocol_version, self.PROTOCOL_MIN, self.PROTOCOL_MAX)
        if ptuple is None:
            ptuple, client_min = util.protocol_version(protocol_version, (1, 1, 0), (1, 4, 0))
            if ptuple is None:
                self.close_after_send = True
                raise RPCError(BAD_REQUEST, f'unsupported protocol version: {protocol_version}')

        self.protocol_tuple = ptuple
        self.protocol_string = util.version_string(ptuple)
        return self.version, self.protocol_string

    async def transaction_broadcast(self, raw_tx):
        """Broadcast a raw transaction to the network.

        raw_tx: the raw transaction as a hexadecimal string"""
        # This returns errors as JSON RPC errors, as is natural
        try:
            hex_hash = await self.session_manager.broadcast_transaction(raw_tx)
            self.txs_sent += 1
            # self.mempool.wakeup.set()
            # await asyncio.sleep(0.5)
            self.logger.info(f'sent tx: {hex_hash}')
            return hex_hash
        except DaemonError as e:
            error, = e.args
            message = error['message']
            self.logger.info(f'error sending transaction: {message}')
            raise RPCError(BAD_REQUEST, 'the transaction was rejected by '
                                        f'network rules.\n\n{message}\n[{raw_tx}]')

    async def transaction_info(self, tx_hash: str):
        return (await self.transaction_get_batch(tx_hash))[tx_hash]

    async def transaction_get_batch(self, *tx_hashes):
        self.session_manager.tx_request_count_metric.inc(len(tx_hashes))
        if len(tx_hashes) > 100:
            raise RPCError(BAD_REQUEST, f'too many tx hashes in request: {len(tx_hashes)}')
        for tx_hash in tx_hashes:
            assert_tx_hash(tx_hash)
        batch_result = await self.db.get_transactions_and_merkles(tx_hashes)
        needed_merkles = {}

        for tx_hash in tx_hashes:
            if tx_hash in batch_result and batch_result[tx_hash][0]:
                continue
            tx_hash_bytes = bytes.fromhex(tx_hash)[::-1]
            mempool_tx = self.mempool.txs.get(tx_hash_bytes, None)
            if mempool_tx:
                raw_tx, block_hash = mempool_tx.raw_tx.hex(), None
            else:
                tx_info = await self.daemon_request('getrawtransaction', tx_hash, 1)
                raw_tx = tx_info['hex']
                block_hash = tx_info.get('blockhash')
            if block_hash:
                block = await self.daemon.deserialised_block(block_hash)
                height = block['height']
                try:
                    pos = block['tx'].index(tx_hash)
                except ValueError:
                    raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in '
                                                f'block {block_hash} at height {height:,d}')
                needed_merkles[tx_hash] = raw_tx, block['tx'], pos, height
            else:
                batch_result[tx_hash] = [raw_tx, {'block_height': -1}]

        if needed_merkles:
            for tx_hash, (raw_tx, block_txs, pos, block_height) in needed_merkles.items():
                batch_result[tx_hash] = raw_tx, {
                    'merkle': self._get_merkle_branch(block_txs, pos),
                    'pos': pos,
                    'block_height': block_height
                }
                await asyncio.sleep(0)  # heavy call, give other tasks a chance
        self.session_manager.tx_replied_count_metric.inc(len(tx_hashes))
        return batch_result

    async def transaction_get(self, tx_hash, verbose=False):
        """Return the serialized raw transaction given its hash

        tx_hash: the transaction hash as a hexadecimal string
        verbose: passed on to the daemon
        """
        assert_tx_hash(tx_hash)
        if verbose not in (True, False):
            raise RPCError(BAD_REQUEST, f'"verbose" must be a boolean')

        return await self.daemon_request('getrawtransaction', tx_hash, int(verbose))

    def _get_merkle_branch(self, tx_hashes, tx_pos):
        """Return a merkle branch to a transaction.

        tx_hashes: ordered list of hex strings of tx hashes in a block
        tx_pos: index of transaction in tx_hashes to create branch for
        """
        hashes = [hex_str_to_hash(hash) for hash in tx_hashes]
        branch, root = self.db.merkle.branch_and_root(hashes, tx_pos)
        branch = [hash_to_hex_str(hash) for hash in branch]
        return branch

    async def transaction_merkle(self, tx_hash, height):
        """Return the markle branch to a confirmed transaction given its hash
        and height.

        tx_hash: the transaction hash as a hexadecimal string
        height: the height of the block it is in
        """
        assert_tx_hash(tx_hash)
        result = await self.transaction_get_batch(tx_hash)
        if tx_hash not in result or result[tx_hash][1]['block_height'] <= 0:
            raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in '
                                        f'block at height {height:,d}')
        return result[tx_hash][1]


class LocalRPC(SessionBase):
    """A local TCP RPC server session."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.client = 'RPC'
        self.connection._max_response_size = 0

    def protocol_version_string(self):
        return 'RPC'


def get_from_possible_keys(dictionary, *keys):
    for key in keys:
        if key in dictionary:
            return dictionary[key]


def format_release_time(release_time):
    # round release time to 1000 so it caches better
    # also set a default so we dont show claims in the future
    def roundup_time(number, factor=360):
        return int(1 + int(number / factor)) * factor
    if isinstance(release_time, str) and len(release_time) > 0:
        time_digits = ''.join(filter(str.isdigit, release_time))
        time_prefix = release_time[:-len(time_digits)]
        return time_prefix + str(roundup_time(int(time_digits)))
    elif isinstance(release_time, int):
        return roundup_time(release_time)