Merge pull request #734 from bhandras/external_db

wallet: allow using external wallet db
This commit is contained in:
Olaoluwa Osuntokun 2021-04-29 15:48:04 -07:00 committed by GitHub
commit a7a9234968
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 135 additions and 43 deletions

View file

@ -1813,8 +1813,6 @@ func testSync(tc *testContext) bool {
// It makes use of a test context because the address manager is persistent and // It makes use of a test context because the address manager is persistent and
// much of the testing involves having specific state. // much of the testing involves having specific state.
func TestManager(t *testing.T) { func TestManager(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
createdWatchingOnly bool createdWatchingOnly bool

View file

@ -6,6 +6,7 @@ package wallet
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -55,6 +56,9 @@ type Loader struct {
timeout time.Duration timeout time.Duration
recoveryWindow uint32 recoveryWindow uint32
wallet *Wallet wallet *Wallet
localDB bool
walletExists func() (bool, error)
walletCreated func(db walletdb.ReadWriteTx) error
db walletdb.DB db walletdb.DB
mu sync.Mutex mu sync.Mutex
} }
@ -72,18 +76,42 @@ func NewLoader(chainParams *chaincfg.Params, dbDirPath string,
noFreelistSync: noFreelistSync, noFreelistSync: noFreelistSync,
timeout: timeout, timeout: timeout,
recoveryWindow: recoveryWindow, recoveryWindow: recoveryWindow,
localDB: true,
} }
} }
// NewLoaderWithDB constructs a Loader with an externally provided DB. This way
// users are free to use their own walletdb implementation (eg. leveldb, etcd)
// to store the wallet. Given that the external DB may be shared an additional
// function is also passed which will override Loader.WalletExists().
func NewLoaderWithDB(chainParams *chaincfg.Params, recoveryWindow uint32,
db walletdb.DB, walletExists func() (bool, error)) (*Loader, error) {
if db == nil {
return nil, fmt.Errorf("no DB provided")
}
if walletExists == nil {
return nil, fmt.Errorf("unable to check if wallet exists")
}
return &Loader{
chainParams: chainParams,
recoveryWindow: recoveryWindow,
localDB: false,
walletExists: walletExists,
db: db,
}, nil
}
// onLoaded executes each added callback and prevents loader from loading any // onLoaded executes each added callback and prevents loader from loading any
// additional wallets. Requires mutex to be locked. // additional wallets. Requires mutex to be locked.
func (l *Loader) onLoaded(w *Wallet, db walletdb.DB) { func (l *Loader) onLoaded(w *Wallet) {
for _, fn := range l.callbacks { for _, fn := range l.callbacks {
fn(w) fn(w)
} }
l.wallet = w l.wallet = w
l.db = db
l.callbacks = nil // not needed anymore l.callbacks = nil // not needed anymore
} }
@ -102,6 +130,15 @@ func (l *Loader) RunAfterLoad(fn func(*Wallet)) {
} }
} }
// OnWalletCreated adds a function that will be executed the wallet structure
// is initialized in the wallet database. This is useful if users want to add
// extra fields in the same transaction (eg. to flag wallet existence).
func (l *Loader) OnWalletCreated(fn func(walletdb.ReadWriteTx) error) {
l.mu.Lock()
defer l.mu.Unlock()
l.walletCreated = fn
}
// CreateNewWallet creates a new wallet using the provided public and private // CreateNewWallet creates a new wallet using the provided public and private
// passphrases. The seed is optional. If non-nil, addresses are derived from // passphrases. The seed is optional. If non-nil, addresses are derived from
// this seed. If nil, a secure random seed is generated. // this seed. If nil, a secure random seed is generated.
@ -134,8 +171,7 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase,
return nil, ErrLoaded return nil, ErrLoaded
} }
dbPath := filepath.Join(l.dbDirPath, WalletDBName) exists, err := l.WalletExists()
exists, err := fileExists(dbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,25 +179,35 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase,
return nil, ErrExists return nil, ErrExists
} }
if l.localDB {
dbPath := filepath.Join(l.dbDirPath, WalletDBName)
// Create the wallet database backed by bolt db. // Create the wallet database backed by bolt db.
err = os.MkdirAll(l.dbDirPath, 0700) err = os.MkdirAll(l.dbDirPath, 0700)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db, err := walletdb.Create("bdb", dbPath, l.noFreelistSync, l.timeout) l.db, err = walletdb.Create(
"bdb", dbPath, l.noFreelistSync, l.timeout,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
// Initialize the newly created database for the wallet before opening. // Initialize the newly created database for the wallet before opening.
if isWatchingOnly { if isWatchingOnly {
err = CreateWatchingOnly(db, pubPassphrase, l.chainParams, bday) err := CreateWatchingOnlyWithCallback(
l.db, pubPassphrase, l.chainParams, bday,
l.walletCreated,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
err = Create( err := CreateWithCallback(
db, pubPassphrase, privPassphrase, seed, l.chainParams, bday, l.db, pubPassphrase, privPassphrase, seed,
l.chainParams, bday, l.walletCreated,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -169,13 +215,13 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase,
} }
// Open the newly-created wallet. // Open the newly-created wallet.
w, err := Open(db, pubPassphrase, nil, l.chainParams, l.recoveryWindow) w, err := Open(l.db, pubPassphrase, nil, l.chainParams, l.recoveryWindow)
if err != nil { if err != nil {
return nil, err return nil, err
} }
w.Start() w.Start()
l.onLoaded(w, db) l.onLoaded(w)
return w, nil return w, nil
} }
@ -197,18 +243,23 @@ func (l *Loader) OpenExistingWallet(pubPassphrase []byte, canConsolePrompt bool)
return nil, ErrLoaded return nil, ErrLoaded
} }
if l.localDB {
var err error
// Ensure that the network directory exists. // Ensure that the network directory exists.
if err := checkCreateDir(l.dbDirPath); err != nil { if err = checkCreateDir(l.dbDirPath); err != nil {
return nil, err return nil, err
} }
// Open the database using the boltdb backend. // Open the database using the boltdb backend.
dbPath := filepath.Join(l.dbDirPath, WalletDBName) dbPath := filepath.Join(l.dbDirPath, WalletDBName)
db, err := walletdb.Open("bdb", dbPath, l.noFreelistSync, l.timeout) l.db, err = walletdb.Open(
"bdb", dbPath, l.noFreelistSync, l.timeout,
)
if err != nil { if err != nil {
log.Errorf("Failed to open database: %v", err) log.Errorf("Failed to open database: %v", err)
return nil, err return nil, err
} }
}
var cbs *waddrmgr.OpenCallbacks var cbs *waddrmgr.OpenCallbacks
if canConsolePrompt { if canConsolePrompt {
@ -222,28 +273,35 @@ func (l *Loader) OpenExistingWallet(pubPassphrase []byte, canConsolePrompt bool)
ObtainPrivatePass: noConsole, ObtainPrivatePass: noConsole,
} }
} }
w, err := Open(db, pubPassphrase, cbs, l.chainParams, l.recoveryWindow) w, err := Open(l.db, pubPassphrase, cbs, l.chainParams, l.recoveryWindow)
if err != nil { if err != nil {
// If opening the wallet fails (e.g. because of wrong // If opening the wallet fails (e.g. because of wrong
// passphrase), we must close the backing database to // passphrase), we must close the backing database to
// allow future calls to walletdb.Open(). // allow future calls to walletdb.Open().
e := db.Close() if l.localDB {
e := l.db.Close()
if e != nil { if e != nil {
log.Warnf("Error closing database: %v", e) log.Warnf("Error closing database: %v", e)
} }
}
return nil, err return nil, err
} }
w.Start() w.Start()
l.onLoaded(w, db) l.onLoaded(w)
return w, nil return w, nil
} }
// WalletExists returns whether a file exists at the loader's database path. // WalletExists returns whether a file exists at the loader's database path.
// This may return an error for unexpected I/O failures. // This may return an error for unexpected I/O failures.
func (l *Loader) WalletExists() (bool, error) { func (l *Loader) WalletExists() (bool, error) {
if l.localDB {
dbPath := filepath.Join(l.dbDirPath, WalletDBName) dbPath := filepath.Join(l.dbDirPath, WalletDBName)
return fileExists(dbPath) return fileExists(dbPath)
}
return l.walletExists()
} }
// LoadedWallet returns the loaded wallet, if any, and a bool for whether the // LoadedWallet returns the loaded wallet, if any, and a bool for whether the
@ -270,10 +328,12 @@ func (l *Loader) UnloadWallet() error {
l.wallet.Stop() l.wallet.Stop()
l.wallet.WaitForShutdown() l.wallet.WaitForShutdown()
if l.localDB {
err := l.db.Close() err := l.db.Close()
if err != nil { if err != nil {
return err return err
} }
}
l.wallet = nil l.wallet = nil
l.db = nil l.db = nil

View file

@ -3665,6 +3665,29 @@ func (w *Wallet) Database() walletdb.DB {
return w.db return w.db
} }
// CreateWithCallback is the same as Create with an added callback that will be
// called in the same transaction the wallet structure is initialized.
func CreateWithCallback(db walletdb.DB, pubPass, privPass, seed []byte,
params *chaincfg.Params, birthday time.Time,
cb func(walletdb.ReadWriteTx) error) error {
return create(
db, pubPass, privPass, seed, params, birthday, false, cb,
)
}
// CreateWatchingOnlyWithCallback is the same as CreateWatchingOnly with an
// added callback that will be called in the same transaction the wallet
// structure is initialized.
func CreateWatchingOnlyWithCallback(db walletdb.DB, pubPass []byte,
params *chaincfg.Params, birthday time.Time,
cb func(walletdb.ReadWriteTx) error) error {
return create(
db, pubPass, nil, nil, params, birthday, true, cb,
)
}
// Create creates an new wallet, writing it to an empty database. If the passed // Create creates an new wallet, writing it to an empty database. If the passed
// seed is non-nil, it is used. Otherwise, a secure random seed of the // seed is non-nil, it is used. Otherwise, a secure random seed of the
// recommended length is generated. // recommended length is generated.
@ -3672,7 +3695,7 @@ func Create(db walletdb.DB, pubPass, privPass, seed []byte,
params *chaincfg.Params, birthday time.Time) error { params *chaincfg.Params, birthday time.Time) error {
return create( return create(
db, pubPass, privPass, seed, params, birthday, false, db, pubPass, privPass, seed, params, birthday, false, nil,
) )
} }
@ -3684,12 +3707,13 @@ func CreateWatchingOnly(db walletdb.DB, pubPass []byte,
params *chaincfg.Params, birthday time.Time) error { params *chaincfg.Params, birthday time.Time) error {
return create( return create(
db, pubPass, nil, nil, params, birthday, true, db, pubPass, nil, nil, params, birthday, true, nil,
) )
} }
func create(db walletdb.DB, pubPass, privPass, seed []byte, func create(db walletdb.DB, pubPass, privPass, seed []byte,
params *chaincfg.Params, birthday time.Time, isWatchingOnly bool) error { params *chaincfg.Params, birthday time.Time, isWatchingOnly bool,
cb func(walletdb.ReadWriteTx) error) error {
if !isWatchingOnly { if !isWatchingOnly {
// If a seed was provided, ensure that it is of valid length. Otherwise, // If a seed was provided, ensure that it is of valid length. Otherwise,
@ -3725,7 +3749,17 @@ func create(db walletdb.DB, pubPass, privPass, seed []byte,
if err != nil { if err != nil {
return err return err
} }
return wtxmgr.Create(txmgrNs)
err = wtxmgr.Create(txmgrNs)
if err != nil {
return err
}
if cb != nil {
return cb(tx)
}
return nil
}) })
} }