From ea41fc5177afd691b092ce984979fe628fb85c5d Mon Sep 17 00:00:00 2001
From: Olaoluwa Osuntokun <laolu32@gmail.com>
Date: Wed, 15 Jan 2020 17:41:17 -0800
Subject: [PATCH] psbt: modify Extract method to return the transaction
 directly

In this commit, we modify the Extract method to return the transaction directly as in many cases a user will likely want to write the transaction to disk, or perform additional validation rather than obtain the raw bytes directly.
---
 psbt/extractor.go | 71 ++++++++++++++++++++++++++++-------------------
 1 file changed, 43 insertions(+), 28 deletions(-)

diff --git a/psbt/extractor.go b/psbt/extractor.go
index de80fec..dc7f10f 100644
--- a/psbt/extractor.go
+++ b/psbt/extractor.go
@@ -12,44 +12,63 @@ package psbt
 import (
 	"bytes"
 
+	"github.com/btcsuite/btcd/txscript"
 	"github.com/btcsuite/btcd/wire"
 )
 
-// Extract takes a finalized psbt and outputs a network serialization
-func Extract(p *Psbt) ([]byte, error) {
+// Extract takes a finalized psbt.Packet and outputs a finalized transaction
+// instance. Note that if the PSBT is in-complete, then an error
+// ErrIncompletePSBT will be returned. As the extracted transaction has been
+// fully finalized, it will be ready for network broadcast once returned.
+func Extract(p *Packet) (*wire.MsgTx, error) {
+	// If the packet isn't complete, then we'll return an error as it
+	// doesn't have all the required witness data.
 	if !p.IsComplete() {
 		return nil, ErrIncompletePSBT
 	}
-	var err error
-	// We take the existing UnsignedTx field and append SignatureScript
-	// and Witness as appropriate, then allow MsgTx to do the serialization
-	// for us.
-	newTx := p.UnsignedTx.Copy()
-	for i, tin := range newTx.TxIn {
+
+	// First, we'll make a copy of the underlying unsigned transaction (the
+	// initial template) so we don't mutate it during our activates below.
+	finalTx := p.UnsignedTx.Copy()
+
+	// For each input, we'll now populate any relevant witness and
+	// sigScript data.
+	for i, tin := range finalTx.TxIn {
+		// We'll grab the corresponding internal packet input which
+		// matches this materialized transaction input and emplace that
+		// final sigScript (if present).
 		pInput := p.Inputs[i]
 		if pInput.FinalScriptSig != nil {
 			tin.SignatureScript = pInput.FinalScriptSig
 		}
+
+		// Similarly, if there's a final witness, then we'll also need
+		// to extract that as well, parsing the lower-level transaction
+		// encoding.
 		if pInput.FinalScriptWitness != nil {
-			// to set the witness, need to re-deserialize the field
-			// For each input, the witness is encoded as a stack
-			// with one or more items. Therefore, we first read a
-			// varint which encodes the number of stack items.
-			r := bytes.NewReader(pInput.FinalScriptWitness)
-			witCount, err := wire.ReadVarInt(r, 0)
+			// In order to set the witness, need to re-deserialize
+			// the field as encoded within the PSBT packet.  For
+			// each input, the witness is encoded as a stack with
+			// one or more items.
+			witnessReader := bytes.NewReader(
+				pInput.FinalScriptWitness,
+			)
+
+			// First we extract the number of witness elements
+			// encoded in the above witnessReader.
+			witCount, err := wire.ReadVarInt(witnessReader, 0)
 			if err != nil {
 				return nil, err
 			}
 
-			// Then for witCount number of stack items, each item
-			// has a varint length prefix, followed by the witness
-			// item itself.
-			tin.Witness = make([][]byte, witCount)
+			// Now that we know how may inputs we'll need, we'll
+			// construct a packing slice, then read out each input
+			// (with a varint prefix) from the witnessReader.
+			tin.Witness = make(wire.TxWitness, witCount)
 			for j := uint64(0); j < witCount; j++ {
-				// the 10000 size limit is as per BIP141 for witness script;
-				// TODO this constant should be somewhere else in the lib,
-				// perhaps btcd/wire/common.go ?
-				wit, err := wire.ReadVarBytes(r, 0, 10000, "witness")
+				wit, err := wire.ReadVarBytes(
+					witnessReader, 0, txscript.MaxScriptSize, "witness",
+				)
 				if err != nil {
 					return nil, err
 				}
@@ -57,10 +76,6 @@ func Extract(p *Psbt) ([]byte, error) {
 			}
 		}
 	}
-	var networkSerializedTx bytes.Buffer
-	err = newTx.Serialize(&networkSerializedTx)
-	if err != nil {
-		return nil, err
-	}
-	return networkSerializedTx.Bytes(), nil
+
+	return finalTx, nil
 }