This commit is contained in:
Jack Robison 2019-06-26 16:27:43 -04:00
parent 8cba43bfed
commit c51db5294f
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 32 additions and 21 deletions

View file

@ -32,6 +32,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
import itertools import itertools
import json import json
import typing
from functools import partial from functools import partial
from numbers import Number from numbers import Number
@ -596,7 +597,7 @@ class JSONRPCConnection(object):
# Sent Requests and Batches that have not received a response. # Sent Requests and Batches that have not received a response.
# The key is its request ID; for a batch it is sorted tuple # The key is its request ID; for a batch it is sorted tuple
# of request IDs # of request IDs
self._requests = {} self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {}
# A public attribute intended to be settable dynamically # A public attribute intended to be settable dynamically
self.max_response_size = 0 self.max_response_size = 0
@ -683,7 +684,7 @@ class JSONRPCConnection(object):
# #
# External API # External API
# #
def send_request(self, request): def send_request(self, request: Request) -> typing.Tuple[bytes, Event]:
"""Send a Request. Return a (message, event) pair. """Send a Request. Return a (message, event) pair.
The message is an unframed message to send over the network. The message is an unframed message to send over the network.

View file

@ -70,6 +70,7 @@ class Peer:
# Transient, non-persisted metadata # Transient, non-persisted metadata
self.bad = False self.bad = False
self.other_port_pairs = set() self.other_port_pairs = set()
self.status = 2
@classmethod @classmethod
def peers_from_features(cls, features, source): def peers_from_features(cls, features, source):

View file

@ -12,6 +12,7 @@ import random
import socket import socket
import ssl import ssl
import time import time
import typing
from asyncio import Event, sleep from asyncio import Event, sleep
from collections import defaultdict, Counter from collections import defaultdict, Counter
@ -72,7 +73,7 @@ class PeerManager:
# ip_addr property is either None, an onion peer, or the # ip_addr property is either None, an onion peer, or the
# IP address that was connected to. Adding a peer will evict # IP address that was connected to. Adding a peer will evict
# any other peers with the same host name or IP address. # any other peers with the same host name or IP address.
self.peers = set() self.peers: typing.Set[Peer] = set()
self.permit_onion_peer_time = time.time() self.permit_onion_peer_time = time.time()
self.proxy = None self.proxy = None
self.group = TaskGroup() self.group = TaskGroup()
@ -394,7 +395,7 @@ class PeerManager:
self.group.add(self._detect_proxy()) self.group.add(self._detect_proxy())
self.group.add(self._import_peers()) self.group.add(self._import_peers())
def info(self): def info(self) -> typing.Dict[str, int]:
"""The number of peers.""" """The number of peers."""
self._set_peer_statuses() self._set_peer_statuses()
counter = Counter(peer.status for peer in self.peers) counter = Counter(peer.status for peer in self.peers)

View file

@ -16,6 +16,7 @@ import os
import pylru import pylru
import ssl import ssl
import time import time
import typing
from asyncio import Event, sleep from asyncio import Event, sleep
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
@ -31,13 +32,18 @@ from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
HASHX_LEN, Base58Error) HASHX_LEN, Base58Error)
from torba.server.daemon import DaemonError from torba.server.daemon import DaemonError
from torba.server.peers import PeerManager from torba.server.peers import PeerManager
if typing.TYPE_CHECKING:
from torba.server.env import Env
from torba.server.db import DB
from torba.server.block_processor import BlockProcessor
from torba.server.mempool import MemPool
from torba.server.daemon import Daemon
BAD_REQUEST = 1 BAD_REQUEST = 1
DAEMON_ERROR = 2 DAEMON_ERROR = 2
def scripthash_to_hashX(scripthash): def scripthash_to_hashX(scripthash: str) -> bytes:
try: try:
bin_hash = hex_str_to_hash(scripthash) bin_hash = hex_str_to_hash(scripthash)
if len(bin_hash) == 32: if len(bin_hash) == 32:
@ -47,7 +53,7 @@ def scripthash_to_hashX(scripthash):
raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash')
def non_negative_integer(value): def non_negative_integer(value) -> int:
"""Return param value it is or can be converted to a non-negative """Return param value it is or can be converted to a non-negative
integer, otherwise raise an RPCError.""" integer, otherwise raise an RPCError."""
try: try:
@ -60,14 +66,14 @@ def non_negative_integer(value):
f'{value} should be a non-negative integer') f'{value} should be a non-negative integer')
def assert_boolean(value): def assert_boolean(value) -> bool:
"""Return param value it is boolean otherwise raise an RPCError.""" """Return param value it is boolean otherwise raise an RPCError."""
if value in (False, True): if value in (False, True):
return value return value
raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')
def assert_tx_hash(value): def assert_tx_hash(value: str) -> None:
"""Raise an RPCError if the value is not a valid transaction """Raise an RPCError if the value is not a valid transaction
hash.""" hash."""
try: try:
@ -97,7 +103,7 @@ class Semaphores:
class SessionGroup: class SessionGroup:
def __init__(self, gid): def __init__(self, gid: int):
self.gid = gid self.gid = gid
# Concurrency per group # Concurrency per group
self.semaphore = asyncio.Semaphore(20) self.semaphore = asyncio.Semaphore(20)
@ -106,7 +112,8 @@ class SessionGroup:
class SessionManager: class SessionManager:
"""Holds global state about all sessions.""" """Holds global state about all sessions."""
def __init__(self, env, db, bp, daemon, mempool, shutdown_event): def __init__(self, env: 'Env', db: 'DB', bp: 'BlockProcessor', daemon: 'Daemon', mempool: 'MemPool',
shutdown_event: asyncio.Event):
env.max_send = max(350000, env.max_send) env.max_send = max(350000, env.max_send)
self.env = env self.env = env
self.db = db self.db = db
@ -116,28 +123,29 @@ class SessionManager:
self.peer_mgr = PeerManager(env, db) self.peer_mgr = PeerManager(env, db)
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.logger = util.class_logger(__name__, self.__class__.__name__) self.logger = util.class_logger(__name__, self.__class__.__name__)
self.servers = {} self.servers: typing.Dict[str, asyncio.AbstractServer] = {}
self.sessions = set() self.sessions: typing.Set['SessionBase'] = set()
self.max_subs = env.max_subs self.max_subs = env.max_subs
self.cur_group = SessionGroup(0) self.cur_group = SessionGroup(0)
self.txs_sent = 0 self.txs_sent = 0
self.start_time = time.time() self.start_time = time.time()
self.history_cache = pylru.lrucache(256) self.history_cache = pylru.lrucache(256)
self.notified_height = None self.notified_height: typing.Optional[int] = None
# Cache some idea of room to avoid recounting on each subscription # Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0 self.subs_room = 0
# Masternode stuff only for such coins # Masternode stuff only for such coins
if issubclass(env.coin.SESSIONCLS, DashElectrumX): if issubclass(env.coin.SESSIONCLS, DashElectrumX):
self.mn_cache_height = 0 self.mn_cache_height = 0
self.mn_cache = [] self.mn_cache = [] # type: ignore
self.session_event = Event() self.session_event = Event()
# Set up the RPC request handlers # Set up the RPC request handlers
cmds = ('add_peer daemon_url disconnect getinfo groups log peers ' cmds = ('add_peer daemon_url disconnect getinfo groups log peers '
'query reorg sessions stop'.split()) 'query reorg sessions stop'.split())
LocalRPC.request_handlers = {cmd: getattr(self, 'rpc_' + cmd) LocalRPC.request_handlers.update(
for cmd in cmds} {cmd: getattr(self, 'rpc_' + cmd) for cmd in cmds}
)
async def _start_server(self, kind, *args, **kw_args): async def _start_server(self, kind, *args, **kw_args):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -147,11 +155,10 @@ class SessionManager:
protocol_class = self.env.coin.SESSIONCLS protocol_class = self.env.coin.SESSIONCLS
protocol_factory = partial(protocol_class, self, self.db, protocol_factory = partial(protocol_class, self, self.db,
self.mempool, self.peer_mgr, kind) self.mempool, self.peer_mgr, kind)
server = loop.create_server(protocol_factory, *args, **kw_args)
host, port = args[:2] host, port = args[:2]
try: try:
self.servers[kind] = await server self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args)
except OSError as e: # don't suppress CancelledError except OSError as e: # don't suppress CancelledError
self.logger.error(f'{kind} server failed to listen on {host}:' self.logger.error(f'{kind} server failed to listen on {host}:'
f'{port:d} :{e!r}') f'{port:d} :{e!r}')
@ -219,7 +226,7 @@ class SessionManager:
group_map[session.group].append(session) group_map[session.group].append(session)
return group_map return group_map
def _sub_count(self): def _sub_count(self) -> int:
return sum(s.sub_count() for s in self.sessions) return sum(s.sub_count() for s in self.sessions)
def _lookup_session(self, session_id): def _lookup_session(self, session_id):
@ -533,7 +540,7 @@ class SessionManager:
session.close(force_after=1) for session in self.sessions session.close(force_after=1) for session in self.sessions
]) ])
def session_count(self): def session_count(self) -> int:
"""The number of connections that we've sent something to.""" """The number of connections that we've sent something to."""
return len(self.sessions) return len(self.sessions)
@ -620,6 +627,7 @@ class SessionBase(RPCSession):
MAX_CHUNK_SIZE = 2016 MAX_CHUNK_SIZE = 2016
session_counter = itertools.count() session_counter = itertools.count()
request_handlers: typing.Dict[str, typing.Callable] = {}
def __init__(self, session_mgr, db, mempool, peer_mgr, kind): def __init__(self, session_mgr, db, mempool, peer_mgr, kind):
connection = JSONRPCConnection(JSONRPCAutoDetect) connection = JSONRPCConnection(JSONRPCAutoDetect)