StartWithdrawal returns a previously saved WithdrawalStatus if parameters match
StartWithdrawal now persists the WithdrawalStatus before returning, and also returns a previously saved one in subsequent calls with the same parameters.
This commit is contained in:
parent
472d6b0c1e
commit
47ca1ca6e5
10 changed files with 631 additions and 5 deletions
|
@ -19,6 +19,7 @@ package votingpool
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
|
@ -63,6 +64,42 @@ func TstRunWithManagerUnlocked(t *testing.T, mgr *waddrmgr.Manager, callback fun
|
|||
callback()
|
||||
}
|
||||
|
||||
// TstCheckWithdrawalStatusMatches compares s1 and s2 using reflect.DeepEqual
|
||||
// and calls t.Fatal() if they're not identical.
|
||||
func TstCheckWithdrawalStatusMatches(t *testing.T, s1, s2 WithdrawalStatus) {
|
||||
if s1.Fees() != s2.Fees() {
|
||||
t.Fatalf("Wrong amount of network fees; want %d, got %d", s1.Fees(), s2.Fees())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(s1.Sigs(), s2.Sigs()) {
|
||||
t.Fatalf("Wrong tx signatures; got %x, want %x", s1.Sigs(), s2.Sigs())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(s1.NextInputAddr(), s2.NextInputAddr()) {
|
||||
t.Fatalf("Wrong NextInputAddr; got %v, want %v", s1.NextInputAddr(), s2.NextInputAddr())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(s1.NextChangeAddr(), s2.NextChangeAddr()) {
|
||||
t.Fatalf("Wrong NextChangeAddr; got %v, want %v", s1.NextChangeAddr(), s2.NextChangeAddr())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(s1.Outputs(), s2.Outputs()) {
|
||||
t.Fatalf("Wrong WithdrawalOutputs; got %v, want %v", s1.Outputs(), s2.Outputs())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(s1.transactions, s2.transactions) {
|
||||
t.Fatalf("Wrong transactions; got %v, want %v", s1.transactions, s2.transactions)
|
||||
}
|
||||
|
||||
// The above checks could be replaced by this one, but when they fail the
|
||||
// failure msg wouldn't give us much clue as to what is not equal, so we do
|
||||
// the individual checks above and use this one as a catch-all check in case
|
||||
// we forget to check any of the individual fields.
|
||||
if !reflect.DeepEqual(s1, s2) {
|
||||
t.Fatalf("Wrong WithdrawalStatus; got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
|
||||
// replaceCalculateTxFee replaces the calculateTxFee func with the given one
|
||||
// and returns a function that restores it to the original one.
|
||||
func replaceCalculateTxFee(f func(*withdrawalTx) btcutil.Amount) func() {
|
||||
|
|
251
votingpool/db.go
251
votingpool/db.go
|
@ -19,8 +19,12 @@ package votingpool
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/btcsuite/btcwallet/snacl"
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
)
|
||||
|
@ -43,8 +47,9 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
usedAddrsBucketName = []byte("usedaddrs")
|
||||
seriesBucketName = []byte("series")
|
||||
usedAddrsBucketName = []byte("usedaddrs")
|
||||
seriesBucketName = []byte("series")
|
||||
withdrawalsBucketName = []byte("withdrawals")
|
||||
// string representing a non-existent private key
|
||||
seriesNullPrivKey = [seriesKeyLength]byte{}
|
||||
)
|
||||
|
@ -57,6 +62,61 @@ type dbSeriesRow struct {
|
|||
privKeysEncrypted [][]byte
|
||||
}
|
||||
|
||||
type dbWithdrawalRow struct {
|
||||
Requests []dbOutputRequest
|
||||
StartAddress dbWithdrawalAddress
|
||||
ChangeStart dbChangeAddress
|
||||
LastSeriesID uint32
|
||||
DustThreshold btcutil.Amount
|
||||
Status dbWithdrawalStatus
|
||||
}
|
||||
|
||||
type dbWithdrawalAddress struct {
|
||||
SeriesID uint32
|
||||
Branch Branch
|
||||
Index Index
|
||||
}
|
||||
|
||||
type dbChangeAddress struct {
|
||||
SeriesID uint32
|
||||
Index Index
|
||||
}
|
||||
|
||||
type dbOutputRequest struct {
|
||||
Addr string
|
||||
Amount btcutil.Amount
|
||||
Server string
|
||||
Transaction uint32
|
||||
}
|
||||
|
||||
type dbWithdrawalOutput struct {
|
||||
// We store the OutBailmentID here as we need a way to look up the
|
||||
// corresponding dbOutputRequest in dbWithdrawalRow when deserializing.
|
||||
OutBailmentID OutBailmentID
|
||||
Status outputStatus
|
||||
Outpoints []dbOutBailmentOutpoint
|
||||
}
|
||||
|
||||
type dbOutBailmentOutpoint struct {
|
||||
Ntxid Ntxid
|
||||
Index uint32
|
||||
Amount btcutil.Amount
|
||||
}
|
||||
|
||||
type dbChangeAwareTx struct {
|
||||
SerializedMsgTx []byte
|
||||
ChangeIdx int32
|
||||
}
|
||||
|
||||
type dbWithdrawalStatus struct {
|
||||
NextInputAddr dbWithdrawalAddress
|
||||
NextChangeAddr dbChangeAddress
|
||||
Fees btcutil.Amount
|
||||
Outputs map[OutBailmentID]dbWithdrawalOutput
|
||||
Sigs map[Ntxid]TxSigs
|
||||
Transactions map[Ntxid]dbChangeAwareTx
|
||||
}
|
||||
|
||||
// getUsedAddrBucketID returns the used addresses bucket ID for the given series
|
||||
// and branch. It has the form seriesID:branch.
|
||||
func getUsedAddrBucketID(seriesID uint32, branch Branch) []byte {
|
||||
|
@ -140,6 +200,11 @@ func putPool(tx walletdb.Tx, poolID []byte) error {
|
|||
return newError(ErrDatabase, fmt.Sprintf("cannot create used addrs bucket for pool %v",
|
||||
poolID), err)
|
||||
}
|
||||
_, err = poolBucket.CreateBucket(withdrawalsBucketName)
|
||||
if err != nil {
|
||||
return newError(
|
||||
ErrDatabase, fmt.Sprintf("cannot create withdrawals bucket for pool %v", poolID), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -339,6 +404,188 @@ func serializeSeriesRow(row *dbSeriesRow) ([]byte, error) {
|
|||
return serialized, nil
|
||||
}
|
||||
|
||||
// serializeWithdrawal constructs a dbWithdrawalRow and serializes it (using
|
||||
// encoding/gob) so that it can be stored in the DB.
|
||||
func serializeWithdrawal(requests []OutputRequest, startAddress WithdrawalAddress,
|
||||
lastSeriesID uint32, changeStart ChangeAddress, dustThreshold btcutil.Amount,
|
||||
status WithdrawalStatus) ([]byte, error) {
|
||||
|
||||
dbStartAddr := dbWithdrawalAddress{
|
||||
SeriesID: startAddress.SeriesID(),
|
||||
Branch: startAddress.Branch(),
|
||||
Index: startAddress.Index(),
|
||||
}
|
||||
dbChangeStart := dbChangeAddress{
|
||||
SeriesID: startAddress.SeriesID(),
|
||||
Index: startAddress.Index(),
|
||||
}
|
||||
dbRequests := make([]dbOutputRequest, len(requests))
|
||||
for i, request := range requests {
|
||||
dbRequests[i] = dbOutputRequest{
|
||||
Addr: request.Address.EncodeAddress(),
|
||||
Amount: request.Amount,
|
||||
Server: request.Server,
|
||||
Transaction: request.Transaction,
|
||||
}
|
||||
}
|
||||
dbOutputs := make(map[OutBailmentID]dbWithdrawalOutput, len(status.outputs))
|
||||
for oid, output := range status.outputs {
|
||||
dbOutpoints := make([]dbOutBailmentOutpoint, len(output.outpoints))
|
||||
for i, outpoint := range output.outpoints {
|
||||
dbOutpoints[i] = dbOutBailmentOutpoint{
|
||||
Ntxid: outpoint.ntxid,
|
||||
Index: outpoint.index,
|
||||
Amount: outpoint.amount,
|
||||
}
|
||||
}
|
||||
dbOutputs[oid] = dbWithdrawalOutput{
|
||||
OutBailmentID: output.request.outBailmentID(),
|
||||
Status: output.status,
|
||||
Outpoints: dbOutpoints,
|
||||
}
|
||||
}
|
||||
dbTransactions := make(map[Ntxid]dbChangeAwareTx, len(status.transactions))
|
||||
for ntxid, tx := range status.transactions {
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(tx.SerializeSize())
|
||||
if err := tx.Serialize(&buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbTransactions[ntxid] = dbChangeAwareTx{
|
||||
SerializedMsgTx: buf.Bytes(),
|
||||
ChangeIdx: tx.changeIdx,
|
||||
}
|
||||
}
|
||||
nextChange := status.nextChangeAddr
|
||||
dbStatus := dbWithdrawalStatus{
|
||||
NextChangeAddr: dbChangeAddress{
|
||||
SeriesID: nextChange.seriesID,
|
||||
Index: nextChange.index,
|
||||
},
|
||||
Fees: status.fees,
|
||||
Outputs: dbOutputs,
|
||||
Sigs: status.sigs,
|
||||
Transactions: dbTransactions,
|
||||
}
|
||||
row := dbWithdrawalRow{
|
||||
Requests: dbRequests,
|
||||
StartAddress: dbStartAddr,
|
||||
LastSeriesID: lastSeriesID,
|
||||
ChangeStart: dbChangeStart,
|
||||
DustThreshold: dustThreshold,
|
||||
Status: dbStatus,
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := gob.NewEncoder(&buf).Encode(row); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// deserializeWithdrawal deserializes the given byte slice into a dbWithdrawalRow,
|
||||
// converts it into an withdrawalInfo and returns it. This function must run
|
||||
// with the address manager unlocked.
|
||||
func deserializeWithdrawal(p *Pool, serialized []byte) (*withdrawalInfo, error) {
|
||||
var row dbWithdrawalRow
|
||||
if err := gob.NewDecoder(bytes.NewReader(serialized)).Decode(&row); err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage, "cannot deserialize withdrawal information",
|
||||
err)
|
||||
}
|
||||
wInfo := &withdrawalInfo{
|
||||
lastSeriesID: row.LastSeriesID,
|
||||
dustThreshold: row.DustThreshold,
|
||||
}
|
||||
chainParams := p.Manager().ChainParams()
|
||||
wInfo.requests = make([]OutputRequest, len(row.Requests))
|
||||
// A map of requests indexed by OutBailmentID; needed to populate
|
||||
// WithdrawalStatus.Outputs later on.
|
||||
requestsByOID := make(map[OutBailmentID]OutputRequest)
|
||||
for i, req := range row.Requests {
|
||||
addr, err := btcutil.DecodeAddress(req.Addr, chainParams)
|
||||
if err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage,
|
||||
"cannot deserialize addr for requested output", err)
|
||||
}
|
||||
pkScript, err := txscript.PayToAddrScript(addr)
|
||||
if err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage, "invalid addr for requested output", err)
|
||||
}
|
||||
request := OutputRequest{
|
||||
Address: addr,
|
||||
Amount: req.Amount,
|
||||
PkScript: pkScript,
|
||||
Server: req.Server,
|
||||
Transaction: req.Transaction,
|
||||
}
|
||||
wInfo.requests[i] = request
|
||||
requestsByOID[request.outBailmentID()] = request
|
||||
}
|
||||
startAddr := row.StartAddress
|
||||
wAddr, err := p.WithdrawalAddress(startAddr.SeriesID, startAddr.Branch, startAddr.Index)
|
||||
if err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage, "cannot deserialize startAddress", err)
|
||||
}
|
||||
wInfo.startAddress = *wAddr
|
||||
|
||||
cAddr, err := p.ChangeAddress(row.ChangeStart.SeriesID, row.ChangeStart.Index)
|
||||
if err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage, "cannot deserialize changeStart", err)
|
||||
}
|
||||
wInfo.changeStart = *cAddr
|
||||
|
||||
// TODO: Copy over row.Status.nextInputAddr. Not done because StartWithdrawal
|
||||
// does not update that yet.
|
||||
nextChangeAddr := row.Status.NextChangeAddr
|
||||
cAddr, err = p.ChangeAddress(nextChangeAddr.SeriesID, nextChangeAddr.Index)
|
||||
if err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage,
|
||||
"cannot deserialize nextChangeAddress for withdrawal", err)
|
||||
}
|
||||
wInfo.status = WithdrawalStatus{
|
||||
nextChangeAddr: *cAddr,
|
||||
fees: row.Status.Fees,
|
||||
outputs: make(map[OutBailmentID]*WithdrawalOutput, len(row.Status.Outputs)),
|
||||
sigs: row.Status.Sigs,
|
||||
transactions: make(map[Ntxid]changeAwareTx, len(row.Status.Transactions)),
|
||||
}
|
||||
for oid, output := range row.Status.Outputs {
|
||||
outpoints := make([]OutBailmentOutpoint, len(output.Outpoints))
|
||||
for i, outpoint := range output.Outpoints {
|
||||
outpoints[i] = OutBailmentOutpoint{
|
||||
ntxid: outpoint.Ntxid,
|
||||
index: outpoint.Index,
|
||||
amount: outpoint.Amount,
|
||||
}
|
||||
}
|
||||
wInfo.status.outputs[oid] = &WithdrawalOutput{
|
||||
request: requestsByOID[output.OutBailmentID],
|
||||
status: output.Status,
|
||||
outpoints: outpoints,
|
||||
}
|
||||
}
|
||||
for ntxid, tx := range row.Status.Transactions {
|
||||
msgtx := wire.NewMsgTx()
|
||||
if err := msgtx.Deserialize(bytes.NewBuffer(tx.SerializedMsgTx)); err != nil {
|
||||
return nil, newError(ErrWithdrawalStorage, "cannot deserialize transaction", err)
|
||||
}
|
||||
wInfo.status.transactions[ntxid] = changeAwareTx{
|
||||
MsgTx: msgtx,
|
||||
changeIdx: tx.ChangeIdx,
|
||||
}
|
||||
}
|
||||
return wInfo, nil
|
||||
}
|
||||
|
||||
func putWithdrawal(tx walletdb.Tx, poolID []byte, roundID uint32, serialized []byte) error {
|
||||
bucket := tx.RootBucket().Bucket(poolID)
|
||||
return bucket.Put(uint32ToBytes(roundID), serialized)
|
||||
}
|
||||
|
||||
func getWithdrawal(tx walletdb.Tx, poolID []byte, roundID uint32) []byte {
|
||||
bucket := tx.RootBucket().Bucket(poolID)
|
||||
return bucket.Get(uint32ToBytes(roundID))
|
||||
}
|
||||
|
||||
// uint32ToBytes converts a 32 bit unsigned integer into a 4-byte slice in
|
||||
// little-endian order: 1 -> [1 0 0 0].
|
||||
func uint32ToBytes(number uint32) []byte {
|
||||
|
|
|
@ -18,6 +18,7 @@ package votingpool
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
|
@ -80,3 +81,77 @@ func TestGetMaxUsedIdx(t *testing.T) {
|
|||
t.Fatalf("Wrong max idx; got %d, want %d", maxIdx, Index(3001))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithdrawalSerialization(t *testing.T) {
|
||||
tearDown, _, pool := TstCreatePool(t)
|
||||
defer tearDown()
|
||||
|
||||
roundID := uint32(0)
|
||||
wi := createAndFulfillWithdrawalRequests(t, pool, roundID)
|
||||
|
||||
serialized, err := serializeWithdrawal(wi.requests, wi.startAddress, wi.lastSeriesID,
|
||||
wi.changeStart, wi.dustThreshold, wi.status)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var wInfo *withdrawalInfo
|
||||
TstRunWithManagerUnlocked(t, pool.Manager(), func() {
|
||||
wInfo, err = deserializeWithdrawal(pool, serialized)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(wInfo.startAddress, wi.startAddress) {
|
||||
t.Fatalf("Wrong startAddr; got %v, want %v", wInfo.startAddress, wi.startAddress)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(wInfo.changeStart, wi.changeStart) {
|
||||
t.Fatalf("Wrong changeStart; got %v, want %v", wInfo.changeStart, wi.changeStart)
|
||||
}
|
||||
|
||||
if wInfo.lastSeriesID != wi.lastSeriesID {
|
||||
t.Fatalf("Wrong LastSeriesID; got %d, want %d", wInfo.lastSeriesID, wi.lastSeriesID)
|
||||
}
|
||||
|
||||
if wInfo.dustThreshold != wi.dustThreshold {
|
||||
t.Fatalf("Wrong DustThreshold; got %d, want %d", wInfo.dustThreshold, wi.dustThreshold)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(wInfo.requests, wi.requests) {
|
||||
t.Fatalf("Wrong output requests; got %v, want %v", wInfo.requests, wi.requests)
|
||||
}
|
||||
|
||||
TstCheckWithdrawalStatusMatches(t, wInfo.status, wi.status)
|
||||
}
|
||||
|
||||
func TestPutAndGetWithdrawal(t *testing.T) {
|
||||
tearDown, _, pool := TstCreatePool(t)
|
||||
defer tearDown()
|
||||
|
||||
serialized := bytes.Repeat([]byte{1}, 10)
|
||||
poolID := []byte{0x00}
|
||||
roundID := uint32(0)
|
||||
err := pool.namespace.Update(
|
||||
func(tx walletdb.Tx) error {
|
||||
return putWithdrawal(tx, poolID, roundID, serialized)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var retrieved []byte
|
||||
err = pool.namespace.View(
|
||||
func(tx walletdb.Tx) error {
|
||||
retrieved = getWithdrawal(tx, poolID, roundID)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(retrieved, serialized) {
|
||||
t.Fatalf("Wrong value retrieved from DB; got %x, want %x", retrieved, serialized)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,6 +151,10 @@ const (
|
|||
// transactions.
|
||||
ErrWithdrawalTxStorage
|
||||
|
||||
// ErrWithdrawalStorage indicates an error occurred when serializing or
|
||||
// deserializing withdrawal information.
|
||||
ErrWithdrawalStorage
|
||||
|
||||
// lastErr is used for testing, making it possible to iterate over
|
||||
// the error codes in order to check that they all have proper
|
||||
// translations in errorCodeStrings.
|
||||
|
@ -192,6 +196,7 @@ var errorCodeStrings = map[ErrorCode]string{
|
|||
ErrInvalidScriptHash: "ErrInvalidScriptHash",
|
||||
ErrWithdrawFromUnusedAddr: "ErrWithdrawFromUnusedAddr",
|
||||
ErrWithdrawalTxStorage: "ErrWithdrawalTxStorage",
|
||||
ErrWithdrawalStorage: "ErrWithdrawalStorage",
|
||||
}
|
||||
|
||||
// String returns the ErrorCode as a human-readable name.
|
||||
|
|
|
@ -65,6 +65,7 @@ func TestErrorCodeStringer(t *testing.T) {
|
|||
{vp.ErrInvalidScriptHash, "ErrInvalidScriptHash"},
|
||||
{vp.ErrWithdrawFromUnusedAddr, "ErrWithdrawFromUnusedAddr"},
|
||||
{vp.ErrWithdrawalTxStorage, "ErrWithdrawalTxStorage"},
|
||||
{vp.ErrWithdrawalStorage, "ErrWithdrawalStorage"},
|
||||
{0xffff, "Unknown ErrorCode (65535)"},
|
||||
}
|
||||
|
||||
|
|
|
@ -421,3 +421,29 @@ func TstNewChangeAddress(t *testing.T, p *Pool, seriesID uint32, idx Index) (add
|
|||
func TstConstantFee(fee btcutil.Amount) func(tx *withdrawalTx) btcutil.Amount {
|
||||
return func(tx *withdrawalTx) btcutil.Amount { return fee }
|
||||
}
|
||||
|
||||
func createAndFulfillWithdrawalRequests(t *testing.T, pool *Pool, roundID uint32) withdrawalInfo {
|
||||
|
||||
params := pool.Manager().ChainParams()
|
||||
seriesID, eligible := TstCreateCreditsOnNewSeries(t, pool, []int64{2e6, 4e6})
|
||||
requests := []OutputRequest{
|
||||
TstNewOutputRequest(t, 1, "34eVkREKgvvGASZW7hkgE2uNc1yycntMK6", 3e6, params),
|
||||
TstNewOutputRequest(t, 2, "3PbExiaztsSYgh6zeMswC49hLUwhTQ86XG", 2e6, params),
|
||||
}
|
||||
changeStart := TstNewChangeAddress(t, pool, seriesID, 0)
|
||||
dustThreshold := btcutil.Amount(1e4)
|
||||
startAddr := TstNewWithdrawalAddress(t, pool, seriesID, 1, 0)
|
||||
lastSeriesID := seriesID
|
||||
w := newWithdrawal(roundID, requests, eligible, *changeStart)
|
||||
if err := w.fulfillRequests(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return withdrawalInfo{
|
||||
requests: requests,
|
||||
startAddress: *startAddr,
|
||||
changeStart: *changeStart,
|
||||
lastSeriesID: lastSeriesID,
|
||||
dustThreshold: dustThreshold,
|
||||
status: *w.status,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -96,7 +96,8 @@ func (vp *Pool) TstDecryptExtendedKey(keyType waddrmgr.CryptoKeyType, encrypted
|
|||
return vp.decryptExtendedKey(keyType, encrypted)
|
||||
}
|
||||
|
||||
// TstGetMsgTx returns the withdrawal transaction with the given ntxid.
|
||||
// TstGetMsgTx returns a copy of the withdrawal transaction with the given
|
||||
// ntxid.
|
||||
func (s *WithdrawalStatus) TstGetMsgTx(ntxid Ntxid) *wire.MsgTx {
|
||||
return s.transactions[ntxid].MsgTx
|
||||
return s.transactions[ntxid].MsgTx.Copy()
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -28,6 +29,7 @@ import (
|
|||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/btcsuite/btcwallet/waddrmgr"
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/btcsuite/btcwallet/wtxmgr"
|
||||
"github.com/btcsuite/fastsha256"
|
||||
)
|
||||
|
@ -84,7 +86,7 @@ type WithdrawalOutput struct {
|
|||
outpoints []OutBailmentOutpoint
|
||||
}
|
||||
|
||||
// OutBailmentOutpoint represents one of the outpoints created to fulfil an OutputRequest.
|
||||
// OutBailmentOutpoint represents one of the outpoints created to fulfill an OutputRequest.
|
||||
type OutBailmentOutpoint struct {
|
||||
ntxid Ntxid
|
||||
index uint32
|
||||
|
@ -110,6 +112,18 @@ type WithdrawalStatus struct {
|
|||
transactions map[Ntxid]changeAwareTx
|
||||
}
|
||||
|
||||
// withdrawalInfo contains all the details of an existing withdrawal, including
|
||||
// the original request parameters and the WithdrawalStatus returned by
|
||||
// StartWithdrawal.
|
||||
type withdrawalInfo struct {
|
||||
requests []OutputRequest
|
||||
startAddress WithdrawalAddress
|
||||
changeStart ChangeAddress
|
||||
lastSeriesID uint32
|
||||
dustThreshold btcutil.Amount
|
||||
status WithdrawalStatus
|
||||
}
|
||||
|
||||
// TxSigs is list of raw signatures (one for every pubkey in the multi-sig
|
||||
// script) for a given transaction input. They should match the order of pubkeys
|
||||
// in the script and an empty RawSig should be used when the private key for a
|
||||
|
@ -443,11 +457,21 @@ func newWithdrawal(roundID uint32, requests []OutputRequest, inputs []credit,
|
|||
// signature lists (one for every private key available to this wallet) for each
|
||||
// of those transaction's inputs. More details about the actual algorithm can be
|
||||
// found at http://opentransactions.org/wiki/index.php/Startwithdrawal
|
||||
// This method must be called with the address manager unlocked.
|
||||
func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest,
|
||||
startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress,
|
||||
txStore *wtxmgr.Store, chainHeight int32, dustThreshold btcutil.Amount) (
|
||||
*WithdrawalStatus, error) {
|
||||
|
||||
status, err := getWithdrawalStatus(p, roundID, requests, startAddress, lastSeriesID,
|
||||
changeStart, dustThreshold)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status != nil {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
eligible, err := p.getEligibleInputs(txStore, startAddress, lastSeriesID, dustThreshold,
|
||||
chainHeight, eligibleInputMinConfirmations)
|
||||
if err != nil {
|
||||
|
@ -463,6 +487,19 @@ func (p *Pool) StartWithdrawal(roundID uint32, requests []OutputRequest,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
serialized, err := serializeWithdrawal(requests, startAddress, lastSeriesID, changeStart,
|
||||
dustThreshold, *w.status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = p.namespace.Update(
|
||||
func(tx walletdb.Tx) error {
|
||||
return putWithdrawal(tx, p.ID, roundID, serialized)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return w.status, nil
|
||||
}
|
||||
|
||||
|
@ -720,6 +757,74 @@ func (s *WithdrawalStatus) updateStatusFor(tx *withdrawalTx) {
|
|||
}
|
||||
}
|
||||
|
||||
// match returns true if the given arguments match the fields in this
|
||||
// withdrawalInfo. For the requests slice, the order of the items does not
|
||||
// matter.
|
||||
func (wi *withdrawalInfo) match(requests []OutputRequest, startAddress WithdrawalAddress,
|
||||
lastSeriesID uint32, changeStart ChangeAddress, dustThreshold btcutil.Amount) bool {
|
||||
// Use reflect.DeepEqual to compare changeStart and startAddress as they're
|
||||
// structs that contain pointers and we want to compare their content and
|
||||
// not their address.
|
||||
if !reflect.DeepEqual(changeStart, wi.changeStart) {
|
||||
log.Debugf("withdrawal changeStart does not match: %v != %v", changeStart, wi.changeStart)
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(startAddress, wi.startAddress) {
|
||||
log.Debugf("withdrawal startAddr does not match: %v != %v", startAddress, wi.startAddress)
|
||||
return false
|
||||
}
|
||||
if lastSeriesID != wi.lastSeriesID {
|
||||
log.Debugf("withdrawal lastSeriesID does not match: %v != %v", lastSeriesID,
|
||||
wi.lastSeriesID)
|
||||
return false
|
||||
}
|
||||
if dustThreshold != wi.dustThreshold {
|
||||
log.Debugf("withdrawal dustThreshold does not match: %v != %v", dustThreshold,
|
||||
wi.dustThreshold)
|
||||
return false
|
||||
}
|
||||
r1 := make([]OutputRequest, len(requests))
|
||||
copy(r1, requests)
|
||||
r2 := make([]OutputRequest, len(wi.requests))
|
||||
copy(r2, wi.requests)
|
||||
sort.Sort(byOutBailmentID(r1))
|
||||
sort.Sort(byOutBailmentID(r2))
|
||||
if !reflect.DeepEqual(r1, r2) {
|
||||
log.Debugf("withdrawal requests does not match: %v != %v", requests, wi.requests)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getWithdrawalStatus returns the existing WithdrawalStatus for the given
|
||||
// withdrawal parameters, if one exists. This function must be called with the
|
||||
// address manager unlocked.
|
||||
func getWithdrawalStatus(p *Pool, roundID uint32, requests []OutputRequest,
|
||||
startAddress WithdrawalAddress, lastSeriesID uint32, changeStart ChangeAddress,
|
||||
dustThreshold btcutil.Amount) (*WithdrawalStatus, error) {
|
||||
|
||||
var serialized []byte
|
||||
err := p.namespace.View(
|
||||
func(tx walletdb.Tx) error {
|
||||
serialized = getWithdrawal(tx, p.ID, roundID)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bytes.Equal(serialized, []byte{}) {
|
||||
return nil, nil
|
||||
}
|
||||
wInfo, err := deserializeWithdrawal(p, serialized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wInfo.match(requests, startAddress, lastSeriesID, changeStart, dustThreshold) {
|
||||
return &wInfo.status, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getRawSigs iterates over the inputs of each transaction given, constructing the
|
||||
// raw signatures for them using the private keys available to us.
|
||||
// It returns a map of ntxids to signature lists.
|
||||
|
|
|
@ -96,6 +96,18 @@ func TestStartWithdrawal(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
// Any subsequent StartWithdrawal() calls with the same parameters will
|
||||
// return the previously stored WithdrawalStatus.
|
||||
var status2 *vp.WithdrawalStatus
|
||||
vp.TstRunWithManagerUnlocked(t, mgr, func() {
|
||||
status2, err = pool.StartWithdrawal(0, requests, *startAddr, lastSeriesID, *changeStart,
|
||||
store, currentBlock, dustThreshold)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
vp.TstCheckWithdrawalStatusMatches(t, *status, *status2)
|
||||
}
|
||||
|
||||
func checkWithdrawalOutputs(
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/btcsuite/btcutil"
|
||||
"github.com/btcsuite/btcutil/hdkeychain"
|
||||
"github.com/btcsuite/btcwallet/waddrmgr"
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/btcsuite/btcwallet/wtxmgr"
|
||||
)
|
||||
|
||||
|
@ -660,6 +661,122 @@ func TestWithdrawalTxOutputTotal(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWithdrawalInfoMatch(t *testing.T) {
|
||||
tearDown, _, pool := TstCreatePool(t)
|
||||
defer tearDown()
|
||||
|
||||
roundID := uint32(0)
|
||||
wi := createAndFulfillWithdrawalRequests(t, pool, roundID)
|
||||
|
||||
// Use freshly created values for requests, startAddress and changeStart
|
||||
// to simulate what would happen if we had recreated them from the
|
||||
// serialized data in the DB.
|
||||
requestsCopy := make([]OutputRequest, len(wi.requests))
|
||||
copy(requestsCopy, wi.requests)
|
||||
startAddr := TstNewWithdrawalAddress(t, pool, wi.startAddress.seriesID, wi.startAddress.branch,
|
||||
wi.startAddress.index)
|
||||
changeStart := TstNewChangeAddress(t, pool, wi.changeStart.seriesID, wi.changeStart.index)
|
||||
|
||||
// First check that it matches when all fields are identical.
|
||||
matches := wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold)
|
||||
if !matches {
|
||||
t.Fatal("Should match when everything is identical.")
|
||||
}
|
||||
|
||||
// It also matches if the OutputRequests are not in the same order.
|
||||
diffOrderRequests := make([]OutputRequest, len(requestsCopy))
|
||||
copy(diffOrderRequests, requestsCopy)
|
||||
diffOrderRequests[0], diffOrderRequests[1] = requestsCopy[1], requestsCopy[0]
|
||||
matches = wi.match(diffOrderRequests, *startAddr, wi.lastSeriesID, *changeStart,
|
||||
wi.dustThreshold)
|
||||
if !matches {
|
||||
t.Fatal("Should match when requests are in different order.")
|
||||
}
|
||||
|
||||
// It should not match when the OutputRequests are not the same.
|
||||
diffRequests := diffOrderRequests
|
||||
diffRequests[0] = OutputRequest{}
|
||||
matches = wi.match(diffRequests, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold)
|
||||
if matches {
|
||||
t.Fatal("Should not match as requests is not equal.")
|
||||
}
|
||||
|
||||
// It should not match when lastSeriesID is not equal.
|
||||
matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID+1, *changeStart, wi.dustThreshold)
|
||||
if matches {
|
||||
t.Fatal("Should not match as lastSeriesID is not equal.")
|
||||
}
|
||||
|
||||
// It should not match when dustThreshold is not equal.
|
||||
matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *changeStart, wi.dustThreshold+1)
|
||||
if matches {
|
||||
t.Fatal("Should not match as dustThreshold is not equal.")
|
||||
}
|
||||
|
||||
// It should not match when startAddress is not equal.
|
||||
diffStartAddr := TstNewWithdrawalAddress(t, pool, startAddr.seriesID, startAddr.branch+1,
|
||||
startAddr.index)
|
||||
matches = wi.match(requestsCopy, *diffStartAddr, wi.lastSeriesID, *changeStart,
|
||||
wi.dustThreshold)
|
||||
if matches {
|
||||
t.Fatal("Should not match as startAddress is not equal.")
|
||||
}
|
||||
|
||||
// It should not match when changeStart is not equal.
|
||||
diffChangeStart := TstNewChangeAddress(t, pool, changeStart.seriesID, changeStart.index+1)
|
||||
matches = wi.match(requestsCopy, *startAddr, wi.lastSeriesID, *diffChangeStart,
|
||||
wi.dustThreshold)
|
||||
if matches {
|
||||
t.Fatal("Should not match as changeStart is not equal.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWithdrawalStatus(t *testing.T) {
|
||||
tearDown, _, pool := TstCreatePool(t)
|
||||
defer tearDown()
|
||||
|
||||
roundID := uint32(0)
|
||||
wi := createAndFulfillWithdrawalRequests(t, pool, roundID)
|
||||
|
||||
serialized, err := serializeWithdrawal(wi.requests, wi.startAddress, wi.lastSeriesID,
|
||||
wi.changeStart, wi.dustThreshold, wi.status)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = pool.namespace.Update(
|
||||
func(tx walletdb.Tx) error {
|
||||
return putWithdrawal(tx, pool.ID, roundID, serialized)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Here we should get a WithdrawalStatus that matches wi.status.
|
||||
var status *WithdrawalStatus
|
||||
TstRunWithManagerUnlocked(t, pool.Manager(), func() {
|
||||
status, err = getWithdrawalStatus(pool, roundID, wi.requests, wi.startAddress,
|
||||
wi.lastSeriesID, wi.changeStart, wi.dustThreshold)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
TstCheckWithdrawalStatusMatches(t, wi.status, *status)
|
||||
|
||||
// Here we should get a nil WithdrawalStatus because the parameters are not
|
||||
// identical to those of the stored WithdrawalStatus with this roundID.
|
||||
dustThreshold := wi.dustThreshold + 1
|
||||
TstRunWithManagerUnlocked(t, pool.Manager(), func() {
|
||||
status, err = getWithdrawalStatus(pool, roundID, wi.requests, wi.startAddress,
|
||||
wi.lastSeriesID, wi.changeStart, dustThreshold)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if status != nil {
|
||||
t.Fatalf("Expected a nil status, got %v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignMultiSigUTXO(t *testing.T) {
|
||||
tearDown, pool, _ := TstCreatePoolAndTxStore(t)
|
||||
defer tearDown()
|
||||
|
|
Loading…
Reference in a new issue