diff --git a/votingpool/common_test.go b/votingpool/common_test.go index 82958d7..f7f83e2 100644 --- a/votingpool/common_test.go +++ b/votingpool/common_test.go @@ -24,7 +24,6 @@ import ( "testing" "github.com/btcsuite/btclog" - "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/waddrmgr" ) @@ -99,19 +98,3 @@ func TstCheckWithdrawalStatusMatches(t *testing.T, s1, s2 WithdrawalStatus) { t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2) } } - -// replaceCalculateTxFee replaces the calculateTxFee func with the given one -// and returns a function that restores it to the original one. -func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() { - orig := calculateTxFee - calculateTxFee = f - return func() { calculateTxFee = orig } -} - -// replaceCalculateTxSize replaces the calculateTxSize func with the given one -// and returns a function that restores it to the original one. -func replaceCalculateTxSize(f func(*withdrawalTx) int) func() { - orig := calculateTxSize - calculateTxSize = f - return func() { calculateTxSize = orig } -} diff --git a/votingpool/factory_test.go b/votingpool/factory_test.go index 4cb6a9e..754cef7 100644 --- a/votingpool/factory_test.go +++ b/votingpool/factory_test.go @@ -54,7 +54,7 @@ func getUniqueID() uint32 { // createWithdrawalTx creates a withdrawalTx with the given input and output amounts. func createWithdrawalTx(t *testing.T, pool *Pool, inputAmounts []int64, outputAmounts []int64) *withdrawalTx { net := pool.Manager().ChainParams() - tx := newWithdrawalTx() + tx := newWithdrawalTx(defaultTxOptions) _, credits := TstCreateCreditsOnNewSeries(t, pool, inputAmounts) for _, c := range credits { tx.addInput(c) @@ -418,8 +418,8 @@ func TstNewChangeAddress(t *testing.T, p *Pool, seriesID uint32, idx Index) (add return addr } -func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount { - return func(tx *withdrawalTx) btcutil.Amount { return fee } +func TstConstantFee(fee btcutil.Amount) func() btcutil.Amount { + return func() btcutil.Amount { return fee } } func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo { diff --git a/votingpool/withdrawal.go b/votingpool/withdrawal.go index 5bd89fa..e635f2b 100644 --- a/votingpool/withdrawal.go +++ b/votingpool/withdrawal.go @@ -269,6 +269,10 @@ type withdrawal struct { pendingRequests []OutputRequest eligibleInputs []credit current *withdrawalTx + // txOptions is a function called for every new withdrawalTx created as + // part of this withdrawal. It is defined as a function field because it + // exists mainly so that tests can mock withdrawalTx fields. + txOptions func(tx *withdrawalTx) } // withdrawalTxOut wraps an OutputRequest and provides a separate amount field. @@ -301,10 +305,26 @@ type withdrawalTx struct { // changeOutput holds information about the change for this transaction. changeOutput *wire.TxOut + + // calculateSize returns the estimated serialized size (in bytes) of this + // tx. See calculateTxSize() for details on how that's done. We use a + // struct field instead of a method so that it can be replaced in tests. + calculateSize func() int + // calculateFee calculates the expected network fees for this tx. We use a + // struct field instead of a method so that it can be replaced in tests. + calculateFee func() btcutil.Amount } -func newWithdrawalTx() *withdrawalTx { - return &withdrawalTx{} +// newWithdrawalTx creates a new withdrawalTx and calls setOptions() +// passing the newly created tx. +func newWithdrawalTx(setOptions func(tx *withdrawalTx)) *withdrawalTx { + tx := &withdrawalTx{} + tx.calculateSize = func() int { return calculateTxSize(tx) } + tx.calculateFee = func() btcutil.Amount { + return btcutil.Amount(1+tx.calculateSize()/1000) * feeIncrement + } + setOptions(tx) + return tx } // ntxid returns the unique ID for this transaction. @@ -323,7 +343,7 @@ func (tx *withdrawalTx) isTooBig() bool { // In bitcoind a tx is considered standard only if smaller than // MAX_STANDARD_TX_SIZE; that's why we consider anything >= txMaxSize to // be too big. - return calculateTxSize(tx) >= txMaxSize + return tx.calculateSize() >= txMaxSize } // inputTotal returns the sum amount of all inputs in this tx. @@ -401,7 +421,7 @@ func (tx *withdrawalTx) removeInput() credit { // added after it's called. Also, callsites must make sure adding a change // output won't cause the tx to exceed the size limit. func (tx *withdrawalTx) addChange(pkScript []byte) bool { - tx.fee = calculateTxFee(tx) + tx.fee = tx.calculateFee() change := tx.inputTotal() - tx.outputTotal() - tx.fee log.Debugf("addChange: input total %v, output total %v, fee %v", tx.inputTotal(), tx.outputTotal(), tx.fee) @@ -430,7 +450,7 @@ func (tx *withdrawalTx) rollBackLastOutput() ([]credit, *withdrawalTxOut, error) var removedInputs []credit // Continue until sum(in) < sum(out) + fee - for tx.inputTotal() >= tx.outputTotal()+calculateTxFee(tx) { + for tx.inputTotal() >= tx.outputTotal()+tx.calculateFee() { removedInputs = append(removedInputs, tx.removeInput()) } @@ -440,6 +460,8 @@ func (tx *withdrawalTx) rollBackLastOutput() ([]credit, *withdrawalTxOut, error) return removedInputs, removedOutput, nil } +func defaultTxOptions(tx *withdrawalTx) {} + func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit, changeStart ChangeAddress) *withdrawal { outputs := make(map[OutBailmentID]*WithdrawalOutput, len(requests)) @@ -452,10 +474,10 @@ func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit, } return &withdrawal{ roundID: roundID, - current: newWithdrawalTx(), pendingRequests: requests, eligibleInputs: inputs, status: status, + txOptions: defaultTxOptions, } } @@ -553,7 +575,7 @@ func (w *withdrawal) fulfillNextRequest() error { return w.handleOversizeTx() } - fee := calculateTxFee(w.current) + fee := w.current.calculateFee() for w.current.inputTotal() < w.current.outputTotal()+fee { if len(w.eligibleInputs) == 0 { log.Debug("Splitting last output because we don't have enough inputs") @@ -563,7 +585,7 @@ func (w *withdrawal) fulfillNextRequest() error { break } w.current.addInput(w.popInput()) - fee = calculateTxFee(w.current) + fee = w.current.calculateFee() if w.current.isTooBig() { return w.handleOversizeTx() @@ -647,7 +669,7 @@ func (w *withdrawal) finalizeCurrentTx() error { } w.transactions = append(w.transactions, tx) - w.current = newWithdrawalTx() + w.current = newWithdrawalTx(w.txOptions) return nil } @@ -683,12 +705,13 @@ func (w *withdrawal) fulfillRequests() error { // Sort outputs by outBailmentID (hash(server ID, tx #)) sort.Sort(byOutBailmentID(w.pendingRequests)) + w.current = newWithdrawalTx(w.txOptions) for len(w.pendingRequests) > 0 { if err := w.fulfillNextRequest(); err != nil { return err } tx := w.current - if len(w.eligibleInputs) == 0 && tx.inputTotal() <= tx.outputTotal()+calculateTxFee(tx) { + if len(w.eligibleInputs) == 0 && tx.inputTotal() <= tx.outputTotal()+tx.calculateFee() { // We don't have more eligible inputs and all the inputs in the // current tx have been spent. break @@ -731,7 +754,7 @@ func (w *withdrawal) splitLastOutput() error { output := tx.outputs[len(tx.outputs)-1] log.Debugf("Splitting tx output for %s", output.request) origAmount := output.amount - spentAmount := tx.outputTotal() + calculateTxFee(tx) - output.amount + spentAmount := tx.outputTotal() + tx.calculateFee() - output.amount // This is how much we have left after satisfying all outputs except the last // one. IOW, all we have left for the last output, so we set that as the // amount of the tx's last output. @@ -993,16 +1016,9 @@ func validateSigScript(msgtx *wire.MsgTx, idx int, pkScript []byte) error { return nil } -// calculateTxFee calculates the expected network fees for a given tx. We use -// a variable instead of a function so that it can be replaced in tests. -var calculateTxFee = func(tx *withdrawalTx) btcutil.Amount { - return btcutil.Amount(1+calculateTxSize(tx)/1000) * feeIncrement -} - // calculateTxSize returns an estimate of the serialized size (in bytes) of the -// given transaction. It assumes all tx inputs are P2SH multi-sig. We use a -// variable instead of a function so that it can be replaced in tests. -var calculateTxSize = func(tx *withdrawalTx) int { +// given transaction. It assumes all tx inputs are P2SH multi-sig. +func calculateTxSize(tx *withdrawalTx) int { msgtx := tx.toMsgTx() // Assume that there will always be a change output, for simplicity. We // simulate that by simply copying the first output as all we care about is diff --git a/votingpool/withdrawal_wb_test.go b/votingpool/withdrawal_wb_test.go index 8352859..cec80e2 100644 --- a/votingpool/withdrawal_wb_test.go +++ b/votingpool/withdrawal_wb_test.go @@ -50,12 +50,13 @@ func TestOutputSplittingNotEnoughInputs(t *testing.T) { } seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{7}) w := newWithdrawal(0, requests, eligible, *TstNewChangeAddress(t, pool, seriesID, 0)) + w.txOptions = func(tx *withdrawalTx) { + // Trigger an output split because of lack of inputs by forcing a high fee. + // If we just started with not enough inputs for the requested outputs, + // fulfillRequests() would drop outputs until we had enough. + tx.calculateFee = TstConstantFee(3) + } - // Trigger an output split because of lack of inputs by forcing a high fee. - // If we just started with not enough inputs for the requested outputs, - // fulfillRequests() would drop outputs until we had enough. - restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(3)) - defer restoreCalculateTxFee() if err := w.fulfillRequests(); err != nil { t.Fatal(err) } @@ -77,7 +78,7 @@ func TestOutputSplittingNotEnoughInputs(t *testing.T) { // The last output should have had its amount updated to whatever we had // left after satisfying all previous outputs. - newAmount := tx.inputTotal() - output1Amount - calculateTxFee(tx) + newAmount := tx.inputTotal() - output1Amount - tx.calculateFee() checkLastOutputWasSplit(t, w, tx, output2Amount, newAmount) } @@ -93,16 +94,16 @@ func TestOutputSplittingOversizeTx(t *testing.T) { seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{smallInput, bigInput}) changeStart := TstNewChangeAddress(t, pool, seriesID, 0) w := newWithdrawal(0, []OutputRequest{request}, eligible, *changeStart) - restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(0)) - defer restoreCalculateTxFee() - restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { - // Trigger an output split right after the second input is added. - if len(tx.inputs) == 2 { - return txMaxSize + 1 + w.txOptions = func(tx *withdrawalTx) { + tx.calculateFee = TstConstantFee(0) + tx.calculateSize = func() int { + // Trigger an output split right after the second input is added. + if len(tx.inputs) == 2 { + return txMaxSize + 1 + } + return txMaxSize - 1 } - return txMaxSize - 1 - }) - defer restoreCalcTxSize() + } if err := w.fulfillRequests(); err != nil { t.Fatal(err) @@ -179,7 +180,7 @@ func TestWithdrawalTxOutputs(t *testing.T) { // The created tx should include both eligible credits, so we expect it to have // an input amount of 2e6+4e6 satoshis. inputAmount := eligible[0].Amount + eligible[1].Amount - change := inputAmount - (outputs[0].Amount + outputs[1].Amount + calculateTxFee(tx)) + change := inputAmount - (outputs[0].Amount + outputs[1].Amount + tx.calculateFee()) expectedOutputs := append( outputs, TstNewOutputRequest(t, 3, changeStart.addr.String(), change, net)) msgtx := tx.toMsgTx() @@ -247,7 +248,7 @@ func TestFulfillRequestsNotEnoughCreditsForAllRequests(t *testing.T) { inputAmount := eligible[0].Amount + eligible[1].Amount // We expect it to include outputs for requests 1 and 2, plus a change output, but // output request #3 should not be there because we don't have enough credits. - change := inputAmount - (out1.Amount + out2.Amount + calculateTxFee(tx)) + change := inputAmount - (out1.Amount + out2.Amount + tx.calculateFee()) expectedOutputs := []OutputRequest{out1, out2} sort.Sort(byOutBailmentID(expectedOutputs)) expectedOutputs = append( @@ -279,8 +280,7 @@ func TestRollbackLastOutput(t *testing.T) { initialInputs := tx.inputs initialOutputs := tx.outputs - restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(1)) - defer restoreCalcTxFee() + tx.calculateFee = TstConstantFee(1) removedInputs, removedOutput, err := tx.rollBackLastOutput() if err != nil { t.Fatal("Unexpected error:", err) @@ -316,8 +316,7 @@ func TestRollbackLastOutputMultipleInputsRolledBack(t *testing.T) { initialInputs := tx.inputs initialOutputs := tx.outputs - restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(0)) - defer restoreCalcTxFee() + tx.calculateFee = TstConstantFee(0) removedInputs, _, err := tx.rollBackLastOutput() if err != nil { t.Fatal("Unexpected error:", err) @@ -348,8 +347,7 @@ func TestRollbackLastOutputNoInputsRolledBack(t *testing.T) { initialInputs := tx.inputs initialOutputs := tx.outputs - restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(1)) - defer restoreCalcTxFee() + tx.calculateFee = TstConstantFee(1) removedInputs, removedOutput, err := tx.rollBackLastOutput() if err != nil { t.Fatal("Unexpected error:", err) @@ -375,7 +373,7 @@ func TestRollbackLastOutputNoInputsRolledBack(t *testing.T) { // rollBackLastOutput returns an error if there are less than two // outputs in the transaction. func TestRollBackLastOutputInsufficientOutputs(t *testing.T) { - tx := newWithdrawalTx() + tx := newWithdrawalTx(defaultTxOptions) _, _, err := tx.rollBackLastOutput() TstCheckError(t, "", err, ErrPreconditionNotMet) @@ -402,16 +400,16 @@ func TestRollbackLastOutputWhenNewOutputAdded(t *testing.T) { changeStart := TstNewChangeAddress(t, pool, series, 0) w := newWithdrawal(0, requests, eligible, *changeStart) - restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(0)) - defer restoreCalculateTxFee() - restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { - // Trigger an output split right after the second output is added. - if len(tx.outputs) > 1 { - return txMaxSize + 1 + w.txOptions = func(tx *withdrawalTx) { + tx.calculateFee = TstConstantFee(0) + tx.calculateSize = func() int { + // Trigger an output split right after the second output is added. + if len(tx.outputs) > 1 { + return txMaxSize + 1 + } + return txMaxSize - 1 } - return txMaxSize - 1 - }) - defer restoreCalcTxSize() + } if err := w.fulfillRequests(); err != nil { t.Fatal("Unexpected error:", err) @@ -456,16 +454,16 @@ func TestRollbackLastOutputWhenNewInputAdded(t *testing.T) { changeStart := TstNewChangeAddress(t, pool, series, 0) w := newWithdrawal(0, requests, eligible, *changeStart) - restoreCalculateTxFee := replaceCalculateTxFee(TstConstantFee(0)) - defer restoreCalculateTxFee() - restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { - // Make a transaction too big as soon as a fourth input is added to it. - if len(tx.inputs) > 3 { - return txMaxSize + 1 + w.txOptions = func(tx *withdrawalTx) { + tx.calculateFee = TstConstantFee(0) + tx.calculateSize = func() int { + // Make a transaction too big as soon as a fourth input is added to it. + if len(tx.inputs) > 3 { + return txMaxSize + 1 + } + return txMaxSize - 1 } - return txMaxSize - 1 - }) - defer restoreCalcTxSize() + } // The rollback should be triggered right after the 4th input is added in // order to fulfill the second request. @@ -559,8 +557,7 @@ func TestWithdrawalTxAddChange(t *testing.T) { input, output, fee := int64(4e6), int64(3e6), int64(10) tx := createWithdrawalTx(t, pool, []int64{input}, []int64{output}) - restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(btcutil.Amount(fee))) - defer restoreCalcTxFee() + tx.calculateFee = TstConstantFee(btcutil.Amount(fee)) if !tx.addChange([]byte{}) { t.Fatal("tx.addChange() returned false, meaning it did not add a change output") @@ -586,8 +583,7 @@ func TestWithdrawalTxAddChangeNoChange(t *testing.T) { input, output, fee := int64(4e6), int64(4e6), int64(0) tx := createWithdrawalTx(t, pool, []int64{input}, []int64{output}) - restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(btcutil.Amount(fee))) - defer restoreCalcTxFee() + tx.calculateFee = TstConstantFee(btcutil.Amount(fee)) if tx.addChange([]byte{}) { t.Fatal("tx.addChange() returned true, meaning it added a change output") @@ -1003,27 +999,24 @@ func TestTxTooBig(t *testing.T) { tx := createWithdrawalTx(t, pool, []int64{5}, []int64{1}) - restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { return txMaxSize - 1 }) + tx.calculateSize = func() int { return txMaxSize - 1 } if tx.isTooBig() { t.Fatalf("Tx is smaller than max size (%d < %d) but was considered too big", - calculateTxSize(tx), txMaxSize) + tx.calculateSize(), txMaxSize) } - restoreCalcTxSize() // A tx whose size is equal to txMaxSize should be considered too big. - restoreCalcTxSize = replaceCalculateTxSize(func(tx *withdrawalTx) int { return txMaxSize }) + tx.calculateSize = func() int { return txMaxSize } if !tx.isTooBig() { t.Fatalf("Tx size is equal to the max size (%d == %d) but was not considered too big", - calculateTxSize(tx), txMaxSize) + tx.calculateSize(), txMaxSize) } - restoreCalcTxSize() - restoreCalcTxSize = replaceCalculateTxSize(func(tx *withdrawalTx) int { return txMaxSize + 1 }) + tx.calculateSize = func() int { return txMaxSize + 1 } if !tx.isTooBig() { t.Fatalf("Tx size is bigger than max size (%d > %d) but was not considered too big", - calculateTxSize(tx), txMaxSize) + tx.calculateSize(), txMaxSize) } - restoreCalcTxSize() } func TestTxSizeCalculation(t *testing.T) { @@ -1032,14 +1025,13 @@ func TestTxSizeCalculation(t *testing.T) { tx := createWithdrawalTx(t, pool, []int64{1, 5}, []int64{2}) - size := calculateTxSize(tx) + size := tx.calculateSize() // Now add a change output, get a msgtx, sign it and get its SerializedSize - // to compare with the value above. We need to replace the calculateTxFee - // function so that the tx.addChange() call below always adds a change + // to compare with the value above. We need to replace the calculateFee + // method so that the tx.addChange() call below always adds a change // output. - restoreCalcTxFee := replaceCalculateTxFee(TstConstantFee(1)) - defer restoreCalcTxFee() + tx.calculateFee = TstConstantFee(1) seriesID := tx.inputs[0].addr.SeriesID() tx.addChange(TstNewChangeAddress(t, pool, seriesID, 0).addr.ScriptAddress()) msgtx := tx.toMsgTx() @@ -1050,7 +1042,7 @@ func TestTxSizeCalculation(t *testing.T) { signTxAndValidate(t, pool.Manager(), msgtx, sigs[tx.ntxid()], tx.inputs) // ECDSA signatures have variable length (71-73 bytes) but in - // calculateTxSize() we use a dummy signature for the worst-case scenario (73 + // calculateSize() we use a dummy signature for the worst-case scenario (73 // bytes) so the estimate here can be up to 2 bytes bigger for every // signature in every input's SigScript. maxDiff := 2 * len(msgtx.TxIn) * int(tx.inputs[0].addr.series().reqSigs) @@ -1072,13 +1064,12 @@ func TestTxSizeCalculation(t *testing.T) { } func TestTxFeeEstimationForSmallTx(t *testing.T) { - tx := newWithdrawalTx() + tx := newWithdrawalTx(defaultTxOptions) // A tx that is smaller than 1000 bytes in size should have a fee of 10000 // satoshis. - restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { return 999 }) - defer restoreCalcTxSize() - fee := calculateTxFee(tx) + tx.calculateSize = func() int { return 999 } + fee := tx.calculateFee() wantFee := btcutil.Amount(1e3) if fee != wantFee { @@ -1087,13 +1078,12 @@ func TestTxFeeEstimationForSmallTx(t *testing.T) { } func TestTxFeeEstimationForLargeTx(t *testing.T) { - tx := newWithdrawalTx() + tx := newWithdrawalTx(defaultTxOptions) // A tx that is larger than 1000 bytes in size should have a fee of 1e3 // satoshis plus 1e3 for every 1000 bytes. - restoreCalcTxSize := replaceCalculateTxSize(func(tx *withdrawalTx) int { return 3000 }) - defer restoreCalcTxSize() - fee := calculateTxFee(tx) + tx.calculateSize = func() int { return 3000 } + fee := tx.calculateFee() wantFee := btcutil.Amount(4e3) if fee != wantFee { @@ -1188,7 +1178,7 @@ func createWithdrawalTxWithStoreCredits(t *testing.T, store *wtxmgr.Store, pool def := TstCreateSeriesDef(t, pool, 2, masters) TstCreateSeries(t, pool, []TstSeriesDef{def}) net := pool.Manager().ChainParams() - tx := newWithdrawalTx() + tx := newWithdrawalTx(defaultTxOptions) for _, c := range TstCreateSeriesCreditsOnStore(t, pool, def.SeriesID, inputAmounts, store) { tx.addInput(c) }