From 47ca1ca6e57e7fddd3c1fd2a6411e7b14a070c23 Mon Sep 17 00:00:00 2001 From: Guilherme Salgado Date: Mon, 23 Feb 2015 16:07:12 +0000 Subject: [PATCH] StartWithdrawal returns a previously saved WithdrawalStatus if parameters match StartWithdrawal now persists the WithdrawalStatus before returning, and also returns a previously saved one in subsequent calls with the same parameters. --- votingpool/common_test.go | 37 +++++ votingpool/db.go | 251 ++++++++++++++++++++++++++++++- votingpool/db_wb_test.go | 75 +++++++++ votingpool/error.go | 5 + votingpool/error_test.go | 1 + votingpool/factory_test.go | 26 ++++ votingpool/internal_test.go | 5 +- votingpool/withdrawal.go | 107 ++++++++++++- votingpool/withdrawal_test.go | 12 ++ votingpool/withdrawal_wb_test.go | 117 ++++++++++++++ 10 files changed, 631 insertions(+), 5 deletions(-) diff --git a/votingpool/common_test.go b/votingpool/common_test.go index 010a733..a638f26 100644 --- a/votingpool/common_test.go +++ b/votingpool/common_test.go @@ -19,6 +19,7 @@ package votingpool import ( "fmt" "os" + "reflect" "runtime" "testing" @@ -63,6 +64,42 @@ func TstRunWithManagerUnlocked(t *testing.T, mgr *waddrmgr.Manager, callback fun callback() } +// TstCheckWithdrawalStatusMatches compares s1 and s2 using reflect.DeepEqual +// and calls t.Fatal() if they're not identical. +func TstCheckWithdrawalStatusMatches(t *testing.T, s1, s2 WithdrawalStatus) { + if s1.Fees() != s2.Fees() { + t.Fatalf("Wrong amount of network fees; want %d, got %d", s1.Fees(), s2.Fees()) + } + + if !reflect.DeepEqual(s1.Sigs(), s2.Sigs()) { + t.Fatalf("Wrong tx signatures; got %x, want %x", s1.Sigs(), s2.Sigs()) + } + + if !reflect.DeepEqual(s1.NextInputAddr(), s2.NextInputAddr()) { + t.Fatalf("Wrong NextInputAddr; got %v, want %v", s1.NextInputAddr(), s2.NextInputAddr()) + } + + if !reflect.DeepEqual(s1.NextChangeAddr(), s2.NextChangeAddr()) { + t.Fatalf("Wrong NextChangeAddr; got %v, want %v", s1.NextChangeAddr(), s2.NextChangeAddr()) + } + + if !reflect.DeepEqual(s1.Outputs(), s2.Outputs()) { + t.Fatalf("Wrong WithdrawalOutputs; got %v, want %v", s1.Outputs(), s2.Outputs()) + } + + if !reflect.DeepEqual(s1.transactions, s2.transactions) { + t.Fatalf("Wrong transactions; got %v, want %v", s1.transactions, s2.transactions) + } + + // The above checks could be replaced by this one, but when they fail the + // failure msg wouldn't give us much clue as to what is not equal, so we do + // the individual checks above and use this one as a catch-all check in case + // we forget to check any of the individual fields. + if !reflect.DeepEqual(s1, s2) { + t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2) + } +} + // replaceCalculateTxFee replaces the calculateTxFee func with the given one // and returns a function that restores it to the original one. func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() { diff --git a/votingpool/db.go b/votingpool/db.go index 60ee14e..416a142 100644 --- a/votingpool/db.go +++ b/votingpool/db.go @@ -19,8 +19,12 @@ package votingpool import ( "bytes" "encoding/binary" + "encoding/gob" "fmt" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/snacl" "github.com/btcsuite/btcwallet/walletdb" ) @@ -43,8 +47,9 @@ const ( ) var ( - usedAddrsBucketName = []byte("usedaddrs") - seriesBucketName = []byte("series") + usedAddrsBucketName = []byte("usedaddrs") + seriesBucketName = []byte("series") + withdrawalsBucketName = []byte("withdrawals") // string representing a non-existent private key seriesNullPrivKey = [seriesKeyLength]byte{} ) @@ -57,6 +62,61 @@ type dbSeriesRow struct { privKeysEncrypted [][]byte } +type dbWithdrawalRow struct { + Requests []dbOutputRequest + StartAddress dbWithdrawalAddress + ChangeStart dbChangeAddress + LastSeriesID uint32 + DustThreshold btcutil.Amount + Status dbWithdrawalStatus +} + +type dbWithdrawalAddress struct { + SeriesID uint32 + Branch Branch + Index Index +} + +type dbChangeAddress struct { + SeriesID uint32 + Index Index +} + +type dbOutputRequest struct { + Addr string + Amount btcutil.Amount + Server string + Transaction uint32 +} + +type dbWithdrawalOutput struct { + // We store the OutBailmentID here as we need a way to look up the + // corresponding dbOutputRequest in dbWithdrawalRow when deserializing. + OutBailmentID OutBailmentID + Status outputStatus + Outpoints []dbOutBailmentOutpoint +} + +type dbOutBailmentOutpoint struct { + Ntxid Ntxid + Index uint32 + Amount btcutil.Amount +} + +type dbChangeAwareTx struct { + SerializedMsgTx []byte + ChangeIdx int32 +} + +type dbWithdrawalStatus struct { + NextInputAddr dbWithdrawalAddress + NextChangeAddr dbChangeAddress + Fees btcutil.Amount + Outputs map[OutBailmentID]dbWithdrawalOutput + Sigs map[Ntxid]TxSigs + Transactions map[Ntxid]dbChangeAwareTx +} + // getUsedAddrBucketID returns the used addresses bucket ID for the given series // and branch. It has the form seriesID:branch. func getUsedAddrBucketID(seriesID uint32, branch Branch) []byte { @@ -140,6 +200,11 @@ func putPool(tx walletdb.Tx, poolID []byte) error { return newError(ErrDatabase, fmt.Sprintf("cannot create used addrs bucket for pool %v", poolID), err) } + _, err = poolBucket.CreateBucket(withdrawalsBucketName) + if err != nil { + return newError( + ErrDatabase, fmt.Sprintf("cannot create withdrawals bucket for pool %v", poolID), err) + } return nil } @@ -339,6 +404,188 @@ func serializeSeriesRow(row *dbSeriesRow) ([]byte, error) { return serialized, nil } +// serializeWithdrawal constructs a dbWithdrawalRow and serializes it (using +// encoding/gob) so that it can be stored in the DB. +func serializeWithdrawal(requests []OutputRequest, startAddress WithdrawalAddress, + lastSeriesID uint32, changeStart ChangeAddress, dustThreshold btcutil.Amount, + status WithdrawalStatus) ([]byte, error) { + + dbStartAddr := dbWithdrawalAddress{ + SeriesID: startAddress.SeriesID(), + Branch: startAddress.Branch(), + Index: startAddress.Index(), + } + dbChangeStart := dbChangeAddress{ + SeriesID: startAddress.SeriesID(), + Index: startAddress.Index(), + } + dbRequests := make([]dbOutputRequest, len(requests)) + for i, request := range requests { + dbRequests[i] = dbOutputRequest{ + Addr: request.Address.EncodeAddress(), + Amount: request.Amount, + Server: request.Server, + Transaction: request.Transaction, + } + } + dbOutputs := make(map[OutBailmentID]dbWithdrawalOutput, len(status.outputs)) + for oid, output := range status.outputs { + dbOutpoints := make([]dbOutBailmentOutpoint, len(output.outpoints)) + for i, outpoint := range output.outpoints { + dbOutpoints[i] = dbOutBailmentOutpoint{ + Ntxid: outpoint.ntxid, + Index: outpoint.index, + Amount: outpoint.amount, + } + } + dbOutputs[oid] = dbWithdrawalOutput{ + OutBailmentID: output.request.outBailmentID(), + Status: output.status, + Outpoints: dbOutpoints, + } + } + dbTransactions := make(map[Ntxid]dbChangeAwareTx, len(status.transactions)) + for ntxid, tx := range status.transactions { + var buf bytes.Buffer + buf.Grow(tx.SerializeSize()) + if err := tx.Serialize(&buf); err != nil { + return nil, err + } + dbTransactions[ntxid] = dbChangeAwareTx{ + SerializedMsgTx: buf.Bytes(), + ChangeIdx: tx.changeIdx, + } + } + nextChange := status.nextChangeAddr + dbStatus := dbWithdrawalStatus{ + NextChangeAddr: dbChangeAddress{ + SeriesID: nextChange.seriesID, + Index: nextChange.index, + }, + Fees: status.fees, + Outputs: dbOutputs, + Sigs: status.sigs, + Transactions: dbTransactions, + } + row := dbWithdrawalRow{ + Requests: dbRequests, + StartAddress: dbStartAddr, + LastSeriesID: lastSeriesID, + ChangeStart: dbChangeStart, + DustThreshold: dustThreshold, + Status: dbStatus, + } + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(row); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// deserializeWithdrawal deserializes the given byte slice into a dbWithdrawalRow, +// converts it into an withdrawalInfo and returns it. This function must run +// with the address manager unlocked. +func deserializeWithdrawal(p *Pool, serialized []byte) (*withdrawalInfo, error) { + var row dbWithdrawalRow + if err := gob.NewDecoder(bytes.NewReader(serialized)).Decode(&row); err != nil { + return nil, newError(ErrWithdrawalStorage, "cannot deserialize withdrawal information", + err) + } + wInfo := &withdrawalInfo{ + lastSeriesID: row.LastSeriesID, + dustThreshold: row.DustThreshold, + } + chainParams := p.Manager().ChainParams() + wInfo.requests = make([]OutputRequest, len(row.Requests)) + // A map of requests indexed by OutBailmentID; needed to populate + // WithdrawalStatus.Outputs later on. + requestsByOID := make(map[OutBailmentID]OutputRequest) + for i, req := range row.Requests { + addr, err := btcutil.DecodeAddress(req.Addr, chainParams) + if err != nil { + return nil, newError(ErrWithdrawalStorage, + "cannot deserialize addr for requested output", err) + } + pkScript, err := txscript.PayToAddrScript(addr) + if err != nil { + return nil, newError(ErrWithdrawalStorage, "invalid addr for requested output", err) + } + request := OutputRequest{ + Address: addr, + Amount: req.Amount, + PkScript: pkScript, + Server: req.Server, + Transaction: req.Transaction, + } + wInfo.requests[i] = request + requestsByOID[request.outBailmentID()] = request + } + startAddr := row.StartAddress + wAddr, err := p.WithdrawalAddress(startAddr.SeriesID, startAddr.Branch, startAddr.Index) + if err != nil { + return nil, newError(ErrWithdrawalStorage, "cannot deserialize startAddress", err) + } + wInfo.startAddress = *wAddr + + cAddr, err := p.ChangeAddress(row.ChangeStart.SeriesID, row.ChangeStart.Index) + if err != nil { + return nil, newError(ErrWithdrawalStorage, "cannot deserialize changeStart", err) + } + wInfo.changeStart = *cAddr + + // TODO: Copy over row.Status.nextInputAddr. Not done because StartWithdrawal + // does not update that yet. + nextChangeAddr := row.Status.NextChangeAddr + cAddr, err = p.ChangeAddress(nextChangeAddr.SeriesID, nextChangeAddr.Index) + if err != nil { + return nil, newError(ErrWithdrawalStorage, + "cannot deserialize nextChangeAddress for withdrawal", err) + } + wInfo.status = WithdrawalStatus{ + nextChangeAddr: *cAddr, + fees: row.Status.Fees, + outputs: make(map[OutBailmentID]*WithdrawalOutput, len(row.Status.Outputs)), + sigs: row.Status.Sigs, + transactions: make(map[Ntxid]changeAwareTx, len(row.Status.Transactions)), + } + for oid, output := range row.Status.Outputs { + outpoints := make([]OutBailmentOutpoint, len(output.Outpoints)) + for i, outpoint := range output.Outpoints { + outpoints[i] = OutBailmentOutpoint{ + ntxid: outpoint.Ntxid, + index: outpoint.Index, + amount: outpoint.Amount, + } + } + wInfo.status.outputs[oid] = &WithdrawalOutput{ + request: requestsByOID[output.OutBailmentID], + status: output.Status, + outpoints: outpoints, + } + } + for ntxid, tx := range row.Status.Transactions { + msgtx := wire.NewMsgTx() + if err := msgtx.Deserialize(bytes.NewBuffer(tx.SerializedMsgTx)); err != nil { + return nil, newError(ErrWithdrawalStorage, "cannot deserialize transaction", err) + } + wInfo.status.transactions[ntxid] = changeAwareTx{ + MsgTx: msgtx, + changeIdx: tx.ChangeIdx, + } + } + return wInfo, nil +} + +func putWithdrawal(tx walletdb.Tx, poolID []byte, roundID uint32, serialized []byte) error { + bucket := tx.RootBucket().Bucket(poolID) + return bucket.Put(uint32ToBytes(roundID), serialized) +} + +func getWithdrawal(tx walletdb.Tx, poolID []byte, roundID uint32) []byte { + bucket := tx.RootBucket().Bucket(poolID) + return bucket.Get(uint32ToBytes(roundID)) +} + // uint32ToBytes converts a 32 bit unsigned integer into a 4-byte slice in // little-endian order: 1 -> [1 0 0 0]. func uint32ToBytes(number uint32) []byte { diff --git a/votingpool/db_wb_test.go b/votingpool/db_wb_test.go index 961c17c..443c3ed 100644 --- a/votingpool/db_wb_test.go +++ b/votingpool/db_wb_test.go @@ -18,6 +18,7 @@ package votingpool import ( "bytes" + "reflect" "testing" "github.com/btcsuite/btcwallet/walletdb" @@ -80,3 +81,77 @@ func TestGetMaxUsedIdx(t *testing.T) { t.Fatalf("Wrong max idx; got %d, want %d", maxIdx, Index(3001)) } } + +func TestWithdrawalSerialization(t *testing.T) { + tearDown, _, pool := TstCreatePool(t) + defer tearDown() + + roundID := uint32(0) + wi := createAndFulfillWithdrawalRequests(t, pool, roundID) + + serialized, err := serializeWithdrawal(wi.requests, wi.startAddress, wi.lastSeriesID, + wi.changeStart, wi.dustThreshold, wi.status) + if err != nil { + t.Fatal(err) + } + + var wInfo *withdrawalInfo + TstRunWithManagerUnlocked(t, pool.Manager(), func() { + wInfo, err = deserializeWithdrawal(pool, serialized) + if err != nil { + t.Fatal(err) + } + }) + + if !reflect.DeepEqual(wInfo.startAddress, wi.startAddress) { + t.Fatalf("Wrong startAddr; got %v, want %v", wInfo.startAddress, wi.startAddress) + } + + if !reflect.DeepEqual(wInfo.changeStart, wi.changeStart) { + t.Fatalf("Wrong changeStart; got %v, want %v", wInfo.changeStart, wi.changeStart) + } + + if wInfo.lastSeriesID != wi.lastSeriesID { + t.Fatalf("Wrong LastSeriesID; got %d, want %d", wInfo.lastSeriesID, wi.lastSeriesID) + } + + if wInfo.dustThreshold != wi.dustThreshold { + t.Fatalf("Wrong DustThreshold; got %d, want %d", wInfo.dustThreshold, wi.dustThreshold) + } + + if !reflect.DeepEqual(wInfo.requests, wi.requests) { + t.Fatalf("Wrong output requests; got %v, want %v", wInfo.requests, wi.requests) + } + + TstCheckWithdrawalStatusMatches(t, wInfo.status, wi.status) +} + +func TestPutAndGetWithdrawal(t *testing.T) { + tearDown, _, pool := TstCreatePool(t) + defer tearDown() + + serialized := bytes.Repeat([]byte{1}, 10) + poolID := []byte{0x00} + roundID := uint32(0) + err := pool.namespace.Update( + func(tx walletdb.Tx) error { + return putWithdrawal(tx, poolID, roundID, serialized) + }) + if err != nil { + t.Fatal(err) + } + + var retrieved []byte + err = pool.namespace.View( + func(tx walletdb.Tx) error { + retrieved = getWithdrawal(tx, poolID, roundID) + return nil + }) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(retrieved, serialized) { + t.Fatalf("Wrong value retrieved from DB; got %x, want %x", retrieved, serialized) + } +} diff --git a/votingpool/error.go b/votingpool/error.go index c2e9e1a..bcd6db8 100644 --- a/votingpool/error.go +++ b/votingpool/error.go @@ -151,6 +151,10 @@ const ( // transactions. ErrWithdrawalTxStorage + // ErrWithdrawalStorage indicates an error occurred when serializing or + // deserializing withdrawal information. + ErrWithdrawalStorage + // lastErr is used for testing, making it possible to iterate over // the error codes in order to check that they all have proper // translations in errorCodeStrings. @@ -192,6 +196,7 @@ var errorCodeStrings = map[ErrorCode]string{ ErrInvalidScriptHash: "ErrInvalidScriptHash", ErrWithdrawFromUnusedAddr: "ErrWithdrawFromUnusedAddr", ErrWithdrawalTxStorage: "ErrWithdrawalTxStorage", + ErrWithdrawalStorage: "ErrWithdrawalStorage", } // String returns the ErrorCode as a human-readable name. diff --git a/votingpool/error_test.go b/votingpool/error_test.go index 0c5e68e..3842031 100644 --- a/votingpool/error_test.go +++ b/votingpool/error_test.go @@ -65,6 +65,7 @@ func TestErrorCodeStringer(t *testing.T) { {vp.ErrInvalidScriptHash, "ErrInvalidScriptHash"}, {vp.ErrWithdrawFromUnusedAddr, "ErrWithdrawFromUnusedAddr"}, {vp.ErrWithdrawalTxStorage, "ErrWithdrawalTxStorage"}, + {vp.ErrWithdrawalStorage, "ErrWithdrawalStorage"}, {0xffff, "Unknown ErrorCode (65535)"}, } diff --git a/votingpool/factory_test.go b/votingpool/factory_test.go index 9f3e16e..b3e5d4b 100644 --- a/votingpool/factory_test.go +++ b/votingpool/factory_test.go @@ -421,3 +421,29 @@ func TstNewChangeAddress(t *testing.T, p *Pool, seriesID uint32, idx Index) (add func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount { return func(tx *withdrawalTx) btcutil.Amount { return fee } } + +func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo { + + params := pool.Manager().ChainParams() + seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{2e6, 4e6}) + requests := []OutputRequest{ + TstNewOutputRequest(t, 1, "34eVkREKgvvGASZW7hkgE2uNc1yycntMK6", 3e6, params), + TstNewOutputRequest(t, 2, "3PbExiaztsSYgh6zeMswC49hLUwhTQ86XG", 2e6, params), + } + changeStart := TstNewChangeAddress(t, pool, seriesID, 0) + dustThreshold := btcutil.Amount(1e4) + startAddr := TstNewWithdrawalAddress(t, pool, seriesID, 1, 0) + lastSeriesID := seriesID + w := newWithdrawal(roundID, requests, eligible, *changeStart) + if err := w.fulfillRequests(); err != nil { + t.Fatal(err) + } + return withdrawalInfo{ + requests: requests, + startAddress: *startAddr, + changeStart: *changeStart, + lastSeriesID: lastSeriesID, + dustThreshold: dustThreshold, + status: *w.status, + } +} diff --git a/votingpool/internal_test.go b/votingpool/internal_test.go index 6bfecd4..7ecf73c 100644 --- a/votingpool/internal_test.go +++ b/votingpool/internal_test.go @@ -96,7 +96,8 @@ func (vp *Pool) TstDecryptExtendedKey(keyType waddrmgr.CryptoKeyType, encrypted return vp.decryptExtendedKey(keyType, encrypted) } -// TstGetMsgTx returns the withdrawal transaction with the given ntxid. +// TstGetMsgTx returns a copy of the withdrawal transaction with the given +// ntxid. func (s *WithdrawalStatus) TstGetMsgTx(ntxid Ntxid) *wire.MsgTx { - return s.transactions[ntxid].MsgTx + return s.transactions[ntxid].MsgTx.Copy() } diff --git a/votingpool/withdrawal.go b/votingpool/withdrawal.go index ac426e0..87002ad 100644 --- a/votingpool/withdrawal.go +++ b/votingpool/withdrawal.go @@ -20,6 +20,7 @@ import ( "bytes" "fmt" "math" + "reflect" "sort" "strconv" "time" @@ -28,6 +29,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/wtxmgr" "github.com/btcsuite/fastsha256" ) @@ -84,7 +86,7 @@ type WithdrawalOutput struct { outpoints []OutBailmentOutpoint } -// OutBailmentOutpoint represents one of the outpoints created to fulfil an OutputRequest. +// OutBailmentOutpoint represents one of the outpoints created to fulfill an OutputRequest. type OutBailmentOutpoint struct { ntxid Ntxid index uint32 @@ -110,6 +112,18 @@ type WithdrawalStatus struct { transactions map[Ntxid]changeAwareTx } +// withdrawalInfo contains all the details of an existing withdrawal, including +// the original request parameters and the WithdrawalStatus returned by +// StartWithdrawal. +type withdrawalInfo struct { + requests []OutputRequest + startAddress WithdrawalAddress + changeStart ChangeAddress + lastSeriesID uint32 + dustThreshold btcutil.Amount + status WithdrawalStatus +} + // TxSigs is list of raw signatures (one for every pubkey in the multi-sig // script) for a given transaction input. They should match the order of pubkeys // in the script and an empty RawSig should be used when the private key for a @@ -443,11 +457,21 @@ func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit, // signature lists (one for every private key available to this wallet) for each // of those transaction's inputs. More details about the actual algorithm can be // found at http://opentransactions.org/wiki/index.php/Startwithdrawal +// This method must be called with the address manager unlocked. func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest, startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress, txStore *wtxmgr.Store, chainHeight int32, dustThreshold btcutil.Amount) ( *WithdrawalStatus, error) { + status, err := getWithdrawalStatus(p, roundID, requests, startAddress, lastSeriesID, + changeStart, dustThreshold) + if err != nil { + return nil, err + } + if status != nil { + return status, nil + } + eligible, err := p.getEligibleInputs(txStore, startAddress, lastSeriesID, dustThreshold, chainHeight, eligibleInputMinConfirmations) if err != nil { @@ -463,6 +487,19 @@ func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest, return nil, err } + serialized, err := serializeWithdrawal(requests, startAddress, lastSeriesID, changeStart, + dustThreshold, *w.status) + if err != nil { + return nil, err + } + err = p.namespace.Update( + func(tx walletdb.Tx) error { + return putWithdrawal(tx, p.ID, roundID, serialized) + }) + if err != nil { + return nil, err + } + return w.status, nil } @@ -720,6 +757,74 @@ func (s *WithdrawalStatus) updateStatusFor(tx *withdrawalTx) { } } +// match returns true if the given arguments match the fields in this +// withdrawalInfo. For the requests slice, the order of the items does not +// matter. +func (wi *withdrawalInfo) match(requests []OutputRequest, startAddress WithdrawalAddress, + lastSeriesID uint32, changeStart ChangeAddress, dustThreshold btcutil.Amount) bool { + // Use reflect.DeepEqual to compare changeStart and startAddress as they're + // structs that contain pointers and we want to compare their content and + // not their address. + if !reflect.DeepEqual(changeStart, wi.changeStart) { + log.Debugf("withdrawal changeStart does not match: %v != %v", changeStart, wi.changeStart) + return false + } + if !reflect.DeepEqual(startAddress, wi.startAddress) { + log.Debugf("withdrawal startAddr does not match: %v != %v", startAddress, wi.startAddress) + return false + } + if lastSeriesID != wi.lastSeriesID { + log.Debugf("withdrawal lastSeriesID does not match: %v != %v", lastSeriesID, + wi.lastSeriesID) + return false + } + if dustThreshold != wi.dustThreshold { + log.Debugf("withdrawal dustThreshold does not match: %v != %v", dustThreshold, + wi.dustThreshold) + return false + } + r1 := make([]OutputRequest, len(requests)) + copy(r1, requests) + r2 := make([]OutputRequest, len(wi.requests)) + copy(r2, wi.requests) + sort.Sort(byOutBailmentID(r1)) + sort.Sort(byOutBailmentID(r2)) + if !reflect.DeepEqual(r1, r2) { + log.Debugf("withdrawal requests does not match: %v != %v", requests, wi.requests) + return false + } + return true +} + +// getWithdrawalStatus returns the existing WithdrawalStatus for the given +// withdrawal parameters, if one exists. This function must be called with the +// address manager unlocked. +func getWithdrawalStatus(p *Pool, roundID uint32, requests []OutputRequest, + startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress, + dustThreshold btcutil.Amount) (*WithdrawalStatus, error) { + + var serialized []byte + err := p.namespace.View( + func(tx walletdb.Tx) error { + serialized = getWithdrawal(tx, p.ID, roundID) + return nil + }) + if err != nil { + return nil, err + } + if bytes.Equal(serialized, []byte{}) { + return nil, nil + } + wInfo, err := deserializeWithdrawal(p, serialized) + if err != nil { + return nil, err + } + if wInfo.match(requests, startAddress, lastSeriesID, changeStart, dustThreshold) { + return &wInfo.status, nil + } + return nil, nil +} + // getRawSigs iterates over the inputs of each transaction given, constructing the // raw signatures for them using the private keys available to us. // It returns a map of ntxids to signature lists. diff --git a/votingpool/withdrawal_test.go b/votingpool/withdrawal_test.go index 51bb654..91ecad2 100644 --- a/votingpool/withdrawal_test.go +++ b/votingpool/withdrawal_test.go @@ -96,6 +96,18 @@ func TestStartWithdrawal(t *testing.T) { t.Fatal(err) } }) + + // Any subsequent StartWithdrawal() calls with the same parameters will + // return the previously stored WithdrawalStatus. + var status2 *vp.WithdrawalStatus + vp.TstRunWithManagerUnlocked(t, mgr, func() { + status2, err = pool.StartWithdrawal(0, requests, *startAddr, lastSeriesID, *changeStart, + store, currentBlock, dustThreshold) + }) + if err != nil { + t.Fatal(err) + } + vp.TstCheckWithdrawalStatusMatches(t, *status, *status2) } func checkWithdrawalOutputs( diff --git a/votingpool/withdrawal_wb_test.go b/votingpool/withdrawal_wb_test.go index 894f0f5..73bc9f7 100644 --- a/votingpool/withdrawal_wb_test.go +++ b/votingpool/withdrawal_wb_test.go @@ -28,6 +28,7 @@ import ( "github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil/hdkeychain" "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/wtxmgr" ) @@ -660,6 +661,122 @@ func TestWithdrawalTxOutputTotal(t *testing.T) { } } +func TestWithdrawalInfoMatch(t *testing.T) { + tearDown, _, pool := TstCreatePool(t) + defer tearDown() + + roundID := uint32(0) + wi := createAndFulfillWithdrawalRequests(t, pool, roundID) + + // Use freshly created values for requests, startAddress and changeStart + // to simulate what would happen if we had recreated them from the + // serialized data in the DB. + requestsCopy := make([]OutputRequest, len(wi.requests)) + copy(requestsCopy, wi.requests) + startAddr := TstNewWithdrawalAddress(t, pool, wi.startAddress.seriesID, wi.startAddress.branch, + wi.startAddress.index) + changeStart := TstNewChangeAddress(t, pool, wi.changeStart.seriesID, wi.changeStart.index) + + // First check that it matches when all fields are identical. + matches := wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold) + if !matches { + t.Fatal("Should match when everything is identical.") + } + + // It also matches if the OutputRequests are not in the same order. + diffOrderRequests := make([]OutputRequest, len(requestsCopy)) + copy(diffOrderRequests, requestsCopy) + diffOrderRequests[0], diffOrderRequests[1] = requestsCopy[1], requestsCopy[0] + matches = wi.match(diffOrderRequests, *startAddr, wi.lastSeriesID, *changeStart, + wi.dustThreshold) + if !matches { + t.Fatal("Should match when requests are in different order.") + } + + // It should not match when the OutputRequests are not the same. + diffRequests := diffOrderRequests + diffRequests[0] = OutputRequest{} + matches = wi.match(diffRequests, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold) + if matches { + t.Fatal("Should not match as requests is not equal.") + } + + // It should not match when lastSeriesID is not equal. + matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID+1, *changeStart, wi.dustThreshold) + if matches { + t.Fatal("Should not match as lastSeriesID is not equal.") + } + + // It should not match when dustThreshold is not equal. + matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold+1) + if matches { + t.Fatal("Should not match as dustThreshold is not equal.") + } + + // It should not match when startAddress is not equal. + diffStartAddr := TstNewWithdrawalAddress(t, pool, startAddr.seriesID, startAddr.branch+1, + startAddr.index) + matches = wi.match(requestsCopy, *diffStartAddr, wi.lastSeriesID, *changeStart, + wi.dustThreshold) + if matches { + t.Fatal("Should not match as startAddress is not equal.") + } + + // It should not match when changeStart is not equal. + diffChangeStart := TstNewChangeAddress(t, pool, changeStart.seriesID, changeStart.index+1) + matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *diffChangeStart, + wi.dustThreshold) + if matches { + t.Fatal("Should not match as changeStart is not equal.") + } +} + +func TestGetWithdrawalStatus(t *testing.T) { + tearDown, _, pool := TstCreatePool(t) + defer tearDown() + + roundID := uint32(0) + wi := createAndFulfillWithdrawalRequests(t, pool, roundID) + + serialized, err := serializeWithdrawal(wi.requests, wi.startAddress, wi.lastSeriesID, + wi.changeStart, wi.dustThreshold, wi.status) + if err != nil { + t.Fatal(err) + } + err = pool.namespace.Update( + func(tx walletdb.Tx) error { + return putWithdrawal(tx, pool.ID, roundID, serialized) + }) + if err != nil { + t.Fatal(err) + } + + // Here we should get a WithdrawalStatus that matches wi.status. + var status *WithdrawalStatus + TstRunWithManagerUnlocked(t, pool.Manager(), func() { + status, err = getWithdrawalStatus(pool, roundID, wi.requests, wi.startAddress, + wi.lastSeriesID, wi.changeStart, wi.dustThreshold) + }) + if err != nil { + t.Fatal(err) + } + TstCheckWithdrawalStatusMatches(t, wi.status, *status) + + // Here we should get a nil WithdrawalStatus because the parameters are not + // identical to those of the stored WithdrawalStatus with this roundID. + dustThreshold := wi.dustThreshold + 1 + TstRunWithManagerUnlocked(t, pool.Manager(), func() { + status, err = getWithdrawalStatus(pool, roundID, wi.requests, wi.startAddress, + wi.lastSeriesID, wi.changeStart, dustThreshold) + }) + if err != nil { + t.Fatal(err) + } + if status != nil { + t.Fatalf("Expected a nil status, got %v", status) + } +} + func TestSignMultiSigUTXO(t *testing.T) { tearDown, pool, _ := TstCreatePoolAndTxStore(t) defer tearDown()