# Copyright (c) 2018, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """RPC message framing in a byte stream.""" __all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer', 'OversizedPayloadError', 'BadChecksumError', 'BadMagicError') from hashlib import sha256 as _sha256 from struct import Struct from asyncio import Queue class FramerBase: """Abstract base class for a framer. A framer breaks an incoming byte stream into protocol messages, buffering if necessary. It also frames outgoing messages into a byte stream. """ def frame(self, message): """Return the framed message.""" raise NotImplementedError def received_bytes(self, data): """Pass incoming network bytes.""" raise NotImplementedError async def receive_message(self): """Wait for a complete unframed message to arrive, and return it.""" raise NotImplementedError class NewlineFramer(FramerBase): """A framer for a protocol where messages are separated by newlines.""" # The default max_size value is motivated by JSONRPC, where a # normal request will be 250 bytes or less, and a reasonable # batch may contain 4000 requests. def __init__(self, max_size=250 * 4000): """max_size - an anti-DoS measure. If, after processing an incoming message, buffered data would exceed max_size bytes, that buffered data is dropped entirely and the framer waits for a newline character to re-synchronize the stream. """ self.max_size = max_size self.queue = Queue() self.received_bytes = self.queue.put_nowait self.synchronizing = False self.residual = b'' def frame(self, message): return message + b'\n' async def receive_message(self): parts = [] buffer_size = 0 while True: part = self.residual self.residual = b'' if not part: part = await self.queue.get() npos = part.find(b'\n') if npos == -1: parts.append(part) buffer_size += len(part) # Ignore over-sized messages; re-synchronize if buffer_size <= self.max_size: continue self.synchronizing = True raise MemoryError(f'dropping message over {self.max_size:,d} ' f'bytes and re-synchronizing') tail, self.residual = part[:npos], part[npos + 1:] if self.synchronizing: self.synchronizing = False return await self.receive_message() else: parts.append(tail) return b''.join(parts) class ByteQueue: """A producer-comsumer queue. Incoming network data is put as it arrives, and the consumer calls an async method waiting for data of a specific length.""" def __init__(self): self.queue = Queue() self.parts = [] self.parts_len = 0 self.put_nowait = self.queue.put_nowait async def receive(self, size): while self.parts_len < size: part = await self.queue.get() self.parts.append(part) self.parts_len += len(part) self.parts_len -= size whole = b''.join(self.parts) self.parts = [whole[size:]] return whole[:size] class BinaryFramer: """A framer for binary messaging protocols.""" def __init__(self): self.byte_queue = ByteQueue() self.message_queue = Queue() self.received_bytes = self.byte_queue.put_nowait def frame(self, message): command, payload = message return b''.join(( self._build_header(command, payload), payload )) async def receive_message(self): command, payload_len, checksum = await self._receive_header() payload = await self.byte_queue.receive(payload_len) payload_checksum = self._checksum(payload) if payload_checksum != checksum: raise BadChecksumError(payload_checksum, checksum) return command, payload def _checksum(self, payload): raise NotImplementedError def _build_header(self, command, payload): raise NotImplementedError async def _receive_header(self): raise NotImplementedError # Helpers struct_le_I = Struct(' 1024 * 1024: if command != b'block' or payload_len > self._max_block_size: raise OversizedPayloadError(command, payload_len) return command, payload_len, checksum