dropping custom wallet rpc implementation
This commit is contained in:
parent
ba154c799e
commit
fe547f1b0e
6 changed files with 0 additions and 2101 deletions
|
@ -1,11 +0,0 @@
|
||||||
from .framing import *
|
|
||||||
from .jsonrpc import *
|
|
||||||
from .socks import *
|
|
||||||
from .session import *
|
|
||||||
from .util import *
|
|
||||||
|
|
||||||
__all__ = (framing.__all__ +
|
|
||||||
jsonrpc.__all__ +
|
|
||||||
socks.__all__ +
|
|
||||||
session.__all__ +
|
|
||||||
util.__all__)
|
|
|
@ -1,239 +0,0 @@
|
||||||
# Copyright (c) 2018, Neil Booth
|
|
||||||
#
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# The MIT License (MIT)
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining
|
|
||||||
# a copy of this software and associated documentation files (the
|
|
||||||
# "Software"), to deal in the Software without restriction, including
|
|
||||||
# without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
|
||||||
# permit persons to whom the Software is furnished to do so, subject to
|
|
||||||
# the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be
|
|
||||||
# included in all copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
||||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
||||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
||||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
"""RPC message framing in a byte stream."""
|
|
||||||
|
|
||||||
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
|
|
||||||
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
|
|
||||||
|
|
||||||
from hashlib import sha256 as _sha256
|
|
||||||
from struct import Struct
|
|
||||||
from asyncio import Queue
|
|
||||||
|
|
||||||
|
|
||||||
class FramerBase:
|
|
||||||
"""Abstract base class for a framer.
|
|
||||||
|
|
||||||
A framer breaks an incoming byte stream into protocol messages,
|
|
||||||
buffering if necessary. It also frames outgoing messages into
|
|
||||||
a byte stream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def frame(self, message):
|
|
||||||
"""Return the framed message."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def received_bytes(self, data):
|
|
||||||
"""Pass incoming network bytes."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def receive_message(self):
|
|
||||||
"""Wait for a complete unframed message to arrive, and return it."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class NewlineFramer(FramerBase):
|
|
||||||
"""A framer for a protocol where messages are separated by newlines."""
|
|
||||||
|
|
||||||
# The default max_size value is motivated by JSONRPC, where a
|
|
||||||
# normal request will be 250 bytes or less, and a reasonable
|
|
||||||
# batch may contain 4000 requests.
|
|
||||||
def __init__(self, max_size=250 * 4000):
|
|
||||||
"""max_size - an anti-DoS measure. If, after processing an incoming
|
|
||||||
message, buffered data would exceed max_size bytes, that
|
|
||||||
buffered data is dropped entirely and the framer waits for a
|
|
||||||
newline character to re-synchronize the stream.
|
|
||||||
"""
|
|
||||||
self.max_size = max_size
|
|
||||||
self.queue = Queue()
|
|
||||||
self.received_bytes = self.queue.put_nowait
|
|
||||||
self.synchronizing = False
|
|
||||||
self.residual = b''
|
|
||||||
|
|
||||||
def frame(self, message):
|
|
||||||
return message + b'\n'
|
|
||||||
|
|
||||||
async def receive_message(self):
|
|
||||||
parts = []
|
|
||||||
buffer_size = 0
|
|
||||||
while True:
|
|
||||||
part = self.residual
|
|
||||||
self.residual = b''
|
|
||||||
if not part:
|
|
||||||
part = await self.queue.get()
|
|
||||||
|
|
||||||
npos = part.find(b'\n')
|
|
||||||
if npos == -1:
|
|
||||||
parts.append(part)
|
|
||||||
buffer_size += len(part)
|
|
||||||
# Ignore over-sized messages; re-synchronize
|
|
||||||
if buffer_size <= self.max_size:
|
|
||||||
continue
|
|
||||||
self.synchronizing = True
|
|
||||||
raise MemoryError(f'dropping message over {self.max_size:,d} '
|
|
||||||
f'bytes and re-synchronizing')
|
|
||||||
|
|
||||||
tail, self.residual = part[:npos], part[npos + 1:]
|
|
||||||
if self.synchronizing:
|
|
||||||
self.synchronizing = False
|
|
||||||
return await self.receive_message()
|
|
||||||
else:
|
|
||||||
parts.append(tail)
|
|
||||||
return b''.join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
class ByteQueue:
|
|
||||||
"""A producer-comsumer queue. Incoming network data is put as it
|
|
||||||
arrives, and the consumer calls an async method waiting for data of
|
|
||||||
a specific length."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.queue = Queue()
|
|
||||||
self.parts = []
|
|
||||||
self.parts_len = 0
|
|
||||||
self.put_nowait = self.queue.put_nowait
|
|
||||||
|
|
||||||
async def receive(self, size):
|
|
||||||
while self.parts_len < size:
|
|
||||||
part = await self.queue.get()
|
|
||||||
self.parts.append(part)
|
|
||||||
self.parts_len += len(part)
|
|
||||||
self.parts_len -= size
|
|
||||||
whole = b''.join(self.parts)
|
|
||||||
self.parts = [whole[size:]]
|
|
||||||
return whole[:size]
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryFramer:
|
|
||||||
"""A framer for binary messaging protocols."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.byte_queue = ByteQueue()
|
|
||||||
self.message_queue = Queue()
|
|
||||||
self.received_bytes = self.byte_queue.put_nowait
|
|
||||||
|
|
||||||
def frame(self, message):
|
|
||||||
command, payload = message
|
|
||||||
return b''.join((
|
|
||||||
self._build_header(command, payload),
|
|
||||||
payload
|
|
||||||
))
|
|
||||||
|
|
||||||
async def receive_message(self):
|
|
||||||
command, payload_len, checksum = await self._receive_header()
|
|
||||||
payload = await self.byte_queue.receive(payload_len)
|
|
||||||
payload_checksum = self._checksum(payload)
|
|
||||||
if payload_checksum != checksum:
|
|
||||||
raise BadChecksumError(payload_checksum, checksum)
|
|
||||||
return command, payload
|
|
||||||
|
|
||||||
def _checksum(self, payload):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _build_header(self, command, payload):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def _receive_header(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
# Helpers
|
|
||||||
struct_le_I = Struct('<I')
|
|
||||||
pack_le_uint32 = struct_le_I.pack
|
|
||||||
|
|
||||||
|
|
||||||
def sha256(x):
|
|
||||||
"""Simple wrapper of hashlib sha256."""
|
|
||||||
return _sha256(x).digest()
|
|
||||||
|
|
||||||
|
|
||||||
def double_sha256(x):
|
|
||||||
"""SHA-256 of SHA-256, as used extensively in bitcoin."""
|
|
||||||
return sha256(sha256(x))
|
|
||||||
|
|
||||||
|
|
||||||
class BadChecksumError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BadMagicError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OversizedPayloadError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BitcoinFramer(BinaryFramer):
|
|
||||||
"""Provides a framer of binary message payloads in the style of the
|
|
||||||
Bitcoin network protocol.
|
|
||||||
|
|
||||||
Each binary message has the following elements, in order:
|
|
||||||
|
|
||||||
Magic - to confirm network (currently unused for stream sync)
|
|
||||||
Command - padded command
|
|
||||||
Length - payload length in bytes
|
|
||||||
Checksum - checksum of the payload
|
|
||||||
Payload - binary payload
|
|
||||||
|
|
||||||
Call frame(command, payload) to get a framed message.
|
|
||||||
Pass incoming network bytes to received_bytes().
|
|
||||||
Wait on receive_message() to get incoming (command, payload) pairs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, magic, max_block_size):
|
|
||||||
def pad_command(command):
|
|
||||||
fill = 12 - len(command)
|
|
||||||
if fill < 0:
|
|
||||||
raise ValueError(f'command {command} too long')
|
|
||||||
return command + bytes(fill)
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self._magic = magic
|
|
||||||
self._max_block_size = max_block_size
|
|
||||||
self._pad_command = pad_command
|
|
||||||
self._unpack = Struct(f'<4s12sI4s').unpack
|
|
||||||
|
|
||||||
def _checksum(self, payload):
|
|
||||||
return double_sha256(payload)[:4]
|
|
||||||
|
|
||||||
def _build_header(self, command, payload):
|
|
||||||
return b''.join((
|
|
||||||
self._magic,
|
|
||||||
self._pad_command(command),
|
|
||||||
pack_le_uint32(len(payload)),
|
|
||||||
self._checksum(payload)
|
|
||||||
))
|
|
||||||
|
|
||||||
async def _receive_header(self):
|
|
||||||
header = await self.byte_queue.receive(24)
|
|
||||||
magic, command, payload_len, checksum = self._unpack(header)
|
|
||||||
if magic != self._magic:
|
|
||||||
raise BadMagicError(magic, self._magic)
|
|
||||||
command = command.rstrip(b'\0')
|
|
||||||
if payload_len > 1024 * 1024:
|
|
||||||
if command != b'block' or payload_len > self._max_block_size:
|
|
||||||
raise OversizedPayloadError(command, payload_len)
|
|
||||||
return command, payload_len, checksum
|
|
|
@ -1,804 +0,0 @@
|
||||||
# Copyright (c) 2018, Neil Booth
|
|
||||||
#
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# The MIT License (MIT)
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining
|
|
||||||
# a copy of this software and associated documentation files (the
|
|
||||||
# "Software"), to deal in the Software without restriction, including
|
|
||||||
# without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
|
||||||
# permit persons to whom the Software is furnished to do so, subject to
|
|
||||||
# the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be
|
|
||||||
# included in all copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
||||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
||||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
||||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
"""Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation."""
|
|
||||||
|
|
||||||
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
|
|
||||||
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
|
|
||||||
'RPCError', 'ProtocolError',
|
|
||||||
'JSONRPCConnection', 'handler_invocation')
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import json
|
|
||||||
import typing
|
|
||||||
import asyncio
|
|
||||||
from functools import partial
|
|
||||||
from numbers import Number
|
|
||||||
|
|
||||||
import attr
|
|
||||||
from asyncio import Queue, Event, CancelledError
|
|
||||||
from .util import signature_info
|
|
||||||
|
|
||||||
|
|
||||||
class SingleRequest:
|
|
||||||
__slots__ = ('method', 'args')
|
|
||||||
|
|
||||||
def __init__(self, method, args):
|
|
||||||
if not isinstance(method, str):
|
|
||||||
raise ProtocolError(JSONRPC.METHOD_NOT_FOUND,
|
|
||||||
'method must be a string')
|
|
||||||
if not isinstance(args, (list, tuple, dict)):
|
|
||||||
raise ProtocolError.invalid_args('request arguments must be a '
|
|
||||||
'list or a dictionary')
|
|
||||||
self.args = args
|
|
||||||
self.method = method
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}({self.method!r}, {self.args!r})'
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return (isinstance(other, self.__class__) and
|
|
||||||
self.method == other.method and self.args == other.args)
|
|
||||||
|
|
||||||
|
|
||||||
class Request(SingleRequest):
|
|
||||||
def send_result(self, response):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Notification(SingleRequest):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Batch:
|
|
||||||
__slots__ = ('items', )
|
|
||||||
|
|
||||||
def __init__(self, items):
|
|
||||||
if not isinstance(items, (list, tuple)):
|
|
||||||
raise ProtocolError.invalid_request('items must be a list')
|
|
||||||
if not items:
|
|
||||||
raise ProtocolError.empty_batch()
|
|
||||||
if not (all(isinstance(item, SingleRequest) for item in items) or
|
|
||||||
all(isinstance(item, Response) for item in items)):
|
|
||||||
raise ProtocolError.invalid_request('batch must be homogeneous')
|
|
||||||
self.items = items
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.items)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self.items[item]
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.items)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'Batch({len(self.items)} items)'
|
|
||||||
|
|
||||||
|
|
||||||
class Response:
|
|
||||||
__slots__ = ('result', )
|
|
||||||
|
|
||||||
def __init__(self, result):
|
|
||||||
# Type checking happens when converting to a message
|
|
||||||
self.result = result
|
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageError(Exception):
|
|
||||||
|
|
||||||
def __init__(self, code, message):
|
|
||||||
super().__init__(code, message)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def code(self):
|
|
||||||
return self.args[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def message(self):
|
|
||||||
return self.args[1]
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return (isinstance(other, self.__class__) and
|
|
||||||
self.code == other.code and self.message == other.message)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
# overridden to make the exception hashable
|
|
||||||
# see https://bugs.python.org/issue28603
|
|
||||||
return hash((self.code, self.message))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def invalid_args(cls, message):
|
|
||||||
return cls(JSONRPC.INVALID_ARGS, message)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def invalid_request(cls, message):
|
|
||||||
return cls(JSONRPC.INVALID_REQUEST, message)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def empty_batch(cls):
|
|
||||||
return cls.invalid_request('batch is empty')
|
|
||||||
|
|
||||||
|
|
||||||
class RPCError(CodeMessageError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolError(CodeMessageError):
|
|
||||||
|
|
||||||
def __init__(self, code, message):
|
|
||||||
super().__init__(code, message)
|
|
||||||
# If not None send this unframed message over the network
|
|
||||||
self.error_message = None
|
|
||||||
# If the error was in a JSON response message; its message ID.
|
|
||||||
# Since None can be a response message ID, "id" means the
|
|
||||||
# error was not sent in a JSON response
|
|
||||||
self.response_msg_id = id
|
|
||||||
|
|
||||||
|
|
||||||
class JSONRPC:
|
|
||||||
"""Abstract base class that interprets and constructs JSON RPC messages."""
|
|
||||||
|
|
||||||
# Error codes. See http://www.jsonrpc.org/specification
|
|
||||||
PARSE_ERROR = -32700
|
|
||||||
INVALID_REQUEST = -32600
|
|
||||||
METHOD_NOT_FOUND = -32601
|
|
||||||
INVALID_ARGS = -32602
|
|
||||||
INTERNAL_ERROR = -32603
|
|
||||||
QUERY_TIMEOUT = -32000
|
|
||||||
|
|
||||||
# Codes specific to this library
|
|
||||||
ERROR_CODE_UNAVAILABLE = -100
|
|
||||||
|
|
||||||
# Can be overridden by derived classes
|
|
||||||
allow_batches = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _message_id(cls, message, require_id):
|
|
||||||
"""Validate the message is a dictionary and return its ID.
|
|
||||||
|
|
||||||
Raise an error if the message is invalid or the ID is of an
|
|
||||||
invalid type. If it has no ID, raise an error if require_id
|
|
||||||
is True, otherwise return None.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_message(cls, message):
|
|
||||||
"""Validate other parts of the message other than those
|
|
||||||
done in _message_id."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _request_args(cls, request):
|
|
||||||
"""Validate the existence and type of the arguments passed
|
|
||||||
in the request dictionary."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _process_request(cls, payload):
|
|
||||||
request_id = None
|
|
||||||
try:
|
|
||||||
request_id = cls._message_id(payload, False)
|
|
||||||
cls._validate_message(payload)
|
|
||||||
method = payload.get('method')
|
|
||||||
if request_id is None:
|
|
||||||
item = Notification(method, cls._request_args(payload))
|
|
||||||
else:
|
|
||||||
item = Request(method, cls._request_args(payload))
|
|
||||||
return item, request_id
|
|
||||||
except ProtocolError as error:
|
|
||||||
code, message = error.code, error.message
|
|
||||||
raise cls._error(code, message, True, request_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _process_response(cls, payload):
|
|
||||||
request_id = None
|
|
||||||
try:
|
|
||||||
request_id = cls._message_id(payload, True)
|
|
||||||
cls._validate_message(payload)
|
|
||||||
return Response(cls.response_value(payload)), request_id
|
|
||||||
except ProtocolError as error:
|
|
||||||
code, message = error.code, error.message
|
|
||||||
raise cls._error(code, message, False, request_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _message_to_payload(cls, message):
|
|
||||||
"""Returns a Python object or a ProtocolError."""
|
|
||||||
try:
|
|
||||||
return json.loads(message.decode())
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
message = 'messages must be encoded in UTF-8'
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
message = 'invalid JSON'
|
|
||||||
raise cls._error(cls.PARSE_ERROR, message, True, None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _error(cls, code, message, send, msg_id):
|
|
||||||
error = ProtocolError(code, message)
|
|
||||||
if send:
|
|
||||||
error.error_message = cls.response_message(error, msg_id)
|
|
||||||
else:
|
|
||||||
error.response_msg_id = msg_id
|
|
||||||
return error
|
|
||||||
|
|
||||||
#
|
|
||||||
# External API
|
|
||||||
#
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def message_to_item(cls, message):
|
|
||||||
"""Translate an unframed received message and return an
|
|
||||||
(item, request_id) pair.
|
|
||||||
|
|
||||||
The item can be a Request, Notification, Response or a list.
|
|
||||||
|
|
||||||
A JSON RPC error response is returned as an RPCError inside a
|
|
||||||
Response object.
|
|
||||||
|
|
||||||
If a Batch is returned, request_id is an iterable of request
|
|
||||||
ids, one per batch member.
|
|
||||||
|
|
||||||
If the message violates the protocol in some way a
|
|
||||||
ProtocolError is returned, except if the message was
|
|
||||||
determined to be a response, in which case the ProtocolError
|
|
||||||
is placed inside a Response object. This is so that client
|
|
||||||
code can mark a request as having been responded to even if
|
|
||||||
the response was bad.
|
|
||||||
|
|
||||||
raises: ProtocolError
|
|
||||||
"""
|
|
||||||
payload = cls._message_to_payload(message)
|
|
||||||
if isinstance(payload, dict):
|
|
||||||
if 'method' in payload:
|
|
||||||
return cls._process_request(payload)
|
|
||||||
else:
|
|
||||||
return cls._process_response(payload)
|
|
||||||
elif isinstance(payload, list) and cls.allow_batches:
|
|
||||||
if not payload:
|
|
||||||
raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty',
|
|
||||||
True, None)
|
|
||||||
return payload, None
|
|
||||||
raise cls._error(cls.INVALID_REQUEST,
|
|
||||||
'request object must be a dictionary', True, None)
|
|
||||||
|
|
||||||
# Message formation
|
|
||||||
@classmethod
|
|
||||||
def request_message(cls, item, request_id):
|
|
||||||
"""Convert an RPCRequest item to a message."""
|
|
||||||
assert isinstance(item, Request)
|
|
||||||
return cls.encode_payload(cls.request_payload(item, request_id))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def notification_message(cls, item):
|
|
||||||
"""Convert an RPCRequest item to a message."""
|
|
||||||
assert isinstance(item, Notification)
|
|
||||||
return cls.encode_payload(cls.request_payload(item, None))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def response_message(cls, result, request_id):
|
|
||||||
"""Convert a response result (or RPCError) to a message."""
|
|
||||||
if isinstance(result, CodeMessageError):
|
|
||||||
payload = cls.error_payload(result, request_id)
|
|
||||||
else:
|
|
||||||
payload = cls.response_payload(result, request_id)
|
|
||||||
return cls.encode_payload(payload)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def batch_message(cls, batch, request_ids):
|
|
||||||
"""Convert a request Batch to a message."""
|
|
||||||
assert isinstance(batch, Batch)
|
|
||||||
if not cls.allow_batches:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'protocol does not permit batches')
|
|
||||||
id_iter = iter(request_ids)
|
|
||||||
rm = cls.request_message
|
|
||||||
nm = cls.notification_message
|
|
||||||
parts = (rm(request, next(id_iter)) if isinstance(request, Request)
|
|
||||||
else nm(request) for request in batch)
|
|
||||||
return cls.batch_message_from_parts(parts)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def batch_message_from_parts(cls, messages):
|
|
||||||
"""Convert messages, one per batch item, into a batch message. At
|
|
||||||
least one message must be passed.
|
|
||||||
"""
|
|
||||||
# Comma-separate the messages and wrap the lot in square brackets
|
|
||||||
middle = b', '.join(messages)
|
|
||||||
if not middle:
|
|
||||||
raise ProtocolError.empty_batch()
|
|
||||||
return b''.join([b'[', middle, b']'])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def encode_payload(cls, payload):
|
|
||||||
"""Encode a Python object as JSON and convert it to bytes."""
|
|
||||||
try:
|
|
||||||
return json.dumps(payload).encode()
|
|
||||||
except TypeError:
|
|
||||||
msg = f'JSON payload encoding error: {payload}'
|
|
||||||
raise ProtocolError(cls.INTERNAL_ERROR, msg) from None
|
|
||||||
|
|
||||||
|
|
||||||
class JSONRPCv1(JSONRPC):
|
|
||||||
"""JSON RPC version 1.0."""
|
|
||||||
|
|
||||||
allow_batches = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _message_id(cls, message, require_id):
|
|
||||||
# JSONv1 requires an ID always, but without constraint on its type
|
|
||||||
# No need to test for a dictionary here as we don't handle batches.
|
|
||||||
if 'id' not in message:
|
|
||||||
raise ProtocolError.invalid_request('request has no "id"')
|
|
||||||
return message['id']
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _request_args(cls, request):
|
|
||||||
args = request.get('params')
|
|
||||||
if not isinstance(args, list):
|
|
||||||
raise ProtocolError.invalid_args(
|
|
||||||
f'invalid request arguments: {args}')
|
|
||||||
return args
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _best_effort_error(cls, error):
|
|
||||||
# Do our best to interpret the error
|
|
||||||
code = cls.ERROR_CODE_UNAVAILABLE
|
|
||||||
message = 'no error message provided'
|
|
||||||
if isinstance(error, str):
|
|
||||||
message = error
|
|
||||||
elif isinstance(error, int):
|
|
||||||
code = error
|
|
||||||
elif isinstance(error, dict):
|
|
||||||
if isinstance(error.get('message'), str):
|
|
||||||
message = error['message']
|
|
||||||
if isinstance(error.get('code'), int):
|
|
||||||
code = error['code']
|
|
||||||
|
|
||||||
return RPCError(code, message)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def response_value(cls, payload):
|
|
||||||
if 'result' not in payload or 'error' not in payload:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'response must contain both "result" and "error"')
|
|
||||||
|
|
||||||
result = payload['result']
|
|
||||||
error = payload['error']
|
|
||||||
if error is None:
|
|
||||||
return result # It seems None can be a valid result
|
|
||||||
if result is not None:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'response has a "result" and an "error"')
|
|
||||||
|
|
||||||
return cls._best_effort_error(error)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def request_payload(cls, request, request_id):
|
|
||||||
"""JSON v1 request (or notification) payload."""
|
|
||||||
if isinstance(request.args, dict):
|
|
||||||
raise ProtocolError.invalid_args(
|
|
||||||
'JSONRPCv1 does not support named arguments')
|
|
||||||
return {
|
|
||||||
'method': request.method,
|
|
||||||
'params': request.args,
|
|
||||||
'id': request_id
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def response_payload(cls, result, request_id):
|
|
||||||
"""JSON v1 response payload."""
|
|
||||||
return {
|
|
||||||
'result': result,
|
|
||||||
'error': None,
|
|
||||||
'id': request_id
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def error_payload(cls, error, request_id):
|
|
||||||
return {
|
|
||||||
'result': None,
|
|
||||||
'error': {'code': error.code, 'message': error.message},
|
|
||||||
'id': request_id
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class JSONRPCv2(JSONRPC):
|
|
||||||
"""JSON RPC version 2.0."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _message_id(cls, message, require_id):
|
|
||||||
if not isinstance(message, dict):
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'request object must be a dictionary')
|
|
||||||
if 'id' in message:
|
|
||||||
request_id = message['id']
|
|
||||||
if not isinstance(request_id, (Number, str, type(None))):
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
f'invalid "id": {request_id}')
|
|
||||||
return request_id
|
|
||||||
else:
|
|
||||||
if require_id:
|
|
||||||
raise ProtocolError.invalid_request('request has no "id"')
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_message(cls, message):
|
|
||||||
if message.get('jsonrpc') != '2.0':
|
|
||||||
raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"')
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _request_args(cls, request):
|
|
||||||
args = request.get('params', [])
|
|
||||||
if not isinstance(args, (dict, list)):
|
|
||||||
raise ProtocolError.invalid_args(
|
|
||||||
f'invalid request arguments: {args}')
|
|
||||||
return args
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def response_value(cls, payload):
|
|
||||||
if 'result' in payload:
|
|
||||||
if 'error' in payload:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'response contains both "result" and "error"')
|
|
||||||
return payload['result']
|
|
||||||
|
|
||||||
if 'error' not in payload:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'response contains neither "result" nor "error"')
|
|
||||||
|
|
||||||
# Return an RPCError object
|
|
||||||
error = payload['error']
|
|
||||||
if isinstance(error, dict):
|
|
||||||
code = error.get('code')
|
|
||||||
message = error.get('message')
|
|
||||||
if isinstance(code, int) and isinstance(message, str):
|
|
||||||
return RPCError(code, message)
|
|
||||||
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
f'ill-formed response error object: {error}')
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def request_payload(cls, request, request_id):
|
|
||||||
"""JSON v2 request (or notification) payload."""
|
|
||||||
payload = {
|
|
||||||
'jsonrpc': '2.0',
|
|
||||||
'method': request.method,
|
|
||||||
}
|
|
||||||
# A notification?
|
|
||||||
if request_id is not None:
|
|
||||||
payload['id'] = request_id
|
|
||||||
# Preserve empty dicts as missing params is read as an array
|
|
||||||
if request.args or request.args == {}:
|
|
||||||
payload['params'] = request.args
|
|
||||||
return payload
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def response_payload(cls, result, request_id):
|
|
||||||
"""JSON v2 response payload."""
|
|
||||||
return {
|
|
||||||
'jsonrpc': '2.0',
|
|
||||||
'result': result,
|
|
||||||
'id': request_id
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def error_payload(cls, error, request_id):
|
|
||||||
return {
|
|
||||||
'jsonrpc': '2.0',
|
|
||||||
'error': {'code': error.code, 'message': error.message},
|
|
||||||
'id': request_id
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class JSONRPCLoose(JSONRPC):
|
|
||||||
"""A relaxed version of JSON RPC."""
|
|
||||||
|
|
||||||
# Don't be so loose we accept any old message ID
|
|
||||||
_message_id = JSONRPCv2._message_id
|
|
||||||
_validate_message = JSONRPC._validate_message
|
|
||||||
_request_args = JSONRPCv2._request_args
|
|
||||||
# Outoing messages are JSONRPCv2 so we give the other side the
|
|
||||||
# best chance to assume / detect JSONRPCv2 as default protocol.
|
|
||||||
error_payload = JSONRPCv2.error_payload
|
|
||||||
request_payload = JSONRPCv2.request_payload
|
|
||||||
response_payload = JSONRPCv2.response_payload
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def response_value(cls, payload):
|
|
||||||
# Return result, unless it is None and there is an error
|
|
||||||
if payload.get('error') is not None:
|
|
||||||
if payload.get('result') is not None:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'response contains both "result" and "error"')
|
|
||||||
return JSONRPCv1._best_effort_error(payload['error'])
|
|
||||||
|
|
||||||
if 'result' not in payload:
|
|
||||||
raise ProtocolError.invalid_request(
|
|
||||||
'response contains neither "result" nor "error"')
|
|
||||||
|
|
||||||
# Can be None
|
|
||||||
return payload['result']
|
|
||||||
|
|
||||||
|
|
||||||
class JSONRPCAutoDetect(JSONRPCv2):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def message_to_item(cls, message):
|
|
||||||
return cls.detect_protocol(message), None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_protocol(cls, message):
|
|
||||||
"""Attempt to detect the protocol from the message."""
|
|
||||||
main = cls._message_to_payload(message)
|
|
||||||
|
|
||||||
def protocol_for_payload(payload):
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
return JSONRPCLoose # Will error
|
|
||||||
# Obey an explicit "jsonrpc"
|
|
||||||
version = payload.get('jsonrpc')
|
|
||||||
if version == '2.0':
|
|
||||||
return JSONRPCv2
|
|
||||||
if version == '1.0':
|
|
||||||
return JSONRPCv1
|
|
||||||
|
|
||||||
# Now to decide between JSONRPCLoose and JSONRPCv1 if possible
|
|
||||||
if 'result' in payload and 'error' in payload:
|
|
||||||
return JSONRPCv1
|
|
||||||
return JSONRPCLoose
|
|
||||||
|
|
||||||
if isinstance(main, list):
|
|
||||||
parts = {protocol_for_payload(payload) for payload in main}
|
|
||||||
# If all same protocol, return it
|
|
||||||
if len(parts) == 1:
|
|
||||||
return parts.pop()
|
|
||||||
# If strict protocol detected, return it, preferring JSONRPCv2.
|
|
||||||
# This means a batch of JSONRPCv1 will fail
|
|
||||||
for protocol in (JSONRPCv2, JSONRPCv1):
|
|
||||||
if protocol in parts:
|
|
||||||
return protocol
|
|
||||||
# Will error if no parts
|
|
||||||
return JSONRPCLoose
|
|
||||||
|
|
||||||
return protocol_for_payload(main)
|
|
||||||
|
|
||||||
|
|
||||||
class JSONRPCConnection:
|
|
||||||
"""Maintains state of a JSON RPC connection, in particular
|
|
||||||
encapsulating the handling of request IDs.
|
|
||||||
|
|
||||||
protocol - the JSON RPC protocol to follow
|
|
||||||
max_response_size - responses over this size send an error response
|
|
||||||
instead.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_id_counter = itertools.count()
|
|
||||||
|
|
||||||
def __init__(self, protocol):
|
|
||||||
self._protocol = protocol
|
|
||||||
# Sent Requests and Batches that have not received a response.
|
|
||||||
# The key is its request ID; for a batch it is sorted tuple
|
|
||||||
# of request IDs
|
|
||||||
self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {}
|
|
||||||
# A public attribute intended to be settable dynamically
|
|
||||||
self.max_response_size = 0
|
|
||||||
|
|
||||||
def _oversized_response_message(self, request_id):
|
|
||||||
text = f'response too large (over {self.max_response_size:,d} bytes'
|
|
||||||
error = RPCError.invalid_request(text)
|
|
||||||
return self._protocol.response_message(error, request_id)
|
|
||||||
|
|
||||||
def _receive_response(self, result, request_id):
|
|
||||||
if request_id not in self._requests:
|
|
||||||
if request_id is None and isinstance(result, RPCError):
|
|
||||||
message = f'diagnostic error received: {result}'
|
|
||||||
else:
|
|
||||||
message = f'response to unsent request (ID: {request_id})'
|
|
||||||
raise ProtocolError.invalid_request(message) from None
|
|
||||||
request, event = self._requests.pop(request_id)
|
|
||||||
event.result = result
|
|
||||||
event.set()
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _receive_request_batch(self, payloads):
|
|
||||||
def item_send_result(request_id, result):
|
|
||||||
nonlocal size
|
|
||||||
part = protocol.response_message(result, request_id)
|
|
||||||
size += len(part) + 2
|
|
||||||
if size > self.max_response_size > 0:
|
|
||||||
part = self._oversized_response_message(request_id)
|
|
||||||
parts.append(part)
|
|
||||||
if len(parts) == count:
|
|
||||||
return protocol.batch_message_from_parts(parts)
|
|
||||||
return None
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
items = []
|
|
||||||
size = 0
|
|
||||||
count = 0
|
|
||||||
protocol = self._protocol
|
|
||||||
for payload in payloads:
|
|
||||||
try:
|
|
||||||
item, request_id = protocol._process_request(payload)
|
|
||||||
items.append(item)
|
|
||||||
if isinstance(item, Request):
|
|
||||||
count += 1
|
|
||||||
item.send_result = partial(item_send_result, request_id)
|
|
||||||
except ProtocolError as error:
|
|
||||||
count += 1
|
|
||||||
parts.append(error.error_message)
|
|
||||||
|
|
||||||
if not items and parts:
|
|
||||||
protocol_error = ProtocolError(0, "")
|
|
||||||
protocol_error.error_message = protocol.batch_message_from_parts(parts)
|
|
||||||
raise protocol_error
|
|
||||||
return items
|
|
||||||
|
|
||||||
def _receive_response_batch(self, payloads):
|
|
||||||
request_ids = []
|
|
||||||
results = []
|
|
||||||
for payload in payloads:
|
|
||||||
# Let ProtocolError exceptions through
|
|
||||||
item, request_id = self._protocol._process_response(payload)
|
|
||||||
request_ids.append(request_id)
|
|
||||||
results.append(item.result)
|
|
||||||
|
|
||||||
ordered = sorted(zip(request_ids, results), key=lambda t: t[0])
|
|
||||||
ordered_ids, ordered_results = zip(*ordered)
|
|
||||||
if ordered_ids not in self._requests:
|
|
||||||
raise ProtocolError.invalid_request('response to unsent batch')
|
|
||||||
request_batch, event = self._requests.pop(ordered_ids)
|
|
||||||
event.result = ordered_results
|
|
||||||
event.set()
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _send_result(self, request_id, result):
|
|
||||||
message = self._protocol.response_message(result, request_id)
|
|
||||||
if len(message) > self.max_response_size > 0:
|
|
||||||
message = self._oversized_response_message(request_id)
|
|
||||||
return message
|
|
||||||
|
|
||||||
def _event(self, request, request_id):
|
|
||||||
event = Event()
|
|
||||||
self._requests[request_id] = (request, event)
|
|
||||||
return event
|
|
||||||
|
|
||||||
#
|
|
||||||
# External API
|
|
||||||
#
|
|
||||||
def send_request(self, request: Request) -> typing.Tuple[bytes, Event]:
|
|
||||||
"""Send a Request. Return a (message, event) pair.
|
|
||||||
|
|
||||||
The message is an unframed message to send over the network.
|
|
||||||
Wait on the event for the response; which will be in the
|
|
||||||
"result" attribute.
|
|
||||||
|
|
||||||
Raises: ProtocolError if the request violates the protocol
|
|
||||||
in some way..
|
|
||||||
"""
|
|
||||||
request_id = next(self._id_counter)
|
|
||||||
message = self._protocol.request_message(request, request_id)
|
|
||||||
return message, self._event(request, request_id)
|
|
||||||
|
|
||||||
def send_notification(self, notification):
|
|
||||||
return self._protocol.notification_message(notification)
|
|
||||||
|
|
||||||
def send_batch(self, batch):
|
|
||||||
ids = tuple(next(self._id_counter)
|
|
||||||
for request in batch if isinstance(request, Request))
|
|
||||||
message = self._protocol.batch_message(batch, ids)
|
|
||||||
event = self._event(batch, ids) if ids else None
|
|
||||||
return message, event
|
|
||||||
|
|
||||||
def receive_message(self, message):
|
|
||||||
"""Call with an unframed message received from the network.
|
|
||||||
|
|
||||||
Raises: ProtocolError if the message violates the protocol in
|
|
||||||
some way. However, if it happened in a response that can be
|
|
||||||
paired with a request, the ProtocolError is instead set in the
|
|
||||||
result attribute of the send_request() that caused the error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
item, request_id = self._protocol.message_to_item(message)
|
|
||||||
except ProtocolError as e:
|
|
||||||
if e.response_msg_id is not id:
|
|
||||||
return self._receive_response(e, e.response_msg_id)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if isinstance(item, Request):
|
|
||||||
item.send_result = partial(self._send_result, request_id)
|
|
||||||
return [item]
|
|
||||||
if isinstance(item, Notification):
|
|
||||||
return [item]
|
|
||||||
if isinstance(item, Response):
|
|
||||||
return self._receive_response(item.result, request_id)
|
|
||||||
if isinstance(item, list):
|
|
||||||
if all(isinstance(payload, dict)
|
|
||||||
and ('result' in payload or 'error' in payload)
|
|
||||||
for payload in item):
|
|
||||||
return self._receive_response_batch(item)
|
|
||||||
else:
|
|
||||||
return self._receive_request_batch(item)
|
|
||||||
else:
|
|
||||||
# Protocol auto-detection hack
|
|
||||||
assert issubclass(item, JSONRPC)
|
|
||||||
self._protocol = item
|
|
||||||
return self.receive_message(message)
|
|
||||||
|
|
||||||
def raise_pending_requests(self, exception):
|
|
||||||
exception = exception or asyncio.TimeoutError()
|
|
||||||
for request, event in self._requests.values():
|
|
||||||
event.result = exception
|
|
||||||
event.set()
|
|
||||||
self._requests.clear()
|
|
||||||
|
|
||||||
def pending_requests(self):
|
|
||||||
"""All sent requests that have not received a response."""
|
|
||||||
return [request for request, event in self._requests.values()]
|
|
||||||
|
|
||||||
|
|
||||||
def handler_invocation(handler, request):
|
|
||||||
method, args = request.method, request.args
|
|
||||||
if handler is None:
|
|
||||||
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
|
|
||||||
f'unknown method "{method}"')
|
|
||||||
|
|
||||||
# We must test for too few and too many arguments. How
|
|
||||||
# depends on whether the arguments were passed as a list or as
|
|
||||||
# a dictionary.
|
|
||||||
info = signature_info(handler)
|
|
||||||
if isinstance(args, (tuple, list)):
|
|
||||||
if len(args) < info.min_args:
|
|
||||||
s = '' if len(args) == 1 else 's'
|
|
||||||
raise RPCError.invalid_args(
|
|
||||||
f'{len(args)} argument{s} passed to method '
|
|
||||||
f'"{method}" but it requires {info.min_args}')
|
|
||||||
if info.max_args is not None and len(args) > info.max_args:
|
|
||||||
s = '' if len(args) == 1 else 's'
|
|
||||||
raise RPCError.invalid_args(
|
|
||||||
f'{len(args)} argument{s} passed to method '
|
|
||||||
f'{method} taking at most {info.max_args}')
|
|
||||||
return partial(handler, *args)
|
|
||||||
|
|
||||||
# Arguments passed by name
|
|
||||||
if info.other_names is None:
|
|
||||||
raise RPCError.invalid_args(f'method "{method}" cannot '
|
|
||||||
f'be called with named arguments')
|
|
||||||
|
|
||||||
missing = set(info.required_names).difference(args)
|
|
||||||
if missing:
|
|
||||||
s = '' if len(missing) == 1 else 's'
|
|
||||||
missing = ', '.join(sorted(f'"{name}"' for name in missing))
|
|
||||||
raise RPCError.invalid_args(f'method "{method}" requires '
|
|
||||||
f'parameter{s} {missing}')
|
|
||||||
|
|
||||||
if info.other_names is not any:
|
|
||||||
excess = set(args).difference(info.required_names)
|
|
||||||
excess = excess.difference(info.other_names)
|
|
||||||
if excess:
|
|
||||||
s = '' if len(excess) == 1 else 's'
|
|
||||||
excess = ', '.join(sorted(f'"{name}"' for name in excess))
|
|
||||||
raise RPCError.invalid_args(f'method "{method}" does not '
|
|
||||||
f'take parameter{s} {excess}')
|
|
||||||
return partial(handler, **args)
|
|
|
@ -1,513 +0,0 @@
|
||||||
# Copyright (c) 2018, Neil Booth
|
|
||||||
#
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# The MIT License (MIT)
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining
|
|
||||||
# a copy of this software and associated documentation files (the
|
|
||||||
# "Software"), to deal in the Software without restriction, including
|
|
||||||
# without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
|
||||||
# permit persons to whom the Software is furnished to do so, subject to
|
|
||||||
# the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be
|
|
||||||
# included in all copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
||||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
||||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
||||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server',
|
|
||||||
'BatchError')
|
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from asyncio import Event, CancelledError
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from contextlib import suppress
|
|
||||||
|
|
||||||
from lbry.wallet.tasks import TaskGroup
|
|
||||||
|
|
||||||
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
|
|
||||||
from .jsonrpc import RPCError, ProtocolError
|
|
||||||
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
|
|
||||||
from lbry.wallet.server.prometheus import NOTIFICATION_COUNT, RESPONSE_TIMES, REQUEST_ERRORS_COUNT, RESET_CONNECTIONS
|
|
||||||
|
|
||||||
|
|
||||||
class Connector:
|
|
||||||
|
|
||||||
def __init__(self, session_factory, host=None, port=None, proxy=None,
|
|
||||||
**kwargs):
|
|
||||||
self.session_factory = session_factory
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.proxy = proxy
|
|
||||||
self.loop = kwargs.get('loop', asyncio.get_event_loop())
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
async def create_connection(self):
|
|
||||||
"""Initiate a connection."""
|
|
||||||
connector = self.proxy or self.loop
|
|
||||||
return await connector.create_connection(
|
|
||||||
self.session_factory, self.host, self.port, **self.kwargs)
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
transport, self.protocol = await self.create_connection()
|
|
||||||
return self.protocol
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
await self.protocol.close()
|
|
||||||
|
|
||||||
|
|
||||||
class SessionBase(asyncio.Protocol):
|
|
||||||
"""Base class of networking sessions.
|
|
||||||
|
|
||||||
There is no client / server distinction other than who initiated
|
|
||||||
the connection.
|
|
||||||
|
|
||||||
To initiate a connection to a remote server pass host, port and
|
|
||||||
proxy to the constructor, and then call create_connection(). Each
|
|
||||||
successful call should have a corresponding call to close().
|
|
||||||
|
|
||||||
Alternatively if used in a with statement, the connection is made
|
|
||||||
on entry to the block, and closed on exit from the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_errors = 10
|
|
||||||
|
|
||||||
def __init__(self, *, framer=None, loop=None):
|
|
||||||
self.framer = framer or self.default_framer()
|
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
|
||||||
self.logger = logging.getLogger(self.__class__.__name__)
|
|
||||||
self.transport = None
|
|
||||||
# Set when a connection is made
|
|
||||||
self._address = None
|
|
||||||
self._proxy_address = None
|
|
||||||
# For logger.debug messages
|
|
||||||
self.verbosity = 0
|
|
||||||
# Cleared when the send socket is full
|
|
||||||
self._can_send = Event()
|
|
||||||
self._can_send.set()
|
|
||||||
self._pm_task = None
|
|
||||||
self._task_group = TaskGroup(self.loop)
|
|
||||||
# Force-close a connection if a send doesn't succeed in this time
|
|
||||||
self.max_send_delay = 60
|
|
||||||
# Statistics. The RPC object also keeps its own statistics.
|
|
||||||
self.start_time = time.perf_counter()
|
|
||||||
self.errors = 0
|
|
||||||
self.send_count = 0
|
|
||||||
self.send_size = 0
|
|
||||||
self.last_send = self.start_time
|
|
||||||
self.recv_count = 0
|
|
||||||
self.recv_size = 0
|
|
||||||
self.last_recv = self.start_time
|
|
||||||
self.last_packet_received = self.start_time
|
|
||||||
|
|
||||||
async def _limited_wait(self, secs):
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._can_send.wait(), secs)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.abort()
|
|
||||||
raise asyncio.TimeoutError(f'task timed out after {secs}s')
|
|
||||||
|
|
||||||
async def _send_message(self, message):
|
|
||||||
if not self._can_send.is_set():
|
|
||||||
await self._limited_wait(self.max_send_delay)
|
|
||||||
if not self.is_closing():
|
|
||||||
framed_message = self.framer.frame(message)
|
|
||||||
self.send_size += len(framed_message)
|
|
||||||
self.send_count += 1
|
|
||||||
self.last_send = time.perf_counter()
|
|
||||||
if self.verbosity >= 4:
|
|
||||||
self.logger.debug(f'Sending framed message {framed_message}')
|
|
||||||
self.transport.write(framed_message)
|
|
||||||
|
|
||||||
def _bump_errors(self):
|
|
||||||
self.errors += 1
|
|
||||||
if self.errors >= self.max_errors:
|
|
||||||
# Don't await self.close() because that is self-cancelling
|
|
||||||
self._close()
|
|
||||||
|
|
||||||
def _close(self):
|
|
||||||
if self.transport:
|
|
||||||
self.transport.close()
|
|
||||||
|
|
||||||
# asyncio framework
|
|
||||||
def data_received(self, framed_message):
|
|
||||||
"""Called by asyncio when a message comes in."""
|
|
||||||
self.last_packet_received = time.perf_counter()
|
|
||||||
if self.verbosity >= 4:
|
|
||||||
self.logger.debug(f'Received framed message {framed_message}')
|
|
||||||
self.recv_size += len(framed_message)
|
|
||||||
self.framer.received_bytes(framed_message)
|
|
||||||
|
|
||||||
def pause_writing(self):
|
|
||||||
"""Transport calls when the send buffer is full."""
|
|
||||||
if not self.is_closing():
|
|
||||||
self._can_send.clear()
|
|
||||||
self.transport.pause_reading()
|
|
||||||
|
|
||||||
def resume_writing(self):
|
|
||||||
"""Transport calls when the send buffer has room."""
|
|
||||||
if not self._can_send.is_set():
|
|
||||||
self._can_send.set()
|
|
||||||
self.transport.resume_reading()
|
|
||||||
|
|
||||||
def connection_made(self, transport):
|
|
||||||
"""Called by asyncio when a connection is established.
|
|
||||||
|
|
||||||
Derived classes overriding this method must call this first."""
|
|
||||||
self.transport = transport
|
|
||||||
# This would throw if called on a closed SSL transport. Fixed
|
|
||||||
# in asyncio in Python 3.6.1 and 3.5.4
|
|
||||||
peer_address = transport.get_extra_info('peername')
|
|
||||||
# If the Socks proxy was used then _address is already set to
|
|
||||||
# the remote address
|
|
||||||
if self._address:
|
|
||||||
self._proxy_address = peer_address
|
|
||||||
else:
|
|
||||||
self._address = peer_address
|
|
||||||
self._pm_task = self.loop.create_task(self._receive_messages())
|
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
|
||||||
"""Called by asyncio when the connection closes.
|
|
||||||
|
|
||||||
Tear down things done in connection_made."""
|
|
||||||
self._address = None
|
|
||||||
self.transport = None
|
|
||||||
self._task_group.cancel()
|
|
||||||
if self._pm_task:
|
|
||||||
self._pm_task.cancel()
|
|
||||||
# Release waiting tasks
|
|
||||||
self._can_send.set()
|
|
||||||
|
|
||||||
# External API
|
|
||||||
def default_framer(self):
|
|
||||||
"""Return a default framer."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def peer_address(self):
|
|
||||||
"""Returns the peer's address (Python networking address), or None if
|
|
||||||
no connection or an error.
|
|
||||||
|
|
||||||
This is the result of socket.getpeername() when the connection
|
|
||||||
was made.
|
|
||||||
"""
|
|
||||||
return self._address
|
|
||||||
|
|
||||||
def peer_address_str(self):
|
|
||||||
"""Returns the peer's IP address and port as a human-readable
|
|
||||||
string."""
|
|
||||||
if not self._address:
|
|
||||||
return 'unknown'
|
|
||||||
ip_addr_str, port = self._address[:2]
|
|
||||||
if ':' in ip_addr_str:
|
|
||||||
return f'[{ip_addr_str}]:{port}'
|
|
||||||
else:
|
|
||||||
return f'{ip_addr_str}:{port}'
|
|
||||||
|
|
||||||
def is_closing(self):
|
|
||||||
"""Return True if the connection is closing."""
|
|
||||||
return not self.transport or self.transport.is_closing()
|
|
||||||
|
|
||||||
def abort(self):
|
|
||||||
"""Forcefully close the connection."""
|
|
||||||
if self.transport:
|
|
||||||
self.transport.abort()
|
|
||||||
|
|
||||||
# TODO: replace with synchronous_close
|
|
||||||
async def close(self, *, force_after=30):
|
|
||||||
"""Close the connection and return when closed."""
|
|
||||||
self._close()
|
|
||||||
if self._pm_task:
|
|
||||||
with suppress(CancelledError):
|
|
||||||
await asyncio.wait([self._pm_task], timeout=force_after)
|
|
||||||
self.abort()
|
|
||||||
await self._pm_task
|
|
||||||
|
|
||||||
def synchronous_close(self):
|
|
||||||
self._close()
|
|
||||||
if self._pm_task and not self._pm_task.done():
|
|
||||||
self._pm_task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
class MessageSession(SessionBase):
|
|
||||||
"""Session class for protocols where messages are not tied to responses,
|
|
||||||
such as the Bitcoin protocol.
|
|
||||||
|
|
||||||
To use as a client (connection-opening) session, pass host, port
|
|
||||||
and perhaps a proxy.
|
|
||||||
"""
|
|
||||||
async def _receive_messages(self):
|
|
||||||
while not self.is_closing():
|
|
||||||
try:
|
|
||||||
message = await self.framer.receive_message()
|
|
||||||
except BadMagicError as e:
|
|
||||||
magic, expected = e.args
|
|
||||||
self.logger.error(
|
|
||||||
f'bad network magic: got {magic} expected {expected}, '
|
|
||||||
f'disconnecting'
|
|
||||||
)
|
|
||||||
self._close()
|
|
||||||
except OversizedPayloadError as e:
|
|
||||||
command, payload_len = e.args
|
|
||||||
self.logger.error(
|
|
||||||
f'oversized payload of {payload_len:,d} bytes to command '
|
|
||||||
f'{command}, disconnecting'
|
|
||||||
)
|
|
||||||
self._close()
|
|
||||||
except BadChecksumError as e:
|
|
||||||
payload_checksum, claimed_checksum = e.args
|
|
||||||
self.logger.warning(
|
|
||||||
f'checksum mismatch: actual {payload_checksum.hex()} '
|
|
||||||
f'vs claimed {claimed_checksum.hex()}'
|
|
||||||
)
|
|
||||||
self._bump_errors()
|
|
||||||
else:
|
|
||||||
self.last_recv = time.perf_counter()
|
|
||||||
self.recv_count += 1
|
|
||||||
await self._task_group.add(self._handle_message(message))
|
|
||||||
|
|
||||||
async def _handle_message(self, message):
|
|
||||||
try:
|
|
||||||
await self.handle_message(message)
|
|
||||||
except ProtocolError as e:
|
|
||||||
self.logger.error(f'{e}')
|
|
||||||
self._bump_errors()
|
|
||||||
except CancelledError:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
self.logger.exception(f'exception handling {message}')
|
|
||||||
self._bump_errors()
|
|
||||||
|
|
||||||
# External API
|
|
||||||
def default_framer(self):
|
|
||||||
"""Return a bitcoin framer."""
|
|
||||||
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
|
|
||||||
|
|
||||||
async def handle_message(self, message):
|
|
||||||
"""message is a (command, payload) pair."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_message(self, message):
|
|
||||||
"""Send a message (command, payload) over the network."""
|
|
||||||
await self._send_message(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchError(Exception):
|
|
||||||
|
|
||||||
def __init__(self, request):
|
|
||||||
self.request = request # BatchRequest object
|
|
||||||
|
|
||||||
|
|
||||||
class BatchRequest:
|
|
||||||
"""Used to build a batch request to send to the server. Stores
|
|
||||||
the
|
|
||||||
|
|
||||||
Attributes batch and results are initially None.
|
|
||||||
|
|
||||||
Adding an invalid request or notification immediately raises a
|
|
||||||
ProtocolError.
|
|
||||||
|
|
||||||
On exiting the with clause, it will:
|
|
||||||
|
|
||||||
1) create a Batch object for the requests in the order they were
|
|
||||||
added. If the batch is empty this raises a ProtocolError.
|
|
||||||
|
|
||||||
2) set the "batch" attribute to be that batch
|
|
||||||
|
|
||||||
3) send the batch request and wait for a response
|
|
||||||
|
|
||||||
4) raise a ProtocolError if the protocol was violated by the
|
|
||||||
server. Currently this only happens if it gave more than one
|
|
||||||
response to any request
|
|
||||||
|
|
||||||
5) otherwise there is precisely one response to each Request. Set
|
|
||||||
the "results" attribute to the tuple of results; the responses
|
|
||||||
are ordered to match the Requests in the batch. Notifications
|
|
||||||
do not get a response.
|
|
||||||
|
|
||||||
6) if raise_errors is True and any individual response was a JSON
|
|
||||||
RPC error response, or violated the protocol in some way, a
|
|
||||||
BatchError exception is raised. Otherwise the caller can be
|
|
||||||
certain each request returned a standard result.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, session, raise_errors):
|
|
||||||
self._session = session
|
|
||||||
self._raise_errors = raise_errors
|
|
||||||
self._requests = []
|
|
||||||
self.batch = None
|
|
||||||
self.results = None
|
|
||||||
|
|
||||||
def add_request(self, method, args=()):
|
|
||||||
self._requests.append(Request(method, args))
|
|
||||||
|
|
||||||
def add_notification(self, method, args=()):
|
|
||||||
self._requests.append(Notification(method, args))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._requests)
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
if exc_type is None:
|
|
||||||
self.batch = Batch(self._requests)
|
|
||||||
message, event = self._session.connection.send_batch(self.batch)
|
|
||||||
await self._session._send_message(message)
|
|
||||||
await event.wait()
|
|
||||||
self.results = event.result
|
|
||||||
if self._raise_errors:
|
|
||||||
if any(isinstance(item, Exception) for item in event.result):
|
|
||||||
raise BatchError(self)
|
|
||||||
|
|
||||||
|
|
||||||
class RPCSession(SessionBase):
|
|
||||||
"""Base class for protocols where a message can lead to a response,
|
|
||||||
for example JSON RPC."""
|
|
||||||
|
|
||||||
def __init__(self, *, framer=None, loop=None, connection=None):
|
|
||||||
super().__init__(framer=framer, loop=loop)
|
|
||||||
self.connection = connection or self.default_connection()
|
|
||||||
self.client_version = 'unknown'
|
|
||||||
|
|
||||||
async def _receive_messages(self):
|
|
||||||
while not self.is_closing():
|
|
||||||
try:
|
|
||||||
message = await self.framer.receive_message()
|
|
||||||
except MemoryError:
|
|
||||||
self.logger.warning('received oversized message from %s:%s, dropping connection',
|
|
||||||
self._address[0], self._address[1])
|
|
||||||
RESET_CONNECTIONS.labels(version=self.client_version).inc()
|
|
||||||
self._close()
|
|
||||||
return
|
|
||||||
|
|
||||||
self.last_recv = time.perf_counter()
|
|
||||||
self.recv_count += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
requests = self.connection.receive_message(message)
|
|
||||||
except ProtocolError as e:
|
|
||||||
self.logger.debug(f'{e}')
|
|
||||||
if e.error_message:
|
|
||||||
await self._send_message(e.error_message)
|
|
||||||
if e.code == JSONRPC.PARSE_ERROR:
|
|
||||||
self.max_errors = 0
|
|
||||||
self._bump_errors()
|
|
||||||
else:
|
|
||||||
for request in requests:
|
|
||||||
await self._task_group.add(self._handle_request(request))
|
|
||||||
|
|
||||||
async def _handle_request(self, request):
|
|
||||||
start = time.perf_counter()
|
|
||||||
try:
|
|
||||||
result = await self.handle_request(request)
|
|
||||||
except (ProtocolError, RPCError) as e:
|
|
||||||
result = e
|
|
||||||
except CancelledError:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
self.logger.exception(f'exception handling {request}')
|
|
||||||
result = RPCError(JSONRPC.INTERNAL_ERROR,
|
|
||||||
'internal server error')
|
|
||||||
if isinstance(request, Request):
|
|
||||||
message = request.send_result(result)
|
|
||||||
RESPONSE_TIMES.labels(
|
|
||||||
method=request.method,
|
|
||||||
version=self.client_version
|
|
||||||
).observe(time.perf_counter() - start)
|
|
||||||
if message:
|
|
||||||
await self._send_message(message)
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
self._bump_errors()
|
|
||||||
REQUEST_ERRORS_COUNT.labels(
|
|
||||||
method=request.method,
|
|
||||||
version=self.client_version
|
|
||||||
).inc()
|
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
|
||||||
# Cancel pending requests and message processing
|
|
||||||
self.connection.raise_pending_requests(exc)
|
|
||||||
super().connection_lost(exc)
|
|
||||||
|
|
||||||
# External API
|
|
||||||
def default_connection(self):
|
|
||||||
"""Return a default connection if the user provides none."""
|
|
||||||
return JSONRPCConnection(JSONRPCv2)
|
|
||||||
|
|
||||||
def default_framer(self):
|
|
||||||
"""Return a default framer."""
|
|
||||||
return NewlineFramer()
|
|
||||||
|
|
||||||
async def handle_request(self, request):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_request(self, method, args=()):
|
|
||||||
"""Send an RPC request over the network."""
|
|
||||||
if self.is_closing():
|
|
||||||
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
|
|
||||||
message, event = self.connection.send_request(Request(method, args))
|
|
||||||
await self._send_message(message)
|
|
||||||
await event.wait()
|
|
||||||
result = event.result
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
raise result
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def send_notification(self, method, args=()):
|
|
||||||
"""Send an RPC notification over the network."""
|
|
||||||
message = self.connection.send_notification(Notification(method, args))
|
|
||||||
NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc()
|
|
||||||
await self._send_message(message)
|
|
||||||
|
|
||||||
def send_batch(self, raise_errors=False):
|
|
||||||
"""Return a BatchRequest. Intended to be used like so:
|
|
||||||
|
|
||||||
async with session.send_batch() as batch:
|
|
||||||
batch.add_request("method1")
|
|
||||||
batch.add_request("sum", (x, y))
|
|
||||||
batch.add_notification("updated")
|
|
||||||
|
|
||||||
for result in batch.results:
|
|
||||||
...
|
|
||||||
|
|
||||||
Note that in some circumstances exceptions can be raised; see
|
|
||||||
BatchRequest doc string.
|
|
||||||
"""
|
|
||||||
return BatchRequest(self, raise_errors)
|
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
|
||||||
"""A simple wrapper around an asyncio.Server object."""
|
|
||||||
|
|
||||||
def __init__(self, session_factory, host=None, port=None, *,
|
|
||||||
loop=None, **kwargs):
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
|
||||||
self.server = None
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._kwargs = kwargs
|
|
||||||
|
|
||||||
async def listen(self):
|
|
||||||
self.server = await self.loop.create_server(
|
|
||||||
self._session_factory, self.host, self.port, **self._kwargs)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Close the listening socket. This does not close any ServerSession
|
|
||||||
objects created to handle incoming connections.
|
|
||||||
"""
|
|
||||||
if self.server:
|
|
||||||
self.server.close()
|
|
||||||
await self.server.wait_closed()
|
|
||||||
self.server = None
|
|
|
@ -1,439 +0,0 @@
|
||||||
# Copyright (c) 2018, Neil Booth
|
|
||||||
#
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# The MIT License (MIT)
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining
|
|
||||||
# a copy of this software and associated documentation files (the
|
|
||||||
# "Software"), to deal in the Software without restriction, including
|
|
||||||
# without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
|
||||||
# permit persons to whom the Software is furnished to do so, subject to
|
|
||||||
# the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be
|
|
||||||
# included in all copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
||||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
||||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
||||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
"""SOCKS proxying."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import collections
|
|
||||||
import ipaddress
|
|
||||||
import socket
|
|
||||||
import struct
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
|
|
||||||
'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
|
|
||||||
|
|
||||||
|
|
||||||
SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKSError(Exception):
|
|
||||||
"""Base class for SOCKS exceptions. Each raised exception will be
|
|
||||||
an instance of a derived class."""
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKSProtocolError(SOCKSError):
|
|
||||||
"""Raised when the proxy does not follow the SOCKS protocol"""
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKSFailure(SOCKSError):
|
|
||||||
"""Raised when the proxy refuses or fails to make a connection"""
|
|
||||||
|
|
||||||
|
|
||||||
class NeedData(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKSBase:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def name(cls):
|
|
||||||
return cls.__name__
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._buffer = bytes()
|
|
||||||
self._state = self._start
|
|
||||||
|
|
||||||
def _read(self, size):
|
|
||||||
if len(self._buffer) < size:
|
|
||||||
raise NeedData(size - len(self._buffer))
|
|
||||||
result = self._buffer[:size]
|
|
||||||
self._buffer = self._buffer[size:]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def receive_data(self, data):
|
|
||||||
self._buffer += data
|
|
||||||
|
|
||||||
def next_message(self):
|
|
||||||
return self._state()
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKS4(SOCKSBase):
|
|
||||||
"""SOCKS4 protocol wrapper."""
|
|
||||||
|
|
||||||
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
|
|
||||||
REPLY_CODES = {
|
|
||||||
90: 'request granted',
|
|
||||||
91: 'request rejected or failed',
|
|
||||||
92: ('request rejected because SOCKS server cannot connect '
|
|
||||||
'to identd on the client'),
|
|
||||||
93: ('request rejected because the client program and identd '
|
|
||||||
'report different user-ids')
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, dst_host, dst_port, auth):
|
|
||||||
super().__init__()
|
|
||||||
self._dst_host = self._check_host(dst_host)
|
|
||||||
self._dst_port = dst_port
|
|
||||||
self._auth = auth
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _check_host(cls, host):
|
|
||||||
if not isinstance(host, ipaddress.IPv4Address):
|
|
||||||
try:
|
|
||||||
host = ipaddress.IPv4Address(host)
|
|
||||||
except ValueError:
|
|
||||||
raise SOCKSProtocolError(
|
|
||||||
f'SOCKS4 requires an IPv4 address: {host}') from None
|
|
||||||
return host
|
|
||||||
|
|
||||||
def _start(self):
|
|
||||||
self._state = self._first_response
|
|
||||||
|
|
||||||
if isinstance(self._dst_host, ipaddress.IPv4Address):
|
|
||||||
# SOCKS4
|
|
||||||
dst_ip_packed = self._dst_host.packed
|
|
||||||
host_bytes = b''
|
|
||||||
else:
|
|
||||||
# SOCKS4a
|
|
||||||
dst_ip_packed = b'\0\0\0\1'
|
|
||||||
host_bytes = self._dst_host.encode() + b'\0'
|
|
||||||
|
|
||||||
if isinstance(self._auth, SOCKSUserAuth):
|
|
||||||
user_id = self._auth.username.encode()
|
|
||||||
else:
|
|
||||||
user_id = b''
|
|
||||||
|
|
||||||
# Send TCP/IP stream CONNECT request
|
|
||||||
return b''.join([b'\4\1', struct.pack('>H', self._dst_port),
|
|
||||||
dst_ip_packed, user_id, b'\0', host_bytes])
|
|
||||||
|
|
||||||
def _first_response(self):
|
|
||||||
# Wait for 8-byte response
|
|
||||||
data = self._read(8)
|
|
||||||
if data[0] != 0:
|
|
||||||
raise SOCKSProtocolError(f'invalid {self.name()} proxy '
|
|
||||||
f'response: {data}')
|
|
||||||
reply_code = data[1]
|
|
||||||
if reply_code != 90:
|
|
||||||
msg = self.REPLY_CODES.get(
|
|
||||||
reply_code, f'unknown {self.name()} reply code {reply_code}')
|
|
||||||
raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
|
|
||||||
|
|
||||||
# Other fields ignored
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKS4a(SOCKS4):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _check_host(cls, host):
|
|
||||||
if not isinstance(host, (str, ipaddress.IPv4Address)):
|
|
||||||
raise SOCKSProtocolError(
|
|
||||||
f'SOCKS4a requires an IPv4 address or host name: {host}')
|
|
||||||
return host
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKS5(SOCKSBase):
|
|
||||||
"""SOCKS protocol wrapper."""
|
|
||||||
|
|
||||||
# See https://tools.ietf.org/html/rfc1928
|
|
||||||
ERROR_CODES = {
|
|
||||||
1: 'general SOCKS server failure',
|
|
||||||
2: 'connection not allowed by ruleset',
|
|
||||||
3: 'network unreachable',
|
|
||||||
4: 'host unreachable',
|
|
||||||
5: 'connection refused',
|
|
||||||
6: 'TTL expired',
|
|
||||||
7: 'command not supported',
|
|
||||||
8: 'address type not supported',
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, dst_host, dst_port, auth):
|
|
||||||
super().__init__()
|
|
||||||
self._dst_bytes = self._destination_bytes(dst_host, dst_port)
|
|
||||||
self._auth_bytes, self._auth_methods = self._authentication(auth)
|
|
||||||
|
|
||||||
def _destination_bytes(self, host, port):
|
|
||||||
if isinstance(host, ipaddress.IPv4Address):
|
|
||||||
addr_bytes = b'\1' + host.packed
|
|
||||||
elif isinstance(host, ipaddress.IPv6Address):
|
|
||||||
addr_bytes = b'\4' + host.packed
|
|
||||||
elif isinstance(host, str):
|
|
||||||
host = host.encode()
|
|
||||||
if len(host) > 255:
|
|
||||||
raise SOCKSProtocolError(f'hostname too long: '
|
|
||||||
f'{len(host)} bytes')
|
|
||||||
addr_bytes = b'\3' + bytes([len(host)]) + host
|
|
||||||
else:
|
|
||||||
raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 '
|
|
||||||
f'address, or host name: {host}')
|
|
||||||
return addr_bytes + struct.pack('>H', port)
|
|
||||||
|
|
||||||
def _authentication(self, auth):
|
|
||||||
if isinstance(auth, SOCKSUserAuth):
|
|
||||||
user_bytes = auth.username.encode()
|
|
||||||
if not 0 < len(user_bytes) < 256:
|
|
||||||
raise SOCKSProtocolError(f'username {auth.username} has '
|
|
||||||
f'invalid length {len(user_bytes)}')
|
|
||||||
pwd_bytes = auth.password.encode()
|
|
||||||
if not 0 < len(pwd_bytes) < 256:
|
|
||||||
raise SOCKSProtocolError(f'password has invalid length '
|
|
||||||
f'{len(pwd_bytes)}')
|
|
||||||
return b''.join([bytes([1, len(user_bytes)]), user_bytes,
|
|
||||||
bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
|
|
||||||
return b'', [0]
|
|
||||||
|
|
||||||
def _start(self):
|
|
||||||
self._state = self._first_response
|
|
||||||
return (b'\5' + bytes([len(self._auth_methods)])
|
|
||||||
+ bytes(m for m in self._auth_methods))
|
|
||||||
|
|
||||||
def _first_response(self):
|
|
||||||
# Wait for 2-byte response
|
|
||||||
data = self._read(2)
|
|
||||||
if data[0] != 5:
|
|
||||||
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
|
||||||
if data[1] not in self._auth_methods:
|
|
||||||
raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
|
|
||||||
|
|
||||||
# Authenticate if user-password authentication
|
|
||||||
if data[1] == 2:
|
|
||||||
self._state = self._auth_response
|
|
||||||
return self._auth_bytes
|
|
||||||
return self._request_connection()
|
|
||||||
|
|
||||||
def _auth_response(self):
|
|
||||||
data = self._read(2)
|
|
||||||
if data[0] != 1:
|
|
||||||
raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
|
|
||||||
f'response: {data}')
|
|
||||||
if data[1] != 0:
|
|
||||||
raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
|
|
||||||
f'{data[1]}')
|
|
||||||
|
|
||||||
return self._request_connection()
|
|
||||||
|
|
||||||
def _request_connection(self):
|
|
||||||
# Send connection request
|
|
||||||
self._state = self._connect_response
|
|
||||||
return b'\5\1\0' + self._dst_bytes
|
|
||||||
|
|
||||||
def _connect_response(self):
|
|
||||||
data = self._read(5)
|
|
||||||
if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
|
|
||||||
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
|
|
||||||
if data[1] != 0:
|
|
||||||
raise SOCKSFailure(self.ERROR_CODES.get(
|
|
||||||
data[1], f'unknown SOCKS5 error code: {data[1]}'))
|
|
||||||
|
|
||||||
if data[3] == 1:
|
|
||||||
addr_len = 3 # IPv4
|
|
||||||
elif data[3] == 3:
|
|
||||||
addr_len = data[4] # Hostname
|
|
||||||
else:
|
|
||||||
addr_len = 15 # IPv6
|
|
||||||
|
|
||||||
self._state = partial(self._connect_response_rest, addr_len)
|
|
||||||
return self.next_message()
|
|
||||||
|
|
||||||
def _connect_response_rest(self, addr_len):
|
|
||||||
self._read(addr_len + 2)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class SOCKSProxy:
|
|
||||||
|
|
||||||
def __init__(self, address, protocol, auth):
|
|
||||||
"""A SOCKS proxy at an address following a SOCKS protocol. auth is an
|
|
||||||
authentication method to use when connecting, or None.
|
|
||||||
|
|
||||||
address is a (host, port) pair; for IPv6 it can instead be a
|
|
||||||
(host, port, flowinfo, scopeid) 4-tuple.
|
|
||||||
"""
|
|
||||||
self.address = address
|
|
||||||
self.protocol = protocol
|
|
||||||
self.auth = auth
|
|
||||||
# Set on each successful connection via the proxy to the
|
|
||||||
# result of socket.getpeername()
|
|
||||||
self.peername = None
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
auth = 'username' if self.auth else 'none'
|
|
||||||
return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
|
|
||||||
|
|
||||||
async def _handshake(self, client, sock, loop):
|
|
||||||
while True:
|
|
||||||
count = 0
|
|
||||||
try:
|
|
||||||
message = client.next_message()
|
|
||||||
except NeedData as e:
|
|
||||||
count = e.args[0]
|
|
||||||
else:
|
|
||||||
if message is None:
|
|
||||||
return
|
|
||||||
await loop.sock_sendall(sock, message)
|
|
||||||
|
|
||||||
if count:
|
|
||||||
data = await loop.sock_recv(sock, count)
|
|
||||||
if not data:
|
|
||||||
raise SOCKSProtocolError("EOF received")
|
|
||||||
client.receive_data(data)
|
|
||||||
|
|
||||||
async def _connect_one(self, host, port):
|
|
||||||
"""Connect to the proxy and perform a handshake requesting a
|
|
||||||
connection to (host, port).
|
|
||||||
|
|
||||||
Return the open socket on success, or the exception on failure.
|
|
||||||
"""
|
|
||||||
client = self.protocol(host, port, self.auth)
|
|
||||||
sock = socket.socket()
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
try:
|
|
||||||
# A non-blocking socket is required by loop socket methods
|
|
||||||
sock.setblocking(False)
|
|
||||||
await loop.sock_connect(sock, self.address)
|
|
||||||
await self._handshake(client, sock, loop)
|
|
||||||
self.peername = sock.getpeername()
|
|
||||||
return sock
|
|
||||||
except Exception as e:
|
|
||||||
# Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8
|
|
||||||
if sys.platform.startswith('linux') or sys.platform == "darwin":
|
|
||||||
sock.close()
|
|
||||||
return e
|
|
||||||
|
|
||||||
async def _connect(self, addresses):
|
|
||||||
"""Connect to the proxy and perform a handshake requesting a
|
|
||||||
connection to each address in addresses.
|
|
||||||
|
|
||||||
Return an (open_socket, address) pair on success.
|
|
||||||
"""
|
|
||||||
assert len(addresses) > 0
|
|
||||||
|
|
||||||
exceptions = []
|
|
||||||
for address in addresses:
|
|
||||||
host, port = address[:2]
|
|
||||||
sock = await self._connect_one(host, port)
|
|
||||||
if isinstance(sock, socket.socket):
|
|
||||||
return sock, address
|
|
||||||
exceptions.append(sock)
|
|
||||||
|
|
||||||
strings = {f'{exc!r}' for exc in exceptions}
|
|
||||||
raise (exceptions[0] if len(strings) == 1 else
|
|
||||||
OSError(f'multiple exceptions: {", ".join(strings)}'))
|
|
||||||
|
|
||||||
async def _detect_proxy(self):
|
|
||||||
"""Return True if it appears we can connect to a SOCKS proxy,
|
|
||||||
otherwise False.
|
|
||||||
"""
|
|
||||||
if self.protocol is SOCKS4a:
|
|
||||||
host, port = 'www.apple.com', 80
|
|
||||||
else:
|
|
||||||
host, port = ipaddress.IPv4Address('8.8.8.8'), 53
|
|
||||||
|
|
||||||
sock = await self._connect_one(host, port)
|
|
||||||
if isinstance(sock, socket.socket):
|
|
||||||
sock.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
# SOCKSFailure indicates something failed, but that we are
|
|
||||||
# likely talking to a proxy
|
|
||||||
return isinstance(sock, SOCKSFailure)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def auto_detect_address(cls, address, auth):
|
|
||||||
"""Try to detect a SOCKS proxy at address using the authentication
|
|
||||||
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
|
|
||||||
order. If a SOCKS proxy is detected a SOCKSProxy object is
|
|
||||||
returned.
|
|
||||||
|
|
||||||
Returning a SOCKSProxy does not mean it is functioning - for
|
|
||||||
example, it may have no network connectivity.
|
|
||||||
|
|
||||||
If no proxy is detected return None.
|
|
||||||
"""
|
|
||||||
for protocol in (SOCKS5, SOCKS4a, SOCKS4):
|
|
||||||
proxy = cls(address, protocol, auth)
|
|
||||||
if await proxy._detect_proxy():
|
|
||||||
return proxy
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def auto_detect_host(cls, host, ports, auth):
|
|
||||||
"""Try to detect a SOCKS proxy on a host on one of the ports.
|
|
||||||
|
|
||||||
Calls auto_detect for the ports in order. Returns SOCKS are
|
|
||||||
tried in order; a SOCKSProxy object for the first detected
|
|
||||||
proxy is returned.
|
|
||||||
|
|
||||||
Returning a SOCKSProxy does not mean it is functioning - for
|
|
||||||
example, it may have no network connectivity.
|
|
||||||
|
|
||||||
If no proxy is detected return None.
|
|
||||||
"""
|
|
||||||
for port in ports:
|
|
||||||
address = (host, port)
|
|
||||||
proxy = await cls.auto_detect_address(address, auth)
|
|
||||||
if proxy:
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def create_connection(self, protocol_factory, host, port, *,
|
|
||||||
resolve=False, ssl=None,
|
|
||||||
family=0, proto=0, flags=0):
|
|
||||||
"""Set up a connection to (host, port) through the proxy.
|
|
||||||
|
|
||||||
If resolve is True then host is resolved locally with
|
|
||||||
getaddrinfo using family, proto and flags, otherwise the proxy
|
|
||||||
is asked to resolve host.
|
|
||||||
|
|
||||||
The function signature is similar to loop.create_connection()
|
|
||||||
with the same result. The attribute _address is set on the
|
|
||||||
protocol to the address of the successful remote connection.
|
|
||||||
Additionally raises SOCKSError if something goes wrong with
|
|
||||||
the proxy handshake.
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if resolve:
|
|
||||||
infos = await loop.getaddrinfo(host, port, family=family,
|
|
||||||
type=socket.SOCK_STREAM,
|
|
||||||
proto=proto, flags=flags)
|
|
||||||
addresses = [info[4] for info in infos]
|
|
||||||
else:
|
|
||||||
addresses = [(host, port)]
|
|
||||||
|
|
||||||
sock, address = await self._connect(addresses)
|
|
||||||
|
|
||||||
def set_address():
|
|
||||||
protocol = protocol_factory()
|
|
||||||
protocol._address = address
|
|
||||||
return protocol
|
|
||||||
|
|
||||||
return await loop.create_connection(
|
|
||||||
set_address, sock=sock, ssl=ssl,
|
|
||||||
server_hostname=host if ssl else None)
|
|
|
@ -1,95 +0,0 @@
|
||||||
# Copyright (c) 2018, Neil Booth
|
|
||||||
#
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# The MIT License (MIT)
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining
|
|
||||||
# a copy of this software and associated documentation files (the
|
|
||||||
# "Software"), to deal in the Software without restriction, including
|
|
||||||
# without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
|
||||||
# permit persons to whom the Software is furnished to do so, subject to
|
|
||||||
# the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be
|
|
||||||
# included in all copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
||||||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
||||||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
||||||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
__all__ = ()
|
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections import namedtuple
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
# other_params: None means cannot be called with keyword arguments only
|
|
||||||
# any means any name is good
|
|
||||||
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
|
|
||||||
'required_names other_names')
|
|
||||||
|
|
||||||
|
|
||||||
def signature_info(func):
|
|
||||||
params = inspect.signature(func).parameters
|
|
||||||
min_args = max_args = 0
|
|
||||||
required_names = []
|
|
||||||
other_names = []
|
|
||||||
no_names = False
|
|
||||||
for p in params.values():
|
|
||||||
if p.kind == p.POSITIONAL_OR_KEYWORD:
|
|
||||||
max_args += 1
|
|
||||||
if p.default is p.empty:
|
|
||||||
min_args += 1
|
|
||||||
required_names.append(p.name)
|
|
||||||
else:
|
|
||||||
other_names.append(p.name)
|
|
||||||
elif p.kind == p.KEYWORD_ONLY:
|
|
||||||
other_names.append(p.name)
|
|
||||||
elif p.kind == p.VAR_POSITIONAL:
|
|
||||||
max_args = None
|
|
||||||
elif p.kind == p.VAR_KEYWORD:
|
|
||||||
other_names = any
|
|
||||||
elif p.kind == p.POSITIONAL_ONLY:
|
|
||||||
max_args += 1
|
|
||||||
if p.default is p.empty:
|
|
||||||
min_args += 1
|
|
||||||
no_names = True
|
|
||||||
|
|
||||||
if no_names:
|
|
||||||
other_names = None
|
|
||||||
|
|
||||||
return SignatureInfo(min_args, max_args, required_names, other_names)
|
|
||||||
|
|
||||||
|
|
||||||
class Concurrency:
|
|
||||||
|
|
||||||
def __init__(self, max_concurrent):
|
|
||||||
self._require_non_negative(max_concurrent)
|
|
||||||
self._max_concurrent = max_concurrent
|
|
||||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
|
||||||
|
|
||||||
def _require_non_negative(self, value):
|
|
||||||
if not isinstance(value, int) or value < 0:
|
|
||||||
raise RuntimeError('concurrency must be a natural number')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_concurrent(self):
|
|
||||||
return self._max_concurrent
|
|
||||||
|
|
||||||
async def set_max_concurrent(self, value):
|
|
||||||
self._require_non_negative(value)
|
|
||||||
diff = value - self._max_concurrent
|
|
||||||
self._max_concurrent = value
|
|
||||||
if diff >= 0:
|
|
||||||
for _ in range(diff):
|
|
||||||
self.semaphore.release()
|
|
||||||
else:
|
|
||||||
for _ in range(-diff):
|
|
||||||
await self.semaphore.acquire()
|
|
Loading…
Reference in a new issue