From 3252f9fc110178edcb1650d830b2ea946b6c2501 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Mon, 5 Oct 2020 10:58:55 +0200 Subject: [PATCH] psbt: don't add scriptSig to txIn Because of an incorrect test, it wasn't discovered that the scriptSig field was being set on the unsigned TX inputs for a nested SegWit input. This commit fixes the bug and also refactors the test so it would have caught this specific bug. --- wallet/psbt.go | 4 +- wallet/psbt_test.go | 303 ++++++++++++++++++++++---------------------- 2 files changed, 156 insertions(+), 151 deletions(-) diff --git a/wallet/psbt.go b/wallet/psbt.go index 1a68a71..234f48a 100644 --- a/wallet/psbt.go +++ b/wallet/psbt.go @@ -91,8 +91,10 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, } packet.Inputs[idx].SighashType = txscript.SigHashAll - // We don't want to include the witness just yet. + // We don't want to include the witness or any script + // just yet. packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{} + packet.UnsignedTx.TxIn[idx].SignatureScript = nil } return nil diff --git a/wallet/psbt_test.go b/wallet/psbt_test.go index b9747b2..af9896b 100644 --- a/wallet/psbt_test.go +++ b/wallet/psbt_test.go @@ -43,7 +43,7 @@ func TestFundPsbt(t *testing.T) { } // Also create a nested P2WKH address we can use to send some coins to. - addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084) + addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0049Plus) if err != nil { t.Fatalf("unable to get current address: %v", addr) } @@ -53,24 +53,36 @@ func TestFundPsbt(t *testing.T) { } // Register two big UTXO that will be used when funding the PSBT. - incomingTx := &wire.MsgTx{ - TxIn: []*wire.TxIn{{}}, - TxOut: []*wire.TxOut{ - wire.NewTxOut(1000000, p2wkhAddr), - wire.NewTxOut(1000000, np2wkhAddr), - }, + incomingTx1 := &wire.MsgTx{ + TxIn: []*wire.TxIn{{}}, + TxOut: []*wire.TxOut{wire.NewTxOut(1000000, p2wkhAddr)}, + } + addUtxo(t, w, incomingTx1) + utxo1 := wire.OutPoint{ + Hash: incomingTx1.TxHash(), + Index: 0, + } + + incomingTx2 := &wire.MsgTx{ + TxIn: []*wire.TxIn{{}}, + TxOut: []*wire.TxOut{wire.NewTxOut(900000, np2wkhAddr)}, + } + addUtxo(t, w, incomingTx2) + utxo2 := wire.OutPoint{ + Hash: incomingTx2.TxHash(), + Index: 0, } - addUtxo(t, w, incomingTx) testCases := []struct { - name string - packet *psbt.Packet - feeRateSatPerKB btcutil.Amount - expectedErr string - validatePackage bool - expectedFee int64 - expectedChange int64 - numExpectedInputs int + name string + packet *psbt.Packet + feeRateSatPerKB btcutil.Amount + expectedErr string + validatePackage bool + expectedFee int64 + expectedChange int64 + expectedInputs []wire.OutPoint + additionalChecks func(*testing.T, *psbt.Packet, int32) }{{ name: "no outputs provided", packet: &psbt.Packet{ @@ -105,26 +117,37 @@ func TestFundPsbt(t *testing.T) { }, Outputs: []psbt.POutput{{}, {}}, }, - feeRateSatPerKB: 2000, // 2 sat/byte - expectedErr: "", - validatePackage: true, - expectedChange: 1000000 - 150000 - 368, - expectedFee: 368, - numExpectedInputs: 1, + feeRateSatPerKB: 2000, // 2 sat/byte + expectedErr: "", + validatePackage: true, + expectedFee: 368, + expectedChange: 1000000 - 150000 - 368, + expectedInputs: []wire.OutPoint{utxo1}, + }, { + name: "large output, no inputs", + packet: &psbt.Packet{ + UnsignedTx: &wire.MsgTx{ + TxOut: []*wire.TxOut{{ + PkScript: testScriptP2WSH, + Value: 1500000, + }}, + }, + Outputs: []psbt.POutput{{}}, + }, + feeRateSatPerKB: 4000, // 4 sat/byte + expectedErr: "", + validatePackage: true, + expectedFee: 980, + expectedChange: 1900000 - 1500000 - 980, + expectedInputs: []wire.OutPoint{utxo1, utxo2}, }, { name: "two outputs, two inputs", packet: &psbt.Packet{ UnsignedTx: &wire.MsgTx{ TxIn: []*wire.TxIn{{ - PreviousOutPoint: wire.OutPoint{ - Hash: incomingTx.TxHash(), - Index: 0, - }, + PreviousOutPoint: utxo1, }, { - PreviousOutPoint: wire.OutPoint{ - Hash: incomingTx.TxHash(), - Index: 1, - }, + PreviousOutPoint: utxo2, }}, TxOut: []*wire.TxOut{{ PkScript: testScriptP2WSH, @@ -137,126 +160,14 @@ func TestFundPsbt(t *testing.T) { Inputs: []psbt.PInput{{}, {}}, Outputs: []psbt.POutput{{}, {}}, }, - feeRateSatPerKB: 2000, // 2 sat/byte - expectedErr: "", - validatePackage: true, - expectedFee: 506, - expectedChange: 2000000 - 150000 - 506, - numExpectedInputs: 2, - }} - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - changeIndex, err := w.FundPsbt( - tc.packet, 0, tc.feeRateSatPerKB, - ) - - // Make sure the error is what we expected. - if err == nil && tc.expectedErr != "" { - t.Fatalf("expected error '%s' but got nil", - tc.expectedErr) - } - if err != nil && tc.expectedErr == "" { - t.Fatalf("expected nil error but got '%v'", err) - } - if err != nil && - !strings.Contains(err.Error(), tc.expectedErr) { - - t.Fatalf("expected error '%s' but got '%v'", - tc.expectedErr, err) - } - - if !tc.validatePackage { - return - } - - // Check wire inputs. - packet := tc.packet - if len(packet.UnsignedTx.TxIn) != tc.numExpectedInputs { - t.Fatalf("expected %d inputs to be added, got "+ - "%d", tc.numExpectedInputs, - len(packet.UnsignedTx.TxIn)) - } - txIn := packet.UnsignedTx.TxIn[0] - if txIn.PreviousOutPoint.Hash != incomingTx.TxHash() { - t.Fatalf("unexpected UTXO prev outpoint "+ - "hash, got %v wanted %v", - txIn.PreviousOutPoint.Hash, - incomingTx.TxHash()) - } - if tc.numExpectedInputs > 1 { - txIn2 := packet.UnsignedTx.TxIn[1] - if txIn2.PreviousOutPoint.Hash != incomingTx.TxHash() { - t.Fatalf("unexpected UTXO prev outpoint "+ - "hash, got %v wanted %v", - txIn2.PreviousOutPoint.Hash, - incomingTx.TxHash()) - } - } - - // Check partial inputs. - if len(packet.Inputs) != tc.numExpectedInputs { - t.Fatalf("expected %d partial input to be "+ - "added, got %d", tc.numExpectedInputs, - len(packet.Inputs)) - } - in := packet.Inputs[0] - if in.WitnessUtxo == nil { - t.Fatalf("partial input witness UTXO not set") - } - if !bytes.Equal(in.WitnessUtxo.PkScript, p2wkhAddr) { - t.Fatalf("unexpected witness UTXO script, "+ - "got %x wanted %x", - in.WitnessUtxo.PkScript, p2wkhAddr) - } - if in.NonWitnessUtxo == nil { - t.Fatalf("partial input non-witness UTXO not " + - "set") - } - prevIdx := txIn.PreviousOutPoint.Index - nonWitnessOut := in.NonWitnessUtxo.TxOut[prevIdx] - if !bytes.Equal(nonWitnessOut.PkScript, p2wkhAddr) { - t.Fatalf("unexpected witness UTXO script, "+ - "got %x wanted %x", - nonWitnessOut.PkScript, p2wkhAddr) - } - if in.SighashType != txscript.SigHashAll { - t.Fatalf("unexpected sighash flag, got %d "+ - "wanted %d", in.SighashType, - txscript.SigHashAll) - } - if tc.numExpectedInputs > 1 { - in2 := packet.Inputs[1] - if in2.WitnessUtxo == nil { - t.Fatalf("partial input witness UTXO " + - "not set") - } - if !bytes.Equal(in2.WitnessUtxo.PkScript, np2wkhAddr) { - t.Fatalf("unexpected witness UTXO "+ - "script, got %x wanted %x", - in2.WitnessUtxo.PkScript, - np2wkhAddr) - } - if in2.NonWitnessUtxo == nil { - t.Fatalf("partial input non-witness " + - "UTXO not set") - } - txIn2 := packet.UnsignedTx.TxIn[1] - prevIdx2 := txIn2.PreviousOutPoint.Index - nonWitnessOut2 := in2.NonWitnessUtxo.TxOut[prevIdx2] - if !bytes.Equal(nonWitnessOut2.PkScript, p2wkhAddr) { - t.Fatalf("unexpected witness UTXO script, "+ - "got %x wanted %x", - nonWitnessOut2.PkScript, p2wkhAddr) - } - if in2.SighashType != txscript.SigHashAll { - t.Fatalf("unexpected sighash flag, "+ - "got %d wanted %d", - in2.SighashType, - txscript.SigHashAll) - } - } + feeRateSatPerKB: 2000, // 2 sat/byte + expectedErr: "", + validatePackage: true, + expectedFee: 552, + expectedChange: 1900000 - 150000 - 552, + expectedInputs: []wire.OutPoint{utxo1, utxo2}, + additionalChecks: func(t *testing.T, packet *psbt.Packet, + changeIndex int32) { // Check outputs, find index for each of the 3 expected. txOuts := packet.UnsignedTx.TxOut @@ -301,8 +212,60 @@ func TestFundPsbt(t *testing.T) { txOuts[p2wkhIndex].PkScript, txOuts[p2wshIndex].PkScript) } + }, + }} + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + changeIndex, err := w.FundPsbt( + tc.packet, 0, tc.feeRateSatPerKB, + ) + + // In any case, unlock the UTXO before continuing, we + // don't want to pollute other test iterations. + for _, in := range tc.packet.UnsignedTx.TxIn { + w.UnlockOutpoint(in.PreviousOutPoint) + } + + // Make sure the error is what we expected. + if err == nil && tc.expectedErr != "" { + t.Fatalf("expected error '%s' but got nil", + tc.expectedErr) + } + if err != nil && tc.expectedErr == "" { + t.Fatalf("expected nil error but got '%v'", err) + } + if err != nil && + !strings.Contains(err.Error(), tc.expectedErr) { + + t.Fatalf("expected error '%s' but got '%v'", + tc.expectedErr, err) + } + + if !tc.validatePackage { + return + } + + // Check wire inputs. + packet := tc.packet + assertTxInputs(t, packet, tc.expectedInputs) + + // Run any additional tests if available. + if tc.additionalChecks != nil { + tc.additionalChecks(t, packet, changeIndex) + } // Finally, check the change output size and fee. + txOuts := packet.UnsignedTx.TxOut + totalOut := int64(0) + for _, txOut := range txOuts { + totalOut += txOut.Value + } + totalIn := int64(0) + for _, txIn := range packet.Inputs { + totalIn += txIn.WitnessUtxo.Value + } fee := totalIn - totalOut if fee != tc.expectedFee { t.Fatalf("unexpected fee, got %d wanted %d", @@ -318,6 +281,46 @@ func TestFundPsbt(t *testing.T) { } } +func assertTxInputs(t *testing.T, packet *psbt.Packet, + expected []wire.OutPoint) { + + if len(packet.UnsignedTx.TxIn) != len(expected) { + t.Fatalf("expected %d inputs to be added, got %d", + len(expected), len(packet.UnsignedTx.TxIn)) + } + + // The order of the UTXOs is random, we need to loop through each of + // them to make sure they're found. We also check that no signature data + // was added yet. + for _, txIn := range packet.UnsignedTx.TxIn { + if !containsUtxo(expected, txIn.PreviousOutPoint) { + t.Fatalf("outpoint %v not found in list of expected "+ + "UTXOs", txIn.PreviousOutPoint) + } + + if len(txIn.SignatureScript) > 0 { + t.Fatalf("expected scriptSig to be empty on "+ + "txin, got %x instead", + txIn.SignatureScript) + } + if len(txIn.Witness) > 0 { + t.Fatalf("expected witness to be empty on "+ + "txin, got %v instead", + txIn.Witness) + } + } +} + +func containsUtxo(list []wire.OutPoint, candidate wire.OutPoint) bool { + for _, utxo := range list { + if utxo == candidate { + return true + } + } + + return false +} + // TestFinalizePsbt tests that a given PSBT packet can be finalized. func TestFinalizePsbt(t *testing.T) { w, cleanup := testWallet(t)