diff --git a/wallet/psbt.go b/wallet/psbt.go index 1b5b211..1a68a71 100644 --- a/wallet/psbt.go +++ b/wallet/psbt.go @@ -19,7 +19,9 @@ import ( // FundPsbt creates a fully populated PSBT packet that contains enough inputs to // fund the outputs specified in the passed in packet with the specified fee -// rate. If there is change left, a change output from the wallet is added. +// rate. If there is change left, a change output from the wallet is added and +// the index of the change output is returned. Otherwise no additional output +// is created and the index -1 is returned. // // NOTE: If the packet doesn't contain any inputs, coin selection is performed // automatically. If the packet does contain any inputs, it is assumed that full @@ -32,13 +34,13 @@ import ( // selected/validated inputs by this method. It is in the caller's // responsibility to lock the inputs before handing the partial transaction out. func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, - feeSatPerKB btcutil.Amount) error { + feeSatPerKB btcutil.Amount) (int32, error) { // Make sure the packet is well formed. We only require there to be at // least one output but not necessarily any inputs. err := psbt.VerifyInputOutputLen(packet, false, true) if err != nil { - return err + return 0, err } txOut := packet.UnsignedTx.TxOut @@ -53,7 +55,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // dust. err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb) if err != nil { - return err + return 0, err } } @@ -108,7 +110,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, false, ) if err != nil { - return fmt.Errorf("error creating funding TX: %v", err) + return 0, fmt.Errorf("error creating funding TX: %v", + err) } // Copy over the inputs now then collect all UTXO information @@ -118,7 +121,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, packet.UnsignedTx.TxIn = tx.Tx.TxIn err = addInputInfo(tx.Tx.TxIn) if err != nil { - return err + return 0, err } // If there are inputs, we need to check if they're sufficient and add @@ -127,11 +130,18 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // Make sure all inputs provided are actually ours. err = addInputInfo(txIn) if err != nil { - return err + return 0, err } // We can leverage the fee calculation of the txauthor package - // if we provide the selected UTXOs as a coin source. + // if we provide the selected UTXOs as a coin source. We just + // need to make sure we always return the full list of user- + // selected UTXOs rather than a subset, otherwise our change + // amount will be off (in case the user selected multiple UTXOs + // that are large enough on their own). That's why we use our + // own static input source creator instead of the more generic + // makeInputSource() that selects a subset that is "large + // enough". credits := make([]wtxmgr.Credit, len(txIn)) for idx, in := range txIn { utxo := packet.Inputs[idx].WitnessUtxo @@ -141,13 +151,13 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, PkScript: utxo.PkScript, } } - inputSource := makeInputSource(credits) + inputSource := constantInputSource(credits) // We also need a change source which needs to be able to insert // a new change addresse into the database. dbtx, err := w.db.BeginReadWriteTx() if err != nil { - return err + return 0, err } _, changeSource := w.addrMgrWithChangeSource(dbtx, account) @@ -159,24 +169,25 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, ) if err != nil { _ = dbtx.Rollback() - return fmt.Errorf("fee estimation not successful: %v", - err) + return 0, fmt.Errorf("fee estimation not successful: "+ + "%v", err) } // The transaction could be created, let's commit the DB TX to // store the change address (if one was created). err = dbtx.Commit() if err != nil { - return fmt.Errorf("could not add change address to "+ + return 0, fmt.Errorf("could not add change address to "+ "database: %v", err) } } // If there is a change output, we need to copy it over to the PSBT now. + var changeTxOut *wire.TxOut if tx.ChangeIndex >= 0 { + changeTxOut = tx.Tx.TxOut[tx.ChangeIndex] packet.UnsignedTx.TxOut = append( - packet.UnsignedTx.TxOut, - tx.Tx.TxOut[tx.ChangeIndex], + packet.UnsignedTx.TxOut, changeTxOut, ) packet.Outputs = append(packet.Outputs, psbt.POutput{}) } @@ -186,10 +197,22 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // partial inputs and outputs accordingly. err = psbt.InPlaceSort(packet) if err != nil { - return fmt.Errorf("could not sort PSBT: %v", err) + return 0, fmt.Errorf("could not sort PSBT: %v", err) } - return nil + // The change output index might have changed after the sorting. We need + // to find our index again. + changeIndex := int32(-1) + if changeTxOut != nil { + for idx, txOut := range packet.UnsignedTx.TxOut { + if psbt.TxOutsEqual(changeTxOut, txOut) { + changeIndex = int32(idx) + break + } + } + } + + return changeIndex, nil } // FinalizePsbt expects a partial transaction with all inputs and outputs fully @@ -303,3 +326,30 @@ func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error { return nil } + +// constantInputSource creates an input source function that always returns the +// static set of user-selected UTXOs. +func constantInputSource(eligible []wtxmgr.Credit) txauthor.InputSource { + // Current inputs and their total value. These won't change over + // different invocations as we want our inputs to remain static since + // they're selected by the user. + currentTotal := btcutil.Amount(0) + currentInputs := make([]*wire.TxIn, 0, len(eligible)) + currentScripts := make([][]byte, 0, len(eligible)) + currentInputValues := make([]btcutil.Amount, 0, len(eligible)) + + for _, credit := range eligible { + nextInput := wire.NewTxIn(&credit.OutPoint, nil, nil) + currentTotal += credit.Amount + currentInputs = append(currentInputs, nextInput) + currentScripts = append(currentScripts, credit.PkScript) + currentInputValues = append(currentInputValues, credit.Amount) + } + + return func(target btcutil.Amount) (btcutil.Amount, []*wire.TxIn, + []btcutil.Amount, [][]byte, error) { + + return currentTotal, currentInputs, currentInputValues, + currentScripts, nil + } +} diff --git a/wallet/psbt_test.go b/wallet/psbt_test.go index 5d77808..b9747b2 100644 --- a/wallet/psbt_test.go +++ b/wallet/psbt_test.go @@ -68,6 +68,8 @@ func TestFundPsbt(t *testing.T) { feeRateSatPerKB btcutil.Amount expectedErr string validatePackage bool + expectedFee int64 + expectedChange int64 numExpectedInputs int }{{ name: "no outputs provided", @@ -106,6 +108,8 @@ func TestFundPsbt(t *testing.T) { feeRateSatPerKB: 2000, // 2 sat/byte expectedErr: "", validatePackage: true, + expectedChange: 1000000 - 150000 - 368, + expectedFee: 368, numExpectedInputs: 1, }, { name: "two outputs, two inputs", @@ -136,13 +140,17 @@ func TestFundPsbt(t *testing.T) { feeRateSatPerKB: 2000, // 2 sat/byte expectedErr: "", validatePackage: true, + expectedFee: 506, + expectedChange: 2000000 - 150000 - 506, numExpectedInputs: 2, }} for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - err := w.FundPsbt(tc.packet, 0, tc.feeRateSatPerKB) + changeIndex, err := w.FundPsbt( + tc.packet, 0, tc.feeRateSatPerKB, + ) // Make sure the error is what we expected. if err == nil && tc.expectedErr != "" { @@ -258,9 +266,10 @@ func TestFundPsbt(t *testing.T) { } p2wkhIndex := -1 p2wshIndex := -1 - changeIndex := -1 + totalOut := int64(0) for idx, txOut := range txOuts { script := txOut.PkScript + totalOut += txOut.Value switch { case bytes.Equal(script, testScriptP2WKH): @@ -269,10 +278,12 @@ func TestFundPsbt(t *testing.T) { case bytes.Equal(script, testScriptP2WSH): p2wshIndex = idx - default: - changeIndex = idx } } + totalIn := int64(0) + for _, txIn := range packet.Inputs { + totalIn += txIn.WitnessUtxo.Value + } // All outputs must be found. if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 { @@ -291,15 +302,17 @@ func TestFundPsbt(t *testing.T) { txOuts[p2wshIndex].PkScript) } - // Finally, check the change output size and that it - // belongs to the wallet. - expectedFee := int64(368) - expectedChange := 1000000 - 150000 - expectedFee - if txOuts[changeIndex].Value != expectedChange { + // Finally, check the change output size and fee. + fee := totalIn - totalOut + if fee != tc.expectedFee { + t.Fatalf("unexpected fee, got %d wanted %d", + fee, tc.expectedFee) + } + if txOuts[changeIndex].Value != tc.expectedChange { t.Fatalf("unexpected change output size, got "+ "%d wanted %d", txOuts[changeIndex].Value, - expectedChange) + tc.expectedChange) } }) }