diff --git a/psbt/utils.go b/psbt/utils.go index c002b57..494d040 100644 --- a/psbt/utils.go +++ b/psbt/utils.go @@ -301,6 +301,15 @@ func SumUtxoInputValues(packet *Packet) (int64, error) { // the UTXO resides in. utxOuts := in.NonWitnessUtxo.TxOut txIn := packet.UnsignedTx.TxIn[idx] + + // Check that utxOuts actually has enough space to + // contain the previous outpoint's index. + opIdx := txIn.PreviousOutPoint.Index + if opIdx >= uint32(len(utxOuts)) { + return 0, fmt.Errorf("input %d has malformed "+ + "TxOut field", idx) + } + inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value default: diff --git a/psbt/utils_test.go b/psbt/utils_test.go index 8d6d2af..90593ff 100644 --- a/psbt/utils_test.go +++ b/psbt/utils_test.go @@ -53,6 +53,24 @@ func TestSumUtxoInputValues(t *testing.T) { if sum != (1234 + 6543) { t.Fatalf("unexpected sum, got %d wanted %d", sum, 1234+6543) } + + // Create a malformed packet where NonWitnessUtxo.TxOut does not + // contain the index specified by the PreviousOutPoint in the + // packet's Unsigned.TxIn field. + badOp := []*wire.OutPoint{{}, {Index: 500}} + malformedPacket, err := New(badOp, nil, 2, 0, []uint32{0, 0}) + if err != nil { + t.Fatalf("could not create malformed packet: %v", err) + } + malformedPacket.Inputs[0].WitnessUtxo = &wire.TxOut{Value: 1234} + malformedPacket.Inputs[1].NonWitnessUtxo = &wire.MsgTx{ + TxOut: []*wire.TxOut{{Value: 6543}}, + } + + _, err = SumUtxoInputValues(malformedPacket) + if err == nil { + t.Fatalf("expected sum of malformed packet to fail") + } } func TestTxOutsEqual(t *testing.T) {