psbt: refactor updater.go for consistent code style

This commit is contained in:
Olaoluwa Osuntokun 2020-01-15 17:43:44 -08:00
parent 2a3238c694
commit 6bd3b8034f
No known key found for this signature in database
GPG key ID: BC13F65E2DC84465

View file

@ -4,12 +4,11 @@
package psbt package psbt
// The Updater requires provision of a single PSBT and is able // The Updater requires provision of a single PSBT and is able to add data to
// to add data to both input and output sections. // both input and output sections. It can be called repeatedly to add more
// It can be called repeatedly to add more data. // data. It also allows addition of signatures via the addPartialSignature
// It also allows addition of signatures via the addPartialSignature // function; this is called internally to the package in the Sign() function of
// function; this is called internally to the package in the Sign() // Updater, located in signer.go
// function of Updater, located in signer.go
import ( import (
"bytes" "bytes"
@ -20,217 +19,259 @@ import (
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
) )
// Updater encapsulates the role 'Updater' // Updater encapsulates the role 'Updater' as specified in BIP174; it accepts
// as specified in BIP174; it accepts Psbt structs // Psbt structs and has methods to add fields to the inputs and outputs.
// and has methods to add fields to the inputs and outputs.
type Updater struct { type Updater struct {
Upsbt *Psbt Upsbt *Packet
} }
// NewUpdater returns a new instance of Updater, // NewUpdater returns a new instance of Updater, if the passed Psbt struct is
// if the passed Psbt struct is in a valid form, // in a valid form, else an error.
// else an error. func NewUpdater(p *Packet) (*Updater, error) {
func NewUpdater(p *Psbt) (*Updater, error) {
if err := p.SanityCheck(); err != nil { if err := p.SanityCheck(); err != nil {
return nil, err return nil, err
} }
return &Updater{Upsbt: p}, nil return &Updater{Upsbt: p}, nil
} }
// AddInNonWitnessUtxo adds the utxo information for an input which // AddInNonWitnessUtxo adds the utxo information for an input which is
// is non-witness. This requires provision of a full transaction // non-witness. This requires provision of a full transaction (which is the
// (which is the source of the corresponding prevOut), and the input // source of the corresponding prevOut), and the input index. If addition of
// index. If addition of this key-value pair to the Psbt fails, an // this key-value pair to the Psbt fails, an error is returned.
// error is returned.
func (p *Updater) AddInNonWitnessUtxo(tx *wire.MsgTx, inIndex int) error { func (p *Updater) AddInNonWitnessUtxo(tx *wire.MsgTx, inIndex int) error {
if inIndex > len(p.Upsbt.Inputs)-1 { if inIndex > len(p.Upsbt.Inputs)-1 {
return ErrInvalidPrevOutNonWitnessTransaction return ErrInvalidPrevOutNonWitnessTransaction
} }
p.Upsbt.Inputs[inIndex].NonWitnessUtxo = tx p.Upsbt.Inputs[inIndex].NonWitnessUtxo = tx
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
return nil return nil
} }
// AddInWitnessUtxo adds the utxo information for an input which // AddInWitnessUtxo adds the utxo information for an input which is witness.
// is witness. This requires provision of a full transaction *output* // This requires provision of a full transaction *output* (which is the source
// (which is the source of the corresponding prevOut); not the full // of the corresponding prevOut); not the full transaction because BIP143 means
// transaction because BIP143 means the output information is sufficient, // the output information is sufficient, and the input index. If addition of
// and the input index. If addition of this key-value pair to the Psbt fails, // this key-value pair to the Psbt fails, an error is returned.
// an error is returned.
func (p *Updater) AddInWitnessUtxo(txout *wire.TxOut, inIndex int) error { func (p *Updater) AddInWitnessUtxo(txout *wire.TxOut, inIndex int) error {
if inIndex > len(p.Upsbt.Inputs)-1 { if inIndex > len(p.Upsbt.Inputs)-1 {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
p.Upsbt.Inputs[inIndex].WitnessUtxo = txout p.Upsbt.Inputs[inIndex].WitnessUtxo = txout
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
return nil return nil
} }
// addPartialSignature allows the Updater role to insert fields // addPartialSignature allows the Updater role to insert fields of type partial
// of type partial signature into a Psbt, consisting of both // signature into a Psbt, consisting of both the pubkey (as keydata) and the
// the pubkey (as keydata) and the ECDSA signature (as value). // ECDSA signature (as value). Note that the Signer role is encapsulated in
// Note that the Signer role is encapsulated in this function; // this function; signatures are only allowed to be added that follow the
// signatures are only allowed to be added that follow the sanity-check // sanity-check on signing rules explained in the BIP under `Signer`; if the
// on signing rules explained in the BIP under `Signer`; if the rules are not // rules are not satisfied, an ErrInvalidSignatureForInput is returned.
// satisfied, an ErrInvalidSignatureForInput is returned. //
// NOTE this function does *not* validate the ECDSA signature itself. // NOTE: This function does *not* validate the ECDSA signature itself.
func (p *Updater) addPartialSignature(inIndex int, sig []byte, func (p *Updater) addPartialSignature(inIndex int, sig []byte,
pubkey []byte) error { pubkey []byte) error {
partialSig := PartialSig{PubKey: pubkey, Signature: sig} partialSig := PartialSig{
//First validate the passed (sig, pub): PubKey: pubkey, Signature: sig,
}
// First validate the passed (sig, pub).
if !partialSig.checkValid() { if !partialSig.checkValid() {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
pInput := p.Upsbt.Inputs[inIndex] pInput := p.Upsbt.Inputs[inIndex]
// First check; don't add duplicates // First check; don't add duplicates.
for _, x := range pInput.PartialSigs { for _, x := range pInput.PartialSigs {
if bytes.Equal(x.PubKey, partialSig.PubKey) { if bytes.Equal(x.PubKey, partialSig.PubKey) {
return ErrDuplicateKey return ErrDuplicateKey
} }
} }
// Sanity checks // Next, we perform a series of additional sanity checks.
if pInput.NonWitnessUtxo != nil { if pInput.NonWitnessUtxo != nil {
if len(p.Upsbt.UnsignedTx.TxIn) < inIndex+1 { if len(p.Upsbt.UnsignedTx.TxIn) < inIndex+1 {
return ErrInvalidPrevOutNonWitnessTransaction return ErrInvalidPrevOutNonWitnessTransaction
} }
if pInput.NonWitnessUtxo.TxHash() != if pInput.NonWitnessUtxo.TxHash() !=
p.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Hash { p.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Hash {
return ErrInvalidSignatureForInput return ErrInvalidSignatureForInput
} }
// To validate that the redeem script matches, we must pull out
// the scriptPubKey of the corresponding output and compare
// that with the P2SH scriptPubKey that is generated by
// redeemScript.
if pInput.RedeemScript != nil { if pInput.RedeemScript != nil {
// To validate that the redeem script matches, we must pull out the
// scriptPubKey of the corresponding output and compare that with
// the P2SH scriptPubKey that is generated by redeemScript:
outIndex := p.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index outIndex := p.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index
scriptPubKey := pInput.NonWitnessUtxo.TxOut[outIndex].PkScript scriptPubKey := pInput.NonWitnessUtxo.TxOut[outIndex].PkScript
scriptHash := btcutil.Hash160(pInput.RedeemScript) scriptHash := btcutil.Hash160(pInput.RedeemScript)
scriptHashScript, err := txscript.NewScriptBuilder().AddOp(
txscript.OP_HASH160).AddData(scriptHash).AddOp( scriptHashScript, err := txscript.NewScriptBuilder().
txscript.OP_EQUAL).Script() AddOp(txscript.OP_HASH160).
AddData(scriptHash).
AddOp(txscript.OP_EQUAL).
Script()
if err != nil { if err != nil {
return err return err
} }
if !bytes.Equal(scriptHashScript, scriptPubKey) { if !bytes.Equal(scriptHashScript, scriptPubKey) {
return ErrInvalidSignatureForInput return ErrInvalidSignatureForInput
} }
} }
} else if pInput.WitnessUtxo != nil { } else if pInput.WitnessUtxo != nil {
scriptPubKey := pInput.WitnessUtxo.PkScript scriptPubKey := pInput.WitnessUtxo.PkScript
var script []byte var script []byte
if pInput.RedeemScript != nil { if pInput.RedeemScript != nil {
scriptHash := btcutil.Hash160(pInput.RedeemScript) scriptHash := btcutil.Hash160(pInput.RedeemScript)
scriptHashScript, err := txscript.NewScriptBuilder().AddOp( scriptHashScript, err := txscript.NewScriptBuilder().
txscript.OP_HASH160).AddData(scriptHash).AddOp( AddOp(txscript.OP_HASH160).
txscript.OP_EQUAL).Script() AddData(scriptHash).
AddOp(txscript.OP_EQUAL).
Script()
if err != nil { if err != nil {
return err return err
} }
if !bytes.Equal(scriptHashScript, scriptPubKey) { if !bytes.Equal(scriptHashScript, scriptPubKey) {
return ErrInvalidSignatureForInput return ErrInvalidSignatureForInput
} }
script = pInput.RedeemScript script = pInput.RedeemScript
} else { } else {
script = scriptPubKey script = scriptPubKey
} }
// If a witnessScript field is present, this is a P2WSH, // If a witnessScript field is present, this is a P2WSH,
// whether nested or not (that is handled by the assignment to // whether nested or not (that is handled by the assignment to
// `script` above); in that case, sanity check that `script` // `script` above); in that case, sanity check that `script` is
// is the p2wsh of witnessScript. Contrariwise, if no witnessScript // the p2wsh of witnessScript. Contrariwise, if no
// field is present, this will be signed as p2wkh. // witnessScript field is present, this will be signed as
// p2wkh.
if pInput.WitnessScript != nil { if pInput.WitnessScript != nil {
witnessScriptHash := sha256.Sum256(pInput.WitnessScript) witnessScriptHash := sha256.Sum256(pInput.WitnessScript)
witnessScriptHashScript, err := txscript.NewScriptBuilder().AddOp( witnessScriptHashScript, err := txscript.NewScriptBuilder().
txscript.OP_0).AddData(witnessScriptHash[:]).Script() AddOp(txscript.OP_0).
AddData(witnessScriptHash[:]).
Script()
if err != nil { if err != nil {
return err return err
} }
if !bytes.Equal(script, witnessScriptHashScript[:]) { if !bytes.Equal(script, witnessScriptHashScript[:]) {
return ErrInvalidSignatureForInput return ErrInvalidSignatureForInput
} }
} else { // p2wkh } else {
// Otherwise, this is a p2wkh input.
pubkeyHash := btcutil.Hash160(pubkey) pubkeyHash := btcutil.Hash160(pubkey)
pubkeyHashScript, err := txscript.NewScriptBuilder().AddOp( pubkeyHashScript, err := txscript.NewScriptBuilder().
txscript.OP_0).AddData(pubkeyHash).Script() AddOp(txscript.OP_0).
AddData(pubkeyHash).
Script()
if err != nil { if err != nil {
return err return err
} }
// Validate that we're able to properly reconstruct the
// witness program.
if !bytes.Equal(pubkeyHashScript, script) { if !bytes.Equal(pubkeyHashScript, script) {
return ErrInvalidSignatureForInput return ErrInvalidSignatureForInput
} }
} }
} else { } else {
// attaching signature without utxo field is not allowed
// Attaching signature without utxo field is not allowed.
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
p.Upsbt.Inputs[inIndex].PartialSigs = append( p.Upsbt.Inputs[inIndex].PartialSigs = append(
p.Upsbt.Inputs[inIndex].PartialSigs, &partialSig) p.Upsbt.Inputs[inIndex].PartialSigs, &partialSig,
)
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return err return err
} }
// Addition of a non-duplicate-key partial signature
// cannot violate sanity-check rules. // Addition of a non-duplicate-key partial signature cannot violate
// sanity-check rules.
return nil return nil
} }
// AddInSighashType adds the sighash type information for an input. // AddInSighashType adds the sighash type information for an input. The
// The sighash type is passed as a 32 bit unsigned integer, along with the // sighash type is passed as a 32 bit unsigned integer, along with the index
// index for the input. An error is returned if addition of this key-value pair // for the input. An error is returned if addition of this key-value pair to
// to the Psbt fails. // the Psbt fails.
func (p *Updater) AddInSighashType(sighashType txscript.SigHashType, func (p *Updater) AddInSighashType(sighashType txscript.SigHashType,
inIndex int) error { inIndex int) error {
p.Upsbt.Inputs[inIndex].SighashType = sighashType p.Upsbt.Inputs[inIndex].SighashType = sighashType
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return err return err
} }
return nil return nil
} }
// AddInRedeemScript adds the redeem script information for an input. // AddInRedeemScript adds the redeem script information for an input. The
// The redeem script is passed serialized, as a byte slice, along with the // redeem script is passed serialized, as a byte slice, along with the index of
// index of the input. An error is returned if addition of this key-value pair // the input. An error is returned if addition of this key-value pair to the
// to the Psbt fails. // Psbt fails.
func (p *Updater) AddInRedeemScript(redeemScript []byte, func (p *Updater) AddInRedeemScript(redeemScript []byte,
inIndex int) error { inIndex int) error {
p.Upsbt.Inputs[inIndex].RedeemScript = redeemScript p.Upsbt.Inputs[inIndex].RedeemScript = redeemScript
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
return nil return nil
} }
// AddInWitnessScript adds the witness script information for an input. // AddInWitnessScript adds the witness script information for an input. The
// The witness script is passed serialized, as a byte slice, along with the // witness script is passed serialized, as a byte slice, along with the index
// index of the input. An error is returned if addition of this key-value pair // of the input. An error is returned if addition of this key-value pair to the
// to the Psbt fails. // Psbt fails.
func (p *Updater) AddInWitnessScript(witnessScript []byte, func (p *Updater) AddInWitnessScript(witnessScript []byte,
inIndex int) error { inIndex int) error {
p.Upsbt.Inputs[inIndex].WitnessScript = witnessScript p.Upsbt.Inputs[inIndex].WitnessScript = witnessScript
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return err return err
} }
return nil return nil
} }
// AddInBip32Derivation takes a master key fingerprint // AddInBip32Derivation takes a master key fingerprint as defined in BIP32, a
// as defined in BIP32, a BIP32 path as a slice of uint32 values, // BIP32 path as a slice of uint32 values, and a serialized pubkey as a byte
// and a serialized pubkey as a byte slice, along with the // slice, along with the integer index of the input, and inserts this data into
// integer index of the input, and inserts this data into that input. // that input.
// NOTE that this can be called multiple times for the same input. //
// An error is returned if addition of this key-value pair // NOTE: This can be called multiple times for the same input. An error is
// to the Psbt fails. // returned if addition of this key-value pair to the Psbt fails.
func (p *Updater) AddInBip32Derivation(masterKeyFingerprint uint32, func (p *Updater) AddInBip32Derivation(masterKeyFingerprint uint32,
bip32Path []uint32, pubKeyData []byte, inIndex int) error { bip32Path []uint32, pubKeyData []byte, inIndex int) error {
bip32Derivation := Bip32Derivation{ bip32Derivation := Bip32Derivation{
PubKey: pubKeyData, PubKey: pubKeyData,
MasterKeyFingerprint: masterKeyFingerprint, MasterKeyFingerprint: masterKeyFingerprint,
@ -249,27 +290,32 @@ func (p *Updater) AddInBip32Derivation(masterKeyFingerprint uint32,
} }
p.Upsbt.Inputs[inIndex].Bip32Derivation = append( p.Upsbt.Inputs[inIndex].Bip32Derivation = append(
p.Upsbt.Inputs[inIndex].Bip32Derivation, &bip32Derivation) p.Upsbt.Inputs[inIndex].Bip32Derivation, &bip32Derivation,
)
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return err return err
} }
return nil return nil
} }
// AddOutBip32Derivation takes a master key fingerprint // AddOutBip32Derivation takes a master key fingerprint as defined in BIP32, a
// as defined in BIP32, a BIP32 path as a slice of uint32 values, // BIP32 path as a slice of uint32 values, and a serialized pubkey as a byte
// and a serialized pubkey as a byte slice, along with the // slice, along with the integer index of the output, and inserts this data
// integer index of the output, and inserts this data into that output. // into that output.
// NOTE that this can be called multiple times for the same output. //
// An error is returned if addition of this key-value pair // NOTE: That this can be called multiple times for the same output. An error
// to the Psbt fails. // is returned if addition of this key-value pair to the Psbt fails.
func (p *Updater) AddOutBip32Derivation(masterKeyFingerprint uint32, func (p *Updater) AddOutBip32Derivation(masterKeyFingerprint uint32,
bip32Path []uint32, pubKeyData []byte, outIndex int) error { bip32Path []uint32, pubKeyData []byte, outIndex int) error {
bip32Derivation := Bip32Derivation{ bip32Derivation := Bip32Derivation{
PubKey: pubKeyData, PubKey: pubKeyData,
MasterKeyFingerprint: masterKeyFingerprint, MasterKeyFingerprint: masterKeyFingerprint,
Bip32Path: bip32Path, Bip32Path: bip32Path,
} }
if !bip32Derivation.checkValid() { if !bip32Derivation.checkValid() {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
@ -282,31 +328,40 @@ func (p *Updater) AddOutBip32Derivation(masterKeyFingerprint uint32,
} }
p.Upsbt.Outputs[outIndex].Bip32Derivation = append( p.Upsbt.Outputs[outIndex].Bip32Derivation = append(
p.Upsbt.Outputs[outIndex].Bip32Derivation, &bip32Derivation) p.Upsbt.Outputs[outIndex].Bip32Derivation, &bip32Derivation,
)
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return err return err
} }
return nil return nil
} }
// AddOutRedeemScript takes a redeem script as a byte slice // AddOutRedeemScript takes a redeem script as a byte slice and appends it to
// and appends it to the output at index outIndex. // the output at index outIndex.
func (p *Updater) AddOutRedeemScript(redeemScript []byte, func (p *Updater) AddOutRedeemScript(redeemScript []byte,
outIndex int) error { outIndex int) error {
p.Upsbt.Outputs[outIndex].RedeemScript = redeemScript p.Upsbt.Outputs[outIndex].RedeemScript = redeemScript
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return ErrInvalidPsbtFormat return ErrInvalidPsbtFormat
} }
return nil return nil
} }
// AddOutWitnessScript takes a witness script as a byte slice // AddOutWitnessScript takes a witness script as a byte slice and appends it to
// and appends it to the output at index outIndex. // the output at index outIndex.
func (p *Updater) AddOutWitnessScript(witnessScript []byte, func (p *Updater) AddOutWitnessScript(witnessScript []byte,
outIndex int) error { outIndex int) error {
p.Upsbt.Outputs[outIndex].WitnessScript = witnessScript p.Upsbt.Outputs[outIndex].WitnessScript = witnessScript
if err := p.Upsbt.SanityCheck(); err != nil { if err := p.Upsbt.SanityCheck(); err != nil {
return err return err
} }
return nil return nil
} }