From e64d948093422e8ad11662ff241ba660db73d457 Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Tue, 1 Jul 2014 10:02:13 -0500 Subject: [PATCH] Synchronize locking/unlocking of all keystores. This change fixes the asynchronous deferred locking that used to be performed after some timeout after a call to walletpassphrase by managing the locked state of each account in a new account manager goroutine. The timeouts for new unlock requests replace any running timeouts for older requests, rather than allowing previous timeouts to expire before the most recent one. Fixes #105. --- acctmgr.go | 210 +++++++++++++++++++++++++++-------------------- rpcserver.go | 43 ++-------- wallet/wallet.go | 18 ++-- 3 files changed, 139 insertions(+), 132 deletions(-) diff --git a/acctmgr.go b/acctmgr.go index 6e6dd86..7790d9b 100644 --- a/acctmgr.go +++ b/acctmgr.go @@ -31,6 +31,7 @@ import ( "path/filepath" "strings" "sync" + "time" ) // Errors relating to accounts. @@ -44,45 +45,56 @@ var ( // AcctMgr is the global account manager for all opened accounts. var AcctMgr = NewAccountManager() -type openAccountsCmd struct{} +type ( + openAccountsCmd struct{} -type accessAccountRequest struct { - name string - resp chan *Account + accessAccountRequest struct { + name string + resp chan *Account + } + + accessAllRequest struct { + resp chan []*Account + } + + accessAccountByAddressRequest struct { + address string + resp chan *Account + } + + markAddressForAccountCmd struct { + address string + account *Account + } + + addAccountCmd struct { + a *Account + } + + removeAccountCmd struct { + a *Account + } + + quitCmd struct{} +) + +type unlockRequest struct { + passphrase []byte + timeout time.Duration // Zero value prevents the timeout. + err chan error } -type accessAllRequest struct { - resp chan []*Account -} - -type accessAccountByAddressRequest struct { - address string - resp chan *Account -} - -type markAddressForAccountCmd struct { - address string - account *Account -} - -type addAccountCmd struct { - a *Account -} - -type removeAccountCmd struct { - a *Account -} - -type quitCmd struct{} - // AccountManager manages a collection of accounts. type AccountManager struct { // The accounts accessed through the account manager are not safe for // concurrent access. The account manager therefore contains a // binary semaphore channel to prevent incorrect access. - bsem chan struct{} - cmdChan chan interface{} - rescanMsgs chan RescanMsg + bsem chan struct{} + cmdChan chan interface{} + rescanMsgs chan RescanMsg + unlockRequests chan unlockRequest + lockRequests chan struct{} + unlockedState chan bool ds *DiskSyncer rm *RescanManager @@ -94,10 +106,14 @@ type AccountManager struct { // NewAccountManager returns a new AccountManager. func NewAccountManager() *AccountManager { am := &AccountManager{ - bsem: make(chan struct{}, 1), - cmdChan: make(chan interface{}), - rescanMsgs: make(chan RescanMsg, 1), - quit: make(chan struct{}), + bsem: make(chan struct{}, 1), + cmdChan: make(chan interface{}), + rescanMsgs: make(chan RescanMsg, 1), + unlockRequests: make(chan unlockRequest), + lockRequests: make(chan struct{}), + unlockedState: make(chan bool), + + quit: make(chan struct{}), } am.ds = NewDiskSyncer(am) am.rm = NewRescanManager(am.rescanMsgs) @@ -109,8 +125,9 @@ func (am *AccountManager) Start() { // Ready the semaphore - can't grab unless the manager has started. am.bsem <- struct{}{} - am.wg.Add(2) + am.wg.Add(3) go am.accountHandler() + go am.keystoreLocker() go am.rescanListener() go am.ds.Start() @@ -474,6 +491,53 @@ out: am.wg.Done() } +// keystoreLocker manages the lockedness state of all account keystores. +func (am *AccountManager) keystoreLocker() { + unlocked := false + var timeout <-chan time.Time +out: + for { + select { + case req := <-am.unlockRequests: + for _, a := range am.AllAccounts() { + if err := a.Unlock(req.passphrase); err != nil { + req.err <- err + continue out + } + } + unlocked = true + if req.timeout == 0 { + timeout = nil + } else { + timeout = time.After(req.timeout) + } + req.err <- nil + continue + + case am.unlockedState <- unlocked: + continue + + case <-am.quit: + break out + + case <-am.lockRequests: + case <-timeout: + } + + // Select statement fell through by an explicit lock or the + // timer expiring. Lock the keystores here. + timeout = nil + for _, a := range am.AllAccounts() { + if err := a.Lock(); err != nil { + log.Errorf("Could not lock wallet for account '%s': %v", + a.name, err) + } + } + unlocked = false + } + am.wg.Done() +} + // rescanListener listens for messages from the rescan manager and marks // accounts and addresses as synced. func (am *AccountManager) rescanListener() { @@ -782,75 +846,43 @@ func (am *AccountManager) CreateEncryptedWallet(passphrase []byte) error { // ChangePassphrase unlocks all account wallets with the old // passphrase, and re-encrypts each using the new passphrase. func (am *AccountManager) ChangePassphrase(old, new []byte) error { - accts := am.AllAccounts() - - for _, a := range accts { - if !a.IsLocked() { - if err := a.Wallet.Lock(); err != nil { - return err - } - } - - if err := a.Wallet.Unlock(old); err != nil { - return err - } - defer func(a *Account) { - if err := a.Lock(); err != nil { - log.Warnf("Cannot lock account: %v", err) - } - }(a) + // Keystores must be unlocked to change their passphrase. + err := am.UnlockWallets(old, 0) + if err != nil { + return err } + accts := am.AllAccounts() + // Change passphrase for each unlocked wallet. for _, a := range accts { - if err := a.Wallet.ChangePassphrase(new); err != nil { + err = a.Wallet.ChangePassphrase(new) + if err != nil { return err } } + am.LockWallets() + // Immediately write out to disk. return am.ds.WriteBatch(accts) } // LockWallets locks all managed account wallets. -func (am *AccountManager) LockWallets() error { - for _, a := range am.AllAccounts() { - if err := a.Lock(); err != nil { - return err - } - } - - return nil +func (am *AccountManager) LockWallets() { + am.lockRequests <- struct{}{} } -// UnlockWallets unlocks all managed account's wallets. If any wallet unlocks -// fail, all successfully unlocked wallets are locked again. -func (am *AccountManager) UnlockWallets(passphrase string) (err error) { - accts := am.AllAccounts() - - unlockedAccts := make([]*Account, 0, len(accts)) - defer func() { - // Lock all account wallets unlocked during this call - // if any of the unlocks failed. - if err != nil { - for _, ua := range unlockedAccts { - if err := ua.Lock(); err != nil { - log.Warnf("Cannot lock account '%s': %v", - ua.name, err) - } - } - } - }() - - for _, a := range accts { - if uErr := a.Unlock([]byte(passphrase)); uErr != nil { - err = fmt.Errorf("cannot unlock account %v: %v", - a.name, uErr) - return - } - unlockedAccts = append(unlockedAccts, a) +// UnlockWallets unlocks all managed account's wallets, locking them again after +// the timeout expires, or resetting a previous timeout if one is still running. +func (am *AccountManager) UnlockWallets(passphrase []byte, timeout time.Duration) error { + req := unlockRequest{ + passphrase: passphrase, + timeout: timeout, + err: make(chan error, 1), } - return + am.unlockRequests <- req + return <-req.err } // DumpKeys returns all WIF-encoded private keys associated with all diff --git a/rpcserver.go b/rpcserver.go index f6bce43..ed1c7e4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1675,7 +1675,7 @@ func ListAccounts(icmd btcjson.Cmd) (interface{}, error) { return AcctMgr.ListAccounts(cmd.MinConf), nil } -// ListLockUnspent handles a listlockunspent request by returning an array of +// ListLockUnspent handles a listlockunspent request by returning an slice of // all locked outpoints. func ListLockUnspent(icmd btcjson.Cmd) (interface{}, error) { // Due to our poor account support, this assumes only the default @@ -2691,32 +2691,17 @@ func VerifyMessage(icmd btcjson.Cmd) (interface{}, error) { // WalletIsLocked handles the walletislocked extension request by // returning the current lock state (false for unlocked, true for locked) -// of an account. An error is returned if the requested account does not -// exist. +// of an account. func WalletIsLocked(icmd btcjson.Cmd) (interface{}, error) { - // Type assert icmd to access parameters. - cmd, ok := icmd.(*btcws.WalletIsLockedCmd) - if !ok { - return nil, btcjson.ErrInternal - } - - a, err := AcctMgr.Account(cmd.Account) - if err != nil { - if err == ErrNotFound { - return nil, btcjson.ErrWalletInvalidAccountName - } - return nil, err - } - - return a.Wallet.IsLocked(), nil + return !<-AcctMgr.unlockedState, nil } // WalletLock handles a walletlock request by locking the all account // wallets, returning an error if any wallet is not encrypted (for example, // a watching-only wallet). func WalletLock(icmd btcjson.Cmd) (interface{}, error) { - err := AcctMgr.LockWallets() - return nil, err + AcctMgr.LockWallets() + return nil, nil } // WalletPassphrase responds to the walletpassphrase request by unlocking @@ -2729,21 +2714,9 @@ func WalletPassphrase(icmd btcjson.Cmd) (interface{}, error) { return nil, btcjson.ErrInternal } - if err := AcctMgr.UnlockWallets(cmd.Passphrase); err != nil { - return nil, err - } - - go func(timeout int64) { - time.Sleep(time.Second * time.Duration(timeout)) - AcctMgr.Grab() - defer AcctMgr.Release() - err := AcctMgr.LockWallets() - if err != nil { - log.Warnf("Cannot lock account wallets: %v", err) - } - }(cmd.Timeout) - - return nil, nil + timeout := time.Second * time.Duration(cmd.Timeout) + err := AcctMgr.UnlockWallets([]byte(cmd.Passphrase), timeout) + return nil, err } // WalletPassphraseChange responds to the walletpassphrasechange request diff --git a/wallet/wallet.go b/wallet/wallet.go index b5674b1..d4a6bf2 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -2367,14 +2367,6 @@ func (a *btcAddress) unlock(key []byte) (privKeyCT []byte, err error) { return nil, errors.New("unable to unlock unencrypted address") } - // If secret is already saved, return a copy without performing a full - // unlock. - if len(a.privKeyCT) == 32 { - privKeyCT := make([]byte, 32) - copy(privKeyCT, a.privKeyCT) - return privKeyCT, nil - } - // Decrypt private key with AES key. aesBlockDecrypter, err := aes.NewCipher(key) if err != nil { @@ -2384,6 +2376,16 @@ func (a *btcAddress) unlock(key []byte) (privKeyCT []byte, err error) { privkey := make([]byte, 32) aesDecrypter.XORKeyStream(privkey, a.privKey[:]) + // If secret is already saved, simply compare the bytes. + if len(a.privKeyCT) == 32 { + if !bytes.Equal(a.privKeyCT, privkey) { + return nil, ErrWrongPassphrase + } + privKeyCT := make([]byte, 32) + copy(privKeyCT, a.privKeyCT) + return privKeyCT, nil + } + x, y := btcec.S256().ScalarBaseMult(privkey) if x.Cmp(a.pubKey.X) != 0 || y.Cmp(a.pubKey.Y) != 0 { return nil, ErrWrongPassphrase