StartWithdrawal returns a previously saved WithdrawalStatus if parameters match

StartWithdrawal now persists the WithdrawalStatus before returning, and also
returns a previously saved one in subsequent calls with the same parameters.
This commit is contained in:
Guilherme Salgado 2015-02-23 16:07:12 +00:00 committed by Dave Collins
parent 472d6b0c1e
commit 47ca1ca6e5
10 changed files with 631 additions and 5 deletions

View file

@ -19,6 +19,7 @@ package votingpool
import ( import (
"fmt" "fmt"
"os" "os"
"reflect"
"runtime" "runtime"
"testing" "testing"
@ -63,6 +64,42 @@ func TstRunWithManagerUnlocked(t *testing.T, mgr *waddrmgr.Manager, callback fun
callback() callback()
} }
// TstCheckWithdrawalStatusMatches compares s1 and s2 using reflect.DeepEqual
// and calls t.Fatal() if they're not identical.
func TstCheckWithdrawalStatusMatches(t *testing.T, s1, s2 WithdrawalStatus) {
if s1.Fees() != s2.Fees() {
t.Fatalf("Wrong amount of network fees; want %d, got %d", s1.Fees(), s2.Fees())
}
if !reflect.DeepEqual(s1.Sigs(), s2.Sigs()) {
t.Fatalf("Wrong tx signatures; got %x, want %x", s1.Sigs(), s2.Sigs())
}
if !reflect.DeepEqual(s1.NextInputAddr(), s2.NextInputAddr()) {
t.Fatalf("Wrong NextInputAddr; got %v, want %v", s1.NextInputAddr(), s2.NextInputAddr())
}
if !reflect.DeepEqual(s1.NextChangeAddr(), s2.NextChangeAddr()) {
t.Fatalf("Wrong NextChangeAddr; got %v, want %v", s1.NextChangeAddr(), s2.NextChangeAddr())
}
if !reflect.DeepEqual(s1.Outputs(), s2.Outputs()) {
t.Fatalf("Wrong WithdrawalOutputs; got %v, want %v", s1.Outputs(), s2.Outputs())
}
if !reflect.DeepEqual(s1.transactions, s2.transactions) {
t.Fatalf("Wrong transactions; got %v, want %v", s1.transactions, s2.transactions)
}
// The above checks could be replaced by this one, but when they fail the
// failure msg wouldn't give us much clue as to what is not equal, so we do
// the individual checks above and use this one as a catch-all check in case
// we forget to check any of the individual fields.
if !reflect.DeepEqual(s1, s2) {
t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2)
}
}
// replaceCalculateTxFee replaces the calculateTxFee func with the given one // replaceCalculateTxFee replaces the calculateTxFee func with the given one
// and returns a function that restores it to the original one. // and returns a function that restores it to the original one.
func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() { func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() {

View file

@ -19,8 +19,12 @@ package votingpool
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/gob"
"fmt" "fmt"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/snacl" "github.com/btcsuite/btcwallet/snacl"
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
) )
@ -45,6 +49,7 @@ const (
var ( var (
usedAddrsBucketName = []byte("usedaddrs") usedAddrsBucketName = []byte("usedaddrs")
seriesBucketName = []byte("series") seriesBucketName = []byte("series")
withdrawalsBucketName = []byte("withdrawals")
// string representing a non-existent private key // string representing a non-existent private key
seriesNullPrivKey = [seriesKeyLength]byte{} seriesNullPrivKey = [seriesKeyLength]byte{}
) )
@ -57,6 +62,61 @@ type dbSeriesRow struct {
privKeysEncrypted [][]byte privKeysEncrypted [][]byte
} }
type dbWithdrawalRow struct {
Requests []dbOutputRequest
StartAddress dbWithdrawalAddress
ChangeStart dbChangeAddress
LastSeriesID uint32
DustThreshold btcutil.Amount
Status dbWithdrawalStatus
}
type dbWithdrawalAddress struct {
SeriesID uint32
Branch Branch
Index Index
}
type dbChangeAddress struct {
SeriesID uint32
Index Index
}
type dbOutputRequest struct {
Addr string
Amount btcutil.Amount
Server string
Transaction uint32
}
type dbWithdrawalOutput struct {
// We store the OutBailmentID here as we need a way to look up the
// corresponding dbOutputRequest in dbWithdrawalRow when deserializing.
OutBailmentID OutBailmentID
Status outputStatus
Outpoints []dbOutBailmentOutpoint
}
type dbOutBailmentOutpoint struct {
Ntxid Ntxid
Index uint32
Amount btcutil.Amount
}
type dbChangeAwareTx struct {
SerializedMsgTx []byte
ChangeIdx int32
}
type dbWithdrawalStatus struct {
NextInputAddr dbWithdrawalAddress
NextChangeAddr dbChangeAddress
Fees btcutil.Amount
Outputs map[OutBailmentID]dbWithdrawalOutput
Sigs map[Ntxid]TxSigs
Transactions map[Ntxid]dbChangeAwareTx
}
// getUsedAddrBucketID returns the used addresses bucket ID for the given series // getUsedAddrBucketID returns the used addresses bucket ID for the given series
// and branch. It has the form seriesID:branch. // and branch. It has the form seriesID:branch.
func getUsedAddrBucketID(seriesID uint32, branch Branch) []byte { func getUsedAddrBucketID(seriesID uint32, branch Branch) []byte {
@ -140,6 +200,11 @@ func putPool(tx walletdb.Tx, poolID []byte) error {
return newError(ErrDatabase, fmt.Sprintf("cannot create used addrs bucket for pool %v", return newError(ErrDatabase, fmt.Sprintf("cannot create used addrs bucket for pool %v",
poolID), err) poolID), err)
} }
_, err = poolBucket.CreateBucket(withdrawalsBucketName)
if err != nil {
return newError(
ErrDatabase, fmt.Sprintf("cannot create withdrawals bucket for pool %v", poolID), err)
}
return nil return nil
} }
@ -339,6 +404,188 @@ func serializeSeriesRow(row *dbSeriesRow) ([]byte, error) {
return serialized, nil return serialized, nil
} }
// serializeWithdrawal constructs a dbWithdrawalRow and serializes it (using
// encoding/gob) so that it can be stored in the DB.
func serializeWithdrawal(requests []OutputRequest, startAddress WithdrawalAddress,
lastSeriesID uint32, changeStart ChangeAddress, dustThreshold btcutil.Amount,
status WithdrawalStatus) ([]byte, error) {
dbStartAddr := dbWithdrawalAddress{
SeriesID: startAddress.SeriesID(),
Branch: startAddress.Branch(),
Index: startAddress.Index(),
}
dbChangeStart := dbChangeAddress{
SeriesID: startAddress.SeriesID(),
Index: startAddress.Index(),
}
dbRequests := make([]dbOutputRequest, len(requests))
for i, request := range requests {
dbRequests[i] = dbOutputRequest{
Addr: request.Address.EncodeAddress(),
Amount: request.Amount,
Server: request.Server,
Transaction: request.Transaction,
}
}
dbOutputs := make(map[OutBailmentID]dbWithdrawalOutput, len(status.outputs))
for oid, output := range status.outputs {
dbOutpoints := make([]dbOutBailmentOutpoint, len(output.outpoints))
for i, outpoint := range output.outpoints {
dbOutpoints[i] = dbOutBailmentOutpoint{
Ntxid: outpoint.ntxid,
Index: outpoint.index,
Amount: outpoint.amount,
}
}
dbOutputs[oid] = dbWithdrawalOutput{
OutBailmentID: output.request.outBailmentID(),
Status: output.status,
Outpoints: dbOutpoints,
}
}
dbTransactions := make(map[Ntxid]dbChangeAwareTx, len(status.transactions))
for ntxid, tx := range status.transactions {
var buf bytes.Buffer
buf.Grow(tx.SerializeSize())
if err := tx.Serialize(&buf); err != nil {
return nil, err
}
dbTransactions[ntxid] = dbChangeAwareTx{
SerializedMsgTx: buf.Bytes(),
ChangeIdx: tx.changeIdx,
}
}
nextChange := status.nextChangeAddr
dbStatus := dbWithdrawalStatus{
NextChangeAddr: dbChangeAddress{
SeriesID: nextChange.seriesID,
Index: nextChange.index,
},
Fees: status.fees,
Outputs: dbOutputs,
Sigs: status.sigs,
Transactions: dbTransactions,
}
row := dbWithdrawalRow{
Requests: dbRequests,
StartAddress: dbStartAddr,
LastSeriesID: lastSeriesID,
ChangeStart: dbChangeStart,
DustThreshold: dustThreshold,
Status: dbStatus,
}
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(row); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// deserializeWithdrawal deserializes the given byte slice into a dbWithdrawalRow,
// converts it into an withdrawalInfo and returns it. This function must run
// with the address manager unlocked.
func deserializeWithdrawal(p *Pool, serialized []byte) (*withdrawalInfo, error) {
var row dbWithdrawalRow
if err := gob.NewDecoder(bytes.NewReader(serialized)).Decode(&row); err != nil {
return nil, newError(ErrWithdrawalStorage, "cannot deserialize withdrawal information",
err)
}
wInfo := &withdrawalInfo{
lastSeriesID: row.LastSeriesID,
dustThreshold: row.DustThreshold,
}
chainParams := p.Manager().ChainParams()
wInfo.requests = make([]OutputRequest, len(row.Requests))
// A map of requests indexed by OutBailmentID; needed to populate
// WithdrawalStatus.Outputs later on.
requestsByOID := make(map[OutBailmentID]OutputRequest)
for i, req := range row.Requests {
addr, err := btcutil.DecodeAddress(req.Addr, chainParams)
if err != nil {
return nil, newError(ErrWithdrawalStorage,
"cannot deserialize addr for requested output", err)
}
pkScript, err := txscript.PayToAddrScript(addr)
if err != nil {
return nil, newError(ErrWithdrawalStorage, "invalid addr for requested output", err)
}
request := OutputRequest{
Address: addr,
Amount: req.Amount,
PkScript: pkScript,
Server: req.Server,
Transaction: req.Transaction,
}
wInfo.requests[i] = request
requestsByOID[request.outBailmentID()] = request
}
startAddr := row.StartAddress
wAddr, err := p.WithdrawalAddress(startAddr.SeriesID, startAddr.Branch, startAddr.Index)
if err != nil {
return nil, newError(ErrWithdrawalStorage, "cannot deserialize startAddress", err)
}
wInfo.startAddress = *wAddr
cAddr, err := p.ChangeAddress(row.ChangeStart.SeriesID, row.ChangeStart.Index)
if err != nil {
return nil, newError(ErrWithdrawalStorage, "cannot deserialize changeStart", err)
}
wInfo.changeStart = *cAddr
// TODO: Copy over row.Status.nextInputAddr. Not done because StartWithdrawal
// does not update that yet.
nextChangeAddr := row.Status.NextChangeAddr
cAddr, err = p.ChangeAddress(nextChangeAddr.SeriesID, nextChangeAddr.Index)
if err != nil {
return nil, newError(ErrWithdrawalStorage,
"cannot deserialize nextChangeAddress for withdrawal", err)
}
wInfo.status = WithdrawalStatus{
nextChangeAddr: *cAddr,
fees: row.Status.Fees,
outputs: make(map[OutBailmentID]*WithdrawalOutput, len(row.Status.Outputs)),
sigs: row.Status.Sigs,
transactions: make(map[Ntxid]changeAwareTx, len(row.Status.Transactions)),
}
for oid, output := range row.Status.Outputs {
outpoints := make([]OutBailmentOutpoint, len(output.Outpoints))
for i, outpoint := range output.Outpoints {
outpoints[i] = OutBailmentOutpoint{
ntxid: outpoint.Ntxid,
index: outpoint.Index,
amount: outpoint.Amount,
}
}
wInfo.status.outputs[oid] = &WithdrawalOutput{
request: requestsByOID[output.OutBailmentID],
status: output.Status,
outpoints: outpoints,
}
}
for ntxid, tx := range row.Status.Transactions {
msgtx := wire.NewMsgTx()
if err := msgtx.Deserialize(bytes.NewBuffer(tx.SerializedMsgTx)); err != nil {
return nil, newError(ErrWithdrawalStorage, "cannot deserialize transaction", err)
}
wInfo.status.transactions[ntxid] = changeAwareTx{
MsgTx: msgtx,
changeIdx: tx.ChangeIdx,
}
}
return wInfo, nil
}
func putWithdrawal(tx walletdb.Tx, poolID []byte, roundID uint32, serialized []byte) error {
bucket := tx.RootBucket().Bucket(poolID)
return bucket.Put(uint32ToBytes(roundID), serialized)
}
func getWithdrawal(tx walletdb.Tx, poolID []byte, roundID uint32) []byte {
bucket := tx.RootBucket().Bucket(poolID)
return bucket.Get(uint32ToBytes(roundID))
}
// uint32ToBytes converts a 32 bit unsigned integer into a 4-byte slice in // uint32ToBytes converts a 32 bit unsigned integer into a 4-byte slice in
// little-endian order: 1 -> [1 0 0 0]. // little-endian order: 1 -> [1 0 0 0].
func uint32ToBytes(number uint32) []byte { func uint32ToBytes(number uint32) []byte {

View file

@ -18,6 +18,7 @@ package votingpool
import ( import (
"bytes" "bytes"
"reflect"
"testing" "testing"
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
@ -80,3 +81,77 @@ func TestGetMaxUsedIdx(t *testing.T) {
t.Fatalf("Wrong max idx; got %d, want %d", maxIdx, Index(3001)) t.Fatalf("Wrong max idx; got %d, want %d", maxIdx, Index(3001))
} }
} }
func TestWithdrawalSerialization(t *testing.T) {
tearDown, _, pool := TstCreatePool(t)
defer tearDown()
roundID := uint32(0)
wi := createAndFulfillWithdrawalRequests(t, pool, roundID)
serialized, err := serializeWithdrawal(wi.requests, wi.startAddress, wi.lastSeriesID,
wi.changeStart, wi.dustThreshold, wi.status)
if err != nil {
t.Fatal(err)
}
var wInfo *withdrawalInfo
TstRunWithManagerUnlocked(t, pool.Manager(), func() {
wInfo, err = deserializeWithdrawal(pool, serialized)
if err != nil {
t.Fatal(err)
}
})
if !reflect.DeepEqual(wInfo.startAddress, wi.startAddress) {
t.Fatalf("Wrong startAddr; got %v, want %v", wInfo.startAddress, wi.startAddress)
}
if !reflect.DeepEqual(wInfo.changeStart, wi.changeStart) {
t.Fatalf("Wrong changeStart; got %v, want %v", wInfo.changeStart, wi.changeStart)
}
if wInfo.lastSeriesID != wi.lastSeriesID {
t.Fatalf("Wrong LastSeriesID; got %d, want %d", wInfo.lastSeriesID, wi.lastSeriesID)
}
if wInfo.dustThreshold != wi.dustThreshold {
t.Fatalf("Wrong DustThreshold; got %d, want %d", wInfo.dustThreshold, wi.dustThreshold)
}
if !reflect.DeepEqual(wInfo.requests, wi.requests) {
t.Fatalf("Wrong output requests; got %v, want %v", wInfo.requests, wi.requests)
}
TstCheckWithdrawalStatusMatches(t, wInfo.status, wi.status)
}
func TestPutAndGetWithdrawal(t *testing.T) {
tearDown, _, pool := TstCreatePool(t)
defer tearDown()
serialized := bytes.Repeat([]byte{1}, 10)
poolID := []byte{0x00}
roundID := uint32(0)
err := pool.namespace.Update(
func(tx walletdb.Tx) error {
return putWithdrawal(tx, poolID, roundID, serialized)
})
if err != nil {
t.Fatal(err)
}
var retrieved []byte
err = pool.namespace.View(
func(tx walletdb.Tx) error {
retrieved = getWithdrawal(tx, poolID, roundID)
return nil
})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(retrieved, serialized) {
t.Fatalf("Wrong value retrieved from DB; got %x, want %x", retrieved, serialized)
}
}

View file

@ -151,6 +151,10 @@ const (
// transactions. // transactions.
ErrWithdrawalTxStorage ErrWithdrawalTxStorage
// ErrWithdrawalStorage indicates an error occurred when serializing or
// deserializing withdrawal information.
ErrWithdrawalStorage
// lastErr is used for testing, making it possible to iterate over // lastErr is used for testing, making it possible to iterate over
// the error codes in order to check that they all have proper // the error codes in order to check that they all have proper
// translations in errorCodeStrings. // translations in errorCodeStrings.
@ -192,6 +196,7 @@ var errorCodeStrings = map[ErrorCode]string{
ErrInvalidScriptHash: "ErrInvalidScriptHash", ErrInvalidScriptHash: "ErrInvalidScriptHash",
ErrWithdrawFromUnusedAddr: "ErrWithdrawFromUnusedAddr", ErrWithdrawFromUnusedAddr: "ErrWithdrawFromUnusedAddr",
ErrWithdrawalTxStorage: "ErrWithdrawalTxStorage", ErrWithdrawalTxStorage: "ErrWithdrawalTxStorage",
ErrWithdrawalStorage: "ErrWithdrawalStorage",
} }
// String returns the ErrorCode as a human-readable name. // String returns the ErrorCode as a human-readable name.

View file

@ -65,6 +65,7 @@ func TestErrorCodeStringer(t *testing.T) {
{vp.ErrInvalidScriptHash, "ErrInvalidScriptHash"}, {vp.ErrInvalidScriptHash, "ErrInvalidScriptHash"},
{vp.ErrWithdrawFromUnusedAddr, "ErrWithdrawFromUnusedAddr"}, {vp.ErrWithdrawFromUnusedAddr, "ErrWithdrawFromUnusedAddr"},
{vp.ErrWithdrawalTxStorage, "ErrWithdrawalTxStorage"}, {vp.ErrWithdrawalTxStorage, "ErrWithdrawalTxStorage"},
{vp.ErrWithdrawalStorage, "ErrWithdrawalStorage"},
{0xffff, "Unknown ErrorCode (65535)"}, {0xffff, "Unknown ErrorCode (65535)"},
} }

View file

@ -421,3 +421,29 @@ func TstNewChangeAddress(t *testing.T, p *Pool, seriesID uint32, idx Index) (add
func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount { func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount {
return func(tx *withdrawalTx) btcutil.Amount { return fee } return func(tx *withdrawalTx) btcutil.Amount { return fee }
} }
func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo {
params := pool.Manager().ChainParams()
seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{2e6, 4e6})
requests := []OutputRequest{
TstNewOutputRequest(t, 1, "34eVkREKgvvGASZW7hkgE2uNc1yycntMK6", 3e6, params),
TstNewOutputRequest(t, 2, "3PbExiaztsSYgh6zeMswC49hLUwhTQ86XG", 2e6, params),
}
changeStart := TstNewChangeAddress(t, pool, seriesID, 0)
dustThreshold := btcutil.Amount(1e4)
startAddr := TstNewWithdrawalAddress(t, pool, seriesID, 1, 0)
lastSeriesID := seriesID
w := newWithdrawal(roundID, requests, eligible, *changeStart)
if err := w.fulfillRequests(); err != nil {
t.Fatal(err)
}
return withdrawalInfo{
requests: requests,
startAddress: *startAddr,
changeStart: *changeStart,
lastSeriesID: lastSeriesID,
dustThreshold: dustThreshold,
status: *w.status,
}
}

View file

@ -96,7 +96,8 @@ func (vp *Pool) TstDecryptExtendedKey(keyType waddrmgr.CryptoKeyType, encrypted
return vp.decryptExtendedKey(keyType, encrypted) return vp.decryptExtendedKey(keyType, encrypted)
} }
// TstGetMsgTx returns the withdrawal transaction with the given ntxid. // TstGetMsgTx returns a copy of the withdrawal transaction with the given
// ntxid.
func (s *WithdrawalStatus) TstGetMsgTx(ntxid Ntxid) *wire.MsgTx { func (s *WithdrawalStatus) TstGetMsgTx(ntxid Ntxid) *wire.MsgTx {
return s.transactions[ntxid].MsgTx return s.transactions[ntxid].MsgTx.Copy()
} }

