psbt: add BIP 69 in-place sort

This commit is contained in:
Oliver Gugger 2020-06-19 14:44:50 +02:00
parent 488d2cc834
commit 0b85b11dcc
No known key found for this signature in database
GPG key ID: 8E4256593F177720
2 changed files with 269 additions and 0 deletions

102
psbt/sort.go Normal file
View file

@ -0,0 +1,102 @@
package psbt
import (
"bytes"
"sort"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// InPlaceSort modifies the passed packet's wire TX inputs and outputs to be
// sorted based on BIP 69. The sorting happens in a way that the packet's
// partial inputs and outputs are also modified to match the sorted TxIn and
// TxOuts of the wire transaction.
//
// WARNING: This function must NOT be called with packages that already contain
// (partial) witness data since it will mutate the transaction if it's not
// already sorted. This can cause issues if you mutate a tx in a block, for
// example, which would invalidate the block. It could also cause cached hashes,
// such as in a btcutil.Tx to become invalidated.
//
// The function should only be used if the caller is creating the transaction or
// is otherwise 100% positive mutating will not cause adverse affects due to
// other dependencies.
func InPlaceSort(packet *Packet) error {
// To make sure we don't run into any nil pointers or array index
// violations during sorting, do a very basic sanity check first.
err := VerifyInputOutputLen(packet, false, false)
if err != nil {
return err
}
sort.Sort(&sortableInputs{p: packet})
sort.Sort(&sortableOutputs{p: packet})
return nil
}
// sortableInputs is a simple wrapper around a packet that implements the
// sort.Interface for sorting the wire and partial inputs of a packet.
type sortableInputs struct {
p *Packet
}
// sortableOutputs is a simple wrapper around a packet that implements the
// sort.Interface for sorting the wire and partial outputs of a packet.
type sortableOutputs struct {
p *Packet
}
// For sortableInputs and sortableOutputs, three functions are needed to make
// them sortable with sort.Sort() -- Len, Less, and Swap.
// Len and Swap are trivial. Less is BIP 69 specific.
func (s *sortableInputs) Len() int { return len(s.p.UnsignedTx.TxIn) }
func (s sortableOutputs) Len() int { return len(s.p.UnsignedTx.TxOut) }
// Swap swaps two inputs.
func (s *sortableInputs) Swap(i, j int) {
tx := s.p.UnsignedTx
tx.TxIn[i], tx.TxIn[j] = tx.TxIn[j], tx.TxIn[i]
s.p.Inputs[i], s.p.Inputs[j] = s.p.Inputs[j], s.p.Inputs[i]
}
// Swap swaps two outputs.
func (s *sortableOutputs) Swap(i, j int) {
tx := s.p.UnsignedTx
tx.TxOut[i], tx.TxOut[j] = tx.TxOut[j], tx.TxOut[i]
s.p.Outputs[i], s.p.Outputs[j] = s.p.Outputs[j], s.p.Outputs[i]
}
// Less is the input comparison function. First sort based on input hash
// (reversed / rpc-style), then index.
func (s *sortableInputs) Less(i, j int) bool {
ins := s.p.UnsignedTx.TxIn
// Input hashes are the same, so compare the index.
ihash := ins[i].PreviousOutPoint.Hash
jhash := ins[j].PreviousOutPoint.Hash
if ihash == jhash {
return ins[i].PreviousOutPoint.Index <
ins[j].PreviousOutPoint.Index
}
// At this point, the hashes are not equal, so reverse them to
// big-endian and return the result of the comparison.
const hashSize = chainhash.HashSize
for b := 0; b < hashSize/2; b++ {
ihash[b], ihash[hashSize-1-b] = ihash[hashSize-1-b], ihash[b]
jhash[b], jhash[hashSize-1-b] = jhash[hashSize-1-b], jhash[b]
}
return bytes.Compare(ihash[:], jhash[:]) == -1
}
// Less is the output comparison function. First sort based on amount (smallest
// first), then PkScript.
func (s *sortableOutputs) Less(i, j int) bool {
outs := s.p.UnsignedTx.TxOut
if outs[i].Value == outs[j].Value {
return bytes.Compare(outs[i].PkScript, outs[j].PkScript) < 0
}
return outs[i].Value < outs[j].Value
}

167
psbt/sort_test.go Normal file
View file

@ -0,0 +1,167 @@
package psbt
import (
"reflect"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
func TestInPlaceSort(t *testing.T) {
testCases := []struct {
name string
packet *Packet
expectedTxIn []*wire.TxIn
expectedTxOut []*wire.TxOut
expectedPIn []PInput
expectedPOut []POutput
expectErr bool
}{{
name: "packet nil",
packet: nil,
expectErr: true,
}, {
name: "no inputs or outputs",
packet: &Packet{UnsignedTx: &wire.MsgTx{}},
expectErr: false,
}, {
name: "inputs only",
packet: &Packet{
UnsignedTx: &wire.MsgTx{
TxIn: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{99, 88},
Index: 7,
},
}, {
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{77, 88},
Index: 12,
},
}, {
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{77, 88},
Index: 7,
},
}},
},
// Abuse the SighashType as an index to make sure the
// partial inputs are also sorted together with the wire
// inputs.
Inputs: []PInput{{
SighashType: 0,
}, {
SighashType: 1,
}, {
SighashType: 2,
}},
},
expectedTxIn: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{77, 88},
Index: 7,
},
}, {
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{77, 88},
Index: 12,
},
}, {
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{99, 88},
Index: 7,
},
}},
expectedPIn: []PInput{{
SighashType: 2,
}, {
SighashType: 1,
}, {
SighashType: 0,
}},
expectErr: false,
}, {
name: "outputs only",
packet: &Packet{
UnsignedTx: &wire.MsgTx{
TxOut: []*wire.TxOut{{
PkScript: []byte{99, 88},
Value: 7,
}, {
PkScript: []byte{77, 88},
Value: 12,
}, {
PkScript: []byte{77, 88},
Value: 7,
}},
},
// Abuse the RedeemScript as an index to make sure the
// partial inputs are also sorted together with the wire
// inputs.
Outputs: []POutput{{
RedeemScript: []byte{0},
}, {
RedeemScript: []byte{1},
}, {
RedeemScript: []byte{2},
}},
},
expectedTxOut: []*wire.TxOut{{
PkScript: []byte{77, 88},
Value: 7,
}, {
PkScript: []byte{99, 88},
Value: 7,
}, {
PkScript: []byte{77, 88},
Value: 12,
}},
expectedPOut: []POutput{{
RedeemScript: []byte{2},
}, {
RedeemScript: []byte{0},
}, {
RedeemScript: []byte{1},
}},
expectErr: false,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p := tc.packet
err := InPlaceSort(p)
if (tc.expectErr && err == nil) ||
(!tc.expectErr && err != nil) {
t.Fatalf("got error '%v' but wanted it to be "+
"nil: %v", err, tc.expectErr)
}
// Don't continue on this special test case.
if p == nil {
return
}
tx := p.UnsignedTx
if !reflect.DeepEqual(tx.TxIn, tc.expectedTxIn) {
t.Fatalf("unexpected txin, got %#v wanted %#v",
tx.TxIn, tc.expectedTxIn)
}
if !reflect.DeepEqual(tx.TxOut, tc.expectedTxOut) {
t.Fatalf("unexpected txout, got %#v wanted %#v",
tx.TxOut, tc.expectedTxOut)
}
if !reflect.DeepEqual(p.Inputs, tc.expectedPIn) {
t.Fatalf("unexpected pin, got %#v wanted %#v",
p.Inputs, tc.expectedPIn)
}
if !reflect.DeepEqual(p.Outputs, tc.expectedPOut) {
t.Fatalf("unexpected pout, got %#v wanted %#v",
p.Inputs, tc.expectedPOut)
}
})
}
}