wallet: include BIP 32 derivation paths for inputs in PSBTs

Watch-only accounts are usually backed by an external hardware signer,
some of which require derivation paths to be populated for each relevant
input to sign.
This commit is contained in:
Wilmer Paulino 2021-02-17 16:24:08 -08:00
parent 07d4cc155e
commit 35b4b237c9
No known key found for this signature in database
GPG key ID: 6DF57B9F9514972F
3 changed files with 65 additions and 13 deletions

View file

@ -70,7 +70,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
addInputInfo := func(inputs []*wire.TxIn) error {
packet.Inputs = make([]psbt.PInput, len(inputs))
for idx, in := range inputs {
tx, utxo, _, err := w.FetchInputInfo(
tx, utxo, derivationPath, _, err := w.FetchInputInfo(
&in.PreviousOutPoint,
)
if err != nil {
@ -91,6 +91,11 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
}
packet.Inputs[idx].SighashType = txscript.SigHashAll
// Include the derivation path for each input.
packet.Inputs[idx].Bip32Derivation = []*psbt.Bip32Derivation{
derivationPath,
}
// We don't want to include the witness or any script
// just yet.
packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
@ -227,6 +232,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
//
// NOTE: This method does NOT publish the transaction after it's been finalized
// successfully.
//
// TODO: require account and check if watch only to avoid signing.
func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error {
// Let's check that this is actually something we can and want to sign.
// We need at least one input and one output.
@ -259,7 +266,7 @@ func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error {
// We can only sign this input if it's ours, so we try to map it
// to a coin we own. If we can't, then we'll continue as it
// isn't our input.
fullTx, txOut, _, err := w.FetchInputInfo(
fullTx, txOut, _, _, err := w.FetchInputInfo(
&txIn.PreviousOutPoint,
)
if err != nil {

View file

@ -11,6 +11,8 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/hdkeychain"
"github.com/btcsuite/btcutil/psbt"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
)
@ -105,15 +107,15 @@ func (w *Wallet) UnspentOutputs(policy OutputSelectionPolicy) ([]*TransactionOut
// full transaction, the target txout and the number of confirmations are
// returned. Otherwise, a non-nil error value of ErrNotMine is returned instead.
func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
*wire.TxOut, int64, error) {
*wire.TxOut, *psbt.Bip32Derivation, int64, error) {
// We manually look up the output within the tx store.
txid := &prevOut.Hash
txDetail, err := UnstableAPI(w).TxDetails(txid)
if err != nil {
return nil, nil, 0, err
return nil, nil, nil, 0, err
} else if txDetail == nil {
return nil, nil, 0, ErrNotMine
return nil, nil, nil, 0, ErrNotMine
}
// With the output retrieved, we'll make an additional check to ensure
@ -122,19 +124,25 @@ func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
// like in the event of us being the sender of the transaction.
numOutputs := uint32(len(txDetail.TxRecord.MsgTx.TxOut))
if prevOut.Index >= numOutputs {
return nil, nil, 0, fmt.Errorf("invalid output index %v for "+
return nil, nil, nil, 0, fmt.Errorf("invalid output index %v for "+
"transaction with %v outputs", prevOut.Index,
numOutputs)
}
pkScript := txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].PkScript
if _, err := w.fetchOutputAddr(pkScript); err != nil {
return nil, nil, 0, err
addr, err := w.fetchOutputAddr(pkScript)
if err != nil {
return nil, nil, nil, 0, err
}
pubKeyAddr, ok := addr.(waddrmgr.ManagedPubKeyAddress)
if !ok {
return nil, nil, nil, 0, err
}
keyScope, derivationPath, _ := pubKeyAddr.DerivationInfo()
// Determine the number of confirmations the output currently has.
_, currentHeight, err := w.chainClient.GetBestBlock()
if err != nil {
return nil, nil, 0, fmt.Errorf("unable to retrieve current "+
return nil, nil, nil, 0, fmt.Errorf("unable to retrieve current "+
"height: %v", err)
}
confs := int64(0)
@ -143,9 +151,19 @@ func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
}
return &txDetail.TxRecord.MsgTx, &wire.TxOut{
Value: txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].Value,
PkScript: pkScript,
}, confs, nil
Value: txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].Value,
PkScript: pkScript,
}, &psbt.Bip32Derivation{
PubKey: pubKeyAddr.PubKey().SerializeCompressed(),
MasterKeyFingerprint: 0, // TODO
Bip32Path: []uint32{
keyScope.Purpose + hdkeychain.HardenedKeyStart,
keyScope.Coin + hdkeychain.HardenedKeyStart,
derivationPath.Account,
derivationPath.Branch,
derivationPath.Index,
},
}, confs, nil
}
// fetchOutputAddr attempts to fetch the managed address corresponding to the

View file

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/hdkeychain"
"github.com/btcsuite/btcwallet/waddrmgr"
)
@ -43,7 +44,7 @@ func TestFetchInputInfo(t *testing.T) {
Hash: incomingTx.TxHash(),
Index: 0,
}
tx, out, confirmations, err := w.FetchInputInfo(prevOut)
tx, out, derivationPath, confirmations, err := w.FetchInputInfo(prevOut)
if err != nil {
t.Fatalf("error fetching input info: %v", err)
}
@ -54,6 +55,32 @@ func TestFetchInputInfo(t *testing.T) {
t.Fatalf("unexpected TX out, got %v wanted %v",
tx.TxOut[prevOut.Index].PkScript, utxOut)
}
if len(derivationPath.Bip32Path) != 5 {
t.Fatalf("expected derivation path of length %v, got %v", 3,
len(derivationPath.Bip32Path))
}
if derivationPath.Bip32Path[0] != waddrmgr.KeyScopeBIP0084.Purpose {
t.Fatalf("expected purpose %v, got %v",
waddrmgr.KeyScopeBIP0084.Purpose,
derivationPath.Bip32Path[0])
}
if derivationPath.Bip32Path[1] != waddrmgr.KeyScopeBIP0084.Coin {
t.Fatalf("expected coin type %v, got %v",
waddrmgr.KeyScopeBIP0084.Coin,
derivationPath.Bip32Path[1])
}
if derivationPath.Bip32Path[2] != hdkeychain.HardenedKeyStart {
t.Fatalf("expected account %v, got %v",
hdkeychain.HardenedKeyStart, derivationPath.Bip32Path[2])
}
if derivationPath.Bip32Path[3] != 0 {
t.Fatalf("expected branch %v, got %v", 0,
derivationPath.Bip32Path[3])
}
if derivationPath.Bip32Path[4] != 0 {
t.Fatalf("expected index %v, got %v", 0,
derivationPath.Bip32Path[4])
}
if confirmations != int64(0-testBlockHeight) {
t.Fatalf("unexpected number of confirmations, got %d wanted %d",
confirmations, 0-testBlockHeight)