combine SessionBase into LBRYElectrumX class

-refactor response handler
-drop local rpc
-remove some unused fields from env.py
This commit is contained in:
Jack Robison 2022-03-09 21:19:18 -05:00
parent c5dc8d5cad
commit 4488dafeda
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 204 additions and 303 deletions

View file

@ -31,9 +31,9 @@ class Env:
pass
def __init__(self, db_dir=None, daemon_url=None, host=None, rpc_host=None, elastic_host=None,
elastic_port=None, loop_policy=None, max_query_workers=None, websocket_host=None, websocket_port=None,
elastic_port=None, loop_policy=None, max_query_workers=None,
chain=None, es_index_prefix=None, cache_MB=None, reorg_limit=None, tcp_port=None,
udp_port=None, ssl_port=None, ssl_certfile=None, ssl_keyfile=None, rpc_port=None,
udp_port=None, ssl_port=None, ssl_certfile=None, ssl_keyfile=None,
prometheus_port=None, max_subscriptions=None, banner_file=None, anon_logs=None, log_sessions=None,
allow_lan_udp=None, cache_all_tx_hashes=None, cache_all_claim_txos=None, country=None,
payment_address=None, donation_address=None, max_send=None, max_receive=None, max_sessions=None,
@ -56,8 +56,6 @@ class Env:
)
self.obsolete(['UTXO_MB', 'HIST_MB', 'NETWORK'])
self.max_query_workers = max_query_workers if max_query_workers is not None else self.integer('MAX_QUERY_WORKERS', 4)
self.websocket_host = websocket_host if websocket_host is not None else self.default('WEBSOCKET_HOST', self.host)
self.websocket_port = websocket_port if websocket_port is not None else self.integer('WEBSOCKET_PORT', None)
if chain == 'mainnet':
self.coin = LBCMainNet
elif chain == 'testnet':
@ -74,7 +72,6 @@ class Env:
if self.ssl_port:
self.ssl_certfile = ssl_certfile if ssl_certfile is not None else self.required('SSL_CERTFILE')
self.ssl_keyfile = ssl_keyfile if ssl_keyfile is not None else self.required('SSL_KEYFILE')
self.rpc_port = rpc_port if rpc_port is not None else self.integer('RPC_PORT', 8000)
self.prometheus_port = prometheus_port if prometheus_port is not None else self.integer('PROMETHEUS_PORT', 0)
self.max_subscriptions = max_subscriptions if max_subscriptions is not None else self.integer('MAX_SUBSCRIPTIONS', 10000)
self.banner_file = banner_file if banner_file is not None else self.default('BANNER_FILE', None)
@ -307,14 +304,7 @@ class Env:
help='TCP port to listen on for hub server')
parser.add_argument('--udp_port', type=int, default=cls.integer('UDP_PORT', 50001),
help='UDP port to listen on for hub server')
parser.add_argument('--rpc_host', default=cls.default('RPC_HOST', 'localhost'), type=str,
help='Listening interface for admin rpc')
parser.add_argument('--rpc_port', default=cls.integer('RPC_PORT', 8000), type=int,
help='Listening port for admin rpc')
parser.add_argument('--websocket_host', default=cls.default('WEBSOCKET_HOST', 'localhost'), type=str,
help='Listening interface for websocket')
parser.add_argument('--websocket_port', default=cls.integer('WEBSOCKET_PORT', None), type=int,
help='Listening port for websocket')
parser.add_argument('--ssl_port', default=cls.integer('SSL_PORT', None), type=int,
help='SSL port to listen on for hub server')
@ -376,12 +366,12 @@ class Env:
def from_arg_parser(cls, args):
return cls(
db_dir=args.db_dir, daemon_url=args.daemon_url, db_max_open_files=args.db_max_open_files,
host=args.host, rpc_host=args.rpc_host, elastic_host=args.elastic_host, elastic_port=args.elastic_port,
loop_policy=args.loop_policy, max_query_workers=args.max_query_workers, websocket_host=args.websocket_host,
websocket_port=args.websocket_port, chain=args.chain, es_index_prefix=args.es_index_prefix,
host=args.host, elastic_host=args.elastic_host, elastic_port=args.elastic_port,
loop_policy=args.loop_policy, max_query_workers=args.max_query_workers,
chain=args.chain, es_index_prefix=args.es_index_prefix,
cache_MB=args.cache_MB, reorg_limit=args.reorg_limit, tcp_port=args.tcp_port,
udp_port=args.udp_port, ssl_port=args.ssl_port, ssl_certfile=args.ssl_certfile,
ssl_keyfile=args.ssl_keyfile, rpc_port=args.rpc_port, prometheus_port=args.prometheus_port,
ssl_keyfile=args.ssl_keyfile, prometheus_port=args.prometheus_port,
max_subscriptions=args.max_subscriptions, banner_file=args.banner_file, anon_logs=args.anon_logs,
log_sessions=None, allow_lan_udp=args.allow_lan_udp,
cache_all_tx_hashes=args.cache_all_tx_hashes, cache_all_claim_txos=args.cache_all_claim_txos,

View file

@ -8,7 +8,6 @@ import asyncio
import logging
import itertools
import collections
import inspect
from bisect import bisect_right
from asyncio import Event, sleep
from collections import defaultdict, namedtuple
@ -44,83 +43,6 @@ SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
'required_names other_names')
@lru_cache(256)
def signature_info(func):
params = inspect.signature(func).parameters
min_args = max_args = 0
required_names = []
other_names = []
no_names = False
for p in params.values():
if p.kind == p.POSITIONAL_OR_KEYWORD:
max_args += 1
if p.default is p.empty:
min_args += 1
required_names.append(p.name)
else:
other_names.append(p.name)
elif p.kind == p.KEYWORD_ONLY:
other_names.append(p.name)
elif p.kind == p.VAR_POSITIONAL:
max_args = None
elif p.kind == p.VAR_KEYWORD:
other_names = any
elif p.kind == p.POSITIONAL_ONLY:
max_args += 1
if p.default is p.empty:
min_args += 1
no_names = True
if no_names:
other_names = None
return SignatureInfo(min_args, max_args, required_names, other_names)
def handler_invocation(handler, request):
method, args = request.method, request.args
if handler is None:
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
f'unknown method "{method}"')
# We must test for too few and too many arguments. How
# depends on whether the arguments were passed as a list or as
# a dictionary.
info = signature_info(handler)
if isinstance(args, (tuple, list)):
if len(args) < info.min_args:
s = '' if len(args) == 1 else 's'
raise RPCError.invalid_args(
f'{len(args)} argument{s} passed to method '
f'"{method}" but it requires {info.min_args}')
if info.max_args is not None and len(args) > info.max_args:
s = '' if len(args) == 1 else 's'
raise RPCError.invalid_args(
f'{len(args)} argument{s} passed to method '
f'{method} taking at most {info.max_args}')
return partial(handler, *args)
# Arguments passed by name
if info.other_names is None:
raise RPCError.invalid_args(f'method "{method}" cannot '
f'be called with named arguments')
missing = set(info.required_names).difference(args)
if missing:
s = '' if len(missing) == 1 else 's'
missing = ', '.join(sorted(f'"{name}"' for name in missing))
raise RPCError.invalid_args(f'method "{method}" requires '
f'parameter{s} {missing}')
if info.other_names is not any:
excess = set(args).difference(info.required_names)
excess = excess.difference(info.other_names)
if excess:
s = '' if len(excess) == 1 else 's'
excess = ', '.join(sorted(f'"{name}"' for name in excess))
raise RPCError.invalid_args(f'method "{method}" does not '
f'take parameter{s} {excess}')
return partial(handler, **args)
def scripthash_to_hashX(scripthash: str) -> bytes:
@ -284,10 +206,10 @@ class SessionManager:
async def _start_server(self, kind, *args, **kw_args):
loop = asyncio.get_event_loop()
if kind == 'RPC':
protocol_class = LocalRPC
else:
if kind == 'TCP':
protocol_class = LBRYElectrumX
else:
raise ValueError(kind)
protocol_factory = partial(protocol_class, self, self.db,
self.mempool, kind)
@ -608,9 +530,6 @@ class SessionManager:
"""Start the RPC server if enabled. When the event is triggered,
start TCP and SSL servers."""
try:
if self.env.rpc_port is not None:
await self._start_server('RPC', self.env.cs_host(for_rpc=True),
self.env.rpc_port)
self.logger.info(f'max session count: {self.env.max_sessions:,d}')
self.logger.info(f'session timeout: '
f'{self.env.session_timeout:,d} seconds')
@ -717,18 +636,18 @@ class SessionManager:
self.session_event.set()
class SessionBase(asyncio.Protocol):
"""Base class of ElectrumX JSON sessions.
class LBRYElectrumX(asyncio.Protocol):
"""A TCP server that handles incoming Electrum connections."""
Each session runs its tasks in asynchronous parallelism with other
sessions.
"""
PROTOCOL_MIN = PROTOCOL_MIN
PROTOCOL_MAX = PROTOCOL_MAX
max_errors = math.inf # don't disconnect people for errors! let them happen...
version = __version__
cached_server_features = {}
MAX_CHUNK_SIZE = 40960
session_counter = itertools.count()
request_handlers: typing.Dict[str, typing.Callable] = {}
version = '0.5.7'
RESPONSE_TIMES = Histogram("response_time", "Response times", namespace=NAMESPACE,
labelnames=("method", "version"), buckets=HISTOGRAM_BUCKETS)
NOTIFICATION_COUNT = Counter("notification", "Number of notifications sent (for subscriptions)",
@ -785,9 +704,193 @@ class SessionBase(asyncio.Protocol):
self.txs_sent = 0
self.log_me = False
self.daemon_request = self.session_manager.daemon_request
# Hijack the connection so we can log messages
self._receive_message_orig = self.connection.receive_message
self.connection.receive_message = self.receive_message
if not LBRYElectrumX.cached_server_features:
LBRYElectrumX.set_server_features(self.env)
self.subscribe_headers = False
self.subscribe_headers_raw = False
self.subscribe_peers = False
self.connection.max_response_size = self.env.max_send
self.hashX_subs = {}
self.sv_seen = False
self.protocol_tuple = self.PROTOCOL_MIN
self.protocol_string = None
self.daemon = self.session_manager.daemon
self.db: 'HubDB' = self.session_manager.db
def data_received(self, framed_message):
"""Called by asyncio when a message comes in."""
self.last_packet_received = time.perf_counter()
if self.verbosity >= 4:
self.logger.debug(f'Received framed message {framed_message}')
self.recv_size += len(framed_message)
self.framer.received_bytes(framed_message)
def pause_writing(self):
"""Transport calls when the send buffer is full."""
if not self.is_closing():
self._can_send.clear()
self.transport.pause_reading()
def resume_writing(self):
"""Transport calls when the send buffer has room."""
if not self._can_send.is_set():
self._can_send.set()
self.transport.resume_reading()
def connection_made(self, transport):
"""Handle an incoming client connection."""
self.transport = transport
# This would throw if called on a closed SSL transport. Fixed
# in asyncio in Python 3.6.1 and 3.5.4
peer_address = transport.get_extra_info('peername')
# If the Socks proxy was used then _address is already set to
# the remote address
if self._address:
self._proxy_address = peer_address
else:
self._address = peer_address
self._pm_task = self.loop.create_task(self._receive_messages())
self.session_id = next(self.session_counter)
context = {'conn_id': f'{self.session_id}'}
self.logger = logging.getLogger(__name__) # util.ConnectionLogger(self.logger, context)
self.group = self.session_manager.add_session(self)
self.session_manager.session_count_metric.labels(version=self.client_version).inc()
peer_addr_str = self.peer_address_str()
self.logger.info(f'{self.kind} {peer_addr_str}, '
f'{self.session_manager.session_count():,d} total')
def connection_lost(self, exc):
"""Handle client disconnection."""
self.connection.raise_pending_requests(exc)
self._address = None
self.transport = None
self._task_group.cancel()
if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
self.session_manager.remove_session(self)
self.session_manager.session_count_metric.labels(version=self.client_version).dec()
msg = ''
if not self._can_send.is_set():
msg += ' whilst paused'
if self.send_size >= 1024 * 1024:
msg += ('. Sent {:,d} bytes in {:,d} messages'
.format(self.send_size, self.send_count))
if msg:
msg = 'disconnected' + msg
self.logger.info(msg)
def default_framer(self):
return NewlineFramer(self.env.max_receive)
def peer_address_str(self, *, for_log=True):
"""Returns the peer's IP address and port as a human-readable
string, respecting anon logs if the output is for a log."""
if for_log and self.anon_logs:
return 'xx.xx.xx.xx:xx'
if not self._address:
return 'unknown'
ip_addr_str, port = self._address[:2]
if ':' in ip_addr_str:
return f'[{ip_addr_str}]:{port}'
else:
return f'{ip_addr_str}:{port}'
def receive_message(self, message):
if self.log_me:
self.logger.info(f'processing {message}')
return self._receive_message_orig(message)
def toggle_logging(self):
self.log_me = not self.log_me
def count_pending_items(self):
return len(self.connection.pending_requests())
def semaphore(self):
return Semaphores([self.group.semaphore])
async def handle_request(self, request):
"""Handle an incoming request. ElectrumX doesn't receive
notifications from client sessions.
"""
self.session_manager.request_count_metric.labels(method=request.method, version=self.client_version).inc()
if isinstance(request, Request):
method = request.method
if method == 'blockchain.block.get_chunk':
coro = self.block_get_chunk
elif method == 'blockchain.block.get_header':
coro = self.block_get_header
elif method == 'blockchain.block.get_server_height':
coro = self.get_server_height
elif method == 'blockchain.scripthash.get_history':
coro = self.scripthash_get_history
elif method == 'blockchain.scripthash.get_mempool':
coro = self.scripthash_get_mempool
elif method == 'blockchain.scripthash.subscribe':
coro = self.scripthash_subscribe
elif method == 'blockchain.transaction.broadcast':
coro = self.transaction_broadcast
elif method == 'blockchain.transaction.get':
coro = self.transaction_get
elif method == 'blockchain.transaction.get_batch':
coro = self.transaction_get_batch
elif method == 'blockchain.transaction.info':
coro = self.transaction_info
elif method == 'blockchain.transaction.get_merkle':
coro = self.transaction_merkle
elif method == 'blockchain.transaction.get_height':
coro = self.transaction_get_height
elif method == 'blockchain.block.headers':
coro = self.block_headers
elif method == 'server.banner':
coro = self.banner
elif method == 'server.payment_address':
coro = self.payment_address
elif method == 'server.donation_address':
coro = self.donation_address
elif method == 'server.features':
coro = self.server_features_async
elif method == 'server.peers.subscribe':
coro = self.peers_subscribe
elif method == 'server.version':
coro = self.server_version
elif method == 'blockchain.claimtrie.search':
coro = self.claimtrie_search
elif method == 'blockchain.claimtrie.resolve':
coro = self.claimtrie_resolve
elif method == 'blockchain.claimtrie.getclaimbyid':
coro = self.claimtrie_getclaimbyid
elif method == 'mempool.get_fee_histogram':
coro = self.mempool_compact_histogram
elif method == 'server.ping':
coro = self.ping
elif method == 'blockchain.headers.subscribe':
coro = self.headers_subscribe_False
elif method == 'blockchain.address.get_history':
coro = self.address_get_history
elif method == 'blockchain.address.get_mempool':
coro = self.address_get_mempool
elif method == 'blockchain.address.subscribe':
coro = self.address_subscribe
elif method == 'blockchain.address.unsubscribe':
coro = self.address_unsubscribe
elif method == 'blockchain.estimatefee':
coro = self.estimatefee
elif method == 'blockchain.relayfee':
coro = self.relayfee
else:
raise ValueError
else:
raise ValueError
if isinstance(request.args, dict):
return await coro(**request.args)
return await coro(*request.args)
async def _limited_wait(self, secs):
try:
@ -962,187 +1065,6 @@ class SessionBase(asyncio.Protocol):
"""
return BatchRequest(self, raise_errors)
def data_received(self, framed_message):
"""Called by asyncio when a message comes in."""
self.last_packet_received = time.perf_counter()
if self.verbosity >= 4:
self.logger.debug(f'Received framed message {framed_message}')
self.recv_size += len(framed_message)
self.framer.received_bytes(framed_message)
def pause_writing(self):
"""Transport calls when the send buffer is full."""
if not self.is_closing():
self._can_send.clear()
self.transport.pause_reading()
def resume_writing(self):
"""Transport calls when the send buffer has room."""
if not self._can_send.is_set():
self._can_send.set()
self.transport.resume_reading()
def default_framer(self):
return NewlineFramer(self.env.max_receive)
def peer_address_str(self, *, for_log=True):
"""Returns the peer's IP address and port as a human-readable
string, respecting anon logs if the output is for a log."""
if for_log and self.anon_logs:
return 'xx.xx.xx.xx:xx'
if not self._address:
return 'unknown'
ip_addr_str, port = self._address[:2]
if ':' in ip_addr_str:
return f'[{ip_addr_str}]:{port}'
else:
return f'{ip_addr_str}:{port}'
def receive_message(self, message):
if self.log_me:
self.logger.info(f'processing {message}')
return self._receive_message_orig(message)
def toggle_logging(self):
self.log_me = not self.log_me
def connection_made(self, transport):
"""Handle an incoming client connection."""
self.transport = transport
# This would throw if called on a closed SSL transport. Fixed
# in asyncio in Python 3.6.1 and 3.5.4
peer_address = transport.get_extra_info('peername')
# If the Socks proxy was used then _address is already set to
# the remote address
if self._address:
self._proxy_address = peer_address
else:
self._address = peer_address
self._pm_task = self.loop.create_task(self._receive_messages())
self.session_id = next(self.session_counter)
context = {'conn_id': f'{self.session_id}'}
self.logger = logging.getLogger(__name__) #util.ConnectionLogger(self.logger, context)
self.group = self.session_manager.add_session(self)
self.session_manager.session_count_metric.labels(version=self.client_version).inc()
peer_addr_str = self.peer_address_str()
self.logger.info(f'{self.kind} {peer_addr_str}, '
f'{self.session_manager.session_count():,d} total')
def connection_lost(self, exc):
"""Handle client disconnection."""
self.connection.raise_pending_requests(exc)
self._address = None
self.transport = None
self._task_group.cancel()
if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
self.session_manager.remove_session(self)
self.session_manager.session_count_metric.labels(version=self.client_version).dec()
msg = ''
if not self._can_send.is_set():
msg += ' whilst paused'
if self.send_size >= 1024*1024:
msg += ('. Sent {:,d} bytes in {:,d} messages'
.format(self.send_size, self.send_count))
if msg:
msg = 'disconnected' + msg
self.logger.info(msg)
def count_pending_items(self):
return len(self.connection.pending_requests())
def semaphore(self):
return Semaphores([self.group.semaphore])
def sub_count(self):
return 0
async def handle_request(self, request):
"""Handle an incoming request. ElectrumX doesn't receive
notifications from client sessions.
"""
self.session_manager.request_count_metric.labels(method=request.method, version=self.client_version).inc()
if isinstance(request, Request):
handler = self.request_handlers.get(request.method)
handler = partial(handler, self)
else:
handler = None
coro = handler_invocation(handler, request)()
return await coro
class LBRYElectrumX(SessionBase):
"""A TCP server that handles incoming Electrum connections."""
PROTOCOL_MIN = PROTOCOL_MIN
PROTOCOL_MAX = PROTOCOL_MAX
max_errors = math.inf # don't disconnect people for errors! let them happen...
version = __version__
cached_server_features = {}
@classmethod
def initialize_request_handlers(cls):
cls.request_handlers.update({
'blockchain.block.get_chunk': cls.block_get_chunk,
'blockchain.block.get_header': cls.block_get_header,
'blockchain.estimatefee': cls.estimatefee,
'blockchain.relayfee': cls.relayfee,
# 'blockchain.scripthash.get_balance': cls.scripthash_get_balance,
'blockchain.scripthash.get_history': cls.scripthash_get_history,
'blockchain.scripthash.get_mempool': cls.scripthash_get_mempool,
# 'blockchain.scripthash.listunspent': cls.scripthash_listunspent,
'blockchain.scripthash.subscribe': cls.scripthash_subscribe,
'blockchain.transaction.broadcast': cls.transaction_broadcast,
'blockchain.transaction.get': cls.transaction_get,
'blockchain.transaction.get_batch': cls.transaction_get_batch,
'blockchain.transaction.info': cls.transaction_info,
'blockchain.transaction.get_merkle': cls.transaction_merkle,
# 'server.add_peer': cls.add_peer,
'server.banner': cls.banner,
'server.payment_address': cls.payment_address,
'server.donation_address': cls.donation_address,
'server.features': cls.server_features_async,
'server.peers.subscribe': cls.peers_subscribe,
'server.version': cls.server_version,
'blockchain.transaction.get_height': cls.transaction_get_height,
'blockchain.claimtrie.search': cls.claimtrie_search,
'blockchain.claimtrie.resolve': cls.claimtrie_resolve,
'blockchain.claimtrie.getclaimbyid': cls.claimtrie_getclaimbyid,
# 'blockchain.claimtrie.getclaimsbyids': cls.claimtrie_getclaimsbyids,
'blockchain.block.get_server_height': cls.get_server_height,
'mempool.get_fee_histogram': cls.mempool_compact_histogram,
'blockchain.block.headers': cls.block_headers,
'server.ping': cls.ping,
'blockchain.headers.subscribe': cls.headers_subscribe_False,
# 'blockchain.address.get_balance': cls.address_get_balance,
'blockchain.address.get_history': cls.address_get_history,
'blockchain.address.get_mempool': cls.address_get_mempool,
# 'blockchain.address.listunspent': cls.address_listunspent,
'blockchain.address.subscribe': cls.address_subscribe,
'blockchain.address.unsubscribe': cls.address_unsubscribe,
})
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not LBRYElectrumX.request_handlers:
LBRYElectrumX.initialize_request_handlers()
if not LBRYElectrumX.cached_server_features:
LBRYElectrumX.set_server_features(self.env)
self.subscribe_headers = False
self.subscribe_headers_raw = False
self.subscribe_peers = False
self.connection.max_response_size = self.env.max_send
self.hashX_subs = {}
self.sv_seen = False
self.protocol_tuple = self.PROTOCOL_MIN
self.protocol_string = None
self.daemon = self.session_manager.daemon
self.db: 'HubDB' = self.session_manager.db
@classmethod
def protocol_min_max_strings(cls):
return [version_string(ver)
@ -1795,17 +1717,6 @@ class LBRYElectrumX(SessionBase):
return result[tx_hash][1]
class LocalRPC(SessionBase):
"""A local TCP RPC server session."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = 'RPC'
self.connection._max_response_size = 0
def protocol_version_string(self):
return 'RPC'
def get_from_possible_keys(dictionary, *keys):
for key in keys: