diff --git a/psbt/utils.go b/psbt/utils.go index a6f27e3..b52518c 100644 --- a/psbt/utils.go +++ b/psbt/utils.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "sort" @@ -275,3 +276,138 @@ func readTxOut(txout []byte) (*wire.TxOut, error) { return wire.NewTxOut(int64(valueSer), scriptPubKey), nil } + +// SumUtxoInputValues tries to extract the sum of all inputs specified in the +// UTXO fields of the PSBT. An error is returned if an input is specified that +// does not contain any UTXO information. +func SumUtxoInputValues(packet *Packet) (int64, error) { + // We take the TX ins of the unsigned TX as the truth for how many + // inputs there should be, as the fields in the extra data part of the + // PSBT can be empty. + if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { + return 0, fmt.Errorf("TX input length doesn't match PSBT " + + "input length") + } + + inputSum := int64(0) + for idx, in := range packet.Inputs { + switch { + case in.WitnessUtxo != nil: + // Witness UTXOs only need to reference the TxOut. + inputSum += in.WitnessUtxo.Value + + case in.NonWitnessUtxo != nil: + // Non-witness UTXOs reference to the whole transaction + // the UTXO resides in. + utxOuts := in.NonWitnessUtxo.TxOut + txIn := packet.UnsignedTx.TxIn[idx] + inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value + + default: + return 0, fmt.Errorf("input %d has no UTXO information", + idx) + } + } + return inputSum, nil +} + +// TxOutsEqual returns true if two transaction outputs are equal. +func TxOutsEqual(out1, out2 *wire.TxOut) bool { + if out1 == nil || out2 == nil { + return out1 == out2 + } + return out1.Value == out2.Value && + bytes.Equal(out1.PkScript, out2.PkScript) +} + +// VerifyOutputsEqual verifies that the two slices of transaction outputs are +// deep equal to each other. We do the length check and manual loop to provide +// better error messages to the user than just returning "not equal". +func VerifyOutputsEqual(outs1, outs2 []*wire.TxOut) error { + if len(outs1) != len(outs2) { + return fmt.Errorf("number of outputs are different") + } + for idx, out := range outs1 { + // There is a byte slice in the output so we can't use the + // equality operator. + if !TxOutsEqual(out, outs2[idx]) { + return fmt.Errorf("output %d is different", idx) + } + } + return nil +} + +// VerifyInputPrevOutpointsEqual verifies that the previous outpoints of the +// two slices of transaction inputs are deep equal to each other. We do the +// length check and manual loop to provide better error messages to the user +// than just returning "not equal". +func VerifyInputPrevOutpointsEqual(ins1, ins2 []*wire.TxIn) error { + if len(ins1) != len(ins2) { + return fmt.Errorf("number of inputs are different") + } + for idx, in := range ins1 { + if in.PreviousOutPoint != ins2[idx].PreviousOutPoint { + return fmt.Errorf("previous outpoint of input %d is "+ + "different", idx) + } + } + return nil +} + +// VerifyInputOutputLen makes sure a packet is non-nil, contains a non-nil wire +// transaction and that the wire input/output lengths match the partial input/ +// output lengths. A caller also can specify if they expect any inputs and/or +// outputs to be contained in the packet. +func VerifyInputOutputLen(packet *Packet, needInputs, needOutputs bool) error { + if packet == nil || packet.UnsignedTx == nil { + return fmt.Errorf("PSBT packet cannot be nil") + } + + if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { + return fmt.Errorf("invalid PSBT, wire inputs don't match " + + "partial inputs") + } + if len(packet.UnsignedTx.TxOut) != len(packet.Outputs) { + return fmt.Errorf("invalid PSBT, wire outputs don't match " + + "partial outputs") + } + + if needInputs && len(packet.UnsignedTx.TxIn) == 0 { + return fmt.Errorf("PSBT packet must contain at least one " + + "input") + } + if needOutputs && len(packet.UnsignedTx.TxOut) == 0 { + return fmt.Errorf("PSBT packet must contain at least one " + + "output") + } + + return nil +} + +// NewFromSignedTx is a utility function to create a packet from an +// already-signed transaction. Returned are: an unsigned transaction +// serialization, a list of scriptSigs, one per input, and a list of witnesses, +// one per input. +func NewFromSignedTx(tx *wire.MsgTx) (*Packet, [][]byte, + []wire.TxWitness, error) { + + scriptSigs := make([][]byte, 0, len(tx.TxIn)) + witnesses := make([]wire.TxWitness, 0, len(tx.TxIn)) + tx2 := tx.Copy() + + // Blank out signature info in inputs + for i, tin := range tx2.TxIn { + tin.SignatureScript = nil + scriptSigs = append(scriptSigs, tx.TxIn[i].SignatureScript) + tin.Witness = nil + witnesses = append(witnesses, tx.TxIn[i].Witness) + } + + // Outputs always contain: (value, scriptPubkey) so don't need + // amending. Now tx2 is tx with all signing data stripped out + unsignedPsbt, err := NewFromUnsignedTx(tx2) + if err != nil { + return nil, nil, nil, err + } + return unsignedPsbt, scriptSigs, witnesses, nil +} diff --git a/psbt/utils_test.go b/psbt/utils_test.go new file mode 100644 index 0000000..8d6d2af --- /dev/null +++ b/psbt/utils_test.go @@ -0,0 +1,352 @@ +package psbt + +import ( + "bytes" + "reflect" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +func TestSumUtxoInputValues(t *testing.T) { + // Expect sum to fail for packet with non-matching txIn and PInputs. + tx := wire.NewMsgTx(2) + badPacket, err := NewFromUnsignedTx(tx) + if err != nil { + t.Fatalf("could not create packet from TX: %v", err) + } + badPacket.Inputs = append(badPacket.Inputs, PInput{}) + + _, err = SumUtxoInputValues(badPacket) + if err == nil { + t.Fatalf("expected sum of bad packet to fail") + } + + // Expect sum to fail if any inputs don't have UTXO information added. + op := []*wire.OutPoint{{}, {}} + noUtxoInfoPacket, err := New(op, nil, 2, 0, []uint32{0, 0}) + if err != nil { + t.Fatalf("could not create new packet: %v", err) + } + + _, err = SumUtxoInputValues(noUtxoInfoPacket) + if err == nil { + t.Fatalf("expected sum of missing UTXO info to fail") + } + + // Create a packet that is OK and contains both witness and non-witness + // UTXO information. + okPacket, err := New(op, nil, 2, 0, []uint32{0, 0}) + if err != nil { + t.Fatalf("could not create new packet: %v", err) + } + okPacket.Inputs[0].WitnessUtxo = &wire.TxOut{Value: 1234} + okPacket.Inputs[1].NonWitnessUtxo = &wire.MsgTx{ + TxOut: []*wire.TxOut{{Value: 6543}}, + } + + sum, err := SumUtxoInputValues(okPacket) + if err != nil { + t.Fatalf("could not sum input: %v", err) + } + if sum != (1234 + 6543) { + t.Fatalf("unexpected sum, got %d wanted %d", sum, 1234+6543) + } +} + +func TestTxOutsEqual(t *testing.T) { + testCases := []struct { + name string + out1 *wire.TxOut + out2 *wire.TxOut + expectEqual bool + }{{ + name: "both nil", + out1: nil, + out2: nil, + expectEqual: true, + }, { + name: "one nil", + out1: nil, + out2: &wire.TxOut{}, + expectEqual: false, + }, { + name: "both empty", + out1: &wire.TxOut{}, + out2: &wire.TxOut{}, + expectEqual: true, + }, { + name: "one pk script set", + out1: &wire.TxOut{}, + out2: &wire.TxOut{ + PkScript: []byte("foo"), + }, + expectEqual: false, + }, { + name: "both fully set", + out1: &wire.TxOut{ + Value: 1234, + PkScript: []byte("bar"), + }, + out2: &wire.TxOut{ + Value: 1234, + PkScript: []byte("bar"), + }, + expectEqual: true, + }} + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + result := TxOutsEqual(tc.out1, tc.out2) + if result != tc.expectEqual { + t.Fatalf("unexpected result, got %v wanted %v", + result, tc.expectEqual) + } + }) + } +} + +func TestVerifyOutputsEqual(t *testing.T) { + testCases := []struct { + name string + outs1 []*wire.TxOut + outs2 []*wire.TxOut + expectErr bool + }{{ + name: "both nil", + outs1: nil, + outs2: nil, + expectErr: false, + }, { + name: "one nil", + outs1: nil, + outs2: []*wire.TxOut{{}}, + expectErr: true, + }, { + name: "both empty", + outs1: []*wire.TxOut{{}}, + outs2: []*wire.TxOut{{}}, + expectErr: false, + }, { + name: "one pk script set", + outs1: []*wire.TxOut{{}}, + outs2: []*wire.TxOut{{ + PkScript: []byte("foo"), + }}, + expectErr: true, + }, { + name: "both fully set", + outs1: []*wire.TxOut{{ + Value: 1234, + PkScript: []byte("bar"), + }, {}}, + outs2: []*wire.TxOut{{ + Value: 1234, + PkScript: []byte("bar"), + }, {}}, + expectErr: false, + }} + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := VerifyOutputsEqual(tc.outs1, tc.outs2) + if (tc.expectErr && err == nil) || + (!tc.expectErr && err != nil) { + + t.Fatalf("got error '%v' but wanted it to be "+ + "nil: %v", err, tc.expectErr) + } + }) + } +} + +func TestVerifyInputPrevOutpointsEqual(t *testing.T) { + testCases := []struct { + name string + ins1 []*wire.TxIn + ins2 []*wire.TxIn + expectErr bool + }{{ + name: "both nil", + ins1: nil, + ins2: nil, + expectErr: false, + }, { + name: "one nil", + ins1: nil, + ins2: []*wire.TxIn{{}}, + expectErr: true, + }, { + name: "both empty", + ins1: []*wire.TxIn{{}}, + ins2: []*wire.TxIn{{}}, + expectErr: false, + }, { + name: "one previous output set", + ins1: []*wire.TxIn{{}}, + ins2: []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{11, 22, 33}, + Index: 7, + }, + }}, + expectErr: true, + }, { + name: "both fully set", + ins1: []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{11, 22, 33}, + Index: 7, + }, + }, {}}, + ins2: []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{11, 22, 33}, + Index: 7, + }, + }, {}}, + expectErr: false, + }} + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := VerifyInputPrevOutpointsEqual(tc.ins1, tc.ins2) + if (tc.expectErr && err == nil) || + (!tc.expectErr && err != nil) { + + t.Fatalf("got error '%v' but wanted it to be "+ + "nil: %v", err, tc.expectErr) + } + }) + } +} + +func TestVerifyInputOutputLen(t *testing.T) { + testCases := []struct { + name string + packet *Packet + needInputs bool + needOutputs bool + expectErr bool + }{{ + name: "packet nil", + packet: nil, + expectErr: true, + }, { + name: "wire tx nil", + packet: &Packet{}, + expectErr: true, + }, { + name: "both empty don't need outputs", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{}, + }, + expectErr: false, + }, { + name: "both empty but need outputs", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{}, + }, + needOutputs: true, + expectErr: true, + }, { + name: "both empty but need inputs", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{}, + }, + needInputs: true, + expectErr: true, + }, { + name: "input len mismatch", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{ + TxIn: []*wire.TxIn{{}}, + }, + }, + needInputs: true, + expectErr: true, + }, { + name: "output len mismatch", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{ + TxOut: []*wire.TxOut{{}}, + }, + }, + needOutputs: true, + expectErr: true, + }, { + name: "all fully set", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{ + TxIn: []*wire.TxIn{{}}, + TxOut: []*wire.TxOut{{}}, + }, + Inputs: []PInput{{}}, + Outputs: []POutput{{}}, + }, + needInputs: true, + needOutputs: true, + expectErr: false, + }} + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := VerifyInputOutputLen( + tc.packet, tc.needInputs, tc.needOutputs, + ) + if (tc.expectErr && err == nil) || + (!tc.expectErr && err != nil) { + + t.Fatalf("got error '%v' but wanted it to be "+ + "nil: %v", err, tc.expectErr) + } + }) + } +} + +func TestNewFromSignedTx(t *testing.T) { + orig := &wire.MsgTx{ + TxIn: []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{}, + SignatureScript: []byte("script"), + Witness: [][]byte{[]byte("witness")}, + Sequence: 1234, + }}, + TxOut: []*wire.TxOut{{ + PkScript: []byte{77, 88}, + Value: 99, + }}, + } + + packet, scripts, witnesses, err := NewFromSignedTx(orig) + if err != nil { + t.Fatalf("could not create packet from signed TX: %v", err) + } + + tx := packet.UnsignedTx + expectedTxIn := []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{}, + Sequence: 1234, + }} + if !reflect.DeepEqual(tx.TxIn, expectedTxIn) { + t.Fatalf("unexpected txin, got %#v wanted %#v", + tx.TxIn, expectedTxIn) + } + if !reflect.DeepEqual(tx.TxOut, orig.TxOut) { + t.Fatalf("unexpected txout, got %#v wanted %#v", + tx.TxOut, orig.TxOut) + } + if len(scripts) != 1 || !bytes.Equal(scripts[0], []byte("script")) { + t.Fatalf("script not extracted correctly") + } + if len(witnesses) != 1 || + !bytes.Equal(witnesses[0][0], []byte("witness")) { + + t.Fatalf("witness not extracted correctly") + } +}