PEP-257 docstrings

Signed-off-by: binaryflesh <logan.campos123@gmail.com>
This commit is contained in:
binaryflesh 2019-04-16 02:50:35 -05:00 committed by Lex Berezhny
parent f0c2d16749
commit 6788e09ae9
22 changed files with 644 additions and 644 deletions

View file

@ -23,7 +23,7 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''RPC message framing in a byte stream.''' """RPC message framing in a byte stream."""
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer', __all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError') 'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
@ -34,38 +34,38 @@ from asyncio import Queue
class FramerBase(object): class FramerBase(object):
'''Abstract base class for a framer. """Abstract base class for a framer.
A framer breaks an incoming byte stream into protocol messages, A framer breaks an incoming byte stream into protocol messages,
buffering if necesary. It also frames outgoing messages into buffering if necesary. It also frames outgoing messages into
a byte stream. a byte stream.
''' """
def frame(self, message): def frame(self, message):
'''Return the framed message.''' """Return the framed message."""
raise NotImplementedError raise NotImplementedError
def received_bytes(self, data): def received_bytes(self, data):
'''Pass incoming network bytes.''' """Pass incoming network bytes."""
raise NotImplementedError raise NotImplementedError
async def receive_message(self): async def receive_message(self):
'''Wait for a complete unframed message to arrive, and return it.''' """Wait for a complete unframed message to arrive, and return it."""
raise NotImplementedError raise NotImplementedError
class NewlineFramer(FramerBase): class NewlineFramer(FramerBase):
'''A framer for a protocol where messages are separated by newlines.''' """A framer for a protocol where messages are separated by newlines."""
# The default max_size value is motivated by JSONRPC, where a # The default max_size value is motivated by JSONRPC, where a
# normal request will be 250 bytes or less, and a reasonable # normal request will be 250 bytes or less, and a reasonable
# batch may contain 4000 requests. # batch may contain 4000 requests.
def __init__(self, max_size=250 * 4000): def __init__(self, max_size=250 * 4000):
'''max_size - an anti-DoS measure. If, after processing an incoming """max_size - an anti-DoS measure. If, after processing an incoming
message, buffered data would exceed max_size bytes, that message, buffered data would exceed max_size bytes, that
buffered data is dropped entirely and the framer waits for a buffered data is dropped entirely and the framer waits for a
newline character to re-synchronize the stream. newline character to re-synchronize the stream.
''' """
self.max_size = max_size self.max_size = max_size
self.queue = Queue() self.queue = Queue()
self.received_bytes = self.queue.put_nowait self.received_bytes = self.queue.put_nowait
@ -105,9 +105,9 @@ class NewlineFramer(FramerBase):
class ByteQueue(object): class ByteQueue(object):
'''A producer-comsumer queue. Incoming network data is put as it """A producer-comsumer queue. Incoming network data is put as it
arrives, and the consumer calls an async method waiting for data of arrives, and the consumer calls an async method waiting for data of
a specific length.''' a specific length."""
def __init__(self): def __init__(self):
self.queue = Queue() self.queue = Queue()
@ -127,7 +127,7 @@ class ByteQueue(object):
class BinaryFramer(object): class BinaryFramer(object):
'''A framer for binary messaging protocols.''' """A framer for binary messaging protocols."""
def __init__(self): def __init__(self):
self.byte_queue = ByteQueue() self.byte_queue = ByteQueue()
@ -165,12 +165,12 @@ pack_le_uint32 = struct_le_I.pack
def sha256(x): def sha256(x):
'''Simple wrapper of hashlib sha256.''' """Simple wrapper of hashlib sha256."""
return _sha256(x).digest() return _sha256(x).digest()
def double_sha256(x): def double_sha256(x):
'''SHA-256 of SHA-256, as used extensively in bitcoin.''' """SHA-256 of SHA-256, as used extensively in bitcoin."""
return sha256(sha256(x)) return sha256(sha256(x))
@ -187,7 +187,7 @@ class OversizedPayloadError(Exception):
class BitcoinFramer(BinaryFramer): class BitcoinFramer(BinaryFramer):
'''Provides a framer of binary message payloads in the style of the """Provides a framer of binary message payloads in the style of the
Bitcoin network protocol. Bitcoin network protocol.
Each binary message has the following elements, in order: Each binary message has the following elements, in order:
@ -201,7 +201,7 @@ class BitcoinFramer(BinaryFramer):
Call frame(command, payload) to get a framed message. Call frame(command, payload) to get a framed message.
Pass incoming network bytes to received_bytes(). Pass incoming network bytes to received_bytes().
Wait on receive_message() to get incoming (command, payload) pairs. Wait on receive_message() to get incoming (command, payload) pairs.
''' """
def __init__(self, magic, max_block_size): def __init__(self, magic, max_block_size):
def pad_command(command): def pad_command(command):

View file

@ -23,7 +23,7 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.''' """Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation."""
__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose',
'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch', 'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch',
@ -156,7 +156,7 @@ class ProtocolError(CodeMessageError):
class JSONRPC(object): class JSONRPC(object):
'''Abstract base class that interprets and constructs JSON RPC messages.''' """Abstract base class that interprets and constructs JSON RPC messages."""
# Error codes. See http://www.jsonrpc.org/specification # Error codes. See http://www.jsonrpc.org/specification
PARSE_ERROR = -32700 PARSE_ERROR = -32700
@ -172,24 +172,24 @@ class JSONRPC(object):
@classmethod @classmethod
def _message_id(cls, message, require_id): def _message_id(cls, message, require_id):
'''Validate the message is a dictionary and return its ID. """Validate the message is a dictionary and return its ID.
Raise an error if the message is invalid or the ID is of an Raise an error if the message is invalid or the ID is of an
invalid type. If it has no ID, raise an error if require_id invalid type. If it has no ID, raise an error if require_id
is True, otherwise return None. is True, otherwise return None.
''' """
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def _validate_message(cls, message): def _validate_message(cls, message):
'''Validate other parts of the message other than those """Validate other parts of the message other than those
done in _message_id.''' done in _message_id."""
pass pass
@classmethod @classmethod
def _request_args(cls, request): def _request_args(cls, request):
'''Validate the existence and type of the arguments passed """Validate the existence and type of the arguments passed
in the request dictionary.''' in the request dictionary."""
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -221,7 +221,7 @@ class JSONRPC(object):
@classmethod @classmethod
def _message_to_payload(cls, message): def _message_to_payload(cls, message):
'''Returns a Python object or a ProtocolError.''' """Returns a Python object or a ProtocolError."""
try: try:
return json.loads(message.decode()) return json.loads(message.decode())
except UnicodeDecodeError: except UnicodeDecodeError:
@ -245,7 +245,7 @@ class JSONRPC(object):
@classmethod @classmethod
def message_to_item(cls, message): def message_to_item(cls, message):
'''Translate an unframed received message and return an """Translate an unframed received message and return an
(item, request_id) pair. (item, request_id) pair.
The item can be a Request, Notification, Response or a list. The item can be a Request, Notification, Response or a list.
@ -264,7 +264,7 @@ class JSONRPC(object):
the response was bad. the response was bad.
raises: ProtocolError raises: ProtocolError
''' """
payload = cls._message_to_payload(message) payload = cls._message_to_payload(message)
if isinstance(payload, dict): if isinstance(payload, dict):
if 'method' in payload: if 'method' in payload:
@ -282,19 +282,19 @@ class JSONRPC(object):
# Message formation # Message formation
@classmethod @classmethod
def request_message(cls, item, request_id): def request_message(cls, item, request_id):
'''Convert an RPCRequest item to a message.''' """Convert an RPCRequest item to a message."""
assert isinstance(item, Request) assert isinstance(item, Request)
return cls.encode_payload(cls.request_payload(item, request_id)) return cls.encode_payload(cls.request_payload(item, request_id))
@classmethod @classmethod
def notification_message(cls, item): def notification_message(cls, item):
'''Convert an RPCRequest item to a message.''' """Convert an RPCRequest item to a message."""
assert isinstance(item, Notification) assert isinstance(item, Notification)
return cls.encode_payload(cls.request_payload(item, None)) return cls.encode_payload(cls.request_payload(item, None))
@classmethod @classmethod
def response_message(cls, result, request_id): def response_message(cls, result, request_id):
'''Convert a response result (or RPCError) to a message.''' """Convert a response result (or RPCError) to a message."""
if isinstance(result, CodeMessageError): if isinstance(result, CodeMessageError):
payload = cls.error_payload(result, request_id) payload = cls.error_payload(result, request_id)
else: else:
@ -303,7 +303,7 @@ class JSONRPC(object):
@classmethod @classmethod
def batch_message(cls, batch, request_ids): def batch_message(cls, batch, request_ids):
'''Convert a request Batch to a message.''' """Convert a request Batch to a message."""
assert isinstance(batch, Batch) assert isinstance(batch, Batch)
if not cls.allow_batches: if not cls.allow_batches:
raise ProtocolError.invalid_request( raise ProtocolError.invalid_request(
@ -317,9 +317,9 @@ class JSONRPC(object):
@classmethod @classmethod
def batch_message_from_parts(cls, messages): def batch_message_from_parts(cls, messages):
'''Convert messages, one per batch item, into a batch message. At """Convert messages, one per batch item, into a batch message. At
least one message must be passed. least one message must be passed.
''' """
# Comma-separate the messages and wrap the lot in square brackets # Comma-separate the messages and wrap the lot in square brackets
middle = b', '.join(messages) middle = b', '.join(messages)
if not middle: if not middle:
@ -328,7 +328,7 @@ class JSONRPC(object):
@classmethod @classmethod
def encode_payload(cls, payload): def encode_payload(cls, payload):
'''Encode a Python object as JSON and convert it to bytes.''' """Encode a Python object as JSON and convert it to bytes."""
try: try:
return json.dumps(payload).encode() return json.dumps(payload).encode()
except TypeError: except TypeError:
@ -337,7 +337,7 @@ class JSONRPC(object):
class JSONRPCv1(JSONRPC): class JSONRPCv1(JSONRPC):
'''JSON RPC version 1.0.''' """JSON RPC version 1.0."""
allow_batches = False allow_batches = False
@ -392,7 +392,7 @@ class JSONRPCv1(JSONRPC):
@classmethod @classmethod
def request_payload(cls, request, request_id): def request_payload(cls, request, request_id):
'''JSON v1 request (or notification) payload.''' """JSON v1 request (or notification) payload."""
if isinstance(request.args, dict): if isinstance(request.args, dict):
raise ProtocolError.invalid_args( raise ProtocolError.invalid_args(
'JSONRPCv1 does not support named arguments') 'JSONRPCv1 does not support named arguments')
@ -404,7 +404,7 @@ class JSONRPCv1(JSONRPC):
@classmethod @classmethod
def response_payload(cls, result, request_id): def response_payload(cls, result, request_id):
'''JSON v1 response payload.''' """JSON v1 response payload."""
return { return {
'result': result, 'result': result,
'error': None, 'error': None,
@ -421,7 +421,7 @@ class JSONRPCv1(JSONRPC):
class JSONRPCv2(JSONRPC): class JSONRPCv2(JSONRPC):
'''JSON RPC version 2.0.''' """JSON RPC version 2.0."""
@classmethod @classmethod
def _message_id(cls, message, require_id): def _message_id(cls, message, require_id):
@ -477,7 +477,7 @@ class JSONRPCv2(JSONRPC):
@classmethod @classmethod
def request_payload(cls, request, request_id): def request_payload(cls, request, request_id):
'''JSON v2 request (or notification) payload.''' """JSON v2 request (or notification) payload."""
payload = { payload = {
'jsonrpc': '2.0', 'jsonrpc': '2.0',
'method': request.method, 'method': request.method,
@ -492,7 +492,7 @@ class JSONRPCv2(JSONRPC):
@classmethod @classmethod
def response_payload(cls, result, request_id): def response_payload(cls, result, request_id):
'''JSON v2 response payload.''' """JSON v2 response payload."""
return { return {
'jsonrpc': '2.0', 'jsonrpc': '2.0',
'result': result, 'result': result,
@ -509,7 +509,7 @@ class JSONRPCv2(JSONRPC):
class JSONRPCLoose(JSONRPC): class JSONRPCLoose(JSONRPC):
'''A relaxed versin of JSON RPC.''' """A relaxed versin of JSON RPC."""
# Don't be so loose we accept any old message ID # Don't be so loose we accept any old message ID
_message_id = JSONRPCv2._message_id _message_id = JSONRPCv2._message_id
@ -546,7 +546,7 @@ class JSONRPCAutoDetect(JSONRPCv2):
@classmethod @classmethod
def detect_protocol(cls, message): def detect_protocol(cls, message):
'''Attempt to detect the protocol from the message.''' """Attempt to detect the protocol from the message."""
main = cls._message_to_payload(message) main = cls._message_to_payload(message)
def protocol_for_payload(payload): def protocol_for_payload(payload):
@ -581,13 +581,13 @@ class JSONRPCAutoDetect(JSONRPCv2):
class JSONRPCConnection(object): class JSONRPCConnection(object):
'''Maintains state of a JSON RPC connection, in particular """Maintains state of a JSON RPC connection, in particular
encapsulating the handling of request IDs. encapsulating the handling of request IDs.
protocol - the JSON RPC protocol to follow protocol - the JSON RPC protocol to follow
max_response_size - responses over this size send an error response max_response_size - responses over this size send an error response
instead. instead.
''' """
_id_counter = itertools.count() _id_counter = itertools.count()
@ -684,7 +684,7 @@ class JSONRPCConnection(object):
# External API # External API
# #
def send_request(self, request): def send_request(self, request):
'''Send a Request. Return a (message, event) pair. """Send a Request. Return a (message, event) pair.
The message is an unframed message to send over the network. The message is an unframed message to send over the network.
Wait on the event for the response; which will be in the Wait on the event for the response; which will be in the
@ -692,7 +692,7 @@ class JSONRPCConnection(object):
Raises: ProtocolError if the request violates the protocol Raises: ProtocolError if the request violates the protocol
in some way.. in some way..
''' """
request_id = next(self._id_counter) request_id = next(self._id_counter)
message = self._protocol.request_message(request, request_id) message = self._protocol.request_message(request, request_id)
return message, self._event(request, request_id) return message, self._event(request, request_id)
@ -708,13 +708,13 @@ class JSONRPCConnection(object):
return message, event return message, event
def receive_message(self, message): def receive_message(self, message):
'''Call with an unframed message received from the network. """Call with an unframed message received from the network.
Raises: ProtocolError if the message violates the protocol in Raises: ProtocolError if the message violates the protocol in
some way. However, if it happened in a response that can be some way. However, if it happened in a response that can be
paired with a request, the ProtocolError is instead set in the paired with a request, the ProtocolError is instead set in the
result attribute of the send_request() that caused the error. result attribute of the send_request() that caused the error.
''' """
try: try:
item, request_id = self._protocol.message_to_item(message) item, request_id = self._protocol.message_to_item(message)
except ProtocolError as e: except ProtocolError as e:
@ -743,7 +743,7 @@ class JSONRPCConnection(object):
return self.receive_message(message) return self.receive_message(message)
def cancel_pending_requests(self): def cancel_pending_requests(self):
'''Cancel all pending requests.''' """Cancel all pending requests."""
exception = CancelledError() exception = CancelledError()
for request, event in self._requests.values(): for request, event in self._requests.values():
event.result = exception event.result = exception
@ -751,7 +751,7 @@ class JSONRPCConnection(object):
self._requests.clear() self._requests.clear()
def pending_requests(self): def pending_requests(self):
'''All sent requests that have not received a response.''' """All sent requests that have not received a response."""
return [request for request, event in self._requests.values()] return [request for request, event in self._requests.values()]

View file

@ -54,7 +54,7 @@ class Connector:
self.kwargs = kwargs self.kwargs = kwargs
async def create_connection(self): async def create_connection(self):
'''Initiate a connection.''' """Initiate a connection."""
connector = self.proxy or self.loop connector = self.proxy or self.loop
return await connector.create_connection( return await connector.create_connection(
self.session_factory, self.host, self.port, **self.kwargs) self.session_factory, self.host, self.port, **self.kwargs)
@ -70,7 +70,7 @@ class Connector:
class SessionBase(asyncio.Protocol): class SessionBase(asyncio.Protocol):
'''Base class of networking sessions. """Base class of networking sessions.
There is no client / server distinction other than who initiated There is no client / server distinction other than who initiated
the connection. the connection.
@ -81,7 +81,7 @@ class SessionBase(asyncio.Protocol):
Alternatively if used in a with statement, the connection is made Alternatively if used in a with statement, the connection is made
on entry to the block, and closed on exit from the block. on entry to the block, and closed on exit from the block.
''' """
max_errors = 10 max_errors = 10
@ -138,7 +138,7 @@ class SessionBase(asyncio.Protocol):
await self._concurrency.set_max_concurrent(target) await self._concurrency.set_max_concurrent(target)
def _using_bandwidth(self, size): def _using_bandwidth(self, size):
'''Called when sending or receiving size bytes.''' """Called when sending or receiving size bytes."""
self.bw_charge += size self.bw_charge += size
async def _limited_wait(self, secs): async def _limited_wait(self, secs):
@ -173,7 +173,7 @@ class SessionBase(asyncio.Protocol):
# asyncio framework # asyncio framework
def data_received(self, framed_message): def data_received(self, framed_message):
'''Called by asyncio when a message comes in.''' """Called by asyncio when a message comes in."""
if self.verbosity >= 4: if self.verbosity >= 4:
self.logger.debug(f'Received framed message {framed_message}') self.logger.debug(f'Received framed message {framed_message}')
self.recv_size += len(framed_message) self.recv_size += len(framed_message)
@ -181,21 +181,21 @@ class SessionBase(asyncio.Protocol):
self.framer.received_bytes(framed_message) self.framer.received_bytes(framed_message)
def pause_writing(self): def pause_writing(self):
'''Transport calls when the send buffer is full.''' """Transport calls when the send buffer is full."""
if not self.is_closing(): if not self.is_closing():
self._can_send.clear() self._can_send.clear()
self.transport.pause_reading() self.transport.pause_reading()
def resume_writing(self): def resume_writing(self):
'''Transport calls when the send buffer has room.''' """Transport calls when the send buffer has room."""
if not self._can_send.is_set(): if not self._can_send.is_set():
self._can_send.set() self._can_send.set()
self.transport.resume_reading() self.transport.resume_reading()
def connection_made(self, transport): def connection_made(self, transport):
'''Called by asyncio when a connection is established. """Called by asyncio when a connection is established.
Derived classes overriding this method must call this first.''' Derived classes overriding this method must call this first."""
self.transport = transport self.transport = transport
# This would throw if called on a closed SSL transport. Fixed # This would throw if called on a closed SSL transport. Fixed
# in asyncio in Python 3.6.1 and 3.5.4 # in asyncio in Python 3.6.1 and 3.5.4
@ -209,9 +209,9 @@ class SessionBase(asyncio.Protocol):
self._pm_task = self.loop.create_task(self._receive_messages()) self._pm_task = self.loop.create_task(self._receive_messages())
def connection_lost(self, exc): def connection_lost(self, exc):
'''Called by asyncio when the connection closes. """Called by asyncio when the connection closes.
Tear down things done in connection_made.''' Tear down things done in connection_made."""
self._address = None self._address = None
self.transport = None self.transport = None
self._task_group.cancel() self._task_group.cancel()
@ -221,21 +221,21 @@ class SessionBase(asyncio.Protocol):
# External API # External API
def default_framer(self): def default_framer(self):
'''Return a default framer.''' """Return a default framer."""
raise NotImplementedError raise NotImplementedError
def peer_address(self): def peer_address(self):
'''Returns the peer's address (Python networking address), or None if """Returns the peer's address (Python networking address), or None if
no connection or an error. no connection or an error.
This is the result of socket.getpeername() when the connection This is the result of socket.getpeername() when the connection
was made. was made.
''' """
return self._address return self._address
def peer_address_str(self): def peer_address_str(self):
'''Returns the peer's IP address and port as a human-readable """Returns the peer's IP address and port as a human-readable
string.''' string."""
if not self._address: if not self._address:
return 'unknown' return 'unknown'
ip_addr_str, port = self._address[:2] ip_addr_str, port = self._address[:2]
@ -245,16 +245,16 @@ class SessionBase(asyncio.Protocol):
return f'{ip_addr_str}:{port}' return f'{ip_addr_str}:{port}'
def is_closing(self): def is_closing(self):
'''Return True if the connection is closing.''' """Return True if the connection is closing."""
return not self.transport or self.transport.is_closing() return not self.transport or self.transport.is_closing()
def abort(self): def abort(self):
'''Forcefully close the connection.''' """Forcefully close the connection."""
if self.transport: if self.transport:
self.transport.abort() self.transport.abort()
async def close(self, *, force_after=30): async def close(self, *, force_after=30):
'''Close the connection and return when closed.''' """Close the connection and return when closed."""
self._close() self._close()
if self._pm_task: if self._pm_task:
with suppress(CancelledError): with suppress(CancelledError):
@ -264,12 +264,12 @@ class SessionBase(asyncio.Protocol):
class MessageSession(SessionBase): class MessageSession(SessionBase):
'''Session class for protocols where messages are not tied to responses, """Session class for protocols where messages are not tied to responses,
such as the Bitcoin protocol. such as the Bitcoin protocol.
To use as a client (connection-opening) session, pass host, port To use as a client (connection-opening) session, pass host, port
and perhaps a proxy. and perhaps a proxy.
''' """
async def _receive_messages(self): async def _receive_messages(self):
while not self.is_closing(): while not self.is_closing():
try: try:
@ -303,7 +303,7 @@ class MessageSession(SessionBase):
await self._task_group.add(self._throttled_message(message)) await self._task_group.add(self._throttled_message(message))
async def _throttled_message(self, message): async def _throttled_message(self, message):
'''Process a single request, respecting the concurrency limit.''' """Process a single request, respecting the concurrency limit."""
async with self._concurrency.semaphore: async with self._concurrency.semaphore:
try: try:
await self.handle_message(message) await self.handle_message(message)
@ -318,15 +318,15 @@ class MessageSession(SessionBase):
# External API # External API
def default_framer(self): def default_framer(self):
'''Return a bitcoin framer.''' """Return a bitcoin framer."""
return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000) return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000)
async def handle_message(self, message): async def handle_message(self, message):
'''message is a (command, payload) pair.''' """message is a (command, payload) pair."""
pass pass
async def send_message(self, message): async def send_message(self, message):
'''Send a message (command, payload) over the network.''' """Send a message (command, payload) over the network."""
await self._send_message(message) await self._send_message(message)
@ -337,7 +337,7 @@ class BatchError(Exception):
class BatchRequest(object): class BatchRequest(object):
'''Used to build a batch request to send to the server. Stores """Used to build a batch request to send to the server. Stores
the the
Attributes batch and results are initially None. Attributes batch and results are initially None.
@ -367,7 +367,7 @@ class BatchRequest(object):
RPC error response, or violated the protocol in some way, a RPC error response, or violated the protocol in some way, a
BatchError exception is raised. Otherwise the caller can be BatchError exception is raised. Otherwise the caller can be
certain each request returned a standard result. certain each request returned a standard result.
''' """
def __init__(self, session, raise_errors): def __init__(self, session, raise_errors):
self._session = session self._session = session
@ -401,8 +401,8 @@ class BatchRequest(object):
class RPCSession(SessionBase): class RPCSession(SessionBase):
'''Base class for protocols where a message can lead to a response, """Base class for protocols where a message can lead to a response,
for example JSON RPC.''' for example JSON RPC."""
def __init__(self, *, framer=None, loop=None, connection=None): def __init__(self, *, framer=None, loop=None, connection=None):
super().__init__(framer=framer, loop=loop) super().__init__(framer=framer, loop=loop)
@ -435,7 +435,7 @@ class RPCSession(SessionBase):
await self._task_group.add(self._throttled_request(request)) await self._task_group.add(self._throttled_request(request))
async def _throttled_request(self, request): async def _throttled_request(self, request):
'''Process a single request, respecting the concurrency limit.''' """Process a single request, respecting the concurrency limit."""
async with self._concurrency.semaphore: async with self._concurrency.semaphore:
try: try:
result = await self.handle_request(request) result = await self.handle_request(request)
@ -461,18 +461,18 @@ class RPCSession(SessionBase):
# External API # External API
def default_connection(self): def default_connection(self):
'''Return a default connection if the user provides none.''' """Return a default connection if the user provides none."""
return JSONRPCConnection(JSONRPCv2) return JSONRPCConnection(JSONRPCv2)
def default_framer(self): def default_framer(self):
'''Return a default framer.''' """Return a default framer."""
return NewlineFramer() return NewlineFramer()
async def handle_request(self, request): async def handle_request(self, request):
pass pass
async def send_request(self, method, args=()): async def send_request(self, method, args=()):
'''Send an RPC request over the network.''' """Send an RPC request over the network."""
message, event = self.connection.send_request(Request(method, args)) message, event = self.connection.send_request(Request(method, args))
await self._send_message(message) await self._send_message(message)
await event.wait() await event.wait()
@ -482,12 +482,12 @@ class RPCSession(SessionBase):
return result return result
async def send_notification(self, method, args=()): async def send_notification(self, method, args=()):
'''Send an RPC notification over the network.''' """Send an RPC notification over the network."""
message = self.connection.send_notification(Notification(method, args)) message = self.connection.send_notification(Notification(method, args))
await self._send_message(message) await self._send_message(message)
def send_batch(self, raise_errors=False): def send_batch(self, raise_errors=False):
'''Return a BatchRequest. Intended to be used like so: """Return a BatchRequest. Intended to be used like so:
async with session.send_batch() as batch: async with session.send_batch() as batch:
batch.add_request("method1") batch.add_request("method1")
@ -499,12 +499,12 @@ class RPCSession(SessionBase):
Note that in some circumstances exceptions can be raised; see Note that in some circumstances exceptions can be raised; see
BatchRequest doc string. BatchRequest doc string.
''' """
return BatchRequest(self, raise_errors) return BatchRequest(self, raise_errors)
class Server(object): class Server(object):
'''A simple wrapper around an asyncio.Server object.''' """A simple wrapper around an asyncio.Server object."""
def __init__(self, session_factory, host=None, port=None, *, def __init__(self, session_factory, host=None, port=None, *,
loop=None, **kwargs): loop=None, **kwargs):
@ -520,9 +520,9 @@ class Server(object):
self._session_factory, self.host, self.port, **self._kwargs) self._session_factory, self.host, self.port, **self._kwargs)
async def close(self): async def close(self):
'''Close the listening socket. This does not close any ServerSession """Close the listening socket. This does not close any ServerSession
objects created to handle incoming connections. objects created to handle incoming connections.
''' """
if self.server: if self.server:
self.server.close() self.server.close()
await self.server.wait_closed() await self.server.wait_closed()

View file

@ -23,7 +23,7 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''SOCKS proxying.''' """SOCKS proxying."""
import sys import sys
import asyncio import asyncio
@ -42,16 +42,16 @@ SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
class SOCKSError(Exception): class SOCKSError(Exception):
'''Base class for SOCKS exceptions. Each raised exception will be """Base class for SOCKS exceptions. Each raised exception will be
an instance of a derived class.''' an instance of a derived class."""
class SOCKSProtocolError(SOCKSError): class SOCKSProtocolError(SOCKSError):
'''Raised when the proxy does not follow the SOCKS protocol''' """Raised when the proxy does not follow the SOCKS protocol"""
class SOCKSFailure(SOCKSError): class SOCKSFailure(SOCKSError):
'''Raised when the proxy refuses or fails to make a connection''' """Raised when the proxy refuses or fails to make a connection"""
class NeedData(Exception): class NeedData(Exception):
@ -83,7 +83,7 @@ class SOCKSBase(object):
class SOCKS4(SOCKSBase): class SOCKS4(SOCKSBase):
'''SOCKS4 protocol wrapper.''' """SOCKS4 protocol wrapper."""
# See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol # See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
REPLY_CODES = { REPLY_CODES = {
@ -159,7 +159,7 @@ class SOCKS4a(SOCKS4):
class SOCKS5(SOCKSBase): class SOCKS5(SOCKSBase):
'''SOCKS protocol wrapper.''' """SOCKS protocol wrapper."""
# See https://tools.ietf.org/html/rfc1928 # See https://tools.ietf.org/html/rfc1928
ERROR_CODES = { ERROR_CODES = {
@ -269,12 +269,12 @@ class SOCKS5(SOCKSBase):
class SOCKSProxy(object): class SOCKSProxy(object):
def __init__(self, address, protocol, auth): def __init__(self, address, protocol, auth):
'''A SOCKS proxy at an address following a SOCKS protocol. auth is an """A SOCKS proxy at an address following a SOCKS protocol. auth is an
authentication method to use when connecting, or None. authentication method to use when connecting, or None.
address is a (host, port) pair; for IPv6 it can instead be a address is a (host, port) pair; for IPv6 it can instead be a
(host, port, flowinfo, scopeid) 4-tuple. (host, port, flowinfo, scopeid) 4-tuple.
''' """
self.address = address self.address = address
self.protocol = protocol self.protocol = protocol
self.auth = auth self.auth = auth
@ -305,11 +305,11 @@ class SOCKSProxy(object):
client.receive_data(data) client.receive_data(data)
async def _connect_one(self, host, port): async def _connect_one(self, host, port):
'''Connect to the proxy and perform a handshake requesting a """Connect to the proxy and perform a handshake requesting a
connection to (host, port). connection to (host, port).
Return the open socket on success, or the exception on failure. Return the open socket on success, or the exception on failure.
''' """
client = self.protocol(host, port, self.auth) client = self.protocol(host, port, self.auth)
sock = socket.socket() sock = socket.socket()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -327,11 +327,11 @@ class SOCKSProxy(object):
return e return e
async def _connect(self, addresses): async def _connect(self, addresses):
'''Connect to the proxy and perform a handshake requesting a """Connect to the proxy and perform a handshake requesting a
connection to each address in addresses. connection to each address in addresses.
Return an (open_socket, address) pair on success. Return an (open_socket, address) pair on success.
''' """
assert len(addresses) > 0 assert len(addresses) > 0
exceptions = [] exceptions = []
@ -347,9 +347,9 @@ class SOCKSProxy(object):
OSError(f'multiple exceptions: {", ".join(strings)}')) OSError(f'multiple exceptions: {", ".join(strings)}'))
async def _detect_proxy(self): async def _detect_proxy(self):
'''Return True if it appears we can connect to a SOCKS proxy, """Return True if it appears we can connect to a SOCKS proxy,
otherwise False. otherwise False.
''' """
if self.protocol is SOCKS4a: if self.protocol is SOCKS4a:
host, port = 'www.apple.com', 80 host, port = 'www.apple.com', 80
else: else:
@ -366,7 +366,7 @@ class SOCKSProxy(object):
@classmethod @classmethod
async def auto_detect_address(cls, address, auth): async def auto_detect_address(cls, address, auth):
'''Try to detect a SOCKS proxy at address using the authentication """Try to detect a SOCKS proxy at address using the authentication
method (or None). SOCKS5, SOCKS4a and SOCKS are tried in method (or None). SOCKS5, SOCKS4a and SOCKS are tried in
order. If a SOCKS proxy is detected a SOCKSProxy object is order. If a SOCKS proxy is detected a SOCKSProxy object is
returned. returned.
@ -375,7 +375,7 @@ class SOCKSProxy(object):
example, it may have no network connectivity. example, it may have no network connectivity.
If no proxy is detected return None. If no proxy is detected return None.
''' """
for protocol in (SOCKS5, SOCKS4a, SOCKS4): for protocol in (SOCKS5, SOCKS4a, SOCKS4):
proxy = cls(address, protocol, auth) proxy = cls(address, protocol, auth)
if await proxy._detect_proxy(): if await proxy._detect_proxy():
@ -384,7 +384,7 @@ class SOCKSProxy(object):
@classmethod @classmethod
async def auto_detect_host(cls, host, ports, auth): async def auto_detect_host(cls, host, ports, auth):
'''Try to detect a SOCKS proxy on a host on one of the ports. """Try to detect a SOCKS proxy on a host on one of the ports.
Calls auto_detect for the ports in order. Returns SOCKS are Calls auto_detect for the ports in order. Returns SOCKS are
tried in order; a SOCKSProxy object for the first detected tried in order; a SOCKSProxy object for the first detected
@ -394,7 +394,7 @@ class SOCKSProxy(object):
example, it may have no network connectivity. example, it may have no network connectivity.
If no proxy is detected return None. If no proxy is detected return None.
''' """
for port in ports: for port in ports:
address = (host, port) address = (host, port)
proxy = await cls.auto_detect_address(address, auth) proxy = await cls.auto_detect_address(address, auth)
@ -406,7 +406,7 @@ class SOCKSProxy(object):
async def create_connection(self, protocol_factory, host, port, *, async def create_connection(self, protocol_factory, host, port, *,
resolve=False, ssl=None, resolve=False, ssl=None,
family=0, proto=0, flags=0): family=0, proto=0, flags=0):
'''Set up a connection to (host, port) through the proxy. """Set up a connection to (host, port) through the proxy.
If resolve is True then host is resolved locally with If resolve is True then host is resolved locally with
getaddrinfo using family, proto and flags, otherwise the proxy getaddrinfo using family, proto and flags, otherwise the proxy
@ -417,7 +417,7 @@ class SOCKSProxy(object):
protocol to the address of the successful remote connection. protocol to the address of the successful remote connection.
Additionally raises SOCKSError if something goes wrong with Additionally raises SOCKSError if something goes wrong with
the proxy handshake. the proxy handshake.
''' """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if resolve: if resolve:
infos = await loop.getaddrinfo(host, port, family=family, infos = await loop.getaddrinfo(host, port, family=family,

View file

@ -6,7 +6,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Block prefetcher and chain processor.''' """Block prefetcher and chain processor."""
import asyncio import asyncio
@ -21,7 +21,7 @@ from torba.server.db import FlushData
class Prefetcher: class Prefetcher:
'''Prefetches blocks (in the forward direction only).''' """Prefetches blocks (in the forward direction only)."""
def __init__(self, daemon, coin, blocks_event): def __init__(self, daemon, coin, blocks_event):
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
@ -43,7 +43,7 @@ class Prefetcher:
self.polling_delay = 5 self.polling_delay = 5
async def main_loop(self, bp_height): async def main_loop(self, bp_height):
'''Loop forever polling for more blocks.''' """Loop forever polling for more blocks."""
await self.reset_height(bp_height) await self.reset_height(bp_height)
while True: while True:
try: try:
@ -55,7 +55,7 @@ class Prefetcher:
self.logger.info(f'ignoring daemon error: {e}') self.logger.info(f'ignoring daemon error: {e}')
def get_prefetched_blocks(self): def get_prefetched_blocks(self):
'''Called by block processor when it is processing queued blocks.''' """Called by block processor when it is processing queued blocks."""
blocks = self.blocks blocks = self.blocks
self.blocks = [] self.blocks = []
self.cache_size = 0 self.cache_size = 0
@ -63,12 +63,12 @@ class Prefetcher:
return blocks return blocks
async def reset_height(self, height): async def reset_height(self, height):
'''Reset to prefetch blocks from the block processor's height. """Reset to prefetch blocks from the block processor's height.
Used in blockchain reorganisations. This coroutine can be Used in blockchain reorganisations. This coroutine can be
called asynchronously to the _prefetch_blocks coroutine so we called asynchronously to the _prefetch_blocks coroutine so we
must synchronize with a semaphore. must synchronize with a semaphore.
''' """
async with self.semaphore: async with self.semaphore:
self.blocks.clear() self.blocks.clear()
self.cache_size = 0 self.cache_size = 0
@ -86,10 +86,10 @@ class Prefetcher:
.format(daemon_height)) .format(daemon_height))
async def _prefetch_blocks(self): async def _prefetch_blocks(self):
'''Prefetch some blocks and put them on the queue. """Prefetch some blocks and put them on the queue.
Repeats until the queue is full or caught up. Repeats until the queue is full or caught up.
''' """
daemon = self.daemon daemon = self.daemon
daemon_height = await daemon.height() daemon_height = await daemon.height()
async with self.semaphore: async with self.semaphore:
@ -136,15 +136,15 @@ class Prefetcher:
class ChainError(Exception): class ChainError(Exception):
'''Raised on error processing blocks.''' """Raised on error processing blocks."""
class BlockProcessor: class BlockProcessor:
'''Process blocks and update the DB state to match. """Process blocks and update the DB state to match.
Employ a prefetcher to prefetch blocks in batches for processing. Employ a prefetcher to prefetch blocks in batches for processing.
Coordinate backing up in case of chain reorganisations. Coordinate backing up in case of chain reorganisations.
''' """
def __init__(self, env, db, daemon, notifications): def __init__(self, env, db, daemon, notifications):
self.env = env self.env = env
@ -187,9 +187,9 @@ class BlockProcessor:
return await asyncio.shield(run_in_thread_locked()) return await asyncio.shield(run_in_thread_locked())
async def check_and_advance_blocks(self, raw_blocks): async def check_and_advance_blocks(self, raw_blocks):
'''Process the list of raw blocks passed. Detects and handles """Process the list of raw blocks passed. Detects and handles
reorgs. reorgs.
''' """
if not raw_blocks: if not raw_blocks:
return return
first = self.height + 1 first = self.height + 1
@ -224,10 +224,10 @@ class BlockProcessor:
await self.prefetcher.reset_height(self.height) await self.prefetcher.reset_height(self.height)
async def reorg_chain(self, count=None): async def reorg_chain(self, count=None):
'''Handle a chain reorganisation. """Handle a chain reorganisation.
Count is the number of blocks to simulate a reorg, or None for Count is the number of blocks to simulate a reorg, or None for
a real reorg.''' a real reorg."""
if count is None: if count is None:
self.logger.info('chain reorg detected') self.logger.info('chain reorg detected')
else: else:
@ -260,12 +260,12 @@ class BlockProcessor:
await self.prefetcher.reset_height(self.height) await self.prefetcher.reset_height(self.height)
async def reorg_hashes(self, count): async def reorg_hashes(self, count):
'''Return a pair (start, last, hashes) of blocks to back up during a """Return a pair (start, last, hashes) of blocks to back up during a
reorg. reorg.
The hashes are returned in order of increasing height. Start The hashes are returned in order of increasing height. Start
is the height of the first hash, last of the last. is the height of the first hash, last of the last.
''' """
start, count = await self.calc_reorg_range(count) start, count = await self.calc_reorg_range(count)
last = start + count - 1 last = start + count - 1
s = '' if count == 1 else 's' s = '' if count == 1 else 's'
@ -275,11 +275,11 @@ class BlockProcessor:
return start, last, await self.db.fs_block_hashes(start, count) return start, last, await self.db.fs_block_hashes(start, count)
async def calc_reorg_range(self, count): async def calc_reorg_range(self, count):
'''Calculate the reorg range''' """Calculate the reorg range"""
def diff_pos(hashes1, hashes2): def diff_pos(hashes1, hashes2):
'''Returns the index of the first difference in the hash lists. """Returns the index of the first difference in the hash lists.
If both lists match returns their length.''' If both lists match returns their length."""
for n, (hash1, hash2) in enumerate(zip(hashes1, hashes2)): for n, (hash1, hash2) in enumerate(zip(hashes1, hashes2)):
if hash1 != hash2: if hash1 != hash2:
return n return n
@ -318,7 +318,7 @@ class BlockProcessor:
# - Flushing # - Flushing
def flush_data(self): def flush_data(self):
'''The data for a flush. The lock must be taken.''' """The data for a flush. The lock must be taken."""
assert self.state_lock.locked() assert self.state_lock.locked()
return FlushData(self.height, self.tx_count, self.headers, return FlushData(self.height, self.tx_count, self.headers,
self.tx_hashes, self.undo_infos, self.utxo_cache, self.tx_hashes, self.undo_infos, self.utxo_cache,
@ -342,7 +342,7 @@ class BlockProcessor:
self.next_cache_check = time.time() + 30 self.next_cache_check = time.time() + 30
def check_cache_size(self): def check_cache_size(self):
'''Flush a cache if it gets too big.''' """Flush a cache if it gets too big."""
# Good average estimates based on traversal of subobjects and # Good average estimates based on traversal of subobjects and
# requesting size from Python (see deep_getsizeof). # requesting size from Python (see deep_getsizeof).
one_MB = 1000*1000 one_MB = 1000*1000
@ -368,10 +368,10 @@ class BlockProcessor:
return None return None
def advance_blocks(self, blocks): def advance_blocks(self, blocks):
'''Synchronously advance the blocks. """Synchronously advance the blocks.
It is already verified they correctly connect onto our tip. It is already verified they correctly connect onto our tip.
''' """
min_height = self.db.min_undo_height(self.daemon.cached_height()) min_height = self.db.min_undo_height(self.daemon.cached_height())
height = self.height height = self.height
@ -436,11 +436,11 @@ class BlockProcessor:
return undo_info return undo_info
def backup_blocks(self, raw_blocks): def backup_blocks(self, raw_blocks):
'''Backup the raw blocks and flush. """Backup the raw blocks and flush.
The blocks should be in order of decreasing height, starting at. The blocks should be in order of decreasing height, starting at.
self.height. A flush is performed once the blocks are backed up. self.height. A flush is performed once the blocks are backed up.
''' """
self.db.assert_flushed(self.flush_data()) self.db.assert_flushed(self.flush_data())
assert self.height >= len(raw_blocks) assert self.height >= len(raw_blocks)
@ -500,7 +500,7 @@ class BlockProcessor:
assert n == 0 assert n == 0
self.tx_count -= len(txs) self.tx_count -= len(txs)
'''An in-memory UTXO cache, representing all changes to UTXO state """An in-memory UTXO cache, representing all changes to UTXO state
since the last DB flush. since the last DB flush.
We want to store millions of these in memory for optimal We want to store millions of these in memory for optimal
@ -552,15 +552,15 @@ class BlockProcessor:
looking up a UTXO the prefix space of the compressed hash needs to looking up a UTXO the prefix space of the compressed hash needs to
be searched and resolved if necessary with the tx_num. The be searched and resolved if necessary with the tx_num. The
collision rate is low (<0.1%). collision rate is low (<0.1%).
''' """
def spend_utxo(self, tx_hash, tx_idx): def spend_utxo(self, tx_hash, tx_idx):
'''Spend a UTXO and return the 33-byte value. """Spend a UTXO and return the 33-byte value.
If the UTXO is not in the cache it must be on disk. We store If the UTXO is not in the cache it must be on disk. We store
all UTXOs so not finding one indicates a logic error or DB all UTXOs so not finding one indicates a logic error or DB
corruption. corruption.
''' """
# Fast track is it being in the cache # Fast track is it being in the cache
idx_packed = pack('<H', tx_idx) idx_packed = pack('<H', tx_idx)
cache_value = self.utxo_cache.pop(tx_hash + idx_packed, None) cache_value = self.utxo_cache.pop(tx_hash + idx_packed, None)
@ -599,7 +599,7 @@ class BlockProcessor:
.format(hash_to_hex_str(tx_hash), tx_idx)) .format(hash_to_hex_str(tx_hash), tx_idx))
async def _process_prefetched_blocks(self): async def _process_prefetched_blocks(self):
'''Loop forever processing blocks as they arrive.''' """Loop forever processing blocks as they arrive."""
while True: while True:
if self.height == self.daemon.cached_height(): if self.height == self.daemon.cached_height():
if not self._caught_up_event.is_set(): if not self._caught_up_event.is_set():
@ -635,7 +635,7 @@ class BlockProcessor:
# --- External API # --- External API
async def fetch_and_process_blocks(self, caught_up_event): async def fetch_and_process_blocks(self, caught_up_event):
'''Fetch, process and index blocks from the daemon. """Fetch, process and index blocks from the daemon.
Sets caught_up_event when first caught up. Flushes to disk Sets caught_up_event when first caught up. Flushes to disk
and shuts down cleanly if cancelled. and shuts down cleanly if cancelled.
@ -645,7 +645,7 @@ class BlockProcessor:
processed but not written to disk, it should write those to processed but not written to disk, it should write those to
disk before exiting, as otherwise a significant amount of work disk before exiting, as otherwise a significant amount of work
could be lost. could be lost.
''' """
self._caught_up_event = caught_up_event self._caught_up_event = caught_up_event
await self._first_open_dbs() await self._first_open_dbs()
try: try:
@ -660,10 +660,10 @@ class BlockProcessor:
self.db.close() self.db.close()
def force_chain_reorg(self, count): def force_chain_reorg(self, count):
'''Force a reorg of the given number of blocks. """Force a reorg of the given number of blocks.
Returns True if a reorg is queued, false if not caught up. Returns True if a reorg is queued, false if not caught up.
''' """
if self._caught_up_event.is_set(): if self._caught_up_event.is_set():
self.reorg_count = count self.reorg_count = count
self.blocks_event.set() self.blocks_event.set()

View file

@ -24,11 +24,11 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Module providing coin abstraction. """Module providing coin abstraction.
Anything coin-specific should go in this file and be subclassed where Anything coin-specific should go in this file and be subclassed where
necessary for appropriate handling. necessary for appropriate handling.
''' """
from collections import namedtuple from collections import namedtuple
import re import re
@ -55,11 +55,11 @@ OP_RETURN = OpCodes.OP_RETURN
class CoinError(Exception): class CoinError(Exception):
'''Exception raised for coin-related errors.''' """Exception raised for coin-related errors."""
class Coin: class Coin:
'''Base class of coin hierarchy.''' """Base class of coin hierarchy."""
REORG_LIMIT = 200 REORG_LIMIT = 200
# Not sure if these are coin-specific # Not sure if these are coin-specific
@ -88,9 +88,9 @@ class Coin:
@classmethod @classmethod
def lookup_coin_class(cls, name, net): def lookup_coin_class(cls, name, net):
'''Return a coin class given name and network. """Return a coin class given name and network.
Raise an exception if unrecognised.''' Raise an exception if unrecognised."""
req_attrs = ['TX_COUNT', 'TX_COUNT_HEIGHT', 'TX_PER_BLOCK'] req_attrs = ['TX_COUNT', 'TX_COUNT_HEIGHT', 'TX_PER_BLOCK']
for coin in util.subclasses(Coin): for coin in util.subclasses(Coin):
if (coin.NAME.lower() == name.lower() and if (coin.NAME.lower() == name.lower() and
@ -120,10 +120,10 @@ class Coin:
@classmethod @classmethod
def genesis_block(cls, block): def genesis_block(cls, block):
'''Check the Genesis block is the right one for this coin. """Check the Genesis block is the right one for this coin.
Return the block less its unspendable coinbase. Return the block less its unspendable coinbase.
''' """
header = cls.block_header(block, 0) header = cls.block_header(block, 0)
header_hex_hash = hash_to_hex_str(cls.header_hash(header)) header_hex_hash = hash_to_hex_str(cls.header_hash(header))
if header_hex_hash != cls.GENESIS_HASH: if header_hex_hash != cls.GENESIS_HASH:
@ -134,16 +134,16 @@ class Coin:
@classmethod @classmethod
def hashX_from_script(cls, script): def hashX_from_script(cls, script):
'''Returns a hashX from a script, or None if the script is provably """Returns a hashX from a script, or None if the script is provably
unspendable so the output can be dropped. unspendable so the output can be dropped.
''' """
if script and script[0] == OP_RETURN: if script and script[0] == OP_RETURN:
return None return None
return sha256(script).digest()[:HASHX_LEN] return sha256(script).digest()[:HASHX_LEN]
@staticmethod @staticmethod
def lookup_xverbytes(verbytes): def lookup_xverbytes(verbytes):
'''Return a (is_xpub, coin_class) pair given xpub/xprv verbytes.''' """Return a (is_xpub, coin_class) pair given xpub/xprv verbytes."""
# Order means BTC testnet will override NMC testnet # Order means BTC testnet will override NMC testnet
for coin in util.subclasses(Coin): for coin in util.subclasses(Coin):
if verbytes == coin.XPUB_VERBYTES: if verbytes == coin.XPUB_VERBYTES:
@ -154,23 +154,23 @@ class Coin:
@classmethod @classmethod
def address_to_hashX(cls, address): def address_to_hashX(cls, address):
'''Return a hashX given a coin address.''' """Return a hashX given a coin address."""
return cls.hashX_from_script(cls.pay_to_address_script(address)) return cls.hashX_from_script(cls.pay_to_address_script(address))
@classmethod @classmethod
def P2PKH_address_from_hash160(cls, hash160): def P2PKH_address_from_hash160(cls, hash160):
'''Return a P2PKH address given a public key.''' """Return a P2PKH address given a public key."""
assert len(hash160) == 20 assert len(hash160) == 20
return cls.ENCODE_CHECK(cls.P2PKH_VERBYTE + hash160) return cls.ENCODE_CHECK(cls.P2PKH_VERBYTE + hash160)
@classmethod @classmethod
def P2PKH_address_from_pubkey(cls, pubkey): def P2PKH_address_from_pubkey(cls, pubkey):
'''Return a coin address given a public key.''' """Return a coin address given a public key."""
return cls.P2PKH_address_from_hash160(hash160(pubkey)) return cls.P2PKH_address_from_hash160(hash160(pubkey))
@classmethod @classmethod
def P2SH_address_from_hash160(cls, hash160): def P2SH_address_from_hash160(cls, hash160):
'''Return a coin address given a hash160.''' """Return a coin address given a hash160."""
assert len(hash160) == 20 assert len(hash160) == 20
return cls.ENCODE_CHECK(cls.P2SH_VERBYTES[0] + hash160) return cls.ENCODE_CHECK(cls.P2SH_VERBYTES[0] + hash160)
@ -184,10 +184,10 @@ class Coin:
@classmethod @classmethod
def pay_to_address_script(cls, address): def pay_to_address_script(cls, address):
'''Return a pubkey script that pays to a pubkey hash. """Return a pubkey script that pays to a pubkey hash.
Pass the address (either P2PKH or P2SH) in base58 form. Pass the address (either P2PKH or P2SH) in base58 form.
''' """
raw = cls.DECODE_CHECK(address) raw = cls.DECODE_CHECK(address)
# Require version byte(s) plus hash160. # Require version byte(s) plus hash160.
@ -205,7 +205,7 @@ class Coin:
@classmethod @classmethod
def privkey_WIF(cls, privkey_bytes, compressed): def privkey_WIF(cls, privkey_bytes, compressed):
'''Return the private key encoded in Wallet Import Format.''' """Return the private key encoded in Wallet Import Format."""
payload = bytearray(cls.WIF_BYTE) + privkey_bytes payload = bytearray(cls.WIF_BYTE) + privkey_bytes
if compressed: if compressed:
payload.append(0x01) payload.append(0x01)
@ -213,48 +213,48 @@ class Coin:
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return hash''' """Given a header return hash"""
return double_sha256(header) return double_sha256(header)
@classmethod @classmethod
def header_prevhash(cls, header): def header_prevhash(cls, header):
'''Given a header return previous hash''' """Given a header return previous hash"""
return header[4:36] return header[4:36]
@classmethod @classmethod
def static_header_offset(cls, height): def static_header_offset(cls, height):
'''Given a header height return its offset in the headers file. """Given a header height return its offset in the headers file.
If header sizes change at some point, this is the only code If header sizes change at some point, this is the only code
that needs updating.''' that needs updating."""
assert cls.STATIC_BLOCK_HEADERS assert cls.STATIC_BLOCK_HEADERS
return height * cls.BASIC_HEADER_SIZE return height * cls.BASIC_HEADER_SIZE
@classmethod @classmethod
def static_header_len(cls, height): def static_header_len(cls, height):
'''Given a header height return its length.''' """Given a header height return its length."""
return (cls.static_header_offset(height + 1) return (cls.static_header_offset(height + 1)
- cls.static_header_offset(height)) - cls.static_header_offset(height))
@classmethod @classmethod
def block_header(cls, block, height): def block_header(cls, block, height):
'''Returns the block header given a block and its height.''' """Returns the block header given a block and its height."""
return block[:cls.static_header_len(height)] return block[:cls.static_header_len(height)]
@classmethod @classmethod
def block(cls, raw_block, height): def block(cls, raw_block, height):
'''Return a Block namedtuple given a raw block and its height.''' """Return a Block namedtuple given a raw block and its height."""
header = cls.block_header(raw_block, height) header = cls.block_header(raw_block, height)
txs = cls.DESERIALIZER(raw_block, start=len(header)).read_tx_block() txs = cls.DESERIALIZER(raw_block, start=len(header)).read_tx_block()
return Block(raw_block, header, txs) return Block(raw_block, header, txs)
@classmethod @classmethod
def decimal_value(cls, value): def decimal_value(cls, value):
'''Return the number of standard coin units as a Decimal given a """Return the number of standard coin units as a Decimal given a
quantity of smallest units. quantity of smallest units.
For example 1 BTC is returned for 100 million satoshis. For example 1 BTC is returned for 100 million satoshis.
''' """
return Decimal(value) / cls.VALUE_PER_COIN return Decimal(value) / cls.VALUE_PER_COIN
@classmethod @classmethod
@ -274,12 +274,12 @@ class AuxPowMixin:
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return hash''' """Given a header return hash"""
return double_sha256(header[:cls.BASIC_HEADER_SIZE]) return double_sha256(header[:cls.BASIC_HEADER_SIZE])
@classmethod @classmethod
def block_header(cls, block, height): def block_header(cls, block, height):
'''Return the AuxPow block header bytes''' """Return the AuxPow block header bytes"""
deserializer = cls.DESERIALIZER(block) deserializer = cls.DESERIALIZER(block)
return deserializer.read_header(height, cls.BASIC_HEADER_SIZE) return deserializer.read_header(height, cls.BASIC_HEADER_SIZE)
@ -306,7 +306,7 @@ class EquihashMixin:
@classmethod @classmethod
def block_header(cls, block, height): def block_header(cls, block, height):
'''Return the block header bytes''' """Return the block header bytes"""
deserializer = cls.DESERIALIZER(block) deserializer = cls.DESERIALIZER(block)
return deserializer.read_header(height, cls.BASIC_HEADER_SIZE) return deserializer.read_header(height, cls.BASIC_HEADER_SIZE)
@ -318,7 +318,7 @@ class ScryptMixin:
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
if cls.HEADER_HASH is None: if cls.HEADER_HASH is None:
import scrypt import scrypt
cls.HEADER_HASH = lambda x: scrypt.hash(x, x, 1024, 1, 1, 32) cls.HEADER_HASH = lambda x: scrypt.hash(x, x, 1024, 1, 1, 32)
@ -432,7 +432,7 @@ class BitcoinGold(EquihashMixin, BitcoinMixin, Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return hash''' """Given a header return hash"""
height, = util.unpack_le_uint32_from(header, 68) height, = util.unpack_le_uint32_from(header, 68)
if height >= cls.FORK_HEIGHT: if height >= cls.FORK_HEIGHT:
return double_sha256(header) return double_sha256(header)
@ -511,7 +511,7 @@ class Emercoin(Coin):
@classmethod @classmethod
def block_header(cls, block, height): def block_header(cls, block, height):
'''Returns the block header given a block and its height.''' """Returns the block header given a block and its height."""
deserializer = cls.DESERIALIZER(block) deserializer = cls.DESERIALIZER(block)
if deserializer.is_merged_block(): if deserializer.is_merged_block():
@ -520,7 +520,7 @@ class Emercoin(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return hash''' """Given a header return hash"""
return double_sha256(header[:cls.BASIC_HEADER_SIZE]) return double_sha256(header[:cls.BASIC_HEADER_SIZE])
@ -543,7 +543,7 @@ class BitcoinTestnetMixin:
class BitcoinCashTestnet(BitcoinTestnetMixin, Coin): class BitcoinCashTestnet(BitcoinTestnetMixin, Coin):
'''Bitcoin Testnet for Bitcoin Cash daemons.''' """Bitcoin Testnet for Bitcoin Cash daemons."""
NAME = "BitcoinCash" NAME = "BitcoinCash"
PEERS = [ PEERS = [
'electrum-testnet-abc.criptolayer.net s50112', 'electrum-testnet-abc.criptolayer.net s50112',
@ -563,7 +563,7 @@ class BitcoinCashRegtest(BitcoinCashTestnet):
class BitcoinSegwitTestnet(BitcoinTestnetMixin, Coin): class BitcoinSegwitTestnet(BitcoinTestnetMixin, Coin):
'''Bitcoin Testnet for Core bitcoind >= 0.13.1.''' """Bitcoin Testnet for Core bitcoind >= 0.13.1."""
NAME = "BitcoinSegwit" NAME = "BitcoinSegwit"
DESERIALIZER = lib_tx.DeserializerSegWit DESERIALIZER = lib_tx.DeserializerSegWit
PEERS = [ PEERS = [
@ -588,7 +588,7 @@ class BitcoinSegwitRegtest(BitcoinSegwitTestnet):
class BitcoinNolnet(BitcoinCash): class BitcoinNolnet(BitcoinCash):
'''Bitcoin Unlimited nolimit testnet.''' """Bitcoin Unlimited nolimit testnet."""
NET = "nolnet" NET = "nolnet"
GENESIS_HASH = ('0000000057e31bd2066c939a63b7b862' GENESIS_HASH = ('0000000057e31bd2066c939a63b7b862'
'3bd0f10d8c001304bdfc1a7902ae6d35') '3bd0f10d8c001304bdfc1a7902ae6d35')
@ -878,7 +878,7 @@ class Motion(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import x16r_hash import x16r_hash
return x16r_hash.getPoWHash(header) return x16r_hash.getPoWHash(header)
@ -912,7 +912,7 @@ class Dash(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import x11_hash import x11_hash
return x11_hash.getPoWHash(header) return x11_hash.getPoWHash(header)
@ -1014,7 +1014,7 @@ class FairCoin(Coin):
@classmethod @classmethod
def block(cls, raw_block, height): def block(cls, raw_block, height):
'''Return a Block namedtuple given a raw block and its height.''' """Return a Block namedtuple given a raw block and its height."""
if height > 0: if height > 0:
return super().block(raw_block, height) return super().block(raw_block, height)
else: else:
@ -1465,7 +1465,7 @@ class Bitzeny(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import zny_yescrypt import zny_yescrypt
return zny_yescrypt.getPoWHash(header) return zny_yescrypt.getPoWHash(header)
@ -1513,7 +1513,7 @@ class Denarius(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import tribus_hash import tribus_hash
return tribus_hash.getPoWHash(header) return tribus_hash.getPoWHash(header)
@ -1552,11 +1552,11 @@ class Sibcoin(Dash):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
''' """
Given a header return the hash for sibcoin. Given a header return the hash for sibcoin.
Need to download `x11_gost_hash` module Need to download `x11_gost_hash` module
Source code: https://github.com/ivansib/x11_gost_hash Source code: https://github.com/ivansib/x11_gost_hash
''' """
import x11_gost_hash import x11_gost_hash
return x11_gost_hash.getPoWHash(header) return x11_gost_hash.getPoWHash(header)
@ -1724,7 +1724,7 @@ class BitcoinAtom(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return hash''' """Given a header return hash"""
header_to_be_hashed = header[:cls.BASIC_HEADER_SIZE] header_to_be_hashed = header[:cls.BASIC_HEADER_SIZE]
# New block header format has some extra flags in the end # New block header format has some extra flags in the end
if len(header) == cls.HEADER_SIZE_POST_FORK: if len(header) == cls.HEADER_SIZE_POST_FORK:
@ -1737,7 +1737,7 @@ class BitcoinAtom(Coin):
@classmethod @classmethod
def block_header(cls, block, height): def block_header(cls, block, height):
'''Return the block header bytes''' """Return the block header bytes"""
deserializer = cls.DESERIALIZER(block) deserializer = cls.DESERIALIZER(block)
return deserializer.read_header(height, cls.BASIC_HEADER_SIZE) return deserializer.read_header(height, cls.BASIC_HEADER_SIZE)
@ -1777,12 +1777,12 @@ class Decred(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
return cls.HEADER_HASH(header) return cls.HEADER_HASH(header)
@classmethod @classmethod
def block(cls, raw_block, height): def block(cls, raw_block, height):
'''Return a Block namedtuple given a raw block and its height.''' """Return a Block namedtuple given a raw block and its height."""
if height > 0: if height > 0:
return super().block(raw_block, height) return super().block(raw_block, height)
else: else:
@ -1837,11 +1837,11 @@ class Axe(Dash):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
''' """
Given a header return the hash for AXE. Given a header return the hash for AXE.
Need to download `axe_hash` module Need to download `axe_hash` module
Source code: https://github.com/AXErunners/axe_hash Source code: https://github.com/AXErunners/axe_hash
''' """
import x11_hash import x11_hash
return x11_hash.getPoWHash(header) return x11_hash.getPoWHash(header)
@ -1867,11 +1867,11 @@ class Xuez(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
''' """
Given a header return the hash for Xuez. Given a header return the hash for Xuez.
Need to download `xevan_hash` module Need to download `xevan_hash` module
Source code: https://github.com/xuez/xuez Source code: https://github.com/xuez/xuez
''' """
version, = util.unpack_le_uint32_from(header) version, = util.unpack_le_uint32_from(header)
import xevan_hash import xevan_hash
@ -1915,7 +1915,7 @@ class Pac(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import x11_hash import x11_hash
return x11_hash.getPoWHash(header) return x11_hash.getPoWHash(header)
@ -1960,7 +1960,7 @@ class Polis(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import x11_hash import x11_hash
return x11_hash.getPoWHash(header) return x11_hash.getPoWHash(header)
@ -1989,7 +1989,7 @@ class ColossusXT(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import quark_hash import quark_hash
return quark_hash.getPoWHash(header) return quark_hash.getPoWHash(header)
@ -2018,7 +2018,7 @@ class GoByte(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import neoscrypt import neoscrypt
return neoscrypt.getPoWHash(header) return neoscrypt.getPoWHash(header)
@ -2047,7 +2047,7 @@ class Monoeci(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import x11_hash import x11_hash
return x11_hash.getPoWHash(header) return x11_hash.getPoWHash(header)
@ -2082,7 +2082,7 @@ class Minexcoin(EquihashMixin, Coin):
@classmethod @classmethod
def block_header(cls, block, height): def block_header(cls, block, height):
'''Return the block header bytes''' """Return the block header bytes"""
deserializer = cls.DESERIALIZER(block) deserializer = cls.DESERIALIZER(block)
return deserializer.read_header(height, cls.HEADER_SIZE_NO_SOLUTION) return deserializer.read_header(height, cls.HEADER_SIZE_NO_SOLUTION)
@ -2116,7 +2116,7 @@ class Groestlcoin(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
return cls.grshash(header) return cls.grshash(header)
ENCODE_CHECK = partial(Base58.encode_check, hash_fn=grshash) ENCODE_CHECK = partial(Base58.encode_check, hash_fn=grshash)
@ -2224,7 +2224,7 @@ class Bitg(Coin):
@classmethod @classmethod
def header_hash(cls, header): def header_hash(cls, header):
'''Given a header return the hash.''' """Given a header return the hash."""
import quark_hash import quark_hash
return quark_hash.getPoWHash(header) return quark_hash.getPoWHash(header)

View file

@ -5,8 +5,8 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Class for handling asynchronous connections to a blockchain """Class for handling asynchronous connections to a blockchain
daemon.''' daemon."""
import asyncio import asyncio
import itertools import itertools
@ -26,19 +26,19 @@ from torba.rpc import JSONRPC
class DaemonError(Exception): class DaemonError(Exception):
'''Raised when the daemon returns an error in its results.''' """Raised when the daemon returns an error in its results."""
class WarmingUpError(Exception): class WarmingUpError(Exception):
'''Internal - when the daemon is warming up.''' """Internal - when the daemon is warming up."""
class WorkQueueFullError(Exception): class WorkQueueFullError(Exception):
'''Internal - when the daemon's work queue is full.''' """Internal - when the daemon's work queue is full."""
class Daemon: class Daemon:
'''Handles connections to a daemon at the given URL.''' """Handles connections to a daemon at the given URL."""
WARMING_UP = -28 WARMING_UP = -28
id_counter = itertools.count() id_counter = itertools.count()
@ -57,7 +57,7 @@ class Daemon:
self.available_rpcs = {} self.available_rpcs = {}
def set_url(self, url): def set_url(self, url):
'''Set the URLS to the given list, and switch to the first one.''' """Set the URLS to the given list, and switch to the first one."""
urls = url.split(',') urls = url.split(',')
urls = [self.coin.sanitize_url(url) for url in urls] urls = [self.coin.sanitize_url(url) for url in urls]
for n, url in enumerate(urls): for n, url in enumerate(urls):
@ -68,19 +68,19 @@ class Daemon:
self.urls = urls self.urls = urls
def current_url(self): def current_url(self):
'''Returns the current daemon URL.''' """Returns the current daemon URL."""
return self.urls[self.url_index] return self.urls[self.url_index]
def logged_url(self, url=None): def logged_url(self, url=None):
'''The host and port part, for logging.''' """The host and port part, for logging."""
url = url or self.current_url() url = url or self.current_url()
return url[url.rindex('@') + 1:] return url[url.rindex('@') + 1:]
def failover(self): def failover(self):
'''Call to fail-over to the next daemon URL. """Call to fail-over to the next daemon URL.
Returns False if there is only one, otherwise True. Returns False if there is only one, otherwise True.
''' """
if len(self.urls) > 1: if len(self.urls) > 1:
self.url_index = (self.url_index + 1) % len(self.urls) self.url_index = (self.url_index + 1) % len(self.urls)
self.logger.info(f'failing over to {self.logged_url()}') self.logger.info(f'failing over to {self.logged_url()}')
@ -88,7 +88,7 @@ class Daemon:
return False return False
def client_session(self): def client_session(self):
'''An aiohttp client session.''' """An aiohttp client session."""
return aiohttp.ClientSession() return aiohttp.ClientSession()
async def _send_data(self, data): async def _send_data(self, data):
@ -107,11 +107,11 @@ class Daemon:
raise DaemonError(text) raise DaemonError(text)
async def _send(self, payload, processor): async def _send(self, payload, processor):
'''Send a payload to be converted to JSON. """Send a payload to be converted to JSON.
Handles temporary connection issues. Daemon reponse errors Handles temporary connection issues. Daemon reponse errors
are raise through DaemonError. are raise through DaemonError.
''' """
def log_error(error): def log_error(error):
nonlocal last_error_log, retry nonlocal last_error_log, retry
now = time.time() now = time.time()
@ -154,7 +154,7 @@ class Daemon:
retry = max(min(self.max_retry, retry * 2), self.init_retry) retry = max(min(self.max_retry, retry * 2), self.init_retry)
async def _send_single(self, method, params=None): async def _send_single(self, method, params=None):
'''Send a single request to the daemon.''' """Send a single request to the daemon."""
def processor(result): def processor(result):
err = result['error'] err = result['error']
if not err: if not err:
@ -169,11 +169,11 @@ class Daemon:
return await self._send(payload, processor) return await self._send(payload, processor)
async def _send_vector(self, method, params_iterable, replace_errs=False): async def _send_vector(self, method, params_iterable, replace_errs=False):
'''Send several requests of the same method. """Send several requests of the same method.
The result will be an array of the same length as params_iterable. The result will be an array of the same length as params_iterable.
If replace_errs is true, any item with an error is returned as None, If replace_errs is true, any item with an error is returned as None,
otherwise an exception is raised.''' otherwise an exception is raised."""
def processor(result): def processor(result):
errs = [item['error'] for item in result if item['error']] errs = [item['error'] for item in result if item['error']]
if any(err.get('code') == self.WARMING_UP for err in errs): if any(err.get('code') == self.WARMING_UP for err in errs):
@ -189,10 +189,10 @@ class Daemon:
return [] return []
async def _is_rpc_available(self, method): async def _is_rpc_available(self, method):
'''Return whether given RPC method is available in the daemon. """Return whether given RPC method is available in the daemon.
Results are cached and the daemon will generally not be queried with Results are cached and the daemon will generally not be queried with
the same method more than once.''' the same method more than once."""
available = self.available_rpcs.get(method) available = self.available_rpcs.get(method)
if available is None: if available is None:
available = True available = True
@ -206,30 +206,30 @@ class Daemon:
return available return available
async def block_hex_hashes(self, first, count): async def block_hex_hashes(self, first, count):
'''Return the hex hashes of count block starting at height first.''' """Return the hex hashes of count block starting at height first."""
params_iterable = ((h, ) for h in range(first, first + count)) params_iterable = ((h, ) for h in range(first, first + count))
return await self._send_vector('getblockhash', params_iterable) return await self._send_vector('getblockhash', params_iterable)
async def deserialised_block(self, hex_hash): async def deserialised_block(self, hex_hash):
'''Return the deserialised block with the given hex hash.''' """Return the deserialised block with the given hex hash."""
return await self._send_single('getblock', (hex_hash, True)) return await self._send_single('getblock', (hex_hash, True))
async def raw_blocks(self, hex_hashes): async def raw_blocks(self, hex_hashes):
'''Return the raw binary blocks with the given hex hashes.''' """Return the raw binary blocks with the given hex hashes."""
params_iterable = ((h, False) for h in hex_hashes) params_iterable = ((h, False) for h in hex_hashes)
blocks = await self._send_vector('getblock', params_iterable) blocks = await self._send_vector('getblock', params_iterable)
# Convert hex string to bytes # Convert hex string to bytes
return [hex_to_bytes(block) for block in blocks] return [hex_to_bytes(block) for block in blocks]
async def mempool_hashes(self): async def mempool_hashes(self):
'''Update our record of the daemon's mempool hashes.''' """Update our record of the daemon's mempool hashes."""
return await self._send_single('getrawmempool') return await self._send_single('getrawmempool')
async def estimatefee(self, block_count): async def estimatefee(self, block_count):
'''Return the fee estimate for the block count. Units are whole """Return the fee estimate for the block count. Units are whole
currency units per KB, e.g. 0.00000995, or -1 if no estimate currency units per KB, e.g. 0.00000995, or -1 if no estimate
is available. is available.
''' """
args = (block_count, ) args = (block_count, )
if await self._is_rpc_available('estimatesmartfee'): if await self._is_rpc_available('estimatesmartfee'):
estimate = await self._send_single('estimatesmartfee', args) estimate = await self._send_single('estimatesmartfee', args)
@ -237,25 +237,25 @@ class Daemon:
return await self._send_single('estimatefee', args) return await self._send_single('estimatefee', args)
async def getnetworkinfo(self): async def getnetworkinfo(self):
'''Return the result of the 'getnetworkinfo' RPC call.''' """Return the result of the 'getnetworkinfo' RPC call."""
return await self._send_single('getnetworkinfo') return await self._send_single('getnetworkinfo')
async def relayfee(self): async def relayfee(self):
'''The minimum fee a low-priority tx must pay in order to be accepted """The minimum fee a low-priority tx must pay in order to be accepted
to the daemon's memory pool.''' to the daemon's memory pool."""
network_info = await self.getnetworkinfo() network_info = await self.getnetworkinfo()
return network_info['relayfee'] return network_info['relayfee']
async def getrawtransaction(self, hex_hash, verbose=False): async def getrawtransaction(self, hex_hash, verbose=False):
'''Return the serialized raw transaction with the given hash.''' """Return the serialized raw transaction with the given hash."""
# Cast to int because some coin daemons are old and require it # Cast to int because some coin daemons are old and require it
return await self._send_single('getrawtransaction', return await self._send_single('getrawtransaction',
(hex_hash, int(verbose))) (hex_hash, int(verbose)))
async def getrawtransactions(self, hex_hashes, replace_errs=True): async def getrawtransactions(self, hex_hashes, replace_errs=True):
'''Return the serialized raw transactions with the given hashes. """Return the serialized raw transactions with the given hashes.
Replaces errors with None by default.''' Replaces errors with None by default."""
params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes) params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes)
txs = await self._send_vector('getrawtransaction', params_iterable, txs = await self._send_vector('getrawtransaction', params_iterable,
replace_errs=replace_errs) replace_errs=replace_errs)
@ -263,57 +263,57 @@ class Daemon:
return [hex_to_bytes(tx) if tx else None for tx in txs] return [hex_to_bytes(tx) if tx else None for tx in txs]
async def broadcast_transaction(self, raw_tx): async def broadcast_transaction(self, raw_tx):
'''Broadcast a transaction to the network.''' """Broadcast a transaction to the network."""
return await self._send_single('sendrawtransaction', (raw_tx, )) return await self._send_single('sendrawtransaction', (raw_tx, ))
async def height(self): async def height(self):
'''Query the daemon for its current height.''' """Query the daemon for its current height."""
self._height = await self._send_single('getblockcount') self._height = await self._send_single('getblockcount')
return self._height return self._height
def cached_height(self): def cached_height(self):
'''Return the cached daemon height. """Return the cached daemon height.
If the daemon has not been queried yet this returns None.''' If the daemon has not been queried yet this returns None."""
return self._height return self._height
class DashDaemon(Daemon): class DashDaemon(Daemon):
async def masternode_broadcast(self, params): async def masternode_broadcast(self, params):
'''Broadcast a transaction to the network.''' """Broadcast a transaction to the network."""
return await self._send_single('masternodebroadcast', params) return await self._send_single('masternodebroadcast', params)
async def masternode_list(self, params): async def masternode_list(self, params):
'''Return the masternode status.''' """Return the masternode status."""
return await self._send_single('masternodelist', params) return await self._send_single('masternodelist', params)
class FakeEstimateFeeDaemon(Daemon): class FakeEstimateFeeDaemon(Daemon):
'''Daemon that simulates estimatefee and relayfee RPC calls. Coin that """Daemon that simulates estimatefee and relayfee RPC calls. Coin that
wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE''' wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE"""
async def estimatefee(self, block_count): async def estimatefee(self, block_count):
'''Return the fee estimate for the given parameters.''' """Return the fee estimate for the given parameters."""
return self.coin.ESTIMATE_FEE return self.coin.ESTIMATE_FEE
async def relayfee(self): async def relayfee(self):
'''The minimum fee a low-priority tx must pay in order to be accepted """The minimum fee a low-priority tx must pay in order to be accepted
to the daemon's memory pool.''' to the daemon's memory pool."""
return self.coin.RELAY_FEE return self.coin.RELAY_FEE
class LegacyRPCDaemon(Daemon): class LegacyRPCDaemon(Daemon):
'''Handles connections to a daemon at the given URL. """Handles connections to a daemon at the given URL.
This class is useful for daemons that don't have the new 'getblock' This class is useful for daemons that don't have the new 'getblock'
RPC call that returns the block in hex, the workaround is to manually RPC call that returns the block in hex, the workaround is to manually
recreate the block bytes. The recreated block bytes may not be the exact recreate the block bytes. The recreated block bytes may not be the exact
as in the underlying blockchain but it is good enough for our indexing as in the underlying blockchain but it is good enough for our indexing
purposes.''' purposes."""
async def raw_blocks(self, hex_hashes): async def raw_blocks(self, hex_hashes):
'''Return the raw binary blocks with the given hex hashes.''' """Return the raw binary blocks with the given hex hashes."""
params_iterable = ((h, ) for h in hex_hashes) params_iterable = ((h, ) for h in hex_hashes)
block_info = await self._send_vector('getblock', params_iterable) block_info = await self._send_vector('getblock', params_iterable)
@ -339,7 +339,7 @@ class LegacyRPCDaemon(Daemon):
]) ])
async def make_raw_block(self, b): async def make_raw_block(self, b):
'''Construct a raw block''' """Construct a raw block"""
header = await self.make_raw_header(b) header = await self.make_raw_header(b)
@ -365,7 +365,7 @@ class LegacyRPCDaemon(Daemon):
class DecredDaemon(Daemon): class DecredDaemon(Daemon):
async def raw_blocks(self, hex_hashes): async def raw_blocks(self, hex_hashes):
'''Return the raw binary blocks with the given hex hashes.''' """Return the raw binary blocks with the given hex hashes."""
params_iterable = ((h, False) for h in hex_hashes) params_iterable = ((h, False) for h in hex_hashes)
blocks = await self._send_vector('getblock', params_iterable) blocks = await self._send_vector('getblock', params_iterable)
@ -448,12 +448,12 @@ class DecredDaemon(Daemon):
class PreLegacyRPCDaemon(LegacyRPCDaemon): class PreLegacyRPCDaemon(LegacyRPCDaemon):
'''Handles connections to a daemon at the given URL. """Handles connections to a daemon at the given URL.
This class is useful for daemons that don't have the new 'getblock' This class is useful for daemons that don't have the new 'getblock'
RPC call that returns the block in hex, and need the False parameter RPC call that returns the block in hex, and need the False parameter
for the getblock''' for the getblock"""
async def deserialised_block(self, hex_hash): async def deserialised_block(self, hex_hash):
'''Return the deserialised block with the given hex hash.''' """Return the deserialised block with the given hex hash."""
return await self._send_single('getblock', (hex_hash, False)) return await self._send_single('getblock', (hex_hash, False))

View file

@ -6,7 +6,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Interface to the blockchain database.''' """Interface to the blockchain database."""
import asyncio import asyncio
@ -47,16 +47,16 @@ class FlushData:
class DB: class DB:
'''Simple wrapper of the backend database for querying. """Simple wrapper of the backend database for querying.
Performs no DB update, though the DB will be cleaned on opening if Performs no DB update, though the DB will be cleaned on opening if
it was shutdown uncleanly. it was shutdown uncleanly.
''' """
DB_VERSIONS = [6] DB_VERSIONS = [6]
class DBError(Exception): class DBError(Exception):
'''Raised on general DB errors generally indicating corruption.''' """Raised on general DB errors generally indicating corruption."""
def __init__(self, env): def __init__(self, env):
self.logger = util.class_logger(__name__, self.__class__.__name__) self.logger = util.class_logger(__name__, self.__class__.__name__)
@ -142,18 +142,18 @@ class DB:
await self._open_dbs(True, True) await self._open_dbs(True, True)
async def open_for_sync(self): async def open_for_sync(self):
'''Open the databases to sync to the daemon. """Open the databases to sync to the daemon.
When syncing we want to reserve a lot of open files for the When syncing we want to reserve a lot of open files for the
synchronization. When serving clients we want the open files for synchronization. When serving clients we want the open files for
serving network connections. serving network connections.
''' """
await self._open_dbs(True, False) await self._open_dbs(True, False)
async def open_for_serving(self): async def open_for_serving(self):
'''Open the databases for serving. If they are already open they are """Open the databases for serving. If they are already open they are
closed first. closed first.
''' """
if self.utxo_db: if self.utxo_db:
self.logger.info('closing DBs to re-open for serving') self.logger.info('closing DBs to re-open for serving')
self.utxo_db.close() self.utxo_db.close()
@ -176,7 +176,7 @@ class DB:
# Flushing # Flushing
def assert_flushed(self, flush_data): def assert_flushed(self, flush_data):
'''Asserts state is fully flushed.''' """Asserts state is fully flushed."""
assert flush_data.tx_count == self.fs_tx_count == self.db_tx_count assert flush_data.tx_count == self.fs_tx_count == self.db_tx_count
assert flush_data.height == self.fs_height == self.db_height assert flush_data.height == self.fs_height == self.db_height
assert flush_data.tip == self.db_tip assert flush_data.tip == self.db_tip
@ -188,8 +188,8 @@ class DB:
self.history.assert_flushed() self.history.assert_flushed()
def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining): def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining):
'''Flush out cached state. History is always flushed; UTXOs are """Flush out cached state. History is always flushed; UTXOs are
flushed if flush_utxos.''' flushed if flush_utxos."""
if flush_data.height == self.db_height: if flush_data.height == self.db_height:
self.assert_flushed(flush_data) self.assert_flushed(flush_data)
return return
@ -231,12 +231,12 @@ class DB:
f'ETA: {formatted_time(eta)}') f'ETA: {formatted_time(eta)}')
def flush_fs(self, flush_data): def flush_fs(self, flush_data):
'''Write headers, tx counts and block tx hashes to the filesystem. """Write headers, tx counts and block tx hashes to the filesystem.
The first height to write is self.fs_height + 1. The FS The first height to write is self.fs_height + 1. The FS
metadata is all append-only, so in a crash we just pick up metadata is all append-only, so in a crash we just pick up
again from the height stored in the DB. again from the height stored in the DB.
''' """
prior_tx_count = (self.tx_counts[self.fs_height] prior_tx_count = (self.tx_counts[self.fs_height]
if self.fs_height >= 0 else 0) if self.fs_height >= 0 else 0)
assert len(flush_data.block_tx_hashes) == len(flush_data.headers) assert len(flush_data.block_tx_hashes) == len(flush_data.headers)
@ -274,7 +274,7 @@ class DB:
self.history.flush() self.history.flush()
def flush_utxo_db(self, batch, flush_data): def flush_utxo_db(self, batch, flush_data):
'''Flush the cached DB writes and UTXO set to the batch.''' """Flush the cached DB writes and UTXO set to the batch."""
# Care is needed because the writes generated by flushing the # Care is needed because the writes generated by flushing the
# UTXO state may have keys in common with our write cache or # UTXO state may have keys in common with our write cache or
# may be in the DB already. # may be in the DB already.
@ -317,7 +317,7 @@ class DB:
self.db_tip = flush_data.tip self.db_tip = flush_data.tip
def flush_state(self, batch): def flush_state(self, batch):
'''Flush chain state to the batch.''' """Flush chain state to the batch."""
now = time.time() now = time.time()
self.wall_time += now - self.last_flush self.wall_time += now - self.last_flush
self.last_flush = now self.last_flush = now
@ -325,7 +325,7 @@ class DB:
self.write_utxo_state(batch) self.write_utxo_state(batch)
def flush_backup(self, flush_data, touched): def flush_backup(self, flush_data, touched):
'''Like flush_dbs() but when backing up. All UTXOs are flushed.''' """Like flush_dbs() but when backing up. All UTXOs are flushed."""
assert not flush_data.headers assert not flush_data.headers
assert not flush_data.block_tx_hashes assert not flush_data.block_tx_hashes
assert flush_data.height < self.db_height assert flush_data.height < self.db_height
@ -369,28 +369,28 @@ class DB:
- self.dynamic_header_offset(height) - self.dynamic_header_offset(height)
def backup_fs(self, height, tx_count): def backup_fs(self, height, tx_count):
'''Back up during a reorg. This just updates our pointers.''' """Back up during a reorg. This just updates our pointers."""
self.fs_height = height self.fs_height = height
self.fs_tx_count = tx_count self.fs_tx_count = tx_count
# Truncate header_mc: header count is 1 more than the height. # Truncate header_mc: header count is 1 more than the height.
self.header_mc.truncate(height + 1) self.header_mc.truncate(height + 1)
async def raw_header(self, height): async def raw_header(self, height):
'''Return the binary header at the given height.''' """Return the binary header at the given height."""
header, n = await self.read_headers(height, 1) header, n = await self.read_headers(height, 1)
if n != 1: if n != 1:
raise IndexError(f'height {height:,d} out of range') raise IndexError(f'height {height:,d} out of range')
return header return header
async def read_headers(self, start_height, count): async def read_headers(self, start_height, count):
'''Requires start_height >= 0, count >= 0. Reads as many headers as """Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This are available starting at start_height up to count. This
would be zero if start_height is beyond self.db_height, for would be zero if start_height is beyond self.db_height, for
example. example.
Returns a (binary, n) pair where binary is the concatenated Returns a (binary, n) pair where binary is the concatenated
binary headers, and n is the count of headers returned. binary headers, and n is the count of headers returned.
''' """
if start_height < 0 or count < 0: if start_height < 0 or count < 0:
raise self.DBError(f'{count:,d} headers starting at ' raise self.DBError(f'{count:,d} headers starting at '
f'{start_height:,d} not on disk') f'{start_height:,d} not on disk')
@ -407,9 +407,9 @@ class DB:
return await asyncio.get_event_loop().run_in_executor(None, read_headers) return await asyncio.get_event_loop().run_in_executor(None, read_headers)
def fs_tx_hash(self, tx_num): def fs_tx_hash(self, tx_num):
'''Return a par (tx_hash, tx_height) for the given tx number. """Return a par (tx_hash, tx_height) for the given tx number.
If the tx_height is not on disk, returns (None, tx_height).''' If the tx_height is not on disk, returns (None, tx_height)."""
tx_height = bisect_right(self.tx_counts, tx_num) tx_height = bisect_right(self.tx_counts, tx_num)
if tx_height > self.db_height: if tx_height > self.db_height:
tx_hash = None tx_hash = None
@ -432,12 +432,12 @@ class DB:
return [self.coin.header_hash(header) for header in headers] return [self.coin.header_hash(header) for header in headers]
async def limited_history(self, hashX, *, limit=1000): async def limited_history(self, hashX, *, limit=1000):
'''Return an unpruned, sorted list of (tx_hash, height) tuples of """Return an unpruned, sorted list of (tx_hash, height) tuples of
confirmed transactions that touched the address, earliest in confirmed transactions that touched the address, earliest in
the blockchain first. Includes both spending and receiving the blockchain first. Includes both spending and receiving
transactions. By default returns at most 1000 entries. Set transactions. By default returns at most 1000 entries. Set
limit to None to get them all. limit to None to get them all.
''' """
def read_history(): def read_history():
tx_nums = list(self.history.get_txnums(hashX, limit)) tx_nums = list(self.history.get_txnums(hashX, limit))
fs_tx_hash = self.fs_tx_hash fs_tx_hash = self.fs_tx_hash
@ -454,19 +454,19 @@ class DB:
# -- Undo information # -- Undo information
def min_undo_height(self, max_height): def min_undo_height(self, max_height):
'''Returns a height from which we should store undo info.''' """Returns a height from which we should store undo info."""
return max_height - self.env.reorg_limit + 1 return max_height - self.env.reorg_limit + 1
def undo_key(self, height): def undo_key(self, height):
'''DB key for undo information at the given height.''' """DB key for undo information at the given height."""
return b'U' + pack('>I', height) return b'U' + pack('>I', height)
def read_undo_info(self, height): def read_undo_info(self, height):
'''Read undo information from a file for the current height.''' """Read undo information from a file for the current height."""
return self.utxo_db.get(self.undo_key(height)) return self.utxo_db.get(self.undo_key(height))
def flush_undo_infos(self, batch_put, undo_infos): def flush_undo_infos(self, batch_put, undo_infos):
'''undo_infos is a list of (undo_info, height) pairs.''' """undo_infos is a list of (undo_info, height) pairs."""
for undo_info, height in undo_infos: for undo_info, height in undo_infos:
batch_put(self.undo_key(height), b''.join(undo_info)) batch_put(self.undo_key(height), b''.join(undo_info))
@ -477,13 +477,13 @@ class DB:
return f'{self.raw_block_prefix()}{height:d}' return f'{self.raw_block_prefix()}{height:d}'
def read_raw_block(self, height): def read_raw_block(self, height):
'''Returns a raw block read from disk. Raises FileNotFoundError """Returns a raw block read from disk. Raises FileNotFoundError
if the block isn't on-disk.''' if the block isn't on-disk."""
with util.open_file(self.raw_block_path(height)) as f: with util.open_file(self.raw_block_path(height)) as f:
return f.read(-1) return f.read(-1)
def write_raw_block(self, block, height): def write_raw_block(self, block, height):
'''Write a raw block to disk.''' """Write a raw block to disk."""
with util.open_truncate(self.raw_block_path(height)) as f: with util.open_truncate(self.raw_block_path(height)) as f:
f.write(block) f.write(block)
# Delete old blocks to prevent them accumulating # Delete old blocks to prevent them accumulating
@ -494,7 +494,7 @@ class DB:
pass pass
def clear_excess_undo_info(self): def clear_excess_undo_info(self):
'''Clear excess undo info. Only most recent N are kept.''' """Clear excess undo info. Only most recent N are kept."""
prefix = b'U' prefix = b'U'
min_height = self.min_undo_height(self.db_height) min_height = self.min_undo_height(self.db_height)
keys = [] keys = []
@ -578,7 +578,7 @@ class DB:
.format(util.formatted_time(self.wall_time))) .format(util.formatted_time(self.wall_time)))
def write_utxo_state(self, batch): def write_utxo_state(self, batch):
'''Write (UTXO) state to the batch.''' """Write (UTXO) state to the batch."""
state = { state = {
'genesis': self.coin.GENESIS_HASH, 'genesis': self.coin.GENESIS_HASH,
'height': self.db_height, 'height': self.db_height,
@ -597,7 +597,7 @@ class DB:
self.write_utxo_state(batch) self.write_utxo_state(batch)
async def all_utxos(self, hashX): async def all_utxos(self, hashX):
'''Return all UTXOs for an address sorted in no particular order.''' """Return all UTXOs for an address sorted in no particular order."""
def read_utxos(): def read_utxos():
utxos = [] utxos = []
utxos_append = utxos.append utxos_append = utxos.append
@ -621,15 +621,15 @@ class DB:
await sleep(0.25) await sleep(0.25)
async def lookup_utxos(self, prevouts): async def lookup_utxos(self, prevouts):
'''For each prevout, lookup it up in the DB and return a (hashX, """For each prevout, lookup it up in the DB and return a (hashX,
value) pair or None if not found. value) pair or None if not found.
Used by the mempool code. Used by the mempool code.
''' """
def lookup_hashXs(): def lookup_hashXs():
'''Return (hashX, suffix) pairs, or None if not found, """Return (hashX, suffix) pairs, or None if not found,
for each prevout. for each prevout.
''' """
def lookup_hashX(tx_hash, tx_idx): def lookup_hashX(tx_hash, tx_idx):
idx_packed = pack('<H', tx_idx) idx_packed = pack('<H', tx_idx)

View file

@ -5,10 +5,10 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''An enum-like type with reverse lookup. """An enum-like type with reverse lookup.
Source: Python Cookbook, http://code.activestate.com/recipes/67107/ Source: Python Cookbook, http://code.activestate.com/recipes/67107/
''' """
class EnumError(Exception): class EnumError(Exception):

View file

@ -139,13 +139,13 @@ class Env:
raise self.Error('unknown event loop policy "{}"'.format(policy)) raise self.Error('unknown event loop policy "{}"'.format(policy))
def cs_host(self, *, for_rpc): def cs_host(self, *, for_rpc):
'''Returns the 'host' argument to pass to asyncio's create_server """Returns the 'host' argument to pass to asyncio's create_server
call. The result can be a single host name string, a list of call. The result can be a single host name string, a list of
host name strings, or an empty string to bind to all interfaces. host name strings, or an empty string to bind to all interfaces.
If rpc is True the host to use for the RPC server is returned. If rpc is True the host to use for the RPC server is returned.
Otherwise the host to use for SSL/TCP servers is returned. Otherwise the host to use for SSL/TCP servers is returned.
''' """
host = self.rpc_host if for_rpc else self.host host = self.rpc_host if for_rpc else self.host
result = [part.strip() for part in host.split(',')] result = [part.strip() for part in host.split(',')]
if len(result) == 1: if len(result) == 1:
@ -161,9 +161,9 @@ class Env:
return result return result
def sane_max_sessions(self): def sane_max_sessions(self):
'''Return the maximum number of sessions to permit. Normally this """Return the maximum number of sessions to permit. Normally this
is MAX_SESSIONS. However, to prevent open file exhaustion, ajdust is MAX_SESSIONS. However, to prevent open file exhaustion, ajdust
downwards if running with a small open file rlimit.''' downwards if running with a small open file rlimit."""
env_value = self.integer('MAX_SESSIONS', 1000) env_value = self.integer('MAX_SESSIONS', 1000)
nofile_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0] nofile_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
# We give the DB 250 files; allow ElectrumX 100 for itself # We give the DB 250 files; allow ElectrumX 100 for itself
@ -209,8 +209,8 @@ class Env:
.format(host)) .format(host))
def port(port_kind): def port(port_kind):
'''Returns the clearnet identity port, if any and not zero, """Returns the clearnet identity port, if any and not zero,
otherwise the listening port.''' otherwise the listening port."""
result = 0 result = 0
if clearnet: if clearnet:
result = getattr(clearnet, port_kind) result = getattr(clearnet, port_kind)

View file

@ -23,7 +23,7 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Cryptograph hash functions and related classes.''' """Cryptograph hash functions and related classes."""
import hashlib import hashlib
@ -39,53 +39,53 @@ HASHX_LEN = 11
def sha256(x): def sha256(x):
'''Simple wrapper of hashlib sha256.''' """Simple wrapper of hashlib sha256."""
return _sha256(x).digest() return _sha256(x).digest()
def ripemd160(x): def ripemd160(x):
'''Simple wrapper of hashlib ripemd160.''' """Simple wrapper of hashlib ripemd160."""
h = _new_hash('ripemd160') h = _new_hash('ripemd160')
h.update(x) h.update(x)
return h.digest() return h.digest()
def double_sha256(x): def double_sha256(x):
'''SHA-256 of SHA-256, as used extensively in bitcoin.''' """SHA-256 of SHA-256, as used extensively in bitcoin."""
return sha256(sha256(x)) return sha256(sha256(x))
def hmac_sha512(key, msg): def hmac_sha512(key, msg):
'''Use SHA-512 to provide an HMAC.''' """Use SHA-512 to provide an HMAC."""
return _new_hmac(key, msg, _sha512).digest() return _new_hmac(key, msg, _sha512).digest()
def hash160(x): def hash160(x):
'''RIPEMD-160 of SHA-256. """RIPEMD-160 of SHA-256.
Used to make bitcoin addresses from pubkeys.''' Used to make bitcoin addresses from pubkeys."""
return ripemd160(sha256(x)) return ripemd160(sha256(x))
def hash_to_hex_str(x): def hash_to_hex_str(x):
'''Convert a big-endian binary hash to displayed hex string. """Convert a big-endian binary hash to displayed hex string.
Display form of a binary hash is reversed and converted to hex. Display form of a binary hash is reversed and converted to hex.
''' """
return bytes(reversed(x)).hex() return bytes(reversed(x)).hex()
def hex_str_to_hash(x): def hex_str_to_hash(x):
'''Convert a displayed hex string to a binary hash.''' """Convert a displayed hex string to a binary hash."""
return bytes(reversed(hex_to_bytes(x))) return bytes(reversed(hex_to_bytes(x)))
class Base58Error(Exception): class Base58Error(Exception):
'''Exception used for Base58 errors.''' """Exception used for Base58 errors."""
class Base58: class Base58:
'''Class providing base 58 functionality.''' """Class providing base 58 functionality."""
chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
assert len(chars) == 58 assert len(chars) == 58
@ -143,8 +143,8 @@ class Base58:
@staticmethod @staticmethod
def decode_check(txt, *, hash_fn=double_sha256): def decode_check(txt, *, hash_fn=double_sha256):
'''Decodes a Base58Check-encoded string to a payload. The version """Decodes a Base58Check-encoded string to a payload. The version
prefixes it.''' prefixes it."""
be_bytes = Base58.decode(txt) be_bytes = Base58.decode(txt)
result, check = be_bytes[:-4], be_bytes[-4:] result, check = be_bytes[:-4], be_bytes[-4:]
if check != hash_fn(result)[:4]: if check != hash_fn(result)[:4]:

View file

@ -6,7 +6,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''History by script hash (address).''' """History by script hash (address)."""
import array import array
import ast import ast
@ -96,7 +96,7 @@ class History:
self.logger.info('deleted excess history entries') self.logger.info('deleted excess history entries')
def write_state(self, batch): def write_state(self, batch):
'''Write state to the history DB.''' """Write state to the history DB."""
state = { state = {
'flush_count': self.flush_count, 'flush_count': self.flush_count,
'comp_flush_count': self.comp_flush_count, 'comp_flush_count': self.comp_flush_count,
@ -174,10 +174,10 @@ class History:
self.logger.info(f'backing up removed {nremoves:,d} history entries') self.logger.info(f'backing up removed {nremoves:,d} history entries')
def get_txnums(self, hashX, limit=1000): def get_txnums(self, hashX, limit=1000):
'''Generator that returns an unpruned, sorted list of tx_nums in the """Generator that returns an unpruned, sorted list of tx_nums in the
history of a hashX. Includes both spending and receiving history of a hashX. Includes both spending and receiving
transactions. By default yields at most 1000 entries. Set transactions. By default yields at most 1000 entries. Set
limit to None to get them all. ''' limit to None to get them all. """
limit = util.resolve_limit(limit) limit = util.resolve_limit(limit)
for key, hist in self.db.iterator(prefix=hashX): for key, hist in self.db.iterator(prefix=hashX):
a = array.array('I') a = array.array('I')
@ -208,7 +208,7 @@ class History:
# flush_count is reset to comp_flush_count, and comp_flush_count to -1 # flush_count is reset to comp_flush_count, and comp_flush_count to -1
def _flush_compaction(self, cursor, write_items, keys_to_delete): def _flush_compaction(self, cursor, write_items, keys_to_delete):
'''Flush a single compaction pass as a batch.''' """Flush a single compaction pass as a batch."""
# Update compaction state # Update compaction state
if cursor == 65536: if cursor == 65536:
self.flush_count = self.comp_flush_count self.flush_count = self.comp_flush_count
@ -228,8 +228,8 @@ class History:
def _compact_hashX(self, hashX, hist_map, hist_list, def _compact_hashX(self, hashX, hist_map, hist_list,
write_items, keys_to_delete): write_items, keys_to_delete):
'''Compres history for a hashX. hist_list is an ordered list of """Compres history for a hashX. hist_list is an ordered list of
the histories to be compressed.''' the histories to be compressed."""
# History entries (tx numbers) are 4 bytes each. Distribute # History entries (tx numbers) are 4 bytes each. Distribute
# over rows of up to 50KB in size. A fixed row size means # over rows of up to 50KB in size. A fixed row size means
# future compactions will not need to update the first N - 1 # future compactions will not need to update the first N - 1
@ -263,8 +263,8 @@ class History:
return write_size return write_size
def _compact_prefix(self, prefix, write_items, keys_to_delete): def _compact_prefix(self, prefix, write_items, keys_to_delete):
'''Compact all history entries for hashXs beginning with the """Compact all history entries for hashXs beginning with the
given prefix. Update keys_to_delete and write.''' given prefix. Update keys_to_delete and write."""
prior_hashX = None prior_hashX = None
hist_map = {} hist_map = {}
hist_list = [] hist_list = []
@ -292,9 +292,9 @@ class History:
return write_size return write_size
def _compact_history(self, limit): def _compact_history(self, limit):
'''Inner loop of history compaction. Loops until limit bytes have """Inner loop of history compaction. Loops until limit bytes have
been processed. been processed.
''' """
keys_to_delete = set() keys_to_delete = set()
write_items = [] # A list of (key, value) pairs write_items = [] # A list of (key, value) pairs
write_size = 0 write_size = 0

View file

@ -5,7 +5,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Mempool handling.''' """Mempool handling."""
import asyncio import asyncio
import itertools import itertools
@ -39,48 +39,48 @@ class MemPoolTxSummary:
class MemPoolAPI(ABC): class MemPoolAPI(ABC):
'''A concrete instance of this class is passed to the MemPool object """A concrete instance of this class is passed to the MemPool object
and used by it to query DB and blockchain state.''' and used by it to query DB and blockchain state."""
@abstractmethod @abstractmethod
async def height(self): async def height(self):
'''Query bitcoind for its height.''' """Query bitcoind for its height."""
@abstractmethod @abstractmethod
def cached_height(self): def cached_height(self):
'''Return the height of bitcoind the last time it was queried, """Return the height of bitcoind the last time it was queried,
for any reason, without actually querying it. for any reason, without actually querying it.
''' """
@abstractmethod @abstractmethod
async def mempool_hashes(self): async def mempool_hashes(self):
'''Query bitcoind for the hashes of all transactions in its """Query bitcoind for the hashes of all transactions in its
mempool, returned as a list.''' mempool, returned as a list."""
@abstractmethod @abstractmethod
async def raw_transactions(self, hex_hashes): async def raw_transactions(self, hex_hashes):
'''Query bitcoind for the serialized raw transactions with the given """Query bitcoind for the serialized raw transactions with the given
hashes. Missing transactions are returned as None. hashes. Missing transactions are returned as None.
hex_hashes is an iterable of hexadecimal hash strings.''' hex_hashes is an iterable of hexadecimal hash strings."""
@abstractmethod @abstractmethod
async def lookup_utxos(self, prevouts): async def lookup_utxos(self, prevouts):
'''Return a list of (hashX, value) pairs each prevout if unspent, """Return a list of (hashX, value) pairs each prevout if unspent,
otherwise return None if spent or not found. otherwise return None if spent or not found.
prevouts - an iterable of (hash, index) pairs prevouts - an iterable of (hash, index) pairs
''' """
@abstractmethod @abstractmethod
async def on_mempool(self, touched, height): async def on_mempool(self, touched, height):
'''Called each time the mempool is synchronized. touched is a set of """Called each time the mempool is synchronized. touched is a set of
hashXs touched since the previous call. height is the hashXs touched since the previous call. height is the
daemon's height at the time the mempool was obtained.''' daemon's height at the time the mempool was obtained."""
class MemPool: class MemPool:
'''Representation of the daemon's mempool. """Representation of the daemon's mempool.
coin - a coin class from coins.py coin - a coin class from coins.py
api - an object implementing MemPoolAPI api - an object implementing MemPoolAPI
@ -91,7 +91,7 @@ class MemPool:
tx: tx_hash -> MemPoolTx tx: tx_hash -> MemPoolTx
hashXs: hashX -> set of all hashes of txs touching the hashX hashXs: hashX -> set of all hashes of txs touching the hashX
''' """
def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0): def __init__(self, coin, api, refresh_secs=5.0, log_status_secs=120.0):
assert isinstance(api, MemPoolAPI) assert isinstance(api, MemPoolAPI)
@ -107,7 +107,7 @@ class MemPool:
self.lock = Lock() self.lock = Lock()
async def _logging(self, synchronized_event): async def _logging(self, synchronized_event):
'''Print regular logs of mempool stats.''' """Print regular logs of mempool stats."""
self.logger.info('beginning processing of daemon mempool. ' self.logger.info('beginning processing of daemon mempool. '
'This can take some time...') 'This can take some time...')
start = time.time() start = time.time()
@ -156,12 +156,12 @@ class MemPool:
self.cached_compact_histogram = compact self.cached_compact_histogram = compact
def _accept_transactions(self, tx_map, utxo_map, touched): def _accept_transactions(self, tx_map, utxo_map, touched):
'''Accept transactions in tx_map to the mempool if all their inputs """Accept transactions in tx_map to the mempool if all their inputs
can be found in the existing mempool or a utxo_map from the can be found in the existing mempool or a utxo_map from the
DB. DB.
Returns an (unprocessed tx_map, unspent utxo_map) pair. Returns an (unprocessed tx_map, unspent utxo_map) pair.
''' """
hashXs = self.hashXs hashXs = self.hashXs
txs = self.txs txs = self.txs
@ -200,7 +200,7 @@ class MemPool:
return deferred, {prevout: utxo_map[prevout] for prevout in unspent} return deferred, {prevout: utxo_map[prevout] for prevout in unspent}
async def _refresh_hashes(self, synchronized_event): async def _refresh_hashes(self, synchronized_event):
'''Refresh our view of the daemon's mempool.''' """Refresh our view of the daemon's mempool."""
while True: while True:
height = self.api.cached_height() height = self.api.cached_height()
hex_hashes = await self.api.mempool_hashes() hex_hashes = await self.api.mempool_hashes()
@ -256,7 +256,7 @@ class MemPool:
return touched return touched
async def _fetch_and_accept(self, hashes, all_hashes, touched): async def _fetch_and_accept(self, hashes, all_hashes, touched):
'''Fetch a list of mempool transactions.''' """Fetch a list of mempool transactions."""
hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes) hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes)
raw_txs = await self.api.raw_transactions(hex_hashes_iter) raw_txs = await self.api.raw_transactions(hex_hashes_iter)
@ -303,7 +303,7 @@ class MemPool:
# #
async def keep_synchronized(self, synchronized_event): async def keep_synchronized(self, synchronized_event):
'''Keep the mempool synchronized with the daemon.''' """Keep the mempool synchronized with the daemon."""
await asyncio.wait([ await asyncio.wait([
self._refresh_hashes(synchronized_event), self._refresh_hashes(synchronized_event),
self._refresh_histogram(synchronized_event), self._refresh_histogram(synchronized_event),
@ -311,10 +311,10 @@ class MemPool:
]) ])
async def balance_delta(self, hashX): async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX. """Return the unconfirmed amount in the mempool for hashX.
Can be positive or negative. Can be positive or negative.
''' """
value = 0 value = 0
if hashX in self.hashXs: if hashX in self.hashXs:
for hash in self.hashXs[hashX]: for hash in self.hashXs[hashX]:
@ -324,16 +324,16 @@ class MemPool:
return value return value
async def compact_fee_histogram(self): async def compact_fee_histogram(self):
'''Return a compact fee histogram of the current mempool.''' """Return a compact fee histogram of the current mempool."""
return self.cached_compact_histogram return self.cached_compact_histogram
async def potential_spends(self, hashX): async def potential_spends(self, hashX):
'''Return a set of (prev_hash, prev_idx) pairs from mempool """Return a set of (prev_hash, prev_idx) pairs from mempool
transactions that touch hashX. transactions that touch hashX.
None, some or all of these may be spends of the hashX, but all None, some or all of these may be spends of the hashX, but all
actual spends of it (in the DB or mempool) will be included. actual spends of it (in the DB or mempool) will be included.
''' """
result = set() result = set()
for tx_hash in self.hashXs.get(hashX, ()): for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs[tx_hash] tx = self.txs[tx_hash]
@ -341,7 +341,7 @@ class MemPool:
return result return result
async def transaction_summaries(self, hashX): async def transaction_summaries(self, hashX):
'''Return a list of MemPoolTxSummary objects for the hashX.''' """Return a list of MemPoolTxSummary objects for the hashX."""
result = [] result = []
for tx_hash in self.hashXs.get(hashX, ()): for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs[tx_hash] tx = self.txs[tx_hash]
@ -350,12 +350,12 @@ class MemPool:
return result return result
async def unordered_UTXOs(self, hashX): async def unordered_UTXOs(self, hashX):
'''Return an unordered list of UTXO named tuples from mempool """Return an unordered list of UTXO named tuples from mempool
transactions that pay to hashX. transactions that pay to hashX.
This does not consider if any other mempool transactions spend This does not consider if any other mempool transactions spend
the outputs. the outputs.
''' """
utxos = [] utxos = []
for tx_hash in self.hashXs.get(hashX, ()): for tx_hash in self.hashXs.get(hashX, ()):
tx = self.txs.get(tx_hash) tx = self.txs.get(tx_hash)

View file

@ -24,7 +24,7 @@
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software. # and warranty status of this software.
'''Merkle trees, branches, proofs and roots.''' """Merkle trees, branches, proofs and roots."""
from asyncio import Event from asyncio import Event
from math import ceil, log from math import ceil, log
@ -33,12 +33,12 @@ from torba.server.hash import double_sha256
class Merkle: class Merkle:
'''Perform merkle tree calculations on binary hashes using a given hash """Perform merkle tree calculations on binary hashes using a given hash
function. function.
If the hash count is not even, the final hash is repeated when If the hash count is not even, the final hash is repeated when
calculating the next merkle layer up the tree. calculating the next merkle layer up the tree.
''' """
def __init__(self, hash_func=double_sha256): def __init__(self, hash_func=double_sha256):
self.hash_func = hash_func self.hash_func = hash_func
@ -47,7 +47,7 @@ class Merkle:
return self.branch_length(hash_count) + 1 return self.branch_length(hash_count) + 1
def branch_length(self, hash_count): def branch_length(self, hash_count):
'''Return the length of a merkle branch given the number of hashes.''' """Return the length of a merkle branch given the number of hashes."""
if not isinstance(hash_count, int): if not isinstance(hash_count, int):
raise TypeError('hash_count must be an integer') raise TypeError('hash_count must be an integer')
if hash_count < 1: if hash_count < 1:
@ -55,9 +55,9 @@ class Merkle:
return ceil(log(hash_count, 2)) return ceil(log(hash_count, 2))
def branch_and_root(self, hashes, index, length=None): def branch_and_root(self, hashes, index, length=None):
'''Return a (merkle branch, merkle_root) pair given hashes, and the """Return a (merkle branch, merkle_root) pair given hashes, and the
index of one of those hashes. index of one of those hashes.
''' """
hashes = list(hashes) hashes = list(hashes)
if not isinstance(index, int): if not isinstance(index, int):
raise TypeError('index must be an integer') raise TypeError('index must be an integer')
@ -86,12 +86,12 @@ class Merkle:
return branch, hashes[0] return branch, hashes[0]
def root(self, hashes, length=None): def root(self, hashes, length=None):
'''Return the merkle root of a non-empty iterable of binary hashes.''' """Return the merkle root of a non-empty iterable of binary hashes."""
branch, root = self.branch_and_root(hashes, 0, length) branch, root = self.branch_and_root(hashes, 0, length)
return root return root
def root_from_proof(self, hash, branch, index): def root_from_proof(self, hash, branch, index):
'''Return the merkle root given a hash, a merkle branch to it, and """Return the merkle root given a hash, a merkle branch to it, and
its index in the hashes array. its index in the hashes array.
branch is an iterable sorted deepest to shallowest. If the branch is an iterable sorted deepest to shallowest. If the
@ -102,7 +102,7 @@ class Merkle:
branch_length(). Unfortunately this is not easily done for branch_length(). Unfortunately this is not easily done for
bitcoin transactions as the number of transactions in a block bitcoin transactions as the number of transactions in a block
is unknown to an SPV client. is unknown to an SPV client.
''' """
hash_func = self.hash_func hash_func = self.hash_func
for elt in branch: for elt in branch:
if index & 1: if index & 1:
@ -115,8 +115,8 @@ class Merkle:
return hash return hash
def level(self, hashes, depth_higher): def level(self, hashes, depth_higher):
'''Return a level of the merkle tree of hashes the given depth """Return a level of the merkle tree of hashes the given depth
higher than the bottom row of the original tree.''' higher than the bottom row of the original tree."""
size = 1 << depth_higher size = 1 << depth_higher
root = self.root root = self.root
return [root(hashes[n: n + size], depth_higher) return [root(hashes[n: n + size], depth_higher)
@ -124,7 +124,7 @@ class Merkle:
def branch_and_root_from_level(self, level, leaf_hashes, index, def branch_and_root_from_level(self, level, leaf_hashes, index,
depth_higher): depth_higher):
'''Return a (merkle branch, merkle_root) pair when a merkle-tree has a """Return a (merkle branch, merkle_root) pair when a merkle-tree has a
level cached. level cached.
To maximally reduce the amount of data hashed in computing a To maximally reduce the amount of data hashed in computing a
@ -140,7 +140,7 @@ class Merkle:
index is the index in the full list of hashes of the hash whose index is the index in the full list of hashes of the hash whose
merkle branch we want. merkle branch we want.
''' """
if not isinstance(level, list): if not isinstance(level, list):
raise TypeError("level must be a list") raise TypeError("level must be a list")
if not isinstance(leaf_hashes, list): if not isinstance(leaf_hashes, list):
@ -157,14 +157,14 @@ class Merkle:
class MerkleCache: class MerkleCache:
'''A cache to calculate merkle branches efficiently.''' """A cache to calculate merkle branches efficiently."""
def __init__(self, merkle, source_func): def __init__(self, merkle, source_func):
'''Initialise a cache hashes taken from source_func: """Initialise a cache hashes taken from source_func:
async def source_func(index, count): async def source_func(index, count):
... ...
''' """
self.merkle = merkle self.merkle = merkle
self.source_func = source_func self.source_func = source_func
self.length = 0 self.length = 0
@ -175,9 +175,9 @@ class MerkleCache:
return 1 << self.depth_higher return 1 << self.depth_higher
def _leaf_start(self, index): def _leaf_start(self, index):
'''Given a level's depth higher and a hash index, return the leaf """Given a level's depth higher and a hash index, return the leaf
index and leaf hash count needed to calculate a merkle branch. index and leaf hash count needed to calculate a merkle branch.
''' """
depth_higher = self.depth_higher depth_higher = self.depth_higher
return (index >> depth_higher) << depth_higher return (index >> depth_higher) << depth_higher
@ -185,7 +185,7 @@ class MerkleCache:
return self.merkle.level(hashes, self.depth_higher) return self.merkle.level(hashes, self.depth_higher)
async def _extend_to(self, length): async def _extend_to(self, length):
'''Extend the length of the cache if necessary.''' """Extend the length of the cache if necessary."""
if length <= self.length: if length <= self.length:
return return
# Start from the beginning of any final partial segment. # Start from the beginning of any final partial segment.
@ -196,8 +196,8 @@ class MerkleCache:
self.length = length self.length = length
async def _level_for(self, length): async def _level_for(self, length):
'''Return a (level_length, final_hash) pair for a truncation """Return a (level_length, final_hash) pair for a truncation
of the hashes to the given length.''' of the hashes to the given length."""
if length == self.length: if length == self.length:
return self.level return self.level
level = self.level[:length >> self.depth_higher] level = self.level[:length >> self.depth_higher]
@ -208,15 +208,15 @@ class MerkleCache:
return level return level
async def initialize(self, length): async def initialize(self, length):
'''Call to initialize the cache to a source of given length.''' """Call to initialize the cache to a source of given length."""
self.length = length self.length = length
self.depth_higher = self.merkle.tree_depth(length) // 2 self.depth_higher = self.merkle.tree_depth(length) // 2
self.level = self._level(await self.source_func(0, length)) self.level = self._level(await self.source_func(0, length))
self.initialized.set() self.initialized.set()
def truncate(self, length): def truncate(self, length):
'''Truncate the cache so it covers no more than length underlying """Truncate the cache so it covers no more than length underlying
hashes.''' hashes."""
if not isinstance(length, int): if not isinstance(length, int):
raise TypeError('length must be an integer') raise TypeError('length must be an integer')
if length <= 0: if length <= 0:
@ -228,11 +228,11 @@ class MerkleCache:
self.level[length >> self.depth_higher:] = [] self.level[length >> self.depth_higher:] = []
async def branch_and_root(self, length, index): async def branch_and_root(self, length, index):
'''Return a merkle branch and root. Length is the number of """Return a merkle branch and root. Length is the number of
hashes used to calculate the merkle root, index is the position hashes used to calculate the merkle root, index is the position
of the hash to calculate the branch of. of the hash to calculate the branch of.
index must be less than length, which must be at least 1.''' index must be less than length, which must be at least 1."""
if not isinstance(length, int): if not isinstance(length, int):
raise TypeError('length must be an integer') raise TypeError('length must be an integer')
if not isinstance(index, int): if not isinstance(index, int):

View file

@ -23,7 +23,7 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Representation of a peer server.''' """Representation of a peer server."""
from ipaddress import ip_address from ipaddress import ip_address
@ -47,8 +47,8 @@ class Peer:
def __init__(self, host, features, source='unknown', ip_addr=None, def __init__(self, host, features, source='unknown', ip_addr=None,
last_good=0, last_try=0, try_count=0): last_good=0, last_try=0, try_count=0):
'''Create a peer given a host name (or IP address as a string), """Create a peer given a host name (or IP address as a string),
a dictionary of features, and a record of the source.''' a dictionary of features, and a record of the source."""
assert isinstance(host, str) assert isinstance(host, str)
assert isinstance(features, dict) assert isinstance(features, dict)
assert host in features.get('hosts', {}) assert host in features.get('hosts', {})
@ -83,14 +83,14 @@ class Peer:
@classmethod @classmethod
def deserialize(cls, item): def deserialize(cls, item):
'''Deserialize from a dictionary.''' """Deserialize from a dictionary."""
return cls(**item) return cls(**item)
def matches(self, peers): def matches(self, peers):
'''Return peers whose host matches our hostname or IP address. """Return peers whose host matches our hostname or IP address.
Additionally include all peers whose IP address matches our Additionally include all peers whose IP address matches our
hostname if that is an IP address. hostname if that is an IP address.
''' """
candidates = (self.host.lower(), self.ip_addr) candidates = (self.host.lower(), self.ip_addr)
return [peer for peer in peers return [peer for peer in peers
if peer.host.lower() in candidates if peer.host.lower() in candidates
@ -100,7 +100,7 @@ class Peer:
return self.host return self.host
def update_features(self, features): def update_features(self, features):
'''Update features in-place.''' """Update features in-place."""
try: try:
tmp = Peer(self.host, features) tmp = Peer(self.host, features)
except Exception: except Exception:
@ -115,8 +115,8 @@ class Peer:
setattr(self, feature, getattr(peer, feature)) setattr(self, feature, getattr(peer, feature))
def connection_port_pairs(self): def connection_port_pairs(self):
'''Return a list of (kind, port) pairs to try when making a """Return a list of (kind, port) pairs to try when making a
connection.''' connection."""
# Use a list not a set - it's important to try the registered # Use a list not a set - it's important to try the registered
# ports first. # ports first.
pairs = [('SSL', self.ssl_port), ('TCP', self.tcp_port)] pairs = [('SSL', self.ssl_port), ('TCP', self.tcp_port)]
@ -125,13 +125,13 @@ class Peer:
return [pair for pair in pairs if pair[1]] return [pair for pair in pairs if pair[1]]
def mark_bad(self): def mark_bad(self):
'''Mark as bad to avoid reconnects but also to remember for a """Mark as bad to avoid reconnects but also to remember for a
while.''' while."""
self.bad = True self.bad = True
def check_ports(self, other): def check_ports(self, other):
'''Remember differing ports in case server operator changed them """Remember differing ports in case server operator changed them
or removed one.''' or removed one."""
if other.ssl_port != self.ssl_port: if other.ssl_port != self.ssl_port:
self.other_port_pairs.add(('SSL', other.ssl_port)) self.other_port_pairs.add(('SSL', other.ssl_port))
if other.tcp_port != self.tcp_port: if other.tcp_port != self.tcp_port:
@ -160,7 +160,7 @@ class Peer:
@cachedproperty @cachedproperty
def ip_address(self): def ip_address(self):
'''The host as a python ip_address object, or None.''' """The host as a python ip_address object, or None."""
try: try:
return ip_address(self.host) return ip_address(self.host)
except ValueError: except ValueError:
@ -174,7 +174,7 @@ class Peer:
return tuple(self.ip_addr.split('.')[:2]) return tuple(self.ip_addr.split('.')[:2])
def serialize(self): def serialize(self):
'''Serialize to a dictionary.''' """Serialize to a dictionary."""
return {attr: getattr(self, attr) for attr in self.ATTRS} return {attr: getattr(self, attr) for attr in self.ATTRS}
def _port(self, key): def _port(self, key):
@ -202,28 +202,28 @@ class Peer:
@cachedproperty @cachedproperty
def genesis_hash(self): def genesis_hash(self):
'''Returns None if no SSL port, otherwise the port as an integer.''' """Returns None if no SSL port, otherwise the port as an integer."""
return self._string('genesis_hash') return self._string('genesis_hash')
@cachedproperty @cachedproperty
def ssl_port(self): def ssl_port(self):
'''Returns None if no SSL port, otherwise the port as an integer.''' """Returns None if no SSL port, otherwise the port as an integer."""
return self._port('ssl_port') return self._port('ssl_port')
@cachedproperty @cachedproperty
def tcp_port(self): def tcp_port(self):
'''Returns None if no TCP port, otherwise the port as an integer.''' """Returns None if no TCP port, otherwise the port as an integer."""
return self._port('tcp_port') return self._port('tcp_port')
@cachedproperty @cachedproperty
def server_version(self): def server_version(self):
'''Returns the server version as a string if known, otherwise None.''' """Returns the server version as a string if known, otherwise None."""
return self._string('server_version') return self._string('server_version')
@cachedproperty @cachedproperty
def pruning(self): def pruning(self):
'''Returns the pruning level as an integer. None indicates no """Returns the pruning level as an integer. None indicates no
pruning.''' pruning."""
pruning = self._integer('pruning') pruning = self._integer('pruning')
if pruning and pruning > 0: if pruning and pruning > 0:
return pruning return pruning
@ -236,22 +236,22 @@ class Peer:
@cachedproperty @cachedproperty
def protocol_min(self): def protocol_min(self):
'''Minimum protocol version as a string, e.g., 1.0''' """Minimum protocol version as a string, e.g., 1.0"""
return self._protocol_version_string('protocol_min') return self._protocol_version_string('protocol_min')
@cachedproperty @cachedproperty
def protocol_max(self): def protocol_max(self):
'''Maximum protocol version as a string, e.g., 1.1''' """Maximum protocol version as a string, e.g., 1.1"""
return self._protocol_version_string('protocol_max') return self._protocol_version_string('protocol_max')
def to_tuple(self): def to_tuple(self):
'''The tuple ((ip, host, details) expected in response """The tuple ((ip, host, details) expected in response
to a peers subscription.''' to a peers subscription."""
details = self.real_name().split()[1:] details = self.real_name().split()[1:]
return (self.ip_addr or self.host, self.host, details) return (self.ip_addr or self.host, self.host, details)
def real_name(self): def real_name(self):
'''Real name of this peer as used on IRC.''' """Real name of this peer as used on IRC."""
def port_text(letter, port): def port_text(letter, port):
if port == self.DEFAULT_PORTS.get(letter): if port == self.DEFAULT_PORTS.get(letter):
return letter return letter
@ -268,12 +268,12 @@ class Peer:
@classmethod @classmethod
def from_real_name(cls, real_name, source): def from_real_name(cls, real_name, source):
'''Real name is a real name as on IRC, such as """Real name is a real name as on IRC, such as
"erbium1.sytes.net v1.0 s t" "erbium1.sytes.net v1.0 s t"
Returns an instance of this Peer class. Returns an instance of this Peer class.
''' """
host = 'nohost' host = 'nohost'
features = {} features = {}
ports = {} ports = {}

View file

@ -5,7 +5,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Peer management.''' """Peer management."""
import asyncio import asyncio
import random import random
@ -39,7 +39,7 @@ def assert_good(message, result, instance):
class PeerSession(RPCSession): class PeerSession(RPCSession):
'''An outgoing session to a peer.''' """An outgoing session to a peer."""
async def handle_request(self, request): async def handle_request(self, request):
# We subscribe so might be unlucky enough to get a notification... # We subscribe so might be unlucky enough to get a notification...
@ -51,11 +51,11 @@ class PeerSession(RPCSession):
class PeerManager: class PeerManager:
'''Looks after the DB of peer network servers. """Looks after the DB of peer network servers.
Attempts to maintain a connection with up to 8 peers. Attempts to maintain a connection with up to 8 peers.
Issues a 'peers.subscribe' RPC to them and tells them our data. Issues a 'peers.subscribe' RPC to them and tells them our data.
''' """
def __init__(self, env, db): def __init__(self, env, db):
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
# Initialise the Peer class # Initialise the Peer class
@ -78,12 +78,12 @@ class PeerManager:
self.group = TaskGroup() self.group = TaskGroup()
def _my_clearnet_peer(self): def _my_clearnet_peer(self):
'''Returns the clearnet peer representing this server, if any.''' """Returns the clearnet peer representing this server, if any."""
clearnet = [peer for peer in self.myselves if not peer.is_tor] clearnet = [peer for peer in self.myselves if not peer.is_tor]
return clearnet[0] if clearnet else None return clearnet[0] if clearnet else None
def _set_peer_statuses(self): def _set_peer_statuses(self):
'''Set peer statuses.''' """Set peer statuses."""
cutoff = time.time() - STALE_SECS cutoff = time.time() - STALE_SECS
for peer in self.peers: for peer in self.peers:
if peer.bad: if peer.bad:
@ -96,10 +96,10 @@ class PeerManager:
peer.status = PEER_NEVER peer.status = PEER_NEVER
def _features_to_register(self, peer, remote_peers): def _features_to_register(self, peer, remote_peers):
'''If we should register ourselves to the remote peer, which has """If we should register ourselves to the remote peer, which has
reported the given list of known peers, return the clearnet reported the given list of known peers, return the clearnet
identity features to register, otherwise None. identity features to register, otherwise None.
''' """
# Announce ourself if not present. Don't if disabled, we # Announce ourself if not present. Don't if disabled, we
# are a non-public IP address, or to ourselves. # are a non-public IP address, or to ourselves.
if not self.env.peer_announce or peer in self.myselves: if not self.env.peer_announce or peer in self.myselves:
@ -114,7 +114,7 @@ class PeerManager:
return my.features return my.features
def _permit_new_onion_peer(self): def _permit_new_onion_peer(self):
'''Accept a new onion peer only once per random time interval.''' """Accept a new onion peer only once per random time interval."""
now = time.time() now = time.time()
if now < self.permit_onion_peer_time: if now < self.permit_onion_peer_time:
return False return False
@ -122,7 +122,7 @@ class PeerManager:
return True return True
async def _import_peers(self): async def _import_peers(self):
'''Import hard-coded peers from a file or the coin defaults.''' """Import hard-coded peers from a file or the coin defaults."""
imported_peers = self.myselves.copy() imported_peers = self.myselves.copy()
# Add the hard-coded ones unless only reporting ourself # Add the hard-coded ones unless only reporting ourself
if self.env.peer_discovery != self.env.PD_SELF: if self.env.peer_discovery != self.env.PD_SELF:
@ -131,12 +131,12 @@ class PeerManager:
await self._note_peers(imported_peers, limit=None) await self._note_peers(imported_peers, limit=None)
async def _detect_proxy(self): async def _detect_proxy(self):
'''Detect a proxy if we don't have one and some time has passed since """Detect a proxy if we don't have one and some time has passed since
the last attempt. the last attempt.
If found self.proxy is set to a SOCKSProxy instance, otherwise If found self.proxy is set to a SOCKSProxy instance, otherwise
None. None.
''' """
host = self.env.tor_proxy_host host = self.env.tor_proxy_host
if self.env.tor_proxy_port is None: if self.env.tor_proxy_port is None:
ports = [9050, 9150, 1080] ports = [9050, 9150, 1080]
@ -155,7 +155,7 @@ class PeerManager:
async def _note_peers(self, peers, limit=2, check_ports=False, async def _note_peers(self, peers, limit=2, check_ports=False,
source=None): source=None):
'''Add a limited number of peers that are not already present.''' """Add a limited number of peers that are not already present."""
new_peers = [] new_peers = []
for peer in peers: for peer in peers:
if not peer.is_public or (peer.is_tor and not self.proxy): if not peer.is_public or (peer.is_tor and not self.proxy):
@ -378,12 +378,12 @@ class PeerManager:
# External interface # External interface
# #
async def discover_peers(self): async def discover_peers(self):
'''Perform peer maintenance. This includes """Perform peer maintenance. This includes
1) Forgetting unreachable peers. 1) Forgetting unreachable peers.
2) Verifying connectivity of new peers. 2) Verifying connectivity of new peers.
3) Retrying old peers at regular intervals. 3) Retrying old peers at regular intervals.
''' """
if self.env.peer_discovery != self.env.PD_ON: if self.env.peer_discovery != self.env.PD_ON:
self.logger.info('peer discovery is disabled') self.logger.info('peer discovery is disabled')
return return
@ -395,7 +395,7 @@ class PeerManager:
self.group.add(self._import_peers()) self.group.add(self._import_peers())
def info(self): def info(self):
'''The number of peers.''' """The number of peers."""
self._set_peer_statuses() self._set_peer_statuses()
counter = Counter(peer.status for peer in self.peers) counter = Counter(peer.status for peer in self.peers)
return { return {
@ -407,11 +407,11 @@ class PeerManager:
} }
async def add_localRPC_peer(self, real_name): async def add_localRPC_peer(self, real_name):
'''Add a peer passed by the admin over LocalRPC.''' """Add a peer passed by the admin over LocalRPC."""
await self._note_peers([Peer.from_real_name(real_name, 'RPC')]) await self._note_peers([Peer.from_real_name(real_name, 'RPC')])
async def on_add_peer(self, features, source_info): async def on_add_peer(self, features, source_info):
'''Add a peer (but only if the peer resolves to the source).''' """Add a peer (but only if the peer resolves to the source)."""
if not source_info: if not source_info:
self.logger.info('ignored add_peer request: no source info') self.logger.info('ignored add_peer request: no source info')
return False return False
@ -449,12 +449,12 @@ class PeerManager:
return permit return permit
def on_peers_subscribe(self, is_tor): def on_peers_subscribe(self, is_tor):
'''Returns the server peers as a list of (ip, host, details) tuples. """Returns the server peers as a list of (ip, host, details) tuples.
We return all peers we've connected to in the last day. We return all peers we've connected to in the last day.
Additionally, if we don't have onion routing, we return a few Additionally, if we don't have onion routing, we return a few
hard-coded onion servers. hard-coded onion servers.
''' """
cutoff = time.time() - STALE_SECS cutoff = time.time() - STALE_SECS
recent = [peer for peer in self.peers recent = [peer for peer in self.peers
if peer.last_good > cutoff and if peer.last_good > cutoff and
@ -485,12 +485,12 @@ class PeerManager:
return [peer.to_tuple() for peer in peers] return [peer.to_tuple() for peer in peers]
def proxy_peername(self): def proxy_peername(self):
'''Return the peername of the proxy, if there is a proxy, otherwise """Return the peername of the proxy, if there is a proxy, otherwise
None.''' None."""
return self.proxy.peername if self.proxy else None return self.proxy.peername if self.proxy else None
def rpc_data(self): def rpc_data(self):
'''Peer data for the peers RPC method.''' """Peer data for the peers RPC method."""
self._set_peer_statuses() self._set_peer_statuses()
descs = ['good', 'stale', 'never', 'bad'] descs = ['good', 'stale', 'never', 'bad']

View file

@ -24,7 +24,7 @@
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software. # and warranty status of this software.
'''Script-related classes and functions.''' """Script-related classes and functions."""
import struct import struct
@ -37,7 +37,7 @@ from torba.server.util import unpack_le_uint16_from, unpack_le_uint32_from, \
class ScriptError(Exception): class ScriptError(Exception):
'''Exception used for script errors.''' """Exception used for script errors."""
OpCodes = Enumeration("Opcodes", [ OpCodes = Enumeration("Opcodes", [
@ -92,9 +92,9 @@ def _match_ops(ops, pattern):
class ScriptPubKey: class ScriptPubKey:
'''A class for handling a tx output script that gives conditions """A class for handling a tx output script that gives conditions
necessary for spending. necessary for spending.
''' """
TO_ADDRESS_OPS = [OpCodes.OP_DUP, OpCodes.OP_HASH160, -1, TO_ADDRESS_OPS = [OpCodes.OP_DUP, OpCodes.OP_HASH160, -1,
OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG] OpCodes.OP_EQUALVERIFY, OpCodes.OP_CHECKSIG]
@ -106,7 +106,7 @@ class ScriptPubKey:
@classmethod @classmethod
def pay_to(cls, handlers, script): def pay_to(cls, handlers, script):
'''Parse a script, invoke the appropriate handler and """Parse a script, invoke the appropriate handler and
return the result. return the result.
One of the following handlers is invoked: One of the following handlers is invoked:
@ -115,7 +115,7 @@ class ScriptPubKey:
handlers.pubkey(pubkey) handlers.pubkey(pubkey)
handlers.unspendable() handlers.unspendable()
handlers.strange(script) handlers.strange(script)
''' """
try: try:
ops = Script.get_ops(script) ops = Script.get_ops(script)
except ScriptError: except ScriptError:
@ -163,7 +163,7 @@ class ScriptPubKey:
@classmethod @classmethod
def multisig_script(cls, m, pubkeys): def multisig_script(cls, m, pubkeys):
'''Returns the script for a pay-to-multisig transaction.''' """Returns the script for a pay-to-multisig transaction."""
n = len(pubkeys) n = len(pubkeys)
if not 1 <= m <= n <= 15: if not 1 <= m <= n <= 15:
raise ScriptError('{:d} of {:d} multisig script not possible' raise ScriptError('{:d} of {:d} multisig script not possible'
@ -218,7 +218,7 @@ class Script:
@classmethod @classmethod
def push_data(cls, data): def push_data(cls, data):
'''Returns the opcodes to push the data on the stack.''' """Returns the opcodes to push the data on the stack."""
assert isinstance(data, (bytes, bytearray)) assert isinstance(data, (bytes, bytearray))
n = len(data) n = len(data)

View file

@ -5,7 +5,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Classes for local RPC server and remote client TCP/SSL servers.''' """Classes for local RPC server and remote client TCP/SSL servers."""
import asyncio import asyncio
import codecs import codecs
@ -48,8 +48,8 @@ def scripthash_to_hashX(scripthash):
def non_negative_integer(value): def non_negative_integer(value):
'''Return param value it is or can be converted to a non-negative """Return param value it is or can be converted to a non-negative
integer, otherwise raise an RPCError.''' integer, otherwise raise an RPCError."""
try: try:
value = int(value) value = int(value)
if value >= 0: if value >= 0:
@ -61,15 +61,15 @@ def non_negative_integer(value):
def assert_boolean(value): def assert_boolean(value):
'''Return param value it is boolean otherwise raise an RPCError.''' """Return param value it is boolean otherwise raise an RPCError."""
if value in (False, True): if value in (False, True):
return value return value
raise RPCError(BAD_REQUEST, f'{value} should be a boolean value') raise RPCError(BAD_REQUEST, f'{value} should be a boolean value')
def assert_tx_hash(value): def assert_tx_hash(value):
'''Raise an RPCError if the value is not a valid transaction """Raise an RPCError if the value is not a valid transaction
hash.''' hash."""
try: try:
if len(util.hex_to_bytes(value)) == 32: if len(util.hex_to_bytes(value)) == 32:
return return
@ -79,7 +79,7 @@ def assert_tx_hash(value):
class Semaphores: class Semaphores:
'''For aiorpcX's semaphore handling.''' """For aiorpcX's semaphore handling."""
def __init__(self, semaphores): def __init__(self, semaphores):
self.semaphores = semaphores self.semaphores = semaphores
@ -104,7 +104,7 @@ class SessionGroup:
class SessionManager: class SessionManager:
'''Holds global state about all sessions.''' """Holds global state about all sessions."""
def __init__(self, env, db, bp, daemon, mempool, shutdown_event): def __init__(self, env, db, bp, daemon, mempool, shutdown_event):
env.max_send = max(350000, env.max_send) env.max_send = max(350000, env.max_send)
@ -159,9 +159,9 @@ class SessionManager:
self.logger.info(f'{kind} server listening on {host}:{port:d}') self.logger.info(f'{kind} server listening on {host}:{port:d}')
async def _start_external_servers(self): async def _start_external_servers(self):
'''Start listening on TCP and SSL ports, but only if the respective """Start listening on TCP and SSL ports, but only if the respective
port was given in the environment. port was given in the environment.
''' """
env = self.env env = self.env
host = env.cs_host(for_rpc=False) host = env.cs_host(for_rpc=False)
if env.tcp_port is not None: if env.tcp_port is not None:
@ -172,7 +172,7 @@ class SessionManager:
await self._start_server('SSL', host, env.ssl_port, ssl=sslc) await self._start_server('SSL', host, env.ssl_port, ssl=sslc)
async def _close_servers(self, kinds): async def _close_servers(self, kinds):
'''Close the servers of the given kinds (TCP etc.).''' """Close the servers of the given kinds (TCP etc.)."""
if kinds: if kinds:
self.logger.info('closing down {} listening servers' self.logger.info('closing down {} listening servers'
.format(', '.join(kinds))) .format(', '.join(kinds)))
@ -203,7 +203,7 @@ class SessionManager:
paused = False paused = False
async def _log_sessions(self): async def _log_sessions(self):
'''Periodically log sessions.''' """Periodically log sessions."""
log_interval = self.env.log_sessions log_interval = self.env.log_sessions
if log_interval: if log_interval:
while True: while True:
@ -247,7 +247,7 @@ class SessionManager:
return result return result
async def _clear_stale_sessions(self): async def _clear_stale_sessions(self):
'''Cut off sessions that haven't done anything for 10 minutes.''' """Cut off sessions that haven't done anything for 10 minutes."""
while True: while True:
await sleep(60) await sleep(60)
stale_cutoff = time.time() - self.env.session_timeout stale_cutoff = time.time() - self.env.session_timeout
@ -276,7 +276,7 @@ class SessionManager:
session.group = new_group session.group = new_group
def _get_info(self): def _get_info(self):
'''A summary of server state.''' """A summary of server state."""
group_map = self._group_map() group_map = self._group_map()
return { return {
'closing': len([s for s in self.sessions if s.is_closing()]), 'closing': len([s for s in self.sessions if s.is_closing()]),
@ -298,7 +298,7 @@ class SessionManager:
} }
def _session_data(self, for_log): def _session_data(self, for_log):
'''Returned to the RPC 'sessions' call.''' """Returned to the RPC 'sessions' call."""
now = time.time() now = time.time()
sessions = sorted(self.sessions, key=lambda s: s.start_time) sessions = sorted(self.sessions, key=lambda s: s.start_time)
return [(session.session_id, return [(session.session_id,
@ -315,7 +315,7 @@ class SessionManager:
for session in sessions] for session in sessions]
def _group_data(self): def _group_data(self):
'''Returned to the RPC 'groups' call.''' """Returned to the RPC 'groups' call."""
result = [] result = []
group_map = self._group_map() group_map = self._group_map()
for group, sessions in group_map.items(): for group, sessions in group_map.items():
@ -338,9 +338,9 @@ class SessionManager:
return electrum_header, raw_header return electrum_header, raw_header
async def _refresh_hsub_results(self, height): async def _refresh_hsub_results(self, height):
'''Refresh the cached header subscription responses to be for height, """Refresh the cached header subscription responses to be for height,
and record that as notified_height. and record that as notified_height.
''' """
# Paranoia: a reorg could race and leave db_height lower # Paranoia: a reorg could race and leave db_height lower
height = min(height, self.db.db_height) height = min(height, self.db.db_height)
electrum, raw = await self._electrum_and_raw_headers(height) electrum, raw = await self._electrum_and_raw_headers(height)
@ -350,39 +350,39 @@ class SessionManager:
# --- LocalRPC command handlers # --- LocalRPC command handlers
async def rpc_add_peer(self, real_name): async def rpc_add_peer(self, real_name):
'''Add a peer. """Add a peer.
real_name: "bch.electrumx.cash t50001 s50002" for example real_name: "bch.electrumx.cash t50001 s50002" for example
''' """
await self.peer_mgr.add_localRPC_peer(real_name) await self.peer_mgr.add_localRPC_peer(real_name)
return "peer '{}' added".format(real_name) return "peer '{}' added".format(real_name)
async def rpc_disconnect(self, session_ids): async def rpc_disconnect(self, session_ids):
'''Disconnect sesssions. """Disconnect sesssions.
session_ids: array of session IDs session_ids: array of session IDs
''' """
async def close(session): async def close(session):
'''Close the session's transport.''' """Close the session's transport."""
await session.close(force_after=2) await session.close(force_after=2)
return f'disconnected {session.session_id}' return f'disconnected {session.session_id}'
return await self._for_each_session(session_ids, close) return await self._for_each_session(session_ids, close)
async def rpc_log(self, session_ids): async def rpc_log(self, session_ids):
'''Toggle logging of sesssions. """Toggle logging of sesssions.
session_ids: array of session IDs session_ids: array of session IDs
''' """
async def toggle_logging(session): async def toggle_logging(session):
'''Toggle logging of the session.''' """Toggle logging of the session."""
session.toggle_logging() session.toggle_logging()
return f'log {session.session_id}: {session.log_me}' return f'log {session.session_id}: {session.log_me}'
return await self._for_each_session(session_ids, toggle_logging) return await self._for_each_session(session_ids, toggle_logging)
async def rpc_daemon_url(self, daemon_url): async def rpc_daemon_url(self, daemon_url):
'''Replace the daemon URL.''' """Replace the daemon URL."""
daemon_url = daemon_url or self.env.daemon_url daemon_url = daemon_url or self.env.daemon_url
try: try:
self.daemon.set_url(daemon_url) self.daemon.set_url(daemon_url)
@ -391,24 +391,24 @@ class SessionManager:
return f'now using daemon at {self.daemon.logged_url()}' return f'now using daemon at {self.daemon.logged_url()}'
async def rpc_stop(self): async def rpc_stop(self):
'''Shut down the server cleanly.''' """Shut down the server cleanly."""
self.shutdown_event.set() self.shutdown_event.set()
return 'stopping' return 'stopping'
async def rpc_getinfo(self): async def rpc_getinfo(self):
'''Return summary information about the server process.''' """Return summary information about the server process."""
return self._get_info() return self._get_info()
async def rpc_groups(self): async def rpc_groups(self):
'''Return statistics about the session groups.''' """Return statistics about the session groups."""
return self._group_data() return self._group_data()
async def rpc_peers(self): async def rpc_peers(self):
'''Return a list of data about server peers.''' """Return a list of data about server peers."""
return self.peer_mgr.rpc_data() return self.peer_mgr.rpc_data()
async def rpc_query(self, items, limit): async def rpc_query(self, items, limit):
'''Return a list of data about server peers.''' """Return a list of data about server peers."""
coin = self.env.coin coin = self.env.coin
db = self.db db = self.db
lines = [] lines = []
@ -459,14 +459,14 @@ class SessionManager:
return lines return lines
async def rpc_sessions(self): async def rpc_sessions(self):
'''Return statistics about connected sessions.''' """Return statistics about connected sessions."""
return self._session_data(for_log=False) return self._session_data(for_log=False)
async def rpc_reorg(self, count): async def rpc_reorg(self, count):
'''Force a reorg of the given number of blocks. """Force a reorg of the given number of blocks.
count: number of blocks to reorg count: number of blocks to reorg
''' """
count = non_negative_integer(count) count = non_negative_integer(count)
if not self.bp.force_chain_reorg(count): if not self.bp.force_chain_reorg(count):
raise RPCError(BAD_REQUEST, 'still catching up with daemon') raise RPCError(BAD_REQUEST, 'still catching up with daemon')
@ -475,8 +475,8 @@ class SessionManager:
# --- External Interface # --- External Interface
async def serve(self, notifications, server_listening_event): async def serve(self, notifications, server_listening_event):
'''Start the RPC server if enabled. When the event is triggered, """Start the RPC server if enabled. When the event is triggered,
start TCP and SSL servers.''' start TCP and SSL servers."""
try: try:
if self.env.rpc_port is not None: if self.env.rpc_port is not None:
await self._start_server('RPC', self.env.cs_host(for_rpc=True), await self._start_server('RPC', self.env.cs_host(for_rpc=True),
@ -515,18 +515,18 @@ class SessionManager:
]) ])
def session_count(self): def session_count(self):
'''The number of connections that we've sent something to.''' """The number of connections that we've sent something to."""
return len(self.sessions) return len(self.sessions)
async def daemon_request(self, method, *args): async def daemon_request(self, method, *args):
'''Catch a DaemonError and convert it to an RPCError.''' """Catch a DaemonError and convert it to an RPCError."""
try: try:
return await getattr(self.daemon, method)(*args) return await getattr(self.daemon, method)(*args)
except DaemonError as e: except DaemonError as e:
raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None
async def raw_header(self, height): async def raw_header(self, height):
'''Return the binary header at the given height.''' """Return the binary header at the given height."""
try: try:
return await self.db.raw_header(height) return await self.db.raw_header(height)
except IndexError: except IndexError:
@ -534,7 +534,7 @@ class SessionManager:
'out of range') from None 'out of range') from None
async def electrum_header(self, height): async def electrum_header(self, height):
'''Return the deserialized header at the given height.''' """Return the deserialized header at the given height."""
electrum_header, _ = await self._electrum_and_raw_headers(height) electrum_header, _ = await self._electrum_and_raw_headers(height)
return electrum_header return electrum_header
@ -544,7 +544,7 @@ class SessionManager:
return hex_hash return hex_hash
async def limited_history(self, hashX): async def limited_history(self, hashX):
'''A caching layer.''' """A caching layer."""
hc = self.history_cache hc = self.history_cache
if hashX not in hc: if hashX not in hc:
# History DoS limit. Each element of history is about 99 # History DoS limit. Each element of history is about 99
@ -556,7 +556,7 @@ class SessionManager:
return hc[hashX] return hc[hashX]
async def _notify_sessions(self, height, touched): async def _notify_sessions(self, height, touched):
'''Notify sessions about height changes and touched addresses.''' """Notify sessions about height changes and touched addresses."""
height_changed = height != self.notified_height height_changed = height != self.notified_height
if height_changed: if height_changed:
await self._refresh_hsub_results(height) await self._refresh_hsub_results(height)
@ -579,7 +579,7 @@ class SessionManager:
return self.cur_group return self.cur_group
def remove_session(self, session): def remove_session(self, session):
'''Remove a session from our sessions list if there.''' """Remove a session from our sessions list if there."""
self.sessions.remove(session) self.sessions.remove(session)
self.session_event.set() self.session_event.set()
@ -593,11 +593,11 @@ class SessionManager:
class SessionBase(RPCSession): class SessionBase(RPCSession):
'''Base class of ElectrumX JSON sessions. """Base class of ElectrumX JSON sessions.
Each session runs its tasks in asynchronous parallelism with other Each session runs its tasks in asynchronous parallelism with other
sessions. sessions.
''' """
MAX_CHUNK_SIZE = 2016 MAX_CHUNK_SIZE = 2016
session_counter = itertools.count() session_counter = itertools.count()
@ -627,8 +627,8 @@ class SessionBase(RPCSession):
pass pass
def peer_address_str(self, *, for_log=True): def peer_address_str(self, *, for_log=True):
'''Returns the peer's IP address and port as a human-readable """Returns the peer's IP address and port as a human-readable
string, respecting anon logs if the output is for a log.''' string, respecting anon logs if the output is for a log."""
if for_log and self.anon_logs: if for_log and self.anon_logs:
return 'xx.xx.xx.xx:xx' return 'xx.xx.xx.xx:xx'
return super().peer_address_str() return super().peer_address_str()
@ -642,7 +642,7 @@ class SessionBase(RPCSession):
self.log_me = not self.log_me self.log_me = not self.log_me
def flags(self): def flags(self):
'''Status flags.''' """Status flags."""
status = self.kind[0] status = self.kind[0]
if self.is_closing(): if self.is_closing():
status += 'C' status += 'C'
@ -652,7 +652,7 @@ class SessionBase(RPCSession):
return status return status
def connection_made(self, transport): def connection_made(self, transport):
'''Handle an incoming client connection.''' """Handle an incoming client connection."""
super().connection_made(transport) super().connection_made(transport)
self.session_id = next(self.session_counter) self.session_id = next(self.session_counter)
context = {'conn_id': f'{self.session_id}'} context = {'conn_id': f'{self.session_id}'}
@ -662,7 +662,7 @@ class SessionBase(RPCSession):
f'{self.session_mgr.session_count():,d} total') f'{self.session_mgr.session_count():,d} total')
def connection_lost(self, exc): def connection_lost(self, exc):
'''Handle client disconnection.''' """Handle client disconnection."""
super().connection_lost(exc) super().connection_lost(exc)
self.session_mgr.remove_session(self) self.session_mgr.remove_session(self)
msg = '' msg = ''
@ -687,9 +687,9 @@ class SessionBase(RPCSession):
return 0 return 0
async def handle_request(self, request): async def handle_request(self, request):
'''Handle an incoming request. ElectrumX doesn't receive """Handle an incoming request. ElectrumX doesn't receive
notifications from client sessions. notifications from client sessions.
''' """
if isinstance(request, Request): if isinstance(request, Request):
handler = self.request_handlers.get(request.method) handler = self.request_handlers.get(request.method)
else: else:
@ -699,7 +699,7 @@ class SessionBase(RPCSession):
class ElectrumX(SessionBase): class ElectrumX(SessionBase):
'''A TCP server that handles incoming Electrum connections.''' """A TCP server that handles incoming Electrum connections."""
PROTOCOL_MIN = (1, 1) PROTOCOL_MIN = (1, 1)
PROTOCOL_MAX = (1, 4) PROTOCOL_MAX = (1, 4)
@ -722,7 +722,7 @@ class ElectrumX(SessionBase):
@classmethod @classmethod
def server_features(cls, env): def server_features(cls, env):
'''Return the server features dictionary.''' """Return the server features dictionary."""
min_str, max_str = cls.protocol_min_max_strings() min_str, max_str = cls.protocol_min_max_strings()
return { return {
'hosts': env.hosts_dict(), 'hosts': env.hosts_dict(),
@ -739,7 +739,7 @@ class ElectrumX(SessionBase):
@classmethod @classmethod
def server_version_args(cls): def server_version_args(cls):
'''The arguments to a server.version RPC call to a peer.''' """The arguments to a server.version RPC call to a peer."""
return [torba.__version__, cls.protocol_min_max_strings()] return [torba.__version__, cls.protocol_min_max_strings()]
def protocol_version_string(self): def protocol_version_string(self):
@ -749,9 +749,9 @@ class ElectrumX(SessionBase):
return len(self.hashX_subs) return len(self.hashX_subs)
async def notify(self, touched, height_changed): async def notify(self, touched, height_changed):
'''Notify the client about changes to touched addresses (from mempool """Notify the client about changes to touched addresses (from mempool
updates or new blocks) and height. updates or new blocks) and height.
''' """
if height_changed and self.subscribe_headers: if height_changed and self.subscribe_headers:
args = (await self.subscribe_headers_result(), ) args = (await self.subscribe_headers_result(), )
await self.send_notification('blockchain.headers.subscribe', args) await self.send_notification('blockchain.headers.subscribe', args)
@ -789,40 +789,40 @@ class ElectrumX(SessionBase):
self.logger.info(f'notified of {len(changed):,d} address{es}') self.logger.info(f'notified of {len(changed):,d} address{es}')
async def subscribe_headers_result(self): async def subscribe_headers_result(self):
'''The result of a header subscription or notification.''' """The result of a header subscription or notification."""
return self.session_mgr.hsub_results[self.subscribe_headers_raw] return self.session_mgr.hsub_results[self.subscribe_headers_raw]
async def _headers_subscribe(self, raw): async def _headers_subscribe(self, raw):
'''Subscribe to get headers of new blocks.''' """Subscribe to get headers of new blocks."""
self.subscribe_headers_raw = assert_boolean(raw) self.subscribe_headers_raw = assert_boolean(raw)
self.subscribe_headers = True self.subscribe_headers = True
return await self.subscribe_headers_result() return await self.subscribe_headers_result()
async def headers_subscribe(self): async def headers_subscribe(self):
'''Subscribe to get raw headers of new blocks.''' """Subscribe to get raw headers of new blocks."""
return await self._headers_subscribe(True) return await self._headers_subscribe(True)
async def headers_subscribe_True(self, raw=True): async def headers_subscribe_True(self, raw=True):
'''Subscribe to get headers of new blocks.''' """Subscribe to get headers of new blocks."""
return await self._headers_subscribe(raw) return await self._headers_subscribe(raw)
async def headers_subscribe_False(self, raw=False): async def headers_subscribe_False(self, raw=False):
'''Subscribe to get headers of new blocks.''' """Subscribe to get headers of new blocks."""
return await self._headers_subscribe(raw) return await self._headers_subscribe(raw)
async def add_peer(self, features): async def add_peer(self, features):
'''Add a peer (but only if the peer resolves to the source).''' """Add a peer (but only if the peer resolves to the source)."""
return await self.peer_mgr.on_add_peer(features, self.peer_address()) return await self.peer_mgr.on_add_peer(features, self.peer_address())
async def peers_subscribe(self): async def peers_subscribe(self):
'''Return the server peers as a list of (ip, host, details) tuples.''' """Return the server peers as a list of (ip, host, details) tuples."""
return self.peer_mgr.on_peers_subscribe(self.is_tor()) return self.peer_mgr.on_peers_subscribe(self.is_tor())
async def address_status(self, hashX): async def address_status(self, hashX):
'''Returns an address status. """Returns an address status.
Status is a hex string, but must be None if there is no history. Status is a hex string, but must be None if there is no history.
''' """
# Note history is ordered and mempool unordered in electrum-server # Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if it has unconfirmed inputs, otherwise 0 # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
db_history = await self.session_mgr.limited_history(hashX) db_history = await self.session_mgr.limited_history(hashX)
@ -847,8 +847,8 @@ class ElectrumX(SessionBase):
return status return status
async def hashX_listunspent(self, hashX): async def hashX_listunspent(self, hashX):
'''Return the list of UTXOs of a script hash, including mempool """Return the list of UTXOs of a script hash, including mempool
effects.''' effects."""
utxos = await self.db.all_utxos(hashX) utxos = await self.db.all_utxos(hashX)
utxos = sorted(utxos) utxos = sorted(utxos)
utxos.extend(await self.mempool.unordered_UTXOs(hashX)) utxos.extend(await self.mempool.unordered_UTXOs(hashX))
@ -879,29 +879,29 @@ class ElectrumX(SessionBase):
raise RPCError(BAD_REQUEST, f'{address} is not a valid address') raise RPCError(BAD_REQUEST, f'{address} is not a valid address')
async def address_get_balance(self, address): async def address_get_balance(self, address):
'''Return the confirmed and unconfirmed balance of an address.''' """Return the confirmed and unconfirmed balance of an address."""
hashX = self.address_to_hashX(address) hashX = self.address_to_hashX(address)
return await self.get_balance(hashX) return await self.get_balance(hashX)
async def address_get_history(self, address): async def address_get_history(self, address):
'''Return the confirmed and unconfirmed history of an address.''' """Return the confirmed and unconfirmed history of an address."""
hashX = self.address_to_hashX(address) hashX = self.address_to_hashX(address)
return await self.confirmed_and_unconfirmed_history(hashX) return await self.confirmed_and_unconfirmed_history(hashX)
async def address_get_mempool(self, address): async def address_get_mempool(self, address):
'''Return the mempool transactions touching an address.''' """Return the mempool transactions touching an address."""
hashX = self.address_to_hashX(address) hashX = self.address_to_hashX(address)
return await self.unconfirmed_history(hashX) return await self.unconfirmed_history(hashX)
async def address_listunspent(self, address): async def address_listunspent(self, address):
'''Return the list of UTXOs of an address.''' """Return the list of UTXOs of an address."""
hashX = self.address_to_hashX(address) hashX = self.address_to_hashX(address)
return await self.hashX_listunspent(hashX) return await self.hashX_listunspent(hashX)
async def address_subscribe(self, address): async def address_subscribe(self, address):
'''Subscribe to an address. """Subscribe to an address.
address: the address to subscribe to''' address: the address to subscribe to"""
hashX = self.address_to_hashX(address) hashX = self.address_to_hashX(address)
return await self.hashX_subscribe(hashX, address) return await self.hashX_subscribe(hashX, address)
@ -912,7 +912,7 @@ class ElectrumX(SessionBase):
return {'confirmed': confirmed, 'unconfirmed': unconfirmed} return {'confirmed': confirmed, 'unconfirmed': unconfirmed}
async def scripthash_get_balance(self, scripthash): async def scripthash_get_balance(self, scripthash):
'''Return the confirmed and unconfirmed balance of a scripthash.''' """Return the confirmed and unconfirmed balance of a scripthash."""
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
return await self.get_balance(hashX) return await self.get_balance(hashX)
@ -932,24 +932,24 @@ class ElectrumX(SessionBase):
return conf + await self.unconfirmed_history(hashX) return conf + await self.unconfirmed_history(hashX)
async def scripthash_get_history(self, scripthash): async def scripthash_get_history(self, scripthash):
'''Return the confirmed and unconfirmed history of a scripthash.''' """Return the confirmed and unconfirmed history of a scripthash."""
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
return await self.confirmed_and_unconfirmed_history(hashX) return await self.confirmed_and_unconfirmed_history(hashX)
async def scripthash_get_mempool(self, scripthash): async def scripthash_get_mempool(self, scripthash):
'''Return the mempool transactions touching a scripthash.''' """Return the mempool transactions touching a scripthash."""
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
return await self.unconfirmed_history(hashX) return await self.unconfirmed_history(hashX)
async def scripthash_listunspent(self, scripthash): async def scripthash_listunspent(self, scripthash):
'''Return the list of UTXOs of a scripthash.''' """Return the list of UTXOs of a scripthash."""
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
return await self.hashX_listunspent(hashX) return await self.hashX_listunspent(hashX)
async def scripthash_subscribe(self, scripthash): async def scripthash_subscribe(self, scripthash):
'''Subscribe to a script hash. """Subscribe to a script hash.
scripthash: the SHA256 hash of the script to subscribe to''' scripthash: the SHA256 hash of the script to subscribe to"""
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
return await self.hashX_subscribe(hashX, scripthash) return await self.hashX_subscribe(hashX, scripthash)
@ -968,8 +968,8 @@ class ElectrumX(SessionBase):
} }
async def block_header(self, height, cp_height=0): async def block_header(self, height, cp_height=0):
'''Return a raw block header as a hexadecimal string, or as a """Return a raw block header as a hexadecimal string, or as a
dictionary with a merkle proof.''' dictionary with a merkle proof."""
height = non_negative_integer(height) height = non_negative_integer(height)
cp_height = non_negative_integer(cp_height) cp_height = non_negative_integer(cp_height)
raw_header_hex = (await self.session_mgr.raw_header(height)).hex() raw_header_hex = (await self.session_mgr.raw_header(height)).hex()
@ -980,18 +980,18 @@ class ElectrumX(SessionBase):
return result return result
async def block_header_13(self, height): async def block_header_13(self, height):
'''Return a raw block header as a hexadecimal string. """Return a raw block header as a hexadecimal string.
height: the header's height''' height: the header's height"""
return await self.block_header(height) return await self.block_header(height)
async def block_headers(self, start_height, count, cp_height=0): async def block_headers(self, start_height, count, cp_height=0):
'''Return count concatenated block headers as hex for the main chain; """Return count concatenated block headers as hex for the main chain;
starting at start_height. starting at start_height.
start_height and count must be non-negative integers. At most start_height and count must be non-negative integers. At most
MAX_CHUNK_SIZE headers will be returned. MAX_CHUNK_SIZE headers will be returned.
''' """
start_height = non_negative_integer(start_height) start_height = non_negative_integer(start_height)
count = non_negative_integer(count) count = non_negative_integer(count)
cp_height = non_negative_integer(cp_height) cp_height = non_negative_integer(cp_height)
@ -1009,9 +1009,9 @@ class ElectrumX(SessionBase):
return await self.block_headers(start_height, count) return await self.block_headers(start_height, count)
async def block_get_chunk(self, index): async def block_get_chunk(self, index):
'''Return a chunk of block headers as a hexadecimal string. """Return a chunk of block headers as a hexadecimal string.
index: the chunk index''' index: the chunk index"""
index = non_negative_integer(index) index = non_negative_integer(index)
size = self.coin.CHUNK_SIZE size = self.coin.CHUNK_SIZE
start_height = index * size start_height = index * size
@ -1019,15 +1019,15 @@ class ElectrumX(SessionBase):
return headers.hex() return headers.hex()
async def block_get_header(self, height): async def block_get_header(self, height):
'''The deserialized header at a given height. """The deserialized header at a given height.
height: the header's height''' height: the header's height"""
height = non_negative_integer(height) height = non_negative_integer(height)
return await self.session_mgr.electrum_header(height) return await self.session_mgr.electrum_header(height)
def is_tor(self): def is_tor(self):
'''Try to detect if the connection is to a tor hidden service we are """Try to detect if the connection is to a tor hidden service we are
running.''' running."""
peername = self.peer_mgr.proxy_peername() peername = self.peer_mgr.proxy_peername()
if not peername: if not peername:
return False return False
@ -1051,11 +1051,11 @@ class ElectrumX(SessionBase):
return banner return banner
async def donation_address(self): async def donation_address(self):
'''Return the donation address as a string, empty if there is none.''' """Return the donation address as a string, empty if there is none."""
return self.env.donation_address return self.env.donation_address
async def banner(self): async def banner(self):
'''Return the server banner text.''' """Return the server banner text."""
banner = f'You are connected to an {torba.__version__} server.' banner = f'You are connected to an {torba.__version__} server.'
if self.is_tor(): if self.is_tor():
@ -1074,31 +1074,31 @@ class ElectrumX(SessionBase):
return banner return banner
async def relayfee(self): async def relayfee(self):
'''The minimum fee a low-priority tx must pay in order to be accepted """The minimum fee a low-priority tx must pay in order to be accepted
to the daemon's memory pool.''' to the daemon's memory pool."""
return await self.daemon_request('relayfee') return await self.daemon_request('relayfee')
async def estimatefee(self, number): async def estimatefee(self, number):
'''The estimated transaction fee per kilobyte to be paid for a """The estimated transaction fee per kilobyte to be paid for a
transaction to be included within a certain number of blocks. transaction to be included within a certain number of blocks.
number: the number of blocks number: the number of blocks
''' """
number = non_negative_integer(number) number = non_negative_integer(number)
return await self.daemon_request('estimatefee', number) return await self.daemon_request('estimatefee', number)
async def ping(self): async def ping(self):
'''Serves as a connection keep-alive mechanism and for the client to """Serves as a connection keep-alive mechanism and for the client to
confirm the server is still responding. confirm the server is still responding.
''' """
return None return None
async def server_version(self, client_name='', protocol_version=None): async def server_version(self, client_name='', protocol_version=None):
'''Returns the server version as a string. """Returns the server version as a string.
client_name: a string identifying the client client_name: a string identifying the client
protocol_version: the protocol version spoken by the client protocol_version: the protocol version spoken by the client
''' """
if self.sv_seen and self.protocol_tuple >= (1, 4): if self.sv_seen and self.protocol_tuple >= (1, 4):
raise RPCError(BAD_REQUEST, f'server.version already sent') raise RPCError(BAD_REQUEST, f'server.version already sent')
self.sv_seen = True self.sv_seen = True
@ -1129,9 +1129,9 @@ class ElectrumX(SessionBase):
return torba.__version__, self.protocol_version_string() return torba.__version__, self.protocol_version_string()
async def transaction_broadcast(self, raw_tx): async def transaction_broadcast(self, raw_tx):
'''Broadcast a raw transaction to the network. """Broadcast a raw transaction to the network.
raw_tx: the raw transaction as a hexadecimal string''' raw_tx: the raw transaction as a hexadecimal string"""
# This returns errors as JSON RPC errors, as is natural # This returns errors as JSON RPC errors, as is natural
try: try:
hex_hash = await self.session_mgr.broadcast_transaction(raw_tx) hex_hash = await self.session_mgr.broadcast_transaction(raw_tx)
@ -1146,11 +1146,11 @@ class ElectrumX(SessionBase):
f'network rules.\n\n{message}\n[{raw_tx}]') f'network rules.\n\n{message}\n[{raw_tx}]')
async def transaction_get(self, tx_hash, verbose=False): async def transaction_get(self, tx_hash, verbose=False):
'''Return the serialized raw transaction given its hash """Return the serialized raw transaction given its hash
tx_hash: the transaction hash as a hexadecimal string tx_hash: the transaction hash as a hexadecimal string
verbose: passed on to the daemon verbose: passed on to the daemon
''' """
assert_tx_hash(tx_hash) assert_tx_hash(tx_hash)
if verbose not in (True, False): if verbose not in (True, False):
raise RPCError(BAD_REQUEST, f'"verbose" must be a boolean') raise RPCError(BAD_REQUEST, f'"verbose" must be a boolean')
@ -1158,12 +1158,12 @@ class ElectrumX(SessionBase):
return await self.daemon_request('getrawtransaction', tx_hash, verbose) return await self.daemon_request('getrawtransaction', tx_hash, verbose)
async def _block_hash_and_tx_hashes(self, height): async def _block_hash_and_tx_hashes(self, height):
'''Returns a pair (block_hash, tx_hashes) for the main chain block at """Returns a pair (block_hash, tx_hashes) for the main chain block at
the given height. the given height.
block_hash is a hexadecimal string, and tx_hashes is an block_hash is a hexadecimal string, and tx_hashes is an
ordered list of hexadecimal strings. ordered list of hexadecimal strings.
''' """
height = non_negative_integer(height) height = non_negative_integer(height)
hex_hashes = await self.daemon_request('block_hex_hashes', height, 1) hex_hashes = await self.daemon_request('block_hex_hashes', height, 1)
block_hash = hex_hashes[0] block_hash = hex_hashes[0]
@ -1171,23 +1171,23 @@ class ElectrumX(SessionBase):
return block_hash, block['tx'] return block_hash, block['tx']
def _get_merkle_branch(self, tx_hashes, tx_pos): def _get_merkle_branch(self, tx_hashes, tx_pos):
'''Return a merkle branch to a transaction. """Return a merkle branch to a transaction.
tx_hashes: ordered list of hex strings of tx hashes in a block tx_hashes: ordered list of hex strings of tx hashes in a block
tx_pos: index of transaction in tx_hashes to create branch for tx_pos: index of transaction in tx_hashes to create branch for
''' """
hashes = [hex_str_to_hash(hash) for hash in tx_hashes] hashes = [hex_str_to_hash(hash) for hash in tx_hashes]
branch, root = self.db.merkle.branch_and_root(hashes, tx_pos) branch, root = self.db.merkle.branch_and_root(hashes, tx_pos)
branch = [hash_to_hex_str(hash) for hash in branch] branch = [hash_to_hex_str(hash) for hash in branch]
return branch return branch
async def transaction_merkle(self, tx_hash, height): async def transaction_merkle(self, tx_hash, height):
'''Return the markle branch to a confirmed transaction given its hash """Return the markle branch to a confirmed transaction given its hash
and height. and height.
tx_hash: the transaction hash as a hexadecimal string tx_hash: the transaction hash as a hexadecimal string
height: the height of the block it is in height: the height of the block it is in
''' """
assert_tx_hash(tx_hash) assert_tx_hash(tx_hash)
block_hash, tx_hashes = await self._block_hash_and_tx_hashes(height) block_hash, tx_hashes = await self._block_hash_and_tx_hashes(height)
try: try:
@ -1199,9 +1199,9 @@ class ElectrumX(SessionBase):
return {"block_height": height, "merkle": branch, "pos": pos} return {"block_height": height, "merkle": branch, "pos": pos}
async def transaction_id_from_pos(self, height, tx_pos, merkle=False): async def transaction_id_from_pos(self, height, tx_pos, merkle=False):
'''Return the txid and optionally a merkle proof, given """Return the txid and optionally a merkle proof, given
a block height and position in the block. a block height and position in the block.
''' """
tx_pos = non_negative_integer(tx_pos) tx_pos = non_negative_integer(tx_pos)
if merkle not in (True, False): if merkle not in (True, False):
raise RPCError(BAD_REQUEST, f'"merkle" must be a boolean') raise RPCError(BAD_REQUEST, f'"merkle" must be a boolean')
@ -1279,7 +1279,7 @@ class ElectrumX(SessionBase):
class LocalRPC(SessionBase): class LocalRPC(SessionBase):
'''A local TCP RPC server session.''' """A local TCP RPC server session."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -1291,7 +1291,7 @@ class LocalRPC(SessionBase):
class DashElectrumX(ElectrumX): class DashElectrumX(ElectrumX):
'''A TCP server that handles incoming Electrum Dash connections.''' """A TCP server that handles incoming Electrum Dash connections."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -1307,7 +1307,7 @@ class DashElectrumX(ElectrumX):
}) })
async def notify(self, touched, height_changed): async def notify(self, touched, height_changed):
'''Notify the client about changes in masternode list.''' """Notify the client about changes in masternode list."""
await super().notify(touched, height_changed) await super().notify(touched, height_changed)
for mn in self.mns: for mn in self.mns:
status = await self.daemon_request('masternode_list', status = await self.daemon_request('masternode_list',
@ -1317,10 +1317,10 @@ class DashElectrumX(ElectrumX):
# Masternode command handlers # Masternode command handlers
async def masternode_announce_broadcast(self, signmnb): async def masternode_announce_broadcast(self, signmnb):
'''Pass through the masternode announce message to be broadcast """Pass through the masternode announce message to be broadcast
by the daemon. by the daemon.
signmnb: signed masternode broadcast message.''' signmnb: signed masternode broadcast message."""
try: try:
return await self.daemon_request('masternode_broadcast', return await self.daemon_request('masternode_broadcast',
['relay', signmnb]) ['relay', signmnb])
@ -1332,10 +1332,10 @@ class DashElectrumX(ElectrumX):
f'rejected.\n\n{message}\n[{signmnb}]') f'rejected.\n\n{message}\n[{signmnb}]')
async def masternode_subscribe(self, collateral): async def masternode_subscribe(self, collateral):
'''Returns the status of masternode. """Returns the status of masternode.
collateral: masternode collateral. collateral: masternode collateral.
''' """
result = await self.daemon_request('masternode_list', result = await self.daemon_request('masternode_list',
['status', collateral]) ['status', collateral])
if result is not None: if result is not None:
@ -1344,20 +1344,20 @@ class DashElectrumX(ElectrumX):
return None return None
async def masternode_list(self, payees): async def masternode_list(self, payees):
''' """
Returns the list of masternodes. Returns the list of masternodes.
payees: a list of masternode payee addresses. payees: a list of masternode payee addresses.
''' """
if not isinstance(payees, list): if not isinstance(payees, list):
raise RPCError(BAD_REQUEST, 'expected a list of payees') raise RPCError(BAD_REQUEST, 'expected a list of payees')
def get_masternode_payment_queue(mns): def get_masternode_payment_queue(mns):
'''Returns the calculated position in the payment queue for all the """Returns the calculated position in the payment queue for all the
valid masterernodes in the given mns list. valid masterernodes in the given mns list.
mns: a list of masternodes information. mns: a list of masternodes information.
''' """
now = int(datetime.datetime.utcnow().strftime("%s")) now = int(datetime.datetime.utcnow().strftime("%s"))
mn_queue = [] mn_queue = []
@ -1383,12 +1383,12 @@ class DashElectrumX(ElectrumX):
return mn_queue return mn_queue
def get_payment_position(payment_queue, address): def get_payment_position(payment_queue, address):
''' """
Returns the position of the payment list for the given address. Returns the position of the payment list for the given address.
payment_queue: position in the payment queue for the masternode. payment_queue: position in the payment queue for the masternode.
address: masternode payee address. address: masternode payee address.
''' """
position = -1 position = -1
for pos, mn in enumerate(payment_queue, start=1): for pos, mn in enumerate(payment_queue, start=1):
if mn[2] == address: if mn[2] == address:

View file

@ -5,7 +5,7 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
'''Backend database abstraction.''' """Backend database abstraction."""
import os import os
from functools import partial from functools import partial
@ -14,7 +14,7 @@ from torba.server import util
def db_class(name): def db_class(name):
'''Returns a DB engine class.''' """Returns a DB engine class."""
for db_class in util.subclasses(Storage): for db_class in util.subclasses(Storage):
if db_class.__name__.lower() == name.lower(): if db_class.__name__.lower() == name.lower():
db_class.import_module() db_class.import_module()
@ -23,7 +23,7 @@ def db_class(name):
class Storage: class Storage:
'''Abstract base class of the DB backend abstraction.''' """Abstract base class of the DB backend abstraction."""
def __init__(self, name, for_sync): def __init__(self, name, for_sync):
self.is_new = not os.path.exists(name) self.is_new = not os.path.exists(name)
@ -32,15 +32,15 @@ class Storage:
@classmethod @classmethod
def import_module(cls): def import_module(cls):
'''Import the DB engine module.''' """Import the DB engine module."""
raise NotImplementedError raise NotImplementedError
def open(self, name, create): def open(self, name, create):
'''Open an existing database or create a new one.''' """Open an existing database or create a new one."""
raise NotImplementedError raise NotImplementedError
def close(self): def close(self):
'''Close an existing database.''' """Close an existing database."""
raise NotImplementedError raise NotImplementedError
def get(self, key): def get(self, key):
@ -50,26 +50,26 @@ class Storage:
raise NotImplementedError raise NotImplementedError
def write_batch(self): def write_batch(self):
'''Return a context manager that provides `put` and `delete`. """Return a context manager that provides `put` and `delete`.
Changes should only be committed when the context manager Changes should only be committed when the context manager
closes without an exception. closes without an exception.
''' """
raise NotImplementedError raise NotImplementedError
def iterator(self, prefix=b'', reverse=False): def iterator(self, prefix=b'', reverse=False):
'''Return an iterator that yields (key, value) pairs from the """Return an iterator that yields (key, value) pairs from the
database sorted by key. database sorted by key.
If `prefix` is set, only keys starting with `prefix` will be If `prefix` is set, only keys starting with `prefix` will be
included. If `reverse` is True the items are returned in included. If `reverse` is True the items are returned in
reverse order. reverse order.
''' """
raise NotImplementedError raise NotImplementedError
class LevelDB(Storage): class LevelDB(Storage):
'''LevelDB database engine.''' """LevelDB database engine."""
@classmethod @classmethod
def import_module(cls): def import_module(cls):
@ -90,7 +90,7 @@ class LevelDB(Storage):
class RocksDB(Storage): class RocksDB(Storage):
'''RocksDB database engine.''' """RocksDB database engine."""
@classmethod @classmethod
def import_module(cls): def import_module(cls):
@ -122,7 +122,7 @@ class RocksDB(Storage):
class RocksDBWriteBatch: class RocksDBWriteBatch:
'''A write batch for RocksDB.''' """A write batch for RocksDB."""
def __init__(self, db): def __init__(self, db):
self.batch = RocksDB.module.WriteBatch() self.batch = RocksDB.module.WriteBatch()
@ -137,7 +137,7 @@ class RocksDBWriteBatch:
class RocksDBIterator: class RocksDBIterator:
'''An iterator for RocksDB.''' """An iterator for RocksDB."""
def __init__(self, db, prefix, reverse): def __init__(self, db, prefix, reverse):
self.prefix = prefix self.prefix = prefix

View file

@ -4,9 +4,9 @@ from torba.server import util
def sessions_lines(data): def sessions_lines(data):
'''A generator returning lines for a list of sessions. """A generator returning lines for a list of sessions.
data is the return value of rpc_sessions().''' data is the return value of rpc_sessions()."""
fmt = ('{:<6} {:<5} {:>17} {:>5} {:>5} {:>5} ' fmt = ('{:<6} {:<5} {:>17} {:>5} {:>5} {:>5} '
'{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}') '{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}')
yield fmt.format('ID', 'Flags', 'Client', 'Proto', yield fmt.format('ID', 'Flags', 'Client', 'Proto',
@ -26,9 +26,9 @@ def sessions_lines(data):
def groups_lines(data): def groups_lines(data):
'''A generator returning lines for a list of groups. """A generator returning lines for a list of groups.
data is the return value of rpc_groups().''' data is the return value of rpc_groups()."""
fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}' fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}'
'{:>7} {:>9} {:>7} {:>9}') '{:>7} {:>9} {:>7} {:>9}')
@ -49,9 +49,9 @@ def groups_lines(data):
def peers_lines(data): def peers_lines(data):
'''A generator returning lines for a list of peers. """A generator returning lines for a list of peers.
data is the return value of rpc_peers().''' data is the return value of rpc_peers()."""
def time_fmt(t): def time_fmt(t):
if not t: if not t:
return 'Never' return 'Never'

View file

@ -25,7 +25,7 @@
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software. # and warranty status of this software.
'''Transaction-related classes and functions.''' """Transaction-related classes and functions."""
from collections import namedtuple from collections import namedtuple
@ -42,7 +42,7 @@ MINUS_1 = 4294967295
class Tx(namedtuple("Tx", "version inputs outputs locktime")): class Tx(namedtuple("Tx", "version inputs outputs locktime")):
'''Class representing a transaction.''' """Class representing a transaction."""
def serialize(self): def serialize(self):
return b''.join(( return b''.join((
@ -56,7 +56,7 @@ class Tx(namedtuple("Tx", "version inputs outputs locktime")):
class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")): class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")):
'''Class representing a transaction input.''' """Class representing a transaction input."""
def __str__(self): def __str__(self):
script = self.script.hex() script = self.script.hex()
prev_hash = hash_to_hex_str(self.prev_hash) prev_hash = hash_to_hex_str(self.prev_hash)
@ -64,7 +64,7 @@ class TxInput(namedtuple("TxInput", "prev_hash prev_idx script sequence")):
.format(prev_hash, self.prev_idx, script, self.sequence)) .format(prev_hash, self.prev_idx, script, self.sequence))
def is_generation(self): def is_generation(self):
'''Test if an input is generation/coinbase like''' """Test if an input is generation/coinbase like"""
return self.prev_idx == MINUS_1 and self.prev_hash == ZERO return self.prev_idx == MINUS_1 and self.prev_hash == ZERO
def serialize(self): def serialize(self):
@ -86,14 +86,14 @@ class TxOutput(namedtuple("TxOutput", "value pk_script")):
class Deserializer: class Deserializer:
'''Deserializes blocks into transactions. """Deserializes blocks into transactions.
External entry points are read_tx(), read_tx_and_hash(), External entry points are read_tx(), read_tx_and_hash(),
read_tx_and_vsize() and read_block(). read_tx_and_vsize() and read_block().
This code is performance sensitive as it is executed 100s of This code is performance sensitive as it is executed 100s of
millions of times during sync. millions of times during sync.
''' """
TX_HASH_FN = staticmethod(double_sha256) TX_HASH_FN = staticmethod(double_sha256)
@ -104,7 +104,7 @@ class Deserializer:
self.cursor = start self.cursor = start
def read_tx(self): def read_tx(self):
'''Return a deserialized transaction.''' """Return a deserialized transaction."""
return Tx( return Tx(
self._read_le_int32(), # version self._read_le_int32(), # version
self._read_inputs(), # inputs self._read_inputs(), # inputs
@ -113,20 +113,20 @@ class Deserializer:
) )
def read_tx_and_hash(self): def read_tx_and_hash(self):
'''Return a (deserialized TX, tx_hash) pair. """Return a (deserialized TX, tx_hash) pair.
The hash needs to be reversed for human display; for efficiency The hash needs to be reversed for human display; for efficiency
we process it in the natural serialized order. we process it in the natural serialized order.
''' """
start = self.cursor start = self.cursor
return self.read_tx(), self.TX_HASH_FN(self.binary[start:self.cursor]) return self.read_tx(), self.TX_HASH_FN(self.binary[start:self.cursor])
def read_tx_and_vsize(self): def read_tx_and_vsize(self):
'''Return a (deserialized TX, vsize) pair.''' """Return a (deserialized TX, vsize) pair."""
return self.read_tx(), self.binary_length return self.read_tx(), self.binary_length
def read_tx_block(self): def read_tx_block(self):
'''Returns a list of (deserialized_tx, tx_hash) pairs.''' """Returns a list of (deserialized_tx, tx_hash) pairs."""
read = self.read_tx_and_hash read = self.read_tx_and_hash
# Some coins have excess data beyond the end of the transactions # Some coins have excess data beyond the end of the transactions
return [read() for _ in range(self._read_varint())] return [read() for _ in range(self._read_varint())]
@ -206,7 +206,7 @@ class Deserializer:
class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs " class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs "
"witness locktime")): "witness locktime")):
'''Class representing a SegWit transaction.''' """Class representing a SegWit transaction."""
class DeserializerSegWit(Deserializer): class DeserializerSegWit(Deserializer):
@ -222,7 +222,7 @@ class DeserializerSegWit(Deserializer):
return [read_varbytes() for i in range(self._read_varint())] return [read_varbytes() for i in range(self._read_varint())]
def _read_tx_parts(self): def _read_tx_parts(self):
'''Return a (deserialized TX, tx_hash, vsize) tuple.''' """Return a (deserialized TX, tx_hash, vsize) tuple."""
start = self.cursor start = self.cursor
marker = self.binary[self.cursor + 4] marker = self.binary[self.cursor + 4]
if marker: if marker:
@ -269,7 +269,7 @@ class DeserializerAuxPow(Deserializer):
VERSION_AUXPOW = (1 << 8) VERSION_AUXPOW = (1 << 8)
def read_header(self, height, static_header_size): def read_header(self, height, static_header_size):
'''Return the AuxPow block header bytes''' """Return the AuxPow block header bytes"""
start = self.cursor start = self.cursor
version = self._read_le_uint32() version = self._read_le_uint32()
if version & self.VERSION_AUXPOW: if version & self.VERSION_AUXPOW:
@ -298,7 +298,7 @@ class DeserializerAuxPowSegWit(DeserializerSegWit, DeserializerAuxPow):
class DeserializerEquihash(Deserializer): class DeserializerEquihash(Deserializer):
def read_header(self, height, static_header_size): def read_header(self, height, static_header_size):
'''Return the block header bytes''' """Return the block header bytes"""
start = self.cursor start = self.cursor
# We are going to calculate the block size then read it as bytes # We are going to calculate the block size then read it as bytes
self.cursor += static_header_size self.cursor += static_header_size
@ -314,7 +314,7 @@ class DeserializerEquihashSegWit(DeserializerSegWit, DeserializerEquihash):
class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")): class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")):
'''Class representing a JoinSplit transaction.''' """Class representing a JoinSplit transaction."""
class DeserializerZcash(DeserializerEquihash): class DeserializerZcash(DeserializerEquihash):
@ -365,7 +365,7 @@ class DeserializerZcash(DeserializerEquihash):
class TxTime(namedtuple("Tx", "version time inputs outputs locktime")): class TxTime(namedtuple("Tx", "version time inputs outputs locktime")):
'''Class representing transaction that has a time field.''' """Class representing transaction that has a time field."""
class DeserializerTxTime(Deserializer): class DeserializerTxTime(Deserializer):
@ -406,7 +406,7 @@ class DeserializerTxTimeAuxPow(DeserializerTxTime):
return False return False
def read_header(self, height, static_header_size): def read_header(self, height, static_header_size):
'''Return the AuxPow block header bytes''' """Return the AuxPow block header bytes"""
start = self.cursor start = self.cursor
version = self._read_le_uint32() version = self._read_le_uint32()
if version & self.VERSION_AUXPOW: if version & self.VERSION_AUXPOW:
@ -433,7 +433,7 @@ class DeserializerBitcoinAtom(DeserializerSegWit):
FORK_BLOCK_HEIGHT = 505888 FORK_BLOCK_HEIGHT = 505888
def read_header(self, height, static_header_size): def read_header(self, height, static_header_size):
'''Return the block header bytes''' """Return the block header bytes"""
header_len = static_header_size header_len = static_header_size
if height >= self.FORK_BLOCK_HEIGHT: if height >= self.FORK_BLOCK_HEIGHT:
header_len += 4 # flags header_len += 4 # flags
@ -445,7 +445,7 @@ class DeserializerGroestlcoin(DeserializerSegWit):
class TxInputTokenPay(TxInput): class TxInputTokenPay(TxInput):
'''Class representing a TokenPay transaction input.''' """Class representing a TokenPay transaction input."""
OP_ANON_MARKER = 0xb9 OP_ANON_MARKER = 0xb9
# 2byte marker (cpubkey + sigc + sigr) # 2byte marker (cpubkey + sigc + sigr)
@ -468,7 +468,7 @@ class TxInputTokenPay(TxInput):
class TxInputTokenPayStealth( class TxInputTokenPayStealth(
namedtuple("TxInput", "keyimage ringsize script sequence")): namedtuple("TxInput", "keyimage ringsize script sequence")):
'''Class representing a TokenPay stealth transaction input.''' """Class representing a TokenPay stealth transaction input."""
def __str__(self): def __str__(self):
script = self.script.hex() script = self.script.hex()
@ -514,7 +514,7 @@ class DeserializerTokenPay(DeserializerTxTime):
# Decred # Decred
class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")): class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")):
'''Class representing a Decred transaction input.''' """Class representing a Decred transaction input."""
def __str__(self): def __str__(self):
prev_hash = hash_to_hex_str(self.prev_hash) prev_hash = hash_to_hex_str(self.prev_hash)
@ -522,18 +522,18 @@ class TxInputDcr(namedtuple("TxInput", "prev_hash prev_idx tree sequence")):
.format(prev_hash, self.prev_idx, self.tree, self.sequence)) .format(prev_hash, self.prev_idx, self.tree, self.sequence))
def is_generation(self): def is_generation(self):
'''Test if an input is generation/coinbase like''' """Test if an input is generation/coinbase like"""
return self.prev_idx == MINUS_1 and self.prev_hash == ZERO return self.prev_idx == MINUS_1 and self.prev_hash == ZERO
class TxOutputDcr(namedtuple("TxOutput", "value version pk_script")): class TxOutputDcr(namedtuple("TxOutput", "value version pk_script")):
'''Class representing a Decred transaction output.''' """Class representing a Decred transaction output."""
pass pass
class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry " class TxDcr(namedtuple("Tx", "version inputs outputs locktime expiry "
"witness")): "witness")):
'''Class representing a Decred transaction.''' """Class representing a Decred transaction."""
class DeserializerDecred(Deserializer): class DeserializerDecred(Deserializer):
@ -559,14 +559,14 @@ class DeserializerDecred(Deserializer):
return tx, vsize return tx, vsize
def read_tx_block(self): def read_tx_block(self):
'''Returns a list of (deserialized_tx, tx_hash) pairs.''' """Returns a list of (deserialized_tx, tx_hash) pairs."""
read = self.read_tx_and_hash read = self.read_tx_and_hash
txs = [read() for _ in range(self._read_varint())] txs = [read() for _ in range(self._read_varint())]
stxs = [read() for _ in range(self._read_varint())] stxs = [read() for _ in range(self._read_varint())]
return txs + stxs return txs + stxs
def read_tx_tree(self): def read_tx_tree(self):
'''Returns a list of deserialized_tx without tx hashes.''' """Returns a list of deserialized_tx without tx hashes."""
read_tx = self.read_tx read_tx = self.read_tx
return [read_tx() for _ in range(self._read_varint())] return [read_tx() for _ in range(self._read_varint())]

View file

@ -24,7 +24,7 @@
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software. # and warranty status of this software.
'''Miscellaneous utility classes and functions.''' """Miscellaneous utility classes and functions."""
import array import array
@ -40,21 +40,21 @@ from struct import pack, Struct
class ConnectionLogger(logging.LoggerAdapter): class ConnectionLogger(logging.LoggerAdapter):
'''Prepends a connection identifier to a logging message.''' """Prepends a connection identifier to a logging message."""
def process(self, msg, kwargs): def process(self, msg, kwargs):
conn_id = self.extra.get('conn_id', 'unknown') conn_id = self.extra.get('conn_id', 'unknown')
return f'[{conn_id}] {msg}', kwargs return f'[{conn_id}] {msg}', kwargs
class CompactFormatter(logging.Formatter): class CompactFormatter(logging.Formatter):
'''Strips the module from the logger name to leave the class only.''' """Strips the module from the logger name to leave the class only."""
def format(self, record): def format(self, record):
record.name = record.name.rpartition('.')[-1] record.name = record.name.rpartition('.')[-1]
return super().format(record) return super().format(record)
def make_logger(name, *, handler, level): def make_logger(name, *, handler, level):
'''Return the root ElectrumX logger.''' """Return the root ElectrumX logger."""
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -63,7 +63,7 @@ def make_logger(name, *, handler, level):
def class_logger(path, classname): def class_logger(path, classname):
'''Return a hierarchical logger for a class.''' """Return a hierarchical logger for a class."""
return logging.getLogger(path).getChild(classname) return logging.getLogger(path).getChild(classname)
@ -83,8 +83,8 @@ class cachedproperty:
def formatted_time(t, sep=' '): def formatted_time(t, sep=' '):
'''Return a number of seconds as a string in days, hours, mins and """Return a number of seconds as a string in days, hours, mins and
maybe secs.''' maybe secs."""
t = int(t) t = int(t)
fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60)) fmts = (('{:d}d', 86400), ('{:02d}h', 3600), ('{:02d}m', 60))
parts = [] parts = []
@ -136,7 +136,7 @@ def deep_getsizeof(obj):
def subclasses(base_class, strict=True): def subclasses(base_class, strict=True):
'''Return a list of subclasses of base_class in its module.''' """Return a list of subclasses of base_class in its module."""
def select(obj): def select(obj):
return (inspect.isclass(obj) and issubclass(obj, base_class) and return (inspect.isclass(obj) and issubclass(obj, base_class) and
(not strict or obj != base_class)) (not strict or obj != base_class))
@ -146,7 +146,7 @@ def subclasses(base_class, strict=True):
def chunks(items, size): def chunks(items, size):
'''Break up items, an iterable, into chunks of length size.''' """Break up items, an iterable, into chunks of length size."""
for i in range(0, len(items), size): for i in range(0, len(items), size):
yield items[i: i + size] yield items[i: i + size]
@ -159,19 +159,19 @@ def resolve_limit(limit):
def bytes_to_int(be_bytes): def bytes_to_int(be_bytes):
'''Interprets a big-endian sequence of bytes as an integer''' """Interprets a big-endian sequence of bytes as an integer"""
return int.from_bytes(be_bytes, 'big') return int.from_bytes(be_bytes, 'big')
def int_to_bytes(value): def int_to_bytes(value):
'''Converts an integer to a big-endian sequence of bytes''' """Converts an integer to a big-endian sequence of bytes"""
return value.to_bytes((value.bit_length() + 7) // 8, 'big') return value.to_bytes((value.bit_length() + 7) // 8, 'big')
def increment_byte_string(bs): def increment_byte_string(bs):
'''Return the lexicographically next byte string of the same length. """Return the lexicographically next byte string of the same length.
Return None if there is none (when the input is all 0xff bytes).''' Return None if there is none (when the input is all 0xff bytes)."""
for n in range(1, len(bs) + 1): for n in range(1, len(bs) + 1):
if bs[-n] != 0xff: if bs[-n] != 0xff:
return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1) return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1)
@ -179,7 +179,7 @@ def increment_byte_string(bs):
class LogicalFile: class LogicalFile:
'''A logical binary file split across several separate files on disk.''' """A logical binary file split across several separate files on disk."""
def __init__(self, prefix, digits, file_size): def __init__(self, prefix, digits, file_size):
digit_fmt = '{' + ':0{:d}d'.format(digits) + '}' digit_fmt = '{' + ':0{:d}d'.format(digits) + '}'
@ -187,10 +187,10 @@ class LogicalFile:
self.file_size = file_size self.file_size = file_size
def read(self, start, size=-1): def read(self, start, size=-1):
'''Read up to size bytes from the virtual file, starting at offset """Read up to size bytes from the virtual file, starting at offset
start, and return them. start, and return them.
If size is -1 all bytes are read.''' If size is -1 all bytes are read."""
parts = [] parts = []
while size != 0: while size != 0:
try: try:
@ -207,7 +207,7 @@ class LogicalFile:
return b''.join(parts) return b''.join(parts)
def write(self, start, b): def write(self, start, b):
'''Write the bytes-like object, b, to the underlying virtual file.''' """Write the bytes-like object, b, to the underlying virtual file."""
while b: while b:
size = min(len(b), self.file_size - (start % self.file_size)) size = min(len(b), self.file_size - (start % self.file_size))
with self.open_file(start, True) as f: with self.open_file(start, True) as f:
@ -216,10 +216,10 @@ class LogicalFile:
start += size start += size
def open_file(self, start, create): def open_file(self, start, create):
'''Open the virtual file and seek to start. Return a file handle. """Open the virtual file and seek to start. Return a file handle.
Raise FileNotFoundError if the file does not exist and create Raise FileNotFoundError if the file does not exist and create
is False. is False.
''' """
file_num, offset = divmod(start, self.file_size) file_num, offset = divmod(start, self.file_size)
filename = self.filename_fmt.format(file_num) filename = self.filename_fmt.format(file_num)
f = open_file(filename, create) f = open_file(filename, create)
@ -228,7 +228,7 @@ class LogicalFile:
def open_file(filename, create=False): def open_file(filename, create=False):
'''Open the file name. Return its handle.''' """Open the file name. Return its handle."""
try: try:
return open(filename, 'rb+') return open(filename, 'rb+')
except FileNotFoundError: except FileNotFoundError:
@ -238,12 +238,12 @@ def open_file(filename, create=False):
def open_truncate(filename): def open_truncate(filename):
'''Open the file name. Return its handle.''' """Open the file name. Return its handle."""
return open(filename, 'wb+') return open(filename, 'wb+')
def address_string(address): def address_string(address):
'''Return an address as a correctly formatted string.''' """Return an address as a correctly formatted string."""
fmt = '{}:{:d}' fmt = '{}:{:d}'
host, port = address host, port = address
try: try:
@ -273,9 +273,9 @@ def is_valid_hostname(hostname):
def protocol_tuple(s): def protocol_tuple(s):
'''Converts a protocol version number, such as "1.0" to a tuple (1, 0). """Converts a protocol version number, such as "1.0" to a tuple (1, 0).
If the version number is bad, (0, ) indicating version 0 is returned.''' If the version number is bad, (0, ) indicating version 0 is returned."""
try: try:
return tuple(int(part) for part in s.split('.')) return tuple(int(part) for part in s.split('.'))
except Exception: except Exception:
@ -283,22 +283,22 @@ def protocol_tuple(s):
def version_string(ptuple): def version_string(ptuple):
'''Convert a version tuple such as (1, 2) to "1.2". """Convert a version tuple such as (1, 2) to "1.2".
There is always at least one dot, so (1, ) becomes "1.0".''' There is always at least one dot, so (1, ) becomes "1.0"."""
while len(ptuple) < 2: while len(ptuple) < 2:
ptuple += (0, ) ptuple += (0, )
return '.'.join(str(p) for p in ptuple) return '.'.join(str(p) for p in ptuple)
def protocol_version(client_req, min_tuple, max_tuple): def protocol_version(client_req, min_tuple, max_tuple):
'''Given a client's protocol version string, return a pair of """Given a client's protocol version string, return a pair of
protocol tuples: protocol tuples:
(negotiated version, client min request) (negotiated version, client min request)
If the request is unsupported, the negotiated protocol tuple is If the request is unsupported, the negotiated protocol tuple is
None. None.
''' """
if client_req is None: if client_req is None:
client_min = client_max = min_tuple client_min = client_max = min_tuple
else: else: