psbt: add new utility functions

This commit is contained in:
Oliver Gugger 2020-06-19 13:33:43 +02:00
parent 8ec8bad266
commit 488d2cc834
No known key found for this signature in database
GPG key ID: 8E4256593F177720
2 changed files with 488 additions and 0 deletions

View file

@ -8,6 +8,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sort"
@ -275,3 +276,138 @@ func readTxOut(txout []byte) (*wire.TxOut, error) {
return wire.NewTxOut(int64(valueSer), scriptPubKey), nil
}
// SumUtxoInputValues tries to extract the sum of all inputs specified in the
// UTXO fields of the PSBT. An error is returned if an input is specified that
// does not contain any UTXO information.
func SumUtxoInputValues(packet *Packet) (int64, error) {
// We take the TX ins of the unsigned TX as the truth for how many
// inputs there should be, as the fields in the extra data part of the
// PSBT can be empty.
if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) {
return 0, fmt.Errorf("TX input length doesn't match PSBT " +
"input length")
}
inputSum := int64(0)
for idx, in := range packet.Inputs {
switch {
case in.WitnessUtxo != nil:
// Witness UTXOs only need to reference the TxOut.
inputSum += in.WitnessUtxo.Value
case in.NonWitnessUtxo != nil:
// Non-witness UTXOs reference to the whole transaction
// the UTXO resides in.
utxOuts := in.NonWitnessUtxo.TxOut
txIn := packet.UnsignedTx.TxIn[idx]
inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value
default:
return 0, fmt.Errorf("input %d has no UTXO information",
idx)
}
}
return inputSum, nil
}
// TxOutsEqual returns true if two transaction outputs are equal.
func TxOutsEqual(out1, out2 *wire.TxOut) bool {
if out1 == nil || out2 == nil {
return out1 == out2
}
return out1.Value == out2.Value &&
bytes.Equal(out1.PkScript, out2.PkScript)
}
// VerifyOutputsEqual verifies that the two slices of transaction outputs are
// deep equal to each other. We do the length check and manual loop to provide
// better error messages to the user than just returning "not equal".
func VerifyOutputsEqual(outs1, outs2 []*wire.TxOut) error {
if len(outs1) != len(outs2) {
return fmt.Errorf("number of outputs are different")
}
for idx, out := range outs1 {
// There is a byte slice in the output so we can't use the
// equality operator.
if !TxOutsEqual(out, outs2[idx]) {
return fmt.Errorf("output %d is different", idx)
}
}
return nil
}
// VerifyInputPrevOutpointsEqual verifies that the previous outpoints of the
// two slices of transaction inputs are deep equal to each other. We do the
// length check and manual loop to provide better error messages to the user
// than just returning "not equal".
func VerifyInputPrevOutpointsEqual(ins1, ins2 []*wire.TxIn) error {
if len(ins1) != len(ins2) {
return fmt.Errorf("number of inputs are different")
}
for idx, in := range ins1 {
if in.PreviousOutPoint != ins2[idx].PreviousOutPoint {
return fmt.Errorf("previous outpoint of input %d is "+
"different", idx)
}
}
return nil
}
// VerifyInputOutputLen makes sure a packet is non-nil, contains a non-nil wire
// transaction and that the wire input/output lengths match the partial input/
// output lengths. A caller also can specify if they expect any inputs and/or
// outputs to be contained in the packet.
func VerifyInputOutputLen(packet *Packet, needInputs, needOutputs bool) error {
if packet == nil || packet.UnsignedTx == nil {
return fmt.Errorf("PSBT packet cannot be nil")
}
if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) {
return fmt.Errorf("invalid PSBT, wire inputs don't match " +
"partial inputs")
}
if len(packet.UnsignedTx.TxOut) != len(packet.Outputs) {
return fmt.Errorf("invalid PSBT, wire outputs don't match " +
"partial outputs")
}
if needInputs && len(packet.UnsignedTx.TxIn) == 0 {
return fmt.Errorf("PSBT packet must contain at least one " +
"input")
}
if needOutputs && len(packet.UnsignedTx.TxOut) == 0 {
return fmt.Errorf("PSBT packet must contain at least one " +
"output")
}
return nil
}
// NewFromSignedTx is a utility function to create a packet from an
// already-signed transaction. Returned are: an unsigned transaction
// serialization, a list of scriptSigs, one per input, and a list of witnesses,
// one per input.
func NewFromSignedTx(tx *wire.MsgTx) (*Packet, [][]byte,
[]wire.TxWitness, error) {
scriptSigs := make([][]byte, 0, len(tx.TxIn))
witnesses := make([]wire.TxWitness, 0, len(tx.TxIn))
tx2 := tx.Copy()
// Blank out signature info in inputs
for i, tin := range tx2.TxIn {
tin.SignatureScript = nil
scriptSigs = append(scriptSigs, tx.TxIn[i].SignatureScript)
tin.Witness = nil
witnesses = append(witnesses, tx.TxIn[i].Witness)
}
// Outputs always contain: (value, scriptPubkey) so don't need
// amending. Now tx2 is tx with all signing data stripped out
unsignedPsbt, err := NewFromUnsignedTx(tx2)
if err != nil {
return nil, nil, nil, err
}
return unsignedPsbt, scriptSigs, witnesses, nil
}

