Merge pull request #721 from guggero/psbt-change-fix

wallet: use constant input source for change calculation
This commit is contained in:
Olaoluwa Osuntokun 2020-10-01 17:39:44 -07:00 committed by GitHub
commit e6d01202cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 27 deletions

View file

@ -19,7 +19,9 @@ import (
// FundPsbt creates a fully populated PSBT packet that contains enough inputs to // FundPsbt creates a fully populated PSBT packet that contains enough inputs to
// fund the outputs specified in the passed in packet with the specified fee // fund the outputs specified in the passed in packet with the specified fee
// rate. If there is change left, a change output from the wallet is added. // rate. If there is change left, a change output from the wallet is added and
// the index of the change output is returned. Otherwise no additional output
// is created and the index -1 is returned.
// //
// NOTE: If the packet doesn't contain any inputs, coin selection is performed // NOTE: If the packet doesn't contain any inputs, coin selection is performed
// automatically. If the packet does contain any inputs, it is assumed that full // automatically. If the packet does contain any inputs, it is assumed that full
@ -32,13 +34,13 @@ import (
// selected/validated inputs by this method. It is in the caller's // selected/validated inputs by this method. It is in the caller's
// responsibility to lock the inputs before handing the partial transaction out. // responsibility to lock the inputs before handing the partial transaction out.
func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
feeSatPerKB btcutil.Amount) error { feeSatPerKB btcutil.Amount) (int32, error) {
// Make sure the packet is well formed. We only require there to be at // Make sure the packet is well formed. We only require there to be at
// least one output but not necessarily any inputs. // least one output but not necessarily any inputs.
err := psbt.VerifyInputOutputLen(packet, false, true) err := psbt.VerifyInputOutputLen(packet, false, true)
if err != nil { if err != nil {
return err return 0, err
} }
txOut := packet.UnsignedTx.TxOut txOut := packet.UnsignedTx.TxOut
@ -53,7 +55,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
// dust. // dust.
err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb) err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb)
if err != nil { if err != nil {
return err return 0, err
} }
} }
@ -108,7 +110,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
false, false,
) )
if err != nil { if err != nil {
return fmt.Errorf("error creating funding TX: %v", err) return 0, fmt.Errorf("error creating funding TX: %v",
err)
} }
// Copy over the inputs now then collect all UTXO information // Copy over the inputs now then collect all UTXO information
@ -118,7 +121,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
packet.UnsignedTx.TxIn = tx.Tx.TxIn packet.UnsignedTx.TxIn = tx.Tx.TxIn
err = addInputInfo(tx.Tx.TxIn) err = addInputInfo(tx.Tx.TxIn)
if err != nil { if err != nil {
return err return 0, err
} }
// If there are inputs, we need to check if they're sufficient and add // If there are inputs, we need to check if they're sufficient and add
@ -127,11 +130,18 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
// Make sure all inputs provided are actually ours. // Make sure all inputs provided are actually ours.
err = addInputInfo(txIn) err = addInputInfo(txIn)
if err != nil { if err != nil {
return err return 0, err
} }
// We can leverage the fee calculation of the txauthor package // We can leverage the fee calculation of the txauthor package
// if we provide the selected UTXOs as a coin source. // if we provide the selected UTXOs as a coin source. We just
// need to make sure we always return the full list of user-
// selected UTXOs rather than a subset, otherwise our change
// amount will be off (in case the user selected multiple UTXOs
// that are large enough on their own). That's why we use our
// own static input source creator instead of the more generic
// makeInputSource() that selects a subset that is "large
// enough".
credits := make([]wtxmgr.Credit, len(txIn)) credits := make([]wtxmgr.Credit, len(txIn))
for idx, in := range txIn { for idx, in := range txIn {
utxo := packet.Inputs[idx].WitnessUtxo utxo := packet.Inputs[idx].WitnessUtxo
@ -141,13 +151,13 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
PkScript: utxo.PkScript, PkScript: utxo.PkScript,
} }
} }
inputSource := makeInputSource(credits) inputSource := constantInputSource(credits)
// We also need a change source which needs to be able to insert // We also need a change source which needs to be able to insert
// a new change addresse into the database. // a new change addresse into the database.
dbtx, err := w.db.BeginReadWriteTx() dbtx, err := w.db.BeginReadWriteTx()
if err != nil { if err != nil {
return err return 0, err
} }
_, changeSource := w.addrMgrWithChangeSource(dbtx, account) _, changeSource := w.addrMgrWithChangeSource(dbtx, account)
@ -159,24 +169,25 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
) )
if err != nil { if err != nil {
_ = dbtx.Rollback() _ = dbtx.Rollback()
return fmt.Errorf("fee estimation not successful: %v", return 0, fmt.Errorf("fee estimation not successful: "+
err) "%v", err)
} }
// The transaction could be created, let's commit the DB TX to // The transaction could be created, let's commit the DB TX to
// store the change address (if one was created). // store the change address (if one was created).
err = dbtx.Commit() err = dbtx.Commit()
if err != nil { if err != nil {
return fmt.Errorf("could not add change address to "+ return 0, fmt.Errorf("could not add change address to "+
"database: %v", err) "database: %v", err)
} }
} }
// If there is a change output, we need to copy it over to the PSBT now. // If there is a change output, we need to copy it over to the PSBT now.
var changeTxOut *wire.TxOut
if tx.ChangeIndex >= 0 { if tx.ChangeIndex >= 0 {
changeTxOut = tx.Tx.TxOut[tx.ChangeIndex]
packet.UnsignedTx.TxOut = append( packet.UnsignedTx.TxOut = append(
packet.UnsignedTx.TxOut, packet.UnsignedTx.TxOut, changeTxOut,
tx.Tx.TxOut[tx.ChangeIndex],
) )
packet.Outputs = append(packet.Outputs, psbt.POutput{}) packet.Outputs = append(packet.Outputs, psbt.POutput{})
} }
@ -186,10 +197,22 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
// partial inputs and outputs accordingly. // partial inputs and outputs accordingly.
err = psbt.InPlaceSort(packet) err = psbt.InPlaceSort(packet)
if err != nil { if err != nil {
return fmt.Errorf("could not sort PSBT: %v", err) return 0, fmt.Errorf("could not sort PSBT: %v", err)
} }
return nil // The change output index might have changed after the sorting. We need
// to find our index again.
changeIndex := int32(-1)
if changeTxOut != nil {
for idx, txOut := range packet.UnsignedTx.TxOut {
if psbt.TxOutsEqual(changeTxOut, txOut) {
changeIndex = int32(idx)
break
}
}
}
return changeIndex, nil
} }
// FinalizePsbt expects a partial transaction with all inputs and outputs fully // FinalizePsbt expects a partial transaction with all inputs and outputs fully
@ -303,3 +326,30 @@ func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error {
return nil return nil
} }
// constantInputSource creates an input source function that always returns the
// static set of user-selected UTXOs.
func constantInputSource(eligible []wtxmgr.Credit) txauthor.InputSource {
// Current inputs and their total value. These won't change over
// different invocations as we want our inputs to remain static since
// they're selected by the user.
currentTotal := btcutil.Amount(0)
currentInputs := make([]*wire.TxIn, 0, len(eligible))
currentScripts := make([][]byte, 0, len(eligible))
currentInputValues := make([]btcutil.Amount, 0, len(eligible))
for _, credit := range eligible {
nextInput := wire.NewTxIn(&credit.OutPoint, nil, nil)
currentTotal += credit.Amount
currentInputs = append(currentInputs, nextInput)
currentScripts = append(currentScripts, credit.PkScript)
currentInputValues = append(currentInputValues, credit.Amount)
}
return func(target btcutil.Amount) (btcutil.Amount, []*wire.TxIn,
[]btcutil.Amount, [][]byte, error) {
return currentTotal, currentInputs, currentInputValues,
currentScripts, nil
}
}

