diff --git a/waddrmgr/scoped_manager.go b/waddrmgr/scoped_manager.go index 7048776..4f2863e 100644 --- a/waddrmgr/scoped_manager.go +++ b/waddrmgr/scoped_manager.go @@ -53,6 +53,17 @@ type KeyScope struct { Coin uint32 } +// ScopedIndex is a tuple of KeyScope and child Index. This is used to compactly +// identify a particular child key, when the account and branch can be inferred +// from context. +type ScopedIndex struct { + // Scope is the BIP44 account' used to derive the child key. + Scope KeyScope + + // Index is the BIP44 address_index used to derive the child key. + Index uint32 +} + // String returns a human readable version describing the keypath encapsulated // by the target key scope. func (k *KeyScope) String() string { @@ -204,7 +215,7 @@ func (s *ScopedKeyManager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32) (ManagedAddress, error) { var addrType AddressType - if branch == internalBranch { + if branch == InternalBranch { addrType = s.addrSchema.InternalAddrType } else { addrType = s.addrSchema.ExternalAddrType @@ -233,7 +244,7 @@ func (s *ScopedKeyManager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, s.deriveOnUnlock = append(s.deriveOnUnlock, &info) } - if branch == internalBranch { + if branch == InternalBranch { ma.internal = true } @@ -297,7 +308,7 @@ func (s *ScopedKeyManager) loadAccountInfo(ns walletdb.ReadBucket, row, ok := rowInterface.(*dbDefaultAccountRow) if !ok { str := fmt.Sprintf("unsupported account type %T", row) - err = managerError(ErrDatabase, str, nil) + return nil, managerError(ErrDatabase, str, nil) } // Use the crypto public key to decrypt the account public extended @@ -345,7 +356,7 @@ func (s *ScopedKeyManager) loadAccountInfo(ns walletdb.ReadBucket, } // Derive and cache the managed address for the last external address. - branch, index := externalBranch, row.nextExternalIndex + branch, index := ExternalBranch, row.nextExternalIndex if index > 0 { index-- } @@ -362,7 +373,7 @@ func (s *ScopedKeyManager) loadAccountInfo(ns walletdb.ReadBucket, acctInfo.lastExternalAddr = lastExtAddr // Derive and cache the managed address for the last internal address. - branch, index = internalBranch, row.nextInternalIndex + branch, index = InternalBranch, row.nextInternalIndex if index > 0 { index-- } @@ -670,9 +681,9 @@ func (s *ScopedKeyManager) nextAddresses(ns walletdb.ReadWriteBucket, // Choose the branch key and index depending on whether or not this is // an internal address. - branchNum, nextIndex := externalBranch, acctInfo.nextExternalIndex + branchNum, nextIndex := ExternalBranch, acctInfo.nextExternalIndex if internal { - branchNum = internalBranch + branchNum = InternalBranch nextIndex = acctInfo.nextInternalIndex } @@ -817,6 +828,183 @@ func (s *ScopedKeyManager) nextAddresses(ns walletdb.ReadWriteBucket, return managedAddresses, nil } +// extendAddresses ensures that all addresses up to and including the lastIndex +// are derived for either an internal or external branch. If the child at +// lastIndex is invalid, this method will proceed until the next valid child is +// found. An error is returned if method failed to properly extend addresses +// up to the requested index. +// +// This function MUST be called with the manager lock held for writes. +func (s *ScopedKeyManager) extendAddresses(ns walletdb.ReadWriteBucket, + account uint32, lastIndex uint32, internal bool) error { + + // The next address can only be generated for accounts that have + // already been created. + acctInfo, err := s.loadAccountInfo(ns, account) + if err != nil { + return err + } + + // Choose the account key to used based on whether the address manager + // is locked. + acctKey := acctInfo.acctKeyPub + if !s.rootManager.IsLocked() { + acctKey = acctInfo.acctKeyPriv + } + + // Choose the branch key and index depending on whether or not this is + // an internal address. + branchNum, nextIndex := ExternalBranch, acctInfo.nextExternalIndex + if internal { + branchNum = InternalBranch + nextIndex = acctInfo.nextInternalIndex + } + + addrType := s.addrSchema.ExternalAddrType + if internal { + addrType = s.addrSchema.InternalAddrType + } + + // If the last index requested is already lower than the next index, we + // can return early. + if lastIndex < nextIndex { + return nil + } + + // Ensure the requested number of addresses doesn't exceed the maximum + // allowed for this account. + if lastIndex > MaxAddressesPerAccount { + str := fmt.Sprintf("last index %d would exceed the maximum "+ + "allowed number of addresses per account of %d", + lastIndex, MaxAddressesPerAccount) + return managerError(ErrTooManyAddresses, str, nil) + } + + // Derive the appropriate branch key and ensure it is zeroed when done. + branchKey, err := acctKey.Child(branchNum) + if err != nil { + str := fmt.Sprintf("failed to derive extended key branch %d", + branchNum) + return managerError(ErrKeyChain, str, err) + } + defer branchKey.Zero() // Ensure branch key is zeroed when done. + + // Starting from this branch's nextIndex, derive all child indexes up to + // and including the requested lastIndex. If a invalid child is + // detected, this loop will continue deriving until it finds the next + // subsequent index. + addressInfo := make([]*unlockDeriveInfo, 0, lastIndex-nextIndex) + for nextIndex <= lastIndex { + // There is an extremely small chance that a particular child is + // invalid, so use a loop to derive the next valid child. + var nextKey *hdkeychain.ExtendedKey + for { + // Derive the next child in the external chain branch. + key, err := branchKey.Child(nextIndex) + if err != nil { + // When this particular child is invalid, skip to the + // next index. + if err == hdkeychain.ErrInvalidChild { + nextIndex++ + continue + } + + str := fmt.Sprintf("failed to generate child %d", + nextIndex) + return managerError(ErrKeyChain, str, err) + } + key.SetNet(s.rootManager.chainParams) + + nextIndex++ + nextKey = key + break + } + + // Create a new managed address based on the public or private + // key depending on whether the generated key is private. + // Also, zero the next key after creating the managed address + // from it. + addr, err := newManagedAddressFromExtKey( + s, account, nextKey, addrType, + ) + if err != nil { + return err + } + if internal { + addr.internal = true + } + managedAddr := addr + nextKey.Zero() + + info := unlockDeriveInfo{ + managedAddr: managedAddr, + branch: branchNum, + index: nextIndex - 1, + } + addressInfo = append(addressInfo, &info) + } + + // Now that all addresses have been successfully generated, update the + // database in a single transaction. + for _, info := range addressInfo { + ma := info.managedAddr + addressID := ma.Address().ScriptAddress() + + switch a := ma.(type) { + case *managedAddress: + err := putChainedAddress( + ns, &s.scope, addressID, account, ssFull, + info.branch, info.index, adtChain, + ) + if err != nil { + return maybeConvertDbError(err) + } + case *scriptAddress: + encryptedHash, err := s.rootManager.cryptoKeyPub.Encrypt(a.AddrHash()) + if err != nil { + str := fmt.Sprintf("failed to encrypt script hash %x", + a.AddrHash()) + return managerError(ErrCrypto, str, err) + } + + err = putScriptAddress( + ns, &s.scope, a.AddrHash(), ImportedAddrAccount, + ssNone, encryptedHash, a.scriptEncrypted, + ) + if err != nil { + return maybeConvertDbError(err) + } + } + } + + // Finally update the next address tracking and add the addresses to + // the cache after the newly generated addresses have been successfully + // added to the db. + for _, info := range addressInfo { + ma := info.managedAddr + s.addrs[addrKey(ma.Address().ScriptAddress())] = ma + + // Add the new managed address to the list of addresses that + // need their private keys derived when the address manager is + // next unlocked. + if s.rootManager.IsLocked() && !s.rootManager.WatchOnly() { + s.deriveOnUnlock = append(s.deriveOnUnlock, info) + } + } + + // Set the last address and next address for tracking. + ma := addressInfo[len(addressInfo)-1].managedAddr + if internal { + acctInfo.nextInternalIndex = nextIndex + acctInfo.lastInternalAddr = ma + } else { + acctInfo.nextExternalIndex = nextIndex + acctInfo.lastExternalAddr = ma + } + + return nil +} + // NextExternalAddresses returns the specified number of next chained addresses // that are intended for external use from the address manager. func (s *ScopedKeyManager) NextExternalAddresses(ns walletdb.ReadWriteBucket, @@ -851,6 +1039,42 @@ func (s *ScopedKeyManager) NextInternalAddresses(ns walletdb.ReadWriteBucket, return s.nextAddresses(ns, account, numAddresses, true) } +// ExtendExternalAddresses ensures that all valid external keys through +// lastIndex are derived and stored in the wallet. This is used to ensure that +// wallet's persistent state catches up to a external child that was found +// during recovery. +func (s *ScopedKeyManager) ExtendExternalAddresses(ns walletdb.ReadWriteBucket, + account uint32, lastIndex uint32) error { + + if account > MaxAccountNum { + err := managerError(ErrAccountNumTooHigh, errAcctTooHigh, nil) + return err + } + + s.mtx.Lock() + defer s.mtx.Unlock() + + return s.extendAddresses(ns, account, lastIndex, false) +} + +// ExtendInternalAddresses ensures that all valid internal keys through +// lastIndex are derived and stored in the wallet. This is used to ensure that +// wallet's persistent state catches up to an internal child that was found +// during recovery. +func (s *ScopedKeyManager) ExtendInternalAddresses(ns walletdb.ReadWriteBucket, + account uint32, lastIndex uint32) error { + + if account > MaxAccountNum { + err := managerError(ErrAccountNumTooHigh, errAcctTooHigh, nil) + return err + } + + s.mtx.Lock() + defer s.mtx.Unlock() + + return s.extendAddresses(ns, account, lastIndex, true) +} + // LastExternalAddress returns the most recently requested chained external // address from calling NextExternalAddress for the given account. The first // external address for the account will be returned if none have been