352
psbt/utils_test.go Normal file
View file

@ -0,0 +1,352 @@
package psbt
import (
"bytes"
"reflect"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
func TestSumUtxoInputValues(t *testing.T) {
// Expect sum to fail for packet with non-matching txIn and PInputs.
tx := wire.NewMsgTx(2)
badPacket, err := NewFromUnsignedTx(tx)
if err != nil {
t.Fatalf("could not create packet from TX: %v", err)
}
badPacket.Inputs = append(badPacket.Inputs, PInput{})
_, err = SumUtxoInputValues(badPacket)
if err == nil {
t.Fatalf("expected sum of bad packet to fail")
}
// Expect sum to fail if any inputs don't have UTXO information added.
op := []*wire.OutPoint{{}, {}}
noUtxoInfoPacket, err := New(op, nil, 2, 0, []uint32{0, 0})
if err != nil {
t.Fatalf("could not create new packet: %v", err)
}
_, err = SumUtxoInputValues(noUtxoInfoPacket)
if err == nil {
t.Fatalf("expected sum of missing UTXO info to fail")
}
// Create a packet that is OK and contains both witness and non-witness
// UTXO information.
okPacket, err := New(op, nil, 2, 0, []uint32{0, 0})
if err != nil {
t.Fatalf("could not create new packet: %v", err)
}
okPacket.Inputs[0].WitnessUtxo = &wire.TxOut{Value: 1234}
okPacket.Inputs[1].NonWitnessUtxo = &wire.MsgTx{
TxOut: []*wire.TxOut{{Value: 6543}},
}
sum, err := SumUtxoInputValues(okPacket)
if err != nil {
t.Fatalf("could not sum input: %v", err)
}
if sum != (1234 + 6543) {
t.Fatalf("unexpected sum, got %d wanted %d", sum, 1234+6543)
}
}
func TestTxOutsEqual(t *testing.T) {
testCases := []struct {
name string
out1 *wire.TxOut
out2 *wire.TxOut
expectEqual bool
}{{
name: "both nil",
out1: nil,
out2: nil,
expectEqual: true,
}, {
name: "one nil",
out1: nil,
out2: &wire.TxOut{},
expectEqual: false,
}, {
name: "both empty",
out1: &wire.TxOut{},
out2: &wire.TxOut{},
expectEqual: true,
}, {
name: "one pk script set",
out1: &wire.TxOut{},
out2: &wire.TxOut{
PkScript: []byte("foo"),
},
expectEqual: false,
}, {
name: "both fully set",
out1: &wire.TxOut{
Value: 1234,
PkScript: []byte("bar"),
},
out2: &wire.TxOut{
Value: 1234,
PkScript: []byte("bar"),
},
expectEqual: true,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
result := TxOutsEqual(tc.out1, tc.out2)
if result != tc.expectEqual {
t.Fatalf("unexpected result, got %v wanted %v",
result, tc.expectEqual)
}
})
}
}
func TestVerifyOutputsEqual(t *testing.T) {
testCases := []struct {
name string
outs1 []*wire.TxOut
outs2 []*wire.TxOut
expectErr bool
}{{
name: "both nil",
outs1: nil,
outs2: nil,
expectErr: false,
}, {
name: "one nil",
outs1: nil,
outs2: []*wire.TxOut{{}},
expectErr: true,
}, {
name: "both empty",
outs1: []*wire.TxOut{{}},
outs2: []*wire.TxOut{{}},
expectErr: false,
}, {
name: "one pk script set",
outs1: []*wire.TxOut{{}},
outs2: []*wire.TxOut{{
PkScript: []byte("foo"),
}},
expectErr: true,
}, {
name: "both fully set",
outs1: []*wire.TxOut{{
Value: 1234,
PkScript: []byte("bar"),
}, {}},
outs2: []*wire.TxOut{{
Value: 1234,
PkScript: []byte("bar"),
}, {}},
expectErr: false,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
err := VerifyOutputsEqual(tc.outs1, tc.outs2)
if (tc.expectErr && err == nil) ||
(!tc.expectErr && err != nil) {
t.Fatalf("got error '%v' but wanted it to be "+
"nil: %v", err, tc.expectErr)
}
})
}
}
func TestVerifyInputPrevOutpointsEqual(t *testing.T) {
testCases := []struct {
name string
ins1 []*wire.TxIn
ins2 []*wire.TxIn
expectErr bool
}{{
name: "both nil",
ins1: nil,
ins2: nil,
expectErr: false,
}, {
name: "one nil",
ins1: nil,
ins2: []*wire.TxIn{{}},
expectErr: true,
}, {
name: "both empty",
ins1: []*wire.TxIn{{}},
ins2: []*wire.TxIn{{}},
expectErr: false,
}, {
name: "one previous output set",
ins1: []*wire.TxIn{{}},
ins2: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{11, 22, 33},
Index: 7,
},
}},
expectErr: true,
}, {
name: "both fully set",
ins1: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{11, 22, 33},
Index: 7,
},
}, {}},
ins2: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{11, 22, 33},
Index: 7,
},
}, {}},
expectErr: false,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
err := VerifyInputPrevOutpointsEqual(tc.ins1, tc.ins2)
if (tc.expectErr && err == nil) ||
(!tc.expectErr && err != nil) {
t.Fatalf("got error '%v' but wanted it to be "+
"nil: %v", err, tc.expectErr)
}
})
}
}
func TestVerifyInputOutputLen(t *testing.T) {
testCases := []struct {
name string
packet *Packet
needInputs bool
needOutputs bool
expectErr bool
}{{
name: "packet nil",
packet: nil,
expectErr: true,
}, {
name: "wire tx nil",
packet: &Packet{},
expectErr: true,
}, {
name: "both empty don't need outputs",
packet: &Packet{
UnsignedTx: &wire.MsgTx{},
},
expectErr: false,
}, {
name: "both empty but need outputs",
packet: &Packet{
UnsignedTx: &wire.MsgTx{},
},
needOutputs: true,
expectErr: true,
}, {
name: "both empty but need inputs",
packet: &Packet{
UnsignedTx: &wire.MsgTx{},
},
needInputs: true,
expectErr: true,
}, {
name: "input len mismatch",
packet: &Packet{
UnsignedTx: &wire.MsgTx{
TxIn: []*wire.TxIn{{}},
},
},
needInputs: true,
expectErr: true,
}, {
name: "output len mismatch",
packet: &Packet{
UnsignedTx: &wire.MsgTx{
TxOut: []*wire.TxOut{{}},
},
},
needOutputs: true,
expectErr: true,
}, {
name: "all fully set",
packet: &Packet{
UnsignedTx: &wire.MsgTx{
TxIn: []*wire.TxIn{{}},
TxOut: []*wire.TxOut{{}},
},
Inputs: []PInput{{}},
Outputs: []POutput{{}},
},
needInputs: true,
needOutputs: true,
expectErr: false,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
err := VerifyInputOutputLen(
tc.packet, tc.needInputs, tc.needOutputs,
)
if (tc.expectErr && err == nil) ||
(!tc.expectErr && err != nil) {
t.Fatalf("got error '%v' but wanted it to be "+
"nil: %v", err, tc.expectErr)
}
})
}
}
func TestNewFromSignedTx(t *testing.T) {
orig := &wire.MsgTx{
TxIn: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{},
SignatureScript: []byte("script"),
Witness: [][]byte{[]byte("witness")},
Sequence: 1234,
}},
TxOut: []*wire.TxOut{{
PkScript: []byte{77, 88},
Value: 99,
}},
}
packet, scripts, witnesses, err := NewFromSignedTx(orig)
if err != nil {
t.Fatalf("could not create packet from signed TX: %v", err)
}
tx := packet.UnsignedTx
expectedTxIn := []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{},
Sequence: 1234,
}}
if !reflect.DeepEqual(tx.TxIn, expectedTxIn) {
t.Fatalf("unexpected txin, got %#v wanted %#v",
tx.TxIn, expectedTxIn)
}
if !reflect.DeepEqual(tx.TxOut, orig.TxOut) {
t.Fatalf("unexpected txout, got %#v wanted %#v",
tx.TxOut, orig.TxOut)
}
if len(scripts) != 1 || !bytes.Equal(scripts[0], []byte("script")) {
t.Fatalf("script not extracted correctly")
}
if len(witnesses) != 1 ||
!bytes.Equal(witnesses[0][0], []byte("witness")) {
t.Fatalf("witness not extracted correctly")
}
}