From b318e99f4feed9c2ed2470237aa077aa3f6dd208 Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Mon, 15 Mar 2021 17:36:16 -0700 Subject: [PATCH] wallet: extend ChangeSource to support all key scopes --- cmd/sweepaccount/main.go | 13 +++++- go.mod | 1 + wallet/createtx.go | 73 +++++++++++++++++++++++++--------- wallet/psbt.go | 5 ++- wallet/txauthor/author.go | 32 +++++++++------ wallet/txauthor/author_test.go | 31 ++++++++------- wallet/txsizes/size.go | 24 ++++++++--- wallet/txsizes/size_test.go | 6 ++- 8 files changed, 131 insertions(+), 54 deletions(-) diff --git a/cmd/sweepaccount/main.go b/cmd/sweepaccount/main.go index c27ca8f..97a3841 100644 --- a/cmd/sweepaccount/main.go +++ b/cmd/sweepaccount/main.go @@ -22,6 +22,7 @@ import ( "github.com/btcsuite/btcwallet/netparams" "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wallet/txrules" + "github.com/btcsuite/btcwallet/wallet/txsizes" "github.com/jessevdk/go-flags" ) @@ -190,14 +191,22 @@ func makeInputSource(outputs []btcjson.ListUnspentResult) txauthor.InputSource { // makeDestinationScriptSource creates a ChangeSource which is used to receive // all correlated previous input value. A non-change address is created by this // function. -func makeDestinationScriptSource(rpcClient *rpcclient.Client, accountName string) txauthor.ChangeSource { - return func() ([]byte, error) { +func makeDestinationScriptSource(rpcClient *rpcclient.Client, accountName string) *txauthor.ChangeSource { + + // GetNewAddress always returns a P2PKH address since it assumes + // BIP-0044. + newChangeScript := func() ([]byte, error) { destinationAddress, err := rpcClient.GetNewAddress(accountName) if err != nil { return nil, err } return txscript.PayToAddrScript(destinationAddress) } + + return &txauthor.ChangeSource{ + ScriptSize: txsizes.P2PKHPkScriptSize, + NewScript: newChangeScript, + } } func main() { diff --git a/go.mod b/go.mod index 82182d9..2f2df85 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/btcsuite/btcutil/psbt v1.0.3-0.20201208143702-a53e38424cce github.com/btcsuite/btcwallet/wallet/txauthor v1.0.0 github.com/btcsuite/btcwallet/wallet/txrules v1.0.0 + github.com/btcsuite/btcwallet/wallet/txsizes v1.0.0 github.com/btcsuite/btcwallet/walletdb v1.3.4 github.com/btcsuite/btcwallet/wtxmgr v1.2.0 github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 diff --git a/wallet/createtx.go b/wallet/createtx.go index 4df8237..900be64 100644 --- a/wallet/createtx.go +++ b/wallet/createtx.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/wallet/txauthor" + "github.com/btcsuite/btcwallet/wallet/txsizes" "github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/wtxmgr" ) @@ -123,9 +124,12 @@ func (w *Wallet) txToOutputs(outputs []*wire.TxOut, keyScope *waddrmgr.KeyScope, } defer func() { _ = dbtx.Rollback() }() - addrmgrNs, changeSource := w.addrMgrWithChangeSource( + addrmgrNs, changeSource, err := w.addrMgrWithChangeSource( dbtx, keyScope, account, ) + if err != nil { + return nil, err + } // Get current block's height and hash. bs, err := chainClient.BlockStamp() @@ -288,25 +292,54 @@ func (w *Wallet) findEligibleOutputs(dbtx walletdb.ReadTx, } // addrMgrWithChangeSource returns the address manager bucket and a change -// source function that returns change addresses from said address manager. The -// change addresses will come from the specified key scope and account, unless -// a key scope is not specified. In that case, change addresses will always -// come from the P2WKH key scope. +// source that returns change addresses from said address manager. The change +// addresses will come from the specified key scope and account, unless a key +// scope is not specified. In that case, change addresses will always come from +// the P2WKH key scope. func (w *Wallet) addrMgrWithChangeSource(dbtx walletdb.ReadWriteTx, - changeKeyScope *waddrmgr.KeyScope, account uint32) (walletdb.ReadWriteBucket, - txauthor.ChangeSource) { + changeKeyScope *waddrmgr.KeyScope, account uint32) ( + walletdb.ReadWriteBucket, *txauthor.ChangeSource, error) { + // Determine the address type for change addresses of the given account. + if changeKeyScope == nil { + changeKeyScope = &waddrmgr.KeyScopeBIP0084 + } + addrType := waddrmgr.ScopeAddrMap[*changeKeyScope].InternalAddrType + + // It's possible for the account to have an address schema override, so + // prefer that if it exists. addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey) - changeSource := func() ([]byte, error) { - // Derive the change output script. We'll use the default key - // scope responsible for P2WPKH addresses to do so. As a hack to - // allow spending from the imported account, change addresses - // are created from account 0. - var changeAddr btcutil.Address - var err error - if changeKeyScope == nil { - changeKeyScope = &waddrmgr.KeyScopeBIP0084 - } + scopeMgr, err := w.Manager.FetchScopedKeyManager(*changeKeyScope) + if err != nil { + return nil, nil, err + } + accountInfo, err := scopeMgr.AccountProperties(addrmgrNs, account) + if err != nil { + return nil, nil, err + } + if accountInfo.AddrSchema != nil { + addrType = accountInfo.AddrSchema.InternalAddrType + } + + // Compute the expected size of the script for the change address type. + var scriptSize int + switch addrType { + case waddrmgr.PubKeyHash: + scriptSize = txsizes.P2PKHPkScriptSize + case waddrmgr.NestedWitnessPubKey: + scriptSize = txsizes.NestedP2WPKHPkScriptSize + case waddrmgr.WitnessPubKey: + scriptSize = txsizes.P2WPKHPkScriptSize + } + + newChangeScript := func() ([]byte, error) { + // Derive the change output script. As a hack to allow spending + // from the imported account, change addresses are created from + // account 0. + var ( + changeAddr btcutil.Address + err error + ) if account == waddrmgr.ImportedAddrAccount { changeAddr, err = w.newChangeAddress( addrmgrNs, 0, *changeKeyScope, @@ -321,7 +354,11 @@ func (w *Wallet) addrMgrWithChangeSource(dbtx walletdb.ReadWriteTx, } return txscript.PayToAddrScript(changeAddr) } - return addrmgrNs, changeSource + + return addrmgrNs, &txauthor.ChangeSource{ + ScriptSize: scriptSize, + NewScript: newChangeScript, + }, nil } // validateMsgTx verifies transaction input scripts for tx. All previous output diff --git a/wallet/psbt.go b/wallet/psbt.go index 58d8f71..fc9085b 100644 --- a/wallet/psbt.go +++ b/wallet/psbt.go @@ -173,9 +173,12 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope, if err != nil { return 0, err } - _, changeSource := w.addrMgrWithChangeSource( + _, changeSource, err := w.addrMgrWithChangeSource( dbtx, keyScope, account, ) + if err != nil { + return 0, err + } // Ask the txauthor to create a transaction with our selected // coins. This will perform fee estimation and add a change diff --git a/wallet/txauthor/author.go b/wallet/txauthor/author.go index 5290c77..dbeaa05 100644 --- a/wallet/txauthor/author.go +++ b/wallet/txauthor/author.go @@ -60,8 +60,15 @@ type AuthoredTx struct { ChangeIndex int // negative if no change } -// ChangeSource provides P2PKH change output scripts for transaction creation. -type ChangeSource func() ([]byte, error) +// ChangeSource provides change output scripts for transaction creation. +type ChangeSource struct { + // NewScript is a closure that produces unique change output scripts per + // invocation. + NewScript func() ([]byte, error) + + // ScriptSize is the size in bytes of scripts produced by `NewScript`. + ScriptSize int +} // NewUnsignedTransaction creates an unsigned transaction paying to one or more // non-change outputs. An appropriate transaction fee is included based on the @@ -84,10 +91,12 @@ type ChangeSource func() ([]byte, error) // // BUGS: Fee estimation may be off when redeeming non-compressed P2PKH outputs. func NewUnsignedTransaction(outputs []*wire.TxOut, feeRatePerKb btcutil.Amount, - fetchInputs InputSource, fetchChange ChangeSource) (*AuthoredTx, error) { + fetchInputs InputSource, changeSource *ChangeSource) (*AuthoredTx, error) { targetAmount := SumOutputValues(outputs) - estimatedSize := txsizes.EstimateVirtualSize(0, 1, 0, outputs, true) + estimatedSize := txsizes.EstimateVirtualSize( + 0, 1, 0, outputs, changeSource.ScriptSize, + ) targetFee := txrules.FeeForSerializeSize(feeRatePerKb, estimatedSize) for { @@ -115,8 +124,9 @@ func NewUnsignedTransaction(outputs []*wire.TxOut, feeRatePerKb btcutil.Amount, } } - maxSignedSize := txsizes.EstimateVirtualSize(p2pkh, p2wpkh, - nested, outputs, true) + maxSignedSize := txsizes.EstimateVirtualSize( + p2pkh, p2wpkh, nested, outputs, changeSource.ScriptSize, + ) maxRequiredFee := txrules.FeeForSerializeSize(feeRatePerKb, maxSignedSize) remainingAmount := inputAmount - targetAmount if remainingAmount < maxRequiredFee { @@ -130,18 +140,16 @@ func NewUnsignedTransaction(outputs []*wire.TxOut, feeRatePerKb btcutil.Amount, TxOut: outputs, LockTime: 0, } + changeIndex := -1 changeAmount := inputAmount - targetAmount - maxRequiredFee if changeAmount != 0 && !txrules.IsDustAmount(changeAmount, - txsizes.P2WPKHPkScriptSize, txrules.DefaultRelayFeePerKb) { - changeScript, err := fetchChange() + changeSource.ScriptSize, txrules.DefaultRelayFeePerKb) { + + changeScript, err := changeSource.NewScript() if err != nil { return nil, err } - if len(changeScript) > txsizes.P2WPKHPkScriptSize { - return nil, errors.New("fee estimation requires change " + - "scripts no larger than P2WPKH output scripts") - } change := wire.NewTxOut(int64(changeAmount), changeScript) l := len(outputs) unsignedTransaction.TxOut = append(outputs[:l:l], change) diff --git a/wallet/txauthor/author_test.go b/wallet/txauthor/author_test.go index 0217417..2100518 100644 --- a/wallet/txauthor/author_test.go +++ b/wallet/txauthor/author_test.go @@ -61,7 +61,7 @@ func TestNewUnsignedTransaction(t *testing.T) { Outputs: p2pkhOutputs(1e6), RelayFee: 1e3, ChangeAmount: 1e8 - 1e6 - txrules.FeeForSerializeSize(1e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6), true)), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6), txsizes.P2WPKHPkScriptSize)), InputCount: 1, }, 2: { @@ -69,7 +69,7 @@ func TestNewUnsignedTransaction(t *testing.T) { Outputs: p2pkhOutputs(1e6), RelayFee: 1e4, ChangeAmount: 1e8 - 1e6 - txrules.FeeForSerializeSize(1e4, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6), true)), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6), txsizes.P2WPKHPkScriptSize)), InputCount: 1, }, 3: { @@ -77,7 +77,7 @@ func TestNewUnsignedTransaction(t *testing.T) { Outputs: p2pkhOutputs(1e6, 1e6, 1e6), RelayFee: 1e4, ChangeAmount: 1e8 - 3e6 - txrules.FeeForSerializeSize(1e4, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6, 1e6, 1e6), true)), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6, 1e6, 1e6), txsizes.P2WPKHPkScriptSize)), InputCount: 1, }, 4: { @@ -85,7 +85,7 @@ func TestNewUnsignedTransaction(t *testing.T) { Outputs: p2pkhOutputs(1e6, 1e6, 1e6), RelayFee: 2.55e3, ChangeAmount: 1e8 - 3e6 - txrules.FeeForSerializeSize(2.55e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6, 1e6, 1e6), true)), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(1e6, 1e6, 1e6), txsizes.P2WPKHPkScriptSize)), InputCount: 1, }, @@ -93,7 +93,7 @@ func TestNewUnsignedTransaction(t *testing.T) { 5: { UnspentOutputs: p2pkhOutputs(1e8), Outputs: p2pkhOutputs(1e8 - 545 - txrules.FeeForSerializeSize(1e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), true))), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), txsizes.P2WPKHPkScriptSize))), RelayFee: 1e3, ChangeAmount: 545, InputCount: 1, @@ -101,7 +101,7 @@ func TestNewUnsignedTransaction(t *testing.T) { 6: { UnspentOutputs: p2pkhOutputs(1e8), Outputs: p2pkhOutputs(1e8 - 546 - txrules.FeeForSerializeSize(1e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), true))), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), txsizes.P2WPKHPkScriptSize))), RelayFee: 1e3, ChangeAmount: 546, InputCount: 1, @@ -111,7 +111,7 @@ func TestNewUnsignedTransaction(t *testing.T) { 7: { UnspentOutputs: p2pkhOutputs(1e8), Outputs: p2pkhOutputs(1e8 - 1392 - txrules.FeeForSerializeSize(2.55e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), true))), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), txsizes.P2WPKHPkScriptSize))), RelayFee: 2.55e3, ChangeAmount: 1392, InputCount: 1, @@ -119,7 +119,7 @@ func TestNewUnsignedTransaction(t *testing.T) { 8: { UnspentOutputs: p2pkhOutputs(1e8), Outputs: p2pkhOutputs(1e8 - 1393 - txrules.FeeForSerializeSize(2.55e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), true))), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), txsizes.P2WPKHPkScriptSize))), RelayFee: 2.55e3, ChangeAmount: 1393, InputCount: 1, @@ -131,7 +131,7 @@ func TestNewUnsignedTransaction(t *testing.T) { 9: { UnspentOutputs: p2pkhOutputs(1e8, 1e8), Outputs: p2pkhOutputs(1e8 - 546 - txrules.FeeForSerializeSize(1e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), true))), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), txsizes.P2WPKHPkScriptSize))), RelayFee: 1e3, ChangeAmount: 546, InputCount: 1, @@ -145,7 +145,7 @@ func TestNewUnsignedTransaction(t *testing.T) { 10: { UnspentOutputs: p2pkhOutputs(1e8, 1e8), Outputs: p2pkhOutputs(1e8 - 545 - txrules.FeeForSerializeSize(1e3, - txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), true))), + txsizes.EstimateVirtualSize(1, 0, 0, p2pkhOutputs(0), txsizes.P2WPKHPkScriptSize))), RelayFee: 1e3, ChangeAmount: 545, InputCount: 1, @@ -157,7 +157,7 @@ func TestNewUnsignedTransaction(t *testing.T) { Outputs: p2pkhOutputs(1e8), RelayFee: 1e3, ChangeAmount: 1e8 - txrules.FeeForSerializeSize(1e3, - txsizes.EstimateVirtualSize(2, 0, 0, p2pkhOutputs(1e8), true)), + txsizes.EstimateVirtualSize(2, 0, 0, p2pkhOutputs(1e8), txsizes.P2WPKHPkScriptSize)), InputCount: 2, }, @@ -172,9 +172,12 @@ func TestNewUnsignedTransaction(t *testing.T) { }, } - changeSource := func() ([]byte, error) { - // Only length matters for these tests. - return make([]byte, txsizes.P2WPKHPkScriptSize), nil + changeSource := &ChangeSource{ + NewScript: func() ([]byte, error) { + // Only length matters for these tests. + return make([]byte, txsizes.P2WPKHPkScriptSize), nil + }, + ScriptSize: txsizes.P2WPKHPkScriptSize, } for i, test := range tests { diff --git a/wallet/txsizes/size.go b/wallet/txsizes/size.go index da3d54c..4b47284 100644 --- a/wallet/txsizes/size.go +++ b/wallet/txsizes/size.go @@ -82,6 +82,16 @@ const ( // - 4 bytes sequence RedeemP2WPKHInputSize = 32 + 4 + 1 + RedeemP2WPKHScriptSize + 4 + // NestedP2WPKHPkScriptSize is the size of a transaction output script + // that pays to a pay-to-witness-key hash nested in P2SH (P2SH-P2WPKH). + // It is calculated as: + // + // - OP_HASH160 + // - OP_DATA_20 + // - 20 bytes script hash + // - OP_EQUAL + NestedP2WPKHPkScriptSize = 1 + 1 + 20 + 1 + // RedeemNestedP2WPKHScriptSize is the worst case size of a transaction // input script that redeems a pay-to-witness-key hash nested in P2SH // (P2SH-P2WPKH). It is calculated as: @@ -150,12 +160,14 @@ func EstimateSerializeSize(inputCount int, txOuts []*wire.TxOut, addChangeOutput // from txOuts. The estimate is incremented for an additional P2PKH // change output if addChangeOutput is true. func EstimateVirtualSize(numP2PKHIns, numP2WPKHIns, numNestedP2WPKHIns int, - txOuts []*wire.TxOut, addChangeOutput bool) int { - changeSize := 0 + txOuts []*wire.TxOut, changeScriptSize int) int { outputCount := len(txOuts) - if addChangeOutput { - // We are always using P2WPKH as change output. - changeSize = P2WPKHOutputSize + + changeOutputSize := 0 + if changeScriptSize > 0 { + changeOutputSize = 8 + + wire.VarIntSerializeSize(uint64(changeScriptSize)) + + changeScriptSize outputCount++ } @@ -170,7 +182,7 @@ func EstimateVirtualSize(numP2PKHIns, numP2WPKHIns, numNestedP2WPKHIns int, numP2WPKHIns*RedeemP2WPKHInputSize + numNestedP2WPKHIns*RedeemNestedP2WPKHInputSize + SumOutputSerializeSizes(txOuts) + - changeSize + changeOutputSize // If this transaction has any witness inputs, we must count the // witness data. diff --git a/wallet/txsizes/size_test.go b/wallet/txsizes/size_test.go index 6594c33..9f3cd0c 100644 --- a/wallet/txsizes/size_test.go +++ b/wallet/txsizes/size_test.go @@ -163,8 +163,12 @@ func TestEstimateVirtualSize(t *testing.T) { t.Fatalf("unable to get test tx: %v", err) } + changeScriptSize := 0 + if test.change { + changeScriptSize = P2WPKHPkScriptSize + } est := EstimateVirtualSize(test.p2pkhIns, test.p2wpkhIns, - test.nestedp2wpkhIns, tx.TxOut, test.change) + test.nestedp2wpkhIns, tx.TxOut, changeScriptSize) if est != test.result { t.Fatalf("expected estimated vsize to be %d, "+