View file

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"math" "math"
"reflect"
"sort" "sort"
"strconv" "strconv"
"time" "time"
@ -28,6 +29,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/wtxmgr" "github.com/btcsuite/btcwallet/wtxmgr"
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
) )
@ -84,7 +86,7 @@ type WithdrawalOutput struct {
outpoints []OutBailmentOutpoint outpoints []OutBailmentOutpoint
} }
// OutBailmentOutpoint represents one of the outpoints created to fulfil an OutputRequest. // OutBailmentOutpoint represents one of the outpoints created to fulfill an OutputRequest.
type OutBailmentOutpoint struct { type OutBailmentOutpoint struct {
ntxid Ntxid ntxid Ntxid
index uint32 index uint32
@ -110,6 +112,18 @@ type WithdrawalStatus struct {
transactions map[Ntxid]changeAwareTx transactions map[Ntxid]changeAwareTx
} }
// withdrawalInfo contains all the details of an existing withdrawal, including
// the original request parameters and the WithdrawalStatus returned by
// StartWithdrawal.
type withdrawalInfo struct {
requests []OutputRequest
startAddress WithdrawalAddress
changeStart ChangeAddress
lastSeriesID uint32
dustThreshold btcutil.Amount
status WithdrawalStatus
}
// TxSigs is list of raw signatures (one for every pubkey in the multi-sig // TxSigs is list of raw signatures (one for every pubkey in the multi-sig
// script) for a given transaction input. They should match the order of pubkeys // script) for a given transaction input. They should match the order of pubkeys
// in the script and an empty RawSig should be used when the private key for a // in the script and an empty RawSig should be used when the private key for a
@ -443,11 +457,21 @@ func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit,
// signature lists (one for every private key available to this wallet) for each // signature lists (one for every private key available to this wallet) for each
// of those transaction's inputs. More details about the actual algorithm can be // of those transaction's inputs. More details about the actual algorithm can be
// found at http://opentransactions.org/wiki/index.php/Startwithdrawal // found at http://opentransactions.org/wiki/index.php/Startwithdrawal
// This method must be called with the address manager unlocked.
func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest, func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest,
startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress, startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress,
txStore *wtxmgr.Store, chainHeight int32, dustThreshold btcutil.Amount) ( txStore *wtxmgr.Store, chainHeight int32, dustThreshold btcutil.Amount) (
*WithdrawalStatus, error) { *WithdrawalStatus, error) {
status, err := getWithdrawalStatus(p, roundID, requests, startAddress, lastSeriesID,
changeStart, dustThreshold)
if err != nil {
return nil, err
}
if status != nil {
return status, nil
}
eligible, err := p.getEligibleInputs(txStore, startAddress, lastSeriesID, dustThreshold, eligible, err := p.getEligibleInputs(txStore, startAddress, lastSeriesID, dustThreshold,
chainHeight, eligibleInputMinConfirmations) chainHeight, eligibleInputMinConfirmations)
if err != nil { if err != nil {
@ -463,6 +487,19 @@ func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest,
return nil, err return nil, err
} }
serialized, err := serializeWithdrawal(requests, startAddress, lastSeriesID, changeStart,
dustThreshold, *w.status)
if err != nil {
return nil, err
}
err = p.namespace.Update(
func(tx walletdb.Tx) error {
return putWithdrawal(tx, p.ID, roundID, serialized)
})
if err != nil {
return nil, err
}
return w.status, nil return w.status, nil
} }
@ -720,6 +757,74 @@ func (s *WithdrawalStatus) updateStatusFor(tx *withdrawalTx) {
} }
} }
// match returns true if the given arguments match the fields in this
// withdrawalInfo. For the requests slice, the order of the items does not
// matter.
func (wi *withdrawalInfo) match(requests []OutputRequest, startAddress WithdrawalAddress,
lastSeriesID uint32, changeStart ChangeAddress, dustThreshold btcutil.Amount) bool {
// Use reflect.DeepEqual to compare changeStart and startAddress as they're
// structs that contain pointers and we want to compare their content and
// not their address.
if !reflect.DeepEqual(changeStart, wi.changeStart) {
log.Debugf("withdrawal changeStart does not match: %v != %v", changeStart, wi.changeStart)
return false
}
if !reflect.DeepEqual(startAddress, wi.startAddress) {
log.Debugf("withdrawal startAddr does not match: %v != %v", startAddress, wi.startAddress)
return false
}
if lastSeriesID != wi.lastSeriesID {
log.Debugf("withdrawal lastSeriesID does not match: %v != %v", lastSeriesID,
wi.lastSeriesID)
return false
}
if dustThreshold != wi.dustThreshold {
log.Debugf("withdrawal dustThreshold does not match: %v != %v", dustThreshold,
wi.dustThreshold)
return false
}
r1 := make([]OutputRequest, len(requests))
copy(r1, requests)
r2 := make([]OutputRequest, len(wi.requests))
copy(r2, wi.requests)
sort.Sort(byOutBailmentID(r1))
sort.Sort(byOutBailmentID(r2))
if !reflect.DeepEqual(r1, r2) {
log.Debugf("withdrawal requests does not match: %v != %v", requests, wi.requests)
return false
}
return true
}
// getWithdrawalStatus returns the existing WithdrawalStatus for the given
// withdrawal parameters, if one exists. This function must be called with the
// address manager unlocked.
func getWithdrawalStatus(p *Pool, roundID uint32, requests []OutputRequest,
startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress,
dustThreshold btcutil.Amount) (*WithdrawalStatus, error) {
var serialized []byte
err := p.namespace.View(
func(tx walletdb.Tx) error {
serialized = getWithdrawal(tx, p.ID, roundID)
return nil
})
if err != nil {
return nil, err
}
if bytes.Equal(serialized, []byte{}) {
return nil, nil
}
wInfo, err := deserializeWithdrawal(p, serialized)
if err != nil {
return nil, err
}
if wInfo.match(requests, startAddress, lastSeriesID, changeStart, dustThreshold) {
return &wInfo.status, nil
}
return nil, nil
}
// getRawSigs iterates over the inputs of each transaction given, constructing the // getRawSigs iterates over the inputs of each transaction given, constructing the
// raw signatures for them using the private keys available to us. // raw signatures for them using the private keys available to us.
// It returns a map of ntxids to signature lists. // It returns a map of ntxids to signature lists.

