diff --git a/scribe/env.py b/scribe/env.py index d1c6097..9fe231e 100644 --- a/scribe/env.py +++ b/scribe/env.py @@ -31,9 +31,9 @@ class Env: pass def __init__(self, db_dir=None, daemon_url=None, host=None, rpc_host=None, elastic_host=None, - elastic_port=None, loop_policy=None, max_query_workers=None, websocket_host=None, websocket_port=None, + elastic_port=None, loop_policy=None, max_query_workers=None, chain=None, es_index_prefix=None, cache_MB=None, reorg_limit=None, tcp_port=None, - udp_port=None, ssl_port=None, ssl_certfile=None, ssl_keyfile=None, rpc_port=None, + udp_port=None, ssl_port=None, ssl_certfile=None, ssl_keyfile=None, prometheus_port=None, max_subscriptions=None, banner_file=None, anon_logs=None, log_sessions=None, allow_lan_udp=None, cache_all_tx_hashes=None, cache_all_claim_txos=None, country=None, payment_address=None, donation_address=None, max_send=None, max_receive=None, max_sessions=None, @@ -56,8 +56,6 @@ class Env: ) self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK']) self.max_query_workers = max_query_workers if max_query_workers is not None else self.integer('MAX_QUERY_WORKERS', 4) - self.websocket_host = websocket_host if websocket_host is not None else self.default('WEBSOCKET_HOST', self.host) - self.websocket_port = websocket_port if websocket_port is not None else self.integer('WEBSOCKET_PORT', None) if chain == 'mainnet': self.coin = LBCMainNet elif chain == 'testnet': @@ -74,7 +72,6 @@ class Env: if self.ssl_port: self.ssl_certfile = ssl_certfile if ssl_certfile is not None else self.required('SSL_CERTFILE') self.ssl_keyfile = ssl_keyfile if ssl_keyfile is not None else self.required('SSL_KEYFILE') - self.rpc_port = rpc_port if rpc_port is not None else self.integer('RPC_PORT', 8000) self.prometheus_port = prometheus_port if prometheus_port is not None else self.integer('PROMETHEUS_PORT', 0) self.max_subscriptions = max_subscriptions if max_subscriptions is not None else self.integer('MAX_SUBSCRIPTIONS', 10000) self.banner_file = banner_file if banner_file is not None else self.default('BANNER_FILE', None) @@ -307,14 +304,7 @@ class Env: help='TCP port to listen on for hub server') parser.add_argument('--udp_port', type=int, default=cls.integer('UDP_PORT', 50001), help='UDP port to listen on for hub server') - parser.add_argument('--rpc_host', default=cls.default('RPC_HOST', 'localhost'), type=str, - help='Listening interface for admin rpc') - parser.add_argument('--rpc_port', default=cls.integer('RPC_PORT', 8000), type=int, - help='Listening port for admin rpc') - parser.add_argument('--websocket_host', default=cls.default('WEBSOCKET_HOST', 'localhost'), type=str, - help='Listening interface for websocket') - parser.add_argument('--websocket_port', default=cls.integer('WEBSOCKET_PORT', None), type=int, - help='Listening port for websocket') + parser.add_argument('--ssl_port', default=cls.integer('SSL_PORT', None), type=int, help='SSL port to listen on for hub server') @@ -376,12 +366,12 @@ class Env: def from_arg_parser(cls, args): return cls( db_dir=args.db_dir, daemon_url=args.daemon_url, db_max_open_files=args.db_max_open_files, - host=args.host, rpc_host=args.rpc_host, elastic_host=args.elastic_host, elastic_port=args.elastic_port, - loop_policy=args.loop_policy, max_query_workers=args.max_query_workers, websocket_host=args.websocket_host, - websocket_port=args.websocket_port, chain=args.chain, es_index_prefix=args.es_index_prefix, + host=args.host, elastic_host=args.elastic_host, elastic_port=args.elastic_port, + loop_policy=args.loop_policy, max_query_workers=args.max_query_workers, + chain=args.chain, es_index_prefix=args.es_index_prefix, cache_MB=args.cache_MB, reorg_limit=args.reorg_limit, tcp_port=args.tcp_port, udp_port=args.udp_port, ssl_port=args.ssl_port, ssl_certfile=args.ssl_certfile, - ssl_keyfile=args.ssl_keyfile, rpc_port=args.rpc_port, prometheus_port=args.prometheus_port, + ssl_keyfile=args.ssl_keyfile, prometheus_port=args.prometheus_port, max_subscriptions=args.max_subscriptions, banner_file=args.banner_file, anon_logs=args.anon_logs, log_sessions=None, allow_lan_udp=args.allow_lan_udp, cache_all_tx_hashes=args.cache_all_tx_hashes, cache_all_claim_txos=args.cache_all_claim_txos, diff --git a/scribe/hub/session.py b/scribe/hub/session.py index a204b3d..591704c 100644 --- a/scribe/hub/session.py +++ b/scribe/hub/session.py @@ -8,7 +8,6 @@ import asyncio import logging import itertools import collections -import inspect from bisect import bisect_right from asyncio import Event, sleep from collections import defaultdict, namedtuple @@ -44,83 +43,6 @@ SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args ' 'required_names other_names') -@lru_cache(256) -def signature_info(func): - params = inspect.signature(func).parameters - min_args = max_args = 0 - required_names = [] - other_names = [] - no_names = False - for p in params.values(): - if p.kind == p.POSITIONAL_OR_KEYWORD: - max_args += 1 - if p.default is p.empty: - min_args += 1 - required_names.append(p.name) - else: - other_names.append(p.name) - elif p.kind == p.KEYWORD_ONLY: - other_names.append(p.name) - elif p.kind == p.VAR_POSITIONAL: - max_args = None - elif p.kind == p.VAR_KEYWORD: - other_names = any - elif p.kind == p.POSITIONAL_ONLY: - max_args += 1 - if p.default is p.empty: - min_args += 1 - no_names = True - - if no_names: - other_names = None - - return SignatureInfo(min_args, max_args, required_names, other_names) - - -def handler_invocation(handler, request): - method, args = request.method, request.args - if handler is None: - raise RPCError(JSONRPC.METHOD_NOT_FOUND, - f'unknown method "{method}"') - - # We must test for too few and too many arguments. How - # depends on whether the arguments were passed as a list or as - # a dictionary. - info = signature_info(handler) - if isinstance(args, (tuple, list)): - if len(args) < info.min_args: - s = '' if len(args) == 1 else 's' - raise RPCError.invalid_args( - f'{len(args)} argument{s} passed to method ' - f'"{method}" but it requires {info.min_args}') - if info.max_args is not None and len(args) > info.max_args: - s = '' if len(args) == 1 else 's' - raise RPCError.invalid_args( - f'{len(args)} argument{s} passed to method ' - f'{method} taking at most {info.max_args}') - return partial(handler, *args) - - # Arguments passed by name - if info.other_names is None: - raise RPCError.invalid_args(f'method "{method}" cannot ' - f'be called with named arguments') - - missing = set(info.required_names).difference(args) - if missing: - s = '' if len(missing) == 1 else 's' - missing = ', '.join(sorted(f'"{name}"' for name in missing)) - raise RPCError.invalid_args(f'method "{method}" requires ' - f'parameter{s} {missing}') - - if info.other_names is not any: - excess = set(args).difference(info.required_names) - excess = excess.difference(info.other_names) - if excess: - s = '' if len(excess) == 1 else 's' - excess = ', '.join(sorted(f'"{name}"' for name in excess)) - raise RPCError.invalid_args(f'method "{method}" does not ' - f'take parameter{s} {excess}') - return partial(handler, **args) def scripthash_to_hashX(scripthash: str) -> bytes: @@ -284,10 +206,10 @@ class SessionManager: async def _start_server(self, kind, *args, **kw_args): loop = asyncio.get_event_loop() - if kind == 'RPC': - protocol_class = LocalRPC - else: + if kind == 'TCP': protocol_class = LBRYElectrumX + else: + raise ValueError(kind) protocol_factory = partial(protocol_class, self, self.db, self.mempool, kind) @@ -608,9 +530,6 @@ class SessionManager: """Start the RPC server if enabled. When the event is triggered, start TCP and SSL servers.""" try: - if self.env.rpc_port is not None: - await self._start_server('RPC', self.env.cs_host(for_rpc=True), - self.env.rpc_port) self.logger.info(f'max session count: {self.env.max_sessions:,d}') self.logger.info(f'session timeout: ' f'{self.env.session_timeout:,d} seconds') @@ -717,18 +636,18 @@ class SessionManager: self.session_event.set() -class SessionBase(asyncio.Protocol): - """Base class of ElectrumX JSON sessions. +class LBRYElectrumX(asyncio.Protocol): + """A TCP server that handles incoming Electrum connections.""" - Each session runs its tasks in asynchronous parallelism with other - sessions. - """ + PROTOCOL_MIN = PROTOCOL_MIN + PROTOCOL_MAX = PROTOCOL_MAX + max_errors = math.inf # don't disconnect people for errors! let them happen... + version = __version__ + cached_server_features = {} MAX_CHUNK_SIZE = 40960 session_counter = itertools.count() request_handlers: typing.Dict[str, typing.Callable] = {} - version = '0.5.7' - RESPONSE_TIMES = Histogram("response_time", "Response times", namespace=NAMESPACE, labelnames=("method", "version"), buckets=HISTOGRAM_BUCKETS) NOTIFICATION_COUNT = Counter("notification", "Number of notifications sent (for subscriptions)", @@ -785,9 +704,193 @@ class SessionBase(asyncio.Protocol): self.txs_sent = 0 self.log_me = False self.daemon_request = self.session_manager.daemon_request - # Hijack the connection so we can log messages - self._receive_message_orig = self.connection.receive_message - self.connection.receive_message = self.receive_message + + if not LBRYElectrumX.cached_server_features: + LBRYElectrumX.set_server_features(self.env) + self.subscribe_headers = False + self.subscribe_headers_raw = False + self.subscribe_peers = False + self.connection.max_response_size = self.env.max_send + self.hashX_subs = {} + self.sv_seen = False + self.protocol_tuple = self.PROTOCOL_MIN + self.protocol_string = None + self.daemon = self.session_manager.daemon + self.db: 'HubDB' = self.session_manager.db + + def data_received(self, framed_message): + """Called by asyncio when a message comes in.""" + self.last_packet_received = time.perf_counter() + if self.verbosity >= 4: + self.logger.debug(f'Received framed message {framed_message}') + self.recv_size += len(framed_message) + self.framer.received_bytes(framed_message) + + def pause_writing(self): + """Transport calls when the send buffer is full.""" + if not self.is_closing(): + self._can_send.clear() + self.transport.pause_reading() + + def resume_writing(self): + """Transport calls when the send buffer has room.""" + if not self._can_send.is_set(): + self._can_send.set() + self.transport.resume_reading() + + def connection_made(self, transport): + """Handle an incoming client connection.""" + self.transport = transport + # This would throw if called on a closed SSL transport. Fixed + # in asyncio in Python 3.6.1 and 3.5.4 + peer_address = transport.get_extra_info('peername') + # If the Socks proxy was used then _address is already set to + # the remote address + if self._address: + self._proxy_address = peer_address + else: + self._address = peer_address + self._pm_task = self.loop.create_task(self._receive_messages()) + + self.session_id = next(self.session_counter) + context = {'conn_id': f'{self.session_id}'} + self.logger = logging.getLogger(__name__) # util.ConnectionLogger(self.logger, context) + self.group = self.session_manager.add_session(self) + self.session_manager.session_count_metric.labels(version=self.client_version).inc() + peer_addr_str = self.peer_address_str() + self.logger.info(f'{self.kind} {peer_addr_str}, ' + f'{self.session_manager.session_count():,d} total') + + def connection_lost(self, exc): + """Handle client disconnection.""" + self.connection.raise_pending_requests(exc) + self._address = None + self.transport = None + self._task_group.cancel() + if self._pm_task: + self._pm_task.cancel() + # Release waiting tasks + self._can_send.set() + + self.session_manager.remove_session(self) + self.session_manager.session_count_metric.labels(version=self.client_version).dec() + msg = '' + if not self._can_send.is_set(): + msg += ' whilst paused' + if self.send_size >= 1024 * 1024: + msg += ('. Sent {:,d} bytes in {:,d} messages' + .format(self.send_size, self.send_count)) + if msg: + msg = 'disconnected' + msg + self.logger.info(msg) + + def default_framer(self): + return NewlineFramer(self.env.max_receive) + + def peer_address_str(self, *, for_log=True): + """Returns the peer's IP address and port as a human-readable + string, respecting anon logs if the output is for a log.""" + if for_log and self.anon_logs: + return 'xx.xx.xx.xx:xx' + if not self._address: + return 'unknown' + ip_addr_str, port = self._address[:2] + if ':' in ip_addr_str: + return f'[{ip_addr_str}]:{port}' + else: + return f'{ip_addr_str}:{port}' + + def receive_message(self, message): + if self.log_me: + self.logger.info(f'processing {message}') + return self._receive_message_orig(message) + + def toggle_logging(self): + self.log_me = not self.log_me + + def count_pending_items(self): + return len(self.connection.pending_requests()) + + def semaphore(self): + return Semaphores([self.group.semaphore]) + + async def handle_request(self, request): + """Handle an incoming request. ElectrumX doesn't receive + notifications from client sessions. + """ + self.session_manager.request_count_metric.labels(method=request.method, version=self.client_version).inc() + + if isinstance(request, Request): + method = request.method + if method == 'blockchain.block.get_chunk': + coro = self.block_get_chunk + elif method == 'blockchain.block.get_header': + coro = self.block_get_header + elif method == 'blockchain.block.get_server_height': + coro = self.get_server_height + elif method == 'blockchain.scripthash.get_history': + coro = self.scripthash_get_history + elif method == 'blockchain.scripthash.get_mempool': + coro = self.scripthash_get_mempool + elif method == 'blockchain.scripthash.subscribe': + coro = self.scripthash_subscribe + elif method == 'blockchain.transaction.broadcast': + coro = self.transaction_broadcast + elif method == 'blockchain.transaction.get': + coro = self.transaction_get + elif method == 'blockchain.transaction.get_batch': + coro = self.transaction_get_batch + elif method == 'blockchain.transaction.info': + coro = self.transaction_info + elif method == 'blockchain.transaction.get_merkle': + coro = self.transaction_merkle + elif method == 'blockchain.transaction.get_height': + coro = self.transaction_get_height + elif method == 'blockchain.block.headers': + coro = self.block_headers + elif method == 'server.banner': + coro = self.banner + elif method == 'server.payment_address': + coro = self.payment_address + elif method == 'server.donation_address': + coro = self.donation_address + elif method == 'server.features': + coro = self.server_features_async + elif method == 'server.peers.subscribe': + coro = self.peers_subscribe + elif method == 'server.version': + coro = self.server_version + elif method == 'blockchain.claimtrie.search': + coro = self.claimtrie_search + elif method == 'blockchain.claimtrie.resolve': + coro = self.claimtrie_resolve + elif method == 'blockchain.claimtrie.getclaimbyid': + coro = self.claimtrie_getclaimbyid + elif method == 'mempool.get_fee_histogram': + coro = self.mempool_compact_histogram + elif method == 'server.ping': + coro = self.ping + elif method == 'blockchain.headers.subscribe': + coro = self.headers_subscribe_False + elif method == 'blockchain.address.get_history': + coro = self.address_get_history + elif method == 'blockchain.address.get_mempool': + coro = self.address_get_mempool + elif method == 'blockchain.address.subscribe': + coro = self.address_subscribe + elif method == 'blockchain.address.unsubscribe': + coro = self.address_unsubscribe + elif method == 'blockchain.estimatefee': + coro = self.estimatefee + elif method == 'blockchain.relayfee': + coro = self.relayfee + else: + raise ValueError + else: + raise ValueError + if isinstance(request.args, dict): + return await coro(**request.args) + return await coro(*request.args) async def _limited_wait(self, secs): try: @@ -962,187 +1065,6 @@ class SessionBase(asyncio.Protocol): """ return BatchRequest(self, raise_errors) - def data_received(self, framed_message): - """Called by asyncio when a message comes in.""" - self.last_packet_received = time.perf_counter() - if self.verbosity >= 4: - self.logger.debug(f'Received framed message {framed_message}') - self.recv_size += len(framed_message) - self.framer.received_bytes(framed_message) - - def pause_writing(self): - """Transport calls when the send buffer is full.""" - if not self.is_closing(): - self._can_send.clear() - self.transport.pause_reading() - - def resume_writing(self): - """Transport calls when the send buffer has room.""" - if not self._can_send.is_set(): - self._can_send.set() - self.transport.resume_reading() - - def default_framer(self): - return NewlineFramer(self.env.max_receive) - - def peer_address_str(self, *, for_log=True): - """Returns the peer's IP address and port as a human-readable - string, respecting anon logs if the output is for a log.""" - if for_log and self.anon_logs: - return 'xx.xx.xx.xx:xx' - if not self._address: - return 'unknown' - ip_addr_str, port = self._address[:2] - if ':' in ip_addr_str: - return f'[{ip_addr_str}]:{port}' - else: - return f'{ip_addr_str}:{port}' - - def receive_message(self, message): - if self.log_me: - self.logger.info(f'processing {message}') - return self._receive_message_orig(message) - - def toggle_logging(self): - self.log_me = not self.log_me - - def connection_made(self, transport): - """Handle an incoming client connection.""" - self.transport = transport - # This would throw if called on a closed SSL transport. Fixed - # in asyncio in Python 3.6.1 and 3.5.4 - peer_address = transport.get_extra_info('peername') - # If the Socks proxy was used then _address is already set to - # the remote address - if self._address: - self._proxy_address = peer_address - else: - self._address = peer_address - self._pm_task = self.loop.create_task(self._receive_messages()) - - self.session_id = next(self.session_counter) - context = {'conn_id': f'{self.session_id}'} - self.logger = logging.getLogger(__name__) #util.ConnectionLogger(self.logger, context) - self.group = self.session_manager.add_session(self) - self.session_manager.session_count_metric.labels(version=self.client_version).inc() - peer_addr_str = self.peer_address_str() - self.logger.info(f'{self.kind} {peer_addr_str}, ' - f'{self.session_manager.session_count():,d} total') - - def connection_lost(self, exc): - """Handle client disconnection.""" - self.connection.raise_pending_requests(exc) - self._address = None - self.transport = None - self._task_group.cancel() - if self._pm_task: - self._pm_task.cancel() - # Release waiting tasks - self._can_send.set() - - self.session_manager.remove_session(self) - self.session_manager.session_count_metric.labels(version=self.client_version).dec() - msg = '' - if not self._can_send.is_set(): - msg += ' whilst paused' - if self.send_size >= 1024*1024: - msg += ('. Sent {:,d} bytes in {:,d} messages' - .format(self.send_size, self.send_count)) - if msg: - msg = 'disconnected' + msg - self.logger.info(msg) - - def count_pending_items(self): - return len(self.connection.pending_requests()) - - def semaphore(self): - return Semaphores([self.group.semaphore]) - - def sub_count(self): - return 0 - - async def handle_request(self, request): - """Handle an incoming request. ElectrumX doesn't receive - notifications from client sessions. - """ - self.session_manager.request_count_metric.labels(method=request.method, version=self.client_version).inc() - if isinstance(request, Request): - handler = self.request_handlers.get(request.method) - handler = partial(handler, self) - else: - handler = None - coro = handler_invocation(handler, request)() - return await coro - - -class LBRYElectrumX(SessionBase): - """A TCP server that handles incoming Electrum connections.""" - - PROTOCOL_MIN = PROTOCOL_MIN - PROTOCOL_MAX = PROTOCOL_MAX - max_errors = math.inf # don't disconnect people for errors! let them happen... - version = __version__ - cached_server_features = {} - - @classmethod - def initialize_request_handlers(cls): - cls.request_handlers.update({ - 'blockchain.block.get_chunk': cls.block_get_chunk, - 'blockchain.block.get_header': cls.block_get_header, - 'blockchain.estimatefee': cls.estimatefee, - 'blockchain.relayfee': cls.relayfee, - # 'blockchain.scripthash.get_balance': cls.scripthash_get_balance, - 'blockchain.scripthash.get_history': cls.scripthash_get_history, - 'blockchain.scripthash.get_mempool': cls.scripthash_get_mempool, - # 'blockchain.scripthash.listunspent': cls.scripthash_listunspent, - 'blockchain.scripthash.subscribe': cls.scripthash_subscribe, - 'blockchain.transaction.broadcast': cls.transaction_broadcast, - 'blockchain.transaction.get': cls.transaction_get, - 'blockchain.transaction.get_batch': cls.transaction_get_batch, - 'blockchain.transaction.info': cls.transaction_info, - 'blockchain.transaction.get_merkle': cls.transaction_merkle, - # 'server.add_peer': cls.add_peer, - 'server.banner': cls.banner, - 'server.payment_address': cls.payment_address, - 'server.donation_address': cls.donation_address, - 'server.features': cls.server_features_async, - 'server.peers.subscribe': cls.peers_subscribe, - 'server.version': cls.server_version, - 'blockchain.transaction.get_height': cls.transaction_get_height, - 'blockchain.claimtrie.search': cls.claimtrie_search, - 'blockchain.claimtrie.resolve': cls.claimtrie_resolve, - 'blockchain.claimtrie.getclaimbyid': cls.claimtrie_getclaimbyid, - # 'blockchain.claimtrie.getclaimsbyids': cls.claimtrie_getclaimsbyids, - 'blockchain.block.get_server_height': cls.get_server_height, - 'mempool.get_fee_histogram': cls.mempool_compact_histogram, - 'blockchain.block.headers': cls.block_headers, - 'server.ping': cls.ping, - 'blockchain.headers.subscribe': cls.headers_subscribe_False, - # 'blockchain.address.get_balance': cls.address_get_balance, - 'blockchain.address.get_history': cls.address_get_history, - 'blockchain.address.get_mempool': cls.address_get_mempool, - # 'blockchain.address.listunspent': cls.address_listunspent, - 'blockchain.address.subscribe': cls.address_subscribe, - 'blockchain.address.unsubscribe': cls.address_unsubscribe, - }) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if not LBRYElectrumX.request_handlers: - LBRYElectrumX.initialize_request_handlers() - if not LBRYElectrumX.cached_server_features: - LBRYElectrumX.set_server_features(self.env) - self.subscribe_headers = False - self.subscribe_headers_raw = False - self.subscribe_peers = False - self.connection.max_response_size = self.env.max_send - self.hashX_subs = {} - self.sv_seen = False - self.protocol_tuple = self.PROTOCOL_MIN - self.protocol_string = None - self.daemon = self.session_manager.daemon - self.db: 'HubDB' = self.session_manager.db - @classmethod def protocol_min_max_strings(cls): return [version_string(ver) @@ -1795,17 +1717,6 @@ class LBRYElectrumX(SessionBase): return result[tx_hash][1] -class LocalRPC(SessionBase): - """A local TCP RPC server session.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.client = 'RPC' - self.connection._max_response_size = 0 - - def protocol_version_string(self): - return 'RPC' - def get_from_possible_keys(dictionary, *keys): for key in keys: