From c51db5294f346315d89fd4ea1951ac170be3181a Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Wed, 26 Jun 2019 16:27:43 -0400 Subject: [PATCH] typing --- torba/torba/rpc/jsonrpc.py | 5 +++-- torba/torba/server/peer.py | 1 + torba/torba/server/peers.py | 5 +++-- torba/torba/server/session.py | 42 +++++++++++++++++++++-------------- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/torba/torba/rpc/jsonrpc.py b/torba/torba/rpc/jsonrpc.py index 84830ecd2..dc80b61a4 100644 --- a/torba/torba/rpc/jsonrpc.py +++ b/torba/torba/rpc/jsonrpc.py @@ -32,6 +32,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', import itertools import json +import typing from functools import partial from numbers import Number @@ -596,7 +597,7 @@ class JSONRPCConnection(object): # Sent Requests and Batches that have not received a response. # The key is its request ID; for a batch it is sorted tuple # of request IDs - self._requests = {} + self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {} # A public attribute intended to be settable dynamically self.max_response_size = 0 @@ -683,7 +684,7 @@ class JSONRPCConnection(object): # # External API # - def send_request(self, request): + def send_request(self, request: Request) -> typing.Tuple[bytes, Event]: """Send a Request. Return a (message, event) pair. The message is an unframed message to send over the network. diff --git a/torba/torba/server/peer.py b/torba/torba/server/peer.py index a3f268da3..e3c0c21d9 100644 --- a/torba/torba/server/peer.py +++ b/torba/torba/server/peer.py @@ -70,6 +70,7 @@ class Peer: # Transient, non-persisted metadata self.bad = False self.other_port_pairs = set() + self.status = 2 @classmethod def peers_from_features(cls, features, source): diff --git a/torba/torba/server/peers.py b/torba/torba/server/peers.py index 1ea4ebc91..cbaffd788 100644 --- a/torba/torba/server/peers.py +++ b/torba/torba/server/peers.py @@ -12,6 +12,7 @@ import random import socket import ssl import time +import typing from asyncio import Event, sleep from collections import defaultdict, Counter @@ -72,7 +73,7 @@ class PeerManager: # ip_addr property is either None, an onion peer, or the # IP address that was connected to. Adding a peer will evict # any other peers with the same host name or IP address. - self.peers = set() + self.peers: typing.Set[Peer] = set() self.permit_onion_peer_time = time.time() self.proxy = None self.group = TaskGroup() @@ -394,7 +395,7 @@ class PeerManager: self.group.add(self._detect_proxy()) self.group.add(self._import_peers()) - def info(self): + def info(self) -> typing.Dict[str, int]: """The number of peers.""" self._set_peer_statuses() counter = Counter(peer.status for peer in self.peers) diff --git a/torba/torba/server/session.py b/torba/torba/server/session.py index 230f46fa5..45900f8c2 100644 --- a/torba/torba/server/session.py +++ b/torba/torba/server/session.py @@ -16,6 +16,7 @@ import os import pylru import ssl import time +import typing from asyncio import Event, sleep from collections import defaultdict from functools import partial @@ -31,13 +32,18 @@ from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, Base58Error) from torba.server.daemon import DaemonError from torba.server.peers import PeerManager - +if typing.TYPE_CHECKING: + from torba.server.env import Env + from torba.server.db import DB + from torba.server.block_processor import BlockProcessor + from torba.server.mempool import MemPool + from torba.server.daemon import Daemon BAD_REQUEST = 1 DAEMON_ERROR = 2 -def scripthash_to_hashX(scripthash): +def scripthash_to_hashX(scripthash: str) -> bytes: try: bin_hash = hex_str_to_hash(scripthash) if len(bin_hash) == 32: @@ -47,7 +53,7 @@ def scripthash_to_hashX(scripthash): raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') -def non_negative_integer(value): +def non_negative_integer(value) -> int: """Return param value it is or can be converted to a non-negative integer, otherwise raise an RPCError.""" try: @@ -60,14 +66,14 @@ def non_negative_integer(value): f'{value} should be a non-negative integer') -def assert_boolean(value): +def assert_boolean(value) -> bool: """Return param value it is boolean otherwise raise an RPCError.""" if value in (False, True): return value raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') -def assert_tx_hash(value): +def assert_tx_hash(value: str) -> None: """Raise an RPCError if the value is not a valid transaction hash.""" try: @@ -97,7 +103,7 @@ class Semaphores: class SessionGroup: - def __init__(self, gid): + def __init__(self, gid: int): self.gid = gid # Concurrency per group self.semaphore = asyncio.Semaphore(20) @@ -106,7 +112,8 @@ class SessionGroup: class SessionManager: """Holds global state about all sessions.""" - def __init__(self, env, db, bp, daemon, mempool, shutdown_event): + def __init__(self, env: 'Env', db: 'DB', bp: 'BlockProcessor', daemon: 'Daemon', mempool: 'MemPool', + shutdown_event: asyncio.Event): env.max_send = max(350000, env.max_send) self.env = env self.db = db @@ -116,28 +123,29 @@ class SessionManager: self.peer_mgr = PeerManager(env, db) self.shutdown_event = shutdown_event self.logger = util.class_logger(__name__, self.__class__.__name__) - self.servers = {} - self.sessions = set() + self.servers: typing.Dict[str, asyncio.AbstractServer] = {} + self.sessions: typing.Set['SessionBase'] = set() self.max_subs = env.max_subs self.cur_group = SessionGroup(0) self.txs_sent = 0 self.start_time = time.time() self.history_cache = pylru.lrucache(256) - self.notified_height = None + self.notified_height: typing.Optional[int] = None # Cache some idea of room to avoid recounting on each subscription self.subs_room = 0 # Masternode stuff only for such coins if issubclass(env.coin.SESSIONCLS, DashElectrumX): self.mn_cache_height = 0 - self.mn_cache = [] + self.mn_cache = [] # type: ignore self.session_event = Event() # Set up the RPC request handlers cmds = ('add_peer daemon_url disconnect getinfo groups log peers ' 'query reorg sessions stop'.split()) - LocalRPC.request_handlers = {cmd: getattr(self, 'rpc_' + cmd) - for cmd in cmds} + LocalRPC.request_handlers.update( + {cmd: getattr(self, 'rpc_' + cmd) for cmd in cmds} + ) async def _start_server(self, kind, *args, **kw_args): loop = asyncio.get_event_loop() @@ -147,11 +155,10 @@ class SessionManager: protocol_class = self.env.coin.SESSIONCLS protocol_factory = partial(protocol_class, self, self.db, self.mempool, self.peer_mgr, kind) - server = loop.create_server(protocol_factory, *args, **kw_args) host, port = args[:2] try: - self.servers[kind] = await server + self.servers[kind] = await loop.create_server(protocol_factory, *args, **kw_args) except OSError as e: # don't suppress CancelledError self.logger.error(f'{kind} server failed to listen on {host}:' f'{port:d} :{e!r}') @@ -219,7 +226,7 @@ class SessionManager: group_map[session.group].append(session) return group_map - def _sub_count(self): + def _sub_count(self) -> int: return sum(s.sub_count() for s in self.sessions) def _lookup_session(self, session_id): @@ -533,7 +540,7 @@ class SessionManager: session.close(force_after=1) for session in self.sessions ]) - def session_count(self): + def session_count(self) -> int: """The number of connections that we've sent something to.""" return len(self.sessions) @@ -620,6 +627,7 @@ class SessionBase(RPCSession): MAX_CHUNK_SIZE = 2016 session_counter = itertools.count() + request_handlers: typing.Dict[str, typing.Callable] = {} def __init__(self, session_mgr, db, mempool, peer_mgr, kind): connection = JSONRPCConnection(JSONRPCAutoDetect)