psbt: add BIP 69 in-place sort
This commit is contained in:
parent
488d2cc834
commit
0b85b11dcc
2 changed files with 269 additions and 0 deletions
102
psbt/sort.go
Normal file
102
psbt/sort.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package psbt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
)
|
||||
|
||||
// InPlaceSort modifies the passed packet's wire TX inputs and outputs to be
|
||||
// sorted based on BIP 69. The sorting happens in a way that the packet's
|
||||
// partial inputs and outputs are also modified to match the sorted TxIn and
|
||||
// TxOuts of the wire transaction.
|
||||
//
|
||||
// WARNING: This function must NOT be called with packages that already contain
|
||||
// (partial) witness data since it will mutate the transaction if it's not
|
||||
// already sorted. This can cause issues if you mutate a tx in a block, for
|
||||
// example, which would invalidate the block. It could also cause cached hashes,
|
||||
// such as in a btcutil.Tx to become invalidated.
|
||||
//
|
||||
// The function should only be used if the caller is creating the transaction or
|
||||
// is otherwise 100% positive mutating will not cause adverse affects due to
|
||||
// other dependencies.
|
||||
func InPlaceSort(packet *Packet) error {
|
||||
// To make sure we don't run into any nil pointers or array index
|
||||
// violations during sorting, do a very basic sanity check first.
|
||||
err := VerifyInputOutputLen(packet, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Sort(&sortableInputs{p: packet})
|
||||
sort.Sort(&sortableOutputs{p: packet})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sortableInputs is a simple wrapper around a packet that implements the
|
||||
// sort.Interface for sorting the wire and partial inputs of a packet.
|
||||
type sortableInputs struct {
|
||||
p *Packet
|
||||
}
|
||||
|
||||
// sortableOutputs is a simple wrapper around a packet that implements the
|
||||
// sort.Interface for sorting the wire and partial outputs of a packet.
|
||||
type sortableOutputs struct {
|
||||
p *Packet
|
||||
}
|
||||
|
||||
// For sortableInputs and sortableOutputs, three functions are needed to make
|
||||
// them sortable with sort.Sort() -- Len, Less, and Swap.
|
||||
// Len and Swap are trivial. Less is BIP 69 specific.
|
||||
func (s *sortableInputs) Len() int { return len(s.p.UnsignedTx.TxIn) }
|
||||
func (s sortableOutputs) Len() int { return len(s.p.UnsignedTx.TxOut) }
|
||||
|
||||
// Swap swaps two inputs.
|
||||
func (s *sortableInputs) Swap(i, j int) {
|
||||
tx := s.p.UnsignedTx
|
||||
tx.TxIn[i], tx.TxIn[j] = tx.TxIn[j], tx.TxIn[i]
|
||||
s.p.Inputs[i], s.p.Inputs[j] = s.p.Inputs[j], s.p.Inputs[i]
|
||||
}
|
||||
|
||||
// Swap swaps two outputs.
|
||||
func (s *sortableOutputs) Swap(i, j int) {
|
||||
tx := s.p.UnsignedTx
|
||||
tx.TxOut[i], tx.TxOut[j] = tx.TxOut[j], tx.TxOut[i]
|
||||
s.p.Outputs[i], s.p.Outputs[j] = s.p.Outputs[j], s.p.Outputs[i]
|
||||
}
|
||||
|
||||
// Less is the input comparison function. First sort based on input hash
|
||||
// (reversed / rpc-style), then index.
|
||||
func (s *sortableInputs) Less(i, j int) bool {
|
||||
ins := s.p.UnsignedTx.TxIn
|
||||
|
||||
// Input hashes are the same, so compare the index.
|
||||
ihash := ins[i].PreviousOutPoint.Hash
|
||||
jhash := ins[j].PreviousOutPoint.Hash
|
||||
if ihash == jhash {
|
||||
return ins[i].PreviousOutPoint.Index <
|
||||
ins[j].PreviousOutPoint.Index
|
||||
}
|
||||
|
||||
// At this point, the hashes are not equal, so reverse them to
|
||||
// big-endian and return the result of the comparison.
|
||||
const hashSize = chainhash.HashSize
|
||||
for b := 0; b < hashSize/2; b++ {
|
||||
ihash[b], ihash[hashSize-1-b] = ihash[hashSize-1-b], ihash[b]
|
||||
jhash[b], jhash[hashSize-1-b] = jhash[hashSize-1-b], jhash[b]
|
||||
}
|
||||
return bytes.Compare(ihash[:], jhash[:]) == -1
|
||||
}
|
||||
|
||||
// Less is the output comparison function. First sort based on amount (smallest
|
||||
// first), then PkScript.
|
||||
func (s *sortableOutputs) Less(i, j int) bool {
|
||||
outs := s.p.UnsignedTx.TxOut
|
||||
|
||||
if outs[i].Value == outs[j].Value {
|
||||
return bytes.Compare(outs[i].PkScript, outs[j].PkScript) < 0
|
||||
}
|
||||
return outs[i].Value < outs[j].Value
|
||||
}
|
167
psbt/sort_test.go
Normal file
167
psbt/sort_test.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package psbt
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
)
|
||||
|
||||
func TestInPlaceSort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
packet *Packet
|
||||
expectedTxIn []*wire.TxIn
|
||||
expectedTxOut []*wire.TxOut
|
||||
expectedPIn []PInput
|
||||
expectedPOut []POutput
|
||||
expectErr bool
|
||||
}{{
|
||||
name: "packet nil",
|
||||
packet: nil,
|
||||
expectErr: true,
|
||||
}, {
|
||||
name: "no inputs or outputs",
|
||||
packet: &Packet{UnsignedTx: &wire.MsgTx{}},
|
||||
expectErr: false,
|
||||
}, {
|
||||
name: "inputs only",
|
||||
packet: &Packet{
|
||||
UnsignedTx: &wire.MsgTx{
|
||||
TxIn: []*wire.TxIn{{
|
||||
PreviousOutPoint: wire.OutPoint{
|
||||
Hash: chainhash.Hash{99, 88},
|
||||
Index: 7,
|
||||
},
|
||||
}, {
|
||||
PreviousOutPoint: wire.OutPoint{
|
||||
Hash: chainhash.Hash{77, 88},
|
||||
Index: 12,
|
||||
},
|
||||
}, {
|
||||
PreviousOutPoint: wire.OutPoint{
|
||||
Hash: chainhash.Hash{77, 88},
|
||||
Index: 7,
|
||||
},
|
||||
}},
|
||||
},
|
||||
// Abuse the SighashType as an index to make sure the
|
||||
// partial inputs are also sorted together with the wire
|
||||
// inputs.
|
||||
Inputs: []PInput{{
|
||||
SighashType: 0,
|
||||
}, {
|
||||
SighashType: 1,
|
||||
}, {
|
||||
SighashType: 2,
|
||||
}},
|
||||
},
|
||||
expectedTxIn: []*wire.TxIn{{
|
||||
PreviousOutPoint: wire.OutPoint{
|
||||
Hash: chainhash.Hash{77, 88},
|
||||
Index: 7,
|
||||
},
|
||||
}, {
|
||||
PreviousOutPoint: wire.OutPoint{
|
||||
Hash: chainhash.Hash{77, 88},
|
||||
Index: 12,
|
||||
},
|
||||
}, {
|
||||
PreviousOutPoint: wire.OutPoint{
|
||||
Hash: chainhash.Hash{99, 88},
|
||||
Index: 7,
|
||||
},
|
||||
}},
|
||||
expectedPIn: []PInput{{
|
||||
SighashType: 2,
|
||||
}, {
|
||||
SighashType: 1,
|
||||
}, {
|
||||
SighashType: 0,
|
||||
}},
|
||||
expectErr: false,
|
||||
}, {
|
||||
name: "outputs only",
|
||||
packet: &Packet{
|
||||
UnsignedTx: &wire.MsgTx{
|
||||
TxOut: []*wire.TxOut{{
|
||||
PkScript: []byte{99, 88},
|
||||
Value: 7,
|
||||
}, {
|
||||
PkScript: []byte{77, 88},
|
||||
Value: 12,
|
||||
}, {
|
||||
PkScript: []byte{77, 88},
|
||||
Value: 7,
|
||||
}},
|
||||
},
|
||||
// Abuse the RedeemScript as an index to make sure the
|
||||
// partial inputs are also sorted together with the wire
|
||||
// inputs.
|
||||
Outputs: []POutput{{
|
||||
RedeemScript: []byte{0},
|
||||
}, {
|
||||
RedeemScript: []byte{1},
|
||||
}, {
|
||||
RedeemScript: []byte{2},
|
||||
}},
|
||||
},
|
||||
expectedTxOut: []*wire.TxOut{{
|
||||
PkScript: []byte{77, 88},
|
||||
Value: 7,
|
||||
}, {
|
||||
PkScript: []byte{99, 88},
|
||||
Value: 7,
|
||||
}, {
|
||||
PkScript: []byte{77, 88},
|
||||
Value: 12,
|
||||
}},
|
||||
expectedPOut: []POutput{{
|
||||
RedeemScript: []byte{2},
|
||||
}, {
|
||||
RedeemScript: []byte{0},
|
||||
}, {
|
||||
RedeemScript: []byte{1},
|
||||
}},
|
||||
expectErr: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := tc.packet
|
||||
err := InPlaceSort(p)
|
||||
if (tc.expectErr && err == nil) ||
|
||||
(!tc.expectErr && err != nil) {
|
||||
|
||||
t.Fatalf("got error '%v' but wanted it to be "+
|
||||
"nil: %v", err, tc.expectErr)
|
||||
}
|
||||
|
||||
// Don't continue on this special test case.
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tx := p.UnsignedTx
|
||||
if !reflect.DeepEqual(tx.TxIn, tc.expectedTxIn) {
|
||||
t.Fatalf("unexpected txin, got %#v wanted %#v",
|
||||
tx.TxIn, tc.expectedTxIn)
|
||||
}
|
||||
if !reflect.DeepEqual(tx.TxOut, tc.expectedTxOut) {
|
||||
t.Fatalf("unexpected txout, got %#v wanted %#v",
|
||||
tx.TxOut, tc.expectedTxOut)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(p.Inputs, tc.expectedPIn) {
|
||||
t.Fatalf("unexpected pin, got %#v wanted %#v",
|
||||
p.Inputs, tc.expectedPIn)
|
||||
}
|
||||
if !reflect.DeepEqual(p.Outputs, tc.expectedPOut) {
|
||||
t.Fatalf("unexpected pout, got %#v wanted %#v",
|
||||
p.Inputs, tc.expectedPOut)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue