merged aiorpcx into torba.rpc
This commit is contained in:
parent
ba5fc2a627
commit
899a6f0d4a
15 changed files with 2582 additions and 14 deletions
1
setup.py
1
setup.py
|
@ -10,7 +10,6 @@ with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh:
|
|||
|
||||
REQUIRES = [
|
||||
'aiohttp',
|
||||
'aiorpcx==0.9.0',
|
||||
'coincurve',
|
||||
'pbkdf2',
|
||||
'cryptography',
|
||||
|
|
13
torba/rpc/__init__.py
Normal file
13
torba/rpc/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from .curio import *
|
||||
from .framing import *
|
||||
from .jsonrpc import *
|
||||
from .socks import *
|
||||
from .session import *
|
||||
from .util import *
|
||||
|
||||
__all__ = (curio.__all__ +
|
||||
framing.__all__ +
|
||||
jsonrpc.__all__ +
|
||||
socks.__all__ +
|
||||
session.__all__ +
|
||||
util.__all__)
|
411
torba/rpc/curio.py
Normal file
411
torba/rpc/curio.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
# The code below is mostly my own but based on the interfaces of the
|
||||
# curio library by David Beazley. I'm considering switching to using
|
||||
# curio. In the mean-time this is an attempt to provide a similar
|
||||
# clean, pure-async interface and move away from direct
|
||||
# framework-specific dependencies. As asyncio differs in its design
|
||||
# it is not possible to provide identical semantics.
|
||||
#
|
||||
# The curio library is distributed under the following licence:
|
||||
#
|
||||
# Copyright (C) 2015-2017
|
||||
# David Beazley (Dabeaz LLC)
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are
|
||||
# met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of the David Beazley or Dabeaz LLC may be used to
|
||||
# endorse or promote products derived from this software without
|
||||
# specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from asyncio import (
|
||||
CancelledError, get_event_loop, Queue, Event, Lock, Semaphore,
|
||||
sleep, Task
|
||||
)
|
||||
from collections import deque
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
from .util import normalize_corofunc, check_task
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError',
|
||||
'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup',
|
||||
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
|
||||
'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at',
|
||||
)
|
||||
|
||||
|
||||
async def run_in_thread(func, *args):
|
||||
'''Run a function in a separate thread, and await its completion.'''
|
||||
return await get_event_loop().run_in_executor(None, func, *args)
|
||||
|
||||
|
||||
async def spawn(coro, *args, loop=None, report_crash=True):
|
||||
return spawn_sync(coro, *args, loop=loop, report_crash=report_crash)
|
||||
|
||||
|
||||
def spawn_sync(coro, *args, loop=None, report_crash=True):
|
||||
coro = normalize_corofunc(coro, args)
|
||||
loop = loop or get_event_loop()
|
||||
task = loop.create_task(coro)
|
||||
if report_crash:
|
||||
task.add_done_callback(partial(check_task, logging))
|
||||
return task
|
||||
|
||||
|
||||
class TaskGroup(object):
|
||||
'''A class representing a group of executing tasks. tasks is an
|
||||
optional set of existing tasks to put into the group. New tasks
|
||||
can later be added using the spawn() method below. wait specifies
|
||||
the policy used for waiting for tasks. See the join() method
|
||||
below. Each TaskGroup is an independent entity. Task groups do not
|
||||
form a hierarchy or any kind of relationship to other previously
|
||||
created task groups or tasks. Moreover, Tasks created by the top
|
||||
level spawn() function are not placed into any task group. To
|
||||
create a task in a group, it should be created using
|
||||
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
|
||||
|
||||
completed attribute: the first task that completed with a result
|
||||
in the group. Takes into account the wait option used in the
|
||||
TaskGroup constructor (but not in the join method)`.
|
||||
'''
|
||||
|
||||
def __init__(self, tasks=(), *, wait=all):
|
||||
if wait not in (any, all, object):
|
||||
raise ValueError('invalid wait argument')
|
||||
self._done = deque()
|
||||
self._pending = set()
|
||||
self._wait = wait
|
||||
self._done_event = Event()
|
||||
self._logger = logging.getLogger(self.__class__.__name__)
|
||||
self._closed = False
|
||||
self.completed = None
|
||||
for task in tasks:
|
||||
self._add_task(task)
|
||||
|
||||
def _add_task(self, task):
|
||||
'''Add an already existing task to the task group.'''
|
||||
if hasattr(task, '_task_group'):
|
||||
raise RuntimeError('task is already part of a group')
|
||||
if self._closed:
|
||||
raise RuntimeError('task group is closed')
|
||||
task._task_group = self
|
||||
if task.done():
|
||||
self._done.append(task)
|
||||
else:
|
||||
self._pending.add(task)
|
||||
task.add_done_callback(self._on_done)
|
||||
|
||||
def _on_done(self, task):
|
||||
task._task_group = None
|
||||
self._pending.remove(task)
|
||||
self._done.append(task)
|
||||
self._done_event.set()
|
||||
if self.completed is None:
|
||||
if not task.cancelled() and not task.exception():
|
||||
if self._wait is object and task.result() is None:
|
||||
pass
|
||||
else:
|
||||
self.completed = task
|
||||
|
||||
async def spawn(self, coro, *args):
|
||||
'''Create a new task that’s part of the group. Returns a Task
|
||||
instance.
|
||||
'''
|
||||
task = await spawn(coro, *args, report_crash=False)
|
||||
self._add_task(task)
|
||||
return task
|
||||
|
||||
async def add_task(self, task):
|
||||
'''Add an already existing task to the task group.'''
|
||||
self._add_task(task)
|
||||
|
||||
async def next_done(self):
|
||||
'''Returns the next completed task. Returns None if no more tasks
|
||||
remain. A TaskGroup may also be used as an asynchronous iterator.
|
||||
'''
|
||||
if not self._done and self._pending:
|
||||
self._done_event.clear()
|
||||
await self._done_event.wait()
|
||||
if self._done:
|
||||
return self._done.popleft()
|
||||
return None
|
||||
|
||||
async def next_result(self):
|
||||
'''Returns the result of the next completed task. If the task failed
|
||||
with an exception, that exception is raised. A RuntimeError
|
||||
exception is raised if this is called when no remaining tasks
|
||||
are available.'''
|
||||
task = await self.next_done()
|
||||
if not task:
|
||||
raise RuntimeError('no tasks remain')
|
||||
return task.result()
|
||||
|
||||
async def join(self):
|
||||
'''Wait for tasks in the group to terminate according to the wait
|
||||
policy for the group.
|
||||
|
||||
If the join() operation itself is cancelled, all remaining
|
||||
tasks in the group are also cancelled.
|
||||
|
||||
If a TaskGroup is used as a context manager, the join() method
|
||||
is called on context-exit.
|
||||
|
||||
Once join() returns, no more tasks may be added to the task
|
||||
group. Tasks can be added while join() is running.
|
||||
'''
|
||||
def errored(task):
|
||||
return not task.cancelled() and task.exception()
|
||||
|
||||
try:
|
||||
if self._wait in (all, object):
|
||||
while True:
|
||||
task = await self.next_done()
|
||||
if task is None:
|
||||
return
|
||||
if errored(task):
|
||||
break
|
||||
if self._wait is object:
|
||||
if task.cancelled() or task.result() is not None:
|
||||
return
|
||||
else: # any
|
||||
task = await self.next_done()
|
||||
if task is None or not errored(task):
|
||||
return
|
||||
finally:
|
||||
await self.cancel_remaining()
|
||||
|
||||
if errored(task):
|
||||
raise task.exception()
|
||||
|
||||
async def cancel_remaining(self):
|
||||
'''Cancel all remaining tasks.'''
|
||||
self._closed = True
|
||||
for task in list(self._pending):
|
||||
task.cancel()
|
||||
with suppress(CancelledError):
|
||||
await task
|
||||
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
task = await self.next_done()
|
||||
if task:
|
||||
return task
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type:
|
||||
await self.cancel_remaining()
|
||||
else:
|
||||
await self.join()
|
||||
|
||||
|
||||
class TaskTimeout(CancelledError):
|
||||
|
||||
def __init__(self, secs):
|
||||
self.secs = secs
|
||||
|
||||
def __str__(self):
|
||||
return f'task timed out after {self.args[0]}s'
|
||||
|
||||
|
||||
class TimeoutCancellationError(CancelledError):
|
||||
pass
|
||||
|
||||
|
||||
class UncaughtTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _set_new_deadline(task, deadline):
|
||||
def timeout_task():
|
||||
# Unfortunately task.cancel is all we can do with asyncio
|
||||
task.cancel()
|
||||
task._timed_out = deadline
|
||||
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
|
||||
|
||||
|
||||
def _set_task_deadline(task, deadline):
|
||||
deadlines = getattr(task, '_deadlines', [])
|
||||
if deadlines:
|
||||
if deadline < min(deadlines):
|
||||
task._deadline_handle.cancel()
|
||||
_set_new_deadline(task, deadline)
|
||||
else:
|
||||
_set_new_deadline(task, deadline)
|
||||
deadlines.append(deadline)
|
||||
task._deadlines = deadlines
|
||||
task._timed_out = None
|
||||
|
||||
|
||||
def _unset_task_deadline(task):
|
||||
deadlines = task._deadlines
|
||||
timed_out_deadline = task._timed_out
|
||||
uncaught = timed_out_deadline not in deadlines
|
||||
task._deadline_handle.cancel()
|
||||
deadlines.pop()
|
||||
if deadlines:
|
||||
_set_new_deadline(task, min(deadlines))
|
||||
return timed_out_deadline, uncaught
|
||||
|
||||
|
||||
class TimeoutAfter(object):
|
||||
|
||||
def __init__(self, deadline, *, ignore=False, absolute=False):
|
||||
self._deadline = deadline
|
||||
self._ignore = ignore
|
||||
self._absolute = absolute
|
||||
self.expired = False
|
||||
|
||||
async def __aenter__(self):
|
||||
task = asyncio.current_task()
|
||||
loop_time = task._loop.time()
|
||||
if self._absolute:
|
||||
self._secs = self._deadline - loop_time
|
||||
else:
|
||||
self._secs = self._deadline
|
||||
self._deadline += loop_time
|
||||
_set_task_deadline(task, self._deadline)
|
||||
self.expired = False
|
||||
self._task = task
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
|
||||
if exc_type not in (CancelledError, TaskTimeout,
|
||||
TimeoutCancellationError):
|
||||
return False
|
||||
if timed_out_deadline == self._deadline:
|
||||
self.expired = True
|
||||
if self._ignore:
|
||||
return True
|
||||
raise TaskTimeout(self._secs) from None
|
||||
if timed_out_deadline is None:
|
||||
assert exc_type is CancelledError
|
||||
return False
|
||||
if uncaught:
|
||||
raise UncaughtTimeoutError('uncaught timeout received')
|
||||
if exc_type is TimeoutCancellationError:
|
||||
return False
|
||||
raise TimeoutCancellationError(timed_out_deadline) from None
|
||||
|
||||
|
||||
async def _timeout_after_func(seconds, absolute, coro, args):
|
||||
coro = normalize_corofunc(coro, args)
|
||||
async with TimeoutAfter(seconds, absolute=absolute):
|
||||
return await coro
|
||||
|
||||
|
||||
def timeout_after(seconds, coro=None, *args):
|
||||
'''Execute the specified coroutine and return its result. However,
|
||||
issue a cancellation request to the calling task after seconds
|
||||
have elapsed. When this happens, a TaskTimeout exception is
|
||||
raised. If coro is None, the result of this function serves
|
||||
as an asynchronous context manager that applies a timeout to a
|
||||
block of statements.
|
||||
|
||||
timeout_after() may be composed with other timeout_after()
|
||||
operations (i.e., nested timeouts). If an outer timeout expires
|
||||
first, then TimeoutCancellationError is raised instead of
|
||||
TaskTimeout. If an inner timeout expires and fails to properly
|
||||
TaskTimeout, a UncaughtTimeoutError is raised in the outer
|
||||
timeout.
|
||||
|
||||
'''
|
||||
if coro:
|
||||
return _timeout_after_func(seconds, False, coro, args)
|
||||
|
||||
return TimeoutAfter(seconds)
|
||||
|
||||
|
||||
def timeout_at(clock, coro=None, *args):
|
||||
'''Execute the specified coroutine and return its result. However,
|
||||
issue a cancellation request to the calling task after seconds
|
||||
have elapsed. When this happens, a TaskTimeout exception is
|
||||
raised. If coro is None, the result of this function serves
|
||||
as an asynchronous context manager that applies a timeout to a
|
||||
block of statements.
|
||||
|
||||
timeout_after() may be composed with other timeout_after()
|
||||
operations (i.e., nested timeouts). If an outer timeout expires
|
||||
first, then TimeoutCancellationError is raised instead of
|
||||
TaskTimeout. If an inner timeout expires and fails to properly
|
||||
TaskTimeout, a UncaughtTimeoutError is raised in the outer
|
||||
timeout.
|
||||
|
||||
'''
|
||||
if coro:
|
||||
return _timeout_after_func(clock, True, coro, args)
|
||||
|
||||
return TimeoutAfter(clock, absolute=True)
|
||||
|
||||
|
||||
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
|
||||
coro = normalize_corofunc(coro, args)
|
||||
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
|
||||
return await coro
|
||||
|
||||
return timeout_result
|
||||
|
||||
|
||||
def ignore_after(seconds, coro=None, *args, timeout_result=None):
|
||||
'''Execute the specified coroutine and return its result. Issue a
|
||||
cancellation request after seconds have elapsed. When a timeout
|
||||
occurs, no exception is raised. Instead, timeout_result is
|
||||
returned.
|
||||
|
||||
If coro is None, the result is an asynchronous context manager
|
||||
that applies a timeout to a block of statements. For the context
|
||||
manager case, the resulting context manager object has an expired
|
||||
attribute set to True if time expired.
|
||||
|
||||
Note: ignore_after() may also be composed with other timeout
|
||||
operations. TimeoutCancellationError and UncaughtTimeoutError
|
||||
exceptions might be raised according to the same rules as for
|
||||
timeout_after().
|
||||
'''
|
||||
if coro:
|
||||
return _ignore_after_func(seconds, False, coro, args, timeout_result)
|
||||
|
||||
return TimeoutAfter(seconds, ignore=True)
|
||||
|
||||
|
||||
def ignore_at(clock, coro=None, *args, timeout_result=None):
|
||||
'''
|
||||
Stop the enclosed task or block of code at an absolute
|
||||
clock value. Same usage as ignore_after().
|
||||
'''
|
||||
if coro:
|
||||
return _ignore_after_func(clock, True, coro, args, timeout_result)
|
||||
|
||||
return TimeoutAfter(clock, absolute=True, ignore=True)
|
239
torba/rpc/framing.py
Normal file
239
torba/rpc/framing.py
Normal file
|
@ -0,0 +1,239 @@
|
|||
# Copyright (c) 2018, Neil Booth
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
'''RPC message framing in a byte stream.'''
|
||||
|
||||
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
|
||||
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
|
||||
|
||||
from hashlib import sha256 as _sha256
|
||||
from struct import Struct
|
||||
from asyncio import Queue
|
||||
|
||||
|
||||
class FramerBase(object):
|
||||
'''Abstract base class for a framer.
|
||||
|
||||
A framer breaks an incoming byte stream into protocol messages,
|
||||
buffering if necesary. It also frames outgoing messages into
|
||||
a byte stream.
|
||||
'''
|
||||
|
||||
def frame(self, message):
|
||||
'''Return the framed message.'''
|
||||
raise NotImplementedError
|
||||
|
||||
def received_bytes(self, data):
|
||||
'''Pass incoming network bytes.'''
|
||||
raise NotImplementedError
|
||||
|
||||
async def receive_message(self):
|
||||
'''Wait for a complete unframed message to arrive, and return it.'''
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NewlineFramer(FramerBase):
|
||||
'''A framer for a protocol where messages are separated by newlines.'''
|
||||
|
||||
# The default max_size value is motivated by JSONRPC, where a
|
||||
# normal request will be 250 bytes or less, and a reasonable
|
||||
# batch may contain 4000 requests.
|
||||
def __init__(self, max_size=250 * 4000):
|
||||
'''max_size - an anti-DoS measure. If, after processing an incoming
|
||||
message, buffered data would exceed max_size bytes, that
|
||||
buffered data is dropped entirely and the framer waits for a
|
||||
newline character to re-synchronize the stream.
|
||||
'''
|
||||
self.max_size = max_size
|
||||
self.queue = Queue()
|
||||
self.received_bytes = self.queue.put_nowait
|
||||
self.synchronizing = False
|
||||
self.residual = b''
|
||||
|
||||
def frame(self, message):
|
||||
return message + b'\n'
|
||||
|
||||
async def receive_message(self):
|
||||
parts = []
|
||||
buffer_size = 0
|
||||
while True:
|
||||
part = self.residual
|
||||
self.residual = b''
|
||||
if not part:
|
||||
part = await self.queue.get()
|
||||
|
||||
npos = part.find(b'\n')
|
||||
if npos == -1:
|
||||
parts.append(part)
|
||||
buffer_size += len(part)
|
||||
# Ignore over-sized messages; re-synchronize
|
||||
if buffer_size <= self.max_size:
|
||||
continue
|
||||
self.synchronizing = True
|
||||
raise MemoryError(f'dropping message over {self.max_size:,d} '
|
||||
f'bytes and re-synchronizing')
|
||||
|
||||
tail, self.residual = part[:npos], part[npos + 1:]
|
||||
if self.synchronizing:
|
||||
self.synchronizing = False
|
||||
return await self.receive_message()
|
||||
else:
|
||||
parts.append(tail)
|
||||
return b''.join(parts)
|
||||
|
||||
|
||||
class ByteQueue(object):
|
||||
'''A producer-comsumer queue. Incoming network data is put as it
|
||||
arrives, and the consumer calls an async method waiting for data of
|
||||
a specific length.'''
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Queue()
|
||||
self.parts = []
|
||||
self.parts_len = 0
|
||||
self.put_nowait = self.queue.put_nowait
|
||||
|
||||
async def receive(self, size):
|
||||
while self.parts_len < size:
|
||||
part = await self.queue.get()
|
||||
self.parts.append(part)
|
||||
self.parts_len += len(part)
|
||||
self.parts_len -= size
|
||||
whole = b''.join(self.parts)
|
||||
self.parts = [whole[size:]]
|
||||
return whole[:size]
|
||||
|
||||
|
||||
class BinaryFramer(object):
|
||||
'''A framer for binary messaging protocols.'''
|
||||
|
||||
def __init__(self):
|
||||
self.byte_queue = ByteQueue()
|
||||
self.message_queue = Queue()
|
||||
self.received_bytes = self.byte_queue.put_nowait
|
||||
|
||||
def frame(self, message):
|
||||
command, payload = message
|
||||
return b''.join((
|
||||
self._build_header(command, payload),
|
||||
payload
|
||||
))
|
||||
|
||||
async def receive_message(self):
|
||||
command, payload_len, checksum = await self._receive_header()
|
||||
payload = await self.byte_queue.receive(payload_len)
|
||||
payload_checksum = self._checksum(payload)
|
||||
if payload_checksum != checksum:
|
||||
raise BadChecksumError(payload_checksum, checksum)
|
||||
return command, payload
|
||||
|
||||
def _checksum(self, payload):
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_header(self, command, payload):
|
||||
raise NotImplementedError
|
||||
|
||||
async def _receive_header(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Helpers
|
||||
struct_le_I = Struct('<I')
|
||||
pack_le_uint32 = struct_le_I.pack
|
||||
|
||||
|
||||
def sha256(x):
|
||||
'''Simple wrapper of hashlib sha256.'''
|
||||
return _sha256(x).digest()
|
||||
|
||||
|
||||
def double_sha256(x):
|
||||
'''SHA-256 of SHA-256, as used extensively in bitcoin.'''
|
||||
return sha256(sha256(x))
|
||||
|
||||
|
||||
class BadChecksumError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadMagicError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OversizedPayloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BitcoinFramer(BinaryFramer):
|
||||
'''Provides a framer of binary message payloads in the style of the
|
||||
Bitcoin network protocol.
|
||||
|
||||
Each binary message has the following elements, in order:
|
||||
|
||||
Magic - to confirm network (currently unused for stream sync)
|
||||
Command - padded command
|
||||
Length - payload length in bytes
|
||||
Checksum - checksum of the payload
|
||||
Payload - binary payload
|
||||
|
||||
Call frame(command, payload) to get a framed message.
|
||||
Pass incoming network bytes to received_bytes().
|
||||
Wait on receive_message() to get incoming (command, payload) pairs.
|
||||
'''
|
||||
|
||||
def __init__(self, magic, max_block_size):
|
||||
def pad_command(command):
|
||||
fill = 12 - len(command)
|
||||
if fill < 0:
|
||||
raise ValueError(f'command {command} too long')
|
||||
return command + bytes(fill)
|
||||
|
||||
super().__init__()
|
||||
self._magic = magic
|
||||
self._max_block_size = max_block_size
|
||||
self._pad_command = pad_command
|
||||
self._unpack = Struct(f'<4s12sI4s').unpack
|
||||
|
||||
def _checksum(self, payload):
|
||||
return double_sha256(payload)[:4]
|
||||
|
||||
def _build_header(self, command, payload):
|
||||
return b''.join((
|
||||
self._magic,
|
||||
self._pad_command(command),
|
||||
pack_le_uint32(len(payload)),
|
||||
self._checksum(payload)
|
||||
))
|
||||
|
||||
async def _receive_header(self):
|
||||
header = await self.byte_queue.receive(24)
|
||||
magic, command, payload_len, checksum = self._unpack(header)
|
||||
if magic != self._magic:
|
||||
raise BadMagicError(magic, self._magic)
|
||||
command = command.rstrip(b'\0')
|
||||
if payload_len > 1024 * 1024:
|
||||
if command != b'block' or payload_len > self._max_block_size:
|
||||
raise OversizedPayloadError(command, payload_len)
|
||||
return command, payload_len, checksum
|
801
torba/rpc/jsonrpc.py
Normal file
801
torba/rpc/jsonrpc.py
Normal file
|
@ -0,0 +1,801 @@
|
|||
# Copyright (c) 2018, Neil Booth
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
'''Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.'''
|
||||
|
||||
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
||||
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
|
||||
'RPCError', 'ProtocolError',
|
||||
'JSONRPCConnection', 'handler_invocation')
|
||||
|
||||
import itertools
|
||||
import json
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
import attr
|
||||
from asyncio import Queue, Event, CancelledError
|
||||
from .util import signature_info
|
||||
|
||||
|
||||
class SingleRequest(object):
|
||||
__slots__ = ('method', 'args')
|
||||
|
||||
def __init__(self, method, args):
|
||||
if not isinstance(method, str):
|
||||
raise ProtocolError(JSONRPC.METHOD_NOT_FOUND,
|
||||
'method must be a string')
|
||||
if not isinstance(args, (list, tuple, dict)):
|
||||
raise ProtocolError.invalid_args('request arguments must be a '
|
||||
'list or a dictionary')
|
||||
self.args = args
|
||||
self.method = method
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.method!r}, {self.args!r})'
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
self.method == other.method and self.args == other.args)
|
||||
|
||||
|
||||
class Request(SingleRequest):
|
||||
def send_result(self, response):
|
||||
return None
|
||||
|
||||
|
||||
class Notification(SingleRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Batch(object):
|
||||
__slots__ = ('items', )
|
||||
|
||||
def __init__(self, items):
|
||||
if not isinstance(items, (list, tuple)):
|
||||
raise ProtocolError.invalid_request('items must be a list')
|
||||
if not items:
|
||||
raise ProtocolError.empty_batch()
|
||||
if not (all(isinstance(item, SingleRequest) for item in items) or
|
||||
all(isinstance(item, Response) for item in items)):
|
||||
raise ProtocolError.invalid_request('batch must be homogeneous')
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.items[item]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items)
|
||||
|
||||
def __repr__(self):
|
||||
return f'Batch({len(self.items)} items)'
|
||||
|
||||
|
||||
class Response(object):
|
||||
__slots__ = ('result', )
|
||||
|
||||
def __init__(self, result):
|
||||
# Type checking happens when converting to a message
|
||||
self.result = result
|
||||
|
||||
|
||||
class CodeMessageError(Exception):
|
||||
|
||||
def __init__(self, code, message):
|
||||
super().__init__(code, message)
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self.args[1]
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
self.code == other.code and self.message == other.message)
|
||||
|
||||
def __hash__(self):
|
||||
# overridden to make the exception hashable
|
||||
# see https://bugs.python.org/issue28603
|
||||
return hash((self.code, self.message))
|
||||
|
||||
@classmethod
|
||||
def invalid_args(cls, message):
|
||||
return cls(JSONRPC.INVALID_ARGS, message)
|
||||
|
||||
@classmethod
|
||||
def invalid_request(cls, message):
|
||||
return cls(JSONRPC.INVALID_REQUEST, message)
|
||||
|
||||
@classmethod
|
||||
def empty_batch(cls):
|
||||
return cls.invalid_request('batch is empty')
|
||||
|
||||
|
||||
class RPCError(CodeMessageError):
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(CodeMessageError):
|
||||
|
||||
def __init__(self, code, message):
|
||||
super().__init__(code, message)
|
||||
# If not None send this unframed message over the network
|
||||
self.error_message = None
|
||||
# If the error was in a JSON response message; its message ID.
|
||||
# Since None can be a response message ID, "id" means the
|
||||
# error was not sent in a JSON response
|
||||
self.response_msg_id = id
|
||||
|
||||
|
||||
class JSONRPC(object):
|
||||
'''Abstract base class that interprets and constructs JSON RPC messages.'''
|
||||
|
||||
# Error codes. See http://www.jsonrpc.org/specification
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_ARGS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
# Codes specific to this library
|
||||
ERROR_CODE_UNAVAILABLE = -100
|
||||
|
||||
# Can be overridden by derived classes
|
||||
allow_batches = True
|
||||
|
||||
@classmethod
|
||||
def _message_id(cls, message, require_id):
|
||||
'''Validate the message is a dictionary and return its ID.
|
||||
|
||||
Raise an error if the message is invalid or the ID is of an
|
||||
invalid type. If it has no ID, raise an error if require_id
|
||||
is True, otherwise return None.
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_message(cls, message):
|
||||
'''Validate other parts of the message other than those
|
||||
done in _message_id.'''
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _request_args(cls, request):
|
||||
'''Validate the existence and type of the arguments passed
|
||||
in the request dictionary.'''
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _process_request(cls, payload):
|
||||
request_id = None
|
||||
try:
|
||||
request_id = cls._message_id(payload, False)
|
||||
cls._validate_message(payload)
|
||||
method = payload.get('method')
|
||||
if request_id is None:
|
||||
item = Notification(method, cls._request_args(payload))
|
||||
else:
|
||||
item = Request(method, cls._request_args(payload))
|
||||
return item, request_id
|
||||
except ProtocolError as error:
|
||||
code, message = error.code, error.message
|
||||
raise cls._error(code, message, True, request_id)
|
||||
|
||||
@classmethod
|
||||
def _process_response(cls, payload):
|
||||
request_id = None
|
||||
try:
|
||||
request_id = cls._message_id(payload, True)
|
||||
cls._validate_message(payload)
|
||||
return Response(cls.response_value(payload)), request_id
|
||||
except ProtocolError as error:
|
||||
code, message = error.code, error.message
|
||||
raise cls._error(code, message, False, request_id)
|
||||
|
||||
@classmethod
|
||||
def _message_to_payload(cls, message):
|
||||
'''Returns a Python object or a ProtocolError.'''
|
||||
try:
|
||||
return json.loads(message.decode())
|
||||
except UnicodeDecodeError:
|
||||
message = 'messages must be encoded in UTF-8'
|
||||
except json.JSONDecodeError:
|
||||
message = 'invalid JSON'
|
||||
raise cls._error(cls.PARSE_ERROR, message, True, None)
|
||||
|
||||
@classmethod
|
||||
def _error(cls, code, message, send, msg_id):
|
||||
error = ProtocolError(code, message)
|
||||
if send:
|
||||
error.error_message = cls.response_message(error, msg_id)
|
||||
else:
|
||||
error.response_msg_id = msg_id
|
||||
return error
|
||||
|
||||
#
|
||||
# External API
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def message_to_item(cls, message):
|
||||
'''Translate an unframed received message and return an
|
||||
(item, request_id) pair.
|
||||
|
||||
The item can be a Request, Notification, Response or a list.
|
||||
|
||||
A JSON RPC error response is returned as an RPCError inside a
|
||||
Response object.
|
||||
|
||||
If a Batch is returned, request_id is an iterable of request
|
||||
ids, one per batch member.
|
||||
|
||||
If the message violates the protocol in some way a
|
||||
ProtocolError is returned, except if the message was
|
||||
determined to be a response, in which case the ProtocolError
|
||||
is placed inside a Response object. This is so that client
|
||||
code can mark a request as having been responded to even if
|
||||
the response was bad.
|
||||
|
||||
raises: ProtocolError
|
||||
'''
|
||||
payload = cls._message_to_payload(message)
|
||||
if isinstance(payload, dict):
|
||||
if 'method' in payload:
|
||||
return cls._process_request(payload)
|
||||
else:
|
||||
return cls._process_response(payload)
|
||||
elif isinstance(payload, list) and cls.allow_batches:
|
||||
if not payload:
|
||||
raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty',
|
||||
True, None)
|
||||
return payload, None
|
||||
raise cls._error(cls.INVALID_REQUEST,
|
||||
'request object must be a dictionary', True, None)
|
||||
|
||||
# Message formation
|
||||
@classmethod
|
||||
def request_message(cls, item, request_id):
|
||||
'''Convert an RPCRequest item to a message.'''
|
||||
assert isinstance(item, Request)
|
||||
return cls.encode_payload(cls.request_payload(item, request_id))
|
||||
|
||||
@classmethod
|
||||
def notification_message(cls, item):
|
||||
'''Convert an RPCRequest item to a message.'''
|
||||
assert isinstance(item, Notification)
|
||||
return cls.encode_payload(cls.request_payload(item, None))
|
||||
|
||||
@classmethod
|
||||
def response_message(cls, result, request_id):
|
||||
'''Convert a response result (or RPCError) to a message.'''
|
||||
if isinstance(result, CodeMessageError):
|
||||
payload = cls.error_payload(result, request_id)
|
||||
else:
|
||||
payload = cls.response_payload(result, request_id)
|
||||
return cls.encode_payload(payload)
|
||||
|
||||
@classmethod
|
||||
def batch_message(cls, batch, request_ids):
|
||||
'''Convert a request Batch to a message.'''
|
||||
assert isinstance(batch, Batch)
|
||||
if not cls.allow_batches:
|
||||
raise ProtocolError.invalid_request(
|
||||
'protocol does not permit batches')
|
||||
id_iter = iter(request_ids)
|
||||
rm = cls.request_message
|
||||
nm = cls.notification_message
|
||||
parts = (rm(request, next(id_iter)) if isinstance(request, Request)
|
||||
else nm(request) for request in batch)
|
||||
return cls.batch_message_from_parts(parts)
|
||||
|
||||
@classmethod
|
||||
def batch_message_from_parts(cls, messages):
|
||||
'''Convert messages, one per batch item, into a batch message. At
|
||||
least one message must be passed.
|
||||
'''
|
||||
# Comma-separate the messages and wrap the lot in square brackets
|
||||
middle = b', '.join(messages)
|
||||
if not middle:
|
||||
raise ProtocolError.empty_batch()
|
||||
return b''.join([b'[', middle, b']'])
|
||||
|
||||
@classmethod
|
||||
def encode_payload(cls, payload):
|
||||
'''Encode a Python object as JSON and convert it to bytes.'''
|
||||
try:
|
||||
return json.dumps(payload).encode()
|
||||
except TypeError:
|
||||
msg = f'JSON payload encoding error: {payload}'
|
||||
raise ProtocolError(cls.INTERNAL_ERROR, msg) from None
|
||||
|
||||
|
||||
class JSONRPCv1(JSONRPC):
|
||||
'''JSON RPC version 1.0.'''
|
||||
|
||||
allow_batches = False
|
||||
|
||||
@classmethod
|
||||
def _message_id(cls, message, require_id):
|
||||
# JSONv1 requires an ID always, but without constraint on its type
|
||||
# No need to test for a dictionary here as we don't handle batches.
|
||||
if 'id' not in message:
|
||||
raise ProtocolError.invalid_request('request has no "id"')
|
||||
return message['id']
|
||||
|
||||
@classmethod
|
||||
def _request_args(cls, request):
|
||||
args = request.get('params')
|
||||
if not isinstance(args, list):
|
||||
raise ProtocolError.invalid_args(
|
||||
f'invalid request arguments: {args}')
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def _best_effort_error(cls, error):
|
||||
# Do our best to interpret the error
|
||||
code = cls.ERROR_CODE_UNAVAILABLE
|
||||
message = 'no error message provided'
|
||||
if isinstance(error, str):
|
||||
message = error
|
||||
elif isinstance(error, int):
|
||||
code = error
|
||||
elif isinstance(error, dict):
|
||||
if isinstance(error.get('message'), str):
|
||||
message = error['message']
|
||||
if isinstance(error.get('code'), int):
|
||||
code = error['code']
|
||||
|
||||
return RPCError(code, message)
|
||||
|
||||
@classmethod
|
||||
def response_value(cls, payload):
|
||||
if 'result' not in payload or 'error' not in payload:
|
||||
raise ProtocolError.invalid_request(
|
||||
'response must contain both "result" and "error"')
|
||||
|
||||
result = payload['result']
|
||||
error = payload['error']
|
||||
if error is None:
|
||||
return result # It seems None can be a valid result
|
||||
if result is not None:
|
||||
raise ProtocolError.invalid_request(
|
||||
'response has a "result" and an "error"')
|
||||
|
||||
return cls._best_effort_error(error)
|
||||
|
||||
@classmethod
|
||||
def request_payload(cls, request, request_id):
|
||||
'''JSON v1 request (or notification) payload.'''
|
||||
if isinstance(request.args, dict):
|
||||
raise ProtocolError.invalid_args(
|
||||
'JSONRPCv1 does not support named arguments')
|
||||
return {
|
||||
'method': request.method,
|
||||
'params': request.args,
|
||||
'id': request_id
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def response_payload(cls, result, request_id):
|
||||
'''JSON v1 response payload.'''
|
||||
return {
|
||||
'result': result,
|
||||
'error': None,
|
||||
'id': request_id
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def error_payload(cls, error, request_id):
|
||||
return {
|
||||
'result': None,
|
||||
'error': {'code': error.code, 'message': error.message},
|
||||
'id': request_id
|
||||
}
|
||||
|
||||
|
||||
class JSONRPCv2(JSONRPC):
|
||||
'''JSON RPC version 2.0.'''
|
||||
|
||||
@classmethod
|
||||
def _message_id(cls, message, require_id):
|
||||
if not isinstance(message, dict):
|
||||
raise ProtocolError.invalid_request(
|
||||
'request object must be a dictionary')
|
||||
if 'id' in message:
|
||||
request_id = message['id']
|
||||
if not isinstance(request_id, (Number, str, type(None))):
|
||||
raise ProtocolError.invalid_request(
|
||||
f'invalid "id": {request_id}')
|
||||
return request_id
|
||||
else:
|
||||
if require_id:
|
||||
raise ProtocolError.invalid_request('request has no "id"')
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _validate_message(cls, message):
|
||||
if message.get('jsonrpc') != '2.0':
|
||||
raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"')
|
||||
|
||||
@classmethod
|
||||
def _request_args(cls, request):
|
||||
args = request.get('params', [])
|
||||
if not isinstance(args, (dict, list)):
|
||||
raise ProtocolError.invalid_args(
|
||||
f'invalid request arguments: {args}')
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def response_value(cls, payload):
|
||||
if 'result' in payload:
|
||||
if 'error' in payload:
|
||||
raise ProtocolError.invalid_request(
|
||||
'response contains both "result" and "error"')
|
||||
return payload['result']
|
||||
|
||||
if 'error' not in payload:
|
||||
raise ProtocolError.invalid_request(
|
||||
'response contains neither "result" nor "error"')
|
||||
|
||||
# Return an RPCError object
|
||||
error = payload['error']
|
||||
if isinstance(error, dict):
|
||||
code = error.get('code')
|
||||
message = error.get('message')
|
||||
if isinstance(code, int) and isinstance(message, str):
|
||||
return RPCError(code, message)
|
||||
|
||||
raise ProtocolError.invalid_request(
|
||||
f'ill-formed response error object: {error}')
|
||||
|
||||
@classmethod
|
||||
def request_payload(cls, request, request_id):
|
||||
'''JSON v2 request (or notification) payload.'''
|
||||
payload = {
|
||||
'jsonrpc': '2.0',
|
||||
'method': request.method,
|
||||
}
|
||||
# A notification?
|
||||
if request_id is not None:
|
||||
payload['id'] = request_id
|
||||
# Preserve empty dicts as missing params is read as an array
|
||||
if request.args or request.args == {}:
|
||||
payload['params'] = request.args
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def response_payload(cls, result, request_id):
|
||||
'''JSON v2 response payload.'''
|
||||
return {
|
||||
'jsonrpc': '2.0',
|
||||
'result': result,
|
||||
'id': request_id
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def error_payload(cls, error, request_id):
|
||||
return {
|
||||
'jsonrpc': '2.0',
|
||||
'error': {'code': error.code, 'message': error.message},
|
||||
'id': request_id
|
||||
}
|
||||
|
||||
|
||||
class JSONRPCLoose(JSONRPC):
|
||||
'''A relaxed versin of JSON RPC.'''
|
||||
|
||||
# Don't be so loose we accept any old message ID
|
||||
_message_id = JSONRPCv2._message_id
|
||||
_validate_message = JSONRPC._validate_message
|
||||
_request_args = JSONRPCv2._request_args
|
||||
# Outoing messages are JSONRPCv2 so we give the other side the
|
||||
# best chance to assume / detect JSONRPCv2 as default protocol.
|
||||
error_payload = JSONRPCv2.error_payload
|
||||
request_payload = JSONRPCv2.request_payload
|
||||
response_payload = JSONRPCv2.response_payload
|
||||
|
||||
@classmethod
|
||||
def response_value(cls, payload):
|
||||
# Return result, unless it is None and there is an error
|
||||
if payload.get('error') is not None:
|
||||
if payload.get('result') is not None:
|
||||
raise ProtocolError.invalid_request(
|
||||
'response contains both "result" and "error"')
|
||||
return JSONRPCv1._best_effort_error(payload['error'])
|
||||
|
||||
if 'result' not in payload:
|
||||
raise ProtocolError.invalid_request(
|
||||
'response contains neither "result" nor "error"')
|
||||
|
||||
# Can be None
|
||||
return payload['result']
|
||||
|
||||
|
||||
class JSONRPCAutoDetect(JSONRPCv2):
|
||||
|
||||
@classmethod
|
||||
def message_to_item(cls, message):
|
||||
return cls.detect_protocol(message), None
|
||||
|
||||
@classmethod
|
||||
def detect_protocol(cls, message):
|
||||
'''Attempt to detect the protocol from the message.'''
|
||||
main = cls._message_to_payload(message)
|
||||
|
||||
def protocol_for_payload(payload):
|
||||
if not isinstance(payload, dict):
|
||||
return JSONRPCLoose # Will error
|
||||
# Obey an explicit "jsonrpc"
|
||||
version = payload.get('jsonrpc')
|
||||
if version == '2.0':
|
||||
return JSONRPCv2
|
||||
if version == '1.0':
|
||||
return JSONRPCv1
|
||||
|
||||
# Now to decide between JSONRPCLoose and JSONRPCv1 if possible
|
||||
if 'result' in payload and 'error' in payload:
|
||||
return JSONRPCv1
|
||||
return JSONRPCLoose
|
||||
|
||||
if isinstance(main, list):
|
||||
parts = set(protocol_for_payload(payload) for payload in main)
|
||||
# If all same protocol, return it
|
||||
if len(parts) == 1:
|
||||
return parts.pop()
|
||||
# If strict protocol detected, return it, preferring JSONRPCv2.
|
||||
# This means a batch of JSONRPCv1 will fail
|
||||
for protocol in (JSONRPCv2, JSONRPCv1):
|
||||
if protocol in parts:
|
||||
return protocol
|
||||
# Will error if no parts
|
||||
return JSONRPCLoose
|
||||
|
||||
return protocol_for_payload(main)
|
||||
|
||||
|
||||
class JSONRPCConnection(object):
|
||||
'''Maintains state of a JSON RPC connection, in particular
|
||||
encapsulating the handling of request IDs.
|
||||
|
||||
protocol - the JSON RPC protocol to follow
|
||||
max_response_size - responses over this size send an error response
|
||||
instead.
|
||||
'''
|
||||
|
||||
_id_counter = itertools.count()
|
||||
|
||||
def __init__(self, protocol):
|
||||
self._protocol = protocol
|
||||
# Sent Requests and Batches that have not received a response.
|
||||
# The key is its request ID; for a batch it is sorted tuple
|
||||
# of request IDs
|
||||
self._requests = {}
|
||||
# A public attribute intended to be settable dynamically
|
||||
self.max_response_size = 0
|
||||
|
||||
def _oversized_response_message(self, request_id):
|
||||
text = f'response too large (over {self.max_response_size:,d} bytes'
|
||||
error = RPCError.invalid_request(text)
|
||||
return self._protocol.response_message(error, request_id)
|
||||
|
||||
def _receive_response(self, result, request_id):
|
||||
if request_id not in self._requests:
|
||||
if request_id is None and isinstance(result, RPCError):
|
||||
message = f'diagnostic error received: {result}'
|
||||
else:
|
||||
message = f'response to unsent request (ID: {request_id})'
|
||||
raise ProtocolError.invalid_request(message) from None
|
||||
request, event = self._requests.pop(request_id)
|
||||
event.result = result
|
||||
event.set()
|
||||
return []
|
||||
|
||||
def _receive_request_batch(self, payloads):
|
||||
def item_send_result(request_id, result):
|
||||
nonlocal size
|
||||
part = protocol.response_message(result, request_id)
|
||||
size += len(part) + 2
|
||||
if size > self.max_response_size > 0:
|
||||
part = self._oversized_response_message(request_id)
|
||||
parts.append(part)
|
||||
if len(parts) == count:
|
||||
return protocol.batch_message_from_parts(parts)
|
||||
return None
|
||||
|
||||
parts = []
|
||||
items = []
|
||||
size = 0
|
||||
count = 0
|
||||
protocol = self._protocol
|
||||
for payload in payloads:
|
||||
try:
|
||||
item, request_id = protocol._process_request(payload)
|
||||
items.append(item)
|
||||
if isinstance(item, Request):
|
||||
count += 1
|
||||
item.send_result = partial(item_send_result, request_id)
|
||||
except ProtocolError as error:
|
||||
count += 1
|
||||
parts.append(error.error_message)
|
||||
|
||||
if not items and parts:
|
||||
error = ProtocolError(0, "")
|
||||
error.error_message = protocol.batch_message_from_parts(parts)
|
||||
raise error
|
||||
return items
|
||||
|
||||
def _receive_response_batch(self, payloads):
|
||||
request_ids = []
|
||||
results = []
|
||||
for payload in payloads:
|
||||
# Let ProtocolError exceptions through
|
||||
item, request_id = self._protocol._process_response(payload)
|
||||
request_ids.append(request_id)
|
||||
results.append(item.result)
|
||||
|
||||
ordered = sorted(zip(request_ids, results), key=lambda t: t[0])
|
||||
ordered_ids, ordered_results = zip(*ordered)
|
||||
if ordered_ids not in self._requests:
|
||||
raise ProtocolError.invalid_request('response to unsent batch')
|
||||
request_batch, event = self._requests.pop(ordered_ids)
|
||||
event.result = ordered_results
|
||||
event.set()
|
||||
return []
|
||||
|
||||
def _send_result(self, request_id, result):
|
||||
message = self._protocol.response_message(result, request_id)
|
||||
if len(message) > self.max_response_size > 0:
|
||||
message = self._oversized_response_message(request_id)
|
||||
return message
|
||||
|
||||
def _event(self, request, request_id):
|
||||
event = Event()
|
||||
self._requests[request_id] = (request, event)
|
||||
return event
|
||||
|
||||
#
|
||||
# External API
|
||||
#
|
||||
def send_request(self, request):
|
||||
'''Send a Request. Return a (message, event) pair.
|
||||
|
||||
The message is an unframed message to send over the network.
|
||||
Wait on the event for the response; which will be in the
|
||||
"result" attribute.
|
||||
|
||||
Raises: ProtocolError if the request violates the protocol
|
||||
in some way..
|
||||
'''
|
||||
request_id = next(self._id_counter)
|
||||
message = self._protocol.request_message(request, request_id)
|
||||
return message, self._event(request, request_id)
|
||||
|
||||
def send_notification(self, notification):
|
||||
return self._protocol.notification_message(notification)
|
||||
|
||||
def send_batch(self, batch):
|
||||
ids = tuple(next(self._id_counter)
|
||||
for request in batch if isinstance(request, Request))
|
||||
message = self._protocol.batch_message(batch, ids)
|
||||
event = self._event(batch, ids) if ids else None
|
||||
return message, event
|
||||
|
||||
def receive_message(self, message):
|
||||
'''Call with an unframed message received from the network.
|
||||
|
||||
Raises: ProtocolError if the message violates the protocol in
|
||||
some way. However, if it happened in a response that can be
|
||||
paired with a request, the ProtocolError is instead set in the
|
||||
result attribute of the send_request() that caused the error.
|
||||
'''
|
||||
try:
|
||||
item, request_id = self._protocol.message_to_item(message)
|
||||
except ProtocolError as e:
|
||||
if e.response_msg_id is not id:
|
||||
return self._receive_response(e, e.response_msg_id)
|
||||
raise
|
||||
|
||||
if isinstance(item, Request):
|
||||
item.send_result = partial(self._send_result, request_id)
|
||||
return [item]
|
||||
if isinstance(item, Notification):
|
||||
return [item]
|
||||
if isinstance(item, Response):
|
||||
return self._receive_response(item.result, request_id)
|
||||
if isinstance(item, list):
|
||||
if all(isinstance(payload, dict)
|
||||
and ('result' in payload or 'error' in payload)
|
||||
for payload in item):
|
||||
return self._receive_response_batch(item)
|
||||
else:
|
||||
return self._receive_request_batch(item)
|
||||
else:
|
||||
# Protocol auto-detection hack
|
||||
assert issubclass(item, JSONRPC)
|
||||
self._protocol = item
|
||||
return self.receive_message(message)
|
||||
|
||||
def cancel_pending_requests(self):
|
||||
'''Cancel all pending requests.'''
|
||||
exception = CancelledError()
|
||||
for request, event in self._requests.values():
|
||||
event.result = exception
|
||||
event.set()
|
||||
self._requests.clear()
|
||||
|
||||
def pending_requests(self):
|
||||
'''All sent requests that have not received a response.'''
|
||||
return [request for request, event in self._requests.values()]
|
||||
|
||||
|
||||
def handler_invocation(handler, request):
|
||||
method, args = request.method, request.args
|
||||
if handler is None:
|
||||
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
|
||||
f'unknown method "{method}"')
|
||||
|
||||
# We must test for too few and too many arguments. How
|
||||
# depends on whether the arguments were passed as a list or as
|
||||
# a dictionary.
|
||||
info = signature_info(handler)
|
||||
if isinstance(args, (tuple, list)):
|
||||
if len(args) < info.min_args:
|
||||
s = '' if len(args) == 1 else 's'
|
||||
raise RPCError.invalid_args(
|
||||
f'{len(args)} argument{s} passed to method '
|
||||
f'"{method}" but it requires {info.min_args}')
|
||||
if info.max_args is not None and len(args) > info.max_args:
|
||||
s = '' if len(args) == 1 else 's'
|
||||
raise RPCError.invalid_args(
|
||||
f'{len(args)} argument{s} passed to method '
|
||||
f'{method} taking at most {info.max_args}')
|
||||
return partial(handler, *args)
|
||||
|
||||
# Arguments passed by name
|
||||
if info.other_names is None:
|
||||
raise RPCError.invalid_args(f'method "{method}" cannot '
|
||||
f'be called with named arguments')
|
||||
|
||||
missing = set(info.required_names).difference(args)
|
||||
if missing:
|
||||
s = '' if len(missing) == 1 else 's'
|
||||
missing = ', '.join(sorted(f'"{name}"' for name in missing))
|
||||
raise RPCError.invalid_args(f'method "{method}" requires '
|
||||
f'parameter{s} {missing}')
|
||||
|
||||
if info.other_names is not any:
|
||||
excess = set(args).difference(info.required_names)
|
||||
excess = excess.difference(info.other_names)
|
||||
if excess:
|
||||
s = '' if len(excess) == 1 else 's'
|
||||
excess = ', '.join(sorted(f'"{name}"' for name in excess))
|
||||
raise RPCError.invalid_args(f'method "{method}" does not '
|
||||
f'take parameter{s} {excess}')
|
||||
return partial(handler, **args)
|
549
torba/rpc/session.py
Normal file
549
torba/rpc/session.py
Normal file
|
@ -0,0 +1,549 @@
|
|||
# Copyright (c) 2018, Neil Booth
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server',
|
||||
'BatchError')
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextlib import suppress
|
||||
|
||||
from . import *
|
||||
from .util import Concurrency
|
||||
|
||||
|
||||
class Connector(object):
|
||||
|
||||
def __init__(self, session_factory, host=None, port=None, proxy=None,
|
||||
**kwargs):
|
||||
self.session_factory = session_factory
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.proxy = proxy
|
||||
self.loop = kwargs.get('loop', asyncio.get_event_loop())
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def create_connection(self):
|
||||
'''Initiate a connection.'''
|
||||
connector = self.proxy or self.loop
|
||||
return await connector.create_connection(
|
||||
self.session_factory, self.host, self.port, **self.kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
transport, self.protocol = await self.create_connection()
|
||||
# By default, do not limit outgoing connections
|
||||
self.protocol.bw_limit = 0
|
||||
return self.protocol
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.protocol.close()
|
||||
|
||||
|
||||
class SessionBase(asyncio.Protocol):
|
||||
'''Base class of networking sessions.
|
||||
|
||||
There is no client / server distinction other than who initiated
|
||||
the connection.
|
||||
|
||||
To initiate a connection to a remote server pass host, port and
|
||||
proxy to the constructor, and then call create_connection(). Each
|
||||
successful call should have a corresponding call to close().
|
||||
|
||||
Alternatively if used in a with statement, the connection is made
|
||||
on entry to the block, and closed on exit from the block.
|
||||
'''
|
||||
|
||||
max_errors = 10
|
||||
|
||||
def __init__(self, *, framer=None, loop=None):
|
||||
self.framer = framer or self.default_framer()
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.transport = None
|
||||
# Set when a connection is made
|
||||
self._address = None
|
||||
self._proxy_address = None
|
||||
# For logger.debug messsages
|
||||
self.verbosity = 0
|
||||
# Cleared when the send socket is full
|
||||
self._can_send = Event()
|
||||
self._can_send.set()
|
||||
self._pm_task = None
|
||||
self._task_group = TaskGroup()
|
||||
# Force-close a connection if a send doesn't succeed in this time
|
||||
self.max_send_delay = 60
|
||||
# Statistics. The RPC object also keeps its own statistics.
|
||||
self.start_time = time.time()
|
||||
self.errors = 0
|
||||
self.send_count = 0
|
||||
self.send_size = 0
|
||||
self.last_send = self.start_time
|
||||
self.recv_count = 0
|
||||
self.recv_size = 0
|
||||
self.last_recv = self.start_time
|
||||
# Bandwidth usage per hour before throttling starts
|
||||
self.bw_limit = 2000000
|
||||
self.bw_time = self.start_time
|
||||
self.bw_charge = 0
|
||||
# Concurrency control
|
||||
self.max_concurrent = 6
|
||||
self._concurrency = Concurrency(self.max_concurrent)
|
||||
|
||||
async def _update_concurrency(self):
|
||||
# A non-positive value means not to limit concurrency
|
||||
if self.bw_limit <= 0:
|
||||
return
|
||||
now = time.time()
|
||||
# Reduce the recorded usage in proportion to the elapsed time
|
||||
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
||||
self.bw_charge = max(0, self.bw_charge - int(refund))
|
||||
self.bw_time = now
|
||||
# Reduce concurrency allocation by 1 for each whole bw_limit used
|
||||
throttle = int(self.bw_charge / self.bw_limit)
|
||||
target = max(1, self.max_concurrent - throttle)
|
||||
current = self._concurrency.max_concurrent
|
||||
if target != current:
|
||||
self.logger.info(f'changing task concurrency from {current} '
|
||||
f'to {target}')
|
||||
await self._concurrency.set_max_concurrent(target)
|
||||
|
||||
def _using_bandwidth(self, size):
|
||||
'''Called when sending or receiving size bytes.'''
|
||||
self.bw_charge += size
|
||||
|
||||
async def _process_messages(self):
|
||||
'''Process incoming messages asynchronously and consume the
|
||||
results.
|
||||
'''
|
||||
async def collect_tasks():
|
||||
next_done = task_group.next_done
|
||||
while True:
|
||||
await next_done()
|
||||
|
||||
task_group = self._task_group
|
||||
async with task_group:
|
||||
await self.spawn(self._receive_messages)
|
||||
await self.spawn(collect_tasks)
|
||||
|
||||
async def _limited_wait(self, secs):
|
||||
# Wait at most secs seconds to send, otherwise abort the connection
|
||||
try:
|
||||
async with timeout_after(secs):
|
||||
await self._can_send.wait()
|
||||
except TaskTimeout:
|
||||
self.abort()
|
||||
raise
|
||||
|
||||
async def _send_message(self, message):
|
||||
if not self._can_send.is_set():
|
||||
await self._limited_wait(self.max_send_delay)
|
||||
if not self.is_closing():
|
||||
framed_message = self.framer.frame(message)
|
||||
self.send_size += len(framed_message)
|
||||
self._using_bandwidth(len(framed_message))
|
||||
self.send_count += 1
|
||||
self.last_send = time.time()
|
||||
if self.verbosity >= 4:
|
||||
self.logger.debug(f'Sending framed message {framed_message}')
|
||||
self.transport.write(framed_message)
|
||||
|
||||
def _bump_errors(self):
|
||||
self.errors += 1
|
||||
if self.errors >= self.max_errors:
|
||||
# Don't await self.close() because that is self-cancelling
|
||||
self._close()
|
||||
|
||||
def _close(self):
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
|
||||
# asyncio framework
|
||||
def data_received(self, framed_message):
|
||||
'''Called by asyncio when a message comes in.'''
|
||||
if self.verbosity >= 4:
|
||||
self.logger.debug(f'Received framed message {framed_message}')
|
||||
self.recv_size += len(framed_message)
|
||||
self._using_bandwidth(len(framed_message))
|
||||
self.framer.received_bytes(framed_message)
|
||||
|
||||
def pause_writing(self):
|
||||
'''Transport calls when the send buffer is full.'''
|
||||
if not self.is_closing():
|
||||
self._can_send.clear()
|
||||
self.transport.pause_reading()
|
||||
|
||||
def resume_writing(self):
|
||||
'''Transport calls when the send buffer has room.'''
|
||||
if not self._can_send.is_set():
|
||||
self._can_send.set()
|
||||
self.transport.resume_reading()
|
||||
|
||||
def connection_made(self, transport):
|
||||
'''Called by asyncio when a connection is established.
|
||||
|
||||
Derived classes overriding this method must call this first.'''
|
||||
self.transport = transport
|
||||
# This would throw if called on a closed SSL transport. Fixed
|
||||
# in asyncio in Python 3.6.1 and 3.5.4
|
||||
peer_address = transport.get_extra_info('peername')
|
||||
# If the Socks proxy was used then _address is already set to
|
||||
# the remote address
|
||||
if self._address:
|
||||
self._proxy_address = peer_address
|
||||
else:
|
||||
self._address = peer_address
|
||||
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
'''Called by asyncio when the connection closes.
|
||||
|
||||
Tear down things done in connection_made.'''
|
||||
self._address = None
|
||||
self.transport = None
|
||||
self._pm_task.cancel()
|
||||
# Release waiting tasks
|
||||
self._can_send.set()
|
||||
|
||||
# External API
|
||||
def default_framer(self):
|
||||
'''Return a default framer.'''
|
||||
raise NotImplementedError
|
||||
|
||||
def peer_address(self):
|
||||
'''Returns the peer's address (Python networking address), or None if
|
||||
no connection or an error.
|
||||
|
||||
This is the result of socket.getpeername() when the connection
|
||||
was made.
|
||||
'''
|
||||
return self._address
|
||||
|
||||
def peer_address_str(self):
|
||||
'''Returns the peer's IP address and port as a human-readable
|
||||
string.'''
|
||||
if not self._address:
|
||||
return 'unknown'
|
||||
ip_addr_str, port = self._address[:2]
|
||||
if ':' in ip_addr_str:
|
||||
return f'[{ip_addr_str}]:{port}'
|
||||
else:
|
||||
return f'{ip_addr_str}:{port}'
|
||||
|
||||
async def spawn(self, coro, *args):
|
||||
'''If the session is connected, spawn a task that is cancelled
|
||||
on disconnect, and return it. Otherwise return None.'''
|
||||
group = self._task_group
|
||||
if not group.closed():
|
||||
return await group.spawn(coro, *args)
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_closing(self):
|
||||
'''Return True if the connection is closing.'''
|
||||
return not self.transport or self.transport.is_closing()
|
||||
|
||||
def abort(self):
|
||||
'''Forcefully close the connection.'''
|
||||
if self.transport:
|
||||
self.transport.abort()
|
||||
|
||||
async def close(self, *, force_after=30):
|
||||
'''Close the connection and return when closed.'''
|
||||
self._close()
|
||||
if self._pm_task:
|
||||
with suppress(CancelledError):
|
||||
async with ignore_after(force_after):
|
||||
await self._pm_task
|
||||
self.abort()
|
||||
await self._pm_task
|
||||
|
||||
|
||||
class MessageSession(SessionBase):
|
||||
'''Session class for protocols where messages are not tied to responses,
|
||||
such as the Bitcoin protocol.
|
||||
|
||||
To use as a client (connection-opening) session, pass host, port
|
||||
and perhaps a proxy.
|
||||
'''
|
||||
async def _receive_messages(self):
|
||||
while not self.is_closing():
|
||||
try:
|
||||
message = await self.framer.receive_message()
|
||||
except BadMagicError as e:
|
||||
magic, expected = e.args
|
||||
self.logger.error(
|
||||
f'bad network magic: got {magic} expected {expected}, '
|
||||
f'disconnecting'
|
||||
)
|
||||
self._close()
|
||||
except OversizedPayloadError as e:
|
||||
command, payload_len = e.args
|
||||
self.logger.error(
|
||||
f'oversized payload of {payload_len:,d} bytes to command '
|
||||
f'{command}, disconnecting'
|
||||
)
|
||||
self._close()
|
||||
except BadChecksumError as e:
|
||||
payload_checksum, claimed_checksum = e.args
|
||||
self.logger.warning(
|
||||
f'checksum mismatch: actual {payload_checksum.hex()} '
|
||||
f'vs claimed {claimed_checksum.hex()}'
|
||||
)
|
||||
self._bump_errors()
|
||||
else:
|
||||
self.last_recv = time.time()
|
||||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
await self.spawn(self._throttled_message(message))
|
||||
|
||||
async def _throttled_message(self, message):
|
||||
'''Process a single request, respecting the concurrency limit.'''
|
||||
async with self._concurrency.semaphore:
|
||||
try:
|
||||
await self.handle_message(message)
|
||||
except ProtocolError as e:
|
||||
self.logger.error(f'{e}')
|
||||
self._bump_errors()
|
||||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.logger.exception(f'exception handling {message}')
|
||||
self._bump_errors()
|
||||
|
||||
# External API
|
||||
def default_framer(self):
|
||||
'''Return a bitcoin framer.'''
|
||||
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
|
||||
|
||||
async def handle_message(self, message):
|
||||
'''message is a (command, payload) pair.'''
|
||||
pass
|
||||
|
||||
async def send_message(self, message):
|
||||
'''Send a message (command, payload) over the network.'''
|
||||
await self._send_message(message)
|
||||
|
||||
|
||||
class BatchError(Exception):
|
||||
|
||||
def __init__(self, request):
|
||||
self.request = request # BatchRequest object
|
||||
|
||||
|
||||
class BatchRequest(object):
|
||||
'''Used to build a batch request to send to the server. Stores
|
||||
the
|
||||
|
||||
Attributes batch and results are initially None.
|
||||
|
||||
Adding an invalid request or notification immediately raises a
|
||||
ProtocolError.
|
||||
|
||||
On exiting the with clause, it will:
|
||||
|
||||
1) create a Batch object for the requests in the order they were
|
||||
added. If the batch is empty this raises a ProtocolError.
|
||||
|
||||
2) set the "batch" attribute to be that batch
|
||||
|
||||
3) send the batch request and wait for a response
|
||||
|
||||
4) raise a ProtocolError if the protocol was violated by the
|
||||
server. Currently this only happens if it gave more than one
|
||||
response to any request
|
||||
|
||||
5) otherwise there is precisely one response to each Request. Set
|
||||
the "results" attribute to the tuple of results; the responses
|
||||
are ordered to match the Requests in the batch. Notifications
|
||||
do not get a response.
|
||||
|
||||
6) if raise_errors is True and any individual response was a JSON
|
||||
RPC error response, or violated the protocol in some way, a
|
||||
BatchError exception is raised. Otherwise the caller can be
|
||||
certain each request returned a standard result.
|
||||
'''
|
||||
|
||||
def __init__(self, session, raise_errors):
|
||||
self._session = session
|
||||
self._raise_errors = raise_errors
|
||||
self._requests = []
|
||||
self.batch = None
|
||||
self.results = None
|
||||
|
||||
def add_request(self, method, args=()):
|
||||
self._requests.append(Request(method, args))
|
||||
|
||||
def add_notification(self, method, args=()):
|
||||
self._requests.append(Notification(method, args))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._requests)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type is None:
|
||||
self.batch = Batch(self._requests)
|
||||
message, event = self._session.connection.send_batch(self.batch)
|
||||
await self._session._send_message(message)
|
||||
await event.wait()
|
||||
self.results = event.result
|
||||
if self._raise_errors:
|
||||
if any(isinstance(item, Exception) for item in event.result):
|
||||
raise BatchError(self)
|
||||
|
||||
|
||||
class RPCSession(SessionBase):
|
||||
'''Base class for protocols where a message can lead to a response,
|
||||
for example JSON RPC.'''
|
||||
|
||||
def __init__(self, *, framer=None, loop=None, connection=None):
|
||||
super().__init__(framer=framer, loop=loop)
|
||||
self.connection = connection or self.default_connection()
|
||||
|
||||
async def _receive_messages(self):
|
||||
while not self.is_closing():
|
||||
try:
|
||||
message = await self.framer.receive_message()
|
||||
except MemoryError as e:
|
||||
self.logger.warning(f'{e!r}')
|
||||
continue
|
||||
|
||||
self.last_recv = time.time()
|
||||
self.recv_count += 1
|
||||
if self.recv_count % 10 == 0:
|
||||
await self._update_concurrency()
|
||||
|
||||
try:
|
||||
requests = self.connection.receive_message(message)
|
||||
except ProtocolError as e:
|
||||
self.logger.debug(f'{e}')
|
||||
if e.error_message:
|
||||
await self._send_message(e.error_message)
|
||||
if e.code == JSONRPC.PARSE_ERROR:
|
||||
self.max_errors = 0
|
||||
self._bump_errors()
|
||||
else:
|
||||
for request in requests:
|
||||
await self.spawn(self._throttled_request(request))
|
||||
|
||||
async def _throttled_request(self, request):
|
||||
'''Process a single request, respecting the concurrency limit.'''
|
||||
async with self._concurrency.semaphore:
|
||||
try:
|
||||
result = await self.handle_request(request)
|
||||
except (ProtocolError, RPCError) as e:
|
||||
result = e
|
||||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.logger.exception(f'exception handling {request}')
|
||||
result = RPCError(JSONRPC.INTERNAL_ERROR,
|
||||
'internal server error')
|
||||
if isinstance(request, Request):
|
||||
message = request.send_result(result)
|
||||
if message:
|
||||
await self._send_message(message)
|
||||
if isinstance(result, Exception):
|
||||
self._bump_errors()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
# Cancel pending requests and message processing
|
||||
self.connection.cancel_pending_requests()
|
||||
super().connection_lost(exc)
|
||||
|
||||
# External API
|
||||
def default_connection(self):
|
||||
'''Return a default connection if the user provides none.'''
|
||||
return JSONRPCConnection(JSONRPCv2)
|
||||
|
||||
def default_framer(self):
|
||||
'''Return a default framer.'''
|
||||
return NewlineFramer()
|
||||
|
||||
async def handle_request(self, request):
|
||||
pass
|
||||
|
||||
async def send_request(self, method, args=()):
|
||||
'''Send an RPC request over the network.'''
|
||||
message, event = self.connection.send_request(Request(method, args))
|
||||
await self._send_message(message)
|
||||
await event.wait()
|
||||
result = event.result
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
async def send_notification(self, method, args=()):
|
||||
'''Send an RPC notification over the network.'''
|
||||
message = self.connection.send_notification(Notification(method, args))
|
||||
await self._send_message(message)
|
||||
|
||||
def send_batch(self, raise_errors=False):
|
||||
'''Return a BatchRequest. Intended to be used like so:
|
||||
|
||||
async with session.send_batch() as batch:
|
||||
batch.add_request("method1")
|
||||
batch.add_request("sum", (x, y))
|
||||
batch.add_notification("updated")
|
||||
|
||||
for result in batch.results:
|
||||
...
|
||||
|
||||
Note that in some circumstances exceptions can be raised; see
|
||||
BatchRequest doc string.
|
||||
'''
|
||||
return BatchRequest(self, raise_errors)
|
||||
|
||||
|
||||
class Server(object):
|
||||
'''A simple wrapper around an asyncio.Server object.'''
|
||||
|
||||
def __init__(self, session_factory, host=None, port=None, *,
|
||||
loop=None, **kwargs):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self.server = None
|
||||
self._session_factory = session_factory
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def listen(self):
|
||||
self.server = await self.loop.create_server(
|
||||
self._session_factory, self.host, self.port, **self._kwargs)
|
||||
|
||||
async def close(self):
|
||||
'''Close the listening socket. This does not close any ServerSession
|
||||
objects created to handle incoming connections.
|
||||
'''
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
self.server = None
|
439
torba/rpc/socks.py
Normal file
439
torba/rpc/socks.py
Normal file
|
@ -0,0 +1,439 @@
|
|||
# Copyright (c) 2018, Neil Booth
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
'''SOCKS proxying.'''
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import collections
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
from functools import partial
|
||||
|
||||
|
||||
__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
|
||||
'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
|
||||
|
||||
|
||||
SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
|
||||
|
||||
|
||||
class SOCKSError(Exception):
|
||||
'''Base class for SOCKS exceptions. Each raised exception will be
|
||||
an instance of a derived class.'''
|
||||
|
||||
|
||||
class SOCKSProtocolError(SOCKSError):
|
||||
'''Raised when the proxy does not follow the SOCKS protocol'''
|
||||
|
||||
|
||||
class SOCKSFailure(SOCKSError):
|
||||
'''Raised when the proxy refuses or fails to make a connection'''
|
||||
|
||||
|
||||
class NeedData(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SOCKSBase(object):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = bytes()
|
||||
self._state = self._start
|
||||
|
||||
def _read(self, size):
|
||||
if len(self._buffer) < size:
|
||||
raise NeedData(size - len(self._buffer))
|
||||
result = self._buffer[:size]
|
||||
self._buffer = self._buffer[size:]
|
||||
return result
|
||||
|
||||
def receive_data(self, data):
|
||||
self._buffer += data
|
||||
|
||||
def next_message(self):
|
||||
return self._state()
|
||||
|
||||
|
||||
class SOCKS4(SOCKSBase):
|
||||
'''SOCKS4 protocol wrapper.'''
|
||||
|
||||
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
|
||||
REPLY_CODES = {
|
||||
90: 'request granted',
|
||||
91: 'request rejected or failed',
|
||||
92: ('request rejected because SOCKS server cannot connect '
|
||||
'to identd on the client'),
|
||||
93: ('request rejected because the client program and identd '
|
||||
'report different user-ids')
|
||||
}
|
||||
|
||||
def __init__(self, dst_host, dst_port, auth):
|
||||
super().__init__()
|
||||
self._dst_host = self._check_host(dst_host)
|
||||
self._dst_port = dst_port
|
||||
self._auth = auth
|
||||
|
||||
@classmethod
|
||||
def _check_host(cls, host):
|
||||
if not isinstance(host, ipaddress.IPv4Address):
|
||||
try:
|
||||
host = ipaddress.IPv4Address(host)
|
||||
except ValueError:
|
||||
raise SOCKSProtocolError(
|
||||
f'SOCKS4 requires an IPv4 address: {host}') from None
|
||||
return host
|
||||
|
||||
def _start(self):
|
||||
self._state = self._first_response
|
||||
|
||||
if isinstance(self._dst_host, ipaddress.IPv4Address):
|
||||
# SOCKS4
|
||||
dst_ip_packed = self._dst_host.packed
|
||||
host_bytes = b''
|
||||
else:
|
||||
# SOCKS4a
|
||||
dst_ip_packed = b'\0\0\0\1'
|
||||
host_bytes = self._dst_host.encode() + b'\0'
|
||||
|
||||
if isinstance(self._auth, SOCKSUserAuth):
|
||||
user_id = self._auth.username.encode()
|
||||
else:
|
||||
user_id = b''
|
||||
|
||||
# Send TCP/IP stream CONNECT request
|
||||
return b''.join([b'\4\1', struct.pack('>H', self._dst_port),
|
||||
dst_ip_packed, user_id, b'\0', host_bytes])
|
||||
|
||||
def _first_response(self):
|
||||
# Wait for 8-byte response
|
||||
data = self._read(8)
|
||||
if data[0] != 0:
|
||||
raise SOCKSProtocolError(f'invalid {self.name()} proxy '
|
||||
f'response: {data}')
|
||||
reply_code = data[1]
|
||||
if reply_code != 90:
|
||||
msg = self.REPLY_CODES.get(
|
||||
reply_code, f'unknown {self.name()} reply code {reply_code}')
|
||||
raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
|
||||
|
||||
# Other fields ignored
|
||||
return None
|
||||
|
||||
|
||||
class SOCKS4a(SOCKS4):
|
||||
|
||||
@classmethod
|
||||
def _check_host(cls, host):
|
||||
if not isinstance(host, (str, ipaddress.IPv4Address)):
|
||||
raise SOCKSProtocolError(
|
||||
f'SOCKS4a requires an IPv4 address or host name: {host}')
|
||||
return host
|
||||
|
||||
|
||||
class SOCKS5(SOCKSBase):
|
||||
'''SOCKS protocol wrapper.'''
|
||||
|
||||
# See https://tools.ietf.org/html/rfc1928
|
||||
ERROR_CODES = {
|
||||
1: 'general SOCKS server failure',
|
||||
2: 'connection not allowed by ruleset',
|
||||
3: 'network unreachable',
|
||||
4: 'host unreachable',
|
||||
5: 'connection refused',
|
||||
6: 'TTL expired',
|
||||
7: 'command not supported',
|
||||
8: 'address type not supported',
|
||||
}
|
||||
|
||||
def __init__(self, dst_host, dst_port, auth):
|
||||
super().__init__()
|
||||
self._dst_bytes = self._destination_bytes(dst_host, dst_port)
|
||||
self._auth_bytes, self._auth_methods = self._authentication(auth)
|
||||
|
||||
def _destination_bytes(self, host, port):
|
||||
if isinstance(host, ipaddress.IPv4Address):
|
||||
addr_bytes = b'\1' + host.packed
|
||||
elif isinstance(host, ipaddress.IPv6Address):
|
||||
addr_bytes = b'\4' + host.packed
|
||||
elif isinstance(host, str):
|
||||
host = host.encode()
|
||||
if len(host) > 255:
|
||||
raise SOCKSProtocolError(f'hostname too long: '
|
||||
f'{len(host)} bytes')
|
||||
addr_bytes = b'\3' + bytes([len(host)]) + host
|
||||
else:
|
||||
raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 '
|
||||
f'address, or host name: {host}')
|
||||
return addr_bytes + struct.pack('>H', port)
|
||||
|
||||
def _authentication(self, auth):
|
||||
if isinstance(auth, SOCKSUserAuth):
|
||||
user_bytes = auth.username.encode()
|
||||
if not 0 < len(user_bytes) < 256:
|
||||
raise SOCKSProtocolError(f'username {auth.username} has '
|
||||
f'invalid length {len(user_bytes)}')
|
||||
pwd_bytes = auth.password.encode()
|
||||
if not 0 < len(pwd_bytes) < 256:
|
||||
raise SOCKSProtocolError(f'password has invalid length '
|
||||
f'{len(pwd_bytes)}')
|
||||
return b''.join([bytes([1, len(user_bytes)]), user_bytes,
|
||||
bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
|
||||
return b'', [0]
|
||||
|
||||
def _start(self):
|
||||
self._state = self._first_response
|
||||
return (b'\5' + bytes([len(self._auth_methods)])
|
||||
+ bytes(m for m in self._auth_methods))
|
||||
|
||||
def _first_response(self):
|
||||
# Wait for 2-byte response
|
||||
data = self._read(2)
|
||||
if data[0] != 5:
|
||||
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
||||
if data[1] not in self._auth_methods:
|
||||
raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
|
||||
|
||||
# Authenticate if user-password authentication
|
||||
if data[1] == 2:
|
||||
self._state = self._auth_response
|
||||
return self._auth_bytes
|
||||
return self._request_connection()
|
||||
|
||||
def _auth_response(self):
|
||||
data = self._read(2)
|
||||
if data[0] != 1:
|
||||
raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
|
||||
f'response: {data}')
|
||||
if data[1] != 0:
|
||||
raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
|
||||
f'{data[1]}')
|
||||
|
||||
return self._request_connection()
|
||||
|
||||
def _request_connection(self):
|
||||
# Send connection request
|
||||
self._state = self._connect_response
|
||||
return b'\5\1\0' + self._dst_bytes
|
||||
|
||||
def _connect_response(self):
|
||||
data = self._read(5)
|
||||
if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
|
||||
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
||||
if data[1] != 0:
|
||||
raise SOCKSFailure(self.ERROR_CODES.get(
|
||||
data[1], f'unknown SOCKS5 error code: {data[1]}'))
|
||||
|
||||
if data[3] == 1:
|
||||
addr_len = 3 # IPv4
|
||||
elif data[3] == 3:
|
||||
addr_len = data[4] # Hostname
|
||||
else:
|
||||
addr_len = 15 # IPv6
|
||||
|
||||
self._state = partial(self._connect_response_rest, addr_len)
|
||||
return self.next_message()
|
||||
|
||||
def _connect_response_rest(self, addr_len):
|
||||
self._read(addr_len + 2)
|
||||
return None
|
||||
|
||||
|
||||
class SOCKSProxy(object):
|
||||
|
||||
def __init__(self, address, protocol, auth):
|
||||
'''A SOCKS proxy at an address following a SOCKS protocol. auth is an
|
||||
authentication method to use when connecting, or None.
|
||||
|
||||
address is a (host, port) pair; for IPv6 it can instead be a
|
||||
(host, port, flowinfo, scopeid) 4-tuple.
|
||||
'''
|
||||
self.address = address
|
||||
self.protocol = protocol
|
||||
self.auth = auth
|
||||
# Set on each successful connection via the proxy to the
|
||||
# result of socket.getpeername()
|
||||
self.peername = None
|
||||
|
||||
def __str__(self):
|
||||
auth = 'username' if self.auth else 'none'
|
||||
return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
|
||||
|
||||
async def _handshake(self, client, sock, loop):
|
||||
while True:
|
||||
count = 0
|
||||
try:
|
||||
message = client.next_message()
|
||||
except NeedData as e:
|
||||
count = e.args[0]
|
||||
else:
|
||||
if message is None:
|
||||
return
|
||||
await loop.sock_sendall(sock, message)
|
||||
|
||||
if count:
|
||||
data = await loop.sock_recv(sock, count)
|
||||
if not data:
|
||||
raise SOCKSProtocolError("EOF received")
|
||||
client.receive_data(data)
|
||||
|
||||
async def _connect_one(self, host, port):
|
||||
'''Connect to the proxy and perform a handshake requesting a
|
||||
connection to (host, port).
|
||||
|
||||
Return the open socket on success, or the exception on failure.
|
||||
'''
|
||||
client = self.protocol(host, port, self.auth)
|
||||
sock = socket.socket()
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
# A non-blocking socket is required by loop socket methods
|
||||
sock.setblocking(False)
|
||||
await loop.sock_connect(sock, self.address)
|
||||
await self._handshake(client, sock, loop)
|
||||
self.peername = sock.getpeername()
|
||||
return sock
|
||||
except Exception as e:
|
||||
# Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8
|
||||
if sys.platform.startswith('linux'):
|
||||
sock.close()
|
||||
return e
|
||||
|
||||
async def _connect(self, addresses):
|
||||
'''Connect to the proxy and perform a handshake requesting a
|
||||
connection to each address in addresses.
|
||||
|
||||
Return an (open_socket, address) pair on success.
|
||||
'''
|
||||
assert len(addresses) > 0
|
||||
|
||||
exceptions = []
|
||||
for address in addresses:
|
||||
host, port = address[:2]
|
||||
sock = await self._connect_one(host, port)
|
||||
if isinstance(sock, socket.socket):
|
||||
return sock, address
|
||||
exceptions.append(sock)
|
||||
|
||||
strings = set(f'{exc!r}' for exc in exceptions)
|
||||
raise (exceptions[0] if len(strings) == 1 else
|
||||
OSError(f'multiple exceptions: {", ".join(strings)}'))
|
||||
|
||||
async def _detect_proxy(self):
|
||||
'''Return True if it appears we can connect to a SOCKS proxy,
|
||||
otherwise False.
|
||||
'''
|
||||
if self.protocol is SOCKS4a:
|
||||
host, port = 'www.apple.com', 80
|
||||
else:
|
||||
host, port = ipaddress.IPv4Address('8.8.8.8'), 53
|
||||
|
||||
sock = await self._connect_one(host, port)
|
||||
if isinstance(sock, socket.socket):
|
||||
sock.close()
|
||||
return True
|
||||
|
||||
# SOCKSFailure indicates something failed, but that we are
|
||||
# likely talking to a proxy
|
||||
return isinstance(sock, SOCKSFailure)
|
||||
|
||||
@classmethod
|
||||
async def auto_detect_address(cls, address, auth):
|
||||
'''Try to detect a SOCKS proxy at address using the authentication
|
||||
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
|
||||
order. If a SOCKS proxy is detected a SOCKSProxy object is
|
||||
returned.
|
||||
|
||||
Returning a SOCKSProxy does not mean it is functioning - for
|
||||
example, it may have no network connectivity.
|
||||
|
||||
If no proxy is detected return None.
|
||||
'''
|
||||
for protocol in (SOCKS5, SOCKS4a, SOCKS4):
|
||||
proxy = cls(address, protocol, auth)
|
||||
if await proxy._detect_proxy():
|
||||
return proxy
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def auto_detect_host(cls, host, ports, auth):
|
||||
'''Try to detect a SOCKS proxy on a host on one of the ports.
|
||||
|
||||
Calls auto_detect for the ports in order. Returns SOCKS are
|
||||
tried in order; a SOCKSProxy object for the first detected
|
||||
proxy is returned.
|
||||
|
||||
Returning a SOCKSProxy does not mean it is functioning - for
|
||||
example, it may have no network connectivity.
|
||||
|
||||
If no proxy is detected return None.
|
||||
'''
|
||||
for port in ports:
|
||||
address = (host, port)
|
||||
proxy = await cls.auto_detect_address(address, auth)
|
||||
if proxy:
|
||||
return proxy
|
||||
|
||||
return None
|
||||
|
||||
async def create_connection(self, protocol_factory, host, port, *,
|
||||
resolve=False, ssl=None,
|
||||
family=0, proto=0, flags=0):
|
||||
'''Set up a connection to (host, port) through the proxy.
|
||||
|
||||
If resolve is True then host is resolved locally with
|
||||
getaddrinfo using family, proto and flags, otherwise the proxy
|
||||
is asked to resolve host.
|
||||
|
||||
The function signature is similar to loop.create_connection()
|
||||
with the same result. The attribute _address is set on the
|
||||
protocol to the address of the successful remote connection.
|
||||
Additionally raises SOCKSError if something goes wrong with
|
||||
the proxy handshake.
|
||||
'''
|
||||
loop = asyncio.get_event_loop()
|
||||
if resolve:
|
||||
infos = await loop.getaddrinfo(host, port, family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=proto, flags=flags)
|
||||
addresses = [info[4] for info in infos]
|
||||
else:
|
||||
addresses = [(host, port)]
|
||||
|
||||
sock, address = await self._connect(addresses)
|
||||
|
||||
def set_address():
|
||||
protocol = protocol_factory()
|
||||
protocol._address = address
|
||||
return protocol
|
||||
|
||||
return await loop.create_connection(
|
||||
set_address, sock=sock, ssl=ssl,
|
||||
server_hostname=host if ssl else None)
|
120
torba/rpc/util.py
Normal file
120
torba/rpc/util.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) 2018, Neil Booth
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
__all__ = ()
|
||||
|
||||
|
||||
import asyncio
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
import inspect
|
||||
|
||||
|
||||
def normalize_corofunc(corofunc, args):
|
||||
if asyncio.iscoroutine(corofunc):
|
||||
if args != ():
|
||||
raise ValueError('args cannot be passed with a coroutine')
|
||||
return corofunc
|
||||
return corofunc(*args)
|
||||
|
||||
|
||||
def is_async_call(func):
|
||||
'''inspect.iscoroutinefunction that looks through partials.'''
|
||||
while isinstance(func, partial):
|
||||
func = func.func
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
# other_params: None means cannot be called with keyword arguments only
|
||||
# any means any name is good
|
||||
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
|
||||
'required_names other_names')
|
||||
|
||||
|
||||
def signature_info(func):
|
||||
params = inspect.signature(func).parameters
|
||||
min_args = max_args = 0
|
||||
required_names = []
|
||||
other_names = []
|
||||
no_names = False
|
||||
for p in params.values():
|
||||
if p.kind == p.POSITIONAL_OR_KEYWORD:
|
||||
max_args += 1
|
||||
if p.default is p.empty:
|
||||
min_args += 1
|
||||
required_names.append(p.name)
|
||||
else:
|
||||
other_names.append(p.name)
|
||||
elif p.kind == p.KEYWORD_ONLY:
|
||||
other_names.append(p.name)
|
||||
elif p.kind == p.VAR_POSITIONAL:
|
||||
max_args = None
|
||||
elif p.kind == p.VAR_KEYWORD:
|
||||
other_names = any
|
||||
elif p.kind == p.POSITIONAL_ONLY:
|
||||
max_args += 1
|
||||
if p.default is p.empty:
|
||||
min_args += 1
|
||||
no_names = True
|
||||
|
||||
if no_names:
|
||||
other_names = None
|
||||
|
||||
return SignatureInfo(min_args, max_args, required_names, other_names)
|
||||
|
||||
|
||||
class Concurrency(object):
|
||||
|
||||
def __init__(self, max_concurrent):
|
||||
self._require_non_negative(max_concurrent)
|
||||
self._max_concurrent = max_concurrent
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
def _require_non_negative(self, value):
|
||||
if not isinstance(value, int) or value < 0:
|
||||
raise RuntimeError('concurrency must be a natural number')
|
||||
|
||||
@property
|
||||
def max_concurrent(self):
|
||||
return self._max_concurrent
|
||||
|
||||
async def set_max_concurrent(self, value):
|
||||
self._require_non_negative(value)
|
||||
diff = value - self._max_concurrent
|
||||
self._max_concurrent = value
|
||||
if diff >= 0:
|
||||
for _ in range(diff):
|
||||
self.semaphore.release()
|
||||
else:
|
||||
for _ in range(-diff):
|
||||
await self.semaphore.acquire()
|
||||
|
||||
|
||||
def check_task(logger, task):
|
||||
if not task.cancelled():
|
||||
try:
|
||||
task.result()
|
||||
except Exception:
|
||||
logger.error('task crashed: %r', task, exc_info=True)
|
|
@ -15,7 +15,7 @@ from struct import pack, unpack
|
|||
import time
|
||||
from functools import partial
|
||||
|
||||
from aiorpcx import TaskGroup, run_in_thread
|
||||
from torba.rpc import TaskGroup, run_in_thread
|
||||
|
||||
import torba
|
||||
from torba.server.daemon import DaemonError
|
||||
|
|
|
@ -22,7 +22,7 @@ from torba.server.util import hex_to_bytes, class_logger,\
|
|||
unpack_le_uint16_from, pack_varint
|
||||
from torba.server.hash import hex_str_to_hash, hash_to_hex_str
|
||||
from torba.server.tx import DeserializerDecred
|
||||
from aiorpcx import JSONRPC
|
||||
from torba.rpc import JSONRPC
|
||||
|
||||
|
||||
class DaemonError(Exception):
|
||||
|
|
|
@ -19,8 +19,8 @@ from glob import glob
|
|||
from struct import pack, unpack
|
||||
|
||||
import attr
|
||||
from aiorpcx import run_in_thread, sleep
|
||||
|
||||
from torba.rpc import run_in_thread, sleep
|
||||
from torba.server import util
|
||||
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
||||
from torba.server.merkle import Merkle, MerkleCache
|
||||
|
|
|
@ -14,8 +14,8 @@ from asyncio import Lock
|
|||
from collections import defaultdict
|
||||
|
||||
import attr
|
||||
from aiorpcx import TaskGroup, run_in_thread, sleep
|
||||
|
||||
from torba.rpc import TaskGroup, run_in_thread, sleep
|
||||
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
|
||||
from torba.server.util import class_logger, chunks
|
||||
from torba.server.db import UTXO
|
||||
|
|
|
@ -28,8 +28,7 @@
|
|||
|
||||
from math import ceil, log
|
||||
|
||||
from aiorpcx import Event
|
||||
|
||||
from torba.rpc import Event
|
||||
from torba.server.hash import double_sha256
|
||||
|
||||
|
||||
|
|
|
@ -14,11 +14,10 @@ import ssl
|
|||
import time
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
from aiorpcx import (Connector, RPCSession, SOCKSProxy,
|
||||
from torba.rpc import (Connector, RPCSession, SOCKSProxy,
|
||||
Notification, handler_invocation,
|
||||
SOCKSError, RPCError, TaskTimeout, TaskGroup, Event,
|
||||
sleep, run_in_thread, ignore_after, timeout_after)
|
||||
|
||||
from torba.server.peer import Peer
|
||||
from torba.server.util import class_logger, protocol_tuple
|
||||
|
||||
|
|
|
@ -19,13 +19,12 @@ import time
|
|||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from aiorpcx import (
|
||||
import torba
|
||||
from torba.rpc import (
|
||||
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
|
||||
TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep,
|
||||
Event
|
||||
)
|
||||
|
||||
import torba
|
||||
from torba.server import text
|
||||
from torba.server import util
|
||||
from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
|
||||
|
@ -664,9 +663,9 @@ class SessionBase(RPCSession):
|
|||
super().connection_lost(exc)
|
||||
self.session_mgr.remove_session(self)
|
||||
msg = ''
|
||||
if not self.can_send.is_set():
|
||||
if not self._can_send.is_set():
|
||||
msg += ' whilst paused'
|
||||
if self.concurrency.max_concurrent != self.max_concurrent:
|
||||
if self._concurrency.max_concurrent != self.max_concurrent:
|
||||
msg += ' whilst throttled'
|
||||
if self.send_size >= 1024*1024:
|
||||
msg += ('. Sent {:,d} bytes in {:,d} messages'
|
||||
|
|
Loading…
Reference in a new issue