// Copyright (c) 2013-2015 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package wire import ( "crypto/rand" "encoding/binary" "fmt" "io" "math" "github.com/btcsuite/fastsha256" ) // Maximum payload size for a variable length integer. const MaxVarIntPayload = 9 // errNonCanonicalVarInt is the common format string used for non-canonically // encoded variable length integer errors. var errNonCanonicalVarInt = "non-canonical varint %x - discriminant %x must " + "encode a value greater than %x" // readElement reads the next sequence of bytes from r using little endian // depending on the concrete type of element pointed to. func readElement(r io.Reader, element interface{}) error { var scratch [8]byte // Attempt to read the element based on the concrete type via fast // type assertions first. switch e := element.(type) { case *int32: b := scratch[0:4] _, err := io.ReadFull(r, b) if err != nil { return err } *e = int32(binary.LittleEndian.Uint32(b)) return nil case *uint32: b := scratch[0:4] _, err := io.ReadFull(r, b) if err != nil { return err } *e = binary.LittleEndian.Uint32(b) return nil case *int64: b := scratch[0:8] _, err := io.ReadFull(r, b) if err != nil { return err } *e = int64(binary.LittleEndian.Uint64(b)) return nil case *uint64: b := scratch[0:8] _, err := io.ReadFull(r, b) if err != nil { return err } *e = binary.LittleEndian.Uint64(b) return nil case *bool: b := scratch[0:1] _, err := io.ReadFull(r, b) if err != nil { return err } if b[0] == 0x00 { *e = false } else { *e = true } return nil // Message header checksum. case *[4]byte: _, err := io.ReadFull(r, e[:]) if err != nil { return err } return nil // Message header command. case *[CommandSize]uint8: _, err := io.ReadFull(r, e[:]) if err != nil { return err } return nil // IP address. case *[16]byte: _, err := io.ReadFull(r, e[:]) if err != nil { return err } return nil case *ShaHash: _, err := io.ReadFull(r, e[:]) if err != nil { return err } return nil case *ServiceFlag: b := scratch[0:8] _, err := io.ReadFull(r, b) if err != nil { return err } *e = ServiceFlag(binary.LittleEndian.Uint64(b)) return nil case *InvType: b := scratch[0:4] _, err := io.ReadFull(r, b) if err != nil { return err } *e = InvType(binary.LittleEndian.Uint32(b)) return nil case *BitcoinNet: b := scratch[0:4] _, err := io.ReadFull(r, b) if err != nil { return err } *e = BitcoinNet(binary.LittleEndian.Uint32(b)) return nil case *BloomUpdateType: b := scratch[0:1] _, err := io.ReadFull(r, b) if err != nil { return err } *e = BloomUpdateType(b[0]) return nil case *RejectCode: b := scratch[0:1] _, err := io.ReadFull(r, b) if err != nil { return err } *e = RejectCode(b[0]) return nil } // Fall back to the slower binary.Read if a fast path was not available // above. return binary.Read(r, binary.LittleEndian, element) } // readElements reads multiple items from r. It is equivalent to multiple // calls to readElement. func readElements(r io.Reader, elements ...interface{}) error { for _, element := range elements { err := readElement(r, element) if err != nil { return err } } return nil } // writeElement writes the little endian representation of element to w. func writeElement(w io.Writer, element interface{}) error { var scratch [8]byte // Attempt to write the element based on the concrete type via fast // type assertions first. switch e := element.(type) { case int32: b := scratch[0:4] binary.LittleEndian.PutUint32(b, uint32(e)) _, err := w.Write(b) if err != nil { return err } return nil case uint32: b := scratch[0:4] binary.LittleEndian.PutUint32(b, e) _, err := w.Write(b) if err != nil { return err } return nil case int64: b := scratch[0:8] binary.LittleEndian.PutUint64(b, uint64(e)) _, err := w.Write(b) if err != nil { return err } return nil case uint64: b := scratch[0:8] binary.LittleEndian.PutUint64(b, e) _, err := w.Write(b) if err != nil { return err } return nil case bool: b := scratch[0:1] if e == true { b[0] = 0x01 } else { b[0] = 0x00 } _, err := w.Write(b) if err != nil { return err } return nil // Message header checksum. case [4]byte: _, err := w.Write(e[:]) if err != nil { return err } return nil // Message header command. case [CommandSize]uint8: _, err := w.Write(e[:]) if err != nil { return err } return nil // IP address. case [16]byte: _, err := w.Write(e[:]) if err != nil { return err } return nil case *ShaHash: _, err := w.Write(e[:]) if err != nil { return err } return nil case ServiceFlag: b := scratch[0:8] binary.LittleEndian.PutUint64(b, uint64(e)) _, err := w.Write(b) if err != nil { return err } return nil case InvType: b := scratch[0:4] binary.LittleEndian.PutUint32(b, uint32(e)) _, err := w.Write(b) if err != nil { return err } return nil case BitcoinNet: b := scratch[0:4] binary.LittleEndian.PutUint32(b, uint32(e)) _, err := w.Write(b) if err != nil { return err } return nil case BloomUpdateType: b := scratch[0:1] b[0] = uint8(e) _, err := w.Write(b) if err != nil { return err } return nil case RejectCode: b := scratch[0:1] b[0] = uint8(e) _, err := w.Write(b) if err != nil { return err } return nil } // Fall back to the slower binary.Write if a fast path was not available // above. return binary.Write(w, binary.LittleEndian, element) } // writeElements writes multiple items to w. It is equivalent to multiple // calls to writeElement. func writeElements(w io.Writer, elements ...interface{}) error { for _, element := range elements { err := writeElement(w, element) if err != nil { return err } } return nil } // readVarInt reads a variable length integer from r and returns it as a uint64. func readVarInt(r io.Reader, pver uint32) (uint64, error) { var b [8]byte _, err := io.ReadFull(r, b[0:1]) if err != nil { return 0, err } var rv uint64 discriminant := uint8(b[0]) switch discriminant { case 0xff: _, err := io.ReadFull(r, b[:]) if err != nil { return 0, err } rv = binary.LittleEndian.Uint64(b[:]) // The encoding is not canonical if the value could have been // encoded using fewer bytes. min := uint64(0x100000000) if rv < min { return 0, messageError("readVarInt", fmt.Sprintf( errNonCanonicalVarInt, rv, discriminant, min)) } case 0xfe: _, err := io.ReadFull(r, b[0:4]) if err != nil { return 0, err } rv = uint64(binary.LittleEndian.Uint32(b[:])) // The encoding is not canonical if the value could have been // encoded using fewer bytes. min := uint64(0x10000) if rv < min { return 0, messageError("readVarInt", fmt.Sprintf( errNonCanonicalVarInt, rv, discriminant, min)) } case 0xfd: _, err := io.ReadFull(r, b[0:2]) if err != nil { return 0, err } rv = uint64(binary.LittleEndian.Uint16(b[:])) // The encoding is not canonical if the value could have been // encoded using fewer bytes. min := uint64(0xfd) if rv < min { return 0, messageError("readVarInt", fmt.Sprintf( errNonCanonicalVarInt, rv, discriminant, min)) } default: rv = uint64(discriminant) } return rv, nil } // writeVarInt serializes val to w using a variable number of bytes depending // on its value. func writeVarInt(w io.Writer, pver uint32, val uint64) error { if val < 0xfd { _, err := w.Write([]byte{uint8(val)}) return err } if val <= math.MaxUint16 { var buf [3]byte buf[0] = 0xfd binary.LittleEndian.PutUint16(buf[1:], uint16(val)) _, err := w.Write(buf[:]) return err } if val <= math.MaxUint32 { var buf [5]byte buf[0] = 0xfe binary.LittleEndian.PutUint32(buf[1:], uint32(val)) _, err := w.Write(buf[:]) return err } var buf [9]byte buf[0] = 0xff binary.LittleEndian.PutUint64(buf[1:], val) _, err := w.Write(buf[:]) return err } // VarIntSerializeSize returns the number of bytes it would take to serialize // val as a variable length integer. func VarIntSerializeSize(val uint64) int { // The value is small enough to be represented by itself, so it's // just 1 byte. if val < 0xfd { return 1 } // Discriminant 1 byte plus 2 bytes for the uint16. if val <= math.MaxUint16 { return 3 } // Discriminant 1 byte plus 4 bytes for the uint32. if val <= math.MaxUint32 { return 5 } // Discriminant 1 byte plus 8 bytes for the uint64. return 9 } // readVarString reads a variable length string from r and returns it as a Go // string. A varString is encoded as a varInt containing the length of the // string, and the bytes that represent the string itself. An error is returned // if the length is greater than the maximum block payload size, since it would // not be possible to put a varString of that size into a block anyways and it // also helps protect against memory exhaustion attacks and forced panics // through malformed messages. func readVarString(r io.Reader, pver uint32) (string, error) { count, err := readVarInt(r, pver) if err != nil { return "", err } // Prevent variable length strings that are larger than the maximum // message size. It would be possible to cause memory exhaustion and // panics without a sane upper bound on this count. if count > MaxMessagePayload { str := fmt.Sprintf("variable length string is too long "+ "[count %d, max %d]", count, MaxMessagePayload) return "", messageError("readVarString", str) } buf := make([]byte, count) _, err = io.ReadFull(r, buf) if err != nil { return "", err } return string(buf), nil } // writeVarString serializes str to w as a varInt containing the length of the // string followed by the bytes that represent the string itself. func writeVarString(w io.Writer, pver uint32, str string) error { err := writeVarInt(w, pver, uint64(len(str))) if err != nil { return err } _, err = w.Write([]byte(str)) if err != nil { return err } return nil } // readVarBytes reads a variable length byte array. A byte array is encoded // as a varInt containing the length of the array followed by the bytes // themselves. An error is returned if the length is greater than the // passed maxAllowed parameter which helps protect against memory exhuastion // attacks and forced panics thorugh malformed messages. The fieldName // parameter is only used for the error message so it provides more context in // the error. func readVarBytes(r io.Reader, pver uint32, maxAllowed uint32, fieldName string) ([]byte, error) { count, err := readVarInt(r, pver) if err != nil { return nil, err } // Prevent byte array larger than the max message size. It would // be possible to cause memory exhaustion and panics without a sane // upper bound on this count. if count > uint64(maxAllowed) { str := fmt.Sprintf("%s is larger than the max allowed size "+ "[count %d, max %d]", fieldName, count, maxAllowed) return nil, messageError("readVarBytes", str) } b := make([]byte, count) _, err = io.ReadFull(r, b) if err != nil { return nil, err } return b, nil } // writeVarInt serializes a variable length byte array to w as a varInt // containing the number of bytes, followed by the bytes themselves. func writeVarBytes(w io.Writer, pver uint32, bytes []byte) error { slen := uint64(len(bytes)) err := writeVarInt(w, pver, slen) if err != nil { return err } _, err = w.Write(bytes) if err != nil { return err } return nil } // randomUint64 returns a cryptographically random uint64 value. This // unexported version takes a reader primarily to ensure the error paths // can be properly tested by passing a fake reader in the tests. func randomUint64(r io.Reader) (uint64, error) { var b [8]byte _, err := io.ReadFull(r, b[:]) if err != nil { return 0, err } return binary.BigEndian.Uint64(b[:]), nil } // RandomUint64 returns a cryptographically random uint64 value. func RandomUint64() (uint64, error) { return randomUint64(rand.Reader) } // DoubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes. func DoubleSha256(b []byte) []byte { first := fastsha256.Sum256(b) second := fastsha256.Sum256(first[:]) return second[:] } // DoubleSha256SH calculates sha256(sha256(b)) and returns the resulting bytes // as a ShaHash. func DoubleSha256SH(b []byte) ShaHash { first := fastsha256.Sum256(b) return ShaHash(fastsha256.Sum256(first[:])) }