merged aiorpcx into torba.rpc
This commit is contained in:
parent
ba5fc2a627
commit
899a6f0d4a
15 changed files with 2582 additions and 14 deletions
1
setup.py
1
setup.py
|
@ -10,7 +10,6 @@ with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh:
|
||||||
|
|
||||||
REQUIRES = [
|
REQUIRES = [
|
||||||
'aiohttp',
|
'aiohttp',
|
||||||
'aiorpcx==0.9.0',
|
|
||||||
'coincurve',
|
'coincurve',
|
||||||
'pbkdf2',
|
'pbkdf2',
|
||||||
'cryptography',
|
'cryptography',
|
||||||
|
|
13
torba/rpc/__init__.py
Normal file
13
torba/rpc/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from .curio import *
|
||||||
|
from .framing import *
|
||||||
|
from .jsonrpc import *
|
||||||
|
from .socks import *
|
||||||
|
from .session import *
|
||||||
|
from .util import *
|
||||||
|
|
||||||
|
__all__ = (curio.__all__ +
|
||||||
|
framing.__all__ +
|
||||||
|
jsonrpc.__all__ +
|
||||||
|
socks.__all__ +
|
||||||
|
session.__all__ +
|
||||||
|
util.__all__)
|
411
torba/rpc/curio.py
Normal file
411
torba/rpc/curio.py
Normal file
|
@ -0,0 +1,411 @@
|
||||||
|
# The code below is mostly my own but based on the interfaces of the
|
||||||
|
# curio library by David Beazley. I'm considering switching to using
|
||||||
|
# curio. In the mean-time this is an attempt to provide a similar
|
||||||
|
# clean, pure-async interface and move away from direct
|
||||||
|
# framework-specific dependencies. As asyncio differs in its design
|
||||||
|
# it is not possible to provide identical semantics.
|
||||||
|
#
|
||||||
|
# The curio library is distributed under the following licence:
|
||||||
|
#
|
||||||
|
# Copyright (C) 2015-2017
|
||||||
|
# David Beazley (Dabeaz LLC)
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are
|
||||||
|
# met:
|
||||||
|
#
|
||||||
|
# * Redistributions of source code must retain the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of the David Beazley or Dabeaz LLC may be used to
|
||||||
|
# endorse or promote products derived from this software without
|
||||||
|
# specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from asyncio import (
|
||||||
|
CancelledError, get_event_loop, Queue, Event, Lock, Semaphore,
|
||||||
|
sleep, Task
|
||||||
|
)
|
||||||
|
from collections import deque
|
||||||
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from .util import normalize_corofunc, check_task
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError',
|
||||||
|
'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup',
|
||||||
|
'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
|
||||||
|
'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_in_thread(func, *args):
|
||||||
|
'''Run a function in a separate thread, and await its completion.'''
|
||||||
|
return await get_event_loop().run_in_executor(None, func, *args)
|
||||||
|
|
||||||
|
|
||||||
|
async def spawn(coro, *args, loop=None, report_crash=True):
|
||||||
|
return spawn_sync(coro, *args, loop=loop, report_crash=report_crash)
|
||||||
|
|
||||||
|
|
||||||
|
def spawn_sync(coro, *args, loop=None, report_crash=True):
|
||||||
|
coro = normalize_corofunc(coro, args)
|
||||||
|
loop = loop or get_event_loop()
|
||||||
|
task = loop.create_task(coro)
|
||||||
|
if report_crash:
|
||||||
|
task.add_done_callback(partial(check_task, logging))
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGroup(object):
|
||||||
|
'''A class representing a group of executing tasks. tasks is an
|
||||||
|
optional set of existing tasks to put into the group. New tasks
|
||||||
|
can later be added using the spawn() method below. wait specifies
|
||||||
|
the policy used for waiting for tasks. See the join() method
|
||||||
|
below. Each TaskGroup is an independent entity. Task groups do not
|
||||||
|
form a hierarchy or any kind of relationship to other previously
|
||||||
|
created task groups or tasks. Moreover, Tasks created by the top
|
||||||
|
level spawn() function are not placed into any task group. To
|
||||||
|
create a task in a group, it should be created using
|
||||||
|
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
|
||||||
|
|
||||||
|
completed attribute: the first task that completed with a result
|
||||||
|
in the group. Takes into account the wait option used in the
|
||||||
|
TaskGroup constructor (but not in the join method)`.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, tasks=(), *, wait=all):
|
||||||
|
if wait not in (any, all, object):
|
||||||
|
raise ValueError('invalid wait argument')
|
||||||
|
self._done = deque()
|
||||||
|
self._pending = set()
|
||||||
|
self._wait = wait
|
||||||
|
self._done_event = Event()
|
||||||
|
self._logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
self._closed = False
|
||||||
|
self.completed = None
|
||||||
|
for task in tasks:
|
||||||
|
self._add_task(task)
|
||||||
|
|
||||||
|
def _add_task(self, task):
|
||||||
|
'''Add an already existing task to the task group.'''
|
||||||
|
if hasattr(task, '_task_group'):
|
||||||
|
raise RuntimeError('task is already part of a group')
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError('task group is closed')
|
||||||
|
task._task_group = self
|
||||||
|
if task.done():
|
||||||
|
self._done.append(task)
|
||||||
|
else:
|
||||||
|
self._pending.add(task)
|
||||||
|
task.add_done_callback(self._on_done)
|
||||||
|
|
||||||
|
def _on_done(self, task):
|
||||||
|
task._task_group = None
|
||||||
|
self._pending.remove(task)
|
||||||
|
self._done.append(task)
|
||||||
|
self._done_event.set()
|
||||||
|
if self.completed is None:
|
||||||
|
if not task.cancelled() and not task.exception():
|
||||||
|
if self._wait is object and task.result() is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.completed = task
|
||||||
|
|
||||||
|
async def spawn(self, coro, *args):
|
||||||
|
'''Create a new task that’s part of the group. Returns a Task
|
||||||
|
instance.
|
||||||
|
'''
|
||||||
|
task = await spawn(coro, *args, report_crash=False)
|
||||||
|
self._add_task(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def add_task(self, task):
|
||||||
|
'''Add an already existing task to the task group.'''
|
||||||
|
self._add_task(task)
|
||||||
|
|
||||||
|
async def next_done(self):
|
||||||
|
'''Returns the next completed task. Returns None if no more tasks
|
||||||
|
remain. A TaskGroup may also be used as an asynchronous iterator.
|
||||||
|
'''
|
||||||
|
if not self._done and self._pending:
|
||||||
|
self._done_event.clear()
|
||||||
|
await self._done_event.wait()
|
||||||
|
if self._done:
|
||||||
|
return self._done.popleft()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def next_result(self):
|
||||||
|
'''Returns the result of the next completed task. If the task failed
|
||||||
|
with an exception, that exception is raised. A RuntimeError
|
||||||
|
exception is raised if this is called when no remaining tasks
|
||||||
|
are available.'''
|
||||||
|
task = await self.next_done()
|
||||||
|
if not task:
|
||||||
|
raise RuntimeError('no tasks remain')
|
||||||
|
return task.result()
|
||||||
|
|
||||||
|
async def join(self):
|
||||||
|
'''Wait for tasks in the group to terminate according to the wait
|
||||||
|
policy for the group.
|
||||||
|
|
||||||
|
If the join() operation itself is cancelled, all remaining
|
||||||
|
tasks in the group are also cancelled.
|
||||||
|
|
||||||
|
If a TaskGroup is used as a context manager, the join() method
|
||||||
|
is called on context-exit.
|
||||||
|
|
||||||
|
Once join() returns, no more tasks may be added to the task
|
||||||
|
group. Tasks can be added while join() is running.
|
||||||
|
'''
|
||||||
|
def errored(task):
|
||||||
|
return not task.cancelled() and task.exception()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._wait in (all, object):
|
||||||
|
while True:
|
||||||
|
task = await self.next_done()
|
||||||
|
if task is None:
|
||||||
|
return
|
||||||
|
if errored(task):
|
||||||
|
break
|
||||||
|
if self._wait is object:
|
||||||
|
if task.cancelled() or task.result() is not None:
|
||||||
|
return
|
||||||
|
else: # any
|
||||||
|
task = await self.next_done()
|
||||||
|
if task is None or not errored(task):
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
await self.cancel_remaining()
|
||||||
|
|
||||||
|
if errored(task):
|
||||||
|
raise task.exception()
|
||||||
|
|
||||||
|
async def cancel_remaining(self):
|
||||||
|
'''Cancel all remaining tasks.'''
|
||||||
|
self._closed = True
|
||||||
|
for task in list(self._pending):
|
||||||
|
task.cancel()
|
||||||
|
with suppress(CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
def closed(self):
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
task = await self.next_done()
|
||||||
|
if task:
|
||||||
|
return task
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
if exc_type:
|
||||||
|
await self.cancel_remaining()
|
||||||
|
else:
|
||||||
|
await self.join()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTimeout(CancelledError):
|
||||||
|
|
||||||
|
def __init__(self, secs):
|
||||||
|
self.secs = secs
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'task timed out after {self.args[0]}s'
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutCancellationError(CancelledError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UncaughtTimeoutError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _set_new_deadline(task, deadline):
|
||||||
|
def timeout_task():
|
||||||
|
# Unfortunately task.cancel is all we can do with asyncio
|
||||||
|
task.cancel()
|
||||||
|
task._timed_out = deadline
|
||||||
|
task._deadline_handle = task._loop.call_at(deadline, timeout_task)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_task_deadline(task, deadline):
|
||||||
|
deadlines = getattr(task, '_deadlines', [])
|
||||||
|
if deadlines:
|
||||||
|
if deadline < min(deadlines):
|
||||||
|
task._deadline_handle.cancel()
|
||||||
|
_set_new_deadline(task, deadline)
|
||||||
|
else:
|
||||||
|
_set_new_deadline(task, deadline)
|
||||||
|
deadlines.append(deadline)
|
||||||
|
task._deadlines = deadlines
|
||||||
|
task._timed_out = None
|
||||||
|
|
||||||
|
|
||||||
|
def _unset_task_deadline(task):
|
||||||
|
deadlines = task._deadlines
|
||||||
|
timed_out_deadline = task._timed_out
|
||||||
|
uncaught = timed_out_deadline not in deadlines
|
||||||
|
task._deadline_handle.cancel()
|
||||||
|
deadlines.pop()
|
||||||
|
if deadlines:
|
||||||
|
_set_new_deadline(task, min(deadlines))
|
||||||
|
return timed_out_deadline, uncaught
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutAfter(object):
|
||||||
|
|
||||||
|
def __init__(self, deadline, *, ignore=False, absolute=False):
|
||||||
|
self._deadline = deadline
|
||||||
|
self._ignore = ignore
|
||||||
|
self._absolute = absolute
|
||||||
|
self.expired = False
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
task = asyncio.current_task()
|
||||||
|
loop_time = task._loop.time()
|
||||||
|
if self._absolute:
|
||||||
|
self._secs = self._deadline - loop_time
|
||||||
|
else:
|
||||||
|
self._secs = self._deadline
|
||||||
|
self._deadline += loop_time
|
||||||
|
_set_task_deadline(task, self._deadline)
|
||||||
|
self.expired = False
|
||||||
|
self._task = task
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
timed_out_deadline, uncaught = _unset_task_deadline(self._task)
|
||||||
|
if exc_type not in (CancelledError, TaskTimeout,
|
||||||
|
TimeoutCancellationError):
|
||||||
|
return False
|
||||||
|
if timed_out_deadline == self._deadline:
|
||||||
|
self.expired = True
|
||||||
|
if self._ignore:
|
||||||
|
return True
|
||||||
|
raise TaskTimeout(self._secs) from None
|
||||||
|
if timed_out_deadline is None:
|
||||||
|
assert exc_type is CancelledError
|
||||||
|
return False
|
||||||
|
if uncaught:
|
||||||
|
raise UncaughtTimeoutError('uncaught timeout received')
|
||||||
|
if exc_type is TimeoutCancellationError:
|
||||||
|
return False
|
||||||
|
raise TimeoutCancellationError(timed_out_deadline) from None
|
||||||
|
|
||||||
|
|
||||||
|
async def _timeout_after_func(seconds, absolute, coro, args):
|
||||||
|
coro = normalize_corofunc(coro, args)
|
||||||
|
async with TimeoutAfter(seconds, absolute=absolute):
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
|
||||||
|
def timeout_after(seconds, coro=None, *args):
|
||||||
|
'''Execute the specified coroutine and return its result. However,
|
||||||
|
issue a cancellation request to the calling task after seconds
|
||||||
|
have elapsed. When this happens, a TaskTimeout exception is
|
||||||
|
raised. If coro is None, the result of this function serves
|
||||||
|
as an asynchronous context manager that applies a timeout to a
|
||||||
|
block of statements.
|
||||||
|
|
||||||
|
timeout_after() may be composed with other timeout_after()
|
||||||
|
operations (i.e., nested timeouts). If an outer timeout expires
|
||||||
|
first, then TimeoutCancellationError is raised instead of
|
||||||
|
TaskTimeout. If an inner timeout expires and fails to properly
|
||||||
|
TaskTimeout, a UncaughtTimeoutError is raised in the outer
|
||||||
|
timeout.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if coro:
|
||||||
|
return _timeout_after_func(seconds, False, coro, args)
|
||||||
|
|
||||||
|
return TimeoutAfter(seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def timeout_at(clock, coro=None, *args):
|
||||||
|
'''Execute the specified coroutine and return its result. However,
|
||||||
|
issue a cancellation request to the calling task after seconds
|
||||||
|
have elapsed. When this happens, a TaskTimeout exception is
|
||||||
|
raised. If coro is None, the result of this function serves
|
||||||
|
as an asynchronous context manager that applies a timeout to a
|
||||||
|
block of statements.
|
||||||
|
|
||||||
|
timeout_after() may be composed with other timeout_after()
|
||||||
|
operations (i.e., nested timeouts). If an outer timeout expires
|
||||||
|
first, then TimeoutCancellationError is raised instead of
|
||||||
|
TaskTimeout. If an inner timeout expires and fails to properly
|
||||||
|
TaskTimeout, a UncaughtTimeoutError is raised in the outer
|
||||||
|
timeout.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if coro:
|
||||||
|
return _timeout_after_func(clock, True, coro, args)
|
||||||
|
|
||||||
|
return TimeoutAfter(clock, absolute=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
|
||||||
|
coro = normalize_corofunc(coro, args)
|
||||||
|
async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
return timeout_result
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_after(seconds, coro=None, *args, timeout_result=None):
|
||||||
|
'''Execute the specified coroutine and return its result. Issue a
|
||||||
|
cancellation request after seconds have elapsed. When a timeout
|
||||||
|
occurs, no exception is raised. Instead, timeout_result is
|
||||||
|
returned.
|
||||||
|
|
||||||
|
If coro is None, the result is an asynchronous context manager
|
||||||
|
that applies a timeout to a block of statements. For the context
|
||||||
|
manager case, the resulting context manager object has an expired
|
||||||
|
attribute set to True if time expired.
|
||||||
|
|
||||||
|
Note: ignore_after() may also be composed with other timeout
|
||||||
|
operations. TimeoutCancellationError and UncaughtTimeoutError
|
||||||
|
exceptions might be raised according to the same rules as for
|
||||||
|
timeout_after().
|
||||||
|
'''
|
||||||
|
if coro:
|
||||||
|
return _ignore_after_func(seconds, False, coro, args, timeout_result)
|
||||||
|
|
||||||
|
return TimeoutAfter(seconds, ignore=True)
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_at(clock, coro=None, *args, timeout_result=None):
|
||||||
|
'''
|
||||||
|
Stop the enclosed task or block of code at an absolute
|
||||||
|
clock value. Same usage as ignore_after().
|
||||||
|
'''
|
||||||
|
if coro:
|
||||||
|
return _ignore_after_func(clock, True, coro, args, timeout_result)
|
||||||
|
|
||||||
|
return TimeoutAfter(clock, absolute=True, ignore=True)
|
239
torba/rpc/framing.py
Normal file
239
torba/rpc/framing.py
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
'''RPC message framing in a byte stream.'''
|
||||||
|
|
||||||
|
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
|
||||||
|
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
|
||||||
|
|
||||||
|
from hashlib import sha256 as _sha256
|
||||||
|
from struct import Struct
|
||||||
|
from asyncio import Queue
|
||||||
|
|
||||||
|
|
||||||
|
class FramerBase(object):
|
||||||
|
'''Abstract base class for a framer.
|
||||||
|
|
||||||
|
A framer breaks an incoming byte stream into protocol messages,
|
||||||
|
buffering if necesary. It also frames outgoing messages into
|
||||||
|
a byte stream.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def frame(self, message):
|
||||||
|
'''Return the framed message.'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def received_bytes(self, data):
|
||||||
|
'''Pass incoming network bytes.'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def receive_message(self):
|
||||||
|
'''Wait for a complete unframed message to arrive, and return it.'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class NewlineFramer(FramerBase):
|
||||||
|
'''A framer for a protocol where messages are separated by newlines.'''
|
||||||
|
|
||||||
|
# The default max_size value is motivated by JSONRPC, where a
|
||||||
|
# normal request will be 250 bytes or less, and a reasonable
|
||||||
|
# batch may contain 4000 requests.
|
||||||
|
def __init__(self, max_size=250 * 4000):
|
||||||
|
'''max_size - an anti-DoS measure. If, after processing an incoming
|
||||||
|
message, buffered data would exceed max_size bytes, that
|
||||||
|
buffered data is dropped entirely and the framer waits for a
|
||||||
|
newline character to re-synchronize the stream.
|
||||||
|
'''
|
||||||
|
self.max_size = max_size
|
||||||
|
self.queue = Queue()
|
||||||
|
self.received_bytes = self.queue.put_nowait
|
||||||
|
self.synchronizing = False
|
||||||
|
self.residual = b''
|
||||||
|
|
||||||
|
def frame(self, message):
|
||||||
|
return message + b'\n'
|
||||||
|
|
||||||
|
async def receive_message(self):
|
||||||
|
parts = []
|
||||||
|
buffer_size = 0
|
||||||
|
while True:
|
||||||
|
part = self.residual
|
||||||
|
self.residual = b''
|
||||||
|
if not part:
|
||||||
|
part = await self.queue.get()
|
||||||
|
|
||||||
|
npos = part.find(b'\n')
|
||||||
|
if npos == -1:
|
||||||
|
parts.append(part)
|
||||||
|
buffer_size += len(part)
|
||||||
|
# Ignore over-sized messages; re-synchronize
|
||||||
|
if buffer_size <= self.max_size:
|
||||||
|
continue
|
||||||
|
self.synchronizing = True
|
||||||
|
raise MemoryError(f'dropping message over {self.max_size:,d} '
|
||||||
|
f'bytes and re-synchronizing')
|
||||||
|
|
||||||
|
tail, self.residual = part[:npos], part[npos + 1:]
|
||||||
|
if self.synchronizing:
|
||||||
|
self.synchronizing = False
|
||||||
|
return await self.receive_message()
|
||||||
|
else:
|
||||||
|
parts.append(tail)
|
||||||
|
return b''.join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class ByteQueue(object):
|
||||||
|
'''A producer-comsumer queue. Incoming network data is put as it
|
||||||
|
arrives, and the consumer calls an async method waiting for data of
|
||||||
|
a specific length.'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.queue = Queue()
|
||||||
|
self.parts = []
|
||||||
|
self.parts_len = 0
|
||||||
|
self.put_nowait = self.queue.put_nowait
|
||||||
|
|
||||||
|
async def receive(self, size):
|
||||||
|
while self.parts_len < size:
|
||||||
|
part = await self.queue.get()
|
||||||
|
self.parts.append(part)
|
||||||
|
self.parts_len += len(part)
|
||||||
|
self.parts_len -= size
|
||||||
|
whole = b''.join(self.parts)
|
||||||
|
self.parts = [whole[size:]]
|
||||||
|
return whole[:size]
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryFramer(object):
|
||||||
|
'''A framer for binary messaging protocols.'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.byte_queue = ByteQueue()
|
||||||
|
self.message_queue = Queue()
|
||||||
|
self.received_bytes = self.byte_queue.put_nowait
|
||||||
|
|
||||||
|
def frame(self, message):
|
||||||
|
command, payload = message
|
||||||
|
return b''.join((
|
||||||
|
self._build_header(command, payload),
|
||||||
|
payload
|
||||||
|
))
|
||||||
|
|
||||||
|
async def receive_message(self):
|
||||||
|
command, payload_len, checksum = await self._receive_header()
|
||||||
|
payload = await self.byte_queue.receive(payload_len)
|
||||||
|
payload_checksum = self._checksum(payload)
|
||||||
|
if payload_checksum != checksum:
|
||||||
|
raise BadChecksumError(payload_checksum, checksum)
|
||||||
|
return command, payload
|
||||||
|
|
||||||
|
def _checksum(self, payload):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _build_header(self, command, payload):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _receive_header(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
struct_le_I = Struct('<I')
|
||||||
|
pack_le_uint32 = struct_le_I.pack
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(x):
|
||||||
|
'''Simple wrapper of hashlib sha256.'''
|
||||||
|
return _sha256(x).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def double_sha256(x):
|
||||||
|
'''SHA-256 of SHA-256, as used extensively in bitcoin.'''
|
||||||
|
return sha256(sha256(x))
|
||||||
|
|
||||||
|
|
||||||
|
class BadChecksumError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BadMagicError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OversizedPayloadError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BitcoinFramer(BinaryFramer):
|
||||||
|
'''Provides a framer of binary message payloads in the style of the
|
||||||
|
Bitcoin network protocol.
|
||||||
|
|
||||||
|
Each binary message has the following elements, in order:
|
||||||
|
|
||||||
|
Magic - to confirm network (currently unused for stream sync)
|
||||||
|
Command - padded command
|
||||||
|
Length - payload length in bytes
|
||||||
|
Checksum - checksum of the payload
|
||||||
|
Payload - binary payload
|
||||||
|
|
||||||
|
Call frame(command, payload) to get a framed message.
|
||||||
|
Pass incoming network bytes to received_bytes().
|
||||||
|
Wait on receive_message() to get incoming (command, payload) pairs.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, magic, max_block_size):
|
||||||
|
def pad_command(command):
|
||||||
|
fill = 12 - len(command)
|
||||||
|
if fill < 0:
|
||||||
|
raise ValueError(f'command {command} too long')
|
||||||
|
return command + bytes(fill)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self._magic = magic
|
||||||
|
self._max_block_size = max_block_size
|
||||||
|
self._pad_command = pad_command
|
||||||
|
self._unpack = Struct(f'<4s12sI4s').unpack
|
||||||
|
|
||||||
|
def _checksum(self, payload):
|
||||||
|
return double_sha256(payload)[:4]
|
||||||
|
|
||||||
|
def _build_header(self, command, payload):
|
||||||
|
return b''.join((
|
||||||
|
self._magic,
|
||||||
|
self._pad_command(command),
|
||||||
|
pack_le_uint32(len(payload)),
|
||||||
|
self._checksum(payload)
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _receive_header(self):
|
||||||
|
header = await self.byte_queue.receive(24)
|
||||||
|
magic, command, payload_len, checksum = self._unpack(header)
|
||||||
|
if magic != self._magic:
|
||||||
|
raise BadMagicError(magic, self._magic)
|
||||||
|
command = command.rstrip(b'\0')
|
||||||
|
if payload_len > 1024 * 1024:
|
||||||
|
if command != b'block' or payload_len > self._max_block_size:
|
||||||
|
raise OversizedPayloadError(command, payload_len)
|
||||||
|
return command, payload_len, checksum
|
801
torba/rpc/jsonrpc.py
Normal file
801
torba/rpc/jsonrpc.py
Normal file
|
@ -0,0 +1,801 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
'''Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.'''
|
||||||
|
|
||||||
|
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
||||||
|
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
|
||||||
|
'RPCError', 'ProtocolError',
|
||||||
|
'JSONRPCConnection', 'handler_invocation')
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
from numbers import Number
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from asyncio import Queue, Event, CancelledError
|
||||||
|
from .util import signature_info
|
||||||
|
|
||||||
|
|
||||||
|
class SingleRequest(object):
|
||||||
|
__slots__ = ('method', 'args')
|
||||||
|
|
||||||
|
def __init__(self, method, args):
|
||||||
|
if not isinstance(method, str):
|
||||||
|
raise ProtocolError(JSONRPC.METHOD_NOT_FOUND,
|
||||||
|
'method must be a string')
|
||||||
|
if not isinstance(args, (list, tuple, dict)):
|
||||||
|
raise ProtocolError.invalid_args('request arguments must be a '
|
||||||
|
'list or a dictionary')
|
||||||
|
self.args = args
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.method!r}, {self.args!r})'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (isinstance(other, self.__class__) and
|
||||||
|
self.method == other.method and self.args == other.args)
|
||||||
|
|
||||||
|
|
||||||
|
class Request(SingleRequest):
|
||||||
|
def send_result(self, response):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Notification(SingleRequest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Batch(object):
|
||||||
|
__slots__ = ('items', )
|
||||||
|
|
||||||
|
def __init__(self, items):
|
||||||
|
if not isinstance(items, (list, tuple)):
|
||||||
|
raise ProtocolError.invalid_request('items must be a list')
|
||||||
|
if not items:
|
||||||
|
raise ProtocolError.empty_batch()
|
||||||
|
if not (all(isinstance(item, SingleRequest) for item in items) or
|
||||||
|
all(isinstance(item, Response) for item in items)):
|
||||||
|
raise ProtocolError.invalid_request('batch must be homogeneous')
|
||||||
|
self.items = items
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.items)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.items[item]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.items)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Batch({len(self.items)} items)'
|
||||||
|
|
||||||
|
|
||||||
|
class Response(object):
|
||||||
|
__slots__ = ('result', )
|
||||||
|
|
||||||
|
def __init__(self, result):
|
||||||
|
# Type checking happens when converting to a message
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
|
||||||
|
class CodeMessageError(Exception):
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
super().__init__(code, message)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code(self):
|
||||||
|
return self.args[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self.args[1]
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (isinstance(other, self.__class__) and
|
||||||
|
self.code == other.code and self.message == other.message)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# overridden to make the exception hashable
|
||||||
|
# see https://bugs.python.org/issue28603
|
||||||
|
return hash((self.code, self.message))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invalid_args(cls, message):
|
||||||
|
return cls(JSONRPC.INVALID_ARGS, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invalid_request(cls, message):
|
||||||
|
return cls(JSONRPC.INVALID_REQUEST, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty_batch(cls):
|
||||||
|
return cls.invalid_request('batch is empty')
|
||||||
|
|
||||||
|
|
||||||
|
class RPCError(CodeMessageError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolError(CodeMessageError):
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
super().__init__(code, message)
|
||||||
|
# If not None send this unframed message over the network
|
||||||
|
self.error_message = None
|
||||||
|
# If the error was in a JSON response message; its message ID.
|
||||||
|
# Since None can be a response message ID, "id" means the
|
||||||
|
# error was not sent in a JSON response
|
||||||
|
self.response_msg_id = id
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPC(object):
|
||||||
|
'''Abstract base class that interprets and constructs JSON RPC messages.'''
|
||||||
|
|
||||||
|
# Error codes. See http://www.jsonrpc.org/specification
|
||||||
|
PARSE_ERROR = -32700
|
||||||
|
INVALID_REQUEST = -32600
|
||||||
|
METHOD_NOT_FOUND = -32601
|
||||||
|
INVALID_ARGS = -32602
|
||||||
|
INTERNAL_ERROR = -32603
|
||||||
|
# Codes specific to this library
|
||||||
|
ERROR_CODE_UNAVAILABLE = -100
|
||||||
|
|
||||||
|
# Can be overridden by derived classes
|
||||||
|
allow_batches = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_id(cls, message, require_id):
|
||||||
|
'''Validate the message is a dictionary and return its ID.
|
||||||
|
|
||||||
|
Raise an error if the message is invalid or the ID is of an
|
||||||
|
invalid type. If it has no ID, raise an error if require_id
|
||||||
|
is True, otherwise return None.
|
||||||
|
'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_message(cls, message):
|
||||||
|
'''Validate other parts of the message other than those
|
||||||
|
done in _message_id.'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _request_args(cls, request):
|
||||||
|
'''Validate the existence and type of the arguments passed
|
||||||
|
in the request dictionary.'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process_request(cls, payload):
|
||||||
|
request_id = None
|
||||||
|
try:
|
||||||
|
request_id = cls._message_id(payload, False)
|
||||||
|
cls._validate_message(payload)
|
||||||
|
method = payload.get('method')
|
||||||
|
if request_id is None:
|
||||||
|
item = Notification(method, cls._request_args(payload))
|
||||||
|
else:
|
||||||
|
item = Request(method, cls._request_args(payload))
|
||||||
|
return item, request_id
|
||||||
|
except ProtocolError as error:
|
||||||
|
code, message = error.code, error.message
|
||||||
|
raise cls._error(code, message, True, request_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process_response(cls, payload):
|
||||||
|
request_id = None
|
||||||
|
try:
|
||||||
|
request_id = cls._message_id(payload, True)
|
||||||
|
cls._validate_message(payload)
|
||||||
|
return Response(cls.response_value(payload)), request_id
|
||||||
|
except ProtocolError as error:
|
||||||
|
code, message = error.code, error.message
|
||||||
|
raise cls._error(code, message, False, request_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_to_payload(cls, message):
|
||||||
|
'''Returns a Python object or a ProtocolError.'''
|
||||||
|
try:
|
||||||
|
return json.loads(message.decode())
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
message = 'messages must be encoded in UTF-8'
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
message = 'invalid JSON'
|
||||||
|
raise cls._error(cls.PARSE_ERROR, message, True, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _error(cls, code, message, send, msg_id):
|
||||||
|
error = ProtocolError(code, message)
|
||||||
|
if send:
|
||||||
|
error.error_message = cls.response_message(error, msg_id)
|
||||||
|
else:
|
||||||
|
error.response_msg_id = msg_id
|
||||||
|
return error
|
||||||
|
|
||||||
|
#
|
||||||
|
# External API
|
||||||
|
#
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def message_to_item(cls, message):
|
||||||
|
'''Translate an unframed received message and return an
|
||||||
|
(item, request_id) pair.
|
||||||
|
|
||||||
|
The item can be a Request, Notification, Response or a list.
|
||||||
|
|
||||||
|
A JSON RPC error response is returned as an RPCError inside a
|
||||||
|
Response object.
|
||||||
|
|
||||||
|
If a Batch is returned, request_id is an iterable of request
|
||||||
|
ids, one per batch member.
|
||||||
|
|
||||||
|
If the message violates the protocol in some way a
|
||||||
|
ProtocolError is returned, except if the message was
|
||||||
|
determined to be a response, in which case the ProtocolError
|
||||||
|
is placed inside a Response object. This is so that client
|
||||||
|
code can mark a request as having been responded to even if
|
||||||
|
the response was bad.
|
||||||
|
|
||||||
|
raises: ProtocolError
|
||||||
|
'''
|
||||||
|
payload = cls._message_to_payload(message)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
if 'method' in payload:
|
||||||
|
return cls._process_request(payload)
|
||||||
|
else:
|
||||||
|
return cls._process_response(payload)
|
||||||
|
elif isinstance(payload, list) and cls.allow_batches:
|
||||||
|
if not payload:
|
||||||
|
raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty',
|
||||||
|
True, None)
|
||||||
|
return payload, None
|
||||||
|
raise cls._error(cls.INVALID_REQUEST,
|
||||||
|
'request object must be a dictionary', True, None)
|
||||||
|
|
||||||
|
# Message formation
|
||||||
|
@classmethod
|
||||||
|
def request_message(cls, item, request_id):
|
||||||
|
'''Convert an RPCRequest item to a message.'''
|
||||||
|
assert isinstance(item, Request)
|
||||||
|
return cls.encode_payload(cls.request_payload(item, request_id))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def notification_message(cls, item):
|
||||||
|
'''Convert an RPCRequest item to a message.'''
|
||||||
|
assert isinstance(item, Notification)
|
||||||
|
return cls.encode_payload(cls.request_payload(item, None))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_message(cls, result, request_id):
|
||||||
|
'''Convert a response result (or RPCError) to a message.'''
|
||||||
|
if isinstance(result, CodeMessageError):
|
||||||
|
payload = cls.error_payload(result, request_id)
|
||||||
|
else:
|
||||||
|
payload = cls.response_payload(result, request_id)
|
||||||
|
return cls.encode_payload(payload)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_message(cls, batch, request_ids):
|
||||||
|
'''Convert a request Batch to a message.'''
|
||||||
|
assert isinstance(batch, Batch)
|
||||||
|
if not cls.allow_batches:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'protocol does not permit batches')
|
||||||
|
id_iter = iter(request_ids)
|
||||||
|
rm = cls.request_message
|
||||||
|
nm = cls.notification_message
|
||||||
|
parts = (rm(request, next(id_iter)) if isinstance(request, Request)
|
||||||
|
else nm(request) for request in batch)
|
||||||
|
return cls.batch_message_from_parts(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_message_from_parts(cls, messages):
|
||||||
|
'''Convert messages, one per batch item, into a batch message. At
|
||||||
|
least one message must be passed.
|
||||||
|
'''
|
||||||
|
# Comma-separate the messages and wrap the lot in square brackets
|
||||||
|
middle = b', '.join(messages)
|
||||||
|
if not middle:
|
||||||
|
raise ProtocolError.empty_batch()
|
||||||
|
return b''.join([b'[', middle, b']'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encode_payload(cls, payload):
|
||||||
|
'''Encode a Python object as JSON and convert it to bytes.'''
|
||||||
|
try:
|
||||||
|
return json.dumps(payload).encode()
|
||||||
|
except TypeError:
|
||||||
|
msg = f'JSON payload encoding error: {payload}'
|
||||||
|
raise ProtocolError(cls.INTERNAL_ERROR, msg) from None
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCv1(JSONRPC):
|
||||||
|
'''JSON RPC version 1.0.'''
|
||||||
|
|
||||||
|
allow_batches = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_id(cls, message, require_id):
|
||||||
|
# JSONv1 requires an ID always, but without constraint on its type
|
||||||
|
# No need to test for a dictionary here as we don't handle batches.
|
||||||
|
if 'id' not in message:
|
||||||
|
raise ProtocolError.invalid_request('request has no "id"')
|
||||||
|
return message['id']
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _request_args(cls, request):
|
||||||
|
args = request.get('params')
|
||||||
|
if not isinstance(args, list):
|
||||||
|
raise ProtocolError.invalid_args(
|
||||||
|
f'invalid request arguments: {args}')
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _best_effort_error(cls, error):
|
||||||
|
# Do our best to interpret the error
|
||||||
|
code = cls.ERROR_CODE_UNAVAILABLE
|
||||||
|
message = 'no error message provided'
|
||||||
|
if isinstance(error, str):
|
||||||
|
message = error
|
||||||
|
elif isinstance(error, int):
|
||||||
|
code = error
|
||||||
|
elif isinstance(error, dict):
|
||||||
|
if isinstance(error.get('message'), str):
|
||||||
|
message = error['message']
|
||||||
|
if isinstance(error.get('code'), int):
|
||||||
|
code = error['code']
|
||||||
|
|
||||||
|
return RPCError(code, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_value(cls, payload):
|
||||||
|
if 'result' not in payload or 'error' not in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response must contain both "result" and "error"')
|
||||||
|
|
||||||
|
result = payload['result']
|
||||||
|
error = payload['error']
|
||||||
|
if error is None:
|
||||||
|
return result # It seems None can be a valid result
|
||||||
|
if result is not None:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response has a "result" and an "error"')
|
||||||
|
|
||||||
|
return cls._best_effort_error(error)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_payload(cls, request, request_id):
|
||||||
|
'''JSON v1 request (or notification) payload.'''
|
||||||
|
if isinstance(request.args, dict):
|
||||||
|
raise ProtocolError.invalid_args(
|
||||||
|
'JSONRPCv1 does not support named arguments')
|
||||||
|
return {
|
||||||
|
'method': request.method,
|
||||||
|
'params': request.args,
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_payload(cls, result, request_id):
|
||||||
|
'''JSON v1 response payload.'''
|
||||||
|
return {
|
||||||
|
'result': result,
|
||||||
|
'error': None,
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_payload(cls, error, request_id):
|
||||||
|
return {
|
||||||
|
'result': None,
|
||||||
|
'error': {'code': error.code, 'message': error.message},
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCv2(JSONRPC):
|
||||||
|
'''JSON RPC version 2.0.'''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _message_id(cls, message, require_id):
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'request object must be a dictionary')
|
||||||
|
if 'id' in message:
|
||||||
|
request_id = message['id']
|
||||||
|
if not isinstance(request_id, (Number, str, type(None))):
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
f'invalid "id": {request_id}')
|
||||||
|
return request_id
|
||||||
|
else:
|
||||||
|
if require_id:
|
||||||
|
raise ProtocolError.invalid_request('request has no "id"')
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_message(cls, message):
|
||||||
|
if message.get('jsonrpc') != '2.0':
|
||||||
|
raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _request_args(cls, request):
|
||||||
|
args = request.get('params', [])
|
||||||
|
if not isinstance(args, (dict, list)):
|
||||||
|
raise ProtocolError.invalid_args(
|
||||||
|
f'invalid request arguments: {args}')
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_value(cls, payload):
|
||||||
|
if 'result' in payload:
|
||||||
|
if 'error' in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains both "result" and "error"')
|
||||||
|
return payload['result']
|
||||||
|
|
||||||
|
if 'error' not in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains neither "result" nor "error"')
|
||||||
|
|
||||||
|
# Return an RPCError object
|
||||||
|
error = payload['error']
|
||||||
|
if isinstance(error, dict):
|
||||||
|
code = error.get('code')
|
||||||
|
message = error.get('message')
|
||||||
|
if isinstance(code, int) and isinstance(message, str):
|
||||||
|
return RPCError(code, message)
|
||||||
|
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
f'ill-formed response error object: {error}')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_payload(cls, request, request_id):
|
||||||
|
'''JSON v2 request (or notification) payload.'''
|
||||||
|
payload = {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'method': request.method,
|
||||||
|
}
|
||||||
|
# A notification?
|
||||||
|
if request_id is not None:
|
||||||
|
payload['id'] = request_id
|
||||||
|
# Preserve empty dicts as missing params is read as an array
|
||||||
|
if request.args or request.args == {}:
|
||||||
|
payload['params'] = request.args
|
||||||
|
return payload
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_payload(cls, result, request_id):
|
||||||
|
'''JSON v2 response payload.'''
|
||||||
|
return {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'result': result,
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_payload(cls, error, request_id):
|
||||||
|
return {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'error': {'code': error.code, 'message': error.message},
|
||||||
|
'id': request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCLoose(JSONRPC):
|
||||||
|
'''A relaxed versin of JSON RPC.'''
|
||||||
|
|
||||||
|
# Don't be so loose we accept any old message ID
|
||||||
|
_message_id = JSONRPCv2._message_id
|
||||||
|
_validate_message = JSONRPC._validate_message
|
||||||
|
_request_args = JSONRPCv2._request_args
|
||||||
|
# Outoing messages are JSONRPCv2 so we give the other side the
|
||||||
|
# best chance to assume / detect JSONRPCv2 as default protocol.
|
||||||
|
error_payload = JSONRPCv2.error_payload
|
||||||
|
request_payload = JSONRPCv2.request_payload
|
||||||
|
response_payload = JSONRPCv2.response_payload
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response_value(cls, payload):
|
||||||
|
# Return result, unless it is None and there is an error
|
||||||
|
if payload.get('error') is not None:
|
||||||
|
if payload.get('result') is not None:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains both "result" and "error"')
|
||||||
|
return JSONRPCv1._best_effort_error(payload['error'])
|
||||||
|
|
||||||
|
if 'result' not in payload:
|
||||||
|
raise ProtocolError.invalid_request(
|
||||||
|
'response contains neither "result" nor "error"')
|
||||||
|
|
||||||
|
# Can be None
|
||||||
|
return payload['result']
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCAutoDetect(JSONRPCv2):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def message_to_item(cls, message):
|
||||||
|
return cls.detect_protocol(message), None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_protocol(cls, message):
|
||||||
|
'''Attempt to detect the protocol from the message.'''
|
||||||
|
main = cls._message_to_payload(message)
|
||||||
|
|
||||||
|
def protocol_for_payload(payload):
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return JSONRPCLoose # Will error
|
||||||
|
# Obey an explicit "jsonrpc"
|
||||||
|
version = payload.get('jsonrpc')
|
||||||
|
if version == '2.0':
|
||||||
|
return JSONRPCv2
|
||||||
|
if version == '1.0':
|
||||||
|
return JSONRPCv1
|
||||||
|
|
||||||
|
# Now to decide between JSONRPCLoose and JSONRPCv1 if possible
|
||||||
|
if 'result' in payload and 'error' in payload:
|
||||||
|
return JSONRPCv1
|
||||||
|
return JSONRPCLoose
|
||||||
|
|
||||||
|
if isinstance(main, list):
|
||||||
|
parts = set(protocol_for_payload(payload) for payload in main)
|
||||||
|
# If all same protocol, return it
|
||||||
|
if len(parts) == 1:
|
||||||
|
return parts.pop()
|
||||||
|
# If strict protocol detected, return it, preferring JSONRPCv2.
|
||||||
|
# This means a batch of JSONRPCv1 will fail
|
||||||
|
for protocol in (JSONRPCv2, JSONRPCv1):
|
||||||
|
if protocol in parts:
|
||||||
|
return protocol
|
||||||
|
# Will error if no parts
|
||||||
|
return JSONRPCLoose
|
||||||
|
|
||||||
|
return protocol_for_payload(main)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONRPCConnection(object):
|
||||||
|
'''Maintains state of a JSON RPC connection, in particular
|
||||||
|
encapsulating the handling of request IDs.
|
||||||
|
|
||||||
|
protocol - the JSON RPC protocol to follow
|
||||||
|
max_response_size - responses over this size send an error response
|
||||||
|
instead.
|
||||||
|
'''
|
||||||
|
|
||||||
|
_id_counter = itertools.count()
|
||||||
|
|
||||||
|
def __init__(self, protocol):
|
||||||
|
self._protocol = protocol
|
||||||
|
# Sent Requests and Batches that have not received a response.
|
||||||
|
# The key is its request ID; for a batch it is sorted tuple
|
||||||
|
# of request IDs
|
||||||
|
self._requests = {}
|
||||||
|
# A public attribute intended to be settable dynamically
|
||||||
|
self.max_response_size = 0
|
||||||
|
|
||||||
|
def _oversized_response_message(self, request_id):
|
||||||
|
text = f'response too large (over {self.max_response_size:,d} bytes'
|
||||||
|
error = RPCError.invalid_request(text)
|
||||||
|
return self._protocol.response_message(error, request_id)
|
||||||
|
|
||||||
|
def _receive_response(self, result, request_id):
|
||||||
|
if request_id not in self._requests:
|
||||||
|
if request_id is None and isinstance(result, RPCError):
|
||||||
|
message = f'diagnostic error received: {result}'
|
||||||
|
else:
|
||||||
|
message = f'response to unsent request (ID: {request_id})'
|
||||||
|
raise ProtocolError.invalid_request(message) from None
|
||||||
|
request, event = self._requests.pop(request_id)
|
||||||
|
event.result = result
|
||||||
|
event.set()
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _receive_request_batch(self, payloads):
|
||||||
|
def item_send_result(request_id, result):
|
||||||
|
nonlocal size
|
||||||
|
part = protocol.response_message(result, request_id)
|
||||||
|
size += len(part) + 2
|
||||||
|
if size > self.max_response_size > 0:
|
||||||
|
part = self._oversized_response_message(request_id)
|
||||||
|
parts.append(part)
|
||||||
|
if len(parts) == count:
|
||||||
|
return protocol.batch_message_from_parts(parts)
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
items = []
|
||||||
|
size = 0
|
||||||
|
count = 0
|
||||||
|
protocol = self._protocol
|
||||||
|
for payload in payloads:
|
||||||
|
try:
|
||||||
|
item, request_id = protocol._process_request(payload)
|
||||||
|
items.append(item)
|
||||||
|
if isinstance(item, Request):
|
||||||
|
count += 1
|
||||||
|
item.send_result = partial(item_send_result, request_id)
|
||||||
|
except ProtocolError as error:
|
||||||
|
count += 1
|
||||||
|
parts.append(error.error_message)
|
||||||
|
|
||||||
|
if not items and parts:
|
||||||
|
error = ProtocolError(0, "")
|
||||||
|
error.error_message = protocol.batch_message_from_parts(parts)
|
||||||
|
raise error
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _receive_response_batch(self, payloads):
|
||||||
|
request_ids = []
|
||||||
|
results = []
|
||||||
|
for payload in payloads:
|
||||||
|
# Let ProtocolError exceptions through
|
||||||
|
item, request_id = self._protocol._process_response(payload)
|
||||||
|
request_ids.append(request_id)
|
||||||
|
results.append(item.result)
|
||||||
|
|
||||||
|
ordered = sorted(zip(request_ids, results), key=lambda t: t[0])
|
||||||
|
ordered_ids, ordered_results = zip(*ordered)
|
||||||
|
if ordered_ids not in self._requests:
|
||||||
|
raise ProtocolError.invalid_request('response to unsent batch')
|
||||||
|
request_batch, event = self._requests.pop(ordered_ids)
|
||||||
|
event.result = ordered_results
|
||||||
|
event.set()
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _send_result(self, request_id, result):
|
||||||
|
message = self._protocol.response_message(result, request_id)
|
||||||
|
if len(message) > self.max_response_size > 0:
|
||||||
|
message = self._oversized_response_message(request_id)
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _event(self, request, request_id):
|
||||||
|
event = Event()
|
||||||
|
self._requests[request_id] = (request, event)
|
||||||
|
return event
|
||||||
|
|
||||||
|
#
|
||||||
|
# External API
|
||||||
|
#
|
||||||
|
def send_request(self, request):
|
||||||
|
'''Send a Request. Return a (message, event) pair.
|
||||||
|
|
||||||
|
The message is an unframed message to send over the network.
|
||||||
|
Wait on the event for the response; which will be in the
|
||||||
|
"result" attribute.
|
||||||
|
|
||||||
|
Raises: ProtocolError if the request violates the protocol
|
||||||
|
in some way..
|
||||||
|
'''
|
||||||
|
request_id = next(self._id_counter)
|
||||||
|
message = self._protocol.request_message(request, request_id)
|
||||||
|
return message, self._event(request, request_id)
|
||||||
|
|
||||||
|
def send_notification(self, notification):
|
||||||
|
return self._protocol.notification_message(notification)
|
||||||
|
|
||||||
|
def send_batch(self, batch):
|
||||||
|
ids = tuple(next(self._id_counter)
|
||||||
|
for request in batch if isinstance(request, Request))
|
||||||
|
message = self._protocol.batch_message(batch, ids)
|
||||||
|
event = self._event(batch, ids) if ids else None
|
||||||
|
return message, event
|
||||||
|
|
||||||
|
def receive_message(self, message):
|
||||||
|
'''Call with an unframed message received from the network.
|
||||||
|
|
||||||
|
Raises: ProtocolError if the message violates the protocol in
|
||||||
|
some way. However, if it happened in a response that can be
|
||||||
|
paired with a request, the ProtocolError is instead set in the
|
||||||
|
result attribute of the send_request() that caused the error.
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
item, request_id = self._protocol.message_to_item(message)
|
||||||
|
except ProtocolError as e:
|
||||||
|
if e.response_msg_id is not id:
|
||||||
|
return self._receive_response(e, e.response_msg_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(item, Request):
|
||||||
|
item.send_result = partial(self._send_result, request_id)
|
||||||
|
return [item]
|
||||||
|
if isinstance(item, Notification):
|
||||||
|
return [item]
|
||||||
|
if isinstance(item, Response):
|
||||||
|
return self._receive_response(item.result, request_id)
|
||||||
|
if isinstance(item, list):
|
||||||
|
if all(isinstance(payload, dict)
|
||||||
|
and ('result' in payload or 'error' in payload)
|
||||||
|
for payload in item):
|
||||||
|
return self._receive_response_batch(item)
|
||||||
|
else:
|
||||||
|
return self._receive_request_batch(item)
|
||||||
|
else:
|
||||||
|
# Protocol auto-detection hack
|
||||||
|
assert issubclass(item, JSONRPC)
|
||||||
|
self._protocol = item
|
||||||
|
return self.receive_message(message)
|
||||||
|
|
||||||
|
def cancel_pending_requests(self):
|
||||||
|
'''Cancel all pending requests.'''
|
||||||
|
exception = CancelledError()
|
||||||
|
for request, event in self._requests.values():
|
||||||
|
event.result = exception
|
||||||
|
event.set()
|
||||||
|
self._requests.clear()
|
||||||
|
|
||||||
|
def pending_requests(self):
|
||||||
|
'''All sent requests that have not received a response.'''
|
||||||
|
return [request for request, event in self._requests.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def handler_invocation(handler, request):
|
||||||
|
method, args = request.method, request.args
|
||||||
|
if handler is None:
|
||||||
|
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
|
||||||
|
f'unknown method "{method}"')
|
||||||
|
|
||||||
|
# We must test for too few and too many arguments. How
|
||||||
|
# depends on whether the arguments were passed as a list or as
|
||||||
|
# a dictionary.
|
||||||
|
info = signature_info(handler)
|
||||||
|
if isinstance(args, (tuple, list)):
|
||||||
|
if len(args) < info.min_args:
|
||||||
|
s = '' if len(args) == 1 else 's'
|
||||||
|
raise RPCError.invalid_args(
|
||||||
|
f'{len(args)} argument{s} passed to method '
|
||||||
|
f'"{method}" but it requires {info.min_args}')
|
||||||
|
if info.max_args is not None and len(args) > info.max_args:
|
||||||
|
s = '' if len(args) == 1 else 's'
|
||||||
|
raise RPCError.invalid_args(
|
||||||
|
f'{len(args)} argument{s} passed to method '
|
||||||
|
f'{method} taking at most {info.max_args}')
|
||||||
|
return partial(handler, *args)
|
||||||
|
|
||||||
|
# Arguments passed by name
|
||||||
|
if info.other_names is None:
|
||||||
|
raise RPCError.invalid_args(f'method "{method}" cannot '
|
||||||
|
f'be called with named arguments')
|
||||||
|
|
||||||
|
missing = set(info.required_names).difference(args)
|
||||||
|
if missing:
|
||||||
|
s = '' if len(missing) == 1 else 's'
|
||||||
|
missing = ', '.join(sorted(f'"{name}"' for name in missing))
|
||||||
|
raise RPCError.invalid_args(f'method "{method}" requires '
|
||||||
|
f'parameter{s} {missing}')
|
||||||
|
|
||||||
|
if info.other_names is not any:
|
||||||
|
excess = set(args).difference(info.required_names)
|
||||||
|
excess = excess.difference(info.other_names)
|
||||||
|
if excess:
|
||||||
|
s = '' if len(excess) == 1 else 's'
|
||||||
|
excess = ', '.join(sorted(f'"{name}"' for name in excess))
|
||||||
|
raise RPCError.invalid_args(f'method "{method}" does not '
|
||||||
|
f'take parameter{s} {excess}')
|
||||||
|
return partial(handler, **args)
|
549
torba/rpc/session.py
Normal file
549
torba/rpc/session.py
Normal file
|
@ -0,0 +1,549 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server',
|
||||||
|
'BatchError')
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
from . import *
|
||||||
|
from .util import Concurrency
|
||||||
|
|
||||||
|
|
||||||
|
class Connector(object):
|
||||||
|
|
||||||
|
def __init__(self, session_factory, host=None, port=None, proxy=None,
|
||||||
|
**kwargs):
|
||||||
|
self.session_factory = session_factory
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.proxy = proxy
|
||||||
|
self.loop = kwargs.get('loop', asyncio.get_event_loop())
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
async def create_connection(self):
|
||||||
|
'''Initiate a connection.'''
|
||||||
|
connector = self.proxy or self.loop
|
||||||
|
return await connector.create_connection(
|
||||||
|
self.session_factory, self.host, self.port, **self.kwargs)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
transport, self.protocol = await self.create_connection()
|
||||||
|
# By default, do not limit outgoing connections
|
||||||
|
self.protocol.bw_limit = 0
|
||||||
|
return self.protocol
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
await self.protocol.close()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionBase(asyncio.Protocol):
|
||||||
|
'''Base class of networking sessions.
|
||||||
|
|
||||||
|
There is no client / server distinction other than who initiated
|
||||||
|
the connection.
|
||||||
|
|
||||||
|
To initiate a connection to a remote server pass host, port and
|
||||||
|
proxy to the constructor, and then call create_connection(). Each
|
||||||
|
successful call should have a corresponding call to close().
|
||||||
|
|
||||||
|
Alternatively if used in a with statement, the connection is made
|
||||||
|
on entry to the block, and closed on exit from the block.
|
||||||
|
'''
|
||||||
|
|
||||||
|
max_errors = 10
|
||||||
|
|
||||||
|
def __init__(self, *, framer=None, loop=None):
|
||||||
|
self.framer = framer or self.default_framer()
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
self.transport = None
|
||||||
|
# Set when a connection is made
|
||||||
|
self._address = None
|
||||||
|
self._proxy_address = None
|
||||||
|
# For logger.debug messsages
|
||||||
|
self.verbosity = 0
|
||||||
|
# Cleared when the send socket is full
|
||||||
|
self._can_send = Event()
|
||||||
|
self._can_send.set()
|
||||||
|
self._pm_task = None
|
||||||
|
self._task_group = TaskGroup()
|
||||||
|
# Force-close a connection if a send doesn't succeed in this time
|
||||||
|
self.max_send_delay = 60
|
||||||
|
# Statistics. The RPC object also keeps its own statistics.
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.errors = 0
|
||||||
|
self.send_count = 0
|
||||||
|
self.send_size = 0
|
||||||
|
self.last_send = self.start_time
|
||||||
|
self.recv_count = 0
|
||||||
|
self.recv_size = 0
|
||||||
|
self.last_recv = self.start_time
|
||||||
|
# Bandwidth usage per hour before throttling starts
|
||||||
|
self.bw_limit = 2000000
|
||||||
|
self.bw_time = self.start_time
|
||||||
|
self.bw_charge = 0
|
||||||
|
# Concurrency control
|
||||||
|
self.max_concurrent = 6
|
||||||
|
self._concurrency = Concurrency(self.max_concurrent)
|
||||||
|
|
||||||
|
async def _update_concurrency(self):
|
||||||
|
# A non-positive value means not to limit concurrency
|
||||||
|
if self.bw_limit <= 0:
|
||||||
|
return
|
||||||
|
now = time.time()
|
||||||
|
# Reduce the recorded usage in proportion to the elapsed time
|
||||||
|
refund = (now - self.bw_time) * (self.bw_limit / 3600)
|
||||||
|
self.bw_charge = max(0, self.bw_charge - int(refund))
|
||||||
|
self.bw_time = now
|
||||||
|
# Reduce concurrency allocation by 1 for each whole bw_limit used
|
||||||
|
throttle = int(self.bw_charge / self.bw_limit)
|
||||||
|
target = max(1, self.max_concurrent - throttle)
|
||||||
|
current = self._concurrency.max_concurrent
|
||||||
|
if target != current:
|
||||||
|
self.logger.info(f'changing task concurrency from {current} '
|
||||||
|
f'to {target}')
|
||||||
|
await self._concurrency.set_max_concurrent(target)
|
||||||
|
|
||||||
|
def _using_bandwidth(self, size):
|
||||||
|
'''Called when sending or receiving size bytes.'''
|
||||||
|
self.bw_charge += size
|
||||||
|
|
||||||
|
async def _process_messages(self):
|
||||||
|
'''Process incoming messages asynchronously and consume the
|
||||||
|
results.
|
||||||
|
'''
|
||||||
|
async def collect_tasks():
|
||||||
|
next_done = task_group.next_done
|
||||||
|
while True:
|
||||||
|
await next_done()
|
||||||
|
|
||||||
|
task_group = self._task_group
|
||||||
|
async with task_group:
|
||||||
|
await self.spawn(self._receive_messages)
|
||||||
|
await self.spawn(collect_tasks)
|
||||||
|
|
||||||
|
async def _limited_wait(self, secs):
|
||||||
|
# Wait at most secs seconds to send, otherwise abort the connection
|
||||||
|
try:
|
||||||
|
async with timeout_after(secs):
|
||||||
|
await self._can_send.wait()
|
||||||
|
except TaskTimeout:
|
||||||
|
self.abort()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _send_message(self, message):
|
||||||
|
if not self._can_send.is_set():
|
||||||
|
await self._limited_wait(self.max_send_delay)
|
||||||
|
if not self.is_closing():
|
||||||
|
framed_message = self.framer.frame(message)
|
||||||
|
self.send_size += len(framed_message)
|
||||||
|
self._using_bandwidth(len(framed_message))
|
||||||
|
self.send_count += 1
|
||||||
|
self.last_send = time.time()
|
||||||
|
if self.verbosity >= 4:
|
||||||
|
self.logger.debug(f'Sending framed message {framed_message}')
|
||||||
|
self.transport.write(framed_message)
|
||||||
|
|
||||||
|
def _bump_errors(self):
|
||||||
|
self.errors += 1
|
||||||
|
if self.errors >= self.max_errors:
|
||||||
|
# Don't await self.close() because that is self-cancelling
|
||||||
|
self._close()
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
if self.transport:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
# asyncio framework
|
||||||
|
def data_received(self, framed_message):
|
||||||
|
'''Called by asyncio when a message comes in.'''
|
||||||
|
if self.verbosity >= 4:
|
||||||
|
self.logger.debug(f'Received framed message {framed_message}')
|
||||||
|
self.recv_size += len(framed_message)
|
||||||
|
self._using_bandwidth(len(framed_message))
|
||||||
|
self.framer.received_bytes(framed_message)
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
'''Transport calls when the send buffer is full.'''
|
||||||
|
if not self.is_closing():
|
||||||
|
self._can_send.clear()
|
||||||
|
self.transport.pause_reading()
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
'''Transport calls when the send buffer has room.'''
|
||||||
|
if not self._can_send.is_set():
|
||||||
|
self._can_send.set()
|
||||||
|
self.transport.resume_reading()
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
'''Called by asyncio when a connection is established.
|
||||||
|
|
||||||
|
Derived classes overriding this method must call this first.'''
|
||||||
|
self.transport = transport
|
||||||
|
# This would throw if called on a closed SSL transport. Fixed
|
||||||
|
# in asyncio in Python 3.6.1 and 3.5.4
|
||||||
|
peer_address = transport.get_extra_info('peername')
|
||||||
|
# If the Socks proxy was used then _address is already set to
|
||||||
|
# the remote address
|
||||||
|
if self._address:
|
||||||
|
self._proxy_address = peer_address
|
||||||
|
else:
|
||||||
|
self._address = peer_address
|
||||||
|
self._pm_task = spawn_sync(self._process_messages(), loop=self.loop)
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
'''Called by asyncio when the connection closes.
|
||||||
|
|
||||||
|
Tear down things done in connection_made.'''
|
||||||
|
self._address = None
|
||||||
|
self.transport = None
|
||||||
|
self._pm_task.cancel()
|
||||||
|
# Release waiting tasks
|
||||||
|
self._can_send.set()
|
||||||
|
|
||||||
|
# External API
|
||||||
|
def default_framer(self):
|
||||||
|
'''Return a default framer.'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def peer_address(self):
|
||||||
|
'''Returns the peer's address (Python networking address), or None if
|
||||||
|
no connection or an error.
|
||||||
|
|
||||||
|
This is the result of socket.getpeername() when the connection
|
||||||
|
was made.
|
||||||
|
'''
|
||||||
|
return self._address
|
||||||
|
|
||||||
|
def peer_address_str(self):
|
||||||
|
'''Returns the peer's IP address and port as a human-readable
|
||||||
|
string.'''
|
||||||
|
if not self._address:
|
||||||
|
return 'unknown'
|
||||||
|
ip_addr_str, port = self._address[:2]
|
||||||
|
if ':' in ip_addr_str:
|
||||||
|
return f'[{ip_addr_str}]:{port}'
|
||||||
|
else:
|
||||||
|
return f'{ip_addr_str}:{port}'
|
||||||
|
|
||||||
|
async def spawn(self, coro, *args):
|
||||||
|
'''If the session is connected, spawn a task that is cancelled
|
||||||
|
on disconnect, and return it. Otherwise return None.'''
|
||||||
|
group = self._task_group
|
||||||
|
if not group.closed():
|
||||||
|
return await group.spawn(coro, *args)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_closing(self):
|
||||||
|
'''Return True if the connection is closing.'''
|
||||||
|
return not self.transport or self.transport.is_closing()
|
||||||
|
|
||||||
|
def abort(self):
|
||||||
|
'''Forcefully close the connection.'''
|
||||||
|
if self.transport:
|
||||||
|
self.transport.abort()
|
||||||
|
|
||||||
|
async def close(self, *, force_after=30):
|
||||||
|
'''Close the connection and return when closed.'''
|
||||||
|
self._close()
|
||||||
|
if self._pm_task:
|
||||||
|
with suppress(CancelledError):
|
||||||
|
async with ignore_after(force_after):
|
||||||
|
await self._pm_task
|
||||||
|
self.abort()
|
||||||
|
await self._pm_task
|
||||||
|
|
||||||
|
|
||||||
|
class MessageSession(SessionBase):
|
||||||
|
'''Session class for protocols where messages are not tied to responses,
|
||||||
|
such as the Bitcoin protocol.
|
||||||
|
|
||||||
|
To use as a client (connection-opening) session, pass host, port
|
||||||
|
and perhaps a proxy.
|
||||||
|
'''
|
||||||
|
async def _receive_messages(self):
|
||||||
|
while not self.is_closing():
|
||||||
|
try:
|
||||||
|
message = await self.framer.receive_message()
|
||||||
|
except BadMagicError as e:
|
||||||
|
magic, expected = e.args
|
||||||
|
self.logger.error(
|
||||||
|
f'bad network magic: got {magic} expected {expected}, '
|
||||||
|
f'disconnecting'
|
||||||
|
)
|
||||||
|
self._close()
|
||||||
|
except OversizedPayloadError as e:
|
||||||
|
command, payload_len = e.args
|
||||||
|
self.logger.error(
|
||||||
|
f'oversized payload of {payload_len:,d} bytes to command '
|
||||||
|
f'{command}, disconnecting'
|
||||||
|
)
|
||||||
|
self._close()
|
||||||
|
except BadChecksumError as e:
|
||||||
|
payload_checksum, claimed_checksum = e.args
|
||||||
|
self.logger.warning(
|
||||||
|
f'checksum mismatch: actual {payload_checksum.hex()} '
|
||||||
|
f'vs claimed {claimed_checksum.hex()}'
|
||||||
|
)
|
||||||
|
self._bump_errors()
|
||||||
|
else:
|
||||||
|
self.last_recv = time.time()
|
||||||
|
self.recv_count += 1
|
||||||
|
if self.recv_count % 10 == 0:
|
||||||
|
await self._update_concurrency()
|
||||||
|
await self.spawn(self._throttled_message(message))
|
||||||
|
|
||||||
|
async def _throttled_message(self, message):
|
||||||
|
'''Process a single request, respecting the concurrency limit.'''
|
||||||
|
async with self._concurrency.semaphore:
|
||||||
|
try:
|
||||||
|
await self.handle_message(message)
|
||||||
|
except ProtocolError as e:
|
||||||
|
self.logger.error(f'{e}')
|
||||||
|
self._bump_errors()
|
||||||
|
except CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception(f'exception handling {message}')
|
||||||
|
self._bump_errors()
|
||||||
|
|
||||||
|
# External API
|
||||||
|
def default_framer(self):
|
||||||
|
'''Return a bitcoin framer.'''
|
||||||
|
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
|
||||||
|
|
||||||
|
async def handle_message(self, message):
|
||||||
|
'''message is a (command, payload) pair.'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_message(self, message):
|
||||||
|
'''Send a message (command, payload) over the network.'''
|
||||||
|
await self._send_message(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchError(Exception):
|
||||||
|
|
||||||
|
def __init__(self, request):
|
||||||
|
self.request = request # BatchRequest object
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequest(object):
|
||||||
|
'''Used to build a batch request to send to the server. Stores
|
||||||
|
the
|
||||||
|
|
||||||
|
Attributes batch and results are initially None.
|
||||||
|
|
||||||
|
Adding an invalid request or notification immediately raises a
|
||||||
|
ProtocolError.
|
||||||
|
|
||||||
|
On exiting the with clause, it will:
|
||||||
|
|
||||||
|
1) create a Batch object for the requests in the order they were
|
||||||
|
added. If the batch is empty this raises a ProtocolError.
|
||||||
|
|
||||||
|
2) set the "batch" attribute to be that batch
|
||||||
|
|
||||||
|
3) send the batch request and wait for a response
|
||||||
|
|
||||||
|
4) raise a ProtocolError if the protocol was violated by the
|
||||||
|
server. Currently this only happens if it gave more than one
|
||||||
|
response to any request
|
||||||
|
|
||||||
|
5) otherwise there is precisely one response to each Request. Set
|
||||||
|
the "results" attribute to the tuple of results; the responses
|
||||||
|
are ordered to match the Requests in the batch. Notifications
|
||||||
|
do not get a response.
|
||||||
|
|
||||||
|
6) if raise_errors is True and any individual response was a JSON
|
||||||
|
RPC error response, or violated the protocol in some way, a
|
||||||
|
BatchError exception is raised. Otherwise the caller can be
|
||||||
|
certain each request returned a standard result.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, session, raise_errors):
|
||||||
|
self._session = session
|
||||||
|
self._raise_errors = raise_errors
|
||||||
|
self._requests = []
|
||||||
|
self.batch = None
|
||||||
|
self.results = None
|
||||||
|
|
||||||
|
def add_request(self, method, args=()):
|
||||||
|
self._requests.append(Request(method, args))
|
||||||
|
|
||||||
|
def add_notification(self, method, args=()):
|
||||||
|
self._requests.append(Notification(method, args))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._requests)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
if exc_type is None:
|
||||||
|
self.batch = Batch(self._requests)
|
||||||
|
message, event = self._session.connection.send_batch(self.batch)
|
||||||
|
await self._session._send_message(message)
|
||||||
|
await event.wait()
|
||||||
|
self.results = event.result
|
||||||
|
if self._raise_errors:
|
||||||
|
if any(isinstance(item, Exception) for item in event.result):
|
||||||
|
raise BatchError(self)
|
||||||
|
|
||||||
|
|
||||||
|
class RPCSession(SessionBase):
|
||||||
|
'''Base class for protocols where a message can lead to a response,
|
||||||
|
for example JSON RPC.'''
|
||||||
|
|
||||||
|
def __init__(self, *, framer=None, loop=None, connection=None):
|
||||||
|
super().__init__(framer=framer, loop=loop)
|
||||||
|
self.connection = connection or self.default_connection()
|
||||||
|
|
||||||
|
async def _receive_messages(self):
|
||||||
|
while not self.is_closing():
|
||||||
|
try:
|
||||||
|
message = await self.framer.receive_message()
|
||||||
|
except MemoryError as e:
|
||||||
|
self.logger.warning(f'{e!r}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.last_recv = time.time()
|
||||||
|
self.recv_count += 1
|
||||||
|
if self.recv_count % 10 == 0:
|
||||||
|
await self._update_concurrency()
|
||||||
|
|
||||||
|
try:
|
||||||
|
requests = self.connection.receive_message(message)
|
||||||
|
except ProtocolError as e:
|
||||||
|
self.logger.debug(f'{e}')
|
||||||
|
if e.error_message:
|
||||||
|
await self._send_message(e.error_message)
|
||||||
|
if e.code == JSONRPC.PARSE_ERROR:
|
||||||
|
self.max_errors = 0
|
||||||
|
self._bump_errors()
|
||||||
|
else:
|
||||||
|
for request in requests:
|
||||||
|
await self.spawn(self._throttled_request(request))
|
||||||
|
|
||||||
|
async def _throttled_request(self, request):
|
||||||
|
'''Process a single request, respecting the concurrency limit.'''
|
||||||
|
async with self._concurrency.semaphore:
|
||||||
|
try:
|
||||||
|
result = await self.handle_request(request)
|
||||||
|
except (ProtocolError, RPCError) as e:
|
||||||
|
result = e
|
||||||
|
except CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception(f'exception handling {request}')
|
||||||
|
result = RPCError(JSONRPC.INTERNAL_ERROR,
|
||||||
|
'internal server error')
|
||||||
|
if isinstance(request, Request):
|
||||||
|
message = request.send_result(result)
|
||||||
|
if message:
|
||||||
|
await self._send_message(message)
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
self._bump_errors()
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
# Cancel pending requests and message processing
|
||||||
|
self.connection.cancel_pending_requests()
|
||||||
|
super().connection_lost(exc)
|
||||||
|
|
||||||
|
# External API
|
||||||
|
def default_connection(self):
|
||||||
|
'''Return a default connection if the user provides none.'''
|
||||||
|
return JSONRPCConnection(JSONRPCv2)
|
||||||
|
|
||||||
|
def default_framer(self):
|
||||||
|
'''Return a default framer.'''
|
||||||
|
return NewlineFramer()
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_request(self, method, args=()):
|
||||||
|
'''Send an RPC request over the network.'''
|
||||||
|
message, event = self.connection.send_request(Request(method, args))
|
||||||
|
await self._send_message(message)
|
||||||
|
await event.wait()
|
||||||
|
result = event.result
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def send_notification(self, method, args=()):
|
||||||
|
'''Send an RPC notification over the network.'''
|
||||||
|
message = self.connection.send_notification(Notification(method, args))
|
||||||
|
await self._send_message(message)
|
||||||
|
|
||||||
|
def send_batch(self, raise_errors=False):
|
||||||
|
'''Return a BatchRequest. Intended to be used like so:
|
||||||
|
|
||||||
|
async with session.send_batch() as batch:
|
||||||
|
batch.add_request("method1")
|
||||||
|
batch.add_request("sum", (x, y))
|
||||||
|
batch.add_notification("updated")
|
||||||
|
|
||||||
|
for result in batch.results:
|
||||||
|
...
|
||||||
|
|
||||||
|
Note that in some circumstances exceptions can be raised; see
|
||||||
|
BatchRequest doc string.
|
||||||
|
'''
|
||||||
|
return BatchRequest(self, raise_errors)
|
||||||
|
|
||||||
|
|
||||||
|
class Server(object):
|
||||||
|
'''A simple wrapper around an asyncio.Server object.'''
|
||||||
|
|
||||||
|
def __init__(self, session_factory, host=None, port=None, *,
|
||||||
|
loop=None, **kwargs):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self.server = None
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
self.server = await self.loop.create_server(
|
||||||
|
self._session_factory, self.host, self.port, **self._kwargs)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
'''Close the listening socket. This does not close any ServerSession
|
||||||
|
objects created to handle incoming connections.
|
||||||
|
'''
|
||||||
|
if self.server:
|
||||||
|
self.server.close()
|
||||||
|
await self.server.wait_closed()
|
||||||
|
self.server = None
|
439
torba/rpc/socks.py
Normal file
439
torba/rpc/socks.py
Normal file
|
@ -0,0 +1,439 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
'''SOCKS proxying.'''
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import collections
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
|
||||||
|
'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
|
||||||
|
|
||||||
|
|
||||||
|
SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSError(Exception):
|
||||||
|
'''Base class for SOCKS exceptions. Each raised exception will be
|
||||||
|
an instance of a derived class.'''
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSProtocolError(SOCKSError):
|
||||||
|
'''Raised when the proxy does not follow the SOCKS protocol'''
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSFailure(SOCKSError):
|
||||||
|
'''Raised when the proxy refuses or fails to make a connection'''
|
||||||
|
|
||||||
|
|
||||||
|
class NeedData(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSBase(object):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
return cls.__name__
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._buffer = bytes()
|
||||||
|
self._state = self._start
|
||||||
|
|
||||||
|
def _read(self, size):
|
||||||
|
if len(self._buffer) < size:
|
||||||
|
raise NeedData(size - len(self._buffer))
|
||||||
|
result = self._buffer[:size]
|
||||||
|
self._buffer = self._buffer[size:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def receive_data(self, data):
|
||||||
|
self._buffer += data
|
||||||
|
|
||||||
|
def next_message(self):
|
||||||
|
return self._state()
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS4(SOCKSBase):
|
||||||
|
'''SOCKS4 protocol wrapper.'''
|
||||||
|
|
||||||
|
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
|
||||||
|
REPLY_CODES = {
|
||||||
|
90: 'request granted',
|
||||||
|
91: 'request rejected or failed',
|
||||||
|
92: ('request rejected because SOCKS server cannot connect '
|
||||||
|
'to identd on the client'),
|
||||||
|
93: ('request rejected because the client program and identd '
|
||||||
|
'report different user-ids')
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, dst_host, dst_port, auth):
|
||||||
|
super().__init__()
|
||||||
|
self._dst_host = self._check_host(dst_host)
|
||||||
|
self._dst_port = dst_port
|
||||||
|
self._auth = auth
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_host(cls, host):
|
||||||
|
if not isinstance(host, ipaddress.IPv4Address):
|
||||||
|
try:
|
||||||
|
host = ipaddress.IPv4Address(host)
|
||||||
|
except ValueError:
|
||||||
|
raise SOCKSProtocolError(
|
||||||
|
f'SOCKS4 requires an IPv4 address: {host}') from None
|
||||||
|
return host
|
||||||
|
|
||||||
|
def _start(self):
|
||||||
|
self._state = self._first_response
|
||||||
|
|
||||||
|
if isinstance(self._dst_host, ipaddress.IPv4Address):
|
||||||
|
# SOCKS4
|
||||||
|
dst_ip_packed = self._dst_host.packed
|
||||||
|
host_bytes = b''
|
||||||
|
else:
|
||||||
|
# SOCKS4a
|
||||||
|
dst_ip_packed = b'\0\0\0\1'
|
||||||
|
host_bytes = self._dst_host.encode() + b'\0'
|
||||||
|
|
||||||
|
if isinstance(self._auth, SOCKSUserAuth):
|
||||||
|
user_id = self._auth.username.encode()
|
||||||
|
else:
|
||||||
|
user_id = b''
|
||||||
|
|
||||||
|
# Send TCP/IP stream CONNECT request
|
||||||
|
return b''.join([b'\4\1', struct.pack('>H', self._dst_port),
|
||||||
|
dst_ip_packed, user_id, b'\0', host_bytes])
|
||||||
|
|
||||||
|
def _first_response(self):
|
||||||
|
# Wait for 8-byte response
|
||||||
|
data = self._read(8)
|
||||||
|
if data[0] != 0:
|
||||||
|
raise SOCKSProtocolError(f'invalid {self.name()} proxy '
|
||||||
|
f'response: {data}')
|
||||||
|
reply_code = data[1]
|
||||||
|
if reply_code != 90:
|
||||||
|
msg = self.REPLY_CODES.get(
|
||||||
|
reply_code, f'unknown {self.name()} reply code {reply_code}')
|
||||||
|
raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
|
||||||
|
|
||||||
|
# Other fields ignored
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS4a(SOCKS4):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_host(cls, host):
|
||||||
|
if not isinstance(host, (str, ipaddress.IPv4Address)):
|
||||||
|
raise SOCKSProtocolError(
|
||||||
|
f'SOCKS4a requires an IPv4 address or host name: {host}')
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS5(SOCKSBase):
|
||||||
|
'''SOCKS protocol wrapper.'''
|
||||||
|
|
||||||
|
# See https://tools.ietf.org/html/rfc1928
|
||||||
|
ERROR_CODES = {
|
||||||
|
1: 'general SOCKS server failure',
|
||||||
|
2: 'connection not allowed by ruleset',
|
||||||
|
3: 'network unreachable',
|
||||||
|
4: 'host unreachable',
|
||||||
|
5: 'connection refused',
|
||||||
|
6: 'TTL expired',
|
||||||
|
7: 'command not supported',
|
||||||
|
8: 'address type not supported',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, dst_host, dst_port, auth):
|
||||||
|
super().__init__()
|
||||||
|
self._dst_bytes = self._destination_bytes(dst_host, dst_port)
|
||||||
|
self._auth_bytes, self._auth_methods = self._authentication(auth)
|
||||||
|
|
||||||
|
def _destination_bytes(self, host, port):
|
||||||
|
if isinstance(host, ipaddress.IPv4Address):
|
||||||
|
addr_bytes = b'\1' + host.packed
|
||||||
|
elif isinstance(host, ipaddress.IPv6Address):
|
||||||
|
addr_bytes = b'\4' + host.packed
|
||||||
|
elif isinstance(host, str):
|
||||||
|
host = host.encode()
|
||||||
|
if len(host) > 255:
|
||||||
|
raise SOCKSProtocolError(f'hostname too long: '
|
||||||
|
f'{len(host)} bytes')
|
||||||
|
addr_bytes = b'\3' + bytes([len(host)]) + host
|
||||||
|
else:
|
||||||
|
raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 '
|
||||||
|
f'address, or host name: {host}')
|
||||||
|
return addr_bytes + struct.pack('>H', port)
|
||||||
|
|
||||||
|
def _authentication(self, auth):
|
||||||
|
if isinstance(auth, SOCKSUserAuth):
|
||||||
|
user_bytes = auth.username.encode()
|
||||||
|
if not 0 < len(user_bytes) < 256:
|
||||||
|
raise SOCKSProtocolError(f'username {auth.username} has '
|
||||||
|
f'invalid length {len(user_bytes)}')
|
||||||
|
pwd_bytes = auth.password.encode()
|
||||||
|
if not 0 < len(pwd_bytes) < 256:
|
||||||
|
raise SOCKSProtocolError(f'password has invalid length '
|
||||||
|
f'{len(pwd_bytes)}')
|
||||||
|
return b''.join([bytes([1, len(user_bytes)]), user_bytes,
|
||||||
|
bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
|
||||||
|
return b'', [0]
|
||||||
|
|
||||||
|
def _start(self):
|
||||||
|
self._state = self._first_response
|
||||||
|
return (b'\5' + bytes([len(self._auth_methods)])
|
||||||
|
+ bytes(m for m in self._auth_methods))
|
||||||
|
|
||||||
|
def _first_response(self):
|
||||||
|
# Wait for 2-byte response
|
||||||
|
data = self._read(2)
|
||||||
|
if data[0] != 5:
|
||||||
|
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
||||||
|
if data[1] not in self._auth_methods:
|
||||||
|
raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
|
||||||
|
|
||||||
|
# Authenticate if user-password authentication
|
||||||
|
if data[1] == 2:
|
||||||
|
self._state = self._auth_response
|
||||||
|
return self._auth_bytes
|
||||||
|
return self._request_connection()
|
||||||
|
|
||||||
|
def _auth_response(self):
|
||||||
|
data = self._read(2)
|
||||||
|
if data[0] != 1:
|
||||||
|
raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
|
||||||
|
f'response: {data}')
|
||||||
|
if data[1] != 0:
|
||||||
|
raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
|
||||||
|
f'{data[1]}')
|
||||||
|
|
||||||
|
return self._request_connection()
|
||||||
|
|
||||||
|
def _request_connection(self):
|
||||||
|
# Send connection request
|
||||||
|
self._state = self._connect_response
|
||||||
|
return b'\5\1\0' + self._dst_bytes
|
||||||
|
|
||||||
|
def _connect_response(self):
|
||||||
|
data = self._read(5)
|
||||||
|
if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
|
||||||
|
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
||||||
|
if data[1] != 0:
|
||||||
|
raise SOCKSFailure(self.ERROR_CODES.get(
|
||||||
|
data[1], f'unknown SOCKS5 error code: {data[1]}'))
|
||||||
|
|
||||||
|
if data[3] == 1:
|
||||||
|
addr_len = 3 # IPv4
|
||||||
|
elif data[3] == 3:
|
||||||
|
addr_len = data[4] # Hostname
|
||||||
|
else:
|
||||||
|
addr_len = 15 # IPv6
|
||||||
|
|
||||||
|
self._state = partial(self._connect_response_rest, addr_len)
|
||||||
|
return self.next_message()
|
||||||
|
|
||||||
|
def _connect_response_rest(self, addr_len):
|
||||||
|
self._read(addr_len + 2)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKSProxy(object):
|
||||||
|
|
||||||
|
def __init__(self, address, protocol, auth):
|
||||||
|
'''A SOCKS proxy at an address following a SOCKS protocol. auth is an
|
||||||
|
authentication method to use when connecting, or None.
|
||||||
|
|
||||||
|
address is a (host, port) pair; for IPv6 it can instead be a
|
||||||
|
(host, port, flowinfo, scopeid) 4-tuple.
|
||||||
|
'''
|
||||||
|
self.address = address
|
||||||
|
self.protocol = protocol
|
||||||
|
self.auth = auth
|
||||||
|
# Set on each successful connection via the proxy to the
|
||||||
|
# result of socket.getpeername()
|
||||||
|
self.peername = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
auth = 'username' if self.auth else 'none'
|
||||||
|
return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
|
||||||
|
|
||||||
|
async def _handshake(self, client, sock, loop):
|
||||||
|
while True:
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
message = client.next_message()
|
||||||
|
except NeedData as e:
|
||||||
|
count = e.args[0]
|
||||||
|
else:
|
||||||
|
if message is None:
|
||||||
|
return
|
||||||
|
await loop.sock_sendall(sock, message)
|
||||||
|
|
||||||
|
if count:
|
||||||
|
data = await loop.sock_recv(sock, count)
|
||||||
|
if not data:
|
||||||
|
raise SOCKSProtocolError("EOF received")
|
||||||
|
client.receive_data(data)
|
||||||
|
|
||||||
|
async def _connect_one(self, host, port):
|
||||||
|
'''Connect to the proxy and perform a handshake requesting a
|
||||||
|
connection to (host, port).
|
||||||
|
|
||||||
|
Return the open socket on success, or the exception on failure.
|
||||||
|
'''
|
||||||
|
client = self.protocol(host, port, self.auth)
|
||||||
|
sock = socket.socket()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
# A non-blocking socket is required by loop socket methods
|
||||||
|
sock.setblocking(False)
|
||||||
|
await loop.sock_connect(sock, self.address)
|
||||||
|
await self._handshake(client, sock, loop)
|
||||||
|
self.peername = sock.getpeername()
|
||||||
|
return sock
|
||||||
|
except Exception as e:
|
||||||
|
# Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8
|
||||||
|
if sys.platform.startswith('linux'):
|
||||||
|
sock.close()
|
||||||
|
return e
|
||||||
|
|
||||||
|
async def _connect(self, addresses):
|
||||||
|
'''Connect to the proxy and perform a handshake requesting a
|
||||||
|
connection to each address in addresses.
|
||||||
|
|
||||||
|
Return an (open_socket, address) pair on success.
|
||||||
|
'''
|
||||||
|
assert len(addresses) > 0
|
||||||
|
|
||||||
|
exceptions = []
|
||||||
|
for address in addresses:
|
||||||
|
host, port = address[:2]
|
||||||
|
sock = await self._connect_one(host, port)
|
||||||
|
if isinstance(sock, socket.socket):
|
||||||
|
return sock, address
|
||||||
|
exceptions.append(sock)
|
||||||
|
|
||||||
|
strings = set(f'{exc!r}' for exc in exceptions)
|
||||||
|
raise (exceptions[0] if len(strings) == 1 else
|
||||||
|
OSError(f'multiple exceptions: {", ".join(strings)}'))
|
||||||
|
|
||||||
|
async def _detect_proxy(self):
|
||||||
|
'''Return True if it appears we can connect to a SOCKS proxy,
|
||||||
|
otherwise False.
|
||||||
|
'''
|
||||||
|
if self.protocol is SOCKS4a:
|
||||||
|
host, port = 'www.apple.com', 80
|
||||||
|
else:
|
||||||
|
host, port = ipaddress.IPv4Address('8.8.8.8'), 53
|
||||||
|
|
||||||
|
sock = await self._connect_one(host, port)
|
||||||
|
if isinstance(sock, socket.socket):
|
||||||
|
sock.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
# SOCKSFailure indicates something failed, but that we are
|
||||||
|
# likely talking to a proxy
|
||||||
|
return isinstance(sock, SOCKSFailure)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auto_detect_address(cls, address, auth):
|
||||||
|
'''Try to detect a SOCKS proxy at address using the authentication
|
||||||
|
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
|
||||||
|
order. If a SOCKS proxy is detected a SOCKSProxy object is
|
||||||
|
returned.
|
||||||
|
|
||||||
|
Returning a SOCKSProxy does not mean it is functioning - for
|
||||||
|
example, it may have no network connectivity.
|
||||||
|
|
||||||
|
If no proxy is detected return None.
|
||||||
|
'''
|
||||||
|
for protocol in (SOCKS5, SOCKS4a, SOCKS4):
|
||||||
|
proxy = cls(address, protocol, auth)
|
||||||
|
if await proxy._detect_proxy():
|
||||||
|
return proxy
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auto_detect_host(cls, host, ports, auth):
|
||||||
|
'''Try to detect a SOCKS proxy on a host on one of the ports.
|
||||||
|
|
||||||
|
Calls auto_detect for the ports in order. Returns SOCKS are
|
||||||
|
tried in order; a SOCKSProxy object for the first detected
|
||||||
|
proxy is returned.
|
||||||
|
|
||||||
|
Returning a SOCKSProxy does not mean it is functioning - for
|
||||||
|
example, it may have no network connectivity.
|
||||||
|
|
||||||
|
If no proxy is detected return None.
|
||||||
|
'''
|
||||||
|
for port in ports:
|
||||||
|
address = (host, port)
|
||||||
|
proxy = await cls.auto_detect_address(address, auth)
|
||||||
|
if proxy:
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_connection(self, protocol_factory, host, port, *,
|
||||||
|
resolve=False, ssl=None,
|
||||||
|
family=0, proto=0, flags=0):
|
||||||
|
'''Set up a connection to (host, port) through the proxy.
|
||||||
|
|
||||||
|
If resolve is True then host is resolved locally with
|
||||||
|
getaddrinfo using family, proto and flags, otherwise the proxy
|
||||||
|
is asked to resolve host.
|
||||||
|
|
||||||
|
The function signature is similar to loop.create_connection()
|
||||||
|
with the same result. The attribute _address is set on the
|
||||||
|
protocol to the address of the successful remote connection.
|
||||||
|
Additionally raises SOCKSError if something goes wrong with
|
||||||
|
the proxy handshake.
|
||||||
|
'''
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if resolve:
|
||||||
|
infos = await loop.getaddrinfo(host, port, family=family,
|
||||||
|
type=socket.SOCK_STREAM,
|
||||||
|
proto=proto, flags=flags)
|
||||||
|
addresses = [info[4] for info in infos]
|
||||||
|
else:
|
||||||
|
addresses = [(host, port)]
|
||||||
|
|
||||||
|
sock, address = await self._connect(addresses)
|
||||||
|
|
||||||
|
def set_address():
|
||||||
|
protocol = protocol_factory()
|
||||||
|
protocol._address = address
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
return await loop.create_connection(
|
||||||
|
set_address, sock=sock, ssl=ssl,
|
||||||
|
server_hostname=host if ssl else None)
|
120
torba/rpc/util.py
Normal file
120
torba/rpc/util.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) 2018, Neil Booth
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# The MIT License (MIT)
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the
|
||||||
|
# "Software"), to deal in the Software without restriction, including
|
||||||
|
# without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
# permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
# the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
__all__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_corofunc(corofunc, args):
|
||||||
|
if asyncio.iscoroutine(corofunc):
|
||||||
|
if args != ():
|
||||||
|
raise ValueError('args cannot be passed with a coroutine')
|
||||||
|
return corofunc
|
||||||
|
return corofunc(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_call(func):
|
||||||
|
'''inspect.iscoroutinefunction that looks through partials.'''
|
||||||
|
while isinstance(func, partial):
|
||||||
|
func = func.func
|
||||||
|
return inspect.iscoroutinefunction(func)
|
||||||
|
|
||||||
|
|
||||||
|
# other_params: None means cannot be called with keyword arguments only
|
||||||
|
# any means any name is good
|
||||||
|
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
|
||||||
|
'required_names other_names')
|
||||||
|
|
||||||
|
|
||||||
|
def signature_info(func):
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
min_args = max_args = 0
|
||||||
|
required_names = []
|
||||||
|
other_names = []
|
||||||
|
no_names = False
|
||||||
|
for p in params.values():
|
||||||
|
if p.kind == p.POSITIONAL_OR_KEYWORD:
|
||||||
|
max_args += 1
|
||||||
|
if p.default is p.empty:
|
||||||
|
min_args += 1
|
||||||
|
required_names.append(p.name)
|
||||||
|
else:
|
||||||
|
other_names.append(p.name)
|
||||||
|
elif p.kind == p.KEYWORD_ONLY:
|
||||||
|
other_names.append(p.name)
|
||||||
|
elif p.kind == p.VAR_POSITIONAL:
|
||||||
|
max_args = None
|
||||||
|
elif p.kind == p.VAR_KEYWORD:
|
||||||
|
other_names = any
|
||||||
|
elif p.kind == p.POSITIONAL_ONLY:
|
||||||
|
max_args += 1
|
||||||
|
if p.default is p.empty:
|
||||||
|
min_args += 1
|
||||||
|
no_names = True
|
||||||
|
|
||||||
|
if no_names:
|
||||||
|
other_names = None
|
||||||
|
|
||||||
|
return SignatureInfo(min_args, max_args, required_names, other_names)
|
||||||
|
|
||||||
|
|
||||||
|
class Concurrency(object):
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent):
|
||||||
|
self._require_non_negative(max_concurrent)
|
||||||
|
self._max_concurrent = max_concurrent
|
||||||
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
|
def _require_non_negative(self, value):
|
||||||
|
if not isinstance(value, int) or value < 0:
|
||||||
|
raise RuntimeError('concurrency must be a natural number')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_concurrent(self):
|
||||||
|
return self._max_concurrent
|
||||||
|
|
||||||
|
async def set_max_concurrent(self, value):
|
||||||
|
self._require_non_negative(value)
|
||||||
|
diff = value - self._max_concurrent
|
||||||
|
self._max_concurrent = value
|
||||||
|
if diff >= 0:
|
||||||
|
for _ in range(diff):
|
||||||
|
self.semaphore.release()
|
||||||
|
else:
|
||||||
|
for _ in range(-diff):
|
||||||
|
await self.semaphore.acquire()
|
||||||
|
|
||||||
|
|
||||||
|
def check_task(logger, task):
|
||||||
|
if not task.cancelled():
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except Exception:
|
||||||
|
logger.error('task crashed: %r', task, exc_info=True)
|
|
@ -15,7 +15,7 @@ from struct import pack, unpack
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from aiorpcx import TaskGroup, run_in_thread
|
from torba.rpc import TaskGroup, run_in_thread
|
||||||
|
|
||||||
import torba
|
import torba
|
||||||
from torba.server.daemon import DaemonError
|
from torba.server.daemon import DaemonError
|
||||||
|
|
|
@ -22,7 +22,7 @@ from torba.server.util import hex_to_bytes, class_logger,\
|
||||||
unpack_le_uint16_from, pack_varint
|
unpack_le_uint16_from, pack_varint
|
||||||
from torba.server.hash import hex_str_to_hash, hash_to_hex_str
|
from torba.server.hash import hex_str_to_hash, hash_to_hex_str
|
||||||
from torba.server.tx import DeserializerDecred
|
from torba.server.tx import DeserializerDecred
|
||||||
from aiorpcx import JSONRPC
|
from torba.rpc import JSONRPC
|
||||||
|
|
||||||
|
|
||||||
class DaemonError(Exception):
|
class DaemonError(Exception):
|
||||||
|
|
|
@ -19,8 +19,8 @@ from glob import glob
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from aiorpcx import run_in_thread, sleep
|
|
||||||
|
|
||||||
|
from torba.rpc import run_in_thread, sleep
|
||||||
from torba.server import util
|
from torba.server import util
|
||||||
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
from torba.server.hash import hash_to_hex_str, HASHX_LEN
|
||||||
from torba.server.merkle import Merkle, MerkleCache
|
from torba.server.merkle import Merkle, MerkleCache
|
||||||
|
|
|
@ -14,8 +14,8 @@ from asyncio import Lock
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from aiorpcx import TaskGroup, run_in_thread, sleep
|
|
||||||
|
|
||||||
|
from torba.rpc import TaskGroup, run_in_thread, sleep
|
||||||
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
|
from torba.server.hash import hash_to_hex_str, hex_str_to_hash
|
||||||
from torba.server.util import class_logger, chunks
|
from torba.server.util import class_logger, chunks
|
||||||
from torba.server.db import UTXO
|
from torba.server.db import UTXO
|
||||||
|
|
|
@ -28,8 +28,7 @@
|
||||||
|
|
||||||
from math import ceil, log
|
from math import ceil, log
|
||||||
|
|
||||||
from aiorpcx import Event
|
from torba.rpc import Event
|
||||||
|
|
||||||
from torba.server.hash import double_sha256
|
from torba.server.hash import double_sha256
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,10 @@ import ssl
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
|
|
||||||
from aiorpcx import (Connector, RPCSession, SOCKSProxy,
|
from torba.rpc import (Connector, RPCSession, SOCKSProxy,
|
||||||
Notification, handler_invocation,
|
Notification, handler_invocation,
|
||||||
SOCKSError, RPCError, TaskTimeout, TaskGroup, Event,
|
SOCKSError, RPCError, TaskTimeout, TaskGroup, Event,
|
||||||
sleep, run_in_thread, ignore_after, timeout_after)
|
sleep, run_in_thread, ignore_after, timeout_after)
|
||||||
|
|
||||||
from torba.server.peer import Peer
|
from torba.server.peer import Peer
|
||||||
from torba.server.util import class_logger, protocol_tuple
|
from torba.server.util import class_logger, protocol_tuple
|
||||||
|
|
||||||
|
|
|
@ -19,13 +19,12 @@ import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from aiorpcx import (
|
import torba
|
||||||
|
from torba.rpc import (
|
||||||
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
|
RPCSession, JSONRPCAutoDetect, JSONRPCConnection,
|
||||||
TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep,
|
TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep,
|
||||||
Event
|
Event
|
||||||
)
|
)
|
||||||
|
|
||||||
import torba
|
|
||||||
from torba.server import text
|
from torba.server import text
|
||||||
from torba.server import util
|
from torba.server import util
|
||||||
from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
|
from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
|
||||||
|
@ -664,9 +663,9 @@ class SessionBase(RPCSession):
|
||||||
super().connection_lost(exc)
|
super().connection_lost(exc)
|
||||||
self.session_mgr.remove_session(self)
|
self.session_mgr.remove_session(self)
|
||||||
msg = ''
|
msg = ''
|
||||||
if not self.can_send.is_set():
|
if not self._can_send.is_set():
|
||||||
msg += ' whilst paused'
|
msg += ' whilst paused'
|
||||||
if self.concurrency.max_concurrent != self.max_concurrent:
|
if self._concurrency.max_concurrent != self.max_concurrent:
|
||||||
msg += ' whilst throttled'
|
msg += ' whilst throttled'
|
||||||
if self.send_size >= 1024*1024:
|
if self.send_size >= 1024*1024:
|
||||||
msg += ('. Sent {:,d} bytes in {:,d} messages'
|
msg += ('. Sent {:,d} bytes in {:,d} messages'
|
||||||
|
|
Loading…
Reference in a new issue