lbry-sdk/lbry/wallet/rpc/framing.py
2020-01-01 15:57:56 -05:00

240 lines
7.6 KiB
Python

# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""RPC message framing in a byte stream."""
__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer',
'OversizedPayloadError', 'BadChecksumError', 'BadMagicError')
from hashlib import sha256 as _sha256
from struct import Struct
from asyncio import Queue
class FramerBase:
"""Abstract base class for a framer.
A framer breaks an incoming byte stream into protocol messages,
buffering if necessary. It also frames outgoing messages into
a byte stream.
"""
def frame(self, message):
"""Return the framed message."""
raise NotImplementedError
def received_bytes(self, data):
"""Pass incoming network bytes."""
raise NotImplementedError
async def receive_message(self):
"""Wait for a complete unframed message to arrive, and return it."""
raise NotImplementedError
class NewlineFramer(FramerBase):
"""A framer for a protocol where messages are separated by newlines."""
# The default max_size value is motivated by JSONRPC, where a
# normal request will be 250 bytes or less, and a reasonable
# batch may contain 4000 requests.
def __init__(self, max_size=250 * 4000):
"""max_size - an anti-DoS measure. If, after processing an incoming
message, buffered data would exceed max_size bytes, that
buffered data is dropped entirely and the framer waits for a
newline character to re-synchronize the stream.
"""
self.max_size = max_size
self.queue = Queue()
self.received_bytes = self.queue.put_nowait
self.synchronizing = False
self.residual = b''
def frame(self, message):
return message + b'\n'
async def receive_message(self):
parts = []
buffer_size = 0
while True:
part = self.residual
self.residual = b''
if not part:
part = await self.queue.get()
npos = part.find(b'\n')
if npos == -1:
parts.append(part)
buffer_size += len(part)
# Ignore over-sized messages; re-synchronize
if buffer_size <= self.max_size:
continue
self.synchronizing = True
raise MemoryError(f'dropping message over {self.max_size:,d} '
f'bytes and re-synchronizing')
tail, self.residual = part[:npos], part[npos + 1:]
if self.synchronizing:
self.synchronizing = False
return await self.receive_message()
else:
parts.append(tail)
return b''.join(parts)
class ByteQueue:
"""A producer-comsumer queue. Incoming network data is put as it
arrives, and the consumer calls an async method waiting for data of
a specific length."""
def __init__(self):
self.queue = Queue()
self.parts = []
self.parts_len = 0
self.put_nowait = self.queue.put_nowait
async def receive(self, size):
while self.parts_len < size:
part = await self.queue.get()
self.parts.append(part)
self.parts_len += len(part)
self.parts_len -= size
whole = b''.join(self.parts)
self.parts = [whole[size:]]
return whole[:size]
class BinaryFramer:
"""A framer for binary messaging protocols."""
def __init__(self):
self.byte_queue = ByteQueue()
self.message_queue = Queue()
self.received_bytes = self.byte_queue.put_nowait
def frame(self, message):
command, payload = message
return b''.join((
self._build_header(command, payload),
payload
))
async def receive_message(self):
command, payload_len, checksum = await self._receive_header()
payload = await self.byte_queue.receive(payload_len)
payload_checksum = self._checksum(payload)
if payload_checksum != checksum:
raise BadChecksumError(payload_checksum, checksum)
return command, payload
def _checksum(self, payload):
raise NotImplementedError
def _build_header(self, command, payload):
raise NotImplementedError
async def _receive_header(self):
raise NotImplementedError
# Helpers
struct_le_I = Struct('<I')
pack_le_uint32 = struct_le_I.pack
def sha256(x):
"""Simple wrapper of hashlib sha256."""
return _sha256(x).digest()
def double_sha256(x):
"""SHA-256 of SHA-256, as used extensively in bitcoin."""
return sha256(sha256(x))
class BadChecksumError(Exception):
pass
class BadMagicError(Exception):
pass
class OversizedPayloadError(Exception):
pass
class BitcoinFramer(BinaryFramer):
"""Provides a framer of binary message payloads in the style of the
Bitcoin network protocol.
Each binary message has the following elements, in order:
Magic - to confirm network (currently unused for stream sync)
Command - padded command
Length - payload length in bytes
Checksum - checksum of the payload
Payload - binary payload
Call frame(command, payload) to get a framed message.
Pass incoming network bytes to received_bytes().
Wait on receive_message() to get incoming (command, payload) pairs.
"""
def __init__(self, magic, max_block_size):
def pad_command(command):
fill = 12 - len(command)
if fill < 0:
raise ValueError(f'command {command} too long')
return command + bytes(fill)
super().__init__()
self._magic = magic
self._max_block_size = max_block_size
self._pad_command = pad_command
self._unpack = Struct(f'<4s12sI4s').unpack
def _checksum(self, payload):
return double_sha256(payload)[:4]
def _build_header(self, command, payload):
return b''.join((
self._magic,
self._pad_command(command),
pack_le_uint32(len(payload)),
self._checksum(payload)
))
async def _receive_header(self):
header = await self.byte_queue.receive(24)
magic, command, payload_len, checksum = self._unpack(header)
if magic != self._magic:
raise BadMagicError(magic, self._magic)
command = command.rstrip(b'\0')
if payload_len > 1024 * 1024:
if command != b'block' or payload_len > self._max_block_size:
raise OversizedPayloadError(command, payload_len)
return command, payload_len, checksum