Merge pull request #722 from guggero/psbt-script-sig-fix

psbt: don't add scriptSig to txIn
This commit is contained in:
Olaoluwa Osuntokun 2020-10-05 11:48:31 -07:00 committed by GitHub
commit a7f551a630
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 156 additions and 151 deletions

View file

@ -91,8 +91,10 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32,
} }
packet.Inputs[idx].SighashType = txscript.SigHashAll packet.Inputs[idx].SighashType = txscript.SigHashAll
// We don't want to include the witness just yet. // We don't want to include the witness or any script
// just yet.
packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{} packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
packet.UnsignedTx.TxIn[idx].SignatureScript = nil
} }
return nil return nil

View file

@ -43,7 +43,7 @@ func TestFundPsbt(t *testing.T) {
} }
// Also create a nested P2WKH address we can use to send some coins to. // Also create a nested P2WKH address we can use to send some coins to.
addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084) addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0049Plus)
if err != nil { if err != nil {
t.Fatalf("unable to get current address: %v", addr) t.Fatalf("unable to get current address: %v", addr)
} }
@ -53,14 +53,25 @@ func TestFundPsbt(t *testing.T) {
} }
// Register two big UTXO that will be used when funding the PSBT. // Register two big UTXO that will be used when funding the PSBT.
incomingTx := &wire.MsgTx{ incomingTx1 := &wire.MsgTx{
TxIn: []*wire.TxIn{{}}, TxIn: []*wire.TxIn{{}},
TxOut: []*wire.TxOut{ TxOut: []*wire.TxOut{wire.NewTxOut(1000000, p2wkhAddr)},
wire.NewTxOut(1000000, p2wkhAddr), }
wire.NewTxOut(1000000, np2wkhAddr), addUtxo(t, w, incomingTx1)
}, utxo1 := wire.OutPoint{
Hash: incomingTx1.TxHash(),
Index: 0,
}
incomingTx2 := &wire.MsgTx{
TxIn: []*wire.TxIn{{}},
TxOut: []*wire.TxOut{wire.NewTxOut(900000, np2wkhAddr)},
}
addUtxo(t, w, incomingTx2)
utxo2 := wire.OutPoint{
Hash: incomingTx2.TxHash(),
Index: 0,
} }
addUtxo(t, w, incomingTx)
testCases := []struct { testCases := []struct {
name string name string
@ -70,7 +81,8 @@ func TestFundPsbt(t *testing.T) {
validatePackage bool validatePackage bool
expectedFee int64 expectedFee int64
expectedChange int64 expectedChange int64
numExpectedInputs int expectedInputs []wire.OutPoint
additionalChecks func(*testing.T, *psbt.Packet, int32)
}{{ }{{
name: "no outputs provided", name: "no outputs provided",
packet: &psbt.Packet{ packet: &psbt.Packet{
@ -108,23 +120,34 @@ func TestFundPsbt(t *testing.T) {
feeRateSatPerKB: 2000, // 2 sat/byte feeRateSatPerKB: 2000, // 2 sat/byte
expectedErr: "", expectedErr: "",
validatePackage: true, validatePackage: true,
expectedChange: 1000000 - 150000 - 368,
expectedFee: 368, expectedFee: 368,
numExpectedInputs: 1, expectedChange: 1000000 - 150000 - 368,
expectedInputs: []wire.OutPoint{utxo1},
}, {
name: "large output, no inputs",
packet: &psbt.Packet{
UnsignedTx: &wire.MsgTx{
TxOut: []*wire.TxOut{{
PkScript: testScriptP2WSH,
Value: 1500000,
}},
},
Outputs: []psbt.POutput{{}},
},
feeRateSatPerKB: 4000, // 4 sat/byte
expectedErr: "",
validatePackage: true,
expectedFee: 980,
expectedChange: 1900000 - 1500000 - 980,
expectedInputs: []wire.OutPoint{utxo1, utxo2},
}, { }, {
name: "two outputs, two inputs", name: "two outputs, two inputs",
packet: &psbt.Packet{ packet: &psbt.Packet{
UnsignedTx: &wire.MsgTx{ UnsignedTx: &wire.MsgTx{
TxIn: []*wire.TxIn{{ TxIn: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{ PreviousOutPoint: utxo1,
Hash: incomingTx.TxHash(),
Index: 0,
},
}, { }, {
PreviousOutPoint: wire.OutPoint{ PreviousOutPoint: utxo2,
Hash: incomingTx.TxHash(),
Index: 1,
},
}}, }},
TxOut: []*wire.TxOut{{ TxOut: []*wire.TxOut{{
PkScript: testScriptP2WSH, PkScript: testScriptP2WSH,
@ -140,123 +163,11 @@ func TestFundPsbt(t *testing.T) {
feeRateSatPerKB: 2000, // 2 sat/byte feeRateSatPerKB: 2000, // 2 sat/byte
expectedErr: "", expectedErr: "",
validatePackage: true, validatePackage: true,
expectedFee: 506, expectedFee: 552,
expectedChange: 2000000 - 150000 - 506, expectedChange: 1900000 - 150000 - 552,
numExpectedInputs: 2, expectedInputs: []wire.OutPoint{utxo1, utxo2},
}} additionalChecks: func(t *testing.T, packet *psbt.Packet,
changeIndex int32) {
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
changeIndex, err := w.FundPsbt(
tc.packet, 0, tc.feeRateSatPerKB,
)
// Make sure the error is what we expected.
if err == nil && tc.expectedErr != "" {
t.Fatalf("expected error '%s' but got nil",
tc.expectedErr)
}
if err != nil && tc.expectedErr == "" {
t.Fatalf("expected nil error but got '%v'", err)
}
if err != nil &&
!strings.Contains(err.Error(), tc.expectedErr) {
t.Fatalf("expected error '%s' but got '%v'",
tc.expectedErr, err)
}
if !tc.validatePackage {
return
}
// Check wire inputs.
packet := tc.packet
if len(packet.UnsignedTx.TxIn) != tc.numExpectedInputs {
t.Fatalf("expected %d inputs to be added, got "+
"%d", tc.numExpectedInputs,
len(packet.UnsignedTx.TxIn))
}
txIn := packet.UnsignedTx.TxIn[0]
if txIn.PreviousOutPoint.Hash != incomingTx.TxHash() {
t.Fatalf("unexpected UTXO prev outpoint "+
"hash, got %v wanted %v",
txIn.PreviousOutPoint.Hash,
incomingTx.TxHash())
}
if tc.numExpectedInputs > 1 {
txIn2 := packet.UnsignedTx.TxIn[1]
if txIn2.PreviousOutPoint.Hash != incomingTx.TxHash() {
t.Fatalf("unexpected UTXO prev outpoint "+
"hash, got %v wanted %v",
txIn2.PreviousOutPoint.Hash,
incomingTx.TxHash())
}
}
// Check partial inputs.
if len(packet.Inputs) != tc.numExpectedInputs {
t.Fatalf("expected %d partial input to be "+
"added, got %d", tc.numExpectedInputs,
len(packet.Inputs))
}
in := packet.Inputs[0]
if in.WitnessUtxo == nil {
t.Fatalf("partial input witness UTXO not set")
}
if !bytes.Equal(in.WitnessUtxo.PkScript, p2wkhAddr) {
t.Fatalf("unexpected witness UTXO script, "+
"got %x wanted %x",
in.WitnessUtxo.PkScript, p2wkhAddr)
}
if in.NonWitnessUtxo == nil {
t.Fatalf("partial input non-witness UTXO not " +
"set")
}
prevIdx := txIn.PreviousOutPoint.Index
nonWitnessOut := in.NonWitnessUtxo.TxOut[prevIdx]
if !bytes.Equal(nonWitnessOut.PkScript, p2wkhAddr) {
t.Fatalf("unexpected witness UTXO script, "+
"got %x wanted %x",
nonWitnessOut.PkScript, p2wkhAddr)
}
if in.SighashType != txscript.SigHashAll {
t.Fatalf("unexpected sighash flag, got %d "+
"wanted %d", in.SighashType,
txscript.SigHashAll)
}
if tc.numExpectedInputs > 1 {
in2 := packet.Inputs[1]
if in2.WitnessUtxo == nil {
t.Fatalf("partial input witness UTXO " +
"not set")
}
if !bytes.Equal(in2.WitnessUtxo.PkScript, np2wkhAddr) {
t.Fatalf("unexpected witness UTXO "+
"script, got %x wanted %x",
in2.WitnessUtxo.PkScript,
np2wkhAddr)
}
if in2.NonWitnessUtxo == nil {
t.Fatalf("partial input non-witness " +
"UTXO not set")
}
txIn2 := packet.UnsignedTx.TxIn[1]
prevIdx2 := txIn2.PreviousOutPoint.Index
nonWitnessOut2 := in2.NonWitnessUtxo.TxOut[prevIdx2]
if !bytes.Equal(nonWitnessOut2.PkScript, p2wkhAddr) {
t.Fatalf("unexpected witness UTXO script, "+
"got %x wanted %x",
nonWitnessOut2.PkScript, p2wkhAddr)
}
if in2.SighashType != txscript.SigHashAll {
t.Fatalf("unexpected sighash flag, "+
"got %d wanted %d",
in2.SighashType,
txscript.SigHashAll)
}
}
// Check outputs, find index for each of the 3 expected. // Check outputs, find index for each of the 3 expected.
txOuts := packet.UnsignedTx.TxOut txOuts := packet.UnsignedTx.TxOut
@ -301,8 +212,60 @@ func TestFundPsbt(t *testing.T) {
txOuts[p2wkhIndex].PkScript, txOuts[p2wkhIndex].PkScript,
txOuts[p2wshIndex].PkScript) txOuts[p2wshIndex].PkScript)
} }
},
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
changeIndex, err := w.FundPsbt(
tc.packet, 0, tc.feeRateSatPerKB,
)
// In any case, unlock the UTXO before continuing, we
// don't want to pollute other test iterations.
for _, in := range tc.packet.UnsignedTx.TxIn {
w.UnlockOutpoint(in.PreviousOutPoint)
}
// Make sure the error is what we expected.
if err == nil && tc.expectedErr != "" {
t.Fatalf("expected error '%s' but got nil",
tc.expectedErr)
}
if err != nil && tc.expectedErr == "" {
t.Fatalf("expected nil error but got '%v'", err)
}
if err != nil &&
!strings.Contains(err.Error(), tc.expectedErr) {
t.Fatalf("expected error '%s' but got '%v'",
tc.expectedErr, err)
}
if !tc.validatePackage {
return
}
// Check wire inputs.
packet := tc.packet
assertTxInputs(t, packet, tc.expectedInputs)
// Run any additional tests if available.
if tc.additionalChecks != nil {
tc.additionalChecks(t, packet, changeIndex)
}
// Finally, check the change output size and fee. // Finally, check the change output size and fee.
txOuts := packet.UnsignedTx.TxOut
totalOut := int64(0)
for _, txOut := range txOuts {
totalOut += txOut.Value
}
totalIn := int64(0)
for _, txIn := range packet.Inputs {
totalIn += txIn.WitnessUtxo.Value
}
fee := totalIn - totalOut fee := totalIn - totalOut
if fee != tc.expectedFee { if fee != tc.expectedFee {
t.Fatalf("unexpected fee, got %d wanted %d", t.Fatalf("unexpected fee, got %d wanted %d",
@ -318,6 +281,46 @@ func TestFundPsbt(t *testing.T) {
} }
} }
func assertTxInputs(t *testing.T, packet *psbt.Packet,
expected []wire.OutPoint) {
if len(packet.UnsignedTx.TxIn) != len(expected) {
t.Fatalf("expected %d inputs to be added, got %d",
len(expected), len(packet.UnsignedTx.TxIn))
}
// The order of the UTXOs is random, we need to loop through each of
// them to make sure they're found. We also check that no signature data
// was added yet.
for _, txIn := range packet.UnsignedTx.TxIn {
if !containsUtxo(expected, txIn.PreviousOutPoint) {
t.Fatalf("outpoint %v not found in list of expected "+
"UTXOs", txIn.PreviousOutPoint)
}
if len(txIn.SignatureScript) > 0 {
t.Fatalf("expected scriptSig to be empty on "+
"txin, got %x instead",
txIn.SignatureScript)
}
if len(txIn.Witness) > 0 {
t.Fatalf("expected witness to be empty on "+
"txin, got %v instead",
txIn.Witness)
}
}
}
func containsUtxo(list []wire.OutPoint, candidate wire.OutPoint) bool {
for _, utxo := range list {
if utxo == candidate {
return true
}
}
return false
}
// TestFinalizePsbt tests that a given PSBT packet can be finalized. // TestFinalizePsbt tests that a given PSBT packet can be finalized.
func TestFinalizePsbt(t *testing.T) { func TestFinalizePsbt(t *testing.T) {
w, cleanup := testWallet(t) w, cleanup := testWallet(t)