lbcd/message.go
David Hill b6e6fc25db Preallocate memory for the message header in WriteMessageN and the
serialization buffer in MsgTx's TxSha().

Benchmarking shows this is slightly faster due to avoiding the extra
garbage collection in addition to less peak memory usage.

Before: BenchmarkTxShaOld  500000         5626 ns/op
After:  BenchmarkTxShaNew  500000         5457 ns/op
2014-03-20 14:33:44 -04:00

348 lines
10 KiB
Go

// Copyright (c) 2013-2014 Conformal Systems LLC.
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package btcwire
import (
"bytes"
"fmt"
"io"
"unicode/utf8"
)
// MessageHeaderSize is the number of bytes in a bitcoin message header.
// Bitcoin network (magic) 4 bytes + command 12 bytes + payload length 4 bytes +
// checksum 4 bytes.
const MessageHeaderSize = 24
// commandSize is the fixed size of all commands in the common bitcoin message
// header. Shorter commands must be zero padded.
const commandSize = 12
// maxMessagePayload is the maximum bytes a message can be regardless of other
// individual limits imposed by messages themselves.
const maxMessagePayload = (1024 * 1024 * 32) // 32MB
// Commands used in bitcoin message headers which describe the type of message.
const (
cmdVersion = "version"
cmdVerAck = "verack"
cmdGetAddr = "getaddr"
cmdAddr = "addr"
cmdGetBlocks = "getblocks"
cmdInv = "inv"
cmdGetData = "getdata"
cmdNotFound = "notfound"
cmdBlock = "block"
cmdTx = "tx"
cmdGetHeaders = "getheaders"
cmdHeaders = "headers"
cmdPing = "ping"
cmdPong = "pong"
cmdAlert = "alert"
cmdMemPool = "mempool"
)
// Message is an interface that describes a bitcoin message. A type that
// implements Message has complete control over the representation of its data
// and may therefore contain additional or fewer fields than those which
// are used directly in the protocol encoded message.
type Message interface {
BtcDecode(io.Reader, uint32) error
BtcEncode(io.Writer, uint32) error
Command() string
MaxPayloadLength(uint32) uint32
}
// makeEmptyMessage creates a message of the appropriate concrete type based
// on the command.
func makeEmptyMessage(command string) (Message, error) {
var msg Message
switch command {
case cmdVersion:
msg = &MsgVersion{}
case cmdVerAck:
msg = &MsgVerAck{}
case cmdGetAddr:
msg = &MsgGetAddr{}
case cmdAddr:
msg = &MsgAddr{}
case cmdGetBlocks:
msg = &MsgGetBlocks{}
case cmdBlock:
msg = &MsgBlock{}
case cmdInv:
msg = &MsgInv{}
case cmdGetData:
msg = &MsgGetData{}
case cmdNotFound:
msg = &MsgNotFound{}
case cmdTx:
msg = &MsgTx{}
case cmdPing:
msg = &MsgPing{}
case cmdPong:
msg = &MsgPong{}
case cmdGetHeaders:
msg = &MsgGetHeaders{}
case cmdHeaders:
msg = &MsgHeaders{}
case cmdAlert:
msg = &MsgAlert{}
case cmdMemPool:
msg = &MsgMemPool{}
default:
return nil, fmt.Errorf("unhandled command [%s]", command)
}
return msg, nil
}
// messageHeader defines the header structure for all bitcoin protocol messages.
type messageHeader struct {
magic BitcoinNet // 4 bytes
command string // 12 bytes
length uint32 // 4 bytes
checksum [4]byte // 4 bytes
}
// readMessageHeader reads a bitcoin message header from r.
func readMessageHeader(r io.Reader) (int, *messageHeader, error) {
// Since readElements doesn't return the amount of bytes read, attempt
// to read the entire header into a buffer first in case there is a
// short read so the proper amount of read bytes are known. This works
// since the header is a fixed size.
headerBytes := make([]byte, MessageHeaderSize)
n, err := io.ReadFull(r, headerBytes)
if err != nil {
return n, nil, err
}
hr := bytes.NewBuffer(headerBytes)
// Create and populate a messageHeader struct from the raw header bytes.
hdr := messageHeader{}
var command [commandSize]byte
readElements(hr, &hdr.magic, &command, &hdr.length, &hdr.checksum)
// Strip trailing zeros from command string.
hdr.command = string(bytes.TrimRight(command[:], string(0)))
return n, &hdr, nil
}
// discardInput reads n bytes from reader r in chunks and discards the read
// bytes. This is used to skip payloads when various errors occur and helps
// prevent rogue nodes from causing massive memory allocation through forging
// header length.
func discardInput(r io.Reader, n uint32) {
maxSize := uint32(10 * 1024) // 10k at a time
numReads := n / maxSize
bytesRemaining := n % maxSize
if n > 0 {
buf := make([]byte, maxSize)
for i := uint32(0); i < numReads; i++ {
io.ReadFull(r, buf)
}
}
if bytesRemaining > 0 {
buf := make([]byte, bytesRemaining)
io.ReadFull(r, buf)
}
}
// WriteMessageN writes a bitcoin Message to w including the necessary header
// information and returns the number of bytes written. This function is the
// same as WriteMessage except it also returns the number of bytes written.
func WriteMessageN(w io.Writer, msg Message, pver uint32, btcnet BitcoinNet) (int, error) {
totalBytes := 0
// Enforce max command size.
var command [commandSize]byte
cmd := msg.Command()
if len(cmd) > commandSize {
str := fmt.Sprintf("command [%s] is too long [max %v]",
cmd, commandSize)
return totalBytes, messageError("WriteMessage", str)
}
copy(command[:], []byte(cmd))
// Encode the message payload.
var bw bytes.Buffer
err := msg.BtcEncode(&bw, pver)
if err != nil {
return totalBytes, err
}
payload := bw.Bytes()
lenp := len(payload)
// Enforce maximum overall message payload.
if lenp > maxMessagePayload {
str := fmt.Sprintf("message payload is too large - encoded "+
"%d bytes, but maximum message payload is %d bytes",
lenp, maxMessagePayload)
return totalBytes, messageError("WriteMessage", str)
}
// Enforce maximum message payload based on the message type.
mpl := msg.MaxPayloadLength(pver)
if uint32(lenp) > mpl {
str := fmt.Sprintf("message payload is too large - encoded "+
"%d bytes, but maximum message payload size for "+
"messages of type [%s] is %d.", lenp, cmd, mpl)
return totalBytes, messageError("WriteMessage", str)
}
// Create header for the message.
hdr := messageHeader{}
hdr.magic = btcnet
hdr.command = cmd
hdr.length = uint32(lenp)
copy(hdr.checksum[:], DoubleSha256(payload)[0:4])
// Encode the header for the message. This is done to a buffer
// rather than directly to the writer since writeElements doesn't
// return the number of bytes written.
hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize))
writeElements(hw, hdr.magic, command, hdr.length, hdr.checksum)
// Write header.
n, err := w.Write(hw.Bytes())
if err != nil {
totalBytes += n
return totalBytes, err
}
totalBytes += n
// Write payload.
n, err = w.Write(payload)
if err != nil {
totalBytes += n
return totalBytes, err
}
totalBytes += n
return totalBytes, nil
}
// WriteMessage writes a bitcoin Message to w including the necessary header
// information. This function is the same as WriteMessageN except it doesn't
// doesn't return the number of bytes written. This function is mainly provided
// for backwards compatibility with the original API, but it's also useful for
// callers that don't care about byte counts.
func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet BitcoinNet) error {
_, err := WriteMessageN(w, msg, pver, btcnet)
return err
}
// ReadMessageN reads, validates, and parses the next bitcoin Message from r for
// the provided protocol version and bitcoin network. It returns the number of
// bytes read in addition to the parsed Message and raw bytes which comprise the
// message. This function is the same as ReadMessage except it also returns the
// number of bytes read.
func ReadMessageN(r io.Reader, pver uint32, btcnet BitcoinNet) (int, Message, []byte, error) {
totalBytes := 0
n, hdr, err := readMessageHeader(r)
if err != nil {
totalBytes += n
return totalBytes, nil, nil, err
}
totalBytes += n
// Enforce maximum message payload.
if hdr.length > maxMessagePayload {
str := fmt.Sprintf("message payload is too large - header "+
"indicates %d bytes, but max message payload is %d "+
"bytes.", hdr.length, maxMessagePayload)
return totalBytes, nil, nil, messageError("ReadMessage", str)
}
// Check for messages from the wrong bitcoin network.
if hdr.magic != btcnet {
discardInput(r, hdr.length)
str := fmt.Sprintf("message from other network [%v]", hdr.magic)
return totalBytes, nil, nil, messageError("ReadMessage", str)
}
// Check for malformed commands.
command := hdr.command
if !utf8.ValidString(command) {
discardInput(r, hdr.length)
str := fmt.Sprintf("invalid command %v", []byte(command))
return totalBytes, nil, nil, messageError("ReadMessage", str)
}
// Create struct of appropriate message type based on the command.
msg, err := makeEmptyMessage(command)
if err != nil {
discardInput(r, hdr.length)
return totalBytes, nil, nil, messageError("ReadMessage",
err.Error())
}
// Check for maximum length based on the message type as a malicious client
// could otherwise create a well-formed header and set the length to max
// numbers in order to exhaust the machine's memory.
mpl := msg.MaxPayloadLength(pver)
if hdr.length > mpl {
discardInput(r, hdr.length)
str := fmt.Sprintf("payload exceeds max length - header "+
"indicates %v bytes, but max payload size for "+
"messages of type [%v] is %v.", hdr.length, command, mpl)
return totalBytes, nil, nil, messageError("ReadMessage", str)
}
// Read payload.
payload := make([]byte, hdr.length)
n, err = io.ReadFull(r, payload)
if err != nil {
totalBytes += n
return totalBytes, nil, nil, err
}
totalBytes += n
// Test checksum.
checksum := DoubleSha256(payload)[0:4]
if !bytes.Equal(checksum[:], hdr.checksum[:]) {
str := fmt.Sprintf("payload checksum failed - header "+
"indicates %v, but actual checksum is %v.",
hdr.checksum, checksum)
return totalBytes, nil, nil, messageError("ReadMessage", str)
}
// Unmarshal message.
pr := bytes.NewBuffer(payload)
err = msg.BtcDecode(pr, pver)
if err != nil {
return totalBytes, nil, nil, err
}
return totalBytes, msg, payload, nil
}
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
// the provided protocol version and bitcoin network. It returns the parsed
// Message and raw bytes which comprise the message. This function only differs
// from ReadMessageN in that it doesn't return the number of bytes read. This
// function is mainly provided for backwards compatibility with the original
// API, but it's also useful for callers that don't care about byte counts.
func ReadMessage(r io.Reader, pver uint32, btcnet BitcoinNet) (Message, []byte, error) {
_, msg, buf, err := ReadMessageN(r, pver, btcnet)
return msg, buf, err
}