diff --git a/wallet/wallet.go b/wallet/wallet.go index 52617d6..87d2a1d 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -34,7 +34,6 @@ import ( "github.com/conformal/btcwire" "io" "math/big" - "sync" "time" ) @@ -487,10 +486,7 @@ type Wallet struct { txCommentMap map[transactionHashKey]comment // The rest of the fields in this struct are not serialized. - secret struct { - sync.Mutex - key []byte - } + secret []byte chainIdxMap map[int64]*btcutil.AddressPubKeyHash importedAddrs []*btcAddress lastChainIdx int64 @@ -571,6 +567,7 @@ func NewWallet(name, desc string, passphrase []byte, net btcwire.BitcoinNet, txCommentMap: make(map[transactionHashKey]comment), chainIdxMap: make(map[int64]*btcutil.AddressPubKeyHash), lastChainIdx: rootKeyChainIdx, + secret: aeskey, } copy(w.name[:], []byte(name)) copy(w.desc[:], []byte(desc)) @@ -580,7 +577,12 @@ func NewWallet(name, desc string, passphrase []byte, net btcwire.BitcoinNet, w.chainIdxMap[rootKeyChainIdx] = w.keyGenerator.address(net) // Fill keypool. - if err := w.extendKeypool(keypoolSize, aeskey, createdAt); err != nil { + if err := w.extendKeypool(keypoolSize, createdAt); err != nil { + return nil, err + } + + // Wallet must be returned locked. + if err := w.Lock(); err != nil { return nil, err } @@ -787,29 +789,22 @@ func (w *Wallet) Unlock(passphrase []byte) error { return err } - // If unlock was successful, create a copy for below and save the - // secret key. - keycopy := make([]byte, len(key)) - copy(keycopy, key) - w.secret.Lock() - w.secret.key = key - w.secret.Unlock() + // If unlock was successful, save the secret key. + w.secret = key - return w.createMissingPrivateKeys(keycopy) + return w.createMissingPrivateKeys() } // Lock performs a best try effort to remove and zero all secret keys // associated with the wallet. func (w *Wallet) Lock() (err error) { // Remove clear text passphrase from wallet. - w.secret.Lock() - if w.secret.key == nil { + if len(w.secret) != 32 { err = ErrWalletLocked } else { - zero(w.secret.key) - w.secret.key = nil + zero(w.secret) + w.secret = nil } - w.secret.Unlock() // Remove clear text private keys from all address entries. for _, addr := range w.addrMap { @@ -828,11 +823,8 @@ func zero(b []byte) { // IsLocked returns whether a wallet is unlocked (in which case the // key is saved in memory), or locked. -func (w *Wallet) IsLocked() (locked bool) { - w.secret.Lock() - locked = w.secret.key == nil - w.secret.Unlock() - return locked +func (w *Wallet) IsLocked() bool { + return len(w.secret) != 32 } // Version returns a wallet's version as a string and int. @@ -852,24 +844,13 @@ func (w *Wallet) NextChainedAddress(bs *BlockStamp, nextAPKH, ok := w.chainIdxMap[w.highestUsed+1] if !ok { // Extending the keypool requires an unlocked wallet. - var aeskey []byte - w.secret.Lock() - if len(w.secret.key) == 32 { - // Key is available, make a copy and extend - // keypool. - aeskey = make([]byte, 32) - copy(aeskey, w.secret.key) - w.secret.Unlock() - - err := w.extendKeypool(keypoolSize, aeskey, bs) - if err != nil { + if len(w.secret) == 32 { + // Key is available, extend keypool. + if err := w.extendKeypool(keypoolSize, bs); err != nil { return nil, err } } else { - w.secret.Unlock() - - err := w.extendLockedWallet(bs) - if err != nil { + if err := w.extendLockedWallet(bs); err != nil { return nil, err } } @@ -901,7 +882,7 @@ func (w *Wallet) LastChainedAddress() btcutil.Address { } // extendKeypool grows the keypool by n addresses. -func (w *Wallet) extendKeypool(n uint, aeskey []byte, bs *BlockStamp) error { +func (w *Wallet) extendKeypool(n uint, bs *BlockStamp) error { // Get last chained address. New chained addresses will be // chained off of this address's chaincode and private key. a := w.chainIdxMap[w.lastChainIdx] @@ -909,7 +890,7 @@ func (w *Wallet) extendKeypool(n uint, aeskey []byte, bs *BlockStamp) error { if !ok { return errors.New("expected last chained address not found") } - privkey, err := addr.unlock(aeskey) + privkey, err := addr.unlock(w.secret) if err != nil { return err } @@ -929,7 +910,7 @@ func (w *Wallet) extendKeypool(n uint, aeskey []byte, bs *BlockStamp) error { if err := newaddr.verifyKeypairs(); err != nil { return err } - if err = newaddr.encrypt(aeskey); err != nil { + if err = newaddr.encrypt(w.secret); err != nil { return err } a := newaddr.address(w.net) @@ -983,7 +964,7 @@ func (w *Wallet) extendLockedWallet(bs *BlockStamp) error { return nil } -func (w *Wallet) createMissingPrivateKeys(aeskey []byte) error { +func (w *Wallet) createMissingPrivateKeys() error { idx := w.missingKeysStart if idx == 0 { return nil @@ -995,7 +976,7 @@ func (w *Wallet) createMissingPrivateKeys(aeskey []byte) error { return errors.New("missing previous chained address") } prevAddr := w.addrMap[*apkh] - prevPrivKey, err := prevAddr.unlock(aeskey) + prevPrivKey, err := prevAddr.unlock(w.secret) if err != nil { return err } @@ -1017,7 +998,7 @@ func (w *Wallet) createMissingPrivateKeys(aeskey []byte) error { } addr := w.addrMap[*apkh] addr.privKeyCT = ithPrivKey - if err := addr.encrypt(aeskey); err != nil { + if err := addr.encrypt(w.secret); err != nil { return err } @@ -1065,21 +1046,15 @@ func (w *Wallet) AddressKey(a btcutil.Address) (key *ecdsa.PrivateKey, err error return nil, err } - // The wallet's secret will be zeroed on lock, so make a local - // copy. - localSecret := make([]byte, 32) - w.secret.Lock() - if len(w.secret.key) != 32 { - w.secret.Unlock() + // Wallet must be unlocked to decrypt the private key. + if len(w.secret) != 32 { return nil, ErrWalletLocked } - copy(localSecret, w.secret.key) - w.secret.Unlock() // Unlock address with wallet secret. unlock returns a copy of the // clear text private key, and may be used safely even during an address // lock. - privKeyCT, err := btcaddr.unlock(localSecret) + privKeyCT, err := btcaddr.unlock(w.secret) if err != nil { return nil, err } @@ -1226,15 +1201,10 @@ func (w *Wallet) ImportPrivateKey(privkey []byte, compressed bool, bs *BlockStam return "", ErrDuplicate } - // The wallet's secret will be zeroed on lock, so make a local copy. - w.secret.Lock() - if len(w.secret.key) != 32 { - w.secret.Unlock() + // The wallet must be unlocked to encrypt the imported private key. + if len(w.secret) != 32 { return "", ErrWalletLocked } - localSecret := make([]byte, 32) - copy(localSecret, w.secret.key) - w.secret.Unlock() // Create new address with this private key. btcaddr, err := newBtcAddress(privkey, nil, bs, compressed) @@ -1244,7 +1214,7 @@ func (w *Wallet) ImportPrivateKey(privkey []byte, compressed bool, bs *BlockStam btcaddr.chainIndex = importedKeyChainIdx // Encrypt imported address with the derived AES key. - if err = btcaddr.encrypt(localSecret); err != nil { + if err = btcaddr.encrypt(w.secret); err != nil { return "", err }