wallet: return change index from FundPsbt
To make it easy to show the user what change output was created (if any) during the funding process, we return its index (or -1 if no change output was created).
This commit is contained in:
parent
9d8d984207
commit
34bfc5efb9
2 changed files with 40 additions and 19 deletions
|
@ -19,7 +19,9 @@ import (
|
|||
|
||||
// FundPsbt creates a fully populated PSBT packet that contains enough inputs to
|
||||
// fund the outputs specified in the passed in packet with the specified fee
|
||||
// rate. If there is change left, a change output from the wallet is added.
|
||||
// rate. If there is change left, a change output from the wallet is added and
|
||||
// the index of the change output is returned. Otherwise no additional output
|
||||
// is created and the index -1 is returned.
|
||||
//
|
||||
// NOTE: If the packet doesn't contain any inputs, coin selection is performed
|
||||
// automatically. If the packet does contain any inputs, it is assumed that full
|
||||
|
@ -32,13 +34,13 @@ import (
|
|||
// selected/validated inputs by this method. It is in the caller's
|
||||
// responsibility to lock the inputs before handing the partial transaction out.
|
||||
func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||
feeSatPerKB btcutil.Amount) error {
|
||||
feeSatPerKB btcutil.Amount) (int32, error) {
|
||||
|
||||
// Make sure the packet is well formed. We only require there to be at
|
||||
// least one output but not necessarily any inputs.
|
||||
err := psbt.VerifyInputOutputLen(packet, false, true)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
txOut := packet.UnsignedTx.TxOut
|
||||
|
@ -53,7 +55,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
// dust.
|
||||
err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,7 +110,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating funding TX: %v", err)
|
||||
return 0, fmt.Errorf("error creating funding TX: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
// Copy over the inputs now then collect all UTXO information
|
||||
|
@ -118,7 +121,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
packet.UnsignedTx.TxIn = tx.Tx.TxIn
|
||||
err = addInputInfo(tx.Tx.TxIn)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// If there are inputs, we need to check if they're sufficient and add
|
||||
|
@ -127,7 +130,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
// Make sure all inputs provided are actually ours.
|
||||
err = addInputInfo(txIn)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// We can leverage the fee calculation of the txauthor package
|
||||
|
@ -147,7 +150,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
// a new change addresse into the database.
|
||||
dbtx, err := w.db.BeginReadWriteTx()
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
_, changeSource := w.addrMgrWithChangeSource(dbtx, account)
|
||||
|
||||
|
@ -159,24 +162,25 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
)
|
||||
if err != nil {
|
||||
_ = dbtx.Rollback()
|
||||
return fmt.Errorf("fee estimation not successful: %v",
|
||||
err)
|
||||
return 0, fmt.Errorf("fee estimation not successful: "+
|
||||
"%v", err)
|
||||
}
|
||||
|
||||
// The transaction could be created, let's commit the DB TX to
|
||||
// store the change address (if one was created).
|
||||
err = dbtx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not add change address to "+
|
||||
return 0, fmt.Errorf("could not add change address to "+
|
||||
"database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a change output, we need to copy it over to the PSBT now.
|
||||
var changeTxOut *wire.TxOut
|
||||
if tx.ChangeIndex >= 0 {
|
||||
changeTxOut = tx.Tx.TxOut[tx.ChangeIndex]
|
||||
packet.UnsignedTx.TxOut = append(
|
||||
packet.UnsignedTx.TxOut,
|
||||
tx.Tx.TxOut[tx.ChangeIndex],
|
||||
packet.UnsignedTx.TxOut, changeTxOut,
|
||||
)
|
||||
packet.Outputs = append(packet.Outputs, psbt.POutput{})
|
||||
}
|
||||
|
@ -186,10 +190,22 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
|||
// partial inputs and outputs accordingly.
|
||||
err = psbt.InPlaceSort(packet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not sort PSBT: %v", err)
|
||||
return 0, fmt.Errorf("could not sort PSBT: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// The change output index might have changed after the sorting. We need
|
||||
// to find our index again.
|
||||
changeIndex := int32(-1)
|
||||
if changeTxOut != nil {
|
||||
for idx, txOut := range packet.UnsignedTx.TxOut {
|
||||
if psbt.TxOutsEqual(changeTxOut, txOut) {
|
||||
changeIndex = int32(idx)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changeIndex, nil
|
||||
}
|
||||
|
||||
// FinalizePsbt expects a partial transaction with all inputs and outputs fully
|
||||
|
|
|
@ -142,7 +142,9 @@ func TestFundPsbt(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := w.FundPsbt(tc.packet, 0, tc.feeRateSatPerKB)
|
||||
changeIndex, err := w.FundPsbt(
|
||||
tc.packet, 0, tc.feeRateSatPerKB,
|
||||
)
|
||||
|
||||
// Make sure the error is what we expected.
|
||||
if err == nil && tc.expectedErr != "" {
|
||||
|
@ -258,9 +260,10 @@ func TestFundPsbt(t *testing.T) {
|
|||
}
|
||||
p2wkhIndex := -1
|
||||
p2wshIndex := -1
|
||||
changeIndex := -1
|
||||
totalOut := int64(0)
|
||||
for idx, txOut := range txOuts {
|
||||
script := txOut.PkScript
|
||||
totalOut += txOut.Value
|
||||
|
||||
switch {
|
||||
case bytes.Equal(script, testScriptP2WKH):
|
||||
|
@ -269,10 +272,12 @@ func TestFundPsbt(t *testing.T) {
|
|||
case bytes.Equal(script, testScriptP2WSH):
|
||||
p2wshIndex = idx
|
||||
|
||||
default:
|
||||
changeIndex = idx
|
||||
}
|
||||
}
|
||||
totalIn := int64(0)
|
||||
for _, txIn := range packet.Inputs {
|
||||
totalIn += txIn.WitnessUtxo.Value
|
||||
}
|
||||
|
||||
// All outputs must be found.
|
||||
if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 {
|
||||
|
|
Loading…
Add table
Reference in a new issue