From 139197e545b9622e35e72b5b75a0a3b0a632a28b Mon Sep 17 00:00:00 2001
From: junderw <junderwood@bitcoinbank.co.jp>
Date: Mon, 26 Aug 2019 19:15:05 +0900
Subject: [PATCH] Add getFee and getVSize

---
 src/psbt.js     | 56 ++++++++++++++++++++++++++++++-----------
 test/psbt.js    | 10 ++++++++
 ts_src/psbt.ts  | 67 ++++++++++++++++++++++++++++++++++++++-----------
 types/psbt.d.ts |  2 ++
 4 files changed, 107 insertions(+), 28 deletions(-)

diff --git a/src/psbt.js b/src/psbt.js
index e134774..d9ed3b1 100644
--- a/src/psbt.js
+++ b/src/psbt.js
@@ -153,6 +153,8 @@ class Psbt {
     if (input.nonWitnessUtxo) {
       addNonWitnessTxCache(this.__CACHE, input, inputIndex);
     }
+    c.__FEE = undefined;
+    c.__VSIZE = undefined;
     c.__FEE_RATE = undefined;
     c.__EXTRACTED_TX = undefined;
     return this;
@@ -171,6 +173,8 @@ class Psbt {
     }
     const c = this.__CACHE;
     this.data.addOutput(outputData);
+    c.__FEE = undefined;
+    c.__VSIZE = undefined;
     c.__FEE_RATE = undefined;
     c.__EXTRACTED_TX = undefined;
     return this;
@@ -187,20 +191,23 @@ class Psbt {
     return tx;
   }
   getFeeRate() {
-    if (!this.data.inputs.every(isFinalized))
-      throw new Error('PSBT must be finalized to calculate fee rate');
-    const c = this.__CACHE;
-    if (c.__FEE_RATE) return c.__FEE_RATE;
-    let tx;
-    let mustFinalize = true;
-    if (c.__EXTRACTED_TX) {
-      tx = c.__EXTRACTED_TX;
-      mustFinalize = false;
-    } else {
-      tx = c.__TX.clone();
-    }
-    inputFinalizeGetAmts(this.data.inputs, tx, c, mustFinalize);
-    return c.__FEE_RATE;
+    return getTxCacheValue(
+      '__FEE_RATE',
+      'fee rate',
+      this.data.inputs,
+      this.__CACHE,
+    );
+  }
+  getFee() {
+    return getTxCacheValue('__FEE', 'fee', this.data.inputs, this.__CACHE);
+  }
+  getVSize() {
+    return getTxCacheValue(
+      '__VSIZE',
+      'virtual size',
+      this.data.inputs,
+      this.__CACHE,
+    );
   }
   finalizeAllInputs() {
     utils_1.checkForInput(this.data.inputs, 0); // making sure we have at least one
@@ -724,6 +731,25 @@ const checkWitnessScript = scriptCheckerFactory(
   payments.p2wsh,
   'Witness script',
 );
+function getTxCacheValue(key, name, inputs, c) {
+  if (!inputs.every(isFinalized))
+    throw new Error(`PSBT must be finalized to calculate ${name}`);
+  if (key === '__FEE_RATE' && c.__FEE_RATE) return c.__FEE_RATE;
+  if (key === '__FEE' && c.__FEE) return c.__FEE;
+  if (key === '__VSIZE' && c.__VSIZE) return c.__VSIZE;
+  let tx;
+  let mustFinalize = true;
+  if (c.__EXTRACTED_TX) {
+    tx = c.__EXTRACTED_TX;
+    mustFinalize = false;
+  } else {
+    tx = c.__TX.clone();
+  }
+  inputFinalizeGetAmts(inputs, tx, c, mustFinalize);
+  if (key === '__FEE_RATE') return c.__FEE_RATE;
+  else if (key === '__FEE') return c.__FEE;
+  else if (key === '__VSIZE') return c.__VSIZE;
+}
 function getFinalScripts(
   script,
   scriptType,
@@ -1124,6 +1150,8 @@ function inputFinalizeGetAmts(inputs, tx, cache, mustFinalize) {
     throw new Error('Outputs are spending more than Inputs');
   }
   const bytes = tx.virtualSize();
+  cache.__VSIZE = bytes;
+  cache.__FEE = fee;
   cache.__EXTRACTED_TX = tx;
   cache.__FEE_RATE = Math.floor(fee / bytes);
 }
diff --git a/test/psbt.js b/test/psbt.js
index 467e426..4d4b3b5 100644
--- a/test/psbt.js
+++ b/test/psbt.js
@@ -149,6 +149,16 @@ describe(`Psbt`, () => {
         const fr1 = psbt5.getFeeRate()
         const fr2 = psbt5.getFeeRate()
         assert.strictEqual(fr1, fr2)
+
+        const psbt6 =  Psbt.fromBase64(f.psbt)
+        const f1 = psbt6.getFee()
+        const f2 = psbt6.getFee()
+        assert.strictEqual(f1, f2)
+
+        const psbt7 =  Psbt.fromBase64(f.psbt)
+        const vs1 = psbt7.getVSize()
+        const vs2 = psbt7.getVSize()
+        assert.strictEqual(vs1, vs2)
       })
     })
   })
diff --git a/ts_src/psbt.ts b/ts_src/psbt.ts
index 0431056..333fa20 100644
--- a/ts_src/psbt.ts
+++ b/ts_src/psbt.ts
@@ -194,6 +194,8 @@ export class Psbt {
     if (input.nonWitnessUtxo) {
       addNonWitnessTxCache(this.__CACHE, input, inputIndex);
     }
+    c.__FEE = undefined;
+    c.__VSIZE = undefined;
     c.__FEE_RATE = undefined;
     c.__EXTRACTED_TX = undefined;
     return this;
@@ -214,6 +216,8 @@ export class Psbt {
     }
     const c = this.__CACHE;
     this.data.addOutput(outputData);
+    c.__FEE = undefined;
+    c.__VSIZE = undefined;
     c.__FEE_RATE = undefined;
     c.__EXTRACTED_TX = undefined;
     return this;
@@ -232,20 +236,25 @@ export class Psbt {
   }
 
   getFeeRate(): number {
-    if (!this.data.inputs.every(isFinalized))
-      throw new Error('PSBT must be finalized to calculate fee rate');
-    const c = this.__CACHE;
-    if (c.__FEE_RATE) return c.__FEE_RATE;
-    let tx: Transaction;
-    let mustFinalize = true;
-    if (c.__EXTRACTED_TX) {
-      tx = c.__EXTRACTED_TX;
-      mustFinalize = false;
-    } else {
-      tx = c.__TX.clone();
-    }
-    inputFinalizeGetAmts(this.data.inputs, tx, c, mustFinalize);
-    return c.__FEE_RATE!;
+    return getTxCacheValue(
+      '__FEE_RATE',
+      'fee rate',
+      this.data.inputs,
+      this.__CACHE,
+    )!;
+  }
+
+  getFee(): number {
+    return getTxCacheValue('__FEE', 'fee', this.data.inputs, this.__CACHE)!;
+  }
+
+  getVSize(): number {
+    return getTxCacheValue(
+      '__VSIZE',
+      'virtual size',
+      this.data.inputs,
+      this.__CACHE,
+    )!;
   }
 
   finalizeAllInputs(): this {
@@ -610,6 +619,8 @@ interface PsbtCache {
   __TX_IN_CACHE: { [index: string]: number };
   __TX: Transaction;
   __FEE_RATE?: number;
+  __FEE?: number;
+  __VSIZE?: number;
   __EXTRACTED_TX?: Transaction;
 }
 
@@ -920,6 +931,32 @@ const checkWitnessScript = scriptCheckerFactory(
   'Witness script',
 );
 
+type TxCacheNumberKey = '__FEE_RATE' | '__FEE' | '__VSIZE';
+function getTxCacheValue(
+  key: TxCacheNumberKey,
+  name: string,
+  inputs: PsbtInput[],
+  c: PsbtCache,
+): number | undefined {
+  if (!inputs.every(isFinalized))
+    throw new Error(`PSBT must be finalized to calculate ${name}`);
+  if (key === '__FEE_RATE' && c.__FEE_RATE) return c.__FEE_RATE;
+  if (key === '__FEE' && c.__FEE) return c.__FEE;
+  if (key === '__VSIZE' && c.__VSIZE) return c.__VSIZE;
+  let tx: Transaction;
+  let mustFinalize = true;
+  if (c.__EXTRACTED_TX) {
+    tx = c.__EXTRACTED_TX;
+    mustFinalize = false;
+  } else {
+    tx = c.__TX.clone();
+  }
+  inputFinalizeGetAmts(inputs, tx, c, mustFinalize);
+  if (key === '__FEE_RATE') return c.__FEE_RATE!;
+  else if (key === '__FEE') return c.__FEE!;
+  else if (key === '__VSIZE') return c.__VSIZE!;
+}
+
 function getFinalScripts(
   script: Buffer,
   scriptType: string,
@@ -1398,6 +1435,8 @@ function inputFinalizeGetAmts(
     throw new Error('Outputs are spending more than Inputs');
   }
   const bytes = tx.virtualSize();
+  cache.__VSIZE = bytes;
+  cache.__FEE = fee;
   cache.__EXTRACTED_TX = tx;
   cache.__FEE_RATE = Math.floor(fee / bytes);
 }
diff --git a/types/psbt.d.ts b/types/psbt.d.ts
index 6a56636..dddedd7 100644
--- a/types/psbt.d.ts
+++ b/types/psbt.d.ts
@@ -57,6 +57,8 @@ export declare class Psbt {
     addOutput(outputData: PsbtOutputExtended): this;
     extractTransaction(disableFeeCheck?: boolean): Transaction;
     getFeeRate(): number;
+    getFee(): number;
+    getVSize(): number;
     finalizeAllInputs(): this;
     finalizeInput(inputIndex: number): this;
     validateSignaturesOfAllInputs(): boolean;