View file

@ -96,6 +96,18 @@ func TestStartWithdrawal(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}) })
// Any subsequent StartWithdrawal() calls with the same parameters will
// return the previously stored WithdrawalStatus.
var status2 *vp.WithdrawalStatus
vp.TstRunWithManagerUnlocked(t, mgr, func() {
status2, err = pool.StartWithdrawal(0, requests, *startAddr, lastSeriesID, *changeStart,
store, currentBlock, dustThreshold)
})
if err != nil {
t.Fatal(err)
}
vp.TstCheckWithdrawalStatusMatches(t, *status, *status2)
} }
func checkWithdrawalOutputs( func checkWithdrawalOutputs(

View file

@ -28,6 +28,7 @@ import (
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/hdkeychain" "github.com/btcsuite/btcutil/hdkeychain"
"github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/wtxmgr" "github.com/btcsuite/btcwallet/wtxmgr"
) )
@ -660,6 +661,122 @@ func TestWithdrawalTxOutputTotal(t *testing.T) {
} }
} }
func TestWithdrawalInfoMatch(t *testing.T) {
tearDown, _, pool := TstCreatePool(t)
defer tearDown()
roundID := uint32(0)
wi := createAndFulfillWithdrawalRequests(t, pool, roundID)
// Use freshly created values for requests, startAddress and changeStart
// to simulate what would happen if we had recreated them from the
// serialized data in the DB.
requestsCopy := make([]OutputRequest, len(wi.requests))
copy(requestsCopy, wi.requests)
startAddr := TstNewWithdrawalAddress(t, pool, wi.startAddress.seriesID, wi.startAddress.branch,
wi.startAddress.index)
changeStart := TstNewChangeAddress(t, pool, wi.changeStart.seriesID, wi.changeStart.index)
// First check that it matches when all fields are identical.
matches := wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold)
if !matches {
t.Fatal("Should match when everything is identical.")
}
// It also matches if the OutputRequests are not in the same order.
diffOrderRequests := make([]OutputRequest, len(requestsCopy))
copy(diffOrderRequests, requestsCopy)
diffOrderRequests[0], diffOrderRequests[1] = requestsCopy[1], requestsCopy[0]
matches = wi.match(diffOrderRequests, *startAddr, wi.lastSeriesID, *changeStart,
wi.dustThreshold)
if !matches {
t.Fatal("Should match when requests are in different order.")
}
// It should not match when the OutputRequests are not the same.
diffRequests := diffOrderRequests
diffRequests[0] = OutputRequest{}
matches = wi.match(diffRequests, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold)
if matches {
t.Fatal("Should not match as requests is not equal.")
}
// It should not match when lastSeriesID is not equal.
matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID+1, *changeStart, wi.dustThreshold)
if matches {
t.Fatal("Should not match as lastSeriesID is not equal.")
}
// It should not match when dustThreshold is not equal.
matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold+1)
if matches {
t.Fatal("Should not match as dustThreshold is not equal.")
}
// It should not match when startAddress is not equal.
diffStartAddr := TstNewWithdrawalAddress(t, pool, startAddr.seriesID, startAddr.branch+1,
startAddr.index)
matches = wi.match(requestsCopy, *diffStartAddr, wi.lastSeriesID, *changeStart,
wi.dustThreshold)
if matches {
t.Fatal("Should not match as startAddress is not equal.")
}
// It should not match when changeStart is not equal.
diffChangeStart := TstNewChangeAddress(t, pool, changeStart.seriesID, changeStart.index+1)
matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *diffChangeStart,
wi.dustThreshold)
if matches {
t.Fatal("Should not match as changeStart is not equal.")
}
}
func TestGetWithdrawalStatus(t *testing.T) {
tearDown, _, pool := TstCreatePool(t)
defer tearDown()
roundID := uint32(0)
wi := createAndFulfillWithdrawalRequests(t, pool, roundID)
serialized, err := serializeWithdrawal(wi.requests, wi.startAddress, wi.lastSeriesID,
wi.changeStart, wi.dustThreshold, wi.status)
if err != nil {
t.Fatal(err)
}
err = pool.namespace.Update(
func(tx walletdb.Tx) error {
return putWithdrawal(tx, pool.ID, roundID, serialized)
})
if err != nil {
t.Fatal(err)
}
// Here we should get a WithdrawalStatus that matches wi.status.
var status *WithdrawalStatus
TstRunWithManagerUnlocked(t, pool.Manager(), func() {
status, err = getWithdrawalStatus(pool, roundID, wi.requests, wi.startAddress,
wi.lastSeriesID, wi.changeStart, wi.dustThreshold)
})
if err != nil {
t.Fatal(err)
}
TstCheckWithdrawalStatusMatches(t, wi.status, *status)
// Here we should get a nil WithdrawalStatus because the parameters are not
// identical to those of the stored WithdrawalStatus with this roundID.
dustThreshold := wi.dustThreshold + 1
TstRunWithManagerUnlocked(t, pool.Manager(), func() {
status, err = getWithdrawalStatus(pool, roundID, wi.requests, wi.startAddress,
wi.lastSeriesID, wi.changeStart, dustThreshold)
})
if err != nil {
t.Fatal(err)
}
if status != nil {
t.Fatalf("Expected a nil status, got %v", status)
}
}
func TestSignMultiSigUTXO(t *testing.T) { func TestSignMultiSigUTXO(t *testing.T) {
tearDown, pool, _ := TstCreatePoolAndTxStore(t) tearDown, pool, _ := TstCreatePoolAndTxStore(t)
defer tearDown() defer tearDown()