psbt: modify Extract method to return the transaction directly

In this commit, we modify the Extract method to return the transaction
directly as in many cases a user will likely want to write the
transaction to disk, or perform additional validation rather than obtain
the raw bytes directly.
This commit is contained in:
Olaoluwa Osuntokun 2020-01-15 17:41:17 -08:00
parent 57a6543394
commit 41cb8d70da
No known key found for this signature in database
GPG key ID: BC13F65E2DC84465

View file

@ -12,44 +12,63 @@ package psbt
import ( import (
"bytes" "bytes"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
) )
// Extract takes a finalized psbt and outputs a network serialization // Extract takes a finalized psbt.Packet and outputs a finalized transaction
func Extract(p *Psbt) ([]byte, error) { // instance. Note that if the PSBT is in-complete, then an error
// ErrIncompletePSBT will be returned. As the extracted transaction has been
// fully finalized, it will be ready for network broadcast once returned.
func Extract(p *Packet) (*wire.MsgTx, error) {
// If the packet isn't complete, then we'll return an error as it
// doesn't have all the required witness data.
if !p.IsComplete() { if !p.IsComplete() {
return nil, ErrIncompletePSBT return nil, ErrIncompletePSBT
} }
var err error
// We take the existing UnsignedTx field and append SignatureScript // First, we'll make a copy of the underlying unsigned transaction (the
// and Witness as appropriate, then allow MsgTx to do the serialization // initial template) so we don't mutate it during our activates below.
// for us. finalTx := p.UnsignedTx.Copy()
newTx := p.UnsignedTx.Copy()
for i, tin := range newTx.TxIn { // For each input, we'll now populate any relevant witness and
// sigScript data.
for i, tin := range finalTx.TxIn {
// We'll grab the corresponding internal packet input which
// matches this materialized transaction input and emplace that
// final sigScript (if present).
pInput := p.Inputs[i] pInput := p.Inputs[i]
if pInput.FinalScriptSig != nil { if pInput.FinalScriptSig != nil {
tin.SignatureScript = pInput.FinalScriptSig tin.SignatureScript = pInput.FinalScriptSig
} }
// Similarly, if there's a final witness, then we'll also need
// to extract that as well, parsing the lower-level transaction
// encoding.
if pInput.FinalScriptWitness != nil { if pInput.FinalScriptWitness != nil {
// to set the witness, need to re-deserialize the field // In order to set the witness, need to re-deserialize
// For each input, the witness is encoded as a stack // the field as encoded within the PSBT packet. For
// with one or more items. Therefore, we first read a // each input, the witness is encoded as a stack with
// varint which encodes the number of stack items. // one or more items.
r := bytes.NewReader(pInput.FinalScriptWitness) witnessReader := bytes.NewReader(
witCount, err := wire.ReadVarInt(r, 0) pInput.FinalScriptWitness,
)
// First we extract the number of witness elements
// encoded in the above witnessReader.
witCount, err := wire.ReadVarInt(witnessReader, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Then for witCount number of stack items, each item // Now that we know how may inputs we'll need, we'll
// has a varint length prefix, followed by the witness // construct a packing slice, then read out each input
// item itself. // (with a varint prefix) from the witnessReader.
tin.Witness = make([][]byte, witCount) tin.Witness = make(wire.TxWitness, witCount)
for j := uint64(0); j < witCount; j++ { for j := uint64(0); j < witCount; j++ {
// the 10000 size limit is as per BIP141 for witness script; wit, err := wire.ReadVarBytes(
// TODO this constant should be somewhere else in the lib, witnessReader, 0, txscript.MaxScriptSize, "witness",
// perhaps btcd/wire/common.go ? )
wit, err := wire.ReadVarBytes(r, 0, 10000, "witness")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,10 +76,6 @@ func Extract(p *Psbt) ([]byte, error) {
} }
} }
} }
var networkSerializedTx bytes.Buffer
err = newTx.Serialize(&networkSerializedTx) return finalTx, nil
if err != nil {
return nil, err
}
return networkSerializedTx.Bytes(), nil
} }