diff --git a/lbry/wallet/rpc/__init__.py b/lbry/wallet/rpc/__init__.py deleted file mode 100644 index c2b585547..000000000 --- a/lbry/wallet/rpc/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .framing import * -from .jsonrpc import * -from .socks import * -from .session import * -from .util import * - -__all__ = (framing.__all__ + - jsonrpc.__all__ + - socks.__all__ + - session.__all__ + - util.__all__) diff --git a/lbry/wallet/rpc/framing.py b/lbry/wallet/rpc/framing.py deleted file mode 100644 index 3cdf67b16..000000000 --- a/lbry/wallet/rpc/framing.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""RPC message framing in a byte stream.""" - -__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer', - 'OversizedPayloadError', 'BadChecksumError', 'BadMagicError') - -from hashlib import sha256 as _sha256 -from struct import Struct -from asyncio import Queue - - -class FramerBase: - """Abstract base class for a framer. - - A framer breaks an incoming byte stream into protocol messages, - buffering if necessary. It also frames outgoing messages into - a byte stream. - """ - - def frame(self, message): - """Return the framed message.""" - raise NotImplementedError - - def received_bytes(self, data): - """Pass incoming network bytes.""" - raise NotImplementedError - - async def receive_message(self): - """Wait for a complete unframed message to arrive, and return it.""" - raise NotImplementedError - - -class NewlineFramer(FramerBase): - """A framer for a protocol where messages are separated by newlines.""" - - # The default max_size value is motivated by JSONRPC, where a - # normal request will be 250 bytes or less, and a reasonable - # batch may contain 4000 requests. - def __init__(self, max_size=250 * 4000): - """max_size - an anti-DoS measure. If, after processing an incoming - message, buffered data would exceed max_size bytes, that - buffered data is dropped entirely and the framer waits for a - newline character to re-synchronize the stream. - """ - self.max_size = max_size - self.queue = Queue() - self.received_bytes = self.queue.put_nowait - self.synchronizing = False - self.residual = b'' - - def frame(self, message): - return message + b'\n' - - async def receive_message(self): - parts = [] - buffer_size = 0 - while True: - part = self.residual - self.residual = b'' - if not part: - part = await self.queue.get() - - npos = part.find(b'\n') - if npos == -1: - parts.append(part) - buffer_size += len(part) - # Ignore over-sized messages; re-synchronize - if buffer_size <= self.max_size: - continue - self.synchronizing = True - raise MemoryError(f'dropping message over {self.max_size:,d} ' - f'bytes and re-synchronizing') - - tail, self.residual = part[:npos], part[npos + 1:] - if self.synchronizing: - self.synchronizing = False - return await self.receive_message() - else: - parts.append(tail) - return b''.join(parts) - - -class ByteQueue: - """A producer-comsumer queue. Incoming network data is put as it - arrives, and the consumer calls an async method waiting for data of - a specific length.""" - - def __init__(self): - self.queue = Queue() - self.parts = [] - self.parts_len = 0 - self.put_nowait = self.queue.put_nowait - - async def receive(self, size): - while self.parts_len < size: - part = await self.queue.get() - self.parts.append(part) - self.parts_len += len(part) - self.parts_len -= size - whole = b''.join(self.parts) - self.parts = [whole[size:]] - return whole[:size] - - -class BinaryFramer: - """A framer for binary messaging protocols.""" - - def __init__(self): - self.byte_queue = ByteQueue() - self.message_queue = Queue() - self.received_bytes = self.byte_queue.put_nowait - - def frame(self, message): - command, payload = message - return b''.join(( - self._build_header(command, payload), - payload - )) - - async def receive_message(self): - command, payload_len, checksum = await self._receive_header() - payload = await self.byte_queue.receive(payload_len) - payload_checksum = self._checksum(payload) - if payload_checksum != checksum: - raise BadChecksumError(payload_checksum, checksum) - return command, payload - - def _checksum(self, payload): - raise NotImplementedError - - def _build_header(self, command, payload): - raise NotImplementedError - - async def _receive_header(self): - raise NotImplementedError - - -# Helpers -struct_le_I = Struct(' 1024 * 1024: - if command != b'block' or payload_len > self._max_block_size: - raise OversizedPayloadError(command, payload_len) - return command, payload_len, checksum diff --git a/lbry/wallet/rpc/jsonrpc.py b/lbry/wallet/rpc/jsonrpc.py deleted file mode 100644 index 61de3f366..000000000 --- a/lbry/wallet/rpc/jsonrpc.py +++ /dev/null @@ -1,804 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.""" - -__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', - 'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch', - 'RPCError', 'ProtocolError', - 'JSONRPCConnection', 'handler_invocation') - -import itertools -import json -import typing -import asyncio -from functools import partial -from numbers import Number - -import attr -from asyncio import Queue, Event, CancelledError -from .util import signature_info - - -class SingleRequest: - __slots__ = ('method', 'args') - - def __init__(self, method, args): - if not isinstance(method, str): - raise ProtocolError(JSONRPC.METHOD_NOT_FOUND, - 'method must be a string') - if not isinstance(args, (list, tuple, dict)): - raise ProtocolError.invalid_args('request arguments must be a ' - 'list or a dictionary') - self.args = args - self.method = method - - def __repr__(self): - return f'{self.__class__.__name__}({self.method!r}, {self.args!r})' - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.method == other.method and self.args == other.args) - - -class Request(SingleRequest): - def send_result(self, response): - return None - - -class Notification(SingleRequest): - pass - - -class Batch: - __slots__ = ('items', ) - - def __init__(self, items): - if not isinstance(items, (list, tuple)): - raise ProtocolError.invalid_request('items must be a list') - if not items: - raise ProtocolError.empty_batch() - if not (all(isinstance(item, SingleRequest) for item in items) or - all(isinstance(item, Response) for item in items)): - raise ProtocolError.invalid_request('batch must be homogeneous') - self.items = items - - def __len__(self): - return len(self.items) - - def __getitem__(self, item): - return self.items[item] - - def __iter__(self): - return iter(self.items) - - def __repr__(self): - return f'Batch({len(self.items)} items)' - - -class Response: - __slots__ = ('result', ) - - def __init__(self, result): - # Type checking happens when converting to a message - self.result = result - - -class CodeMessageError(Exception): - - def __init__(self, code, message): - super().__init__(code, message) - - @property - def code(self): - return self.args[0] - - @property - def message(self): - return self.args[1] - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.code == other.code and self.message == other.message) - - def __hash__(self): - # overridden to make the exception hashable - # see https://bugs.python.org/issue28603 - return hash((self.code, self.message)) - - @classmethod - def invalid_args(cls, message): - return cls(JSONRPC.INVALID_ARGS, message) - - @classmethod - def invalid_request(cls, message): - return cls(JSONRPC.INVALID_REQUEST, message) - - @classmethod - def empty_batch(cls): - return cls.invalid_request('batch is empty') - - -class RPCError(CodeMessageError): - pass - - -class ProtocolError(CodeMessageError): - - def __init__(self, code, message): - super().__init__(code, message) - # If not None send this unframed message over the network - self.error_message = None - # If the error was in a JSON response message; its message ID. - # Since None can be a response message ID, "id" means the - # error was not sent in a JSON response - self.response_msg_id = id - - -class JSONRPC: - """Abstract base class that interprets and constructs JSON RPC messages.""" - - # Error codes. See http://www.jsonrpc.org/specification - PARSE_ERROR = -32700 - INVALID_REQUEST = -32600 - METHOD_NOT_FOUND = -32601 - INVALID_ARGS = -32602 - INTERNAL_ERROR = -32603 - QUERY_TIMEOUT = -32000 - - # Codes specific to this library - ERROR_CODE_UNAVAILABLE = -100 - - # Can be overridden by derived classes - allow_batches = True - - @classmethod - def _message_id(cls, message, require_id): - """Validate the message is a dictionary and return its ID. - - Raise an error if the message is invalid or the ID is of an - invalid type. If it has no ID, raise an error if require_id - is True, otherwise return None. - """ - raise NotImplementedError - - @classmethod - def _validate_message(cls, message): - """Validate other parts of the message other than those - done in _message_id.""" - pass - - @classmethod - def _request_args(cls, request): - """Validate the existence and type of the arguments passed - in the request dictionary.""" - raise NotImplementedError - - @classmethod - def _process_request(cls, payload): - request_id = None - try: - request_id = cls._message_id(payload, False) - cls._validate_message(payload) - method = payload.get('method') - if request_id is None: - item = Notification(method, cls._request_args(payload)) - else: - item = Request(method, cls._request_args(payload)) - return item, request_id - except ProtocolError as error: - code, message = error.code, error.message - raise cls._error(code, message, True, request_id) - - @classmethod - def _process_response(cls, payload): - request_id = None - try: - request_id = cls._message_id(payload, True) - cls._validate_message(payload) - return Response(cls.response_value(payload)), request_id - except ProtocolError as error: - code, message = error.code, error.message - raise cls._error(code, message, False, request_id) - - @classmethod - def _message_to_payload(cls, message): - """Returns a Python object or a ProtocolError.""" - try: - return json.loads(message.decode()) - except UnicodeDecodeError: - message = 'messages must be encoded in UTF-8' - except json.JSONDecodeError: - message = 'invalid JSON' - raise cls._error(cls.PARSE_ERROR, message, True, None) - - @classmethod - def _error(cls, code, message, send, msg_id): - error = ProtocolError(code, message) - if send: - error.error_message = cls.response_message(error, msg_id) - else: - error.response_msg_id = msg_id - return error - - # - # External API - # - - @classmethod - def message_to_item(cls, message): - """Translate an unframed received message and return an - (item, request_id) pair. - - The item can be a Request, Notification, Response or a list. - - A JSON RPC error response is returned as an RPCError inside a - Response object. - - If a Batch is returned, request_id is an iterable of request - ids, one per batch member. - - If the message violates the protocol in some way a - ProtocolError is returned, except if the message was - determined to be a response, in which case the ProtocolError - is placed inside a Response object. This is so that client - code can mark a request as having been responded to even if - the response was bad. - - raises: ProtocolError - """ - payload = cls._message_to_payload(message) - if isinstance(payload, dict): - if 'method' in payload: - return cls._process_request(payload) - else: - return cls._process_response(payload) - elif isinstance(payload, list) and cls.allow_batches: - if not payload: - raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty', - True, None) - return payload, None - raise cls._error(cls.INVALID_REQUEST, - 'request object must be a dictionary', True, None) - - # Message formation - @classmethod - def request_message(cls, item, request_id): - """Convert an RPCRequest item to a message.""" - assert isinstance(item, Request) - return cls.encode_payload(cls.request_payload(item, request_id)) - - @classmethod - def notification_message(cls, item): - """Convert an RPCRequest item to a message.""" - assert isinstance(item, Notification) - return cls.encode_payload(cls.request_payload(item, None)) - - @classmethod - def response_message(cls, result, request_id): - """Convert a response result (or RPCError) to a message.""" - if isinstance(result, CodeMessageError): - payload = cls.error_payload(result, request_id) - else: - payload = cls.response_payload(result, request_id) - return cls.encode_payload(payload) - - @classmethod - def batch_message(cls, batch, request_ids): - """Convert a request Batch to a message.""" - assert isinstance(batch, Batch) - if not cls.allow_batches: - raise ProtocolError.invalid_request( - 'protocol does not permit batches') - id_iter = iter(request_ids) - rm = cls.request_message - nm = cls.notification_message - parts = (rm(request, next(id_iter)) if isinstance(request, Request) - else nm(request) for request in batch) - return cls.batch_message_from_parts(parts) - - @classmethod - def batch_message_from_parts(cls, messages): - """Convert messages, one per batch item, into a batch message. At - least one message must be passed. - """ - # Comma-separate the messages and wrap the lot in square brackets - middle = b', '.join(messages) - if not middle: - raise ProtocolError.empty_batch() - return b''.join([b'[', middle, b']']) - - @classmethod - def encode_payload(cls, payload): - """Encode a Python object as JSON and convert it to bytes.""" - try: - return json.dumps(payload).encode() - except TypeError: - msg = f'JSON payload encoding error: {payload}' - raise ProtocolError(cls.INTERNAL_ERROR, msg) from None - - -class JSONRPCv1(JSONRPC): - """JSON RPC version 1.0.""" - - allow_batches = False - - @classmethod - def _message_id(cls, message, require_id): - # JSONv1 requires an ID always, but without constraint on its type - # No need to test for a dictionary here as we don't handle batches. - if 'id' not in message: - raise ProtocolError.invalid_request('request has no "id"') - return message['id'] - - @classmethod - def _request_args(cls, request): - args = request.get('params') - if not isinstance(args, list): - raise ProtocolError.invalid_args( - f'invalid request arguments: {args}') - return args - - @classmethod - def _best_effort_error(cls, error): - # Do our best to interpret the error - code = cls.ERROR_CODE_UNAVAILABLE - message = 'no error message provided' - if isinstance(error, str): - message = error - elif isinstance(error, int): - code = error - elif isinstance(error, dict): - if isinstance(error.get('message'), str): - message = error['message'] - if isinstance(error.get('code'), int): - code = error['code'] - - return RPCError(code, message) - - @classmethod - def response_value(cls, payload): - if 'result' not in payload or 'error' not in payload: - raise ProtocolError.invalid_request( - 'response must contain both "result" and "error"') - - result = payload['result'] - error = payload['error'] - if error is None: - return result # It seems None can be a valid result - if result is not None: - raise ProtocolError.invalid_request( - 'response has a "result" and an "error"') - - return cls._best_effort_error(error) - - @classmethod - def request_payload(cls, request, request_id): - """JSON v1 request (or notification) payload.""" - if isinstance(request.args, dict): - raise ProtocolError.invalid_args( - 'JSONRPCv1 does not support named arguments') - return { - 'method': request.method, - 'params': request.args, - 'id': request_id - } - - @classmethod - def response_payload(cls, result, request_id): - """JSON v1 response payload.""" - return { - 'result': result, - 'error': None, - 'id': request_id - } - - @classmethod - def error_payload(cls, error, request_id): - return { - 'result': None, - 'error': {'code': error.code, 'message': error.message}, - 'id': request_id - } - - -class JSONRPCv2(JSONRPC): - """JSON RPC version 2.0.""" - - @classmethod - def _message_id(cls, message, require_id): - if not isinstance(message, dict): - raise ProtocolError.invalid_request( - 'request object must be a dictionary') - if 'id' in message: - request_id = message['id'] - if not isinstance(request_id, (Number, str, type(None))): - raise ProtocolError.invalid_request( - f'invalid "id": {request_id}') - return request_id - else: - if require_id: - raise ProtocolError.invalid_request('request has no "id"') - return None - - @classmethod - def _validate_message(cls, message): - if message.get('jsonrpc') != '2.0': - raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"') - - @classmethod - def _request_args(cls, request): - args = request.get('params', []) - if not isinstance(args, (dict, list)): - raise ProtocolError.invalid_args( - f'invalid request arguments: {args}') - return args - - @classmethod - def response_value(cls, payload): - if 'result' in payload: - if 'error' in payload: - raise ProtocolError.invalid_request( - 'response contains both "result" and "error"') - return payload['result'] - - if 'error' not in payload: - raise ProtocolError.invalid_request( - 'response contains neither "result" nor "error"') - - # Return an RPCError object - error = payload['error'] - if isinstance(error, dict): - code = error.get('code') - message = error.get('message') - if isinstance(code, int) and isinstance(message, str): - return RPCError(code, message) - - raise ProtocolError.invalid_request( - f'ill-formed response error object: {error}') - - @classmethod - def request_payload(cls, request, request_id): - """JSON v2 request (or notification) payload.""" - payload = { - 'jsonrpc': '2.0', - 'method': request.method, - } - # A notification? - if request_id is not None: - payload['id'] = request_id - # Preserve empty dicts as missing params is read as an array - if request.args or request.args == {}: - payload['params'] = request.args - return payload - - @classmethod - def response_payload(cls, result, request_id): - """JSON v2 response payload.""" - return { - 'jsonrpc': '2.0', - 'result': result, - 'id': request_id - } - - @classmethod - def error_payload(cls, error, request_id): - return { - 'jsonrpc': '2.0', - 'error': {'code': error.code, 'message': error.message}, - 'id': request_id - } - - -class JSONRPCLoose(JSONRPC): - """A relaxed version of JSON RPC.""" - - # Don't be so loose we accept any old message ID - _message_id = JSONRPCv2._message_id - _validate_message = JSONRPC._validate_message - _request_args = JSONRPCv2._request_args - # Outoing messages are JSONRPCv2 so we give the other side the - # best chance to assume / detect JSONRPCv2 as default protocol. - error_payload = JSONRPCv2.error_payload - request_payload = JSONRPCv2.request_payload - response_payload = JSONRPCv2.response_payload - - @classmethod - def response_value(cls, payload): - # Return result, unless it is None and there is an error - if payload.get('error') is not None: - if payload.get('result') is not None: - raise ProtocolError.invalid_request( - 'response contains both "result" and "error"') - return JSONRPCv1._best_effort_error(payload['error']) - - if 'result' not in payload: - raise ProtocolError.invalid_request( - 'response contains neither "result" nor "error"') - - # Can be None - return payload['result'] - - -class JSONRPCAutoDetect(JSONRPCv2): - - @classmethod - def message_to_item(cls, message): - return cls.detect_protocol(message), None - - @classmethod - def detect_protocol(cls, message): - """Attempt to detect the protocol from the message.""" - main = cls._message_to_payload(message) - - def protocol_for_payload(payload): - if not isinstance(payload, dict): - return JSONRPCLoose # Will error - # Obey an explicit "jsonrpc" - version = payload.get('jsonrpc') - if version == '2.0': - return JSONRPCv2 - if version == '1.0': - return JSONRPCv1 - - # Now to decide between JSONRPCLoose and JSONRPCv1 if possible - if 'result' in payload and 'error' in payload: - return JSONRPCv1 - return JSONRPCLoose - - if isinstance(main, list): - parts = {protocol_for_payload(payload) for payload in main} - # If all same protocol, return it - if len(parts) == 1: - return parts.pop() - # If strict protocol detected, return it, preferring JSONRPCv2. - # This means a batch of JSONRPCv1 will fail - for protocol in (JSONRPCv2, JSONRPCv1): - if protocol in parts: - return protocol - # Will error if no parts - return JSONRPCLoose - - return protocol_for_payload(main) - - -class JSONRPCConnection: - """Maintains state of a JSON RPC connection, in particular - encapsulating the handling of request IDs. - - protocol - the JSON RPC protocol to follow - max_response_size - responses over this size send an error response - instead. - """ - - _id_counter = itertools.count() - - def __init__(self, protocol): - self._protocol = protocol - # Sent Requests and Batches that have not received a response. - # The key is its request ID; for a batch it is sorted tuple - # of request IDs - self._requests: typing.Dict[str, typing.Tuple[Request, Event]] = {} - # A public attribute intended to be settable dynamically - self.max_response_size = 0 - - def _oversized_response_message(self, request_id): - text = f'response too large (over {self.max_response_size:,d} bytes' - error = RPCError.invalid_request(text) - return self._protocol.response_message(error, request_id) - - def _receive_response(self, result, request_id): - if request_id not in self._requests: - if request_id is None and isinstance(result, RPCError): - message = f'diagnostic error received: {result}' - else: - message = f'response to unsent request (ID: {request_id})' - raise ProtocolError.invalid_request(message) from None - request, event = self._requests.pop(request_id) - event.result = result - event.set() - return [] - - def _receive_request_batch(self, payloads): - def item_send_result(request_id, result): - nonlocal size - part = protocol.response_message(result, request_id) - size += len(part) + 2 - if size > self.max_response_size > 0: - part = self._oversized_response_message(request_id) - parts.append(part) - if len(parts) == count: - return protocol.batch_message_from_parts(parts) - return None - - parts = [] - items = [] - size = 0 - count = 0 - protocol = self._protocol - for payload in payloads: - try: - item, request_id = protocol._process_request(payload) - items.append(item) - if isinstance(item, Request): - count += 1 - item.send_result = partial(item_send_result, request_id) - except ProtocolError as error: - count += 1 - parts.append(error.error_message) - - if not items and parts: - protocol_error = ProtocolError(0, "") - protocol_error.error_message = protocol.batch_message_from_parts(parts) - raise protocol_error - return items - - def _receive_response_batch(self, payloads): - request_ids = [] - results = [] - for payload in payloads: - # Let ProtocolError exceptions through - item, request_id = self._protocol._process_response(payload) - request_ids.append(request_id) - results.append(item.result) - - ordered = sorted(zip(request_ids, results), key=lambda t: t[0]) - ordered_ids, ordered_results = zip(*ordered) - if ordered_ids not in self._requests: - raise ProtocolError.invalid_request('response to unsent batch') - request_batch, event = self._requests.pop(ordered_ids) - event.result = ordered_results - event.set() - return [] - - def _send_result(self, request_id, result): - message = self._protocol.response_message(result, request_id) - if len(message) > self.max_response_size > 0: - message = self._oversized_response_message(request_id) - return message - - def _event(self, request, request_id): - event = Event() - self._requests[request_id] = (request, event) - return event - - # - # External API - # - def send_request(self, request: Request) -> typing.Tuple[bytes, Event]: - """Send a Request. Return a (message, event) pair. - - The message is an unframed message to send over the network. - Wait on the event for the response; which will be in the - "result" attribute. - - Raises: ProtocolError if the request violates the protocol - in some way.. - """ - request_id = next(self._id_counter) - message = self._protocol.request_message(request, request_id) - return message, self._event(request, request_id) - - def send_notification(self, notification): - return self._protocol.notification_message(notification) - - def send_batch(self, batch): - ids = tuple(next(self._id_counter) - for request in batch if isinstance(request, Request)) - message = self._protocol.batch_message(batch, ids) - event = self._event(batch, ids) if ids else None - return message, event - - def receive_message(self, message): - """Call with an unframed message received from the network. - - Raises: ProtocolError if the message violates the protocol in - some way. However, if it happened in a response that can be - paired with a request, the ProtocolError is instead set in the - result attribute of the send_request() that caused the error. - """ - try: - item, request_id = self._protocol.message_to_item(message) - except ProtocolError as e: - if e.response_msg_id is not id: - return self._receive_response(e, e.response_msg_id) - raise - - if isinstance(item, Request): - item.send_result = partial(self._send_result, request_id) - return [item] - if isinstance(item, Notification): - return [item] - if isinstance(item, Response): - return self._receive_response(item.result, request_id) - if isinstance(item, list): - if all(isinstance(payload, dict) - and ('result' in payload or 'error' in payload) - for payload in item): - return self._receive_response_batch(item) - else: - return self._receive_request_batch(item) - else: - # Protocol auto-detection hack - assert issubclass(item, JSONRPC) - self._protocol = item - return self.receive_message(message) - - def raise_pending_requests(self, exception): - exception = exception or asyncio.TimeoutError() - for request, event in self._requests.values(): - event.result = exception - event.set() - self._requests.clear() - - def pending_requests(self): - """All sent requests that have not received a response.""" - return [request for request, event in self._requests.values()] - - -def handler_invocation(handler, request): - method, args = request.method, request.args - if handler is None: - raise RPCError(JSONRPC.METHOD_NOT_FOUND, - f'unknown method "{method}"') - - # We must test for too few and too many arguments. How - # depends on whether the arguments were passed as a list or as - # a dictionary. - info = signature_info(handler) - if isinstance(args, (tuple, list)): - if len(args) < info.min_args: - s = '' if len(args) == 1 else 's' - raise RPCError.invalid_args( - f'{len(args)} argument{s} passed to method ' - f'"{method}" but it requires {info.min_args}') - if info.max_args is not None and len(args) > info.max_args: - s = '' if len(args) == 1 else 's' - raise RPCError.invalid_args( - f'{len(args)} argument{s} passed to method ' - f'{method} taking at most {info.max_args}') - return partial(handler, *args) - - # Arguments passed by name - if info.other_names is None: - raise RPCError.invalid_args(f'method "{method}" cannot ' - f'be called with named arguments') - - missing = set(info.required_names).difference(args) - if missing: - s = '' if len(missing) == 1 else 's' - missing = ', '.join(sorted(f'"{name}"' for name in missing)) - raise RPCError.invalid_args(f'method "{method}" requires ' - f'parameter{s} {missing}') - - if info.other_names is not any: - excess = set(args).difference(info.required_names) - excess = excess.difference(info.other_names) - if excess: - s = '' if len(excess) == 1 else 's' - excess = ', '.join(sorted(f'"{name}"' for name in excess)) - raise RPCError.invalid_args(f'method "{method}" does not ' - f'take parameter{s} {excess}') - return partial(handler, **args) diff --git a/lbry/wallet/rpc/session.py b/lbry/wallet/rpc/session.py deleted file mode 100644 index 53c164f4f..000000000 --- a/lbry/wallet/rpc/session.py +++ /dev/null @@ -1,513 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server', - 'BatchError') - - -import asyncio -from asyncio import Event, CancelledError -import logging -import time -from contextlib import suppress - -from lbry.wallet.tasks import TaskGroup - -from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification -from .jsonrpc import RPCError, ProtocolError -from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer -from lbry.wallet.server.prometheus import NOTIFICATION_COUNT, RESPONSE_TIMES, REQUEST_ERRORS_COUNT, RESET_CONNECTIONS - - -class Connector: - - def __init__(self, session_factory, host=None, port=None, proxy=None, - **kwargs): - self.session_factory = session_factory - self.host = host - self.port = port - self.proxy = proxy - self.loop = kwargs.get('loop', asyncio.get_event_loop()) - self.kwargs = kwargs - - async def create_connection(self): - """Initiate a connection.""" - connector = self.proxy or self.loop - return await connector.create_connection( - self.session_factory, self.host, self.port, **self.kwargs) - - async def __aenter__(self): - transport, self.protocol = await self.create_connection() - return self.protocol - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.protocol.close() - - -class SessionBase(asyncio.Protocol): - """Base class of networking sessions. - - There is no client / server distinction other than who initiated - the connection. - - To initiate a connection to a remote server pass host, port and - proxy to the constructor, and then call create_connection(). Each - successful call should have a corresponding call to close(). - - Alternatively if used in a with statement, the connection is made - on entry to the block, and closed on exit from the block. - """ - - max_errors = 10 - - def __init__(self, *, framer=None, loop=None): - self.framer = framer or self.default_framer() - self.loop = loop or asyncio.get_event_loop() - self.logger = logging.getLogger(self.__class__.__name__) - self.transport = None - # Set when a connection is made - self._address = None - self._proxy_address = None - # For logger.debug messages - self.verbosity = 0 - # Cleared when the send socket is full - self._can_send = Event() - self._can_send.set() - self._pm_task = None - self._task_group = TaskGroup(self.loop) - # Force-close a connection if a send doesn't succeed in this time - self.max_send_delay = 60 - # Statistics. The RPC object also keeps its own statistics. - self.start_time = time.perf_counter() - self.errors = 0 - self.send_count = 0 - self.send_size = 0 - self.last_send = self.start_time - self.recv_count = 0 - self.recv_size = 0 - self.last_recv = self.start_time - self.last_packet_received = self.start_time - - async def _limited_wait(self, secs): - try: - await asyncio.wait_for(self._can_send.wait(), secs) - except asyncio.TimeoutError: - self.abort() - raise asyncio.TimeoutError(f'task timed out after {secs}s') - - async def _send_message(self, message): - if not self._can_send.is_set(): - await self._limited_wait(self.max_send_delay) - if not self.is_closing(): - framed_message = self.framer.frame(message) - self.send_size += len(framed_message) - self.send_count += 1 - self.last_send = time.perf_counter() - if self.verbosity >= 4: - self.logger.debug(f'Sending framed message {framed_message}') - self.transport.write(framed_message) - - def _bump_errors(self): - self.errors += 1 - if self.errors >= self.max_errors: - # Don't await self.close() because that is self-cancelling - self._close() - - def _close(self): - if self.transport: - self.transport.close() - - # asyncio framework - def data_received(self, framed_message): - """Called by asyncio when a message comes in.""" - self.last_packet_received = time.perf_counter() - if self.verbosity >= 4: - self.logger.debug(f'Received framed message {framed_message}') - self.recv_size += len(framed_message) - self.framer.received_bytes(framed_message) - - def pause_writing(self): - """Transport calls when the send buffer is full.""" - if not self.is_closing(): - self._can_send.clear() - self.transport.pause_reading() - - def resume_writing(self): - """Transport calls when the send buffer has room.""" - if not self._can_send.is_set(): - self._can_send.set() - self.transport.resume_reading() - - def connection_made(self, transport): - """Called by asyncio when a connection is established. - - Derived classes overriding this method must call this first.""" - self.transport = transport - # This would throw if called on a closed SSL transport. Fixed - # in asyncio in Python 3.6.1 and 3.5.4 - peer_address = transport.get_extra_info('peername') - # If the Socks proxy was used then _address is already set to - # the remote address - if self._address: - self._proxy_address = peer_address - else: - self._address = peer_address - self._pm_task = self.loop.create_task(self._receive_messages()) - - def connection_lost(self, exc): - """Called by asyncio when the connection closes. - - Tear down things done in connection_made.""" - self._address = None - self.transport = None - self._task_group.cancel() - if self._pm_task: - self._pm_task.cancel() - # Release waiting tasks - self._can_send.set() - - # External API - def default_framer(self): - """Return a default framer.""" - raise NotImplementedError - - def peer_address(self): - """Returns the peer's address (Python networking address), or None if - no connection or an error. - - This is the result of socket.getpeername() when the connection - was made. - """ - return self._address - - def peer_address_str(self): - """Returns the peer's IP address and port as a human-readable - string.""" - if not self._address: - return 'unknown' - ip_addr_str, port = self._address[:2] - if ':' in ip_addr_str: - return f'[{ip_addr_str}]:{port}' - else: - return f'{ip_addr_str}:{port}' - - def is_closing(self): - """Return True if the connection is closing.""" - return not self.transport or self.transport.is_closing() - - def abort(self): - """Forcefully close the connection.""" - if self.transport: - self.transport.abort() - - # TODO: replace with synchronous_close - async def close(self, *, force_after=30): - """Close the connection and return when closed.""" - self._close() - if self._pm_task: - with suppress(CancelledError): - await asyncio.wait([self._pm_task], timeout=force_after) - self.abort() - await self._pm_task - - def synchronous_close(self): - self._close() - if self._pm_task and not self._pm_task.done(): - self._pm_task.cancel() - - -class MessageSession(SessionBase): - """Session class for protocols where messages are not tied to responses, - such as the Bitcoin protocol. - - To use as a client (connection-opening) session, pass host, port - and perhaps a proxy. - """ - async def _receive_messages(self): - while not self.is_closing(): - try: - message = await self.framer.receive_message() - except BadMagicError as e: - magic, expected = e.args - self.logger.error( - f'bad network magic: got {magic} expected {expected}, ' - f'disconnecting' - ) - self._close() - except OversizedPayloadError as e: - command, payload_len = e.args - self.logger.error( - f'oversized payload of {payload_len:,d} bytes to command ' - f'{command}, disconnecting' - ) - self._close() - except BadChecksumError as e: - payload_checksum, claimed_checksum = e.args - self.logger.warning( - f'checksum mismatch: actual {payload_checksum.hex()} ' - f'vs claimed {claimed_checksum.hex()}' - ) - self._bump_errors() - else: - self.last_recv = time.perf_counter() - self.recv_count += 1 - await self._task_group.add(self._handle_message(message)) - - async def _handle_message(self, message): - try: - await self.handle_message(message) - except ProtocolError as e: - self.logger.error(f'{e}') - self._bump_errors() - except CancelledError: - raise - except Exception: - self.logger.exception(f'exception handling {message}') - self._bump_errors() - - # External API - def default_framer(self): - """Return a bitcoin framer.""" - return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000) - - async def handle_message(self, message): - """message is a (command, payload) pair.""" - pass - - async def send_message(self, message): - """Send a message (command, payload) over the network.""" - await self._send_message(message) - - -class BatchError(Exception): - - def __init__(self, request): - self.request = request # BatchRequest object - - -class BatchRequest: - """Used to build a batch request to send to the server. Stores - the - - Attributes batch and results are initially None. - - Adding an invalid request or notification immediately raises a - ProtocolError. - - On exiting the with clause, it will: - - 1) create a Batch object for the requests in the order they were - added. If the batch is empty this raises a ProtocolError. - - 2) set the "batch" attribute to be that batch - - 3) send the batch request and wait for a response - - 4) raise a ProtocolError if the protocol was violated by the - server. Currently this only happens if it gave more than one - response to any request - - 5) otherwise there is precisely one response to each Request. Set - the "results" attribute to the tuple of results; the responses - are ordered to match the Requests in the batch. Notifications - do not get a response. - - 6) if raise_errors is True and any individual response was a JSON - RPC error response, or violated the protocol in some way, a - BatchError exception is raised. Otherwise the caller can be - certain each request returned a standard result. - """ - - def __init__(self, session, raise_errors): - self._session = session - self._raise_errors = raise_errors - self._requests = [] - self.batch = None - self.results = None - - def add_request(self, method, args=()): - self._requests.append(Request(method, args)) - - def add_notification(self, method, args=()): - self._requests.append(Notification(method, args)) - - def __len__(self): - return len(self._requests) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - if exc_type is None: - self.batch = Batch(self._requests) - message, event = self._session.connection.send_batch(self.batch) - await self._session._send_message(message) - await event.wait() - self.results = event.result - if self._raise_errors: - if any(isinstance(item, Exception) for item in event.result): - raise BatchError(self) - - -class RPCSession(SessionBase): - """Base class for protocols where a message can lead to a response, - for example JSON RPC.""" - - def __init__(self, *, framer=None, loop=None, connection=None): - super().__init__(framer=framer, loop=loop) - self.connection = connection or self.default_connection() - self.client_version = 'unknown' - - async def _receive_messages(self): - while not self.is_closing(): - try: - message = await self.framer.receive_message() - except MemoryError: - self.logger.warning('received oversized message from %s:%s, dropping connection', - self._address[0], self._address[1]) - RESET_CONNECTIONS.labels(version=self.client_version).inc() - self._close() - return - - self.last_recv = time.perf_counter() - self.recv_count += 1 - - try: - requests = self.connection.receive_message(message) - except ProtocolError as e: - self.logger.debug(f'{e}') - if e.error_message: - await self._send_message(e.error_message) - if e.code == JSONRPC.PARSE_ERROR: - self.max_errors = 0 - self._bump_errors() - else: - for request in requests: - await self._task_group.add(self._handle_request(request)) - - async def _handle_request(self, request): - start = time.perf_counter() - try: - result = await self.handle_request(request) - except (ProtocolError, RPCError) as e: - result = e - except CancelledError: - raise - except Exception: - self.logger.exception(f'exception handling {request}') - result = RPCError(JSONRPC.INTERNAL_ERROR, - 'internal server error') - if isinstance(request, Request): - message = request.send_result(result) - RESPONSE_TIMES.labels( - method=request.method, - version=self.client_version - ).observe(time.perf_counter() - start) - if message: - await self._send_message(message) - if isinstance(result, Exception): - self._bump_errors() - REQUEST_ERRORS_COUNT.labels( - method=request.method, - version=self.client_version - ).inc() - - def connection_lost(self, exc): - # Cancel pending requests and message processing - self.connection.raise_pending_requests(exc) - super().connection_lost(exc) - - # External API - def default_connection(self): - """Return a default connection if the user provides none.""" - return JSONRPCConnection(JSONRPCv2) - - def default_framer(self): - """Return a default framer.""" - return NewlineFramer() - - async def handle_request(self, request): - pass - - async def send_request(self, method, args=()): - """Send an RPC request over the network.""" - if self.is_closing(): - raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.") - message, event = self.connection.send_request(Request(method, args)) - await self._send_message(message) - await event.wait() - result = event.result - if isinstance(result, Exception): - raise result - return result - - async def send_notification(self, method, args=()): - """Send an RPC notification over the network.""" - message = self.connection.send_notification(Notification(method, args)) - NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc() - await self._send_message(message) - - def send_batch(self, raise_errors=False): - """Return a BatchRequest. Intended to be used like so: - - async with session.send_batch() as batch: - batch.add_request("method1") - batch.add_request("sum", (x, y)) - batch.add_notification("updated") - - for result in batch.results: - ... - - Note that in some circumstances exceptions can be raised; see - BatchRequest doc string. - """ - return BatchRequest(self, raise_errors) - - -class Server: - """A simple wrapper around an asyncio.Server object.""" - - def __init__(self, session_factory, host=None, port=None, *, - loop=None, **kwargs): - self.host = host - self.port = port - self.loop = loop or asyncio.get_event_loop() - self.server = None - self._session_factory = session_factory - self._kwargs = kwargs - - async def listen(self): - self.server = await self.loop.create_server( - self._session_factory, self.host, self.port, **self._kwargs) - - async def close(self): - """Close the listening socket. This does not close any ServerSession - objects created to handle incoming connections. - """ - if self.server: - self.server.close() - await self.server.wait_closed() - self.server = None diff --git a/lbry/wallet/rpc/socks.py b/lbry/wallet/rpc/socks.py deleted file mode 100644 index bfaadc8b6..000000000 --- a/lbry/wallet/rpc/socks.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""SOCKS proxying.""" - -import sys -import asyncio -import collections -import ipaddress -import socket -import struct -from functools import partial - - -__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy', - 'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure') - - -SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password") - - -class SOCKSError(Exception): - """Base class for SOCKS exceptions. Each raised exception will be - an instance of a derived class.""" - - -class SOCKSProtocolError(SOCKSError): - """Raised when the proxy does not follow the SOCKS protocol""" - - -class SOCKSFailure(SOCKSError): - """Raised when the proxy refuses or fails to make a connection""" - - -class NeedData(Exception): - pass - - -class SOCKSBase: - - @classmethod - def name(cls): - return cls.__name__ - - def __init__(self): - self._buffer = bytes() - self._state = self._start - - def _read(self, size): - if len(self._buffer) < size: - raise NeedData(size - len(self._buffer)) - result = self._buffer[:size] - self._buffer = self._buffer[size:] - return result - - def receive_data(self, data): - self._buffer += data - - def next_message(self): - return self._state() - - -class SOCKS4(SOCKSBase): - """SOCKS4 protocol wrapper.""" - - # See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol - REPLY_CODES = { - 90: 'request granted', - 91: 'request rejected or failed', - 92: ('request rejected because SOCKS server cannot connect ' - 'to identd on the client'), - 93: ('request rejected because the client program and identd ' - 'report different user-ids') - } - - def __init__(self, dst_host, dst_port, auth): - super().__init__() - self._dst_host = self._check_host(dst_host) - self._dst_port = dst_port - self._auth = auth - - @classmethod - def _check_host(cls, host): - if not isinstance(host, ipaddress.IPv4Address): - try: - host = ipaddress.IPv4Address(host) - except ValueError: - raise SOCKSProtocolError( - f'SOCKS4 requires an IPv4 address: {host}') from None - return host - - def _start(self): - self._state = self._first_response - - if isinstance(self._dst_host, ipaddress.IPv4Address): - # SOCKS4 - dst_ip_packed = self._dst_host.packed - host_bytes = b'' - else: - # SOCKS4a - dst_ip_packed = b'\0\0\0\1' - host_bytes = self._dst_host.encode() + b'\0' - - if isinstance(self._auth, SOCKSUserAuth): - user_id = self._auth.username.encode() - else: - user_id = b'' - - # Send TCP/IP stream CONNECT request - return b''.join([b'\4\1', struct.pack('>H', self._dst_port), - dst_ip_packed, user_id, b'\0', host_bytes]) - - def _first_response(self): - # Wait for 8-byte response - data = self._read(8) - if data[0] != 0: - raise SOCKSProtocolError(f'invalid {self.name()} proxy ' - f'response: {data}') - reply_code = data[1] - if reply_code != 90: - msg = self.REPLY_CODES.get( - reply_code, f'unknown {self.name()} reply code {reply_code}') - raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}') - - # Other fields ignored - return None - - -class SOCKS4a(SOCKS4): - - @classmethod - def _check_host(cls, host): - if not isinstance(host, (str, ipaddress.IPv4Address)): - raise SOCKSProtocolError( - f'SOCKS4a requires an IPv4 address or host name: {host}') - return host - - -class SOCKS5(SOCKSBase): - """SOCKS protocol wrapper.""" - - # See https://tools.ietf.org/html/rfc1928 - ERROR_CODES = { - 1: 'general SOCKS server failure', - 2: 'connection not allowed by ruleset', - 3: 'network unreachable', - 4: 'host unreachable', - 5: 'connection refused', - 6: 'TTL expired', - 7: 'command not supported', - 8: 'address type not supported', - } - - def __init__(self, dst_host, dst_port, auth): - super().__init__() - self._dst_bytes = self._destination_bytes(dst_host, dst_port) - self._auth_bytes, self._auth_methods = self._authentication(auth) - - def _destination_bytes(self, host, port): - if isinstance(host, ipaddress.IPv4Address): - addr_bytes = b'\1' + host.packed - elif isinstance(host, ipaddress.IPv6Address): - addr_bytes = b'\4' + host.packed - elif isinstance(host, str): - host = host.encode() - if len(host) > 255: - raise SOCKSProtocolError(f'hostname too long: ' - f'{len(host)} bytes') - addr_bytes = b'\3' + bytes([len(host)]) + host - else: - raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 ' - f'address, or host name: {host}') - return addr_bytes + struct.pack('>H', port) - - def _authentication(self, auth): - if isinstance(auth, SOCKSUserAuth): - user_bytes = auth.username.encode() - if not 0 < len(user_bytes) < 256: - raise SOCKSProtocolError(f'username {auth.username} has ' - f'invalid length {len(user_bytes)}') - pwd_bytes = auth.password.encode() - if not 0 < len(pwd_bytes) < 256: - raise SOCKSProtocolError(f'password has invalid length ' - f'{len(pwd_bytes)}') - return b''.join([bytes([1, len(user_bytes)]), user_bytes, - bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2] - return b'', [0] - - def _start(self): - self._state = self._first_response - return (b'\5' + bytes([len(self._auth_methods)]) - + bytes(m for m in self._auth_methods)) - - def _first_response(self): - # Wait for 2-byte response - data = self._read(2) - if data[0] != 5: - raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}') - if data[1] not in self._auth_methods: - raise SOCKSFailure('SOCKS5 proxy rejected authentication methods') - - # Authenticate if user-password authentication - if data[1] == 2: - self._state = self._auth_response - return self._auth_bytes - return self._request_connection() - - def _auth_response(self): - data = self._read(2) - if data[0] != 1: - raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth ' - f'response: {data}') - if data[1] != 0: - raise SOCKSFailure(f'SOCKS5 proxy auth failure code: ' - f'{data[1]}') - - return self._request_connection() - - def _request_connection(self): - # Send connection request - self._state = self._connect_response - return b'\5\1\0' + self._dst_bytes - - def _connect_response(self): - data = self._read(5) - if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4): - raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}') - if data[1] != 0: - raise SOCKSFailure(self.ERROR_CODES.get( - data[1], f'unknown SOCKS5 error code: {data[1]}')) - - if data[3] == 1: - addr_len = 3 # IPv4 - elif data[3] == 3: - addr_len = data[4] # Hostname - else: - addr_len = 15 # IPv6 - - self._state = partial(self._connect_response_rest, addr_len) - return self.next_message() - - def _connect_response_rest(self, addr_len): - self._read(addr_len + 2) - return None - - -class SOCKSProxy: - - def __init__(self, address, protocol, auth): - """A SOCKS proxy at an address following a SOCKS protocol. auth is an - authentication method to use when connecting, or None. - - address is a (host, port) pair; for IPv6 it can instead be a - (host, port, flowinfo, scopeid) 4-tuple. - """ - self.address = address - self.protocol = protocol - self.auth = auth - # Set on each successful connection via the proxy to the - # result of socket.getpeername() - self.peername = None - - def __str__(self): - auth = 'username' if self.auth else 'none' - return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}' - - async def _handshake(self, client, sock, loop): - while True: - count = 0 - try: - message = client.next_message() - except NeedData as e: - count = e.args[0] - else: - if message is None: - return - await loop.sock_sendall(sock, message) - - if count: - data = await loop.sock_recv(sock, count) - if not data: - raise SOCKSProtocolError("EOF received") - client.receive_data(data) - - async def _connect_one(self, host, port): - """Connect to the proxy and perform a handshake requesting a - connection to (host, port). - - Return the open socket on success, or the exception on failure. - """ - client = self.protocol(host, port, self.auth) - sock = socket.socket() - loop = asyncio.get_event_loop() - try: - # A non-blocking socket is required by loop socket methods - sock.setblocking(False) - await loop.sock_connect(sock, self.address) - await self._handshake(client, sock, loop) - self.peername = sock.getpeername() - return sock - except Exception as e: - # Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8 - if sys.platform.startswith('linux') or sys.platform == "darwin": - sock.close() - return e - - async def _connect(self, addresses): - """Connect to the proxy and perform a handshake requesting a - connection to each address in addresses. - - Return an (open_socket, address) pair on success. - """ - assert len(addresses) > 0 - - exceptions = [] - for address in addresses: - host, port = address[:2] - sock = await self._connect_one(host, port) - if isinstance(sock, socket.socket): - return sock, address - exceptions.append(sock) - - strings = {f'{exc!r}' for exc in exceptions} - raise (exceptions[0] if len(strings) == 1 else - OSError(f'multiple exceptions: {", ".join(strings)}')) - - async def _detect_proxy(self): - """Return True if it appears we can connect to a SOCKS proxy, - otherwise False. - """ - if self.protocol is SOCKS4a: - host, port = 'www.apple.com', 80 - else: - host, port = ipaddress.IPv4Address('8.8.8.8'), 53 - - sock = await self._connect_one(host, port) - if isinstance(sock, socket.socket): - sock.close() - return True - - # SOCKSFailure indicates something failed, but that we are - # likely talking to a proxy - return isinstance(sock, SOCKSFailure) - - @classmethod - async def auto_detect_address(cls, address, auth): - """Try to detect a SOCKS proxy at address using the authentication - method (or None). SOCKS5, SOCKS4a and SOCKS are tried in - order. If a SOCKS proxy is detected a SOCKSProxy object is - returned. - - Returning a SOCKSProxy does not mean it is functioning - for - example, it may have no network connectivity. - - If no proxy is detected return None. - """ - for protocol in (SOCKS5, SOCKS4a, SOCKS4): - proxy = cls(address, protocol, auth) - if await proxy._detect_proxy(): - return proxy - return None - - @classmethod - async def auto_detect_host(cls, host, ports, auth): - """Try to detect a SOCKS proxy on a host on one of the ports. - - Calls auto_detect for the ports in order. Returns SOCKS are - tried in order; a SOCKSProxy object for the first detected - proxy is returned. - - Returning a SOCKSProxy does not mean it is functioning - for - example, it may have no network connectivity. - - If no proxy is detected return None. - """ - for port in ports: - address = (host, port) - proxy = await cls.auto_detect_address(address, auth) - if proxy: - return proxy - - return None - - async def create_connection(self, protocol_factory, host, port, *, - resolve=False, ssl=None, - family=0, proto=0, flags=0): - """Set up a connection to (host, port) through the proxy. - - If resolve is True then host is resolved locally with - getaddrinfo using family, proto and flags, otherwise the proxy - is asked to resolve host. - - The function signature is similar to loop.create_connection() - with the same result. The attribute _address is set on the - protocol to the address of the successful remote connection. - Additionally raises SOCKSError if something goes wrong with - the proxy handshake. - """ - loop = asyncio.get_event_loop() - if resolve: - infos = await loop.getaddrinfo(host, port, family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - addresses = [info[4] for info in infos] - else: - addresses = [(host, port)] - - sock, address = await self._connect(addresses) - - def set_address(): - protocol = protocol_factory() - protocol._address = address - return protocol - - return await loop.create_connection( - set_address, sock=sock, ssl=ssl, - server_hostname=host if ssl else None) diff --git a/lbry/wallet/rpc/util.py b/lbry/wallet/rpc/util.py deleted file mode 100644 index 1fb743de3..000000000 --- a/lbry/wallet/rpc/util.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2018, Neil Booth -# -# All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -__all__ = () - - -import asyncio -from collections import namedtuple -import inspect - -# other_params: None means cannot be called with keyword arguments only -# any means any name is good -SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args ' - 'required_names other_names') - - -def signature_info(func): - params = inspect.signature(func).parameters - min_args = max_args = 0 - required_names = [] - other_names = [] - no_names = False - for p in params.values(): - if p.kind == p.POSITIONAL_OR_KEYWORD: - max_args += 1 - if p.default is p.empty: - min_args += 1 - required_names.append(p.name) - else: - other_names.append(p.name) - elif p.kind == p.KEYWORD_ONLY: - other_names.append(p.name) - elif p.kind == p.VAR_POSITIONAL: - max_args = None - elif p.kind == p.VAR_KEYWORD: - other_names = any - elif p.kind == p.POSITIONAL_ONLY: - max_args += 1 - if p.default is p.empty: - min_args += 1 - no_names = True - - if no_names: - other_names = None - - return SignatureInfo(min_args, max_args, required_names, other_names) - - -class Concurrency: - - def __init__(self, max_concurrent): - self._require_non_negative(max_concurrent) - self._max_concurrent = max_concurrent - self.semaphore = asyncio.Semaphore(max_concurrent) - - def _require_non_negative(self, value): - if not isinstance(value, int) or value < 0: - raise RuntimeError('concurrency must be a natural number') - - @property - def max_concurrent(self): - return self._max_concurrent - - async def set_max_concurrent(self, value): - self._require_non_negative(value) - diff = value - self._max_concurrent - self._max_concurrent = value - if diff >= 0: - for _ in range(diff): - self.semaphore.release() - else: - for _ in range(-diff): - await self.semaphore.acquire()