Synchronize locking/unlocking of all keystores.

This change fixes the asynchronous deferred locking that used to be
performed after some timeout after a call to walletpassphrase by
managing the locked state of each account in a new account manager
goroutine.  The timeouts for new unlock requests replace any running
timeouts for older requests, rather than allowing previous timeouts to
expire before the most recent one.

Fixes #105.
This commit is contained in:
Josh Rickmar 2014-07-01 10:02:13 -05:00
parent 478a7ec867
commit e64d948093
3 changed files with 139 additions and 132 deletions

View file

@ -31,6 +31,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
) )
// Errors relating to accounts. // Errors relating to accounts.
@ -44,45 +45,56 @@ var (
// AcctMgr is the global account manager for all opened accounts. // AcctMgr is the global account manager for all opened accounts.
var AcctMgr = NewAccountManager() var AcctMgr = NewAccountManager()
type openAccountsCmd struct{} type (
openAccountsCmd struct{}
type accessAccountRequest struct { accessAccountRequest struct {
name string name string
resp chan *Account resp chan *Account
}
accessAllRequest struct {
resp chan []*Account
}
accessAccountByAddressRequest struct {
address string
resp chan *Account
}
markAddressForAccountCmd struct {
address string
account *Account
}
addAccountCmd struct {
a *Account
}
removeAccountCmd struct {
a *Account
}
quitCmd struct{}
)
type unlockRequest struct {
passphrase []byte
timeout time.Duration // Zero value prevents the timeout.
err chan error
} }
type accessAllRequest struct {
resp chan []*Account
}
type accessAccountByAddressRequest struct {
address string
resp chan *Account
}
type markAddressForAccountCmd struct {
address string
account *Account
}
type addAccountCmd struct {
a *Account
}
type removeAccountCmd struct {
a *Account
}
type quitCmd struct{}
// AccountManager manages a collection of accounts. // AccountManager manages a collection of accounts.
type AccountManager struct { type AccountManager struct {
// The accounts accessed through the account manager are not safe for // The accounts accessed through the account manager are not safe for
// concurrent access. The account manager therefore contains a // concurrent access. The account manager therefore contains a
// binary semaphore channel to prevent incorrect access. // binary semaphore channel to prevent incorrect access.
bsem chan struct{} bsem chan struct{}
cmdChan chan interface{} cmdChan chan interface{}
rescanMsgs chan RescanMsg rescanMsgs chan RescanMsg
unlockRequests chan unlockRequest
lockRequests chan struct{}
unlockedState chan bool
ds *DiskSyncer ds *DiskSyncer
rm *RescanManager rm *RescanManager
@ -94,10 +106,14 @@ type AccountManager struct {
// NewAccountManager returns a new AccountManager. // NewAccountManager returns a new AccountManager.
func NewAccountManager() *AccountManager { func NewAccountManager() *AccountManager {
am := &AccountManager{ am := &AccountManager{
bsem: make(chan struct{}, 1), bsem: make(chan struct{}, 1),
cmdChan: make(chan interface{}), cmdChan: make(chan interface{}),
rescanMsgs: make(chan RescanMsg, 1), rescanMsgs: make(chan RescanMsg, 1),
quit: make(chan struct{}), unlockRequests: make(chan unlockRequest),
lockRequests: make(chan struct{}),
unlockedState: make(chan bool),
quit: make(chan struct{}),
} }
am.ds = NewDiskSyncer(am) am.ds = NewDiskSyncer(am)
am.rm = NewRescanManager(am.rescanMsgs) am.rm = NewRescanManager(am.rescanMsgs)
@ -109,8 +125,9 @@ func (am *AccountManager) Start() {
// Ready the semaphore - can't grab unless the manager has started. // Ready the semaphore - can't grab unless the manager has started.
am.bsem <- struct{}{} am.bsem <- struct{}{}
am.wg.Add(2) am.wg.Add(3)
go am.accountHandler() go am.accountHandler()
go am.keystoreLocker()
go am.rescanListener() go am.rescanListener()
go am.ds.Start() go am.ds.Start()
@ -474,6 +491,53 @@ out:
am.wg.Done() am.wg.Done()
} }
// keystoreLocker manages the lockedness state of all account keystores.
func (am *AccountManager) keystoreLocker() {
unlocked := false
var timeout <-chan time.Time
out:
for {
select {
case req := <-am.unlockRequests:
for _, a := range am.AllAccounts() {
if err := a.Unlock(req.passphrase); err != nil {
req.err <- err
continue out
}
}
unlocked = true
if req.timeout == 0 {
timeout = nil
} else {
timeout = time.After(req.timeout)
}
req.err <- nil
continue
case am.unlockedState <- unlocked:
continue
case <-am.quit:
break out
case <-am.lockRequests:
case <-timeout:
}
// Select statement fell through by an explicit lock or the
// timer expiring. Lock the keystores here.
timeout = nil
for _, a := range am.AllAccounts() {
if err := a.Lock(); err != nil {
log.Errorf("Could not lock wallet for account '%s': %v",
a.name, err)
}
}
unlocked = false
}
am.wg.Done()
}
// rescanListener listens for messages from the rescan manager and marks // rescanListener listens for messages from the rescan manager and marks
// accounts and addresses as synced. // accounts and addresses as synced.
func (am *AccountManager) rescanListener() { func (am *AccountManager) rescanListener() {
@ -782,75 +846,43 @@ func (am *AccountManager) CreateEncryptedWallet(passphrase []byte) error {
// ChangePassphrase unlocks all account wallets with the old // ChangePassphrase unlocks all account wallets with the old
// passphrase, and re-encrypts each using the new passphrase. // passphrase, and re-encrypts each using the new passphrase.
func (am *AccountManager) ChangePassphrase(old, new []byte) error { func (am *AccountManager) ChangePassphrase(old, new []byte) error {
accts := am.AllAccounts() // Keystores must be unlocked to change their passphrase.
err := am.UnlockWallets(old, 0)
for _, a := range accts { if err != nil {
if !a.IsLocked() { return err
if err := a.Wallet.Lock(); err != nil {
return err
}
}
if err := a.Wallet.Unlock(old); err != nil {
return err
}
defer func(a *Account) {
if err := a.Lock(); err != nil {
log.Warnf("Cannot lock account: %v", err)
}
}(a)
} }
accts := am.AllAccounts()
// Change passphrase for each unlocked wallet. // Change passphrase for each unlocked wallet.
for _, a := range accts { for _, a := range accts {
if err := a.Wallet.ChangePassphrase(new); err != nil { err = a.Wallet.ChangePassphrase(new)
if err != nil {
return err return err
} }
} }
am.LockWallets()
// Immediately write out to disk. // Immediately write out to disk.
return am.ds.WriteBatch(accts) return am.ds.WriteBatch(accts)
} }
// LockWallets locks all managed account wallets. // LockWallets locks all managed account wallets.
func (am *AccountManager) LockWallets() error { func (am *AccountManager) LockWallets() {
for _, a := range am.AllAccounts() { am.lockRequests <- struct{}{}
if err := a.Lock(); err != nil {
return err
}
}
return nil
} }
// UnlockWallets unlocks all managed account's wallets. If any wallet unlocks // UnlockWallets unlocks all managed account's wallets, locking them again after
// fail, all successfully unlocked wallets are locked again. // the timeout expires, or resetting a previous timeout if one is still running.
func (am *AccountManager) UnlockWallets(passphrase string) (err error) { func (am *AccountManager) UnlockWallets(passphrase []byte, timeout time.Duration) error {
accts := am.AllAccounts() req := unlockRequest{
passphrase: passphrase,
unlockedAccts := make([]*Account, 0, len(accts)) timeout: timeout,
defer func() { err: make(chan error, 1),
// Lock all account wallets unlocked during this call
// if any of the unlocks failed.
if err != nil {
for _, ua := range unlockedAccts {
if err := ua.Lock(); err != nil {
log.Warnf("Cannot lock account '%s': %v",
ua.name, err)
}
}
}
}()
for _, a := range accts {
if uErr := a.Unlock([]byte(passphrase)); uErr != nil {
err = fmt.Errorf("cannot unlock account %v: %v",
a.name, uErr)
return
}
unlockedAccts = append(unlockedAccts, a)
} }
return am.unlockRequests <- req
return <-req.err
} }
// DumpKeys returns all WIF-encoded private keys associated with all // DumpKeys returns all WIF-encoded private keys associated with all

View file

@ -1675,7 +1675,7 @@ func ListAccounts(icmd btcjson.Cmd) (interface{}, error) {
return AcctMgr.ListAccounts(cmd.MinConf), nil return AcctMgr.ListAccounts(cmd.MinConf), nil
} }
// ListLockUnspent handles a listlockunspent request by returning an array of // ListLockUnspent handles a listlockunspent request by returning an slice of
// all locked outpoints. // all locked outpoints.
func ListLockUnspent(icmd btcjson.Cmd) (interface{}, error) { func ListLockUnspent(icmd btcjson.Cmd) (interface{}, error) {
// Due to our poor account support, this assumes only the default // Due to our poor account support, this assumes only the default
@ -2691,32 +2691,17 @@ func VerifyMessage(icmd btcjson.Cmd) (interface{}, error) {
// WalletIsLocked handles the walletislocked extension request by // WalletIsLocked handles the walletislocked extension request by
// returning the current lock state (false for unlocked, true for locked) // returning the current lock state (false for unlocked, true for locked)
// of an account. An error is returned if the requested account does not // of an account.
// exist.
func WalletIsLocked(icmd btcjson.Cmd) (interface{}, error) { func WalletIsLocked(icmd btcjson.Cmd) (interface{}, error) {
// Type assert icmd to access parameters. return !<-AcctMgr.unlockedState, nil
cmd, ok := icmd.(*btcws.WalletIsLockedCmd)
if !ok {
return nil, btcjson.ErrInternal
}
a, err := AcctMgr.Account(cmd.Account)
if err != nil {
if err == ErrNotFound {
return nil, btcjson.ErrWalletInvalidAccountName
}
return nil, err
}
return a.Wallet.IsLocked(), nil
} }
// WalletLock handles a walletlock request by locking the all account // WalletLock handles a walletlock request by locking the all account
// wallets, returning an error if any wallet is not encrypted (for example, // wallets, returning an error if any wallet is not encrypted (for example,
// a watching-only wallet). // a watching-only wallet).
func WalletLock(icmd btcjson.Cmd) (interface{}, error) { func WalletLock(icmd btcjson.Cmd) (interface{}, error) {
err := AcctMgr.LockWallets() AcctMgr.LockWallets()
return nil, err return nil, nil
} }
// WalletPassphrase responds to the walletpassphrase request by unlocking // WalletPassphrase responds to the walletpassphrase request by unlocking
@ -2729,21 +2714,9 @@ func WalletPassphrase(icmd btcjson.Cmd) (interface{}, error) {
return nil, btcjson.ErrInternal return nil, btcjson.ErrInternal
} }
if err := AcctMgr.UnlockWallets(cmd.Passphrase); err != nil { timeout := time.Second * time.Duration(cmd.Timeout)
return nil, err err := AcctMgr.UnlockWallets([]byte(cmd.Passphrase), timeout)
} return nil, err
go func(timeout int64) {
time.Sleep(time.Second * time.Duration(timeout))
AcctMgr.Grab()
defer AcctMgr.Release()
err := AcctMgr.LockWallets()
if err != nil {
log.Warnf("Cannot lock account wallets: %v", err)
}
}(cmd.Timeout)
return nil, nil
} }
// WalletPassphraseChange responds to the walletpassphrasechange request // WalletPassphraseChange responds to the walletpassphrasechange request

View file

@ -2367,14 +2367,6 @@ func (a *btcAddress) unlock(key []byte) (privKeyCT []byte, err error) {
return nil, errors.New("unable to unlock unencrypted address") return nil, errors.New("unable to unlock unencrypted address")
} }
// If secret is already saved, return a copy without performing a full
// unlock.
if len(a.privKeyCT) == 32 {
privKeyCT := make([]byte, 32)
copy(privKeyCT, a.privKeyCT)
return privKeyCT, nil
}
// Decrypt private key with AES key. // Decrypt private key with AES key.
aesBlockDecrypter, err := aes.NewCipher(key) aesBlockDecrypter, err := aes.NewCipher(key)
if err != nil { if err != nil {
@ -2384,6 +2376,16 @@ func (a *btcAddress) unlock(key []byte) (privKeyCT []byte, err error) {
privkey := make([]byte, 32) privkey := make([]byte, 32)
aesDecrypter.XORKeyStream(privkey, a.privKey[:]) aesDecrypter.XORKeyStream(privkey, a.privKey[:])
// If secret is already saved, simply compare the bytes.
if len(a.privKeyCT) == 32 {
if !bytes.Equal(a.privKeyCT, privkey) {
return nil, ErrWrongPassphrase
}
privKeyCT := make([]byte, 32)
copy(privKeyCT, a.privKeyCT)
return privKeyCT, nil
}
x, y := btcec.S256().ScalarBaseMult(privkey) x, y := btcec.S256().ScalarBaseMult(privkey)
if x.Cmp(a.pubKey.X) != 0 || y.Cmp(a.pubKey.Y) != 0 { if x.Cmp(a.pubKey.X) != 0 || y.Cmp(a.pubKey.Y) != 0 {
return nil, ErrWrongPassphrase return nil, ErrWrongPassphrase