diff --git a/wallet/psbt.go b/wallet/psbt.go index 234f48a..3fd81c8 100644 --- a/wallet/psbt.go +++ b/wallet/psbt.go @@ -70,7 +70,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, addInputInfo := func(inputs []*wire.TxIn) error { packet.Inputs = make([]psbt.PInput, len(inputs)) for idx, in := range inputs { - tx, utxo, _, err := w.FetchInputInfo( + tx, utxo, derivationPath, _, err := w.FetchInputInfo( &in.PreviousOutPoint, ) if err != nil { @@ -91,6 +91,11 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, } packet.Inputs[idx].SighashType = txscript.SigHashAll + // Include the derivation path for each input. + packet.Inputs[idx].Bip32Derivation = []*psbt.Bip32Derivation{ + derivationPath, + } + // We don't want to include the witness or any script // just yet. packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{} @@ -227,6 +232,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // // NOTE: This method does NOT publish the transaction after it's been finalized // successfully. +// +// TODO: require account and check if watch only to avoid signing. func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error { // Let's check that this is actually something we can and want to sign. // We need at least one input and one output. @@ -259,7 +266,7 @@ func (w *Wallet) FinalizePsbt(packet *psbt.Packet) error { // We can only sign this input if it's ours, so we try to map it // to a coin we own. If we can't, then we'll continue as it // isn't our input. - fullTx, txOut, _, err := w.FetchInputInfo( + fullTx, txOut, _, _, err := w.FetchInputInfo( &txIn.PreviousOutPoint, ) if err != nil { diff --git a/wallet/utxos.go b/wallet/utxos.go index a29d094..d22cc90 100644 --- a/wallet/utxos.go +++ b/wallet/utxos.go @@ -11,6 +11,8 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil/hdkeychain" + "github.com/btcsuite/btcutil/psbt" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/walletdb" ) @@ -105,15 +107,15 @@ func (w *Wallet) UnspentOutputs(policy OutputSelectionPolicy) ([]*TransactionOut // full transaction, the target txout and the number of confirmations are // returned. Otherwise, a non-nil error value of ErrNotMine is returned instead. func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx, - *wire.TxOut, int64, error) { + *wire.TxOut, *psbt.Bip32Derivation, int64, error) { // We manually look up the output within the tx store. txid := &prevOut.Hash txDetail, err := UnstableAPI(w).TxDetails(txid) if err != nil { - return nil, nil, 0, err + return nil, nil, nil, 0, err } else if txDetail == nil { - return nil, nil, 0, ErrNotMine + return nil, nil, nil, 0, ErrNotMine } // With the output retrieved, we'll make an additional check to ensure @@ -122,19 +124,25 @@ func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx, // like in the event of us being the sender of the transaction. numOutputs := uint32(len(txDetail.TxRecord.MsgTx.TxOut)) if prevOut.Index >= numOutputs { - return nil, nil, 0, fmt.Errorf("invalid output index %v for "+ + return nil, nil, nil, 0, fmt.Errorf("invalid output index %v for "+ "transaction with %v outputs", prevOut.Index, numOutputs) } pkScript := txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].PkScript - if _, err := w.fetchOutputAddr(pkScript); err != nil { - return nil, nil, 0, err + addr, err := w.fetchOutputAddr(pkScript) + if err != nil { + return nil, nil, nil, 0, err } + pubKeyAddr, ok := addr.(waddrmgr.ManagedPubKeyAddress) + if !ok { + return nil, nil, nil, 0, err + } + keyScope, derivationPath, _ := pubKeyAddr.DerivationInfo() // Determine the number of confirmations the output currently has. _, currentHeight, err := w.chainClient.GetBestBlock() if err != nil { - return nil, nil, 0, fmt.Errorf("unable to retrieve current "+ + return nil, nil, nil, 0, fmt.Errorf("unable to retrieve current "+ "height: %v", err) } confs := int64(0) @@ -143,9 +151,19 @@ func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx, } return &txDetail.TxRecord.MsgTx, &wire.TxOut{ - Value: txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].Value, - PkScript: pkScript, - }, confs, nil + Value: txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].Value, + PkScript: pkScript, + }, &psbt.Bip32Derivation{ + PubKey: pubKeyAddr.PubKey().SerializeCompressed(), + MasterKeyFingerprint: 0, // TODO + Bip32Path: []uint32{ + keyScope.Purpose + hdkeychain.HardenedKeyStart, + keyScope.Coin + hdkeychain.HardenedKeyStart, + derivationPath.Account, + derivationPath.Branch, + derivationPath.Index, + }, + }, confs, nil } // fetchOutputAddr attempts to fetch the managed address corresponding to the diff --git a/wallet/utxos_test.go b/wallet/utxos_test.go index 3181c80..adeef80 100644 --- a/wallet/utxos_test.go +++ b/wallet/utxos_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil/hdkeychain" "github.com/btcsuite/btcwallet/waddrmgr" ) @@ -43,7 +44,7 @@ func TestFetchInputInfo(t *testing.T) { Hash: incomingTx.TxHash(), Index: 0, } - tx, out, confirmations, err := w.FetchInputInfo(prevOut) + tx, out, derivationPath, confirmations, err := w.FetchInputInfo(prevOut) if err != nil { t.Fatalf("error fetching input info: %v", err) } @@ -54,6 +55,32 @@ func TestFetchInputInfo(t *testing.T) { t.Fatalf("unexpected TX out, got %v wanted %v", tx.TxOut[prevOut.Index].PkScript, utxOut) } + if len(derivationPath.Bip32Path) != 5 { + t.Fatalf("expected derivation path of length %v, got %v", 3, + len(derivationPath.Bip32Path)) + } + if derivationPath.Bip32Path[0] != waddrmgr.KeyScopeBIP0084.Purpose { + t.Fatalf("expected purpose %v, got %v", + waddrmgr.KeyScopeBIP0084.Purpose, + derivationPath.Bip32Path[0]) + } + if derivationPath.Bip32Path[1] != waddrmgr.KeyScopeBIP0084.Coin { + t.Fatalf("expected coin type %v, got %v", + waddrmgr.KeyScopeBIP0084.Coin, + derivationPath.Bip32Path[1]) + } + if derivationPath.Bip32Path[2] != hdkeychain.HardenedKeyStart { + t.Fatalf("expected account %v, got %v", + hdkeychain.HardenedKeyStart, derivationPath.Bip32Path[2]) + } + if derivationPath.Bip32Path[3] != 0 { + t.Fatalf("expected branch %v, got %v", 0, + derivationPath.Bip32Path[3]) + } + if derivationPath.Bip32Path[4] != 0 { + t.Fatalf("expected index %v, got %v", 0, + derivationPath.Bip32Path[4]) + } if confirmations != int64(0-testBlockHeight) { t.Fatalf("unexpected number of confirmations, got %d wanted %d", confirmations, 0-testBlockHeight)