From 98ba16748e6ca077d53dab638decd1cca07409d2 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Wed, 28 Apr 2021 16:43:59 +0200 Subject: [PATCH] loader: add txn callback when wallet is created --- wallet/loader.go | 17 ++++++++++++++--- wallet/wallet.go | 42 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/wallet/loader.go b/wallet/loader.go index 32f3bd5..89812a1 100644 --- a/wallet/loader.go +++ b/wallet/loader.go @@ -58,6 +58,7 @@ type Loader struct { wallet *Wallet localDB bool walletExists func() (bool, error) + walletCreated func(db walletdb.ReadWriteTx) error db walletdb.DB mu sync.Mutex } @@ -129,6 +130,15 @@ func (l *Loader) RunAfterLoad(fn func(*Wallet)) { } } +// OnWalletCreated adds a function that will be executed the wallet structure +// is initialized in the wallet database. This is useful if users want to add +// extra fields in the same transaction (eg. to flag wallet existence). +func (l *Loader) OnWalletCreated(fn func(walletdb.ReadWriteTx) error) { + l.mu.Lock() + defer l.mu.Unlock() + l.walletCreated = fn +} + // CreateNewWallet creates a new wallet using the provided public and private // passphrases. The seed is optional. If non-nil, addresses are derived from // this seed. If nil, a secure random seed is generated. @@ -187,16 +197,17 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase, // Initialize the newly created database for the wallet before opening. if isWatchingOnly { - err := CreateWatchingOnly( + err := CreateWatchingOnlyWithCallback( l.db, pubPassphrase, l.chainParams, bday, + l.walletCreated, ) if err != nil { return nil, err } } else { - err := Create( + err := CreateWithCallback( l.db, pubPassphrase, privPassphrase, seed, - l.chainParams, bday, + l.chainParams, bday, l.walletCreated, ) if err != nil { return nil, err diff --git a/wallet/wallet.go b/wallet/wallet.go index 549199f..7e85a89 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -3665,6 +3665,29 @@ func (w *Wallet) Database() walletdb.DB { return w.db } +// CreateWithCallback is the same as Create with an added callback that will be +// called in the same transaction the wallet structure is initialized. +func CreateWithCallback(db walletdb.DB, pubPass, privPass, seed []byte, + params *chaincfg.Params, birthday time.Time, + cb func(walletdb.ReadWriteTx) error) error { + + return create( + db, pubPass, privPass, seed, params, birthday, false, cb, + ) +} + +// CreateWatchingOnlyWithCallback is the same as CreateWatchingOnly with an +// added callback that will be called in the same transaction the wallet +// structure is initialized. +func CreateWatchingOnlyWithCallback(db walletdb.DB, pubPass []byte, + params *chaincfg.Params, birthday time.Time, + cb func(walletdb.ReadWriteTx) error) error { + + return create( + db, pubPass, nil, nil, params, birthday, true, cb, + ) +} + // Create creates an new wallet, writing it to an empty database. If the passed // seed is non-nil, it is used. Otherwise, a secure random seed of the // recommended length is generated. @@ -3672,7 +3695,7 @@ func Create(db walletdb.DB, pubPass, privPass, seed []byte, params *chaincfg.Params, birthday time.Time) error { return create( - db, pubPass, privPass, seed, params, birthday, false, + db, pubPass, privPass, seed, params, birthday, false, nil, ) } @@ -3684,12 +3707,13 @@ func CreateWatchingOnly(db walletdb.DB, pubPass []byte, params *chaincfg.Params, birthday time.Time) error { return create( - db, pubPass, nil, nil, params, birthday, true, + db, pubPass, nil, nil, params, birthday, true, nil, ) } func create(db walletdb.DB, pubPass, privPass, seed []byte, - params *chaincfg.Params, birthday time.Time, isWatchingOnly bool) error { + params *chaincfg.Params, birthday time.Time, isWatchingOnly bool, + cb func(walletdb.ReadWriteTx) error) error { if !isWatchingOnly { // If a seed was provided, ensure that it is of valid length. Otherwise, @@ -3725,7 +3749,17 @@ func create(db walletdb.DB, pubPass, privPass, seed []byte, if err != nil { return err } - return wtxmgr.Create(txmgrNs) + + err = wtxmgr.Create(txmgrNs) + if err != nil { + return err + } + + if cb != nil { + return cb(tx) + } + + return nil }) }