dropping custom wallet rpc implementation

This commit is contained in:
Lex Berezhny 2020-04-25 08:05:46 -04:00
parent ba154c799e
commit fe547f1b0e
6 changed files with 0 additions and 2101 deletions

View file

@ -1,11 +0,0 @@
from .framing import *
from .jsonrpc import *
from .socks import *
from .session import *
from .util import *
__all__ = (framing.__all__ +
jsonrpc.__all__ +
socks.__all__ +
session.__all__ +
util.__all__)

View file

@ -1,239 +0,0 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""RPC message framing in a byte stream."""
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
from hashlib import sha256 as _sha256
from struct import Struct
from asyncio import Queue
class FramerBase:
"""Abstract base class for a framer.
A framer breaks an incoming byte stream into protocol messages,
buffering if necessary. It also frames outgoing messages into
a byte stream.
"""
def frame(self, message):
"""Return the framed message."""
raise NotImplementedError
def received_bytes(self, data):
"""Pass incoming network bytes."""
raise NotImplementedError
async def receive_message(self):
"""Wait for a complete unframed message to arrive, and return it."""
raise NotImplementedError
class NewlineFramer(FramerBase):
"""A framer for a protocol where messages are separated by newlines."""
# The default max_size value is motivated by JSONRPC, where a
# normal request will be 250 bytes or less, and a reasonable
# batch may contain 4000 requests.
def __init__(self, max_size=250 * 4000):
"""max_size - an anti-DoS measure. If, after processing an incoming
message, buffered data would exceed max_size bytes, that
buffered data is dropped entirely and the framer waits for a
newline character to re-synchronize the stream.
"""
self.max_size = max_size
self.queue = Queue()
self.received_bytes = self.queue.put_nowait
self.synchronizing = False
self.residual = b''
def frame(self, message):
return message + b'\n'
async def receive_message(self):
parts = []
buffer_size = 0
while True:
part = self.residual
self.residual = b''
if not part:
part = await self.queue.get()
npos = part.find(b'\n')
if npos == -1:
parts.append(part)
buffer_size += len(part)
# Ignore over-sized messages; re-synchronize
if buffer_size <= self.max_size:
continue
self.synchronizing = True
raise MemoryError(f'dropping message over {self.max_size:,d} '
f'bytes and re-synchronizing')
tail, self.residual = part[:npos], part[npos + 1:]
if self.synchronizing:
self.synchronizing = False
return await self.receive_message()
else:
parts.append(tail)
return b''.join(parts)
class ByteQueue:
"""A producer-comsumer queue. Incoming network data is put as it
arrives, and the consumer calls an async method waiting for data of
a specific length."""
def __init__(self):
self.queue = Queue()
self.parts = []
self.parts_len = 0
self.put_nowait = self.queue.put_nowait
async def receive(self, size):
while self.parts_len < size:
part = await self.queue.get()
self.parts.append(part)
self.parts_len += len(part)
self.parts_len -= size
whole = b''.join(self.parts)
self.parts = [whole[size:]]
return whole[:size]
class BinaryFramer:
"""A framer for binary messaging protocols."""
def __init__(self):
self.byte_queue = ByteQueue()
self.message_queue = Queue()
self.received_bytes = self.byte_queue.put_nowait
def frame(self, message):
command, payload = message
return b''.join((
self._build_header(command, payload),
payload
))
async def receive_message(self):
command, payload_len, checksum = await self._receive_header()
payload = await self.byte_queue.receive(payload_len)
payload_checksum = self._checksum(payload)
if payload_checksum != checksum:
raise BadChecksumError(payload_checksum, checksum)
return command, payload
def _checksum(self, payload):
raise NotImplementedError
def _build_header(self, command, payload):
raise NotImplementedError
async def _receive_header(self):
raise NotImplementedError
# Helpers
struct_le_I = Struct('<I')
pack_le_uint32 = struct_le_I.pack
def sha256(x):
"""Simple wrapper of hashlib sha256."""
return _sha256(x).digest()
def double_sha256(x):
"""SHA-256 of SHA-256, as used extensively in bitcoin."""
return sha256(sha256(x))
class BadChecksumError(Exception):
pass
class BadMagicError(Exception):
pass
class OversizedPayloadError(Exception):
pass
class BitcoinFramer(BinaryFramer):
"""Provides a framer of binary message payloads in the style of the
Bitcoin network protocol.
Each binary message has the following elements, in order:
Magic - to confirm network (currently unused for stream sync)
Command - padded command
Length - payload length in bytes
Checksum - checksum of the payload
Payload - binary payload
Call frame(command, payload) to get a framed message.
Pass incoming network bytes to received_bytes().
Wait on receive_message() to get incoming (command, payload) pairs.
"""
def __init__(self, magic, max_block_size):
def pad_command(command):
fill = 12 - len(command)
if fill < 0:
raise ValueError(f'command {command} too long')
return command + bytes(fill)
super().__init__()
self._magic = magic
self._max_block_size = max_block_size
self._pad_command = pad_command
self._unpack = Struct(f'<4s12sI4s').unpack
def _checksum(self, payload):
return double_sha256(payload)[:4]
def _build_header(self, command, payload):
return b''.join((
self._magic,
self._pad_command(command),
pack_le_uint32(len(payload)),
self._checksum(payload)
))
async def _receive_header(self):
header = await self.byte_queue.receive(24)
magic, command, payload_len, checksum = self._unpack(header)
if magic != self._magic:
raise BadMagicError(magic, self._magic)
command = command.rstrip(b'\0')
if payload_len > 1024 * 1024:
if command != b'block' or payload_len > self._max_block_size:
raise OversizedPayloadError(command, payload_len)
return command, payload_len, checksum

View file

@ -1,804 +0,0 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation."""
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
'RPCError', 'ProtocolError',
'JSONRPCConnection', 'handler_invocation')
import itertools
import json
import typing
import asyncio
from functools import partial
from numbers import Number
import attr
from asyncio import Queue, Event, CancelledError
from .util import signature_info
class SingleRequest:
__slots__ = ('method', 'args')
def __init__(self, method, args):
if not isinstance(method, str):
raise ProtocolError(JSONRPC.METHOD_NOT_FOUND,
'method must be a string')
if not isinstance(args, (list, tuple, dict)):
raise ProtocolError.invalid_args('request arguments must be a '
'list or a dictionary')
self.args = args
self.method = method
def __repr__(self):
return f'{self.__class__.__name__}({self.method!r}, {self.args!r})'
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.method == other.method and self.args == other.args)
class Request(SingleRequest):
def send_result(self, response):
return None
class Notification(SingleRequest):
pass
class Batch:
__slots__ = ('items', )
def __init__(self, items):
if not isinstance(items, (list, tuple)):
raise ProtocolError.invalid_request('items must be a list')
if not items:
raise ProtocolError.empty_batch()
if not (all(isinstance(item, SingleRequest) for item in items) or
all(isinstance(item, Response) for item in items)):
raise ProtocolError.invalid_request('batch must be homogeneous')
self.items = items
def __len__(self):
return len(self.items)
def __getitem__(self, item):
return self.items[item]
def __iter__(self):
return iter(self.items)
def __repr__(self):
return f'Batch({len(self.items)} items)'
class Response:
__slots__ = ('result', )
def __init__(self, result):
# Type checking happens when converting to a message
self.result = result
class CodeMessageError(Exception):
def __init__(self, code, message):
super().__init__(code, message)
@property
def code(self):
return self.args[0]
@property
def message(self):
return self.args[1]
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.code == other.code and self.message == other.message)
def __hash__(self):
# overridden to make the exception hashable
# see https://bugs.python.org/issue28603
return hash((self.code, self.message))
@classmethod
def invalid_args(cls, message):
return cls(JSONRPC.INVALID_ARGS, message)
@classmethod
def invalid_request(cls, message):
return cls(JSONRPC.INVALID_REQUEST, message)
@classmethod
def empty_batch(cls):
return cls.invalid_request('batch is empty')
class RPCError(CodeMessageError):
pass
class ProtocolError(CodeMessageError):
def __init__(self, code, message):
super().__init__(code, message)
# If not None send this unframed message over the network
self.error_message = None
# If the error was in a JSON response message; its message ID.
# Since None can be a response message ID, "id" means the
# error was not sent in a JSON response
self.response_msg_id = id
class JSONRPC:
"""Abstract base class that interprets and constructs JSON RPC messages."""
# Error codes. See http://www.jsonrpc.org/specification
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_ARGS = -32602
INTERNAL_ERROR = -32603
QUERY_TIMEOUT = -32000
# Codes specific to this library
ERROR_CODE_UNAVAILABLE = -100
# Can be overridden by derived classes
allow_batches = True
@classmethod
def _message_id(cls, message, require_id):
"""Validate the message is a dictionary and return its ID.
Raise an error if the message is invalid or the ID is of an
invalid type. If it has no ID, raise an error if require_id
is True, otherwise return None.
"""
raise NotImplementedError
@classmethod
def _validate_message(cls, message):
"""Validate other parts of the message other than those
done in _message_id."""
pass
@classmethod
def _request_args(cls, request):
"""Validate the existence and type of the arguments passed
in the request dictionary."""
raise NotImplementedError
@classmethod
def _process_request(cls, payload):
request_id = None
try:
request_id = cls._message_id(payload, False)
cls._validate_message(payload)
method = payload.get('method')
if request_id is None:
item = Notification(method, cls._request_args(payload))
else:
item = Request(method, cls._request_args(payload))
return item, request_id
except ProtocolError as error:
code, message = error.code, error.message
raise cls._error(code, message, True, request_id)
@classmethod
def _process_response(cls, payload):
request_id = None
try:
request_id = cls._message_id(payload, True)
cls._validate_message(payload)
return Response(cls.response_value(payload)), request_id
except ProtocolError as error:
code, message = error.code, error.message
raise cls._error(code, message, False, request_id)
@classmethod
def _message_to_payload(cls, message):
"""Returns a Python object or a ProtocolError."""
try:
return json.loads(message.decode())
except UnicodeDecodeError:
message = 'messages must be encoded in UTF-8'
except json.JSONDecodeError:
message = 'invalid JSON'
raise cls._error(cls.PARSE_ERROR, message, True, None)
@classmethod
def _error(cls, code, message, send, msg_id):
error = ProtocolError(code, message)
if send:
error.error_message = cls.response_message(error, msg_id)
else:
error.response_msg_id = msg_id
return error
#
# External API
#
@classmethod
def message_to_item(cls, message):
"""Translate an unframed received message and return an
(item, request_id) pair.
The item can be a Request, Notification, Response or a list.
A JSON RPC error response is returned as an RPCError inside a
Response object.
If a Batch is returned, request_id is an iterable of request
ids, one per batch member.
If the message violates the protocol in some way a
ProtocolError is returned, except if the message was
determined to be a response, in which case the ProtocolError
is placed inside a Response object. This is so that client
code can mark a request as having been responded to even if
the response was bad.
raises: ProtocolError
"""
payload = cls._message_to_payload(message)
if isinstance(payload, dict):
if 'method' in payload:
return cls._process_request(payload)
else:
return cls._process_response(payload)
elif isinstance(payload, list) and cls.allow_batches:
if not payload:
raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty',
True, None)
return payload, None
raise cls._error(cls.INVALID_REQUEST,
'request object must be a dictionary', True, None)
# Message formation
@classmethod
def request_message(cls, item, request_id):
"""Convert an RPCRequest item to a message."""
assert isinstance(item, Request)
return cls.encode_payload(cls.request_payload(item, request_id))
@classmethod
def notification_message(cls, item):
"""Convert an RPCRequest item to a message."""
assert isinstance(item, Notification)
return cls.encode_payload(cls.request_payload(item, None))
@classmethod
def response_message(cls, result, request_id):
"""Convert a response result (or RPCError) to a message."""
if isinstance(result, CodeMessageError):
payload = cls.error_payload(result, request_id)
else:
payload = cls.response_payload(result, request_id)
return cls.encode_payload(payload)
@classmethod
def batch_message(cls, batch, request_ids):
"""Convert a request Batch to a message."""
assert isinstance(batch, Batch)
if not cls.allow_batches:
raise ProtocolError.invalid_request(
'protocol does not permit batches')
id_iter = iter(request_ids)
rm = cls.request_message
nm = cls.notification_message
parts = (rm(request, next(id_iter)) if isinstance(request, Request)
else nm(request) for request in batch)
return cls.batch_message_from_parts(parts)
@classmethod
def batch_message_from_parts(cls, messages):
"""Convert messages, one per batch item, into a batch message. At
least one message must be passed.
"""
# Comma-separate the messages and wrap the lot in square brackets
middle = b', '.join(messages)
if not middle:
raise ProtocolError.empty_batch()
return b''.join([b'[', middle, b']'])
@classmethod
def encode_payload(cls, payload):
"""Encode a Python object as JSON and convert it to bytes."""
try:
return json.dumps(payload).encode()
except TypeError:
msg = f'JSON payload encoding error: {payload}'
raise ProtocolError(cls.INTERNAL_ERROR, msg) from None
class JSONRPCv1(JSONRPC):
"""JSON RPC version 1.0."""
allow_batches = False
@classmethod
def _message_id(cls, message, require_id):
# JSONv1 requires an ID always, but without constraint on its type
# No need to test for a dictionary here as we don't handle batches.
if 'id' not in message:
raise ProtocolError.invalid_request('request has no "id"')
return message['id']
@classmethod
def _request_args(cls, request):
args = request.get('params')
if not isinstance(args, list):
raise ProtocolError.invalid_args(
f'invalid request arguments: {args}')
return args
@classmethod
def _best_effort_error(cls, error):
# Do our best to interpret the error
code = cls.ERROR_CODE_UNAVAILABLE
message = 'no error message provided'
if isinstance(error, str):
message = error
elif isinstance(error, int):
code = error
elif isinstance(error, dict):
if isinstance(error.get('message'), str):
message = error['message']
if isinstance(error.get('code'), int):
code = error['code']
return RPCError(code, message)
@classmethod
def response_value(cls, payload):
if 'result' not in payload or 'error' not in payload:
raise ProtocolError.invalid_request(
'response must contain both "result" and "error"')
result = payload['result']
error = payload['error']
if error is None:
return result # It seems None can be a valid result
if result is not None:
raise ProtocolError.invalid_request(
'response has a "result" and an "error"')
return cls._best_effort_error(error)
@classmethod
def request_payload(cls, request, request_id):
"""JSON v1 request (or notification) payload."""
if isinstance(request.args, dict):
raise ProtocolError.invalid_args(
'JSONRPCv1 does not support named arguments')
return {
'method': request.method,
'params': request.args,
'id': request_id
}
@classmethod
def response_payload(cls, result, request_id):
"""JSON v1 response payload."""
return {
'result': result,
'error': None,
'id': request_id
}
@classmethod
def error_payload(cls, error, request_id):
return {
'result': None,
'error': {'code': error.code, 'message': error.message},
'id': request_id
}
class JSONRPCv2(JSONRPC):
"""JSON RPC version 2.0."""
@classmethod
def _message_id(cls, message, require_id):
if not isinstance(message, dict):
raise ProtocolError.invalid_request(
'request object must be a dictionary')
if 'id' in message:
request_id = message['id']
if not isinstance(request_id, (Number, str, type(None))):
raise ProtocolError.invalid_request(
f'invalid "id": {request_id}')
return request_id
else:
if require_id:
raise ProtocolError.invalid_request('request has no "id"')
return None
@classmethod
def _validate_message(cls, message):
if message.get('jsonrpc') != '2.0':
raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"')
@classmethod
def _request_args(cls, request):
args = request.get('params', [])
if not isinstance(args, (dict, list)):
raise ProtocolError.invalid_args(
f'invalid request arguments: {args}')
return args
@classmethod
def response_value(cls, payload):
if 'result' in payload:
if 'error' in payload:
raise ProtocolError.invalid_request(
'response contains both "result" and "error"')
return payload['result']
if 'error' not in payload:
raise ProtocolError.invalid_request(
'response contains neither "result" nor "error"')
# Return an RPCError object
error = payload['error']
if isinstance(error, dict):
code = error.get('code')
message = error.get('message')
if isinstance(code, int) and isinstance(message, str):
return RPCError(code, message)
raise ProtocolError.invalid_request(
f'ill-formed response error object: {error}')
@classmethod
def request_payload(cls, request, request_id):
"""JSON v2 request (or notification) payload."""
payload = {
'jsonrpc': '2.0',
'method': request.method,
}
# A notification?
if request_id is not None:
payload['id'] = request_id
# Preserve empty dicts as missing params is read as an array
if request.args or request.args == {}:
payload['params'] = request.args
return payload
@classmethod
def response_payload(cls, result, request_id):
"""JSON v2 response payload."""
return {
'jsonrpc': '2.0',
'result': result,
'id': request_id
}
@classmethod
def error_payload(cls, error, request_id):
return {
'jsonrpc': '2.0',
'error': {'code': error.code, 'message': error.message},
'id': request_id
}
class JSONRPCLoose(JSONRPC):
"""A relaxed version of JSON RPC."""
# Don't be so loose we accept any old message ID
_message_id = JSONRPCv2._message_id
_validate_message = JSONRPC._validate_message
_request_args = JSONRPCv2._request_args
# Outoing messages are JSONRPCv2 so we give the other side the
# best chance to assume / detect JSONRPCv2 as default protocol.
error_payload = JSONRPCv2.error_payload
request_payload = JSONRPCv2.request_payload
response_payload = JSONRPCv2.response_payload
@classmethod
def response_value(cls, payload):
# Return result, unless it is None and there is an error
if payload.get('error') is not None:
if payload.get('result') is not None:
raise ProtocolError.invalid_request(
'response contains both "result" and "error"')
return JSONRPCv1._best_effort_error(payload['error'])
if 'result' not in payload:
raise ProtocolError.invalid_request(
'response contains neither "result" nor "error"')
# Can be None
return payload['result']
class JSONRPCAutoDetect(JSONRPCv2):
@classmethod
def message_to_item(cls, message):
return cls.detect_protocol(message), None
@classmethod
def detect_protocol(cls, message):
"""Attempt to detect the protocol from the message."""
main = cls._message_to_payload(message)
def protocol_for_payload(payload):
if not isinstance(payload, dict):
return JSONRPCLoose # Will error
# Obey an explicit "jsonrpc"
version = payload.get('jsonrpc')
if version == '2.0':
return JSONRPCv2
if version == '1.0':
return JSONRPCv1
# Now to decide between JSONRPCLoose and JSONRPCv1 if possible
if 'result' in payload and 'error' in payload:
return JSONRPCv1
return JSONRPCLoose
if isinstance(main, list):
parts = {protocol_for_payload(payload) for payload in main}
# If all same protocol, return it
if len(parts) == 1:
return parts.pop()
# If strict protocol detected, return it, preferring JSONRPCv2.
# This means a batch of JSONRPCv1 will fail
for protocol in (JSONRPCv2, JSONRPCv1):
if protocol in parts:
return protocol
# Will error if no parts
return JSONRPCLoose
return protocol_for_payload(main)
class JSONRPCConnection:
"""Maintains state of a JSON RPC connection, in particular
encapsulating the handling of request IDs.
protocol - the JSON RPC protocol to follow
max_response_size - responses over this size send an error response
instead.
"""
_id_counter = itertools.count()
def __init__(self, protocol):
self._protocol = protocol
# Sent Requests and Batches that have not received a response.
# The key is its request ID; for a batch it is sorted tuple
# of request IDs
self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {}
# A public attribute intended to be settable dynamically
self.max_response_size = 0
def _oversized_response_message(self, request_id):
text = f'response too large (over {self.max_response_size:,d} bytes'
error = RPCError.invalid_request(text)
return self._protocol.response_message(error, request_id)
def _receive_response(self, result, request_id):
if request_id not in self._requests:
if request_id is None and isinstance(result, RPCError):
message = f'diagnostic error received: {result}'
else:
message = f'response to unsent request (ID: {request_id})'
raise ProtocolError.invalid_request(message) from None
request, event = self._requests.pop(request_id)
event.result = result
event.set()
return []
def _receive_request_batch(self, payloads):
def item_send_result(request_id, result):
nonlocal size
part = protocol.response_message(result, request_id)
size += len(part) + 2
if size > self.max_response_size > 0:
part = self._oversized_response_message(request_id)
parts.append(part)
if len(parts) == count:
return protocol.batch_message_from_parts(parts)
return None
parts = []
items = []
size = 0
count = 0
protocol = self._protocol
for payload in payloads:
try:
item, request_id = protocol._process_request(payload)
items.append(item)
if isinstance(item, Request):
count += 1
item.send_result = partial(item_send_result, request_id)
except ProtocolError as error:
count += 1
parts.append(error.error_message)
if not items and parts:
protocol_error = ProtocolError(0, "")
protocol_error.error_message = protocol.batch_message_from_parts(parts)
raise protocol_error
return items
def _receive_response_batch(self, payloads):
request_ids = []
results = []
for payload in payloads:
# Let ProtocolError exceptions through
item, request_id = self._protocol._process_response(payload)
request_ids.append(request_id)
results.append(item.result)
ordered = sorted(zip(request_ids, results), key=lambda t: t[0])
ordered_ids, ordered_results = zip(*ordered)
if ordered_ids not in self._requests:
raise ProtocolError.invalid_request('response to unsent batch')
request_batch, event = self._requests.pop(ordered_ids)
event.result = ordered_results
event.set()
return []
def _send_result(self, request_id, result):
message = self._protocol.response_message(result, request_id)
if len(message) > self.max_response_size > 0:
message = self._oversized_response_message(request_id)
return message
def _event(self, request, request_id):
event = Event()
self._requests[request_id] = (request, event)
return event
#
# External API
#
def send_request(self, request: Request) -> typing.Tuple[bytes, Event]:
"""Send a Request. Return a (message, event) pair.
The message is an unframed message to send over the network.
Wait on the event for the response; which will be in the
"result" attribute.
Raises: ProtocolError if the request violates the protocol
in some way..
"""
request_id = next(self._id_counter)
message = self._protocol.request_message(request, request_id)
return message, self._event(request, request_id)
def send_notification(self, notification):
return self._protocol.notification_message(notification)
def send_batch(self, batch):
ids = tuple(next(self._id_counter)
for request in batch if isinstance(request, Request))
message = self._protocol.batch_message(batch, ids)
event = self._event(batch, ids) if ids else None
return message, event
def receive_message(self, message):
"""Call with an unframed message received from the network.
Raises: ProtocolError if the message violates the protocol in
some way. However, if it happened in a response that can be
paired with a request, the ProtocolError is instead set in the
result attribute of the send_request() that caused the error.
"""
try:
item, request_id = self._protocol.message_to_item(message)
except ProtocolError as e:
if e.response_msg_id is not id:
return self._receive_response(e, e.response_msg_id)
raise
if isinstance(item, Request):
item.send_result = partial(self._send_result, request_id)
return [item]
if isinstance(item, Notification):
return [item]
if isinstance(item, Response):
return self._receive_response(item.result, request_id)
if isinstance(item, list):
if all(isinstance(payload, dict)
and ('result' in payload or 'error' in payload)
for payload in item):
return self._receive_response_batch(item)
else:
return self._receive_request_batch(item)
else:
# Protocol auto-detection hack
assert issubclass(item, JSONRPC)
self._protocol = item
return self.receive_message(message)
def raise_pending_requests(self, exception):
exception = exception or asyncio.TimeoutError()
for request, event in self._requests.values():
event.result = exception
event.set()
self._requests.clear()
def pending_requests(self):
"""All sent requests that have not received a response."""
return [request for request, event in self._requests.values()]
def handler_invocation(handler, request):
method, args = request.method, request.args
if handler is None:
raise RPCError(JSONRPC.METHOD_NOT_FOUND,
f'unknown method "{method}"')
# We must test for too few and too many arguments. How
# depends on whether the arguments were passed as a list or as
# a dictionary.
info = signature_info(handler)
if isinstance(args, (tuple, list)):
if len(args) < info.min_args:
s = '' if len(args) == 1 else 's'
raise RPCError.invalid_args(
f'{len(args)} argument{s} passed to method '
f'"{method}" but it requires {info.min_args}')
if info.max_args is not None and len(args) > info.max_args:
s = '' if len(args) == 1 else 's'
raise RPCError.invalid_args(
f'{len(args)} argument{s} passed to method '
f'{method} taking at most {info.max_args}')
return partial(handler, *args)
# Arguments passed by name
if info.other_names is None:
raise RPCError.invalid_args(f'method "{method}" cannot '
f'be called with named arguments')
missing = set(info.required_names).difference(args)
if missing:
s = '' if len(missing) == 1 else 's'
missing = ', '.join(sorted(f'"{name}"' for name in missing))
raise RPCError.invalid_args(f'method "{method}" requires '
f'parameter{s} {missing}')
if info.other_names is not any:
excess = set(args).difference(info.required_names)
excess = excess.difference(info.other_names)
if excess:
s = '' if len(excess) == 1 else 's'
excess = ', '.join(sorted(f'"{name}"' for name in excess))
raise RPCError.invalid_args(f'method "{method}" does not '
f'take parameter{s} {excess}')
return partial(handler, **args)

