# Copyright (c) 2018, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """RPC message framing in a byte stream.""" __all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer', 'OversizedPayloadError', 'BadChecksumError', 'BadMagicError') from hashlib import sha256 as _sha256 from struct import Struct from asyncio import Queue class FramerBase: """Abstract base class for a framer. A framer breaks an incoming byte stream into protocol messages, buffering if necessary. It also frames outgoing messages into a byte stream. """ def frame(self, message): """Return the framed message.""" raise NotImplementedError def received_bytes(self, data): """Pass incoming network bytes.""" raise NotImplementedError async def receive_message(self): """Wait for a complete unframed message to arrive, and return it.""" raise NotImplementedError class NewlineFramer(FramerBase): """A framer for a protocol where messages are separated by newlines.""" # The default max_size value is motivated by JSONRPC, where a # normal request will be 250 bytes or less, and a reasonable # batch may contain 4000 requests. def __init__(self, max_size=250 * 4000): """max_size - an anti-DoS measure. If, after processing an incoming message, buffered data would exceed max_size bytes, that buffered data is dropped entirely and the framer waits for a newline character to re-synchronize the stream. """ self.max_size = max_size self.queue = Queue() self.received_bytes = self.queue.put_nowait self.synchronizing = False self.residual = b'' def frame(self, message): return message + b'\n' async def receive_message(self): parts = [] buffer_size = 0 while True: part = self.residual self.residual = b'' if not part: part = await self.queue.get() npos = part.find(b'\n') if npos == -1: parts.append(part) buffer_size += len(part) # Ignore over-sized messages; re-synchronize if buffer_size <= self.max_size: continue self.synchronizing = True raise MemoryError(f'dropping message over {self.max_size:,d} ' f'bytes and re-synchronizing') tail, self.residual = part[:npos], part[npos + 1:] if self.synchronizing: self.synchronizing = False return await self.receive_message() else: parts.append(tail) return b''.join(parts) class ByteQueue: """A producer-comsumer queue. Incoming network data is put as it arrives, and the consumer calls an async method waiting for data of a specific length.""" def __init__(self): self.queue = Queue() self.parts = [] self.parts_len = 0 self.put_nowait = self.queue.put_nowait async def receive(self, size): while self.parts_len < size: part = await self.queue.get() self.parts.append(part) self.parts_len += len(part) self.parts_len -= size whole = b''.join(self.parts) self.parts = [whole[size:]] return whole[:size] class BinaryFramer: """A framer for binary messaging protocols.""" def __init__(self): self.byte_queue = ByteQueue() self.message_queue = Queue() self.received_bytes = self.byte_queue.put_nowait def frame(self, message): command, payload = message return b''.join(( self._build_header(command, payload), payload )) async def receive_message(self): command, payload_len, checksum = await self._receive_header() payload = await self.byte_queue.receive(payload_len) payload_checksum = self._checksum(payload) if payload_checksum != checksum: raise BadChecksumError(payload_checksum, checksum) return command, payload def _checksum(self, payload): raise NotImplementedError def _build_header(self, command, payload): raise NotImplementedError async def _receive_header(self): raise NotImplementedError # Helpers struct_le_I = Struct('<I') pack_le_uint32 = struct_le_I.pack def sha256(x): """Simple wrapper of hashlib sha256.""" return _sha256(x).digest() def double_sha256(x): """SHA-256 of SHA-256, as used extensively in bitcoin.""" return sha256(sha256(x)) class BadChecksumError(Exception): pass class BadMagicError(Exception): pass class OversizedPayloadError(Exception): pass class BitcoinFramer(BinaryFramer): """Provides a framer of binary message payloads in the style of the Bitcoin network protocol. Each binary message has the following elements, in order: Magic - to confirm network (currently unused for stream sync) Command - padded command Length - payload length in bytes Checksum - checksum of the payload Payload - binary payload Call frame(command, payload) to get a framed message. Pass incoming network bytes to received_bytes(). Wait on receive_message() to get incoming (command, payload) pairs. """ def __init__(self, magic, max_block_size): def pad_command(command): fill = 12 - len(command) if fill < 0: raise ValueError(f'command {command} too long') return command + bytes(fill) super().__init__() self._magic = magic self._max_block_size = max_block_size self._pad_command = pad_command self._unpack = Struct(f'<4s12sI4s').unpack def _checksum(self, payload): return double_sha256(payload)[:4] def _build_header(self, command, payload): return b''.join(( self._magic, self._pad_command(command), pack_le_uint32(len(payload)), self._checksum(payload) )) async def _receive_header(self): header = await self.byte_queue.receive(24) magic, command, payload_len, checksum = self._unpack(header) if magic != self._magic: raise BadMagicError(magic, self._magic) command = command.rstrip(b'\0') if payload_len > 1024 * 1024: if command != b'block' or payload_len > self._max_block_size: raise OversizedPayloadError(command, payload_len) return command, payload_len, checksum