merged aiorpcx into torba.rpc

This commit is contained in:
Lex Berezhny 2018-12-05 00:40:06 -05:00
parent ba5fc2a627
commit 899a6f0d4a
15 changed files with 2582 additions and 14 deletions

View file

@ -10,7 +10,6 @@ with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh:
REQUIRES = [
'aiohttp',
'aiorpcx==0.9.0',
'coincurve',
'pbkdf2',
'cryptography',

13
torba/rpc/__init__.py Normal file
View file

@ -0,0 +1,13 @@
from .curio import *
from .framing import *
from .jsonrpc import *
from .socks import *
from .session import *
from .util import *
__all__ = (curio.__all__ +
framing.__all__ +
jsonrpc.__all__ +
socks.__all__ +
session.__all__ +
util.__all__)

411
torba/rpc/curio.py Normal file
View file

@ -0,0 +1,411 @@
# The code below is mostly my own but based on the interfaces of the
# curio library by David Beazley. I'm considering switching to using
# curio. In the mean-time this is an attempt to provide a similar
# clean, pure-async interface and move away from direct
# framework-specific dependencies. As asyncio differs in its design
# it is not possible to provide identical semantics.
#
# The curio library is distributed under the following licence:
#
# Copyright (C) 2015-2017
# David Beazley (Dabeaz LLC)
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the David Beazley or Dabeaz LLC may be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
import asyncio
from asyncio import (
CancelledError, get_event_loop, Queue, Event, Lock, Semaphore,
sleep, Task
)
from collections import deque
from contextlib import suppress
from functools import partial
from .util import normalize_corofunc, check_task
__all__ = (
'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError',
'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup',
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at',
)
async def run_in_thread(func, *args):
'''Run a function in a separate thread, and await its completion.'''
return await get_event_loop().run_in_executor(None, func, *args)
async def spawn(coro, *args, loop=None, report_crash=True):
return spawn_sync(coro, *args, loop=loop, report_crash=report_crash)
def spawn_sync(coro, *args, loop=None, report_crash=True):
coro = normalize_corofunc(coro, args)
loop = loop or get_event_loop()
task = loop.create_task(coro)
if report_crash:
task.add_done_callback(partial(check_task, logging))
return task
class TaskGroup(object):
'''A class representing a group of executing tasks. tasks is an
optional set of existing tasks to put into the group. New tasks
can later be added using the spawn() method below. wait specifies
the policy used for waiting for tasks. See the join() method
below. Each TaskGroup is an independent entity. Task groups do not
form a hierarchy or any kind of relationship to other previously
created task groups or tasks. Moreover, Tasks created by the top
level spawn() function are not placed into any task group. To
create a task in a group, it should be created using
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
completed attribute: the first task that completed with a result
in the group. Takes into account the wait option used in the
TaskGroup constructor (but not in the join method)`.
'''
def __init__(self, tasks=(), *, wait=all):
if wait not in (any, all, object):
raise ValueError('invalid wait argument')
self._done = deque()
self._pending = set()
self._wait = wait
self._done_event = Event()
self._logger = logging.getLogger(self.__class__.__name__)
self._closed = False
self.completed = None
for task in tasks:
self._add_task(task)
def _add_task(self, task):
'''Add an already existing task to the task group.'''
if hasattr(task, '_task_group'):
raise RuntimeError('task is already part of a group')
if self._closed:
raise RuntimeError('task group is closed')
task._task_group = self
if task.done():
self._done.append(task)
else:
self._pending.add(task)
task.add_done_callback(self._on_done)
def _on_done(self, task):
task._task_group = None
self._pending.remove(task)
self._done.append(task)
self._done_event.set()
if self.completed is None:
if not task.cancelled() and not task.exception():
if self._wait is object and task.result() is None:
pass
else:
self.completed = task
async def spawn(self, coro, *args):
'''Create a new task thats part of the group. Returns a Task
instance.
'''
task = await spawn(coro, *args, report_crash=False)
self._add_task(task)
return task
async def add_task(self, task):
'''Add an already existing task to the task group.'''
self._add_task(task)
async def next_done(self):
'''Returns the next completed task. Returns None if no more tasks
remain. A TaskGroup may also be used as an asynchronous iterator.
'''
if not self._done and self._pending:
self._done_event.clear()
await self._done_event.wait()
if self._done:
return self._done.popleft()
return None
async def next_result(self):
'''Returns the result of the next completed task. If the task failed
with an exception, that exception is raised. A RuntimeError
exception is raised if this is called when no remaining tasks
are available.'''
task = await self.next_done()
if not task:
raise RuntimeError('no tasks remain')
return task.result()
async def join(self):
'''Wait for tasks in the group to terminate according to the wait
policy for the group.
If the join() operation itself is cancelled, all remaining
tasks in the group are also cancelled.
If a TaskGroup is used as a context manager, the join() method
is called on context-exit.
Once join() returns, no more tasks may be added to the task
group. Tasks can be added while join() is running.
'''
def errored(task):
return not task.cancelled() and task.exception()
try:
if self._wait in (all, object):
while True:
task = await self.next_done()
if task is None:
return
if errored(task):
break
if self._wait is object:
if task.cancelled() or task.result() is not None:
return
else: # any
task = await self.next_done()
if task is None or not errored(task):
return
finally:
await self.cancel_remaining()
if errored(task):
raise task.exception()
async def cancel_remaining(self):
'''Cancel all remaining tasks.'''
self._closed = True
for task in list(self._pending):
task.cancel()
with suppress(CancelledError):
await task
def closed(self):
return self._closed
def __aiter__(self):
return self
async def __anext__(self):
task = await self.next_done()
if task:
return task
raise StopAsyncIteration
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if exc_type:
await self.cancel_remaining()
else:
await self.join()
class TaskTimeout(CancelledError):
def __init__(self, secs):
self.secs = secs
def __str__(self):
return f'task timed out after {self.args[0]}s'
class TimeoutCancellationError(CancelledError):
pass
class UncaughtTimeoutError(Exception):
pass
def _set_new_deadline(task, deadline):
def timeout_task():
# Unfortunately task.cancel is all we can do with asyncio
task.cancel()
task._timed_out = deadline
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
def _set_task_deadline(task, deadline):
deadlines = getattr(task, '_deadlines', [])
if deadlines:
if deadline < min(deadlines):
task._deadline_handle.cancel()
_set_new_deadline(task, deadline)
else:
_set_new_deadline(task, deadline)
deadlines.append(deadline)
task._deadlines = deadlines
task._timed_out = None
def _unset_task_deadline(task):
deadlines = task._deadlines
timed_out_deadline = task._timed_out
uncaught = timed_out_deadline not in deadlines
task._deadline_handle.cancel()
deadlines.pop()
if deadlines:
_set_new_deadline(task, min(deadlines))
return timed_out_deadline, uncaught
class TimeoutAfter(object):
def __init__(self, deadline, *, ignore=False, absolute=False):
self._deadline = deadline
self._ignore = ignore
self._absolute = absolute
self.expired = False
async def __aenter__(self):
task = asyncio.current_task()
loop_time = task._loop.time()
if self._absolute:
self._secs = self._deadline - loop_time
else:
self._secs = self._deadline
self._deadline += loop_time
_set_task_deadline(task, self._deadline)
self.expired = False
self._task = task
return self
async def __aexit__(self, exc_type, exc_value, traceback):
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
if exc_type not in (CancelledError, TaskTimeout,
TimeoutCancellationError):
return False
if timed_out_deadline == self._deadline:
self.expired = True
if self._ignore:
return True
raise TaskTimeout(self._secs) from None
if timed_out_deadline is None:
assert exc_type is CancelledError
return False
if uncaught:
raise UncaughtTimeoutError('uncaught timeout received')
if exc_type is TimeoutCancellationError:
return False
raise TimeoutCancellationError(timed_out_deadline) from None
async def _timeout_after_func(seconds, absolute, coro, args):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute):
return await coro
def timeout_after(seconds, coro=None, *args):
'''Execute the specified coroutine and return its result. However,
issue a cancellation request to the calling task after seconds
have elapsed. When this happens, a TaskTimeout exception is
raised. If coro is None, the result of this function serves
as an asynchronous context manager that applies a timeout to a
block of statements.
timeout_after() may be composed with other timeout_after()
operations (i.e., nested timeouts). If an outer timeout expires
first, then TimeoutCancellationError is raised instead of
TaskTimeout. If an inner timeout expires and fails to properly
TaskTimeout, a UncaughtTimeoutError is raised in the outer
timeout.
'''
if coro:
return _timeout_after_func(seconds, False, coro, args)
return TimeoutAfter(seconds)
def timeout_at(clock, coro=None, *args):
'''Execute the specified coroutine and return its result. However,
issue a cancellation request to the calling task after seconds
have elapsed. When this happens, a TaskTimeout exception is
raised. If coro is None, the result of this function serves
as an asynchronous context manager that applies a timeout to a
block of statements.
timeout_after() may be composed with other timeout_after()
operations (i.e., nested timeouts). If an outer timeout expires
first, then TimeoutCancellationError is raised instead of
TaskTimeout. If an inner timeout expires and fails to properly
TaskTimeout, a UncaughtTimeoutError is raised in the outer
timeout.
'''
if coro:
return _timeout_after_func(clock, True, coro, args)
return TimeoutAfter(clock, absolute=True)
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
coro = normalize_corofunc(coro, args)
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
return await coro
return timeout_result
def ignore_after(seconds, coro=None, *args, timeout_result=None):
'''Execute the specified coroutine and return its result. Issue a
cancellation request after seconds have elapsed. When a timeout
occurs, no exception is raised. Instead, timeout_result is
returned.
If coro is None, the result is an asynchronous context manager
that applies a timeout to a block of statements. For the context
manager case, the resulting context manager object has an expired
attribute set to True if time expired.
Note: ignore_after() may also be composed with other timeout
operations. TimeoutCancellationError and UncaughtTimeoutError
exceptions might be raised according to the same rules as for
timeout_after().
'''
if coro:
return _ignore_after_func(seconds, False, coro, args, timeout_result)
return TimeoutAfter(seconds, ignore=True)
def ignore_at(clock, coro=None, *args, timeout_result=None):
'''
Stop the enclosed task or block of code at an absolute
clock value. Same usage as ignore_after().
'''
if coro:
return _ignore_after_func(clock, True, coro, args, timeout_result)
return TimeoutAfter(clock, absolute=True, ignore=True)

239
torba/rpc/framing.py Normal file
View file

@ -0,0 +1,239 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''RPC message framing in a byte stream.'''
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
from hashlib import sha256 as _sha256
from struct import Struct
from asyncio import Queue
class FramerBase(object):
'''Abstract base class for a framer.
A framer breaks an incoming byte stream into protocol messages,
buffering if necesary. It also frames outgoing messages into
a byte stream.
'''
def frame(self, message):
'''Return the framed message.'''
raise NotImplementedError
def received_bytes(self, data):
'''Pass incoming network bytes.'''
raise NotImplementedError
async def receive_message(self):
'''Wait for a complete unframed message to arrive, and return it.'''
raise NotImplementedError
class NewlineFramer(FramerBase):
'''A framer for a protocol where messages are separated by newlines.'''
# The default max_size value is motivated by JSONRPC, where a
# normal request will be 250 bytes or less, and a reasonable
# batch may contain 4000 requests.
def __init__(self, max_size=250 * 4000):
'''max_size - an anti-DoS measure. If, after processing an incoming
message, buffered data would exceed max_size bytes, that
buffered data is dropped entirely and the framer waits for a
newline character to re-synchronize the stream.
'''
self.max_size = max_size
self.queue = Queue()
self.received_bytes = self.queue.put_nowait
self.synchronizing = False
self.residual = b''
def frame(self, message):
return message + b'\n'
async def receive_message(self):
parts = []
buffer_size = 0
while True:
part = self.residual
self.residual = b''
if not part:
part = await self.queue.get()
npos = part.find(b'\n')
if npos == -1:
parts.append(part)
buffer_size += len(part)
# Ignore over-sized messages; re-synchronize
if buffer_size <= self.max_size:
continue
self.synchronizing = True
raise MemoryError(f'dropping message over {self.max_size:,d} '
f'bytes and re-synchronizing')
tail, self.residual = part[:npos], part[npos + 1:]
if self.synchronizing:
self.synchronizing = False
return await self.receive_message()
else:
parts.append(tail)
return b''.join(parts)
class ByteQueue(object):
'''A producer-comsumer queue. Incoming network data is put as it
arrives, and the consumer calls an async method waiting for data of
a specific length.'''
def __init__(self):
self.queue = Queue()
self.parts = []
self.parts_len = 0
self.put_nowait = self.queue.put_nowait
async def receive(self, size):
while self.parts_len < size:
part = await self.queue.get()
self.parts.append(part)
self.parts_len += len(part)
self.parts_len -= size
whole = b''.join(self.parts)
self.parts = [whole[size:]]
return whole[:size]
class BinaryFramer(object):
'''A framer for binary messaging protocols.'''
def __init__(self):
self.byte_queue = ByteQueue()
self.message_queue = Queue()
self.received_bytes = self.byte_queue.put_nowait
def frame(self, message):
command, payload = message
return b''.join((
self._build_header(command, payload),
payload
))
async def receive_message(self):
command, payload_len, checksum = await self._receive_header()
payload = await self.byte_queue.receive(payload_len)
payload_checksum = self._checksum(payload)
if payload_checksum != checksum:
raise BadChecksumError(payload_checksum, checksum)
return command, payload
def _checksum(self, payload):
raise NotImplementedError
def _build_header(self, command, payload):
raise NotImplementedError
async def _receive_header(self):
raise NotImplementedError
# Helpers
struct_le_I = Struct('<I')
pack_le_uint32 = struct_le_I.pack
def sha256(x):
'''Simple wrapper of hashlib sha256.'''
return _sha256(x).digest()
def double_sha256(x):
'''SHA-256 of SHA-256, as used extensively in bitcoin.'''
return sha256(sha256(x))
class BadChecksumError(Exception):
pass
class BadMagicError(Exception):
pass
class OversizedPayloadError(Exception):
pass
class BitcoinFramer(BinaryFramer):
'''Provides a framer of binary message payloads in the style of the
Bitcoin network protocol.
Each binary message has the following elements, in order:
Magic - to confirm network (currently unused for stream sync)
Command - padded command
Length - payload length in bytes
Checksum - checksum of the payload
Payload - binary payload
Call frame(command, payload) to get a framed message.
Pass incoming network bytes to received_bytes().
Wait on receive_message() to get incoming (command, payload) pairs.
'''
def __init__(self, magic, max_block_size):
def pad_command(command):
fill = 12 - len(command)
if fill < 0:
raise ValueError(f'command {command} too long')
return command + bytes(fill)
super().__init__()
self._magic = magic
self._max_block_size = max_block_size
self._pad_command = pad_command
self._unpack = Struct(f'<4s12sI4s').unpack
def _checksum(self, payload):
return double_sha256(payload)[:4]
def _build_header(self, command, payload):
return b''.join((
self._magic,
self._pad_command(command),
pack_le_uint32(len(payload)),
self._checksum(payload)
))
async def _receive_header(self):
header = await self.byte_queue.receive(24)
magic, command, payload_len, checksum = self._unpack(header)
if magic != self._magic:
raise BadMagicError(magic, self._magic)
command = command.rstrip(b'\0')
if payload_len > 1024 * 1024:
if command != b'block' or payload_len > self._max_block_size:
raise OversizedPayloadError(command, payload_len)
return command, payload_len, checksum

801
torba/rpc/jsonrpc.py Normal file
View file

@ -0,0 +1,801 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.'''
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
'RPCError', 'ProtocolError',
'JSONRPCConnection', 'handler_invocation')
import itertools
import json
from functools import partial
from numbers import Number
import attr
from asyncio import Queue, Event, CancelledError
from .util import signature_info
class SingleRequest(object):
__slots__ = ('method', 'args')
def __init__(self, method, args):
if not isinstance(method, str):
raise ProtocolError(JSONRPC.METHOD_NOT_FOUND,
'method must be a string')
if not isinstance(args, (list, tuple, dict)):
raise ProtocolError.invalid_args('request arguments must be a '
'list or a dictionary')
self.args = args
self.method = method
def __repr__(self):
return f'{self.__class__.__name__}({self.method!r}, {self.args!r})'
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.method == other.method and self.args == other.args)
class Request(SingleRequest):
def send_result(self, response):
return None
class Notification(SingleRequest):
pass
class Batch(object):
__slots__ = ('items', )
def __init__(self, items):
if not isinstance(items, (list, tuple)):
raise ProtocolError.invalid_request('items must be a list')
if not items:
raise ProtocolError.empty_batch()
if not (all(isinstance(item, SingleRequest) for item in items) or
all(isinstance(item, Response) for item in items)):
raise ProtocolError.invalid_request('batch must be homogeneous')
self.items = items
def __len__(self):
return len(self.items)
def __getitem__(self, item):
return self.items[item]
def __iter__(self):
return iter(self.items)
def __repr__(self):
return f'Batch({len(self.items)} items)'
class Response(object):
__slots__ = ('result', )
def __init__(self, result):
# Type checking happens when converting to a message
self.result = result
class CodeMessageError(Exception):
def __init__(self, code, message):
super().__init__(code, message)
@property
def code(self):
return self.args[0]
@property
def message(self):
return self.args[1]
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.code == other.code and self.message == other.message)
def __hash__(self):
# overridden to make the exception hashable
# see https://bugs.python.org/issue28603
return hash((self.code, self.message))
@classmethod
def invalid_args(cls, message):
return cls(JSONRPC.INVALID_ARGS, message)
@classmethod
def invalid_request(cls, message):
return cls(JSONRPC.INVALID_REQUEST, message)
@classmethod
def empty_batch(cls):
return cls.invalid_request('batch is empty')
class RPCError(CodeMessageError):
pass
class ProtocolError(CodeMessageError):
def __init__(self, code, message):
super().__init__(code, message)
# If not None send this unframed message over the network
self.error_message = None
# If the error was in a JSON response message; its message ID.
# Since None can be a response message ID, "id" means the
# error was not sent in a JSON response
self.response_msg_id = id
class JSONRPC(object):
'''Abstract base class that interprets and constructs JSON RPC messages.'''
# Error codes. See http://www.jsonrpc.org/specification
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_ARGS = -32602
INTERNAL_ERROR = -32603
# Codes specific to this library
ERROR_CODE_UNAVAILABLE = -100
# Can be overridden by derived classes
allow_batches = True
@classmethod
def _message_id(cls, message, require_id):
'''Validate the message is a dictionary and return its ID.
Raise an error if the message is invalid or the ID is of an
invalid type. If it has no ID, raise an error if require_id
is True, otherwise return None.
'''
raise NotImplementedError
@classmethod
def _validate_message(cls, message):
'''Validate other parts of the message other than those
done in _message_id.'''
pass
@classmethod
def _request_args(cls, request):
'''Validate the existence and type of the arguments passed
in the request dictionary.'''
raise NotImplementedError
@classmethod
def _process_request(cls, payload):
request_id = None
try:
request_id = cls._message_id(payload, False)
cls._validate_message(payload)
method = payload.get('method')
if request_id is None:
item = Notification(method, cls._request_args(payload))
else:
item = Request(method, cls._request_args(payload))
return item, request_id
except ProtocolError as error:
code, message = error.code, error.message
raise cls._error(code, message, True, request_id)
@classmethod
def _process_response(cls, payload):
request_id = None
try:
request_id = cls._message_id(payload, True)
cls._validate_message(payload)
return Response(cls.response_value(payload)), request_id
except ProtocolError as error:
code, message = error.code, error.message
raise cls._error(code, message, False, request_id)
@classmethod
def _message_to_payload(cls, message):
'''Returns a Python object or a ProtocolError.'''
try:
return json.loads(message.decode())
except UnicodeDecodeError:
message = 'messages must be encoded in UTF-8'
except json.JSONDecodeError:
message = 'invalid JSON'
raise cls._error(cls.PARSE_ERROR, message, True, None)
@classmethod
def _error(cls, code, message, send, msg_id):
error = ProtocolError(code, message)
if send:
error.error_message = cls.response_message(error, msg_id)
else:
error.response_msg_id = msg_id
return error
#
# External API
#
@classmethod
def message_to_item(cls, message):
'''Translate an unframed received message and return an
(item, request_id) pair.
The item can be a Request, Notification, Response or a list.
A JSON RPC error response is returned as an RPCError inside a
Response object.
If a Batch is returned, request_id is an iterable of request
ids, one per batch member.
If the message violates the protocol in some way a
ProtocolError is returned, except if the message was
determined to be a response, in which case the ProtocolError
is placed inside a Response object. This is so that client
code can mark a request as having been responded to even if
the response was bad.
raises: ProtocolError
'''
payload = cls._message_to_payload(message)
if isinstance(payload, dict):
if 'method' in payload:
return cls._process_request(payload)
else:
return cls._process_response(payload)
elif isinstance(payload, list) and cls.allow_batches:
if not payload:
raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty',
True, None)
return payload, None
raise cls._error(cls.INVALID_REQUEST,
'request object must be a dictionary', True, None)
# Message formation
@classmethod
def request_message(cls, item, request_id):
'''Convert an RPCRequest item to a message.'''
assert isinstance(item, Request)
return cls.encode_payload(cls.request_payload(item, request_id))
@classmethod
def notification_message(cls, item):
'''Convert an RPCRequest item to a message.'''
assert isinstance(item, Notification)
return cls.encode_payload(cls.request_payload(item, None))
@classmethod
def response_message(cls, result, request_id):
'''Convert a response result (or RPCError) to a message.'''
if isinstance(result, CodeMessageError):
payload = cls.error_payload(result, request_id)
else:
payload = cls.response_payload(result, request_id)
return cls.encode_payload(payload)
@classmethod
def batch_message(cls, batch, request_ids):
'''Convert a request Batch to a message.'''
assert isinstance(batch, Batch)
if not cls.allow_batches:
raise ProtocolError.invalid_request(
'protocol does not permit batches')
id_iter = iter(request_ids)
rm = cls.request_message
nm = cls.notification_message
parts = (rm(request, next(id_iter)) if isinstance(request, Request)
else nm(request) for request in batch)
return cls.batch_message_from_parts(parts)
@classmethod
def batch_message_from_parts(cls, messages):
'''Convert messages, one per batch item, into a batch message. At
least one message must be passed.
'''
# Comma-separate the messages and wrap the lot in square brackets
middle = b', '.join(messages)
if not middle:
raise ProtocolError.empty_batch()
return b''.join([b'[', middle, b']'])
@classmethod
def encode_payload(cls, payload):
'''Encode a Python object as JSON and convert it to bytes.'''
try:
return json.dumps(payload).encode()
except TypeError:
msg = f'JSON payload encoding error: {payload}'
raise ProtocolError(cls.INTERNAL_ERROR, msg) from None
class JSONRPCv1(JSONRPC):
'''JSON RPC version 1.0.'''
allow_batches = False
@classmethod
def _message_id(cls, message, require_id):
# JSONv1 requires an ID always, but without constraint on its type
# No need to test for a dictionary here as we don't handle batches.
if 'id' not in message:
raise ProtocolError.invalid_request('request has no "id"')
return message['id']
@classmethod
def _request_args(cls, request):
args = request.get('params')
if not isinstance(args, list):
raise ProtocolError.invalid_args(
f'invalid request arguments: {args}')
return args
@classmethod
def _best_effort_error(cls, error):
# Do our best to interpret the error
code = cls.ERROR_CODE_UNAVAILABLE
message = 'no error message provided'
if isinstance(error, str):
message = error
elif isinstance(error, int):
code = error
elif isinstance(error, dict):
if isinstance(error.get('message'), str):
message = error['message']
if isinstance(error.get('code'), int):
code = error['code']
return RPCError(code, message)
@classmethod
def response_value(cls, payload):
if 'result' not in payload or 'error' not in payload:
raise ProtocolError.invalid_request(
'response must contain both "result" and "error"')
result = payload['result']
error = payload['error']
if error is None:
return result # It seems None can be a valid result
if result is not None:
raise ProtocolError.invalid_request(
'response has a "result" and an "error"')
return cls._best_effort_error(error)
@classmethod
def request_payload(cls, request, request_id):
'''JSON v1 request (or notification) payload.'''
if isinstance(request.args, dict):
raise ProtocolError.invalid_args(
'JSONRPCv1 does not support named arguments')
return {
'method': request.method,
'params': request.args,
'id': request_id
}
@classmethod
def response_payload(cls, result, request_id):
'''JSON v1 response payload.'''
return {
'result': result,
'error': None,
'id': request_id
}
@classmethod
def error_payload(cls, error, request_id):
return {
'result': None,
'error': {'code': error.code, 'message': error.message},
'id': request_id
}
class JSONRPCv2(JSONRPC):
'''JSON RPC version 2.0.'''
@classmethod
def _message_id(cls, message, require_id):
if not isinstance(message, dict):
raise ProtocolError.invalid_request(
'request object must be a dictionary')
if 'id' in message:
request_id = message['id']
if not isinstance(request_id, (Number, str, type(None))):
raise ProtocolError.invalid_request(
f'invalid "id": {request_id}')
return request_id
else:
if require_id:
raise ProtocolError.invalid_request('request has no "id"')
return None
@classmethod
def _validate_message(cls, message):
if message.get('jsonrpc') != '2.0':
raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"')
@classmethod
def _request_args(cls, request):
args = request.get('params', [])
if not isinstance(args, (dict, list)):
raise ProtocolError.invalid_args(
f'invalid request arguments: {args}')
return args
@classmethod
def response_value(cls, payload):
if 'result' in payload:
if 'error' in payload:
raise ProtocolError.invalid_request(
'response contains both "result" and "error"')
return payload['result']
if 'error' not in payload:
raise ProtocolError.invalid_request(
'response contains neither "result" nor "error"')
# Return an RPCError object
error = payload['error']
if isinstance(error, dict):
code = error.get('code')
message = error.get('message')
if isinstance(code, int) and isinstance(message, str):
return RPCError(code, message)
raise ProtocolError.invalid_request(
f'ill-formed response error object: {error}')
@classmethod
def request_payload(cls, request, request_id):
'''JSON v2 request (or notification) payload.'''
payload = {
'jsonrpc': '2.0',
'method': request.method,
}
# A notification?
if request_id is not None:
payload['id'] = request_id
# Preserve empty dicts as missing params is read as an array
if request.args or request.args == {}:
payload['params'] = request.args
return payload
@classmethod
def response_payload(cls, result, request_id):
'''JSON v2 response payload.'''
return {
'jsonrpc': '2.0',
'result': result,
'id': request_id
}
@classmethod
def error_payload(cls, error, request_id):
return {
'jsonrpc': '2.0',
'error': {'code': error.code, 'message': error.message},
'id': request_id
}
class JSONRPCLoose(JSONRPC):
'''A relaxed versin of JSON RPC.'''
# Don't be so loose we accept any old message ID
_message_id = JSONRPCv2._message_id
_validate_message = JSONRPC._validate_message
_request_args = JSONRPCv2._request_args
# Outoing messages are JSONRPCv2 so we give the other side the
# best chance to assume / detect JSONRPCv2 as default protocol.
error_payload = JSONRPCv2.error_payload
request_payload = JSONRPCv2.request_payload
response_payload = JSONRPCv2.response_payload
@classmethod
def response_value(cls, payload):
# Return result, unless it is None and there is an error
if payload.get('error') is not None:
if payload.get('result') is not None:
raise ProtocolError.invalid_request(
'response contains both "result" and "error"')
return JSONRPCv1._best_effort_error(payload['error'])
if 'result' not in payload:
raise ProtocolError.invalid_request(
'response contains neither "result" nor "error"')
# Can be None
return payload['result']
class JSONRPCAutoDetect(JSONRPCv2):
@classmethod
def message_to_item(cls, message):
return cls.detect_protocol(message), None
@classmethod
def detect_protocol(cls, message):
'''Attempt to detect the protocol from the message.'''
main = cls._message_to_payload(message)
def protocol_for_payload(payload):
if not isinstance(payload, dict):
return JSONRPCLoose # Will error
# Obey an explicit "jsonrpc"
version = payload.get('jsonrpc')
if version == '2.0':
return JSONRPCv2
if version == '1.0':
return JSONRPCv1
# Now to decide between JSONRPCLoose and JSONRPCv1 if possible
if 'result' in payload and 'error' in payload:
return JSONRPCv1
return JSONRPCLoose
if isinstance(main, list):
parts = set(protocol_for_payload(payload) for payload in main)
# If all same protocol, return it
if len(parts) == 1:
return parts.pop()
# If strict protocol detected, return it, preferring JSONRPCv2.
# This means a batch of JSONRPCv1 will fail
for protocol in (JSONRPCv2, JSONRPCv1):
if protocol in parts:
return protocol
# Will error if no parts
return JSONRPCLoose
return protocol_for_payload(main)
class JSONRPCConnection(object):
'''Maintains state of a JSON RPC connection, in particular
encapsulating the handling of request IDs.
protocol - the JSON RPC protocol to follow
max_response_size - responses over this size send an error response
instead.
'''
_id_counter = itertools.count()
def __init__(self, protocol):
self._protocol = protocol
# Sent Requests and Batches that have not received a response.
# The key is its request ID; for a batch it is sorted tuple
# of request IDs
self._requests = {}
# A public attribute intended to be settable dynamically
self.max_response_size = 0
def _oversized_response_message(self, request_id):
text = f'response too large (over {self.max_response_size:,d} bytes'
error = RPCError.invalid_request(text)
return self._protocol.response_message(error, request_id)
def _receive_response(self, result, request_id):
if request_id not in self._requests:
if request_id is None and isinstance(result, RPCError):
message = f'diagnostic error received: {result}'
else:
message = f'response to unsent request (ID: {request_id})'
raise ProtocolError.invalid_request(message) from None
request, event = self._requests.pop(request_id)
event.result = result
event.set()
return []
def _receive_request_batch(self, payloads):
def item_send_result(request_id, result):
nonlocal size
part = protocol.response_message(result, request_id)
size += len(part) + 2
if size > self.max_response_size > 0:
part = self._oversized_response_message(request_id)
parts.append(part)
if len(parts) == count:
return protocol.batch_message_from_parts(parts)
return None
parts = []
items = []
size = 0
count = 0
protocol = self._protocol
for payload in payloads:
try:
item, request_id = protocol._process_request(payload)
items.append(item)
if isinstance(item, Request):
count += 1
item.send_result = partial(item_send_result, request_id)
except ProtocolError as error:
count += 1
parts.append(error.error_message)
if not items and parts:
error = ProtocolError(0, "")
error.error_message = protocol.batch_message_from_parts(parts)
raise error
return items
def _receive_response_batch(self, payloads):
request_ids = []
results = []
for payload in payloads:
# Let ProtocolError exceptions through
item, request_id = self._protocol._process_response(payload)
request_ids.append(request_id)
results.append(item.result)
ordered = sorted(zip(request_ids, results), key=lambda t: t[0])
ordered_ids, ordered_results = zip(*ordered)
if ordered_ids not in self._requests:
raise ProtocolError.invalid_request('response to unsent batch')
request_batch, event = self._requests.pop(ordered_ids)
event.result = ordered_results
event.set()
return []
def _send_result(self, request_id, result):
message = self._protocol.response_message(result, request_id)
if len(message) > self.max_response_size > 0:
message = self._oversized_response_message(request_id)
return message
def _event(self, request, request_id):
event = Event()
self._requests[request_id] = (request, event)
return event
#
# External API
#
def send_request(self, request):
'''Send a Request. Return a (message, event) pair.
The message is an unframed message to send over the network.
Wait on the event for the response; which will be in the
"result" attribute.
Raises: ProtocolError if the request violates the protocol
in some way..
'''
request_id = next(self._id_counter)
message = self._protocol.request_message(request, request_id)
return message, self._event(request, request_id)
def send_notification(self, notification):
return self._protocol.notification_message(notification)
def send_batch(self, batch):
ids = tuple(next(self._id_counter)
for request in batch if isinstance(request, Request))
message = self._protocol.batch_message(batch, ids)
event = self._event(batch, ids) if ids else None
return message, event
def receive_message(self, message):
'''Call with an unframed message received from the network.
Raises: ProtocolError if the message violates the protocol in
some way. However, if it happened in a response that can be
paired with a request, the ProtocolError is instead set in the
result attribute of the send_request() that caused the error.
'''
try:
item, request_id = self._protocol.message_to_item(message)
except ProtocolError as e:
if e.response_msg_id is not id:
return self._receive_response(e, e.response_msg_id)
raise
if isinstance(item, Request):
item.send_result = partial(self._send_result, request_id)
return [item]
if isinstance(item, Notification):
return [item]
if isinstance(item, Response):
return self._receive_response(item.result, request_id)
if isinstance(item, list):
if all(isinstance(payload, dict)
and ('result' in payload or 'error' in payload)
for payload in item):
return self._receive_response_batch(item)
else:
return self._receive_request_batch(item)
else:
# Protocol auto-detection hack
assert issubclass(item, JSONRPC)
self._protocol = item
return self.receive_message(message)
def cancel_pending_requests(self):
'''Cancel all pending requests.'''
exception = CancelledError()
for request, event in self._requests.values():
event.result = exception
event.set()
self._requests.clear()
def pending_requests(self):
'''All sent requests that have not received a response.'''
return [request for request, event in self._requests.values()]
def handler_invocation(handler, request):
method, args = request.method, request.args
if handler is None:
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
f'unknown method "{method}"')
# We must test for too few and too many arguments. How
# depends on whether the arguments were passed as a list or as
# a dictionary.
info = signature_info(handler)
if isinstance(args, (tuple, list)):
if len(args) < info.min_args:
s = '' if len(args) == 1 else 's'
raise RPCError.invalid_args(
f'{len(args)} argument{s} passed to method '
f'"{method}" but it requires {info.min_args}')
if info.max_args is not None and len(args) > info.max_args:
s = '' if len(args) == 1 else 's'
raise RPCError.invalid_args(
f'{len(args)} argument{s} passed to method '
f'{method} taking at most {info.max_args}')
return partial(handler, *args)
# Arguments passed by name
if info.other_names is None:
raise RPCError.invalid_args(f'method "{method}" cannot '
f'be called with named arguments')
missing = set(info.required_names).difference(args)
if missing:
s = '' if len(missing) == 1 else 's'
missing = ', '.join(sorted(f'"{name}"' for name in missing))
raise RPCError.invalid_args(f'method "{method}" requires '
f'parameter{s} {missing}')
if info.other_names is not any:
excess = set(args).difference(info.required_names)
excess = excess.difference(info.other_names)
if excess:
s = '' if len(excess) == 1 else 's'
excess = ', '.join(sorted(f'"{name}"' for name in excess))
raise RPCError.invalid_args(f'method "{method}" does not '
f'take parameter{s} {excess}')
return partial(handler, **args)

549
torba/rpc/session.py Normal file
View file

@ -0,0 +1,549 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server',
'BatchError')
import asyncio
import logging
import time
from contextlib import suppress
from . import *
from .util import Concurrency
class Connector(object):
def __init__(self, session_factory, host=None, port=None, proxy=None,
**kwargs):
self.session_factory = session_factory
self.host = host
self.port = port
self.proxy = proxy
self.loop = kwargs.get('loop', asyncio.get_event_loop())
self.kwargs = kwargs
async def create_connection(self):
'''Initiate a connection.'''
connector = self.proxy or self.loop
return await connector.create_connection(
self.session_factory, self.host, self.port, **self.kwargs)
async def __aenter__(self):
transport, self.protocol = await self.create_connection()
# By default, do not limit outgoing connections
self.protocol.bw_limit = 0
return self.protocol
async def __aexit__(self, exc_type, exc_value, traceback):
await self.protocol.close()
class SessionBase(asyncio.Protocol):
'''Base class of networking sessions.
There is no client / server distinction other than who initiated
the connection.
To initiate a connection to a remote server pass host, port and
proxy to the constructor, and then call create_connection(). Each
successful call should have a corresponding call to close().
Alternatively if used in a with statement, the connection is made
on entry to the block, and closed on exit from the block.
'''
max_errors = 10
def __init__(self, *, framer=None, loop=None):
self.framer = framer or self.default_framer()
self.loop = loop or asyncio.get_event_loop()
self.logger = logging.getLogger(self.__class__.__name__)
self.transport = None
# Set when a connection is made
self._address = None
self._proxy_address = None
# For logger.debug messsages
self.verbosity = 0
# Cleared when the send socket is full
self._can_send = Event()
self._can_send.set()
self._pm_task = None
self._task_group = TaskGroup()
# Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics.
self.start_time = time.time()
self.errors = 0
self.send_count = 0
self.send_size = 0
self.last_send = self.start_time
self.recv_count = 0
self.recv_size = 0
self.last_recv = self.start_time
# Bandwidth usage per hour before throttling starts
self.bw_limit = 2000000
self.bw_time = self.start_time
self.bw_charge = 0
# Concurrency control
self.max_concurrent = 6
self._concurrency = Concurrency(self.max_concurrent)
async def _update_concurrency(self):
# A non-positive value means not to limit concurrency
if self.bw_limit <= 0:
return
now = time.time()
# Reduce the recorded usage in proportion to the elapsed time
refund = (now - self.bw_time) * (self.bw_limit / 3600)
self.bw_charge = max(0, self.bw_charge - int(refund))
self.bw_time = now
# Reduce concurrency allocation by 1 for each whole bw_limit used
throttle = int(self.bw_charge / self.bw_limit)
target = max(1, self.max_concurrent - throttle)
current = self._concurrency.max_concurrent
if target != current:
self.logger.info(f'changing task concurrency from {current} '
f'to {target}')
await self._concurrency.set_max_concurrent(target)
def _using_bandwidth(self, size):
'''Called when sending or receiving size bytes.'''
self.bw_charge += size
async def _process_messages(self):
'''Process incoming messages asynchronously and consume the
results.
'''
async def collect_tasks():
next_done = task_group.next_done
while True:
await next_done()
task_group = self._task_group
async with task_group:
await self.spawn(self._receive_messages)
await self.spawn(collect_tasks)
async def _limited_wait(self, secs):
# Wait at most secs seconds to send, otherwise abort the connection
try:
async with timeout_after(secs):
await self._can_send.wait()
except TaskTimeout:
self.abort()
raise
async def _send_message(self, message):
if not self._can_send.is_set():
await self._limited_wait(self.max_send_delay)
if not self.is_closing():
framed_message = self.framer.frame(message)
self.send_size += len(framed_message)
self._using_bandwidth(len(framed_message))
self.send_count += 1
self.last_send = time.time()
if self.verbosity >= 4:
self.logger.debug(f'Sending framed message {framed_message}')
self.transport.write(framed_message)
def _bump_errors(self):
self.errors += 1
if self.errors >= self.max_errors:
# Don't await self.close() because that is self-cancelling
self._close()
def _close(self):
if self.transport:
self.transport.close()
# asyncio framework
def data_received(self, framed_message):
'''Called by asyncio when a message comes in.'''
if self.verbosity >= 4:
self.logger.debug(f'Received framed message {framed_message}')
self.recv_size += len(framed_message)
self._using_bandwidth(len(framed_message))
self.framer.received_bytes(framed_message)
def pause_writing(self):
'''Transport calls when the send buffer is full.'''
if not self.is_closing():
self._can_send.clear()
self.transport.pause_reading()
def resume_writing(self):
'''Transport calls when the send buffer has room.'''
if not self._can_send.is_set():
self._can_send.set()
self.transport.resume_reading()
def connection_made(self, transport):
'''Called by asyncio when a connection is established.
Derived classes overriding this method must call this first.'''
self.transport = transport
# This would throw if called on a closed SSL transport. Fixed
# in asyncio in Python 3.6.1 and 3.5.4
peer_address = transport.get_extra_info('peername')
# If the Socks proxy was used then _address is already set to
# the remote address
if self._address:
self._proxy_address = peer_address
else:
self._address = peer_address
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop)
def connection_lost(self, exc):
'''Called by asyncio when the connection closes.
Tear down things done in connection_made.'''
self._address = None
self.transport = None
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
# External API
def default_framer(self):
'''Return a default framer.'''
raise NotImplementedError
def peer_address(self):
'''Returns the peer's address (Python networking address), or None if
no connection or an error.
This is the result of socket.getpeername() when the connection
was made.
'''
return self._address
def peer_address_str(self):
'''Returns the peer's IP address and port as a human-readable
string.'''
if not self._address:
return 'unknown'
ip_addr_str, port = self._address[:2]
if ':' in ip_addr_str:
return f'[{ip_addr_str}]:{port}'
else:
return f'{ip_addr_str}:{port}'
async def spawn(self, coro, *args):
'''If the session is connected, spawn a task that is cancelled
on disconnect, and return it. Otherwise return None.'''
group = self._task_group
if not group.closed():
return await group.spawn(coro, *args)
else:
return None
def is_closing(self):
'''Return True if the connection is closing.'''
return not self.transport or self.transport.is_closing()
def abort(self):
'''Forcefully close the connection.'''
if self.transport:
self.transport.abort()
async def close(self, *, force_after=30):
'''Close the connection and return when closed.'''
self._close()
if self._pm_task:
with suppress(CancelledError):
async with ignore_after(force_after):
await self._pm_task
self.abort()
await self._pm_task
class MessageSession(SessionBase):
'''Session class for protocols where messages are not tied to responses,
such as the Bitcoin protocol.
To use as a client (connection-opening) session, pass host, port
and perhaps a proxy.
'''
async def _receive_messages(self):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except BadMagicError as e:
magic, expected = e.args
self.logger.error(
f'bad network magic: got {magic} expected {expected}, '
f'disconnecting'
)
self._close()
except OversizedPayloadError as e:
command, payload_len = e.args
self.logger.error(
f'oversized payload of {payload_len:,d} bytes to command '
f'{command}, disconnecting'
)
self._close()
except BadChecksumError as e:
payload_checksum, claimed_checksum = e.args
self.logger.warning(
f'checksum mismatch: actual {payload_checksum.hex()} '
f'vs claimed {claimed_checksum.hex()}'
)
self._bump_errors()
else:
self.last_recv = time.time()
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
await self.spawn(self._throttled_message(message))
async def _throttled_message(self, message):
'''Process a single request, respecting the concurrency limit.'''
async with self._concurrency.semaphore:
try:
await self.handle_message(message)
except ProtocolError as e:
self.logger.error(f'{e}')
self._bump_errors()
except CancelledError:
raise
except Exception:
self.logger.exception(f'exception handling {message}')
self._bump_errors()
# External API
def default_framer(self):
'''Return a bitcoin framer.'''
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
async def handle_message(self, message):
'''message is a (command, payload) pair.'''
pass
async def send_message(self, message):
'''Send a message (command, payload) over the network.'''
await self._send_message(message)
class BatchError(Exception):
def __init__(self, request):
self.request = request # BatchRequest object
class BatchRequest(object):
'''Used to build a batch request to send to the server. Stores
the
Attributes batch and results are initially None.
Adding an invalid request or notification immediately raises a
ProtocolError.
On exiting the with clause, it will:
1) create a Batch object for the requests in the order they were
added. If the batch is empty this raises a ProtocolError.
2) set the "batch" attribute to be that batch
3) send the batch request and wait for a response
4) raise a ProtocolError if the protocol was violated by the
server. Currently this only happens if it gave more than one
response to any request
5) otherwise there is precisely one response to each Request. Set
the "results" attribute to the tuple of results; the responses
are ordered to match the Requests in the batch. Notifications
do not get a response.
6) if raise_errors is True and any individual response was a JSON
RPC error response, or violated the protocol in some way, a
BatchError exception is raised. Otherwise the caller can be
certain each request returned a standard result.
'''
def __init__(self, session, raise_errors):
self._session = session
self._raise_errors = raise_errors
self._requests = []
self.batch = None
self.results = None
def add_request(self, method, args=()):
self._requests.append(Request(method, args))
def add_notification(self, method, args=()):
self._requests.append(Notification(method, args))
def __len__(self):
return len(self._requests)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self.batch = Batch(self._requests)
message, event = self._session.connection.send_batch(self.batch)
await self._session._send_message(message)
await event.wait()
self.results = event.result
if self._raise_errors:
if any(isinstance(item, Exception) for item in event.result):
raise BatchError(self)
class RPCSession(SessionBase):
'''Base class for protocols where a message can lead to a response,
for example JSON RPC.'''
def __init__(self, *, framer=None, loop=None, connection=None):
super().__init__(framer=framer, loop=loop)
self.connection = connection or self.default_connection()
async def _receive_messages(self):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except MemoryError as e:
self.logger.warning(f'{e!r}')
continue
self.last_recv = time.time()
self.recv_count += 1
if self.recv_count % 10 == 0:
await self._update_concurrency()
try:
requests = self.connection.receive_message(message)
except ProtocolError as e:
self.logger.debug(f'{e}')
if e.error_message:
await self._send_message(e.error_message)
if e.code == JSONRPC.PARSE_ERROR:
self.max_errors = 0
self._bump_errors()
else:
for request in requests:
await self.spawn(self._throttled_request(request))
async def _throttled_request(self, request):
'''Process a single request, respecting the concurrency limit.'''
async with self._concurrency.semaphore:
try:
result = await self.handle_request(request)
except (ProtocolError, RPCError) as e:
result = e
except CancelledError:
raise
except Exception:
self.logger.exception(f'exception handling {request}')
result = RPCError(JSONRPC.INTERNAL_ERROR,
'internal server error')
if isinstance(request, Request):
message = request.send_result(result)
if message:
await self._send_message(message)
if isinstance(result, Exception):
self._bump_errors()
def connection_lost(self, exc):
# Cancel pending requests and message processing
self.connection.cancel_pending_requests()
super().connection_lost(exc)
# External API
def default_connection(self):
'''Return a default connection if the user provides none.'''
return JSONRPCConnection(JSONRPCv2)
def default_framer(self):
'''Return a default framer.'''
return NewlineFramer()
async def handle_request(self, request):
pass
async def send_request(self, method, args=()):
'''Send an RPC request over the network.'''
message, event = self.connection.send_request(Request(method, args))
await self._send_message(message)
await event.wait()
result = event.result
if isinstance(result, Exception):
raise result
return result
async def send_notification(self, method, args=()):
'''Send an RPC notification over the network.'''
message = self.connection.send_notification(Notification(method, args))
await self._send_message(message)
def send_batch(self, raise_errors=False):
'''Return a BatchRequest. Intended to be used like so:
async with session.send_batch() as batch:
batch.add_request("method1")
batch.add_request("sum", (x, y))
batch.add_notification("updated")
for result in batch.results:
...
Note that in some circumstances exceptions can be raised; see
BatchRequest doc string.
'''
return BatchRequest(self, raise_errors)
class Server(object):
'''A simple wrapper around an asyncio.Server object.'''
def __init__(self, session_factory, host=None, port=None, *,
loop=None, **kwargs):
self.host = host
self.port = port
self.loop = loop or asyncio.get_event_loop()
self.server = None
self._session_factory = session_factory
self._kwargs = kwargs
async def listen(self):
self.server = await self.loop.create_server(
self._session_factory, self.host, self.port, **self._kwargs)
async def close(self):
'''Close the listening socket. This does not close any ServerSession
objects created to handle incoming connections.
'''
if self.server:
self.server.close()
await self.server.wait_closed()
self.server = None

