diff --git a/rpcserver.go b/rpcserver.go index 311bbce..5f93ae0 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1636,17 +1636,13 @@ func GetAddressesByAccount(w *wallet.Wallet, chainSvr *chain.Client, icmd interf return nil, err } - addrs, err := w.Manager.AllAccountAddresses(account) - if err != nil { - return nil, err - } - - addrStrs := make([]string, len(addrs)) - for i, addr := range addrs { - addrStrs[i] = addr.Address().EncodeAddress() - } - - return addrStrs, nil + var addrStrs []string + err = w.Manager.ForEachAccountAddress(account, + func(maddr waddrmgr.ManagedAddress) error { + addrStrs = append(addrStrs, maddr.Address().EncodeAddress()) + return nil + }) + return addrStrs, err } // GetBalance handles a getbalance request by returning the balance for an @@ -2263,7 +2259,11 @@ func ListAccounts(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{}) (i cmd := icmd.(*btcjson.ListAccountsCmd) accountBalances := map[string]float64{} - accounts, err := w.Manager.AllAccounts() + var accounts []uint32 + err := w.Manager.ForEachAccount(func(account uint32) error { + accounts = append(accounts, account) + return nil + }) if err != nil { return nil, err } @@ -2302,7 +2302,11 @@ func ListLockUnspent(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{}) func ListReceivedByAccount(w *wallet.Wallet, chainSvr *chain.Client, icmd interface{}) (interface{}, error) { cmd := icmd.(*btcjson.ListReceivedByAccountCmd) - accounts, err := w.Manager.AllAccounts() + var accounts []uint32 + err := w.Manager.ForEachAccount(func(account uint32) error { + accounts = append(accounts, account) + return nil + }) if err != nil { return nil, err } diff --git a/waddrmgr/db.go b/waddrmgr/db.go index 62a036c..4eaa3aa 100644 --- a/waddrmgr/db.go +++ b/waddrmgr/db.go @@ -577,23 +577,18 @@ func serializeBIP0044AccountRow(encryptedPubKey, return rawData } -// fetchAllAccounts loads information about all accounts from the database. -// The returned value is a slice of account numbers which can be used to load -// the respective account rows. -// TODO(tuxcanfly): Switch over to an iterator to support the maximum of 2^31-2 accounts -func fetchAllAccounts(tx walletdb.Tx) ([]uint32, error) { +// forEachAccount calls the given function with each account stored in +// the manager, breaking early on error. +func forEachAccount(tx walletdb.Tx, fn func(account uint32) error) error { bucket := tx.RootBucket().Bucket(acctBucketName) - var accounts []uint32 - err := bucket.ForEach(func(k, v []byte) error { + return bucket.ForEach(func(k, v []byte) error { // Skip buckets. if v == nil { return nil } - accounts = append(accounts, binary.LittleEndian.Uint32(k)) - return nil + return fn(binary.LittleEndian.Uint32(k)) }) - return accounts, err } // fetchLastAccount retreives the last account from the database. @@ -1187,19 +1182,17 @@ func fetchAddrAccount(tx walletdb.Tx, addressID []byte) (uint32, error) { return binary.LittleEndian.Uint32(val), nil } -// fetchAccountAddresses loads information about addresses of an account from the database. -// The returned value is a slice address rows for each specific address type. -// The caller should use type assertions to ascertain the types. -func fetchAccountAddresses(tx walletdb.Tx, account uint32) ([]interface{}, error) { +// forEachAccountAddress calls the given function with each address of +// the given account stored in the manager, breaking early on error. +func forEachAccountAddress(tx walletdb.Tx, account uint32, fn func(rowInterface interface{}) error) error { bucket := tx.RootBucket().Bucket(addrAcctIdxBucketName). Bucket(uint32ToBytes(account)) // if index bucket is missing the account, there hasn't been any address // entries yet if bucket == nil { - return nil, nil + return nil } - var addrs []interface{} err := bucket.ForEach(func(k, v []byte) error { // Skip buckets. if v == nil { @@ -1216,24 +1209,19 @@ func fetchAccountAddresses(tx walletdb.Tx, account uint32) ([]interface{}, error return err } - addrs = append(addrs, addrRow) - return nil + return fn(addrRow) }) if err != nil { - return nil, maybeConvertDbError(err) + return maybeConvertDbError(err) } - - return addrs, nil + return nil } -// fetchAllAddresses loads information about all addresses from the database. -// The returned value is a slice of address rows for each specific address type. -// The caller should use type assertions to ascertain the types. -// TODO(tuxcanfly): Switch over to an iterator to support the maximum of 2^62 - 2^32 - 2^31 + 2 addrs -func fetchAllAddresses(tx walletdb.Tx) ([]interface{}, error) { +// forEachActiveAddress calls the given function with each active address +// stored in the manager, breaking early on error. +func forEachActiveAddress(tx walletdb.Tx, fn func(rowInterface interface{}) error) error { bucket := tx.RootBucket().Bucket(addrBucketName) - var addrs []interface{} err := bucket.ForEach(func(k, v []byte) error { // Skip buckets. if v == nil { @@ -1253,14 +1241,12 @@ func fetchAllAddresses(tx walletdb.Tx) ([]interface{}, error) { return err } - addrs = append(addrs, addrRow) - return nil + return fn(addrRow) }) if err != nil { - return nil, maybeConvertDbError(err) + return maybeConvertDbError(err) } - - return addrs, nil + return nil } // deletePrivateKeys removes all private key material from the database. diff --git a/waddrmgr/error.go b/waddrmgr/error.go index 22dd588..2753576 100644 --- a/waddrmgr/error.go +++ b/waddrmgr/error.go @@ -131,6 +131,10 @@ const ( // ErrWrongNet indicates that the private key to be imported is not for the // the same network the account manager is configured for. ErrWrongNet + + // ErrCallBackBreak is used to break from a callback function passed + // down to the manager. + ErrCallBackBreak ) // Map of ErrorCode values back to their constant names for pretty printing. @@ -154,6 +158,7 @@ var errorCodeStrings = map[ErrorCode]string{ ErrTooManyAddresses: "ErrTooManyAddresses", ErrWrongPassphrase: "ErrWrongPassphrase", ErrWrongNet: "ErrWrongNet", + ErrCallBackBreak: "ErrCallBackBreak", } // String returns the ErrorCode as a human-readable name. @@ -195,3 +200,7 @@ func (e ManagerError) Error() string { func managerError(c ErrorCode, desc string, err error) ManagerError { return ManagerError{ErrorCode: c, Description: desc, Err: err} } + +// Break is a global err used to signal a break from the callback +// function by returning an error with the code ErrCallBackBreak +var Break = managerError(ErrCallBackBreak, "callback break", nil) diff --git a/waddrmgr/manager.go b/waddrmgr/manager.go index 9642655..97a7d94 100644 --- a/waddrmgr/manager.go +++ b/waddrmgr/manager.go @@ -1821,22 +1821,14 @@ func (m *Manager) AccountName(account uint32) (string, error) { return acctName, nil } -// AllAccounts returns a slice of all the accounts stored in the manager. -func (m *Manager) AllAccounts() ([]uint32, error) { +// ForEachAccount calls the given function with each account stored in the +// manager, breaking early on error. +func (m *Manager) ForEachAccount(fn func(account uint32) error) error { m.mtx.Lock() defer m.mtx.Unlock() - - var accounts []uint32 - err := m.namespace.View(func(tx walletdb.Tx) error { - var err error - accounts, err = fetchAllAccounts(tx) - return err + return m.namespace.View(func(tx walletdb.Tx) error { + return forEachAccount(tx, fn) }) - if err != nil { - return nil, err - } - - return accounts, nil } // LastAccount returns the last account stored in the manager. @@ -1853,73 +1845,58 @@ func (m *Manager) LastAccount() (uint32, error) { return account, err } -// AllAccountAddresses returns a slice of addresses of an account stored in the manager. -func (m *Manager) AllAccountAddresses(account uint32) ([]ManagedAddress, error) { +// ForEachAccountAddress calls the given function with each address of +// the given account stored in the manager, breaking early on error. +func (m *Manager) ForEachAccountAddress(account uint32, fn func(maddr ManagedAddress) error) error { m.mtx.Lock() defer m.mtx.Unlock() - // Load the raw address information from the database. - var rowInterfaces []interface{} - err := m.namespace.View(func(tx walletdb.Tx) error { - var err error - rowInterfaces, err = fetchAccountAddresses(tx, account) - return err - }) - if err != nil { - return nil, err - } - - addrs := make([]ManagedAddress, 0, len(rowInterfaces)) - for _, rowInterface := range rowInterfaces { - // Create a new managed address for the specific type of address - // based on type. + addrFn := func(rowInterface interface{}) error { managedAddr, err := m.rowInterfaceToManaged(rowInterface) if err != nil { - return nil, err + return err } - - addrs = append(addrs, managedAddr) + return fn(managedAddr) } - return addrs, nil + err := m.namespace.View(func(tx walletdb.Tx) error { + return forEachAccountAddress(tx, account, addrFn) + }) + if err != nil { + return maybeConvertDbError(err) + } + return nil } -// ActiveAccountAddresses returns a slice of active addresses of an account -// stored in the manager. +// ForEachActiveAccountAddress calls the given function with each active +// address of the given account stored in the manager, breaking early on +// error. // TODO(tuxcanfly): actually return only active addresses -func (m *Manager) ActiveAccountAddresses(account uint32) ([]ManagedAddress, error) { - return m.AllAccountAddresses(account) +func (m *Manager) ForEachActiveAccountAddress(account uint32, fn func(maddr ManagedAddress) error) error { + return m.ForEachAccountAddress(account, fn) } -// AllActiveAddresses returns a slice of all addresses stored in the manager. -func (m *Manager) AllActiveAddresses() ([]btcutil.Address, error) { +// ForEachActiveAddress calls the given function with each active address +// stored in the manager, breaking early on error. +func (m *Manager) ForEachActiveAddress(fn func(addr btcutil.Address) error) error { m.mtx.Lock() defer m.mtx.Unlock() - // Load the raw address information from the database. - var rowInterfaces []interface{} - err := m.namespace.View(func(tx walletdb.Tx) error { - var err error - rowInterfaces, err = fetchAllAddresses(tx) - return err - }) - if err != nil { - return nil, maybeConvertDbError(err) - } - - addrs := make([]btcutil.Address, 0, len(rowInterfaces)) - for _, rowInterface := range rowInterfaces { - // Create a new managed address for the specific type of address - // based on type. + addrFn := func(rowInterface interface{}) error { managedAddr, err := m.rowInterfaceToManaged(rowInterface) if err != nil { - return nil, err + return err } - - addrs = append(addrs, managedAddr.Address()) + return fn(managedAddr.Address()) } - return addrs, nil + err := m.namespace.View(func(tx walletdb.Tx) error { + return forEachActiveAddress(tx, addrFn) + }) + if err != nil { + return maybeConvertDbError(err) + } + return nil } // selectCryptoKey selects the appropriate crypto key based on the key type. An diff --git a/waddrmgr/manager_test.go b/waddrmgr/manager_test.go index 958acf3..9f3ee99 100644 --- a/waddrmgr/manager_test.go +++ b/waddrmgr/manager_test.go @@ -1267,9 +1267,10 @@ func testRenameAccount(tc *testContext) bool { return true } -// testAllAccounts tests the retrieve all accounts func of the address manager works -// as expected. -func testAllAccounts(tc *testContext) bool { +// testForEachAccount tests the retrieve all accounts func of the address +// manager works as expected. +func testForEachAccount(tc *testContext) bool { + prefix := testNamePrefix(tc) + " testForEachAccount" expectedAccounts := []uint32{0, 1} if !tc.create { // Existing wallet manager will have 3 accounts @@ -1277,38 +1278,47 @@ func testAllAccounts(tc *testContext) bool { } // Imported account expectedAccounts = append(expectedAccounts, waddrmgr.ImportedAddrAccount) - accounts, err := tc.manager.AllAccounts() + var accounts []uint32 + err := tc.manager.ForEachAccount(func(account uint32) error { + accounts = append(accounts, account) + return nil + }) if err != nil { - tc.t.Errorf("AllAccounts: unexpected error: %v", err) + tc.t.Errorf("%s: unexpected error: %v", prefix, err) return false } if len(accounts) != len(expectedAccounts) { - tc.t.Errorf("AllAccounts: unexpected number of accounts - got "+ - "%d, want %d", len(accounts), + tc.t.Errorf("%s: unexpected number of accounts - got "+ + "%d, want %d", prefix, len(accounts), len(expectedAccounts)) return false } for i, account := range accounts { if expectedAccounts[i] != account { - tc.t.Errorf("AllAccounts %s: "+ + tc.t.Errorf("%s #%d: "+ "account mismatch -- got %d, "+ - "want %d", i, account, expectedAccounts[i]) + "want %d", prefix, i, account, expectedAccounts[i]) } } return true } -// testAllAccountAddresses tests the account addresses returned by the manager -// API. -func testAllAccountAddresses(tc *testContext) bool { - prefix := testNamePrefix(tc) + " testAllAccountAddresses" +// testForEachAccountAddress tests that iterating through the given +// account addresses using the manager API works as expected. +func testForEachAccountAddress(tc *testContext) bool { + prefix := testNamePrefix(tc) + " testForEachAccountAddress" // Make a map of expected addresses expectedAddrMap := make(map[string]*expectedAddr, len(expectedAddrs)) for i := 0; i < len(expectedAddrs); i++ { expectedAddrMap[expectedAddrs[i].address] = &expectedAddrs[i] } - addrs, err := tc.manager.AllAccountAddresses(tc.account) + var addrs []waddrmgr.ManagedAddress + err := tc.manager.ForEachAccountAddress(tc.account, + func(maddr waddrmgr.ManagedAddress) error { + addrs = append(addrs, maddr) + return nil + }) if err != nil { tc.t.Errorf("%s: unexpected error: %v", prefix, err) return false @@ -1349,8 +1359,8 @@ func testManagerAPI(tc *testContext) { tc.account = 0 testNewAccount(tc) testLookupAccount(tc) - testAllAccounts(tc) - testAllAccountAddresses(tc) + testForEachAccount(tc) + testForEachAccountAddress(tc) // Rename account 1 "acct-create" tc.account = 1 diff --git a/wallet/wallet.go b/wallet/wallet.go index 912d264..93fe1fa 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -370,7 +370,11 @@ func (w *Wallet) SetChainSynced(synced bool) { // outputs. This is primarely intended to provide the parameters for a // rescan request. func (w *Wallet) activeData() ([]btcutil.Address, []wtxmgr.Credit, error) { - addrs, err := w.Manager.AllActiveAddresses() + var addrs []btcutil.Address + err := w.Manager.ForEachActiveAddress(func(addr btcutil.Address) error { + addrs = append(addrs, addr) + return nil + }) if err != nil { return nil, nil, err } @@ -669,20 +673,23 @@ func (w *Wallet) ChangePassphrase(old, new []byte) error { // a given account. It returns true if atleast one address in the account was // used and false if no address in the account was used. func (w *Wallet) AccountUsed(account uint32) (bool, error) { - addrs, err := w.Manager.AllAccountAddresses(account) - if err != nil { - return false, err + var used bool + var err error + merr := w.Manager.ForEachAccountAddress(account, + func(maddr waddrmgr.ManagedAddress) error { + used, err = maddr.Used() + if err != nil { + return err + } + if used { + return waddrmgr.Break + } + return nil + }) + if merr == waddrmgr.Break { + merr = nil } - for _, addr := range addrs { - used, err := addr.Used() - if err != nil { - return false, err - } - if used { - return true, nil - } - } - return false, nil + return used, merr } // CalculateBalance sums the amounts of all unspent transaction @@ -1203,24 +1210,19 @@ func (w *Wallet) ListUnspent(minconf, maxconf int32, // DumpPrivKeys returns the WIF-encoded private keys for all addresses with // private keys in a wallet. func (w *Wallet) DumpPrivKeys() ([]string, error) { - addrs, err := w.Manager.AllActiveAddresses() - if err != nil { - return nil, err - } - + var privkeys []string // Iterate over each active address, appending the private key to // privkeys. - privkeys := make([]string, 0, len(addrs)) - for _, addr := range addrs { + err := w.Manager.ForEachActiveAddress(func(addr btcutil.Address) error { ma, err := w.Manager.Address(addr) if err != nil { - return nil, err + return err } // Only those addresses with keys needed. pka, ok := ma.(waddrmgr.ManagedPubKeyAddress) if !ok { - continue + return nil } wif, err := pka.ExportPrivKey() @@ -1228,12 +1230,12 @@ func (w *Wallet) DumpPrivKeys() ([]string, error) { // It would be nice to zero out the array here. However, // since strings in go are immutable, and we have no // control over the caller I don't think we can. :( - return nil, err + return err } privkeys = append(privkeys, wif.String()) - } - - return privkeys, nil + return nil + }) + return privkeys, err } // DumpWIFPrivateKey returns the WIF encoded private key for a @@ -1436,16 +1438,15 @@ func (w *Wallet) ResendUnminedTxs() { // SortedActivePaymentAddresses returns a slice of all active payment // addresses in a wallet. func (w *Wallet) SortedActivePaymentAddresses() ([]string, error) { - addrs, err := w.Manager.AllActiveAddresses() + var addrStrs []string + err := w.Manager.ForEachActiveAddress(func(addr btcutil.Address) error { + addrStrs = append(addrStrs, addr.EncodeAddress()) + return nil + }) if err != nil { return nil, err } - addrStrs := make([]string, len(addrs)) - for i, addr := range addrs { - addrStrs[i] = addr.EncodeAddress() - } - sort.Sort(sort.StringSlice(addrStrs)) return addrStrs, nil }