diff --git a/common.go b/common.go index 203a3a4e..eb337dc9 100644 --- a/common.go +++ b/common.go @@ -281,6 +281,32 @@ func writeElements(w io.Writer, elements ...interface{}) error { return nil } +// readVarBytes reads a variable length byte array +func readVarBytes(r io.Reader, pver uint32, maxAllowed uint32, + fieldName string) ([]byte, error) { + + count, err := readVarInt(r, pver) + if err != nil { + return nil, err + } + + // Prevent byte array larger than the max message size. It would + // be possible to cause memory exhaustion and panics without a sane + // upper bound on this count. + if count > uint64(maxAllowed) { + str := fmt.Sprintf("%s is larger than the max allowed size "+ + "[count %d, max %d]", fieldName, count, maxAllowed) + return nil, messageError("readVarBytes", str) + } + + b := make([]byte, count) + _, err = io.ReadFull(r, b) + if err != nil { + return nil, err + } + return b, nil +} + // readVarInt reads a variable length integer from r and returns it as a uint64. func readVarInt(r io.Reader, pver uint32) (uint64, error) { var b [8]byte @@ -320,6 +346,21 @@ func readVarInt(r io.Reader, pver uint32) (uint64, error) { return rv, nil } +// writeVarBytes writes a variable length byte array +func writeVarBytes(w io.Writer, pver uint32, bytes []byte) error { + slen := uint64(len(bytes)) + err := writeVarInt(w, pver, slen) + if err != nil { + return err + } + + _, err = w.Write(bytes) + if err != nil { + return err + } + return nil +} + // writeVarInt serializes val to w using a variable number of bytes depending // on its value. func writeVarInt(w io.Writer, pver uint32, val uint64) error { diff --git a/common_test.go b/common_test.go index 577d3190..3a534d47 100644 --- a/common_test.go +++ b/common_test.go @@ -501,6 +501,139 @@ func TestVarStringOverflowErrors(t *testing.T) { } +// TestVarBytesWire tests wire encode and decode for variable length byte array. +func TestVarBytesWire(t *testing.T) { + pver := btcwire.ProtocolVersion + + // bytes256 is a byte array that takes a 2-byte varint to encode. + bytes256 := bytes.Repeat([]byte{0x01}, 256) + + tests := []struct { + in []byte // Byte Array to write + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + }{ + // Latest protocol version. + // Empty byte array + {[]byte{}, []byte{0x00}, pver}, + // Single byte varint + byte array + {[]byte{0x01}, []byte{0x01, 0x01}, pver}, + // 2-byte varint + byte array + {bytes256, append([]byte{0xfd, 0x00, 0x01}, bytes256...), pver}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode to wire format. + var buf bytes.Buffer + err := btcwire.TstWriteVarBytes(&buf, test.pver, test.in) + if err != nil { + t.Errorf("writeVarBytes #%d error %v", i, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.buf) { + t.Errorf("writeVarBytes #%d\n got: %s want: %s", i, + spew.Sdump(buf.Bytes()), spew.Sdump(test.buf)) + continue + } + + // Decode from wire format. + rbuf := bytes.NewBuffer(test.buf) + val, err := btcwire.TstReadVarBytes(rbuf, test.pver, btcwire.MaxMessagePayload, + "alert serialized payload") + if err != nil { + t.Errorf("readVarBytes #%d error %v", i, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.buf) { + t.Errorf("readVarBytes #%d\n got: %s want: %s", i, + val, test.buf) + continue + } + } +} + +// TestVarBytesWireErrors performs negative tests against wire encode and +// decode of variable length byte arrays to confirm error paths work correctly. +func TestVarBytesWireErrors(t *testing.T) { + pver := btcwire.ProtocolVersion + + // bytes256 is a byte array that takes a 2-byte varint to encode. + bytes256 := bytes.Repeat([]byte{0x01}, 256) + + tests := []struct { + in []byte // Byte Array to write + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + max int // Max size of fixed buffer to induce errors + writeErr error // Expected write error + readErr error // Expected read error + }{ + // Latest protocol version with intentional read/write errors. + // Force errors on empty byte array. + {[]byte{}, []byte{0x00}, pver, 0, io.ErrShortWrite, io.EOF}, + // Force error on single byte varint + byte array. + {[]byte{0x01, 0x02, 0x03}, []byte{0x04}, pver, 2, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force errors on 2-byte varint + byte array. + {bytes256, []byte{0xfd}, pver, 2, io.ErrShortWrite, io.ErrUnexpectedEOF}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode to wire format. + w := newFixedWriter(test.max) + err := btcwire.TstWriteVarBytes(w, test.pver, test.in) + if err != test.writeErr { + t.Errorf("writeVarBytes #%d wrong error got: %v, want: %v", + i, err, test.writeErr) + continue + } + + // Decode from wire format. + r := newFixedReader(test.max, test.buf) + _, err = btcwire.TstReadVarBytes(r, test.pver, btcwire.MaxMessagePayload, + "alert serialized payload") + if err != test.readErr { + t.Errorf("readVarBytes #%d wrong error got: %v, want: %v", + i, err, test.readErr) + continue + } + } +} + +// TestVarBytesOverflowErrors performs tests to ensure deserializing variable +// length byte arrays intentionally crafted to use large values for the array +// length are handled properly. This could otherwise potentially be used as an +// attack vector. +func TestVarBytesOverflowErrors(t *testing.T) { + pver := btcwire.ProtocolVersion + + tests := []struct { + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + err error // Expected error + }{ + {[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + pver, &btcwire.MessageError{}}, + {[]byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, + pver, &btcwire.MessageError{}}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Decode from wire format. + rbuf := bytes.NewBuffer(test.buf) + _, err := btcwire.TstReadVarBytes(rbuf, test.pver, btcwire.MaxMessagePayload, + "alert serialized payload") + if reflect.TypeOf(err) != reflect.TypeOf(test.err) { + t.Errorf("readVarBytes #%d wrong error got: %v, "+ + "want: %v", i, err, reflect.TypeOf(test.err)) + continue + } + } + +} + // TestRandomUint64 exercises the randomness of the random number generator on // the system by ensuring the probability of the generated numbers. If the RNG // is evenly distributed as a proper cryptographic RNG should be, there really diff --git a/internal_test.go b/internal_test.go index 2bb0c380..031d3068 100644 --- a/internal_test.go +++ b/internal_test.go @@ -19,6 +19,14 @@ import ( // the test package. const MaxMessagePayload uint32 = maxMessagePayload +// MaxCountSetCancel makes the internal maxCountSetCancel constant available to +// the test package. +const MaxCountSetCancel uint32 = maxCountSetCancel + +// MaxCountSetSubVer makes the internal maxCountSetSubVer constant available to +// the test package. +const MaxCountSetSubVer uint32 = maxCountSetSubVer + // CommandSize makes the internal commandSize constant available to the test // package. const CommandSize = commandSize @@ -65,6 +73,18 @@ func TstWriteVarString(w io.Writer, pver uint32, str string) error { return writeVarString(w, pver, str) } +// TstReadVarBytes makes the internal readVarBytes function available to the +// test package. +func TstReadVarBytes(r io.Reader, pver uint32, maxAllowed uint32, fieldName string) ([]byte, error) { + return readVarBytes(r, pver, maxAllowed, fieldName) +} + +// TstWriteVarBytes makes the internal writeVarBytes function available to the +// test package. +func TstWriteVarBytes(w io.Writer, pver uint32, bytes []byte) error { + return writeVarBytes(w, pver, bytes) +} + // TstReadNetAddress makes the internal readNetAddress function available to // the test package. func TstReadNetAddress(r io.Reader, pver uint32, na *NetAddress, ts bool) error { diff --git a/message_test.go b/message_test.go index a33fc977..c79f4e8d 100644 --- a/message_test.go +++ b/message_test.go @@ -66,7 +66,7 @@ func TestMessage(t *testing.T) { msgPong := btcwire.NewMsgPong(123123) msgGetHeaders := btcwire.NewMsgGetHeaders() msgHeaders := btcwire.NewMsgHeaders() - msgAlert := btcwire.NewMsgAlert("payload", "signature") + msgAlert := btcwire.NewMsgAlert([]byte("payload"), []byte("signature")) msgMemPool := btcwire.NewMsgMemPool() tests := []struct { diff --git a/msgalert.go b/msgalert.go index fdb1d66d..2bd45dfd 100644 --- a/msgalert.go +++ b/msgalert.go @@ -5,9 +5,317 @@ package btcwire import ( + "bytes" + "fmt" "io" ) +// MsgAlert contains a payload and a signature: +// +// =============================================== +// | Field | Data Type | Size | +// =============================================== +// | payload | []uchar | ? | +// ----------------------------------------------- +// | signature | []uchar | ? | +// ----------------------------------------------- +// +// Here payload is an Alert serialized into a byte array to ensure that +// versions using incompatible alert formats can still relay +// alerts among one another. +// +// An Alert is the payload deserialized as follows: +// +// =============================================== +// | Field | Data Type | Size | +// =============================================== +// | Version | int32 | 4 | +// ----------------------------------------------- +// | RelayUntil | int64 | 8 | +// ----------------------------------------------- +// | Expiration | int64 | 8 | +// ----------------------------------------------- +// | ID | int32 | 4 | +// ----------------------------------------------- +// | Cancel | int32 | 4 | +// ----------------------------------------------- +// | SetCancel | set | ? | +// ----------------------------------------------- +// | MinVer | int32 | 4 | +// ----------------------------------------------- +// | MaxVer | int32 | 4 | +// ----------------------------------------------- +// | SetSubVer | set | ? | +// ----------------------------------------------- +// | Priority | int32 | 4 | +// ----------------------------------------------- +// | Comment | string | ? | +// ----------------------------------------------- +// | StatusBar | string | ? | +// ----------------------------------------------- +// | Reserved | string | ? | +// ----------------------------------------------- +// | Total (Fixed) | 45 | +// ----------------------------------------------- +// +// note: +// * string is a VarString i.e VarInt length followed by the string itself +// * set is a VarInt followed by as many number of strings +// * set is a VarInt followed by as many number of ints +// * fixedAlertSize = 40 + 5*min(VarInt) = 40 + 5*1 = 45 + +// Now we can define bounds on Alert size, SetCancel and SetSubVer + +// Fixed size of the alert payload +const fixedAlertSize = 45 + +// Max size of the ECDSA signature +// note: since this size is fixed and < 255, size of VarInt +// required = 1 (fits in uint8) +const maxSignatureSize = 72 + +// Maximum size of the alert +// MessagePayload = VarInt(Alert) + Alert + VarInt(Signature) + Signature +// maxMessagePayload = maxAlertSize + max(VarInt) + maxSignatureSize + 1 +const maxAlertSize = maxMessagePayload - maxSignatureSize - MaxVarIntPayload - 1 + +// Maximum number of Cancel IDs from SetCancel to read +// maxAlertSize = fixedAlertSize + max(SetCancel) + max(SetSubVer) + 3*(string) +// for caculating maximum number of Cancel IDs, set all other variable sizes to 0 +// maxAlertSize = fixedAlertSize + (MaxVarIntPayload-1) + x*sizeOf(int32) +// x = (maxAlertSize - fixedAlertSize - MaxVarIntPayload + 1) / 4 +const maxCountSetCancel = (maxAlertSize - fixedAlertSize - MaxVarIntPayload + 1) / 4 + +// Maximum number of subversions from SetSubVer to read +// maxAlertSize = fixedAlertSize + max(SetCancel) + max(SetSubVer) + 3*(string) +// for caculating maximum number of subversions, set all other variable sizes to 0 +// maxAlertSize = fixedAlertSize + (MaxVarIntPayload-1) + x*sizeOf(string) +// x = (maxAlertSize - fixedAlertSize - MaxVarIntPayload + 1) / sizeOf(string) +// subversion would typically be something like "/Satoshi:0.7.2/" (15 bytes) +// so assuming < 255 bytes, sizeOf(string) = sizeOf(uint8) + 255 = 256 +const maxCountSetSubVer = (maxAlertSize - fixedAlertSize - MaxVarIntPayload + 1) / 256 + +// Alert contains the data deserialized from the MsgAlert payload +type Alert struct { + + // Alert format version + Version int32 + + // Timestamp beyond which nodes should stop relaying this alert + RelayUntil int64 + + // Timestamp beyond which this alert is no longer in effect and + // should be ignored + Expiration int64 + + // A unique ID number for this alert + ID int32 + + // All alerts with an ID less than or equal to this number should + // cancelled, deleted and not accepted in the future + Cancel int32 + + // All alert IDs contained in this set should be cancelled as above + SetCancel []int32 + + // This alert only applies to versions greater than or equal to this + // version. Other versions should still relay it. + MinVer int32 + + // This alert only applies to versions less than or equal to this version. + // Other versions should still relay it. + MaxVer int32 + + // If this set contains any elements, then only nodes that have their + // subVer contained in this set are affected by the alert. Other versions + // should still relay it. + SetSubVer []string + + // Relative priority compared to other alerts + Priority int32 + + // A comment on the alert that is not displayed + Comment string + + // The alert message that is displayed to the user + StatusBar string + + // Reserved + Reserved string +} + +// NewAlert returns an new Alert with values provided +func NewAlert(version int32, relayuntil int64, expiration int64, + id int32, cancel int32, setcancel []int32, minver int32, + maxver int32, setsubver []string, priority int32, comment string, + statusbar string, reserved string) *Alert { + return &Alert{ + Version: version, + RelayUntil: relayuntil, + Expiration: expiration, + ID: id, + Cancel: cancel, + SetCancel: setcancel, + MinVer: minver, + MaxVer: maxver, + SetSubVer: setsubver, + Priority: priority, + Comment: comment, + StatusBar: statusbar, + Reserved: reserved, + } +} + +// NewAlertFromPayload returns an Alert with values deserialized +// from the serializedpayload +func NewAlertFromPayload(serializedpayload []byte, pver uint32) (*Alert, error) { + var alert Alert + r := bytes.NewReader(serializedpayload) + err := alert.Deserialize(r, pver) + if err != nil { + return nil, err + } + return &alert, nil +} + +// Serialize writes a serialized byte array of the Alert +func (alert *Alert) Serialize(w io.Writer, pver uint32) error { + err := writeElements(w, &alert.Version, + &alert.RelayUntil, &alert.Expiration, &alert.ID, &alert.Cancel) + if err != nil { + return err + } + + count := len(alert.SetCancel) + if count > maxCountSetCancel { + str := fmt.Sprintf("too many cancel alert IDs for alert "+ + "[count %v, max %v]", count, maxCountSetCancel) + return messageError("Alert.Serialize", str) + } + err = writeVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := 0; i < int(count); i++ { + err = writeElement(w, &alert.SetCancel[i]) + if err != nil { + return err + } + } + + err = writeElements(w, &alert.MinVer, &alert.MaxVer) + if err != nil { + return err + } + + count = len(alert.SetSubVer) + if count > maxCountSetSubVer { + str := fmt.Sprintf("too many sub versions for alert "+ + "[count %v, max %v]", count, maxCountSetSubVer) + return messageError("Alert.Serialize", str) + } + err = writeVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := 0; i < int(count); i++ { + err = writeVarString(w, pver, alert.SetSubVer[i]) + if err != nil { + return err + } + } + + err = writeElement(w, &alert.Priority) + if err != nil { + return err + } + err = writeVarString(w, pver, alert.Comment) + if err != nil { + return err + } + err = writeVarString(w, pver, alert.StatusBar) + if err != nil { + return err + } + err = writeVarString(w, pver, alert.Reserved) + if err != nil { + return err + } + return nil +} + +// Deserialize reads a byte array, deserializes +// it and updates the Alert +func (alert *Alert) Deserialize(r io.Reader, pver uint32) error { + err := readElements(r, &alert.Version, &alert.RelayUntil, + &alert.Expiration, &alert.ID, &alert.Cancel) + if err != nil { + return err + } + + // SetCancel: first read a VarInt that contains + // count - the number of Cancel IDs, then + // iterate count times and read them + count, err := readVarInt(r, pver) + if err != nil { + return err + } + if count > maxCountSetCancel { + str := fmt.Sprintf("too many cancel alert IDs for alert "+ + "[count %v, max %v]", count, maxCountSetCancel) + return messageError("Alert.Deserialize", str) + } + alert.SetCancel = make([]int32, count) + for i := 0; i < int(count); i++ { + err := readElement(r, &alert.SetCancel[i]) + if err != nil { + return err + } + } + + err = readElements(r, &alert.MinVer, &alert.MaxVer) + if err != nil { + return err + } + + // SetSubVer: similar to SetCancel + // but read count number of sub-version strings + count, err = readVarInt(r, pver) + if err != nil { + return err + } + if count > maxCountSetSubVer { + str := fmt.Sprintf("too many sub versions for alert "+ + "[count %v, max %v]", count, maxCountSetSubVer) + return messageError("Alert.Deserialize", str) + } + alert.SetSubVer = make([]string, count) + for i := 0; i < int(count); i++ { + alert.SetSubVer[i], err = readVarString(r, pver) + if err != nil { + return err + } + } + + err = readElement(r, &alert.Priority) + if err != nil { + return err + } + alert.Comment, err = readVarString(r, pver) + if err != nil { + return err + } + alert.StatusBar, err = readVarString(r, pver) + if err != nil { + return err + } + alert.Reserved, err = readVarString(r, pver) + if err != nil { + return err + } + return nil +} + // MsgAlert implements the Message interface and defines a bitcoin alert // message. // @@ -15,24 +323,36 @@ import ( // display if the signature matches the key. bitcoind/bitcoin-qt only checks // against a signature from the core developers. type MsgAlert struct { - // PayloadBlob is the alert payload serialized as a string so that the + // SerializedPayload is the alert payload serialized as a string so that the // version can change but the Alert can still be passed on by older // clients. - PayloadBlob string + SerializedPayload []byte // Signature is the ECDSA signature of the message. - Signature string + Signature []byte + + // Deserialized Payload + Payload *Alert } // BtcDecode decodes r using the bitcoin protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgAlert) BtcDecode(r io.Reader, pver uint32) error { var err error - msg.PayloadBlob, err = readVarString(r, pver) + + msg.SerializedPayload, err = readVarBytes(r, pver, maxMessagePayload, + "alert serialized payload") if err != nil { return err } - msg.Signature, err = readVarString(r, pver) + + msg.Payload, err = NewAlertFromPayload(msg.SerializedPayload, pver) + if err != nil { + msg.Payload = nil + } + + msg.Signature, err = readVarBytes(r, pver, maxMessagePayload, + "alert signature") if err != nil { return err } @@ -44,15 +364,33 @@ func (msg *MsgAlert) BtcDecode(r io.Reader, pver uint32) error { // This is part of the Message interface implementation. func (msg *MsgAlert) BtcEncode(w io.Writer, pver uint32) error { var err error - err = writeVarString(w, pver, msg.PayloadBlob) + var serializedpayload []byte + if msg.Payload != nil { + // try to Serialize Payload if possible + r := new(bytes.Buffer) + err = msg.Payload.Serialize(r, pver) + if err != nil { + // Serialize failed - ignore & fallback + // to SerializedPayload + serializedpayload = msg.SerializedPayload + } else { + serializedpayload = r.Bytes() + } + } else { + serializedpayload = msg.SerializedPayload + } + slen := uint64(len(serializedpayload)) + if slen == 0 { + return messageError("MsgAlert.BtcEncode", "empty serialized payload") + } + err = writeVarBytes(w, pver, serializedpayload) if err != nil { return err } - err = writeVarString(w, pver, msg.Signature) + err = writeVarBytes(w, pver, msg.Signature) if err != nil { return err } - return nil } @@ -72,9 +410,10 @@ func (msg *MsgAlert) MaxPayloadLength(pver uint32) uint32 { // NewMsgAlert returns a new bitcoin alert message that conforms to the Message // interface. See MsgAlert for details. -func NewMsgAlert(payloadblob string, signature string) *MsgAlert { +func NewMsgAlert(serializedpayload []byte, signature []byte) *MsgAlert { return &MsgAlert{ - PayloadBlob: payloadblob, - Signature: signature, + SerializedPayload: serializedpayload, + Signature: signature, + Payload: nil, } } diff --git a/msgalert_test.go b/msgalert_test.go index 3e9e3ce5..1ae012c9 100644 --- a/msgalert_test.go +++ b/msgalert_test.go @@ -13,19 +13,19 @@ import ( "testing" ) -// TestAlert tests the MsgAlert API. -func TestAlert(t *testing.T) { +// TestMsgAlert tests the MsgAlert API. +func TestMsgAlert(t *testing.T) { pver := btcwire.ProtocolVersion - payloadblob := "some message" - signature := "some sig" + serializedpayload := []byte("some message") + signature := []byte("some sig") // Ensure we get the same payload and signature back out. - msg := btcwire.NewMsgAlert(payloadblob, signature) - if msg.PayloadBlob != payloadblob { - t.Errorf("NewMsgAlert: wrong payloadblob - got %v, want %v", - msg.PayloadBlob, payloadblob) + msg := btcwire.NewMsgAlert(serializedpayload, signature) + if !reflect.DeepEqual(msg.SerializedPayload, serializedpayload) { + t.Errorf("NewMsgAlert: wrong serializedpayload - got %v, want %v", + msg.SerializedPayload, serializedpayload) } - if msg.Signature != signature { + if !reflect.DeepEqual(msg.Signature, signature) { t.Errorf("NewMsgAlert: wrong signature - got %v, want %v", msg.Signature, signature) } @@ -46,14 +46,46 @@ func TestAlert(t *testing.T) { maxPayload, wantPayload) } - return + // Test BtcEncode with Payload == nil + var buf bytes.Buffer + err := msg.BtcEncode(&buf, pver) + if err != nil { + t.Error(err.Error()) + } + // expected = 0x0c + serializedpayload + 0x08 + signature + expectedBuf := append([]byte{0x0c}, serializedpayload...) + expectedBuf = append(expectedBuf, []byte{0x08}...) + expectedBuf = append(expectedBuf, signature...) + if !bytes.Equal(buf.Bytes(), expectedBuf) { + t.Errorf("BtcEncode got: %s want: %s", + spew.Sdump(buf.Bytes()), spew.Sdump(expectedBuf)) + } + + // Test BtcEncode with Payload != nil + // note: Payload is an empty Alert but not nil + msg.Payload = new(btcwire.Alert) + buf = *new(bytes.Buffer) + err = msg.BtcEncode(&buf, pver) + if err != nil { + t.Error(err.Error()) + } + // empty Alert is 45 null bytes, see Alert comments + // for details + // expected = 0x2d + 45*0x00 + 0x08 + signature + expectedBuf = append([]byte{0x2d}, bytes.Repeat([]byte{0x00}, 45)...) + expectedBuf = append(expectedBuf, []byte{0x08}...) + expectedBuf = append(expectedBuf, signature...) + if !bytes.Equal(buf.Bytes(), expectedBuf) { + t.Errorf("BtcEncode got: %s want: %s", + spew.Sdump(buf.Bytes()), spew.Sdump(expectedBuf)) + } } -// TestAlertWire tests the MsgAlert wire encode and decode for various protocol +// TestMsgAlertWire tests the MsgAlert wire encode and decode for various protocol // versions. -func TestAlertWire(t *testing.T) { - baseAlert := btcwire.NewMsgAlert("some payload", "somesig") - baseAlertEncoded := []byte{ +func TestMsgAlertWire(t *testing.T) { + baseMsgAlert := btcwire.NewMsgAlert([]byte("some payload"), []byte("somesig")) + baseMsgAlertEncoded := []byte{ 0x0c, // Varint for payload length 0x73, 0x6f, 0x6d, 0x65, 0x20, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, // "some payload" @@ -69,41 +101,41 @@ func TestAlertWire(t *testing.T) { }{ // Latest protocol version. { - baseAlert, - baseAlert, - baseAlertEncoded, + baseMsgAlert, + baseMsgAlert, + baseMsgAlertEncoded, btcwire.ProtocolVersion, }, // Protocol version BIP0035Version. { - baseAlert, - baseAlert, - baseAlertEncoded, + baseMsgAlert, + baseMsgAlert, + baseMsgAlertEncoded, btcwire.BIP0035Version, }, // Protocol version BIP0031Version. { - baseAlert, - baseAlert, - baseAlertEncoded, + baseMsgAlert, + baseMsgAlert, + baseMsgAlertEncoded, btcwire.BIP0031Version, }, // Protocol version NetAddressTimeVersion. { - baseAlert, - baseAlert, - baseAlertEncoded, + baseMsgAlert, + baseMsgAlert, + baseMsgAlertEncoded, btcwire.NetAddressTimeVersion, }, // Protocol version MultipleAddressVersion. { - baseAlert, - baseAlert, - baseAlertEncoded, + baseMsgAlert, + baseMsgAlert, + baseMsgAlertEncoded, btcwire.MultipleAddressVersion, }, } @@ -139,13 +171,13 @@ func TestAlertWire(t *testing.T) { } } -// TestAlertWireErrors performs negative tests against wire encode and decode +// TestMsgAlertWireErrors performs negative tests against wire encode and decode // of MsgAlert to confirm error paths work correctly. -func TestAlertWireErrors(t *testing.T) { +func TestMsgAlertWireErrors(t *testing.T) { pver := btcwire.ProtocolVersion - baseAlert := btcwire.NewMsgAlert("some payload", "somesig") - baseAlertEncoded := []byte{ + baseMsgAlert := btcwire.NewMsgAlert([]byte("some payload"), []byte("somesig")) + baseMsgAlertEncoded := []byte{ 0x0c, // Varint for payload length 0x73, 0x6f, 0x6d, 0x65, 0x20, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, // "some payload" @@ -162,13 +194,13 @@ func TestAlertWireErrors(t *testing.T) { readErr error // Expected read error }{ // Force error in payload length. - {baseAlert, baseAlertEncoded, pver, 0, io.ErrShortWrite, io.EOF}, + {baseMsgAlert, baseMsgAlertEncoded, pver, 0, io.ErrShortWrite, io.EOF}, // Force error in payload. - {baseAlert, baseAlertEncoded, pver, 1, io.ErrShortWrite, io.EOF}, + {baseMsgAlert, baseMsgAlertEncoded, pver, 1, io.ErrShortWrite, io.EOF}, // Force error in signature length. - {baseAlert, baseAlertEncoded, pver, 13, io.ErrShortWrite, io.EOF}, + {baseMsgAlert, baseMsgAlertEncoded, pver, 13, io.ErrShortWrite, io.EOF}, // Force error in signature. - {baseAlert, baseAlertEncoded, pver, 14, io.ErrShortWrite, io.EOF}, + {baseMsgAlert, baseMsgAlertEncoded, pver, 14, io.ErrShortWrite, io.EOF}, } t.Logf("Running %d tests", len(tests)) @@ -212,4 +244,223 @@ func TestAlertWireErrors(t *testing.T) { } } } + + // Test Error on empty Payload + baseMsgAlert.SerializedPayload = []byte{} + w := new(bytes.Buffer) + err := baseMsgAlert.BtcEncode(w, pver) + if _, ok := err.(*btcwire.MessageError); !ok { + t.Errorf("MsgAlert.BtcEncode wrong error got: %T, want: %T", + err, btcwire.MessageError{}) + } + + // Test Payload Serialize error + // overflow the max number of elements in SetCancel + baseMsgAlert.Payload = new(btcwire.Alert) + baseMsgAlert.Payload.SetCancel = make([]int32, btcwire.MaxCountSetCancel+1) + buf := *new(bytes.Buffer) + err = baseMsgAlert.BtcEncode(&buf, pver) + if _, ok := err.(*btcwire.MessageError); !ok { + t.Errorf("MsgAlert.BtcEncode wrong error got: %T, want: %T", + err, btcwire.MessageError{}) + } + + // overflow the max number of elements in SetSubVer + baseMsgAlert.Payload = new(btcwire.Alert) + baseMsgAlert.Payload.SetSubVer = make([]string, btcwire.MaxCountSetSubVer+1) + buf = *new(bytes.Buffer) + err = baseMsgAlert.BtcEncode(&buf, pver) + if _, ok := err.(*btcwire.MessageError); !ok { + t.Errorf("MsgAlert.BtcEncode wrong error got: %T, want: %T", + err, btcwire.MessageError{}) + } +} + +// TestAlert tests serialization and deserialization +// of the payload to Alert +func TestAlert(t *testing.T) { + pver := btcwire.ProtocolVersion + alert := btcwire.NewAlert( + 1, 1337093712, 1368628812, 1015, + 1013, []int32{1014}, 0, 40599, []string{"/Satoshi:0.7.2/"}, 5000, "", + "URGENT: upgrade required, see http://bitcoin.org/dos for details", "", + ) + w := new(bytes.Buffer) + err := alert.Serialize(w, pver) + if err != nil { + t.Error(err.Error()) + } + serializedpayload := w.Bytes() + newAlert, err := btcwire.NewAlertFromPayload(serializedpayload, pver) + if err != nil { + t.Error(err.Error()) + } + + if alert.Version != newAlert.Version { + t.Errorf("NewAlertFromPayload: wrong Version - got %v, want %v ", + alert.Version, newAlert.Version) + } + if alert.RelayUntil != newAlert.RelayUntil { + t.Errorf("NewAlertFromPayload: wrong RelayUntil - got %v, want %v ", + alert.RelayUntil, newAlert.RelayUntil) + } + if alert.Expiration != newAlert.Expiration { + t.Errorf("NewAlertFromPayload: wrong Expiration - got %v, want %v ", + alert.Expiration, newAlert.Expiration) + } + if alert.ID != newAlert.ID { + t.Errorf("NewAlertFromPayload: wrong ID - got %v, want %v ", + alert.ID, newAlert.ID) + } + if alert.Cancel != newAlert.Cancel { + t.Errorf("NewAlertFromPayload: wrong Cancel - got %v, want %v ", + alert.Cancel, newAlert.Cancel) + } + if len(alert.SetCancel) != len(newAlert.SetCancel) { + t.Errorf("NewAlertFromPayload: wrong number of SetCancel - got %v, want %v ", + len(alert.SetCancel), len(newAlert.SetCancel)) + } + for i := 0; i < len(alert.SetCancel); i++ { + if alert.SetCancel[i] != newAlert.SetCancel[i] { + t.Errorf("NewAlertFromPayload: wrong SetCancel[%v] - got %v, want %v ", + len(alert.SetCancel), alert.SetCancel[i], newAlert.SetCancel[i]) + } + } + if alert.MinVer != newAlert.MinVer { + t.Errorf("NewAlertFromPayload: wrong MinVer - got %v, want %v ", + alert.MinVer, newAlert.MinVer) + } + if alert.MaxVer != newAlert.MaxVer { + t.Errorf("NewAlertFromPayload: wrong MaxVer - got %v, want %v ", + alert.MaxVer, newAlert.MaxVer) + } + if len(alert.SetSubVer) != len(newAlert.SetSubVer) { + t.Errorf("NewAlertFromPayload: wrong number of SetSubVer - got %v, want %v ", + len(alert.SetSubVer), len(newAlert.SetSubVer)) + } + for i := 0; i < len(alert.SetSubVer); i++ { + if alert.SetSubVer[i] != newAlert.SetSubVer[i] { + t.Errorf("NewAlertFromPayload: wrong SetSubVer[%v] - got %v, want %v ", + len(alert.SetSubVer), alert.SetSubVer[i], newAlert.SetSubVer[i]) + } + } + if alert.Priority != newAlert.Priority { + t.Errorf("NewAlertFromPayload: wrong Priority - got %v, want %v ", + alert.Priority, newAlert.Priority) + } + if alert.Comment != newAlert.Comment { + t.Errorf("NewAlertFromPayload: wrong Comment - got %v, want %v ", + alert.Comment, newAlert.Comment) + } + if alert.StatusBar != newAlert.StatusBar { + t.Errorf("NewAlertFromPayload: wrong StatusBar - got %v, want %v ", + alert.StatusBar, newAlert.StatusBar) + } + if alert.Reserved != newAlert.Reserved { + t.Errorf("NewAlertFromPayload: wrong Reserved - got %v, want %v ", + alert.Reserved, newAlert.Reserved) + } +} + +// TestAlertErrors performs negative tests against payload serialization, +// deserialization of Alert to confirm error paths work correctly. +func TestAlertErrors(t *testing.T) { + pver := btcwire.ProtocolVersion + + baseAlert := btcwire.NewAlert( + 1, 1337093712, 1368628812, 1015, + 1013, []int32{1014}, 0, 40599, []string{"/Satoshi:0.7.2/"}, 5000, "", + "URGENT", "", + ) + baseAlertEncoded := []byte{ + 0x01, 0x00, 0x00, 0x00, 0x50, 0x6e, 0xb2, 0x4f, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x9e, 0x93, 0x51, //|....Pn.O....L..Q| + 0x00, 0x00, 0x00, 0x00, 0xf7, 0x03, 0x00, 0x00, 0xf5, 0x03, 0x00, 0x00, 0x01, 0xf6, 0x03, 0x00, //|................| + 0x00, 0x00, 0x00, 0x00, 0x00, 0x97, 0x9e, 0x00, 0x00, 0x01, 0x0f, 0x2f, 0x53, 0x61, 0x74, 0x6f, //|.........../Sato| + 0x73, 0x68, 0x69, 0x3a, 0x30, 0x2e, 0x37, 0x2e, 0x32, 0x2f, 0x88, 0x13, 0x00, 0x00, 0x00, 0x06, //|shi:0.7.2/......| + 0x55, 0x52, 0x47, 0x45, 0x4e, 0x54, 0x00, //|URGENT.| + } + tests := []struct { + in *btcwire.Alert // Value to encode + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + max int // Max size of fixed buffer to induce errors + writeErr error // Expected write error + readErr error // Expected read error + }{ + // Force error in Version + {baseAlert, baseAlertEncoded, pver, 0, io.ErrShortWrite, io.EOF}, + // Force error in SetCancel VarInt. + {baseAlert, baseAlertEncoded, pver, 28, io.ErrShortWrite, io.EOF}, + // Force error in SetCancel ints. + {baseAlert, baseAlertEncoded, pver, 29, io.ErrShortWrite, io.EOF}, + // Force error in MinVer + {baseAlert, baseAlertEncoded, pver, 40, io.ErrShortWrite, io.EOF}, + // Force error in SetSubVer string VarInt. + {baseAlert, baseAlertEncoded, pver, 41, io.ErrShortWrite, io.EOF}, + // Force error in SetSubVer strings. + {baseAlert, baseAlertEncoded, pver, 48, io.ErrShortWrite, io.EOF}, + // Force error in Priority + {baseAlert, baseAlertEncoded, pver, 60, io.ErrShortWrite, io.EOF}, + // Force error in Comment string. + {baseAlert, baseAlertEncoded, pver, 62, io.ErrShortWrite, io.EOF}, + // Force error in StatusBar string. + {baseAlert, baseAlertEncoded, pver, 64, io.ErrShortWrite, io.EOF}, + // Force error in Reserved string. + {baseAlert, baseAlertEncoded, pver, 70, io.ErrShortWrite, io.EOF}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + w := newFixedWriter(test.max) + err := test.in.Serialize(w, test.pver) + if reflect.TypeOf(err) != reflect.TypeOf(test.writeErr) { + t.Errorf("Alert.Serialize #%d wrong error got: %v, want: %v", + i, err, test.writeErr) + continue + } + + var alert btcwire.Alert + r := newFixedReader(test.max, test.buf) + err = alert.Deserialize(r, test.pver) + if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) { + t.Errorf("Alert.Deserialize #%d wrong error got: %v, want: %v", + i, err, test.readErr) + continue + } + } + + // overflow the max number of elements in SetCancel + // maxCountSetCancel + 1 == 8388575 == \xdf\xff\x7f\x00 + // replace bytes 29-33 + badAlertEncoded := []byte{ + 0x01, 0x00, 0x00, 0x00, 0x50, 0x6e, 0xb2, 0x4f, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x9e, 0x93, 0x51, //|....Pn.O....L..Q| + 0x00, 0x00, 0x00, 0x00, 0xf7, 0x03, 0x00, 0x00, 0xf5, 0x03, 0x00, 0x00, 0xfe, 0xdf, 0xff, 0x7f, //|................| + 0x00, 0x00, 0x00, 0x00, 0x00, 0x97, 0x9e, 0x00, 0x00, 0x01, 0x0f, 0x2f, 0x53, 0x61, 0x74, 0x6f, //|.........../Sato| + 0x73, 0x68, 0x69, 0x3a, 0x30, 0x2e, 0x37, 0x2e, 0x32, 0x2f, 0x88, 0x13, 0x00, 0x00, 0x00, 0x06, //|shi:0.7.2/......| + 0x55, 0x52, 0x47, 0x45, 0x4e, 0x54, 0x00, //|URGENT.| + } + var alert btcwire.Alert + r := bytes.NewBuffer(badAlertEncoded) + err := alert.Deserialize(r, pver) + if _, ok := err.(*btcwire.MessageError); !ok { + t.Errorf("Alert.Deserialize wrong error got: %T, want: %T", + err, btcwire.MessageError{}) + } + + // overflow the max number of elements in SetSubVer + // maxCountSetSubVer + 1 == 131071 + 1 == \x00\x00\x02\x00 + // replace bytes 42-46 + badAlertEncoded = []byte{ + 0x01, 0x00, 0x00, 0x00, 0x50, 0x6e, 0xb2, 0x4f, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x9e, 0x93, 0x51, //|....Pn.O....L..Q| + 0x00, 0x00, 0x00, 0x00, 0xf7, 0x03, 0x00, 0x00, 0xf5, 0x03, 0x00, 0x00, 0x01, 0xf6, 0x03, 0x00, //|................| + 0x00, 0x00, 0x00, 0x00, 0x00, 0x97, 0x9e, 0x00, 0x00, 0xfe, 0x00, 0x00, 0x02, 0x00, 0x74, 0x6f, //|.........../Sato| + 0x73, 0x68, 0x69, 0x3a, 0x30, 0x2e, 0x37, 0x2e, 0x32, 0x2f, 0x88, 0x13, 0x00, 0x00, 0x00, 0x06, //|shi:0.7.2/......| + 0x55, 0x52, 0x47, 0x45, 0x4e, 0x54, 0x00, //|URGENT.| + } + r = bytes.NewBuffer(badAlertEncoded) + err = alert.Deserialize(r, pver) + if _, ok := err.(*btcwire.MessageError); !ok { + t.Errorf("Alert.Deserialize wrong error got: %T, want: %T", + err, btcwire.MessageError{}) + } } diff --git a/msgtx.go b/msgtx.go index e31040b5..fc93faa4 100644 --- a/msgtx.go +++ b/msgtx.go @@ -470,28 +470,12 @@ func readTxIn(r io.Reader, pver uint32, version uint32, ti *TxIn) error { } ti.PreviousOutpoint = op - count, err := readVarInt(r, pver) + ti.SignatureScript, err = readVarBytes(r, pver, maxMessagePayload, + "transaction input signature script") if err != nil { return err } - // Prevent signature script larger than the max message size. It would - // be possible to cause memory exhaustion and panics without a sane - // upper bound on this count. - if count > uint64(maxMessagePayload) { - str := fmt.Sprintf("transaction input signature script is "+ - "larger than max message size [count %d, max %d]", - count, maxMessagePayload) - return messageError("MsgTx.BtcDecode", str) - } - - b := make([]byte, count) - _, err = io.ReadFull(r, b) - if err != nil { - return err - } - ti.SignatureScript = b - var buf [4]byte _, err = io.ReadFull(r, buf[:]) if err != nil { @@ -510,13 +494,7 @@ func writeTxIn(w io.Writer, pver uint32, version uint32, ti *TxIn) error { return err } - slen := uint64(len(ti.SignatureScript)) - err = writeVarInt(w, pver, slen) - if err != nil { - return err - } - - _, err = w.Write(ti.SignatureScript) + err = writeVarBytes(w, pver, ti.SignatureScript) if err != nil { return err } @@ -541,28 +519,12 @@ func readTxOut(r io.Reader, pver uint32, version uint32, to *TxOut) error { } to.Value = int64(binary.LittleEndian.Uint64(buf[:])) - count, err := readVarInt(r, pver) + to.PkScript, err = readVarBytes(r, pver, maxMessagePayload, + "transaction output public key script") if err != nil { return err } - // Prevent public key script larger than the max message size. It would - // be possible to cause memory exhaustion and panics without a sane - // upper bound on this count. - if count > uint64(maxMessagePayload) { - str := fmt.Sprintf("transaction output public key script is "+ - "larger than max message size [count %d, max %d]", - count, maxMessagePayload) - return messageError("MsgTx.BtcDecode", str) - } - - b := make([]byte, count) - _, err = io.ReadFull(r, b) - if err != nil { - return err - } - to.PkScript = b - return nil } @@ -576,16 +538,9 @@ func writeTxOut(w io.Writer, pver uint32, version uint32, to *TxOut) error { return err } - pkLen := uint64(len(to.PkScript)) - err = writeVarInt(w, pver, pkLen) + err = writeVarBytes(w, pver, to.PkScript) if err != nil { return err } - - _, err = w.Write(to.PkScript) - if err != nil { - return err - } - return nil }