103 lines
3.4 KiB
Go
103 lines
3.4 KiB
Go
|
package psbt
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"sort"
|
||
|
|
||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||
|
)
|
||
|
|
||
|
// InPlaceSort modifies the passed packet's wire TX inputs and outputs to be
|
||
|
// sorted based on BIP 69. The sorting happens in a way that the packet's
|
||
|
// partial inputs and outputs are also modified to match the sorted TxIn and
|
||
|
// TxOuts of the wire transaction.
|
||
|
//
|
||
|
// WARNING: This function must NOT be called with packages that already contain
|
||
|
// (partial) witness data since it will mutate the transaction if it's not
|
||
|
// already sorted. This can cause issues if you mutate a tx in a block, for
|
||
|
// example, which would invalidate the block. It could also cause cached hashes,
|
||
|
// such as in a btcutil.Tx to become invalidated.
|
||
|
//
|
||
|
// The function should only be used if the caller is creating the transaction or
|
||
|
// is otherwise 100% positive mutating will not cause adverse affects due to
|
||
|
// other dependencies.
|
||
|
func InPlaceSort(packet *Packet) error {
|
||
|
// To make sure we don't run into any nil pointers or array index
|
||
|
// violations during sorting, do a very basic sanity check first.
|
||
|
err := VerifyInputOutputLen(packet, false, false)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
sort.Sort(&sortableInputs{p: packet})
|
||
|
sort.Sort(&sortableOutputs{p: packet})
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// sortableInputs is a simple wrapper around a packet that implements the
|
||
|
// sort.Interface for sorting the wire and partial inputs of a packet.
|
||
|
type sortableInputs struct {
|
||
|
p *Packet
|
||
|
}
|
||
|
|
||
|
// sortableOutputs is a simple wrapper around a packet that implements the
|
||
|
// sort.Interface for sorting the wire and partial outputs of a packet.
|
||
|
type sortableOutputs struct {
|
||
|
p *Packet
|
||
|
}
|
||
|
|
||
|
// For sortableInputs and sortableOutputs, three functions are needed to make
|
||
|
// them sortable with sort.Sort() -- Len, Less, and Swap.
|
||
|
// Len and Swap are trivial. Less is BIP 69 specific.
|
||
|
func (s *sortableInputs) Len() int { return len(s.p.UnsignedTx.TxIn) }
|
||
|
func (s sortableOutputs) Len() int { return len(s.p.UnsignedTx.TxOut) }
|
||
|
|
||
|
// Swap swaps two inputs.
|
||
|
func (s *sortableInputs) Swap(i, j int) {
|
||
|
tx := s.p.UnsignedTx
|
||
|
tx.TxIn[i], tx.TxIn[j] = tx.TxIn[j], tx.TxIn[i]
|
||
|
s.p.Inputs[i], s.p.Inputs[j] = s.p.Inputs[j], s.p.Inputs[i]
|
||
|
}
|
||
|
|
||
|
// Swap swaps two outputs.
|
||
|
func (s *sortableOutputs) Swap(i, j int) {
|
||
|
tx := s.p.UnsignedTx
|
||
|
tx.TxOut[i], tx.TxOut[j] = tx.TxOut[j], tx.TxOut[i]
|
||
|
s.p.Outputs[i], s.p.Outputs[j] = s.p.Outputs[j], s.p.Outputs[i]
|
||
|
}
|
||
|
|
||
|
// Less is the input comparison function. First sort based on input hash
|
||
|
// (reversed / rpc-style), then index.
|
||
|
func (s *sortableInputs) Less(i, j int) bool {
|
||
|
ins := s.p.UnsignedTx.TxIn
|
||
|
|
||
|
// Input hashes are the same, so compare the index.
|
||
|
ihash := ins[i].PreviousOutPoint.Hash
|
||
|
jhash := ins[j].PreviousOutPoint.Hash
|
||
|
if ihash == jhash {
|
||
|
return ins[i].PreviousOutPoint.Index <
|
||
|
ins[j].PreviousOutPoint.Index
|
||
|
}
|
||
|
|
||
|
// At this point, the hashes are not equal, so reverse them to
|
||
|
// big-endian and return the result of the comparison.
|
||
|
const hashSize = chainhash.HashSize
|
||
|
for b := 0; b < hashSize/2; b++ {
|
||
|
ihash[b], ihash[hashSize-1-b] = ihash[hashSize-1-b], ihash[b]
|
||
|
jhash[b], jhash[hashSize-1-b] = jhash[hashSize-1-b], jhash[b]
|
||
|
}
|
||
|
return bytes.Compare(ihash[:], jhash[:]) == -1
|
||
|
}
|
||
|
|
||
|
// Less is the output comparison function. First sort based on amount (smallest
|
||
|
// first), then PkScript.
|
||
|
func (s *sortableOutputs) Less(i, j int) bool {
|
||
|
outs := s.p.UnsignedTx.TxOut
|
||
|
|
||
|
if outs[i].Value == outs[j].Value {
|
||
|
return bytes.Compare(outs[i].PkScript, outs[j].PkScript) < 0
|
||
|
}
|
||
|
return outs[i].Value < outs[j].Value
|
||
|
}
|