View file

@ -1,513 +0,0 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server',
'BatchError')
import asyncio
from asyncio import Event, CancelledError
import logging
import time
from contextlib import suppress
from lbry.wallet.tasks import TaskGroup
from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification
from .jsonrpc import RPCError, ProtocolError
from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer
from lbry.wallet.server.prometheus import NOTIFICATION_COUNT, RESPONSE_TIMES, REQUEST_ERRORS_COUNT, RESET_CONNECTIONS
class Connector:
def __init__(self, session_factory, host=None, port=None, proxy=None,
**kwargs):
self.session_factory = session_factory
self.host = host
self.port = port
self.proxy = proxy
self.loop = kwargs.get('loop', asyncio.get_event_loop())
self.kwargs = kwargs
async def create_connection(self):
"""Initiate a connection."""
connector = self.proxy or self.loop
return await connector.create_connection(
self.session_factory, self.host, self.port, **self.kwargs)
async def __aenter__(self):
transport, self.protocol = await self.create_connection()
return self.protocol
async def __aexit__(self, exc_type, exc_value, traceback):
await self.protocol.close()
class SessionBase(asyncio.Protocol):
"""Base class of networking sessions.
There is no client / server distinction other than who initiated
the connection.
To initiate a connection to a remote server pass host, port and
proxy to the constructor, and then call create_connection(). Each
successful call should have a corresponding call to close().
Alternatively if used in a with statement, the connection is made
on entry to the block, and closed on exit from the block.
"""
max_errors = 10
def __init__(self, *, framer=None, loop=None):
self.framer = framer or self.default_framer()
self.loop = loop or asyncio.get_event_loop()
self.logger = logging.getLogger(self.__class__.__name__)
self.transport = None
# Set when a connection is made
self._address = None
self._proxy_address = None
# For logger.debug messages
self.verbosity = 0
# Cleared when the send socket is full
self._can_send = Event()
self._can_send.set()
self._pm_task = None
self._task_group = TaskGroup(self.loop)
# Force-close a connection if a send doesn't succeed in this time
self.max_send_delay = 60
# Statistics. The RPC object also keeps its own statistics.
self.start_time = time.perf_counter()
self.errors = 0
self.send_count = 0
self.send_size = 0
self.last_send = self.start_time
self.recv_count = 0
self.recv_size = 0
self.last_recv = self.start_time
self.last_packet_received = self.start_time
async def _limited_wait(self, secs):
try:
await asyncio.wait_for(self._can_send.wait(), secs)
except asyncio.TimeoutError:
self.abort()
raise asyncio.TimeoutError(f'task timed out after {secs}s')
async def _send_message(self, message):
if not self._can_send.is_set():
await self._limited_wait(self.max_send_delay)
if not self.is_closing():
framed_message = self.framer.frame(message)
self.send_size += len(framed_message)
self.send_count += 1
self.last_send = time.perf_counter()
if self.verbosity >= 4:
self.logger.debug(f'Sending framed message {framed_message}')
self.transport.write(framed_message)
def _bump_errors(self):
self.errors += 1
if self.errors >= self.max_errors:
# Don't await self.close() because that is self-cancelling
self._close()
def _close(self):
if self.transport:
self.transport.close()
# asyncio framework
def data_received(self, framed_message):
"""Called by asyncio when a message comes in."""
self.last_packet_received = time.perf_counter()
if self.verbosity >= 4:
self.logger.debug(f'Received framed message {framed_message}')
self.recv_size += len(framed_message)
self.framer.received_bytes(framed_message)
def pause_writing(self):
"""Transport calls when the send buffer is full."""
if not self.is_closing():
self._can_send.clear()
self.transport.pause_reading()
def resume_writing(self):
"""Transport calls when the send buffer has room."""
if not self._can_send.is_set():
self._can_send.set()
self.transport.resume_reading()
def connection_made(self, transport):
"""Called by asyncio when a connection is established.
Derived classes overriding this method must call this first."""
self.transport = transport
# This would throw if called on a closed SSL transport. Fixed
# in asyncio in Python 3.6.1 and 3.5.4
peer_address = transport.get_extra_info('peername')
# If the Socks proxy was used then _address is already set to
# the remote address
if self._address:
self._proxy_address = peer_address
else:
self._address = peer_address
self._pm_task = self.loop.create_task(self._receive_messages())
def connection_lost(self, exc):
"""Called by asyncio when the connection closes.
Tear down things done in connection_made."""
self._address = None
self.transport = None
self._task_group.cancel()
if self._pm_task:
self._pm_task.cancel()
# Release waiting tasks
self._can_send.set()
# External API
def default_framer(self):
"""Return a default framer."""
raise NotImplementedError
def peer_address(self):
"""Returns the peer's address (Python networking address), or None if
no connection or an error.
This is the result of socket.getpeername() when the connection
was made.
"""
return self._address
def peer_address_str(self):
"""Returns the peer's IP address and port as a human-readable
string."""
if not self._address:
return 'unknown'
ip_addr_str, port = self._address[:2]
if ':' in ip_addr_str:
return f'[{ip_addr_str}]:{port}'
else:
return f'{ip_addr_str}:{port}'
def is_closing(self):
"""Return True if the connection is closing."""
return not self.transport or self.transport.is_closing()
def abort(self):
"""Forcefully close the connection."""
if self.transport:
self.transport.abort()
# TODO: replace with synchronous_close
async def close(self, *, force_after=30):
"""Close the connection and return when closed."""
self._close()
if self._pm_task:
with suppress(CancelledError):
await asyncio.wait([self._pm_task], timeout=force_after)
self.abort()
await self._pm_task
def synchronous_close(self):
self._close()
if self._pm_task and not self._pm_task.done():
self._pm_task.cancel()
class MessageSession(SessionBase):
"""Session class for protocols where messages are not tied to responses,
such as the Bitcoin protocol.
To use as a client (connection-opening) session, pass host, port
and perhaps a proxy.
"""
async def _receive_messages(self):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except BadMagicError as e:
magic, expected = e.args
self.logger.error(
f'bad network magic: got {magic} expected {expected}, '
f'disconnecting'
)
self._close()
except OversizedPayloadError as e:
command, payload_len = e.args
self.logger.error(
f'oversized payload of {payload_len:,d} bytes to command '
f'{command}, disconnecting'
)
self._close()
except BadChecksumError as e:
payload_checksum, claimed_checksum = e.args
self.logger.warning(
f'checksum mismatch: actual {payload_checksum.hex()} '
f'vs claimed {claimed_checksum.hex()}'
)
self._bump_errors()
else:
self.last_recv = time.perf_counter()
self.recv_count += 1
await self._task_group.add(self._handle_message(message))
async def _handle_message(self, message):
try:
await self.handle_message(message)
except ProtocolError as e:
self.logger.error(f'{e}')
self._bump_errors()
except CancelledError:
raise
except Exception:
self.logger.exception(f'exception handling {message}')
self._bump_errors()
# External API
def default_framer(self):
"""Return a bitcoin framer."""
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
async def handle_message(self, message):
"""message is a (command, payload) pair."""
pass
async def send_message(self, message):
"""Send a message (command, payload) over the network."""
await self._send_message(message)
class BatchError(Exception):
def __init__(self, request):
self.request = request # BatchRequest object
class BatchRequest:
"""Used to build a batch request to send to the server. Stores
the
Attributes batch and results are initially None.
Adding an invalid request or notification immediately raises a
ProtocolError.
On exiting the with clause, it will:
1) create a Batch object for the requests in the order they were
added. If the batch is empty this raises a ProtocolError.
2) set the "batch" attribute to be that batch
3) send the batch request and wait for a response
4) raise a ProtocolError if the protocol was violated by the
server. Currently this only happens if it gave more than one
response to any request
5) otherwise there is precisely one response to each Request. Set
the "results" attribute to the tuple of results; the responses
are ordered to match the Requests in the batch. Notifications
do not get a response.
6) if raise_errors is True and any individual response was a JSON
RPC error response, or violated the protocol in some way, a
BatchError exception is raised. Otherwise the caller can be
certain each request returned a standard result.
"""
def __init__(self, session, raise_errors):
self._session = session
self._raise_errors = raise_errors
self._requests = []
self.batch = None
self.results = None
def add_request(self, method, args=()):
self._requests.append(Request(method, args))
def add_notification(self, method, args=()):
self._requests.append(Notification(method, args))
def __len__(self):
return len(self._requests)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self.batch = Batch(self._requests)
message, event = self._session.connection.send_batch(self.batch)
await self._session._send_message(message)
await event.wait()
self.results = event.result
if self._raise_errors:
if any(isinstance(item, Exception) for item in event.result):
raise BatchError(self)
class RPCSession(SessionBase):
"""Base class for protocols where a message can lead to a response,
for example JSON RPC."""
def __init__(self, *, framer=None, loop=None, connection=None):
super().__init__(framer=framer, loop=loop)
self.connection = connection or self.default_connection()
self.client_version = 'unknown'
async def _receive_messages(self):
while not self.is_closing():
try:
message = await self.framer.receive_message()
except MemoryError:
self.logger.warning('received oversized message from %s:%s, dropping connection',
self._address[0], self._address[1])
RESET_CONNECTIONS.labels(version=self.client_version).inc()
self._close()
return
self.last_recv = time.perf_counter()
self.recv_count += 1
try:
requests = self.connection.receive_message(message)
except ProtocolError as e:
self.logger.debug(f'{e}')
if e.error_message:
await self._send_message(e.error_message)
if e.code == JSONRPC.PARSE_ERROR:
self.max_errors = 0
self._bump_errors()
else:
for request in requests:
await self._task_group.add(self._handle_request(request))
async def _handle_request(self, request):
start = time.perf_counter()
try:
result = await self.handle_request(request)
except (ProtocolError, RPCError) as e:
result = e
except CancelledError:
raise
except Exception:
self.logger.exception(f'exception handling {request}')
result = RPCError(JSONRPC.INTERNAL_ERROR,
'internal server error')
if isinstance(request, Request):
message = request.send_result(result)
RESPONSE_TIMES.labels(
method=request.method,
version=self.client_version
).observe(time.perf_counter() - start)
if message:
await self._send_message(message)
if isinstance(result, Exception):
self._bump_errors()
REQUEST_ERRORS_COUNT.labels(
method=request.method,
version=self.client_version
).inc()
def connection_lost(self, exc):
# Cancel pending requests and message processing
self.connection.raise_pending_requests(exc)
super().connection_lost(exc)
# External API
def default_connection(self):
"""Return a default connection if the user provides none."""
return JSONRPCConnection(JSONRPCv2)
def default_framer(self):
"""Return a default framer."""
return NewlineFramer()
async def handle_request(self, request):
pass
async def send_request(self, method, args=()):
"""Send an RPC request over the network."""
if self.is_closing():
raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.")
message, event = self.connection.send_request(Request(method, args))
await self._send_message(message)
await event.wait()
result = event.result
if isinstance(result, Exception):
raise result
return result
async def send_notification(self, method, args=()):
"""Send an RPC notification over the network."""
message = self.connection.send_notification(Notification(method, args))
NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc()
await self._send_message(message)
def send_batch(self, raise_errors=False):
"""Return a BatchRequest. Intended to be used like so:
async with session.send_batch() as batch:
batch.add_request("method1")
batch.add_request("sum", (x, y))
batch.add_notification("updated")
for result in batch.results:
...
Note that in some circumstances exceptions can be raised; see
BatchRequest doc string.
"""
return BatchRequest(self, raise_errors)
class Server:
"""A simple wrapper around an asyncio.Server object."""
def __init__(self, session_factory, host=None, port=None, *,
loop=None, **kwargs):
self.host = host
self.port = port
self.loop = loop or asyncio.get_event_loop()
self.server = None
self._session_factory = session_factory
self._kwargs = kwargs
async def listen(self):
self.server = await self.loop.create_server(
self._session_factory, self.host, self.port, **self._kwargs)
async def close(self):
"""Close the listening socket. This does not close any ServerSession
objects created to handle incoming connections.
"""
if self.server:
self.server.close()
await self.server.wait_closed()
self.server = None

View file

@ -1,439 +0,0 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""SOCKS proxying."""
import sys
import asyncio
import collections
import ipaddress
import socket
import struct
from functools import partial
__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
class SOCKSError(Exception):
"""Base class for SOCKS exceptions. Each raised exception will be
an instance of a derived class."""
class SOCKSProtocolError(SOCKSError):
"""Raised when the proxy does not follow the SOCKS protocol"""
class SOCKSFailure(SOCKSError):
"""Raised when the proxy refuses or fails to make a connection"""
class NeedData(Exception):
pass
class SOCKSBase:
@classmethod
def name(cls):
return cls.__name__
def __init__(self):
self._buffer = bytes()
self._state = self._start
def _read(self, size):
if len(self._buffer) < size:
raise NeedData(size - len(self._buffer))
result = self._buffer[:size]
self._buffer = self._buffer[size:]
return result
def receive_data(self, data):
self._buffer += data
def next_message(self):
return self._state()
class SOCKS4(SOCKSBase):
"""SOCKS4 protocol wrapper."""
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
REPLY_CODES = {
90: 'request granted',
91: 'request rejected or failed',
92: ('request rejected because SOCKS server cannot connect '
'to identd on the client'),
93: ('request rejected because the client program and identd '
'report different user-ids')
}
def __init__(self, dst_host, dst_port, auth):
super().__init__()
self._dst_host = self._check_host(dst_host)
self._dst_port = dst_port
self._auth = auth
@classmethod
def _check_host(cls, host):
if not isinstance(host, ipaddress.IPv4Address):
try:
host = ipaddress.IPv4Address(host)
except ValueError:
raise SOCKSProtocolError(
f'SOCKS4 requires an IPv4 address: {host}') from None
return host
def _start(self):
self._state = self._first_response
if isinstance(self._dst_host, ipaddress.IPv4Address):
# SOCKS4
dst_ip_packed = self._dst_host.packed
host_bytes = b''
else:
# SOCKS4a
dst_ip_packed = b'\0\0\0\1'
host_bytes = self._dst_host.encode() + b'\0'
if isinstance(self._auth, SOCKSUserAuth):
user_id = self._auth.username.encode()
else:
user_id = b''
# Send TCP/IP stream CONNECT request
return b''.join([b'\4\1', struct.pack('>H', self._dst_port),
dst_ip_packed, user_id, b'\0', host_bytes])
def _first_response(self):
# Wait for 8-byte response
data = self._read(8)
if data[0] != 0:
raise SOCKSProtocolError(f'invalid {self.name()} proxy '
f'response: {data}')
reply_code = data[1]
if reply_code != 90:
msg = self.REPLY_CODES.get(
reply_code, f'unknown {self.name()} reply code {reply_code}')
raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
# Other fields ignored
return None
class SOCKS4a(SOCKS4):
@classmethod
def _check_host(cls, host):
if not isinstance(host, (str, ipaddress.IPv4Address)):
raise SOCKSProtocolError(
f'SOCKS4a requires an IPv4 address or host name: {host}')
return host
class SOCKS5(SOCKSBase):
"""SOCKS protocol wrapper."""
# See https://tools.ietf.org/html/rfc1928
ERROR_CODES = {
1: 'general SOCKS server failure',
2: 'connection not allowed by ruleset',
3: 'network unreachable',
4: 'host unreachable',
5: 'connection refused',
6: 'TTL expired',
7: 'command not supported',
8: 'address type not supported',
}
def __init__(self, dst_host, dst_port, auth):
super().__init__()
self._dst_bytes = self._destination_bytes(dst_host, dst_port)
self._auth_bytes, self._auth_methods = self._authentication(auth)
def _destination_bytes(self, host, port):
if isinstance(host, ipaddress.IPv4Address):
addr_bytes = b'\1' + host.packed
elif isinstance(host, ipaddress.IPv6Address):
addr_bytes = b'\4' + host.packed
elif isinstance(host, str):
host = host.encode()
if len(host) > 255:
raise SOCKSProtocolError(f'hostname too long: '
f'{len(host)} bytes')
addr_bytes = b'\3' + bytes([len(host)]) + host
else:
raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 '
f'address, or host name: {host}')
return addr_bytes + struct.pack('>H', port)
def _authentication(self, auth):
if isinstance(auth, SOCKSUserAuth):
user_bytes = auth.username.encode()
if not 0 < len(user_bytes) < 256:
raise SOCKSProtocolError(f'username {auth.username} has '
f'invalid length {len(user_bytes)}')
pwd_bytes = auth.password.encode()
if not 0 < len(pwd_bytes) < 256:
raise SOCKSProtocolError(f'password has invalid length '
f'{len(pwd_bytes)}')
return b''.join([bytes([1, len(user_bytes)]), user_bytes,
bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
return b'', [0]
def _start(self):
self._state = self._first_response
return (b'\5' + bytes([len(self._auth_methods)])
+ bytes(m for m in self._auth_methods))
def _first_response(self):
# Wait for 2-byte response
data = self._read(2)
if data[0] != 5:
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
if data[1] not in self._auth_methods:
raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
# Authenticate if user-password authentication
if data[1] == 2:
self._state = self._auth_response
return self._auth_bytes
return self._request_connection()
def _auth_response(self):
data = self._read(2)
if data[0] != 1:
raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
f'response: {data}')
if data[1] != 0:
raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
f'{data[1]}')
return self._request_connection()
def _request_connection(self):
# Send connection request
self._state = self._connect_response
return b'\5\1\0' + self._dst_bytes
def _connect_response(self):
data = self._read(5)
if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
if data[1] != 0:
raise SOCKSFailure(self.ERROR_CODES.get(
data[1], f'unknown SOCKS5 error code: {data[1]}'))
if data[3] == 1:
addr_len = 3 # IPv4
elif data[3] == 3:
addr_len = data[4] # Hostname
else:
addr_len = 15 # IPv6
self._state = partial(self._connect_response_rest, addr_len)
return self.next_message()
def _connect_response_rest(self, addr_len):
self._read(addr_len + 2)
return None
class SOCKSProxy:
def __init__(self, address, protocol, auth):
"""A SOCKS proxy at an address following a SOCKS protocol. auth is an
authentication method to use when connecting, or None.
address is a (host, port) pair; for IPv6 it can instead be a
(host, port, flowinfo, scopeid) 4-tuple.
"""
self.address = address
self.protocol = protocol
self.auth = auth
# Set on each successful connection via the proxy to the
# result of socket.getpeername()
self.peername = None
def __str__(self):
auth = 'username' if self.auth else 'none'
return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
async def _handshake(self, client, sock, loop):
while True:
count = 0
try:
message = client.next_message()
except NeedData as e:
count = e.args[0]
else:
if message is None:
return
await loop.sock_sendall(sock, message)
if count:
data = await loop.sock_recv(sock, count)
if not data:
raise SOCKSProtocolError("EOF received")
client.receive_data(data)
async def _connect_one(self, host, port):
"""Connect to the proxy and perform a handshake requesting a
connection to (host, port).
Return the open socket on success, or the exception on failure.
"""
client = self.protocol(host, port, self.auth)
sock = socket.socket()
loop = asyncio.get_event_loop()
try:
# A non-blocking socket is required by loop socket methods
sock.setblocking(False)
await loop.sock_connect(sock, self.address)
await self._handshake(client, sock, loop)
self.peername = sock.getpeername()
return sock
except Exception as e:
# Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8
if sys.platform.startswith('linux') or sys.platform == "darwin":
sock.close()
return e
async def _connect(self, addresses):
"""Connect to the proxy and perform a handshake requesting a
connection to each address in addresses.
Return an (open_socket, address) pair on success.
"""
assert len(addresses) > 0
exceptions = []
for address in addresses:
host, port = address[:2]
sock = await self._connect_one(host, port)
if isinstance(sock, socket.socket):
return sock, address
exceptions.append(sock)
strings = {f'{exc!r}' for exc in exceptions}
raise (exceptions[0] if len(strings) == 1 else
OSError(f'multiple exceptions: {", ".join(strings)}'))
async def _detect_proxy(self):
"""Return True if it appears we can connect to a SOCKS proxy,
otherwise False.
"""
if self.protocol is SOCKS4a:
host, port = 'www.apple.com', 80
else:
host, port = ipaddress.IPv4Address('8.8.8.8'), 53
sock = await self._connect_one(host, port)
if isinstance(sock, socket.socket):
sock.close()
return True
# SOCKSFailure indicates something failed, but that we are
# likely talking to a proxy
return isinstance(sock, SOCKSFailure)
@classmethod
async def auto_detect_address(cls, address, auth):
"""Try to detect a SOCKS proxy at address using the authentication
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
order. If a SOCKS proxy is detected a SOCKSProxy object is
returned.
Returning a SOCKSProxy does not mean it is functioning - for
example, it may have no network connectivity.
If no proxy is detected return None.
"""
for protocol in (SOCKS5, SOCKS4a, SOCKS4):
proxy = cls(address, protocol, auth)
if await proxy._detect_proxy():
return proxy
return None
@classmethod
async def auto_detect_host(cls, host, ports, auth):
"""Try to detect a SOCKS proxy on a host on one of the ports.
Calls auto_detect for the ports in order. Returns SOCKS are
tried in order; a SOCKSProxy object for the first detected
proxy is returned.
Returning a SOCKSProxy does not mean it is functioning - for
example, it may have no network connectivity.
If no proxy is detected return None.
"""
for port in ports:
address = (host, port)
proxy = await cls.auto_detect_address(address, auth)
if proxy:
return proxy
return None
async def create_connection(self, protocol_factory, host, port, *,
resolve=False, ssl=None,
family=0, proto=0, flags=0):
"""Set up a connection to (host, port) through the proxy.
If resolve is True then host is resolved locally with
getaddrinfo using family, proto and flags, otherwise the proxy
is asked to resolve host.
The function signature is similar to loop.create_connection()
with the same result. The attribute _address is set on the
protocol to the address of the successful remote connection.
Additionally raises SOCKSError if something goes wrong with
the proxy handshake.
"""
loop = asyncio.get_event_loop()
if resolve:
infos = await loop.getaddrinfo(host, port, family=family,
type=socket.SOCK_STREAM,
proto=proto, flags=flags)
addresses = [info[4] for info in infos]
else:
addresses = [(host, port)]
sock, address = await self._connect(addresses)
def set_address():
protocol = protocol_factory()
protocol._address = address
return protocol
return await loop.create_connection(
set_address, sock=sock, ssl=ssl,
server_hostname=host if ssl else None)

View file

@ -1,95 +0,0 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
__all__ = ()
import asyncio
from collections import namedtuple
import inspect
# other_params: None means cannot be called with keyword arguments only
# any means any name is good
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
'required_names other_names')
def signature_info(func):
params = inspect.signature(func).parameters
min_args = max_args = 0
required_names = []
other_names = []
no_names = False
for p in params.values():
if p.kind == p.POSITIONAL_OR_KEYWORD:
max_args += 1
if p.default is p.empty:
min_args += 1
required_names.append(p.name)
else:
other_names.append(p.name)
elif p.kind == p.KEYWORD_ONLY:
other_names.append(p.name)
elif p.kind == p.VAR_POSITIONAL:
max_args = None
elif p.kind == p.VAR_KEYWORD:
other_names = any
elif p.kind == p.POSITIONAL_ONLY:
max_args += 1
if p.default is p.empty:
min_args += 1
no_names = True
if no_names:
other_names = None
return SignatureInfo(min_args, max_args, required_names, other_names)
class Concurrency:
def __init__(self, max_concurrent):
self._require_non_negative(max_concurrent)
self._max_concurrent = max_concurrent
self.semaphore = asyncio.Semaphore(max_concurrent)
def _require_non_negative(self, value):
if not isinstance(value, int) or value < 0:
raise RuntimeError('concurrency must be a natural number')
@property
def max_concurrent(self):
return self._max_concurrent
async def set_max_concurrent(self, value):
self._require_non_negative(value)
diff = value - self._max_concurrent
self._max_concurrent = value
if diff >= 0:
for _ in range(diff):
self.semaphore.release()
else:
for _ in range(-diff):
await self.semaphore.acquire()