Stop mocking global func() variables

Doing that may cause erratic test failures when we run them in parallel, so
move the functions the tests need to mock as struct fields that are not
shared across tests.
This commit is contained in:
Guilherme Salgado 2015-04-03 13:14:52 +01:00
parent 97e84fe212
commit fe0f60991a
4 changed files with 99 additions and 110 deletions

View file

@ -24,7 +24,6 @@ import (
"testing" "testing"
"github.com/btcsuite/btclog" "github.com/btcsuite/btclog"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/waddrmgr"
) )
@ -99,19 +98,3 @@ func TstCheckWithdrawalStatusMatches(t *testing.T, s1, s2 WithdrawalStatus) {
t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2) t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2)
} }
} }
// replaceCalculateTxFee replaces the calculateTxFee func with the given one
// and returns a function that restores it to the original one.
func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() {
orig := calculateTxFee
calculateTxFee = f
return func() { calculateTxFee = orig }
}
// replaceCalculateTxSize replaces the calculateTxSize func with the given one
// and returns a function that restores it to the original one.
func replaceCalculateTxSize(f func(*withdrawalTx) int) func() {
orig := calculateTxSize
calculateTxSize = f
return func() { calculateTxSize = orig }
}

View file

@ -54,7 +54,7 @@ func getUniqueID() uint32 {
// createWithdrawalTx creates a withdrawalTx with the given input and output amounts. // createWithdrawalTx creates a withdrawalTx with the given input and output amounts.
func createWithdrawalTx(t *testing.T, pool *Pool, inputAmounts []int64, outputAmounts []int64) *withdrawalTx { func createWithdrawalTx(t *testing.T, pool *Pool, inputAmounts []int64, outputAmounts []int64) *withdrawalTx {
net := pool.Manager().ChainParams() net := pool.Manager().ChainParams()
tx := newWithdrawalTx() tx := newWithdrawalTx(defaultTxOptions)
_, credits := TstCreateCreditsOnNewSeries(t, pool, inputAmounts) _, credits := TstCreateCreditsOnNewSeries(t, pool, inputAmounts)
for _, c := range credits { for _, c := range credits {
tx.addInput(c) tx.addInput(c)
@ -418,8 +418,8 @@ func TstNewChangeAddress(t *testing.T, p *Pool, seriesID uint32, idx Index) (add
return addr return addr
} }
func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount { func TstConstantFee(fee btcutil.Amount) func() btcutil.Amount {
return func(tx *withdrawalTx) btcutil.Amount { return fee } return func() btcutil.Amount { return fee }
} }
func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo { func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo {

View file

@ -269,6 +269,10 @@ type withdrawal struct {
pendingRequests []OutputRequest pendingRequests []OutputRequest
eligibleInputs []credit eligibleInputs []credit
current *withdrawalTx current *withdrawalTx
// txOptions is a function called for every new withdrawalTx created as
// part of this withdrawal. It is defined as a function field because it
// exists mainly so that tests can mock withdrawalTx fields.
txOptions func(tx *withdrawalTx)
} }
// withdrawalTxOut wraps an OutputRequest and provides a separate amount field. // withdrawalTxOut wraps an OutputRequest and provides a separate amount field.
@ -301,10 +305,26 @@ type withdrawalTx struct {
// changeOutput holds information about the change for this transaction. // changeOutput holds information about the change for this transaction.
changeOutput *wire.TxOut changeOutput *wire.TxOut
// calculateSize returns the estimated serialized size (in bytes) of this
// tx. See calculateTxSize() for details on how that's done. We use a
// struct field instead of a method so that it can be replaced in tests.
calculateSize func() int
// calculateFee calculates the expected network fees for this tx. We use a
// struct field instead of a method so that it can be replaced in tests.
calculateFee func() btcutil.Amount
} }
func newWithdrawalTx() *withdrawalTx { // newWithdrawalTx creates a new withdrawalTx and calls setOptions()
return &withdrawalTx{} // passing the newly created tx.
func newWithdrawalTx(setOptions func(tx *withdrawalTx)) *withdrawalTx {
tx := &withdrawalTx{}
tx.calculateSize = func() int { return calculateTxSize(tx) }
tx.calculateFee = func() btcutil.Amount {
return btcutil.Amount(1+tx.calculateSize()/1000) * feeIncrement
}
setOptions(tx)
return tx
} }
// ntxid returns the unique ID for this transaction. // ntxid returns the unique ID for this transaction.
@ -323,7 +343,7 @@ func (tx *withdrawalTx) isTooBig() bool {
// In bitcoind a tx is considered standard only if smaller than // In bitcoind a tx is considered standard only if smaller than
// MAX_STANDARD_TX_SIZE; that's why we consider anything >= txMaxSize to // MAX_STANDARD_TX_SIZE; that's why we consider anything >= txMaxSize to
// be too big. // be too big.
return calculateTxSize(tx) >= txMaxSize return tx.calculateSize() >= txMaxSize
} }
// inputTotal returns the sum amount of all inputs in this tx. // inputTotal returns the sum amount of all inputs in this tx.
@ -401,7 +421,7 @@ func (tx *withdrawalTx) removeInput() credit {
// added after it's called. Also, callsites must make sure adding a change // added after it's called. Also, callsites must make sure adding a change
// output won't cause the tx to exceed the size limit. // output won't cause the tx to exceed the size limit.
func (tx *withdrawalTx) addChange(pkScript []byte) bool { func (tx *withdrawalTx) addChange(pkScript []byte) bool {
tx.fee = calculateTxFee(tx) tx.fee = tx.calculateFee()
change := tx.inputTotal() - tx.outputTotal() - tx.fee change := tx.inputTotal() - tx.outputTotal() - tx.fee
log.Debugf("addChange: input total %v, output total %v, fee %v", tx.inputTotal(), log.Debugf("addChange: input total %v, output total %v, fee %v", tx.inputTotal(),
tx.outputTotal(), tx.fee) tx.outputTotal(), tx.fee)
@ -430,7 +450,7 @@ func (tx *withdrawalTx) rollBackLastOutput() ([]credit, *withdrawalTxOut, error)
var removedInputs []credit var removedInputs []credit
// Continue until sum(in) < sum(out) + fee // Continue until sum(in) < sum(out) + fee
for tx.inputTotal() >= tx.outputTotal()+calculateTxFee(tx) { for tx.inputTotal() >= tx.outputTotal()+tx.calculateFee() {
removedInputs = append(removedInputs, tx.removeInput()) removedInputs = append(removedInputs, tx.removeInput())
} }
@ -440,6 +460,8 @@ func (tx *withdrawalTx) rollBackLastOutput() ([]credit, *withdrawalTxOut, error)
return removedInputs, removedOutput, nil return removedInputs, removedOutput, nil
} }
func defaultTxOptions(tx *withdrawalTx) {}
func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit, func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit,
changeStart ChangeAddress) *withdrawal { changeStart ChangeAddress) *withdrawal {
outputs := make(map[OutBailmentID]*WithdrawalOutput, len(requests)) outputs := make(map[OutBailmentID]*WithdrawalOutput, len(requests))
@ -452,10 +474,10 @@ func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit,
} }
return &withdrawal{ return &withdrawal{
roundID: roundID, roundID: roundID,
current: newWithdrawalTx(),
pendingRequests: requests, pendingRequests: requests,
eligibleInputs: inputs, eligibleInputs: inputs,
status: status, status: status,
txOptions: defaultTxOptions,
} }
} }
@ -553,7 +575,7 @@ func (w *withdrawal) fulfillNextRequest() error {
return w.handleOversizeTx() return w.handleOversizeTx()
} }
fee := calculateTxFee(w.current) fee := w.current.calculateFee()
for w.current.inputTotal() < w.current.outputTotal()+fee { for w.current.inputTotal() < w.current.outputTotal()+fee {
if len(w.eligibleInputs) == 0 { if len(w.eligibleInputs) == 0 {
log.Debug("Splitting last output because we don't have enough inputs") log.Debug("Splitting last output because we don't have enough inputs")
@ -563,7 +585,7 @@ func (w *withdrawal) fulfillNextRequest() error {
break break
} }
w.current.addInput(w.popInput()) w.current.addInput(w.popInput())
fee = calculateTxFee(w.current) fee = w.current.calculateFee()
if w.current.isTooBig() { if w.current.isTooBig() {
return w.handleOversizeTx() return w.handleOversizeTx()
@ -647,7 +669,7 @@ func (w *withdrawal) finalizeCurrentTx() error {
} }
w.transactions = append(w.transactions, tx) w.transactions = append(w.transactions, tx)
w.current = newWithdrawalTx() w.current = newWithdrawalTx(w.txOptions)
return nil return nil
} }
@ -683,12 +705,13 @@ func (w *withdrawal) fulfillRequests() error {
// Sort outputs by outBailmentID (hash(server ID, tx #)) // Sort outputs by outBailmentID (hash(server ID, tx #))
sort.Sort(byOutBailmentID(w.pendingRequests)) sort.Sort(byOutBailmentID(w.pendingRequests))
w.current = newWithdrawalTx(w.txOptions)
for len(w.pendingRequests) > 0 { for len(w.pendingRequests) > 0 {
if err := w.fulfillNextRequest(); err != nil { if err := w.fulfillNextRequest(); err != nil {
return err return err
} }
tx := w.current tx := w.current
if len(w.eligibleInputs) == 0 && tx.inputTotal() <= tx.outputTotal()+calculateTxFee(tx) { if len(w.eligibleInputs) == 0 && tx.inputTotal() <= tx.outputTotal()+tx.calculateFee() {
// We don't have more eligible inputs and all the inputs in the // We don't have more eligible inputs and all the inputs in the
// current tx have been spent. // current tx have been spent.
break break
@ -731,7 +754,7 @@ func (w *withdrawal) splitLastOutput() error {
output := tx.outputs[len(tx.outputs)-1] output := tx.outputs[len(tx.outputs)-1]
log.Debugf("Splitting tx output for %s", output.request) log.Debugf("Splitting tx output for %s", output.request)
origAmount := output.amount origAmount := output.amount
spentAmount := tx.outputTotal() + calculateTxFee(tx) - output.amount spentAmount := tx.outputTotal() + tx.calculateFee() - output.amount
// This is how much we have left after satisfying all outputs except the last // This is how much we have left after satisfying all outputs except the last
// one. IOW, all we have left for the last output, so we set that as the // one. IOW, all we have left for the last output, so we set that as the
// amount of the tx's last output. // amount of the tx's last output.
@ -993,16 +1016,9 @@ func validateSigScript(msgtx *wire.MsgTx, idx int, pkScript []byte) error {
return nil return nil
} }
// calculateTxFee calculates the expected network fees for a given tx. We use
// a variable instead of a function so that it can be replaced in tests.
var calculateTxFee = func(tx *withdrawalTx) btcutil.Amount {
return btcutil.Amount(1+calculateTxSize(tx)/1000) * feeIncrement
}
// calculateTxSize returns an estimate of the serialized size (in bytes) of the // calculateTxSize returns an estimate of the serialized size (in bytes) of the
// given transaction. It assumes all tx inputs are P2SH multi-sig. We use a // given transaction. It assumes all tx inputs are P2SH multi-sig.
// variable instead of a function so that it can be replaced in tests. func calculateTxSize(tx *withdrawalTx) int {
var calculateTxSize = func(tx *withdrawalTx) int {
msgtx := tx.toMsgTx() msgtx := tx.toMsgTx()
// Assume that there will always be a change output, for simplicity. We // Assume that there will always be a change output, for simplicity. We
// simulate that by simply copying the first output as all we care about is // simulate that by simply copying the first output as all we care about is

View file

@ -50,12 +50,13 @@ func TestOutputSplittingNotEnoughInputs(t *testing.T) {
} }
seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{7}) seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{7})
w := newWithdrawal(0, requests, eligible, *TstNewChangeAddress(t, pool, seriesID, 0)) w := newWithdrawal(0, requests, eligible, *TstNewChangeAddress(t, pool, seriesID, 0))
w.txOptions = func(tx *withdrawalTx) {
// Trigger an output split because of lack of inputs by forcing a high fee.
// If we just started with not enough inputs for the requested outputs,
// fulfillRequests() would drop outputs until we had enough.
tx.calculateFee = TstConstantFee(3)
}
// Trigger an output split because of lack of inputs by forcing a high fee.
// If we just started with not enough inputs for the requested outputs,
// fulfillRequests() would drop outputs until we had enough.
restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(3))
defer restoreCalculateTxFee()
if err := w.fulfillRequests(); err != nil { if err := w.fulfillRequests(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -77,7 +78,7 @@ func TestOutputSplittingNotEnoughInputs(t *testing.T) {
// The last output should have had its amount updated to whatever we had // The last output should have had its amount updated to whatever we had
// left after satisfying all previous outputs. // left after satisfying all previous outputs.
newAmount := tx.inputTotal() - output1Amount - calculateTxFee(tx) newAmount := tx.inputTotal() - output1Amount - tx.calculateFee()
checkLastOutputWasSplit(t, w, tx, output2Amount, newAmount) checkLastOutputWasSplit(t, w, tx, output2Amount, newAmount)
} }
@ -93,16 +94,16 @@ func TestOutputSplittingOversizeTx(t *testing.T) {
seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{smallInput, bigInput}) seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{smallInput, bigInput})
changeStart := TstNewChangeAddress(t, pool, seriesID, 0) changeStart := TstNewChangeAddress(t, pool, seriesID, 0)
w := newWithdrawal(0, []OutputRequest{request}, eligible, *changeStart) w := newWithdrawal(0, []OutputRequest{request}, eligible, *changeStart)
restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(0)) w.txOptions = func(tx *withdrawalTx) {
defer restoreCalculateTxFee() tx.calculateFee = TstConstantFee(0)
restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { tx.calculateSize = func() int {
// Trigger an output split right after the second input is added. // Trigger an output split right after the second input is added.
if len(tx.inputs) == 2 { if len(tx.inputs) == 2 {
return txMaxSize + 1 return txMaxSize + 1
}
return txMaxSize - 1
} }
return txMaxSize - 1 }
})
defer restoreCalcTxSize()
if err := w.fulfillRequests(); err != nil { if err := w.fulfillRequests(); err != nil {
t.Fatal(err) t.Fatal(err)
@ -179,7 +180,7 @@ func TestWithdrawalTxOutputs(t *testing.T) {
// The created tx should include both eligible credits, so we expect it to have // The created tx should include both eligible credits, so we expect it to have
// an input amount of 2e6+4e6 satoshis. // an input amount of 2e6+4e6 satoshis.
inputAmount := eligible[0].Amount + eligible[1].Amount inputAmount := eligible[0].Amount + eligible[1].Amount
change := inputAmount - (outputs[0].Amount + outputs[1].Amount + calculateTxFee(tx)) change := inputAmount - (outputs[0].Amount + outputs[1].Amount + tx.calculateFee())
expectedOutputs := append( expectedOutputs := append(
outputs, TstNewOutputRequest(t, 3, changeStart.addr.String(), change, net)) outputs, TstNewOutputRequest(t, 3, changeStart.addr.String(), change, net))
msgtx := tx.toMsgTx() msgtx := tx.toMsgTx()
@ -247,7 +248,7 @@ func TestFulfillRequestsNotEnoughCreditsForAllRequests(t *testing.T) {
inputAmount := eligible[0].Amount + eligible[1].Amount inputAmount := eligible[0].Amount + eligible[1].Amount
// We expect it to include outputs for requests 1 and 2, plus a change output, but // We expect it to include outputs for requests 1 and 2, plus a change output, but
// output request #3 should not be there because we don't have enough credits. // output request #3 should not be there because we don't have enough credits.
change := inputAmount - (out1.Amount + out2.Amount + calculateTxFee(tx)) change := inputAmount - (out1.Amount + out2.Amount + tx.calculateFee())
expectedOutputs := []OutputRequest{out1, out2} expectedOutputs := []OutputRequest{out1, out2}
sort.Sort(byOutBailmentID(expectedOutputs)) sort.Sort(byOutBailmentID(expectedOutputs))
expectedOutputs = append( expectedOutputs = append(
@ -279,8 +280,7 @@ func TestRollbackLastOutput(t *testing.T) {
initialInputs := tx.inputs initialInputs := tx.inputs
initialOutputs := tx.outputs initialOutputs := tx.outputs
restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(1)) tx.calculateFee = TstConstantFee(1)
defer restoreCalcTxFee()
removedInputs, removedOutput, err := tx.rollBackLastOutput() removedInputs, removedOutput, err := tx.rollBackLastOutput()
if err != nil { if err != nil {
t.Fatal("Unexpected error:", err) t.Fatal("Unexpected error:", err)
@ -316,8 +316,7 @@ func TestRollbackLastOutputMultipleInputsRolledBack(t *testing.T) {
initialInputs := tx.inputs initialInputs := tx.inputs
initialOutputs := tx.outputs initialOutputs := tx.outputs
restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(0)) tx.calculateFee = TstConstantFee(0)
defer restoreCalcTxFee()
removedInputs, _, err := tx.rollBackLastOutput() removedInputs, _, err := tx.rollBackLastOutput()
if err != nil { if err != nil {
t.Fatal("Unexpected error:", err) t.Fatal("Unexpected error:", err)
@ -348,8 +347,7 @@ func TestRollbackLastOutputNoInputsRolledBack(t *testing.T) {
initialInputs := tx.inputs initialInputs := tx.inputs
initialOutputs := tx.outputs initialOutputs := tx.outputs
restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(1)) tx.calculateFee = TstConstantFee(1)
defer restoreCalcTxFee()
removedInputs, removedOutput, err := tx.rollBackLastOutput() removedInputs, removedOutput, err := tx.rollBackLastOutput()
if err != nil { if err != nil {
t.Fatal("Unexpected error:", err) t.Fatal("Unexpected error:", err)
@ -375,7 +373,7 @@ func TestRollbackLastOutputNoInputsRolledBack(t *testing.T) {
// rollBackLastOutput returns an error if there are less than two // rollBackLastOutput returns an error if there are less than two
// outputs in the transaction. // outputs in the transaction.
func TestRollBackLastOutputInsufficientOutputs(t *testing.T) { func TestRollBackLastOutputInsufficientOutputs(t *testing.T) {
tx := newWithdrawalTx() tx := newWithdrawalTx(defaultTxOptions)
_, _, err := tx.rollBackLastOutput() _, _, err := tx.rollBackLastOutput()
TstCheckError(t, "", err, ErrPreconditionNotMet) TstCheckError(t, "", err, ErrPreconditionNotMet)
@ -402,16 +400,16 @@ func TestRollbackLastOutputWhenNewOutputAdded(t *testing.T) {
changeStart := TstNewChangeAddress(t, pool, series, 0) changeStart := TstNewChangeAddress(t, pool, series, 0)
w := newWithdrawal(0, requests, eligible, *changeStart) w := newWithdrawal(0, requests, eligible, *changeStart)
restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(0)) w.txOptions = func(tx *withdrawalTx) {
defer restoreCalculateTxFee() tx.calculateFee = TstConstantFee(0)
restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { tx.calculateSize = func() int {
// Trigger an output split right after the second output is added. // Trigger an output split right after the second output is added.
if len(tx.outputs) > 1 { if len(tx.outputs) > 1 {
return txMaxSize + 1 return txMaxSize + 1
}
return txMaxSize - 1
} }
return txMaxSize - 1 }
})
defer restoreCalcTxSize()
if err := w.fulfillRequests(); err != nil { if err := w.fulfillRequests(); err != nil {
t.Fatal("Unexpected error:", err) t.Fatal("Unexpected error:", err)
@ -456,16 +454,16 @@ func TestRollbackLastOutputWhenNewInputAdded(t *testing.T) {
changeStart := TstNewChangeAddress(t, pool, series, 0) changeStart := TstNewChangeAddress(t, pool, series, 0)
w := newWithdrawal(0, requests, eligible, *changeStart) w := newWithdrawal(0, requests, eligible, *changeStart)
restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(0)) w.txOptions = func(tx *withdrawalTx) {
defer restoreCalculateTxFee() tx.calculateFee = TstConstantFee(0)
restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { tx.calculateSize = func() int {
// Make a transaction too big as soon as a fourth input is added to it. // Make a transaction too big as soon as a fourth input is added to it.
if len(tx.inputs) > 3 { if len(tx.inputs) > 3 {
return txMaxSize + 1 return txMaxSize + 1
}
return txMaxSize - 1
} }
return txMaxSize - 1 }
})
defer restoreCalcTxSize()
// The rollback should be triggered right after the 4th input is added in // The rollback should be triggered right after the 4th input is added in
// order to fulfill the second request. // order to fulfill the second request.
@ -559,8 +557,7 @@ func TestWithdrawalTxAddChange(t *testing.T) {
input, output, fee := int64(4e6), int64(3e6), int64(10) input, output, fee := int64(4e6), int64(3e6), int64(10)
tx := createWithdrawalTx(t, pool, []int64{input}, []int64{output}) tx := createWithdrawalTx(t, pool, []int64{input}, []int64{output})
restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(btcutil.Amount(fee))) tx.calculateFee = TstConstantFee(btcutil.Amount(fee))
defer restoreCalcTxFee()
if !tx.addChange([]byte{}) { if !tx.addChange([]byte{}) {
t.Fatal("tx.addChange() returned false, meaning it did not add a change output") t.Fatal("tx.addChange() returned false, meaning it did not add a change output")
@ -586,8 +583,7 @@ func TestWithdrawalTxAddChangeNoChange(t *testing.T) {
input, output, fee := int64(4e6), int64(4e6), int64(0) input, output, fee := int64(4e6), int64(4e6), int64(0)
tx := createWithdrawalTx(t, pool, []int64{input}, []int64{output}) tx := createWithdrawalTx(t, pool, []int64{input}, []int64{output})
restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(btcutil.Amount(fee))) tx.calculateFee = TstConstantFee(btcutil.Amount(fee))
defer restoreCalcTxFee()
if tx.addChange([]byte{}) { if tx.addChange([]byte{}) {
t.Fatal("tx.addChange() returned true, meaning it added a change output") t.Fatal("tx.addChange() returned true, meaning it added a change output")
@ -1003,27 +999,24 @@ func TestTxTooBig(t *testing.T) {
tx := createWithdrawalTx(t, pool, []int64{5}, []int64{1}) tx := createWithdrawalTx(t, pool, []int64{5}, []int64{1})
restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { return txMaxSize - 1 }) tx.calculateSize = func() int { return txMaxSize - 1 }
if tx.isTooBig() { if tx.isTooBig() {
t.Fatalf("Tx is smaller than max size (%d < %d) but was considered too big", t.Fatalf("Tx is smaller than max size (%d < %d) but was considered too big",
calculateTxSize(tx), txMaxSize) tx.calculateSize(), txMaxSize)
} }
restoreCalcTxSize()
// A tx whose size is equal to txMaxSize should be considered too big. // A tx whose size is equal to txMaxSize should be considered too big.
restoreCalcTxSize = replaceCalculateTxSize(func(tx *withdrawalTx) int { return txMaxSize }) tx.calculateSize = func() int { return txMaxSize }
if !tx.isTooBig() { if !tx.isTooBig() {
t.Fatalf("Tx size is equal to the max size (%d == %d) but was not considered too big", t.Fatalf("Tx size is equal to the max size (%d == %d) but was not considered too big",
calculateTxSize(tx), txMaxSize) tx.calculateSize(), txMaxSize)
} }
restoreCalcTxSize()
restoreCalcTxSize = replaceCalculateTxSize(func(tx *withdrawalTx) int { return txMaxSize + 1 }) tx.calculateSize = func() int { return txMaxSize + 1 }
if !tx.isTooBig() { if !tx.isTooBig() {
t.Fatalf("Tx size is bigger than max size (%d > %d) but was not considered too big", t.Fatalf("Tx size is bigger than max size (%d > %d) but was not considered too big",
calculateTxSize(tx), txMaxSize) tx.calculateSize(), txMaxSize)
} }
restoreCalcTxSize()
} }
func TestTxSizeCalculation(t *testing.T) { func TestTxSizeCalculation(t *testing.T) {
@ -1032,14 +1025,13 @@ func TestTxSizeCalculation(t *testing.T) {
tx := createWithdrawalTx(t, pool, []int64{1, 5}, []int64{2}) tx := createWithdrawalTx(t, pool, []int64{1, 5}, []int64{2})
size := calculateTxSize(tx) size := tx.calculateSize()
// Now add a change output, get a msgtx, sign it and get its SerializedSize // Now add a change output, get a msgtx, sign it and get its SerializedSize
// to compare with the value above. We need to replace the calculateTxFee // to compare with the value above. We need to replace the calculateFee
// function so that the tx.addChange() call below always adds a change // method so that the tx.addChange() call below always adds a change
// output. // output.
restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(1)) tx.calculateFee = TstConstantFee(1)
defer restoreCalcTxFee()
seriesID := tx.inputs[0].addr.SeriesID() seriesID := tx.inputs[0].addr.SeriesID()
tx.addChange(TstNewChangeAddress(t, pool, seriesID, 0).addr.ScriptAddress()) tx.addChange(TstNewChangeAddress(t, pool, seriesID, 0).addr.ScriptAddress())
msgtx := tx.toMsgTx() msgtx := tx.toMsgTx()
@ -1050,7 +1042,7 @@ func TestTxSizeCalculation(t *testing.T) {
signTxAndValidate(t, pool.Manager(), msgtx, sigs[tx.ntxid()], tx.inputs) signTxAndValidate(t, pool.Manager(), msgtx, sigs[tx.ntxid()], tx.inputs)
// ECDSA signatures have variable length (71-73 bytes) but in // ECDSA signatures have variable length (71-73 bytes) but in
// calculateTxSize() we use a dummy signature for the worst-case scenario (73 // calculateSize() we use a dummy signature for the worst-case scenario (73
// bytes) so the estimate here can be up to 2 bytes bigger for every // bytes) so the estimate here can be up to 2 bytes bigger for every
// signature in every input's SigScript. // signature in every input's SigScript.
maxDiff := 2 * len(msgtx.TxIn) * int(tx.inputs[0].addr.series().reqSigs) maxDiff := 2 * len(msgtx.TxIn) * int(tx.inputs[0].addr.series().reqSigs)
@ -1072,13 +1064,12 @@ func TestTxSizeCalculation(t *testing.T) {
} }
func TestTxFeeEstimationForSmallTx(t *testing.T) { func TestTxFeeEstimationForSmallTx(t *testing.T) {
tx := newWithdrawalTx() tx := newWithdrawalTx(defaultTxOptions)
// A tx that is smaller than 1000 bytes in size should have a fee of 10000 // A tx that is smaller than 1000 bytes in size should have a fee of 10000
// satoshis. // satoshis.
restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { return 999 }) tx.calculateSize = func() int { return 999 }
defer restoreCalcTxSize() fee := tx.calculateFee()
fee := calculateTxFee(tx)
wantFee := btcutil.Amount(1e3) wantFee := btcutil.Amount(1e3)
if fee != wantFee { if fee != wantFee {
@ -1087,13 +1078,12 @@ func TestTxFeeEstimationForSmallTx(t *testing.T) {
} }
func TestTxFeeEstimationForLargeTx(t *testing.T) { func TestTxFeeEstimationForLargeTx(t *testing.T) {
tx := newWithdrawalTx() tx := newWithdrawalTx(defaultTxOptions)
// A tx that is larger than 1000 bytes in size should have a fee of 1e3 // A tx that is larger than 1000 bytes in size should have a fee of 1e3
// satoshis plus 1e3 for every 1000 bytes. // satoshis plus 1e3 for every 1000 bytes.
restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { return 3000 }) tx.calculateSize = func() int { return 3000 }
defer restoreCalcTxSize() fee := tx.calculateFee()
fee := calculateTxFee(tx)
wantFee := btcutil.Amount(4e3) wantFee := btcutil.Amount(4e3)
if fee != wantFee { if fee != wantFee {
@ -1188,7 +1178,7 @@ func createWithdrawalTxWithStoreCredits(t *testing.T, store *wtxmgr.Store, pool
def := TstCreateSeriesDef(t, pool, 2, masters) def := TstCreateSeriesDef(t, pool, 2, masters)
TstCreateSeries(t, pool, []TstSeriesDef{def}) TstCreateSeries(t, pool, []TstSeriesDef{def})
net := pool.Manager().ChainParams() net := pool.Manager().ChainParams()
tx := newWithdrawalTx() tx := newWithdrawalTx(defaultTxOptions)
for _, c := range TstCreateSeriesCreditsOnStore(t, pool, def.SeriesID, inputAmounts, store) { for _, c := range TstCreateSeriesCreditsOnStore(t, pool, def.SeriesID, inputAmounts, store) {
tx.addInput(c) tx.addInput(c)
} }