View file

@ -68,6 +68,8 @@ func TestFundPsbt(t *testing.T) {
feeRateSatPerKB btcutil.Amount feeRateSatPerKB btcutil.Amount
expectedErr string expectedErr string
validatePackage bool validatePackage bool
expectedFee int64
expectedChange int64
numExpectedInputs int numExpectedInputs int
}{{ }{{
name: "no outputs provided", name: "no outputs provided",
@ -106,6 +108,8 @@ func TestFundPsbt(t *testing.T) {
feeRateSatPerKB: 2000, // 2 sat/byte feeRateSatPerKB: 2000, // 2 sat/byte
expectedErr: "", expectedErr: "",
validatePackage: true, validatePackage: true,
expectedChange: 1000000 - 150000 - 368,
expectedFee: 368,
numExpectedInputs: 1, numExpectedInputs: 1,
}, { }, {
name: "two outputs, two inputs", name: "two outputs, two inputs",
@ -136,13 +140,17 @@ func TestFundPsbt(t *testing.T) {
feeRateSatPerKB: 2000, // 2 sat/byte feeRateSatPerKB: 2000, // 2 sat/byte
expectedErr: "", expectedErr: "",
validatePackage: true, validatePackage: true,
expectedFee: 506,
expectedChange: 2000000 - 150000 - 506,
numExpectedInputs: 2, numExpectedInputs: 2,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := w.FundPsbt(tc.packet, 0, tc.feeRateSatPerKB) changeIndex, err := w.FundPsbt(
tc.packet, 0, tc.feeRateSatPerKB,
)
// Make sure the error is what we expected. // Make sure the error is what we expected.
if err == nil && tc.expectedErr != "" { if err == nil && tc.expectedErr != "" {
@ -258,9 +266,10 @@ func TestFundPsbt(t *testing.T) {
} }
p2wkhIndex := -1 p2wkhIndex := -1
p2wshIndex := -1 p2wshIndex := -1
changeIndex := -1 totalOut := int64(0)
for idx, txOut := range txOuts { for idx, txOut := range txOuts {
script := txOut.PkScript script := txOut.PkScript
totalOut += txOut.Value
switch { switch {
case bytes.Equal(script, testScriptP2WKH): case bytes.Equal(script, testScriptP2WKH):
@ -269,10 +278,12 @@ func TestFundPsbt(t *testing.T) {
case bytes.Equal(script, testScriptP2WSH): case bytes.Equal(script, testScriptP2WSH):
p2wshIndex = idx p2wshIndex = idx
default:
changeIndex = idx
} }
} }
totalIn := int64(0)
for _, txIn := range packet.Inputs {
totalIn += txIn.WitnessUtxo.Value
}
// All outputs must be found. // All outputs must be found.
if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 { if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 {
@ -291,15 +302,17 @@ func TestFundPsbt(t *testing.T) {
txOuts[p2wshIndex].PkScript) txOuts[p2wshIndex].PkScript)
} }
// Finally, check the change output size and that it // Finally, check the change output size and fee.
// belongs to the wallet. fee := totalIn - totalOut
expectedFee := int64(368) if fee != tc.expectedFee {
expectedChange := 1000000 - 150000 - expectedFee t.Fatalf("unexpected fee, got %d wanted %d",
if txOuts[changeIndex].Value != expectedChange { fee, tc.expectedFee)
}
if txOuts[changeIndex].Value != tc.expectedChange {
t.Fatalf("unexpected change output size, got "+ t.Fatalf("unexpected change output size, got "+
"%d wanted %d", "%d wanted %d",
txOuts[changeIndex].Value, txOuts[changeIndex].Value,
expectedChange) tc.expectedChange)
} }
}) })
} }