Merge pull request #721 from guggero/psbt-change-fix
wallet: use constant input source for change calculation
This commit is contained in:
commit
e6d01202cb
2 changed files with 90 additions and 27 deletions
|
@ -19,7 +19,9 @@ import (
|
||||||
|
|
||||||
// FundPsbt creates a fully populated PSBT packet that contains enough inputs to
|
// FundPsbt creates a fully populated PSBT packet that contains enough inputs to
|
||||||
// fund the outputs specified in the passed in packet with the specified fee
|
// fund the outputs specified in the passed in packet with the specified fee
|
||||||
// rate. If there is change left, a change output from the wallet is added.
|
// rate. If there is change left, a change output from the wallet is added and
|
||||||
|
// the index of the change output is returned. Otherwise no additional output
|
||||||
|
// is created and the index -1 is returned.
|
||||||
//
|
//
|
||||||
// NOTE: If the packet doesn't contain any inputs, coin selection is performed
|
// NOTE: If the packet doesn't contain any inputs, coin selection is performed
|
||||||
// automatically. If the packet does contain any inputs, it is assumed that full
|
// automatically. If the packet does contain any inputs, it is assumed that full
|
||||||
|
@ -32,13 +34,13 @@ import (
|
||||||
// selected/validated inputs by this method. It is in the caller's
|
// selected/validated inputs by this method. It is in the caller's
|
||||||
// responsibility to lock the inputs before handing the partial transaction out.
|
// responsibility to lock the inputs before handing the partial transaction out.
|
||||||
func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
feeSatPerKB btcutil.Amount) error {
|
feeSatPerKB btcutil.Amount) (int32, error) {
|
||||||
|
|
||||||
// Make sure the packet is well formed. We only require there to be at
|
// Make sure the packet is well formed. We only require there to be at
|
||||||
// least one output but not necessarily any inputs.
|
// least one output but not necessarily any inputs.
|
||||||
err := psbt.VerifyInputOutputLen(packet, false, true)
|
err := psbt.VerifyInputOutputLen(packet, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
txOut := packet.UnsignedTx.TxOut
|
txOut := packet.UnsignedTx.TxOut
|
||||||
|
@ -53,7 +55,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
// dust.
|
// dust.
|
||||||
err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb)
|
err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +110,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating funding TX: %v", err)
|
return 0, fmt.Errorf("error creating funding TX: %v",
|
||||||
|
err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy over the inputs now then collect all UTXO information
|
// Copy over the inputs now then collect all UTXO information
|
||||||
|
@ -118,7 +121,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
packet.UnsignedTx.TxIn = tx.Tx.TxIn
|
packet.UnsignedTx.TxIn = tx.Tx.TxIn
|
||||||
err = addInputInfo(tx.Tx.TxIn)
|
err = addInputInfo(tx.Tx.TxIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are inputs, we need to check if they're sufficient and add
|
// If there are inputs, we need to check if they're sufficient and add
|
||||||
|
@ -127,11 +130,18 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
// Make sure all inputs provided are actually ours.
|
// Make sure all inputs provided are actually ours.
|
||||||
err = addInputInfo(txIn)
|
err = addInputInfo(txIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can leverage the fee calculation of the txauthor package
|
// We can leverage the fee calculation of the txauthor package
|
||||||
// if we provide the selected UTXOs as a coin source.
|
// if we provide the selected UTXOs as a coin source. We just
|
||||||
|
// need to make sure we always return the full list of user-
|
||||||
|
// selected UTXOs rather than a subset, otherwise our change
|
||||||
|
// amount will be off (in case the user selected multiple UTXOs
|
||||||
|
// that are large enough on their own). That's why we use our
|
||||||
|
// own static input source creator instead of the more generic
|
||||||
|
// makeInputSource() that selects a subset that is "large
|
||||||
|
// enough".
|
||||||
credits := make([]wtxmgr.Credit, len(txIn))
|
credits := make([]wtxmgr.Credit, len(txIn))
|
||||||
for idx, in := range txIn {
|
for idx, in := range txIn {
|
||||||
utxo := packet.Inputs[idx].WitnessUtxo
|
utxo := packet.Inputs[idx].WitnessUtxo
|
||||||
|
@ -141,13 +151,13 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
PkScript: utxo.PkScript,
|
PkScript: utxo.PkScript,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inputSource := makeInputSource(credits)
|
inputSource := constantInputSource(credits)
|
||||||
|
|
||||||
// We also need a change source which needs to be able to insert
|
// We also need a change source which needs to be able to insert
|
||||||
// a new change addresse into the database.
|
// a new change addresse into the database.
|
||||||
dbtx, err := w.db.BeginReadWriteTx()
|
dbtx, err := w.db.BeginReadWriteTx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
_, changeSource := w.addrMgrWithChangeSource(dbtx, account)
|
_, changeSource := w.addrMgrWithChangeSource(dbtx, account)
|
||||||
|
|
||||||
|
@ -159,24 +169,25 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = dbtx.Rollback()
|
_ = dbtx.Rollback()
|
||||||
return fmt.Errorf("fee estimation not successful: %v",
|
return 0, fmt.Errorf("fee estimation not successful: "+
|
||||||
err)
|
"%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The transaction could be created, let's commit the DB TX to
|
// The transaction could be created, let's commit the DB TX to
|
||||||
// store the change address (if one was created).
|
// store the change address (if one was created).
|
||||||
err = dbtx.Commit()
|
err = dbtx.Commit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not add change address to "+
|
return 0, fmt.Errorf("could not add change address to "+
|
||||||
"database: %v", err)
|
"database: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is a change output, we need to copy it over to the PSBT now.
|
// If there is a change output, we need to copy it over to the PSBT now.
|
||||||
|
var changeTxOut *wire.TxOut
|
||||||
if tx.ChangeIndex >= 0 {
|
if tx.ChangeIndex >= 0 {
|
||||||
|
changeTxOut = tx.Tx.TxOut[tx.ChangeIndex]
|
||||||
packet.UnsignedTx.TxOut = append(
|
packet.UnsignedTx.TxOut = append(
|
||||||
packet.UnsignedTx.TxOut,
|
packet.UnsignedTx.TxOut, changeTxOut,
|
||||||
tx.Tx.TxOut[tx.ChangeIndex],
|
|
||||||
)
|
)
|
||||||
packet.Outputs = append(packet.Outputs, psbt.POutput{})
|
packet.Outputs = append(packet.Outputs, psbt.POutput{})
|
||||||
}
|
}
|
||||||
|
@ -186,10 +197,22 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
|
||||||
// partial inputs and outputs accordingly.
|
// partial inputs and outputs accordingly.
|
||||||
err = psbt.InPlaceSort(packet)
|
err = psbt.InPlaceSort(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not sort PSBT: %v", err)
|
return 0, fmt.Errorf("could not sort PSBT: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// The change output index might have changed after the sorting. We need
|
||||||
|
// to find our index again.
|
||||||
|
changeIndex := int32(-1)
|
||||||
|
if changeTxOut != nil {
|
||||||
|
for idx, txOut := range packet.UnsignedTx.TxOut {
|
||||||
|
if psbt.TxOutsEqual(changeTxOut, txOut) {
|
||||||
|
changeIndex = int32(idx)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changeIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinalizePsbt expects a partial transaction with all inputs and outputs fully
|
// FinalizePsbt expects a partial transaction with all inputs and outputs fully
|
||||||
|
@ -303,3 +326,30 @@ func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// constantInputSource creates an input source function that always returns the
|
||||||
|
// static set of user-selected UTXOs.
|
||||||
|
func constantInputSource(eligible []wtxmgr.Credit) txauthor.InputSource {
|
||||||
|
// Current inputs and their total value. These won't change over
|
||||||
|
// different invocations as we want our inputs to remain static since
|
||||||
|
// they're selected by the user.
|
||||||
|
currentTotal := btcutil.Amount(0)
|
||||||
|
currentInputs := make([]*wire.TxIn, 0, len(eligible))
|
||||||
|
currentScripts := make([][]byte, 0, len(eligible))
|
||||||
|
currentInputValues := make([]btcutil.Amount, 0, len(eligible))
|
||||||
|
|
||||||
|
for _, credit := range eligible {
|
||||||
|
nextInput := wire.NewTxIn(&credit.OutPoint, nil, nil)
|
||||||
|
currentTotal += credit.Amount
|
||||||
|
currentInputs = append(currentInputs, nextInput)
|
||||||
|
currentScripts = append(currentScripts, credit.PkScript)
|
||||||
|
currentInputValues = append(currentInputValues, credit.Amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(target btcutil.Amount) (btcutil.Amount, []*wire.TxIn,
|
||||||
|
[]btcutil.Amount, [][]byte, error) {
|
||||||
|
|
||||||
|
return currentTotal, currentInputs, currentInputValues,
|
||||||
|
currentScripts, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -68,6 +68,8 @@ func TestFundPsbt(t *testing.T) {
|
||||||
feeRateSatPerKB btcutil.Amount
|
feeRateSatPerKB btcutil.Amount
|
||||||
expectedErr string
|
expectedErr string
|
||||||
validatePackage bool
|
validatePackage bool
|
||||||
|
expectedFee int64
|
||||||
|
expectedChange int64
|
||||||
numExpectedInputs int
|
numExpectedInputs int
|
||||||
}{{
|
}{{
|
||||||
name: "no outputs provided",
|
name: "no outputs provided",
|
||||||
|
@ -106,6 +108,8 @@ func TestFundPsbt(t *testing.T) {
|
||||||
feeRateSatPerKB: 2000, // 2 sat/byte
|
feeRateSatPerKB: 2000, // 2 sat/byte
|
||||||
expectedErr: "",
|
expectedErr: "",
|
||||||
validatePackage: true,
|
validatePackage: true,
|
||||||
|
expectedChange: 1000000 - 150000 - 368,
|
||||||
|
expectedFee: 368,
|
||||||
numExpectedInputs: 1,
|
numExpectedInputs: 1,
|
||||||
}, {
|
}, {
|
||||||
name: "two outputs, two inputs",
|
name: "two outputs, two inputs",
|
||||||
|
@ -136,13 +140,17 @@ func TestFundPsbt(t *testing.T) {
|
||||||
feeRateSatPerKB: 2000, // 2 sat/byte
|
feeRateSatPerKB: 2000, // 2 sat/byte
|
||||||
expectedErr: "",
|
expectedErr: "",
|
||||||
validatePackage: true,
|
validatePackage: true,
|
||||||
|
expectedFee: 506,
|
||||||
|
expectedChange: 2000000 - 150000 - 506,
|
||||||
numExpectedInputs: 2,
|
numExpectedInputs: 2,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
err := w.FundPsbt(tc.packet, 0, tc.feeRateSatPerKB)
|
changeIndex, err := w.FundPsbt(
|
||||||
|
tc.packet, 0, tc.feeRateSatPerKB,
|
||||||
|
)
|
||||||
|
|
||||||
// Make sure the error is what we expected.
|
// Make sure the error is what we expected.
|
||||||
if err == nil && tc.expectedErr != "" {
|
if err == nil && tc.expectedErr != "" {
|
||||||
|
@ -258,9 +266,10 @@ func TestFundPsbt(t *testing.T) {
|
||||||
}
|
}
|
||||||
p2wkhIndex := -1
|
p2wkhIndex := -1
|
||||||
p2wshIndex := -1
|
p2wshIndex := -1
|
||||||
changeIndex := -1
|
totalOut := int64(0)
|
||||||
for idx, txOut := range txOuts {
|
for idx, txOut := range txOuts {
|
||||||
script := txOut.PkScript
|
script := txOut.PkScript
|
||||||
|
totalOut += txOut.Value
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case bytes.Equal(script, testScriptP2WKH):
|
case bytes.Equal(script, testScriptP2WKH):
|
||||||
|
@ -269,10 +278,12 @@ func TestFundPsbt(t *testing.T) {
|
||||||
case bytes.Equal(script, testScriptP2WSH):
|
case bytes.Equal(script, testScriptP2WSH):
|
||||||
p2wshIndex = idx
|
p2wshIndex = idx
|
||||||
|
|
||||||
default:
|
|
||||||
changeIndex = idx
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
totalIn := int64(0)
|
||||||
|
for _, txIn := range packet.Inputs {
|
||||||
|
totalIn += txIn.WitnessUtxo.Value
|
||||||
|
}
|
||||||
|
|
||||||
// All outputs must be found.
|
// All outputs must be found.
|
||||||
if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 {
|
if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 {
|
||||||
|
@ -291,15 +302,17 @@ func TestFundPsbt(t *testing.T) {
|
||||||
txOuts[p2wshIndex].PkScript)
|
txOuts[p2wshIndex].PkScript)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, check the change output size and that it
|
// Finally, check the change output size and fee.
|
||||||
// belongs to the wallet.
|
fee := totalIn - totalOut
|
||||||
expectedFee := int64(368)
|
if fee != tc.expectedFee {
|
||||||
expectedChange := 1000000 - 150000 - expectedFee
|
t.Fatalf("unexpected fee, got %d wanted %d",
|
||||||
if txOuts[changeIndex].Value != expectedChange {
|
fee, tc.expectedFee)
|
||||||
|
}
|
||||||
|
if txOuts[changeIndex].Value != tc.expectedChange {
|
||||||
t.Fatalf("unexpected change output size, got "+
|
t.Fatalf("unexpected change output size, got "+
|
||||||
"%d wanted %d",
|
"%d wanted %d",
|
||||||
txOuts[changeIndex].Value,
|
txOuts[changeIndex].Value,
|
||||||
expectedChange)
|
tc.expectedChange)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue