Update wallet to use ForEach- style functions

This commit is contained in:
Javed Khan 2015-03-26 23:52:59 +05:30
parent fe0f60991a
commit fbf744bc5e
6 changed files with 140 additions and 153 deletions

View file

@ -1636,17 +1636,13 @@ func GetAddressesByAccount(w *wallet.Wallet, chainSvr *chain.Client, icmd interf
return nil, err return nil, err
} }
addrs, err := w.Manager.AllAccountAddresses(account) var addrStrs []string
if err != nil { err = w.Manager.ForEachAccountAddress(account,
return nil, err func(maddr waddrmgr.ManagedAddress) error {
} addrStrs = append(addrStrs, maddr.Address().EncodeAddress())
return nil
addrStrs := make([]string, len(addrs)) })
for i, addr := range addrs { return addrStrs, err
addrStrs[i] = addr.Address().EncodeAddress()
}
return addrStrs, nil
} }
// GetBalance handles a getbalance request by returning the balance for an // GetBalance handles a getbalance request by returning the balance for an
@ -2263,7 +2259,11 @@ func ListAccounts(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{}) (i
cmd := icmd.(*btcjson.ListAccountsCmd) cmd := icmd.(*btcjson.ListAccountsCmd)
accountBalances := map[string]float64{} accountBalances := map[string]float64{}
accounts, err := w.Manager.AllAccounts() var accounts []uint32
err := w.Manager.ForEachAccount(func(account uint32) error {
accounts = append(accounts, account)
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2302,7 +2302,11 @@ func ListLockUnspent(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{})
func ListReceivedByAccount(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{}) (interface{}, error) { func ListReceivedByAccount(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{}) (interface{}, error) {
cmd := icmd.(*btcjson.ListReceivedByAccountCmd) cmd := icmd.(*btcjson.ListReceivedByAccountCmd)
accounts, err := w.Manager.AllAccounts() var accounts []uint32
err := w.Manager.ForEachAccount(func(account uint32) error {
accounts = append(accounts, account)
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -577,23 +577,18 @@ func serializeBIP0044AccountRow(encryptedPubKey,
return rawData return rawData
} }
// fetchAllAccounts loads information about all accounts from the database. // forEachAccount calls the given function with each account stored in
// The returned value is a slice of account numbers which can be used to load // the manager, breaking early on error.
// the respective account rows. func forEachAccount(tx walletdb.Tx, fn func(account uint32) error) error {
// TODO(tuxcanfly): Switch over to an iterator to support the maximum of 2^31-2 accounts
func fetchAllAccounts(tx walletdb.Tx) ([]uint32, error) {
bucket := tx.RootBucket().Bucket(acctBucketName) bucket := tx.RootBucket().Bucket(acctBucketName)
var accounts []uint32 return bucket.ForEach(func(k, v []byte) error {
err := bucket.ForEach(func(k, v []byte) error {
// Skip buckets. // Skip buckets.
if v == nil { if v == nil {
return nil return nil
} }
accounts = append(accounts, binary.LittleEndian.Uint32(k)) return fn(binary.LittleEndian.Uint32(k))
return nil
}) })
return accounts, err
} }
// fetchLastAccount retreives the last account from the database. // fetchLastAccount retreives the last account from the database.
@ -1187,19 +1182,17 @@ func fetchAddrAccount(tx walletdb.Tx, addressID []byte) (uint32, error) {
return binary.LittleEndian.Uint32(val), nil return binary.LittleEndian.Uint32(val), nil
} }
// fetchAccountAddresses loads information about addresses of an account from the database. // forEachAccountAddress calls the given function with each address of
// The returned value is a slice address rows for each specific address type. // the given account stored in the manager, breaking early on error.
// The caller should use type assertions to ascertain the types. func forEachAccountAddress(tx walletdb.Tx, account uint32, fn func(rowInterface interface{}) error) error {
func fetchAccountAddresses(tx walletdb.Tx, account uint32) ([]interface{}, error) {
bucket := tx.RootBucket().Bucket(addrAcctIdxBucketName). bucket := tx.RootBucket().Bucket(addrAcctIdxBucketName).
Bucket(uint32ToBytes(account)) Bucket(uint32ToBytes(account))
// if index bucket is missing the account, there hasn't been any address // if index bucket is missing the account, there hasn't been any address
// entries yet // entries yet
if bucket == nil { if bucket == nil {
return nil, nil return nil
} }
var addrs []interface{}
err := bucket.ForEach(func(k, v []byte) error { err := bucket.ForEach(func(k, v []byte) error {
// Skip buckets. // Skip buckets.
if v == nil { if v == nil {
@ -1216,24 +1209,19 @@ func fetchAccountAddresses(tx walletdb.Tx, account uint32) ([]interface{}, error
return err return err
} }
addrs = append(addrs, addrRow) return fn(addrRow)
return nil
}) })
if err != nil { if err != nil {
return nil, maybeConvertDbError(err) return maybeConvertDbError(err)
} }
return nil
return addrs, nil
} }
// fetchAllAddresses loads information about all addresses from the database. // forEachActiveAddress calls the given function with each active address
// The returned value is a slice of address rows for each specific address type. // stored in the manager, breaking early on error.
// The caller should use type assertions to ascertain the types. func forEachActiveAddress(tx walletdb.Tx, fn func(rowInterface interface{}) error) error {
// TODO(tuxcanfly): Switch over to an iterator to support the maximum of 2^62 - 2^32 - 2^31 + 2 addrs
func fetchAllAddresses(tx walletdb.Tx) ([]interface{}, error) {
bucket := tx.RootBucket().Bucket(addrBucketName) bucket := tx.RootBucket().Bucket(addrBucketName)
var addrs []interface{}
err := bucket.ForEach(func(k, v []byte) error { err := bucket.ForEach(func(k, v []byte) error {
// Skip buckets. // Skip buckets.
if v == nil { if v == nil {
@ -1253,14 +1241,12 @@ func fetchAllAddresses(tx walletdb.Tx) ([]interface{}, error) {
return err return err
} }
addrs = append(addrs, addrRow) return fn(addrRow)
return nil
}) })
if err != nil { if err != nil {
return nil, maybeConvertDbError(err) return maybeConvertDbError(err)
} }
return nil
return addrs, nil
} }
// deletePrivateKeys removes all private key material from the database. // deletePrivateKeys removes all private key material from the database.

View file

@ -131,6 +131,10 @@ const (
// ErrWrongNet indicates that the private key to be imported is not for the // ErrWrongNet indicates that the private key to be imported is not for the
// the same network the account manager is configured for. // the same network the account manager is configured for.
ErrWrongNet ErrWrongNet
// ErrCallBackBreak is used to break from a callback function passed
// down to the manager.
ErrCallBackBreak
) )
// Map of ErrorCode values back to their constant names for pretty printing. // Map of ErrorCode values back to their constant names for pretty printing.
@ -154,6 +158,7 @@ var errorCodeStrings = map[ErrorCode]string{
ErrTooManyAddresses: "ErrTooManyAddresses", ErrTooManyAddresses: "ErrTooManyAddresses",
ErrWrongPassphrase: "ErrWrongPassphrase", ErrWrongPassphrase: "ErrWrongPassphrase",
ErrWrongNet: "ErrWrongNet", ErrWrongNet: "ErrWrongNet",
ErrCallBackBreak: "ErrCallBackBreak",
} }
// String returns the ErrorCode as a human-readable name. // String returns the ErrorCode as a human-readable name.
@ -195,3 +200,7 @@ func (e ManagerError) Error() string {
func managerError(c ErrorCode, desc string, err error) ManagerError { func managerError(c ErrorCode, desc string, err error) ManagerError {
return ManagerError{ErrorCode: c, Description: desc, Err: err} return ManagerError{ErrorCode: c, Description: desc, Err: err}
} }
// Break is a global err used to signal a break from the callback
// function by returning an error with the code ErrCallBackBreak
var Break = managerError(ErrCallBackBreak, "callback break", nil)

View file

@ -1821,22 +1821,14 @@ func (m *Manager) AccountName(account uint32) (string, error) {
return acctName, nil return acctName, nil
} }
// AllAccounts returns a slice of all the accounts stored in the manager. // ForEachAccount calls the given function with each account stored in the
func (m *Manager) AllAccounts() ([]uint32, error) { // manager, breaking early on error.
func (m *Manager) ForEachAccount(fn func(account uint32) error) error {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
return m.namespace.View(func(tx walletdb.Tx) error {
var accounts []uint32 return forEachAccount(tx, fn)
err := m.namespace.View(func(tx walletdb.Tx) error {
var err error
accounts, err = fetchAllAccounts(tx)
return err
}) })
if err != nil {
return nil, err
}
return accounts, nil
} }
// LastAccount returns the last account stored in the manager. // LastAccount returns the last account stored in the manager.
@ -1853,73 +1845,58 @@ func (m *Manager) LastAccount() (uint32, error) {
return account, err return account, err
} }
// AllAccountAddresses returns a slice of addresses of an account stored in the manager. // ForEachAccountAddress calls the given function with each address of
func (m *Manager) AllAccountAddresses(account uint32) ([]ManagedAddress, error) { // the given account stored in the manager, breaking early on error.
func (m *Manager) ForEachAccountAddress(account uint32, fn func(maddr ManagedAddress) error) error {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
// Load the raw address information from the database. addrFn := func(rowInterface interface{}) error {
var rowInterfaces []interface{}
err := m.namespace.View(func(tx walletdb.Tx) error {
var err error
rowInterfaces, err = fetchAccountAddresses(tx, account)
return err
})
if err != nil {
return nil, err
}
addrs := make([]ManagedAddress, 0, len(rowInterfaces))
for _, rowInterface := range rowInterfaces {
// Create a new managed address for the specific type of address
// based on type.
managedAddr, err := m.rowInterfaceToManaged(rowInterface) managedAddr, err := m.rowInterfaceToManaged(rowInterface)
if err != nil { if err != nil {
return nil, err return err
} }
return fn(managedAddr)
addrs = append(addrs, managedAddr)
} }
return addrs, nil err := m.namespace.View(func(tx walletdb.Tx) error {
return forEachAccountAddress(tx, account, addrFn)
})
if err != nil {
return maybeConvertDbError(err)
}
return nil
} }
// ActiveAccountAddresses returns a slice of active addresses of an account // ForEachActiveAccountAddress calls the given function with each active
// stored in the manager. // address of the given account stored in the manager, breaking early on
// error.
// TODO(tuxcanfly): actually return only active addresses // TODO(tuxcanfly): actually return only active addresses
func (m *Manager) ActiveAccountAddresses(account uint32) ([]ManagedAddress, error) { func (m *Manager) ForEachActiveAccountAddress(account uint32, fn func(maddr ManagedAddress) error) error {
return m.AllAccountAddresses(account) return m.ForEachAccountAddress(account, fn)
} }
// AllActiveAddresses returns a slice of all addresses stored in the manager. // ForEachActiveAddress calls the given function with each active address
func (m *Manager) AllActiveAddresses() ([]btcutil.Address, error) { // stored in the manager, breaking early on error.
func (m *Manager) ForEachActiveAddress(fn func(addr btcutil.Address) error) error {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
// Load the raw address information from the database. addrFn := func(rowInterface interface{}) error {
var rowInterfaces []interface{}
err := m.namespace.View(func(tx walletdb.Tx) error {
var err error
rowInterfaces, err = fetchAllAddresses(tx)
return err
})
if err != nil {
return nil, maybeConvertDbError(err)
}
addrs := make([]btcutil.Address, 0, len(rowInterfaces))
for _, rowInterface := range rowInterfaces {
// Create a new managed address for the specific type of address
// based on type.
managedAddr, err := m.rowInterfaceToManaged(rowInterface) managedAddr, err := m.rowInterfaceToManaged(rowInterface)
if err != nil { if err != nil {
return nil, err return err
} }
return fn(managedAddr.Address())
addrs = append(addrs, managedAddr.Address())
} }
return addrs, nil err := m.namespace.View(func(tx walletdb.Tx) error {
return forEachActiveAddress(tx, addrFn)
})
if err != nil {
return maybeConvertDbError(err)
}
return nil
} }
// selectCryptoKey selects the appropriate crypto key based on the key type. An // selectCryptoKey selects the appropriate crypto key based on the key type. An

