lbcd/message.go
Dave Collins 0400d0cec3 Remove dead error check in WriteMessage.
The io.Writer.Write function always returns an error if the bytes written
is less than provided, so there is no reason to further check if the
payload length matches after a successful write.
2013-05-11 23:22:09 -05:00

281 lines
7.1 KiB
Go

// Copyright (c) 2013 Conformal Systems LLC.
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package btcwire
import (
"bytes"
"fmt"
"io"
"unicode/utf8"
)
// commandSize is the fixed size of all commands in the common bitcoin message
// header. Shorter commands must be zero padded.
const commandSize = 12
// maxMessagePayload is the maximum byes a message can be regardless of other
// individual limits imposed by messages themselves.
const maxMessagePayload = (1024 * 1024 * 32) // 32MB
// Commands used in bitcoin message headers which describe the type of message.
const (
cmdVersion = "version"
cmdVerAck = "verack"
cmdGetAddr = "getaddr"
cmdAddr = "addr"
cmdGetBlocks = "getblocks"
cmdInv = "inv"
cmdGetData = "getdata"
cmdNotFound = "notfound"
cmdBlock = "block"
cmdTx = "tx"
cmdGetHeaders = "getheaders"
cmdHeaders = "headers"
cmdPing = "ping"
cmdPong = "pong"
cmdAlert = "alert"
cmdMemPool = "mempool"
)
// Message is an interface that describes a bitcoin message. A type that
// implements Message has complete control over the representation of its data
// and may therefore contain additional or fewer fields than those which
// are used directly in the protocol encoded message.
type Message interface {
BtcDecode(io.Reader, uint32) error
BtcEncode(io.Writer, uint32) error
Command() string
MaxPayloadLength(uint32) uint32
}
// makeEmptyMessage creates a message of the appropriate concrete type based
// on the command.
func makeEmptyMessage(command string) (Message, error) {
var msg Message
switch command {
case cmdVersion:
msg = &MsgVersion{}
case cmdVerAck:
msg = &MsgVerAck{}
case cmdGetAddr:
msg = &MsgGetAddr{}
case cmdAddr:
msg = &MsgAddr{}
case cmdGetBlocks:
msg = &MsgGetBlocks{}
case cmdBlock:
msg = &MsgBlock{}
case cmdInv:
msg = &MsgInv{}
case cmdGetData:
msg = &MsgGetData{}
case cmdNotFound:
msg = &MsgNotFound{}
case cmdTx:
msg = &MsgTx{}
case cmdPing:
msg = &MsgPing{}
case cmdPong:
msg = &MsgPong{}
case cmdGetHeaders:
msg = &MsgGetHeaders{}
case cmdHeaders:
msg = &MsgHeaders{}
case cmdAlert:
msg = &MsgAlert{}
case cmdMemPool:
msg = &MsgMemPool{}
default:
return nil, fmt.Errorf("unhandled command [%s]", command)
}
return msg, nil
}
// messageHeader defines the header structure for all bitcoin protocol messages.
type messageHeader struct {
magic BitcoinNet // 4 bytes
command string // 12 bytes
length uint32 // 4 bytes
checksum [4]byte // 4 bytes
}
// readMessageHeader reads a bitcoin message header from r.
func readMessageHeader(r io.Reader) (*messageHeader, error) {
var command [commandSize]byte
hdr := messageHeader{}
err := readElements(r, &hdr.magic, &command, &hdr.length, &hdr.checksum)
if err != nil {
return nil, err
}
// Strip trailing zeros from command string.
hdr.command = string(bytes.TrimRight(command[:], string(0)))
return &hdr, nil
}
// discardInput reads n bytes from reader r in chunks and discards the read
// bytes. This is used to skip payloads when various errors occur and helps
// prevent rogue nodes from causing massive memory allocation through forging
// header length.
func discardInput(r io.Reader, n uint32) {
maxSize := uint32(10 * 1024) // 10k at a time
numReads := n / maxSize
bytesRemaining := n % maxSize
if n > 0 {
buf := make([]byte, maxSize)
for i := uint32(0); i < numReads; i++ {
io.ReadFull(r, buf)
}
}
if bytesRemaining > 0 {
buf := make([]byte, bytesRemaining)
io.ReadFull(r, buf)
}
}
// WriteMessage writes a bitcoin Message to w including the necessary header
// information.
func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet BitcoinNet) error {
var command [commandSize]byte
cmd := msg.Command()
if len(cmd) > commandSize {
str := fmt.Sprintf("command [%s] is too long [max %v]",
cmd, commandSize)
return messageError("WriteMessage", str)
}
copy(command[:], []byte(cmd))
var bw bytes.Buffer
err := msg.BtcEncode(&bw, pver)
if err != nil {
return err
}
payload := bw.Bytes()
lenp := len(payload)
// Enforce maximum message payload.
if lenp > maxMessagePayload {
str := fmt.Sprintf("message payload is too large - encoded "+
"%d bytes, but maximum message payload is %d bytes",
lenp, maxMessagePayload)
return messageError("WriteMessage", str)
}
// Create header for the message.
hdr := messageHeader{}
hdr.magic = btcnet
hdr.command = cmd
hdr.length = uint32(lenp)
copy(hdr.checksum[:], DoubleSha256(payload)[0:4])
// Write header.
err = writeElements(w, hdr.magic, command, hdr.length, hdr.checksum)
if err != nil {
return err
}
// Write payload.
_, err = w.Write(payload)
if err != nil {
return err
}
return nil
}
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
// the provided protocol version and bitcoin network.
func ReadMessage(r io.Reader, pver uint32, btcnet BitcoinNet) (Message, []byte, error) {
hdr, err := readMessageHeader(r)
if err != nil {
return nil, nil, err
}
// Enforce maximum message payload.
if hdr.length > maxMessagePayload {
str := fmt.Sprintf("message payload is too large - header "+
"indicates %d bytes, but max message payload is %d "+
"bytes.", hdr.length, maxMessagePayload)
return nil, nil, messageError("ReadMessage", str)
}
// Check for messages from the wrong bitcoin network.
if hdr.magic != btcnet {
discardInput(r, hdr.length)
str := fmt.Sprintf("message from other network [%v]", hdr.magic)
return nil, nil, messageError("ReadMessage", str)
}
// Check for malformed commands.
command := hdr.command
if !utf8.ValidString(command) {
discardInput(r, hdr.length)
str := fmt.Sprintf("invalid command %v", []byte(command))
return nil, nil, messageError("ReadMessage", str)
}
// Create struct of appropriate message type based on the command.
msg, err := makeEmptyMessage(command)
if err != nil {
discardInput(r, hdr.length)
return nil, nil, messageError("ReadMessage", err.Error())
}
// Check for maximum length based on the message type as a malicious client
// could otherwise create a well-formed header and set the length to max
// numbers in order to exhaust the machine's memory.
mpl := msg.MaxPayloadLength(pver)
if hdr.length > mpl {
discardInput(r, hdr.length)
str := fmt.Sprintf("payload exceeds max length - header "+
"indicates %v bytes, but max payload size for "+
"messages of type [%v] is %v.", hdr.length, command, mpl)
return nil, nil, messageError("ReadMessage", str)
}
// Read payload.
payload := make([]byte, hdr.length)
_, err = io.ReadFull(r, payload)
if err != nil {
return nil, nil, err
}
// Test checksum.
checksum := DoubleSha256(payload)[0:4]
if !bytes.Equal(checksum[:], hdr.checksum[:]) {
str := fmt.Sprintf("payload checksum failed - header "+
"indicates %v, but actual checksum is %v.",
hdr.checksum, checksum)
return nil, nil, messageError("ReadMessage", str)
}
// Unmarshal message.
pr := bytes.NewBuffer(payload)
err = msg.BtcDecode(pr, pver)
if err != nil {
return nil, nil, err
}
return msg, payload, nil
}