439
torba/rpc/socks.py Normal file
View file

@ -0,0 +1,439 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''SOCKS proxying.'''
import sys
import asyncio
import collections
import ipaddress
import socket
import struct
from functools import partial
__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
class SOCKSError(Exception):
'''Base class for SOCKS exceptions. Each raised exception will be
an instance of a derived class.'''
class SOCKSProtocolError(SOCKSError):
'''Raised when the proxy does not follow the SOCKS protocol'''
class SOCKSFailure(SOCKSError):
'''Raised when the proxy refuses or fails to make a connection'''
class NeedData(Exception):
pass
class SOCKSBase(object):
@classmethod
def name(cls):
return cls.__name__
def __init__(self):
self._buffer = bytes()
self._state = self._start
def _read(self, size):
if len(self._buffer) < size:
raise NeedData(size - len(self._buffer))
result = self._buffer[:size]
self._buffer = self._buffer[size:]
return result
def receive_data(self, data):
self._buffer += data
def next_message(self):
return self._state()
class SOCKS4(SOCKSBase):
'''SOCKS4 protocol wrapper.'''
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
REPLY_CODES = {
90: 'request granted',
91: 'request rejected or failed',
92: ('request rejected because SOCKS server cannot connect '
'to identd on the client'),
93: ('request rejected because the client program and identd '
'report different user-ids')
}
def __init__(self, dst_host, dst_port, auth):
super().__init__()
self._dst_host = self._check_host(dst_host)
self._dst_port = dst_port
self._auth = auth
@classmethod
def _check_host(cls, host):
if not isinstance(host, ipaddress.IPv4Address):
try:
host = ipaddress.IPv4Address(host)
except ValueError:
raise SOCKSProtocolError(
f'SOCKS4 requires an IPv4 address: {host}') from None
return host
def _start(self):
self._state = self._first_response
if isinstance(self._dst_host, ipaddress.IPv4Address):
# SOCKS4
dst_ip_packed = self._dst_host.packed
host_bytes = b''
else:
# SOCKS4a
dst_ip_packed = b'\0\0\0\1'
host_bytes = self._dst_host.encode() + b'\0'
if isinstance(self._auth, SOCKSUserAuth):
user_id = self._auth.username.encode()
else:
user_id = b''
# Send TCP/IP stream CONNECT request
return b''.join([b'\4\1', struct.pack('>H', self._dst_port),
dst_ip_packed, user_id, b'\0', host_bytes])
def _first_response(self):
# Wait for 8-byte response
data = self._read(8)
if data[0] != 0:
raise SOCKSProtocolError(f'invalid {self.name()} proxy '
f'response: {data}')
reply_code = data[1]
if reply_code != 90:
msg = self.REPLY_CODES.get(
reply_code, f'unknown {self.name()} reply code {reply_code}')
raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
# Other fields ignored
return None
class SOCKS4a(SOCKS4):
@classmethod
def _check_host(cls, host):
if not isinstance(host, (str, ipaddress.IPv4Address)):
raise SOCKSProtocolError(
f'SOCKS4a requires an IPv4 address or host name: {host}')
return host
class SOCKS5(SOCKSBase):
'''SOCKS protocol wrapper.'''
# See https://tools.ietf.org/html/rfc1928
ERROR_CODES = {
1: 'general SOCKS server failure',
2: 'connection not allowed by ruleset',
3: 'network unreachable',
4: 'host unreachable',
5: 'connection refused',
6: 'TTL expired',
7: 'command not supported',
8: 'address type not supported',
}
def __init__(self, dst_host, dst_port, auth):
super().__init__()
self._dst_bytes = self._destination_bytes(dst_host, dst_port)
self._auth_bytes, self._auth_methods = self._authentication(auth)
def _destination_bytes(self, host, port):
if isinstance(host, ipaddress.IPv4Address):
addr_bytes = b'\1' + host.packed
elif isinstance(host, ipaddress.IPv6Address):
addr_bytes = b'\4' + host.packed
elif isinstance(host, str):
host = host.encode()
if len(host) > 255:
raise SOCKSProtocolError(f'hostname too long: '
f'{len(host)} bytes')
addr_bytes = b'\3' + bytes([len(host)]) + host
else:
raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 '
f'address, or host name: {host}')
return addr_bytes + struct.pack('>H', port)
def _authentication(self, auth):
if isinstance(auth, SOCKSUserAuth):
user_bytes = auth.username.encode()
if not 0 < len(user_bytes) < 256:
raise SOCKSProtocolError(f'username {auth.username} has '
f'invalid length {len(user_bytes)}')
pwd_bytes = auth.password.encode()
if not 0 < len(pwd_bytes) < 256:
raise SOCKSProtocolError(f'password has invalid length '
f'{len(pwd_bytes)}')
return b''.join([bytes([1, len(user_bytes)]), user_bytes,
bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
return b'', [0]
def _start(self):
self._state = self._first_response
return (b'\5' + bytes([len(self._auth_methods)])
+ bytes(m for m in self._auth_methods))
def _first_response(self):
# Wait for 2-byte response
data = self._read(2)
if data[0] != 5:
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
if data[1] not in self._auth_methods:
raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
# Authenticate if user-password authentication
if data[1] == 2:
self._state = self._auth_response
return self._auth_bytes
return self._request_connection()
def _auth_response(self):
data = self._read(2)
if data[0] != 1:
raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
f'response: {data}')
if data[1] != 0:
raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
f'{data[1]}')
return self._request_connection()
def _request_connection(self):
# Send connection request
self._state = self._connect_response
return b'\5\1\0' + self._dst_bytes
def _connect_response(self):
data = self._read(5)
if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
if data[1] != 0:
raise SOCKSFailure(self.ERROR_CODES.get(
data[1], f'unknown SOCKS5 error code: {data[1]}'))
if data[3] == 1:
addr_len = 3 # IPv4
elif data[3] == 3:
addr_len = data[4] # Hostname
else:
addr_len = 15 # IPv6
self._state = partial(self._connect_response_rest, addr_len)
return self.next_message()
def _connect_response_rest(self, addr_len):
self._read(addr_len + 2)
return None
class SOCKSProxy(object):
def __init__(self, address, protocol, auth):
'''A SOCKS proxy at an address following a SOCKS protocol. auth is an
authentication method to use when connecting, or None.
address is a (host, port) pair; for IPv6 it can instead be a
(host, port, flowinfo, scopeid) 4-tuple.
'''
self.address = address
self.protocol = protocol
self.auth = auth
# Set on each successful connection via the proxy to the
# result of socket.getpeername()
self.peername = None
def __str__(self):
auth = 'username' if self.auth else 'none'
return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
async def _handshake(self, client, sock, loop):
while True:
count = 0
try:
message = client.next_message()
except NeedData as e:
count = e.args[0]
else:
if message is None:
return
await loop.sock_sendall(sock, message)
if count:
data = await loop.sock_recv(sock, count)
if not data:
raise SOCKSProtocolError("EOF received")
client.receive_data(data)
async def _connect_one(self, host, port):
'''Connect to the proxy and perform a handshake requesting a
connection to (host, port).
Return the open socket on success, or the exception on failure.
'''
client = self.protocol(host, port, self.auth)
sock = socket.socket()
loop = asyncio.get_event_loop()
try:
# A non-blocking socket is required by loop socket methods
sock.setblocking(False)
await loop.sock_connect(sock, self.address)
await self._handshake(client, sock, loop)
self.peername = sock.getpeername()
return sock
except Exception as e:
# Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8
if sys.platform.startswith('linux'):
sock.close()
return e
async def _connect(self, addresses):
'''Connect to the proxy and perform a handshake requesting a
connection to each address in addresses.
Return an (open_socket, address) pair on success.
'''
assert len(addresses) > 0
exceptions = []
for address in addresses:
host, port = address[:2]
sock = await self._connect_one(host, port)
if isinstance(sock, socket.socket):
return sock, address
exceptions.append(sock)
strings = set(f'{exc!r}' for exc in exceptions)
raise (exceptions[0] if len(strings) == 1 else
OSError(f'multiple exceptions: {", ".join(strings)}'))
async def _detect_proxy(self):
'''Return True if it appears we can connect to a SOCKS proxy,
otherwise False.
'''
if self.protocol is SOCKS4a:
host, port = 'www.apple.com', 80
else:
host, port = ipaddress.IPv4Address('8.8.8.8'), 53
sock = await self._connect_one(host, port)
if isinstance(sock, socket.socket):
sock.close()
return True
# SOCKSFailure indicates something failed, but that we are
# likely talking to a proxy
return isinstance(sock, SOCKSFailure)
@classmethod
async def auto_detect_address(cls, address, auth):
'''Try to detect a SOCKS proxy at address using the authentication
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
order. If a SOCKS proxy is detected a SOCKSProxy object is
returned.
Returning a SOCKSProxy does not mean it is functioning - for
example, it may have no network connectivity.
If no proxy is detected return None.
'''
for protocol in (SOCKS5, SOCKS4a, SOCKS4):
proxy = cls(address, protocol, auth)
if await proxy._detect_proxy():
return proxy
return None
@classmethod
async def auto_detect_host(cls, host, ports, auth):
'''Try to detect a SOCKS proxy on a host on one of the ports.
Calls auto_detect for the ports in order. Returns SOCKS are
tried in order; a SOCKSProxy object for the first detected
proxy is returned.
Returning a SOCKSProxy does not mean it is functioning - for
example, it may have no network connectivity.
If no proxy is detected return None.
'''
for port in ports:
address = (host, port)
proxy = await cls.auto_detect_address(address, auth)
if proxy:
return proxy
return None
async def create_connection(self, protocol_factory, host, port, *,
resolve=False, ssl=None,
family=0, proto=0, flags=0):
'''Set up a connection to (host, port) through the proxy.
If resolve is True then host is resolved locally with
getaddrinfo using family, proto and flags, otherwise the proxy
is asked to resolve host.
The function signature is similar to loop.create_connection()
with the same result. The attribute _address is set on the
protocol to the address of the successful remote connection.
Additionally raises SOCKSError if something goes wrong with
the proxy handshake.
'''
loop = asyncio.get_event_loop()
if resolve:
infos = await loop.getaddrinfo(host, port, family=family,
type=socket.SOCK_STREAM,
proto=proto, flags=flags)
addresses = [info[4] for info in infos]
else:
addresses = [(host, port)]
sock, address = await self._connect(addresses)
def set_address():
protocol = protocol_factory()
protocol._address = address
return protocol
return await loop.create_connection(
set_address, sock=sock, ssl=ssl,
server_hostname=host if ssl else None)

120
torba/rpc/util.py Normal file
View file

@ -0,0 +1,120 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
__all__ = ()
import asyncio
from collections import namedtuple
from functools import partial
import inspect
def normalize_corofunc(corofunc, args):
if asyncio.iscoroutine(corofunc):
if args != ():
raise ValueError('args cannot be passed with a coroutine')
return corofunc
return corofunc(*args)
def is_async_call(func):
'''inspect.iscoroutinefunction that looks through partials.'''
while isinstance(func, partial):
func = func.func
return inspect.iscoroutinefunction(func)
# other_params: None means cannot be called with keyword arguments only
# any means any name is good
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
'required_names other_names')
def signature_info(func):
params = inspect.signature(func).parameters
min_args = max_args = 0
required_names = []
other_names = []
no_names = False
for p in params.values():
if p.kind == p.POSITIONAL_OR_KEYWORD:
max_args += 1
if p.default is p.empty:
min_args += 1
required_names.append(p.name)
else:
other_names.append(p.name)
elif p.kind == p.KEYWORD_ONLY:
other_names.append(p.name)
elif p.kind == p.VAR_POSITIONAL:
max_args = None
elif p.kind == p.VAR_KEYWORD:
other_names = any
elif p.kind == p.POSITIONAL_ONLY:
max_args += 1
if p.default is p.empty:
min_args += 1
no_names = True
if no_names:
other_names = None
return SignatureInfo(min_args, max_args, required_names, other_names)
class Concurrency(object):
def __init__(self, max_concurrent):
self._require_non_negative(max_concurrent)
self._max_concurrent = max_concurrent
self.semaphore = asyncio.Semaphore(max_concurrent)
def _require_non_negative(self, value):
if not isinstance(value, int) or value < 0:
raise RuntimeError('concurrency must be a natural number')
@property
def max_concurrent(self):
return self._max_concurrent
async def set_max_concurrent(self, value):
self._require_non_negative(value)
diff = value - self._max_concurrent
self._max_concurrent = value
if diff >= 0:
for _ in range(diff):
self.semaphore.release()
else:
for _ in range(-diff):
await self.semaphore.acquire()
def check_task(logger, task):
if not task.cancelled():
try:
task.result()
except Exception:
logger.error('task crashed: %r', task, exc_info=True)

View file

@ -15,7 +15,7 @@ from struct import pack, unpack
import time
from functools import partial
from aiorpcx import TaskGroup, run_in_thread
from torba.rpc import TaskGroup, run_in_thread
import torba
from torba.server.daemon import DaemonError

View file

@ -22,7 +22,7 @@ from torba.server.util import hex_to_bytes, class_logger,\
unpack_le_uint16_from, pack_varint
from torba.server.hash import hex_str_to_hash, hash_to_hex_str
from torba.server.tx import DeserializerDecred
from aiorpcx import JSONRPC
from torba.rpc import JSONRPC
class DaemonError(Exception):

View file

@ -19,8 +19,8 @@ from glob import glob
from struct import pack, unpack
import attr
from aiorpcx import run_in_thread, sleep
from torba.rpc import run_in_thread, sleep
from torba.server import util
from torba.server.hash import hash_to_hex_str, HASHX_LEN
from torba.server.merkle import Merkle, MerkleCache

View file

@ -14,8 +14,8 @@ from asyncio import Lock
from collections import defaultdict
import attr
from aiorpcx import TaskGroup, run_in_thread, sleep
from torba.rpc import TaskGroup, run_in_thread, sleep
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
from torba.server.util import class_logger, chunks
from torba.server.db import UTXO

View file

@ -28,8 +28,7 @@
from math import ceil, log
from aiorpcx import Event
from torba.rpc import Event
from torba.server.hash import double_sha256

View file

@ -14,11 +14,10 @@ import ssl
import time
from collections import defaultdict, Counter
from aiorpcx import (Connector, RPCSession, SOCKSProxy,
from torba.rpc import (Connector, RPCSession, SOCKSProxy,
Notification, handler_invocation,
SOCKSError, RPCError, TaskTimeout, TaskGroup, Event,
sleep, run_in_thread, ignore_after, timeout_after)
from torba.server.peer import Peer
from torba.server.util import class_logger, protocol_tuple

View file

@ -19,13 +19,12 @@ import time
from collections import defaultdict
from functools import partial
from aiorpcx import (
import torba
from torba.rpc import (
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep,
Event
)
import torba
from torba.server import text
from torba.server import util
from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
@ -664,9 +663,9 @@ class SessionBase(RPCSession):
super().connection_lost(exc)
self.session_mgr.remove_session(self)
msg = ''
if not self.can_send.is_set():
if not self._can_send.is_set():
msg += ' whilst paused'
if self.concurrency.max_concurrent != self.max_concurrent:
if self._concurrency.max_concurrent != self.max_concurrent:
msg += ' whilst throttled'
if self.send_size >= 1024*1024:
msg += ('. Sent {:,d} bytes in {:,d} messages'