View file

@ -1267,9 +1267,10 @@ func testRenameAccount(tc *testContext) bool {
return true return true
} }
// testAllAccounts tests the retrieve all accounts func of the address manager works // testForEachAccount tests the retrieve all accounts func of the address
// as expected. // manager works as expected.
func testAllAccounts(tc *testContext) bool { func testForEachAccount(tc *testContext) bool {
prefix := testNamePrefix(tc) + " testForEachAccount"
expectedAccounts := []uint32{0, 1} expectedAccounts := []uint32{0, 1}
if !tc.create { if !tc.create {
// Existing wallet manager will have 3 accounts // Existing wallet manager will have 3 accounts
@ -1277,38 +1278,47 @@ func testAllAccounts(tc *testContext) bool {
} }
// Imported account // Imported account
expectedAccounts = append(expectedAccounts, waddrmgr.ImportedAddrAccount) expectedAccounts = append(expectedAccounts, waddrmgr.ImportedAddrAccount)
accounts, err := tc.manager.AllAccounts() var accounts []uint32
err := tc.manager.ForEachAccount(func(account uint32) error {
accounts = append(accounts, account)
return nil
})
if err != nil { if err != nil {
tc.t.Errorf("AllAccounts: unexpected error: %v", err) tc.t.Errorf("%s: unexpected error: %v", prefix, err)
return false return false
} }
if len(accounts) != len(expectedAccounts) { if len(accounts) != len(expectedAccounts) {
tc.t.Errorf("AllAccounts: unexpected number of accounts - got "+ tc.t.Errorf("%s: unexpected number of accounts - got "+
"%d, want %d", len(accounts), "%d, want %d", prefix, len(accounts),
len(expectedAccounts)) len(expectedAccounts))
return false return false
} }
for i, account := range accounts { for i, account := range accounts {
if expectedAccounts[i] != account { if expectedAccounts[i] != account {
tc.t.Errorf("AllAccounts %s: "+ tc.t.Errorf("%s #%d: "+
"account mismatch -- got %d, "+ "account mismatch -- got %d, "+
"want %d", i, account, expectedAccounts[i]) "want %d", prefix, i, account, expectedAccounts[i])
} }
} }
return true return true
} }
// testAllAccountAddresses tests the account addresses returned by the manager // testForEachAccountAddress tests that iterating through the given
// API. // account addresses using the manager API works as expected.
func testAllAccountAddresses(tc *testContext) bool { func testForEachAccountAddress(tc *testContext) bool {
prefix := testNamePrefix(tc) + " testAllAccountAddresses" prefix := testNamePrefix(tc) + " testForEachAccountAddress"
// Make a map of expected addresses // Make a map of expected addresses
expectedAddrMap := make(map[string]*expectedAddr, len(expectedAddrs)) expectedAddrMap := make(map[string]*expectedAddr, len(expectedAddrs))
for i := 0; i < len(expectedAddrs); i++ { for i := 0; i < len(expectedAddrs); i++ {
expectedAddrMap[expectedAddrs[i].address] = &expectedAddrs[i] expectedAddrMap[expectedAddrs[i].address] = &expectedAddrs[i]
} }
addrs, err := tc.manager.AllAccountAddresses(tc.account) var addrs []waddrmgr.ManagedAddress
err := tc.manager.ForEachAccountAddress(tc.account,
func(maddr waddrmgr.ManagedAddress) error {
addrs = append(addrs, maddr)
return nil
})
if err != nil { if err != nil {
tc.t.Errorf("%s: unexpected error: %v", prefix, err) tc.t.Errorf("%s: unexpected error: %v", prefix, err)
return false return false
@ -1349,8 +1359,8 @@ func testManagerAPI(tc *testContext) {
tc.account = 0 tc.account = 0
testNewAccount(tc) testNewAccount(tc)
testLookupAccount(tc) testLookupAccount(tc)
testAllAccounts(tc) testForEachAccount(tc)
testAllAccountAddresses(tc) testForEachAccountAddress(tc)
// Rename account 1 "acct-create" // Rename account 1 "acct-create"
tc.account = 1 tc.account = 1

View file

@ -370,7 +370,11 @@ func (w *Wallet) SetChainSynced(synced bool) {
// outputs. This is primarely intended to provide the parameters for a // outputs. This is primarely intended to provide the parameters for a
// rescan request. // rescan request.
func (w *Wallet) activeData() ([]btcutil.Address, []wtxmgr.Credit, error) { func (w *Wallet) activeData() ([]btcutil.Address, []wtxmgr.Credit, error) {
addrs, err := w.Manager.AllActiveAddresses() var addrs []btcutil.Address
err := w.Manager.ForEachActiveAddress(func(addr btcutil.Address) error {
addrs = append(addrs, addr)
return nil
})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -669,20 +673,23 @@ func (w *Wallet) ChangePassphrase(old, new []byte) error {
// a given account. It returns true if atleast one address in the account was // a given account. It returns true if atleast one address in the account was
// used and false if no address in the account was used. // used and false if no address in the account was used.
func (w *Wallet) AccountUsed(account uint32) (bool, error) { func (w *Wallet) AccountUsed(account uint32) (bool, error) {
addrs, err := w.Manager.AllAccountAddresses(account) var used bool
if err != nil { var err error
return false, err merr := w.Manager.ForEachAccountAddress(account,
func(maddr waddrmgr.ManagedAddress) error {
used, err = maddr.Used()
if err != nil {
return err
}
if used {
return waddrmgr.Break
}
return nil
})
if merr == waddrmgr.Break {
merr = nil
} }
for _, addr := range addrs { return used, merr
used, err := addr.Used()
if err != nil {
return false, err
}
if used {
return true, nil
}
}
return false, nil
} }
// CalculateBalance sums the amounts of all unspent transaction // CalculateBalance sums the amounts of all unspent transaction
@ -1203,24 +1210,19 @@ func (w *Wallet) ListUnspent(minconf, maxconf int32,
// DumpPrivKeys returns the WIF-encoded private keys for all addresses with // DumpPrivKeys returns the WIF-encoded private keys for all addresses with
// private keys in a wallet. // private keys in a wallet.
func (w *Wallet) DumpPrivKeys() ([]string, error) { func (w *Wallet) DumpPrivKeys() ([]string, error) {
addrs, err := w.Manager.AllActiveAddresses() var privkeys []string
if err != nil {
return nil, err
}
// Iterate over each active address, appending the private key to // Iterate over each active address, appending the private key to
// privkeys. // privkeys.
privkeys := make([]string, 0, len(addrs)) err := w.Manager.ForEachActiveAddress(func(addr btcutil.Address) error {
for _, addr := range addrs {
ma, err := w.Manager.Address(addr) ma, err := w.Manager.Address(addr)
if err != nil { if err != nil {
return nil, err return err
} }
// Only those addresses with keys needed. // Only those addresses with keys needed.
pka, ok := ma.(waddrmgr.ManagedPubKeyAddress) pka, ok := ma.(waddrmgr.ManagedPubKeyAddress)
if !ok { if !ok {
continue return nil
} }
wif, err := pka.ExportPrivKey() wif, err := pka.ExportPrivKey()
@ -1228,12 +1230,12 @@ func (w *Wallet) DumpPrivKeys() ([]string, error) {
// It would be nice to zero out the array here. However, // It would be nice to zero out the array here. However,
// since strings in go are immutable, and we have no // since strings in go are immutable, and we have no
// control over the caller I don't think we can. :( // control over the caller I don't think we can. :(
return nil, err return err
} }
privkeys = append(privkeys, wif.String()) privkeys = append(privkeys, wif.String())
} return nil
})
return privkeys, nil return privkeys, err
} }
// DumpWIFPrivateKey returns the WIF encoded private key for a // DumpWIFPrivateKey returns the WIF encoded private key for a
@ -1436,16 +1438,15 @@ func (w *Wallet) ResendUnminedTxs() {
// SortedActivePaymentAddresses returns a slice of all active payment // SortedActivePaymentAddresses returns a slice of all active payment
// addresses in a wallet. // addresses in a wallet.
func (w *Wallet) SortedActivePaymentAddresses() ([]string, error) { func (w *Wallet) SortedActivePaymentAddresses() ([]string, error) {
addrs, err := w.Manager.AllActiveAddresses() var addrStrs []string
err := w.Manager.ForEachActiveAddress(func(addr btcutil.Address) error {
addrStrs = append(addrStrs, addr.EncodeAddress())
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
addrStrs := make([]string, len(addrs))
for i, addr := range addrs {
addrStrs[i] = addr.EncodeAddress()
}
sort.Sort(sort.StringSlice(addrStrs)) sort.Sort(sort.StringSlice(addrStrs))
return addrStrs, nil return addrStrs, nil
} }