loader: add txn callback when wallet is created

This commit is contained in:
Andras Banki-Horvath 2021-04-28 16:43:59 +02:00
parent a795db6b12
commit 98ba16748e
No known key found for this signature in database
GPG key ID: 80E5375C094198D8
2 changed files with 52 additions and 7 deletions

View file

@ -58,6 +58,7 @@ type Loader struct {
wallet *Wallet wallet *Wallet
localDB bool localDB bool
walletExists func() (bool, error) walletExists func() (bool, error)
walletCreated func(db walletdb.ReadWriteTx) error
db walletdb.DB db walletdb.DB
mu sync.Mutex mu sync.Mutex
} }
@ -129,6 +130,15 @@ func (l *Loader) RunAfterLoad(fn func(*Wallet)) {
} }
} }
// OnWalletCreated adds a function that will be executed the wallet structure
// is initialized in the wallet database. This is useful if users want to add
// extra fields in the same transaction (eg. to flag wallet existence).
func (l *Loader) OnWalletCreated(fn func(walletdb.ReadWriteTx) error) {
l.mu.Lock()
defer l.mu.Unlock()
l.walletCreated = fn
}
// CreateNewWallet creates a new wallet using the provided public and private // CreateNewWallet creates a new wallet using the provided public and private
// passphrases. The seed is optional. If non-nil, addresses are derived from // passphrases. The seed is optional. If non-nil, addresses are derived from
// this seed. If nil, a secure random seed is generated. // this seed. If nil, a secure random seed is generated.
@ -187,16 +197,17 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase,
// Initialize the newly created database for the wallet before opening. // Initialize the newly created database for the wallet before opening.
if isWatchingOnly { if isWatchingOnly {
err := CreateWatchingOnly( err := CreateWatchingOnlyWithCallback(
l.db, pubPassphrase, l.chainParams, bday, l.db, pubPassphrase, l.chainParams, bday,
l.walletCreated,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
err := Create( err := CreateWithCallback(
l.db, pubPassphrase, privPassphrase, seed, l.db, pubPassphrase, privPassphrase, seed,
l.chainParams, bday, l.chainParams, bday, l.walletCreated,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -3665,6 +3665,29 @@ func (w *Wallet) Database() walletdb.DB {
return w.db return w.db
} }
// CreateWithCallback is the same as Create with an added callback that will be
// called in the same transaction the wallet structure is initialized.
func CreateWithCallback(db walletdb.DB, pubPass, privPass, seed []byte,
params *chaincfg.Params, birthday time.Time,
cb func(walletdb.ReadWriteTx) error) error {
return create(
db, pubPass, privPass, seed, params, birthday, false, cb,
)
}
// CreateWatchingOnlyWithCallback is the same as CreateWatchingOnly with an
// added callback that will be called in the same transaction the wallet
// structure is initialized.
func CreateWatchingOnlyWithCallback(db walletdb.DB, pubPass []byte,
params *chaincfg.Params, birthday time.Time,
cb func(walletdb.ReadWriteTx) error) error {
return create(
db, pubPass, nil, nil, params, birthday, true, cb,
)
}
// Create creates an new wallet, writing it to an empty database. If the passed // Create creates an new wallet, writing it to an empty database. If the passed
// seed is non-nil, it is used. Otherwise, a secure random seed of the // seed is non-nil, it is used. Otherwise, a secure random seed of the
// recommended length is generated. // recommended length is generated.
@ -3672,7 +3695,7 @@ func Create(db walletdb.DB, pubPass, privPass, seed []byte,
params *chaincfg.Params, birthday time.Time) error { params *chaincfg.Params, birthday time.Time) error {
return create( return create(
db, pubPass, privPass, seed, params, birthday, false, db, pubPass, privPass, seed, params, birthday, false, nil,
) )
} }
@ -3684,12 +3707,13 @@ func CreateWatchingOnly(db walletdb.DB, pubPass []byte,
params *chaincfg.Params, birthday time.Time) error { params *chaincfg.Params, birthday time.Time) error {
return create( return create(
db, pubPass, nil, nil, params, birthday, true, db, pubPass, nil, nil, params, birthday, true, nil,
) )
} }
func create(db walletdb.DB, pubPass, privPass, seed []byte, func create(db walletdb.DB, pubPass, privPass, seed []byte,
params *chaincfg.Params, birthday time.Time, isWatchingOnly bool) error { params *chaincfg.Params, birthday time.Time, isWatchingOnly bool,
cb func(walletdb.ReadWriteTx) error) error {
if !isWatchingOnly { if !isWatchingOnly {
// If a seed was provided, ensure that it is of valid length. Otherwise, // If a seed was provided, ensure that it is of valid length. Otherwise,
@ -3725,7 +3749,17 @@ func create(db walletdb.DB, pubPass, privPass, seed []byte,
if err != nil { if err != nil {
return err return err
} }
return wtxmgr.Create(txmgrNs)
err = wtxmgr.Create(txmgrNs)
if err != nil {
return err
}
if cb != nil {
return cb(tx)
}
return nil
}) })
} }