psbt: add BIP 69 in-place sort
This commit is contained in:
parent
488d2cc834
commit
0b85b11dcc
2 changed files with 269 additions and 0 deletions
102
psbt/sort.go
Normal file
102
psbt/sort.go
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
package psbt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InPlaceSort modifies the passed packet's wire TX inputs and outputs to be
|
||||||
|
// sorted based on BIP 69. The sorting happens in a way that the packet's
|
||||||
|
// partial inputs and outputs are also modified to match the sorted TxIn and
|
||||||
|
// TxOuts of the wire transaction.
|
||||||
|
//
|
||||||
|
// WARNING: This function must NOT be called with packages that already contain
|
||||||
|
// (partial) witness data since it will mutate the transaction if it's not
|
||||||
|
// already sorted. This can cause issues if you mutate a tx in a block, for
|
||||||
|
// example, which would invalidate the block. It could also cause cached hashes,
|
||||||
|
// such as in a btcutil.Tx to become invalidated.
|
||||||
|
//
|
||||||
|
// The function should only be used if the caller is creating the transaction or
|
||||||
|
// is otherwise 100% positive mutating will not cause adverse affects due to
|
||||||
|
// other dependencies.
|
||||||
|
func InPlaceSort(packet *Packet) error {
|
||||||
|
// To make sure we don't run into any nil pointers or array index
|
||||||
|
// violations during sorting, do a very basic sanity check first.
|
||||||
|
err := VerifyInputOutputLen(packet, false, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(&sortableInputs{p: packet})
|
||||||
|
sort.Sort(&sortableOutputs{p: packet})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortableInputs is a simple wrapper around a packet that implements the
|
||||||
|
// sort.Interface for sorting the wire and partial inputs of a packet.
|
||||||
|
type sortableInputs struct {
|
||||||
|
p *Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortableOutputs is a simple wrapper around a packet that implements the
|
||||||
|
// sort.Interface for sorting the wire and partial outputs of a packet.
|
||||||
|
type sortableOutputs struct {
|
||||||
|
p *Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
// For sortableInputs and sortableOutputs, three functions are needed to make
|
||||||
|
// them sortable with sort.Sort() -- Len, Less, and Swap.
|
||||||
|
// Len and Swap are trivial. Less is BIP 69 specific.
|
||||||
|
func (s *sortableInputs) Len() int { return len(s.p.UnsignedTx.TxIn) }
|
||||||
|
func (s sortableOutputs) Len() int { return len(s.p.UnsignedTx.TxOut) }
|
||||||
|
|
||||||
|
// Swap swaps two inputs.
|
||||||
|
func (s *sortableInputs) Swap(i, j int) {
|
||||||
|
tx := s.p.UnsignedTx
|
||||||
|
tx.TxIn[i], tx.TxIn[j] = tx.TxIn[j], tx.TxIn[i]
|
||||||
|
s.p.Inputs[i], s.p.Inputs[j] = s.p.Inputs[j], s.p.Inputs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap swaps two outputs.
|
||||||
|
func (s *sortableOutputs) Swap(i, j int) {
|
||||||
|
tx := s.p.UnsignedTx
|
||||||
|
tx.TxOut[i], tx.TxOut[j] = tx.TxOut[j], tx.TxOut[i]
|
||||||
|
s.p.Outputs[i], s.p.Outputs[j] = s.p.Outputs[j], s.p.Outputs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less is the input comparison function. First sort based on input hash
|
||||||
|
// (reversed / rpc-style), then index.
|
||||||
|
func (s *sortableInputs) Less(i, j int) bool {
|
||||||
|
ins := s.p.UnsignedTx.TxIn
|
||||||
|
|
||||||
|
// Input hashes are the same, so compare the index.
|
||||||
|
ihash := ins[i].PreviousOutPoint.Hash
|
||||||
|
jhash := ins[j].PreviousOutPoint.Hash
|
||||||
|
if ihash == jhash {
|
||||||
|
return ins[i].PreviousOutPoint.Index <
|
||||||
|
ins[j].PreviousOutPoint.Index
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, the hashes are not equal, so reverse them to
|
||||||
|
// big-endian and return the result of the comparison.
|
||||||
|
const hashSize = chainhash.HashSize
|
||||||
|
for b := 0; b < hashSize/2; b++ {
|
||||||
|
ihash[b], ihash[hashSize-1-b] = ihash[hashSize-1-b], ihash[b]
|
||||||
|
jhash[b], jhash[hashSize-1-b] = jhash[hashSize-1-b], jhash[b]
|
||||||
|
}
|
||||||
|
return bytes.Compare(ihash[:], jhash[:]) == -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less is the output comparison function. First sort based on amount (smallest
|
||||||
|
// first), then PkScript.
|
||||||
|
func (s *sortableOutputs) Less(i, j int) bool {
|
||||||
|
outs := s.p.UnsignedTx.TxOut
|
||||||
|
|
||||||
|
if outs[i].Value == outs[j].Value {
|
||||||
|
return bytes.Compare(outs[i].PkScript, outs[j].PkScript) < 0
|
||||||
|
}
|
||||||
|
return outs[i].Value < outs[j].Value
|
||||||
|
}
|
167
psbt/sort_test.go
Normal file
167
psbt/sort_test.go
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
package psbt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInPlaceSort(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
packet *Packet
|
||||||
|
expectedTxIn []*wire.TxIn
|
||||||
|
expectedTxOut []*wire.TxOut
|
||||||
|
expectedPIn []PInput
|
||||||
|
expectedPOut []POutput
|
||||||
|
expectErr bool
|
||||||
|
}{{
|
||||||
|
name: "packet nil",
|
||||||
|
packet: nil,
|
||||||
|
expectErr: true,
|
||||||
|
}, {
|
||||||
|
name: "no inputs or outputs",
|
||||||
|
packet: &Packet{UnsignedTx: &wire.MsgTx{}},
|
||||||
|
expectErr: false,
|
||||||
|
}, {
|
||||||
|
name: "inputs only",
|
||||||
|
packet: &Packet{
|
||||||
|
UnsignedTx: &wire.MsgTx{
|
||||||
|
TxIn: []*wire.TxIn{{
|
||||||
|
PreviousOutPoint: wire.OutPoint{
|
||||||
|
Hash: chainhash.Hash{99, 88},
|
||||||
|
Index: 7,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
PreviousOutPoint: wire.OutPoint{
|
||||||
|
Hash: chainhash.Hash{77, 88},
|
||||||
|
Index: 12,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
PreviousOutPoint: wire.OutPoint{
|
||||||
|
Hash: chainhash.Hash{77, 88},
|
||||||
|
Index: 7,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
// Abuse the SighashType as an index to make sure the
|
||||||
|
// partial inputs are also sorted together with the wire
|
||||||
|
// inputs.
|
||||||
|
Inputs: []PInput{{
|
||||||
|
SighashType: 0,
|
||||||
|
}, {
|
||||||
|
SighashType: 1,
|
||||||
|
}, {
|
||||||
|
SighashType: 2,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
expectedTxIn: []*wire.TxIn{{
|
||||||
|
PreviousOutPoint: wire.OutPoint{
|
||||||
|
Hash: chainhash.Hash{77, 88},
|
||||||
|
Index: 7,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
PreviousOutPoint: wire.OutPoint{
|
||||||
|
Hash: chainhash.Hash{77, 88},
|
||||||
|
Index: 12,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
PreviousOutPoint: wire.OutPoint{
|
||||||
|
Hash: chainhash.Hash{99, 88},
|
||||||
|
Index: 7,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
expectedPIn: []PInput{{
|
||||||
|
SighashType: 2,
|
||||||
|
}, {
|
||||||
|
SighashType: 1,
|
||||||
|
}, {
|
||||||
|
SighashType: 0,
|
||||||
|
}},
|
||||||
|
expectErr: false,
|
||||||
|
}, {
|
||||||
|
name: "outputs only",
|
||||||
|
packet: &Packet{
|
||||||
|
UnsignedTx: &wire.MsgTx{
|
||||||
|
TxOut: []*wire.TxOut{{
|
||||||
|
PkScript: []byte{99, 88},
|
||||||
|
Value: 7,
|
||||||
|
}, {
|
||||||
|
PkScript: []byte{77, 88},
|
||||||
|
Value: 12,
|
||||||
|
}, {
|
||||||
|
PkScript: []byte{77, 88},
|
||||||
|
Value: 7,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
// Abuse the RedeemScript as an index to make sure the
|
||||||
|
// partial inputs are also sorted together with the wire
|
||||||
|
// inputs.
|
||||||
|
Outputs: []POutput{{
|
||||||
|
RedeemScript: []byte{0},
|
||||||
|
}, {
|
||||||
|
RedeemScript: []byte{1},
|
||||||
|
}, {
|
||||||
|
RedeemScript: []byte{2},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
expectedTxOut: []*wire.TxOut{{
|
||||||
|
PkScript: []byte{77, 88},
|
||||||
|
Value: 7,
|
||||||
|
}, {
|
||||||
|
PkScript: []byte{99, 88},
|
||||||
|
Value: 7,
|
||||||
|
}, {
|
||||||
|
PkScript: []byte{77, 88},
|
||||||
|
Value: 12,
|
||||||
|
}},
|
||||||
|
expectedPOut: []POutput{{
|
||||||
|
RedeemScript: []byte{2},
|
||||||
|
}, {
|
||||||
|
RedeemScript: []byte{0},
|
||||||
|
}, {
|
||||||
|
RedeemScript: []byte{1},
|
||||||
|
}},
|
||||||
|
expectErr: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
p := tc.packet
|
||||||
|
err := InPlaceSort(p)
|
||||||
|
if (tc.expectErr && err == nil) ||
|
||||||
|
(!tc.expectErr && err != nil) {
|
||||||
|
|
||||||
|
t.Fatalf("got error '%v' but wanted it to be "+
|
||||||
|
"nil: %v", err, tc.expectErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't continue on this special test case.
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := p.UnsignedTx
|
||||||
|
if !reflect.DeepEqual(tx.TxIn, tc.expectedTxIn) {
|
||||||
|
t.Fatalf("unexpected txin, got %#v wanted %#v",
|
||||||
|
tx.TxIn, tc.expectedTxIn)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(tx.TxOut, tc.expectedTxOut) {
|
||||||
|
t.Fatalf("unexpected txout, got %#v wanted %#v",
|
||||||
|
tx.TxOut, tc.expectedTxOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(p.Inputs, tc.expectedPIn) {
|
||||||
|
t.Fatalf("unexpected pin, got %#v wanted %#v",
|
||||||
|
p.Inputs, tc.expectedPIn)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(p.Outputs, tc.expectedPOut) {
|
||||||
|
t.Fatalf("unexpected pout, got %#v wanted %#v",
|
||||||
|
p.Inputs, tc.expectedPOut)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue