diff --git a/waddrmgr/scoped_manager.go b/waddrmgr/scoped_manager.go index 327be2f..7b901db 100644 --- a/waddrmgr/scoped_manager.go +++ b/waddrmgr/scoped_manager.go @@ -2270,3 +2270,11 @@ func (s *ScopedKeyManager) cloneKeyWithVersion(key *hdkeychain.ExtendedKey) ( return key.CloneWithVersion(versionBytes[:]) } + +// InvalidateAccountCache invalidates the cache for the given account, forcing a +// database read to retrieve the account information. +func (s *ScopedKeyManager) InvalidateAccountCache(account uint32) { + s.mtx.Lock() + defer s.mtx.Unlock() + delete(s.acctInfo, account) +} diff --git a/wallet/import.go b/wallet/import.go index b1b699a..af2c3f0 100644 --- a/wallet/import.go +++ b/wallet/import.go @@ -192,6 +192,24 @@ func (w *Wallet) ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKe masterKeyFingerprint uint32, addrType *waddrmgr.AddressType) ( *waddrmgr.AccountProperties, error) { + var accountProps *waddrmgr.AccountProperties + err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + var err error + accountProps, err = w.importAccount( + ns, name, accountPubKey, masterKeyFingerprint, addrType, + ) + return err + }) + return accountProps, err +} + +// importAccount is the internal implementation of ImportAccount -- one should +// reference its documentation for this method. +func (w *Wallet) importAccount(ns walletdb.ReadWriteBucket, name string, + accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32, + addrType *waddrmgr.AddressType) (*waddrmgr.AccountProperties, error) { + // Ensure we have a valid account public key. if err := w.validateExtendedPubKey(accountPubKey, true); err != nil { return nil, err @@ -208,21 +226,80 @@ func (w *Wallet) ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKe return nil, err } - // Store the account as watch-only within the database. - var accountProps *waddrmgr.AccountProperties - err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { - ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) - account, err := scopedMgr.NewAccountWatchingOnly( - ns, name, accountPubKey, masterKeyFingerprint, - addrSchema, - ) - if err != nil { - return err - } - accountProps, err = scopedMgr.AccountProperties(ns, account) - return err - }) - return accountProps, err + account, err := scopedMgr.NewAccountWatchingOnly( + ns, name, accountPubKey, masterKeyFingerprint, addrSchema, + ) + if err != nil { + return nil, err + } + return scopedMgr.AccountProperties(ns, account) +} + +// ImportAccountDryRun serves as a dry run implementation of ImportAccount. This +// method also returns the first N external and internal addresses, which can be +// presented to users to confirm whether the account has been imported +// correctly. +func (w *Wallet) ImportAccountDryRun(name string, + accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32, + addrType *waddrmgr.AddressType, numAddrs uint32) ( + *waddrmgr.AccountProperties, []waddrmgr.ManagedAddress, + []waddrmgr.ManagedAddress, error) { + + // Start a database transaction that we'll never commit and always + // rollback. + tx, err := w.db.BeginReadWriteTx() + if err != nil { + return nil, nil, nil, err + } + defer func() { + _ = tx.Rollback() + }() + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + + // Import the account as usual. + accountProps, err := w.importAccount( + ns, name, accountPubKey, masterKeyFingerprint, addrType, + ) + if err != nil { + return nil, nil, nil, err + } + + // Derive the external and internal addresses. Note that we could do + // this based on the provided accountPubKey alone, but we go through the + // ScopedKeyManager instead to ensure addresses will be derived as + // expected from the wallet's point-of-view. + manager, err := w.Manager.FetchScopedKeyManager(accountProps.KeyScope) + if err != nil { + return nil, nil, nil, err + } + + // The importAccount method above will cache the imported account within + // the scoped manager. Since this is a dry-run attempt, we'll want to + // invalidate the cache for it. + defer manager.InvalidateAccountCache(accountProps.AccountNumber) + + externalAddrs, err := manager.NextExternalAddresses( + ns, accountProps.AccountNumber, numAddrs, + ) + if err != nil { + return nil, nil, nil, err + } + internalAddrs, err := manager.NextInternalAddresses( + ns, accountProps.AccountNumber, numAddrs, + ) + if err != nil { + return nil, nil, nil, err + } + + // Refresh the account's properties after generating the addresses. + accountProps, err = manager.AccountProperties( + ns, accountProps.AccountNumber, + ) + if err != nil { + return nil, nil, nil, err + } + + return accountProps, externalAddrs, internalAddrs, nil } // ImportPublicKey imports a single derived public key into the address manager. diff --git a/wallet/import_test.go b/wallet/import_test.go index 2d75cca..a0f6503 100644 --- a/wallet/import_test.go +++ b/wallet/import_test.go @@ -178,6 +178,17 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool, acct3ExternalPub, err := acct3ExternalExtPub.ECPubKey() require.NoError(t, err) + // Do a dry run import first and check that it results in the expected + // addresses being derived. + _, extAddrs, intAddrs, err := w.ImportAccountDryRun( + name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType, 1, + ) + require.NoError(t, err) + require.Len(t, extAddrs, 1) + require.Equal(t, tc.expectedAddr, extAddrs[0].Address().String()) + require.Len(t, intAddrs, 1) + require.Equal(t, tc.expectedChangeAddr, intAddrs[0].Address().String()) + // Import the extended public keys into new accounts. acct1, err := w.ImportAccount( name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType, @@ -228,12 +239,12 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool, require.Equal(t, uint32(0), acct2.ImportedKeyCount) // Test address derivation. - addr, err := w.NewAddress(acct1.AccountNumber, tc.expectedScope) + extAddr, err := w.NewAddress(acct1.AccountNumber, tc.expectedScope) require.NoError(t, err) - require.Equal(t, tc.expectedAddr, addr.String()) - addr, err = w.NewChangeAddress(acct1.AccountNumber, tc.expectedScope) + require.Equal(t, tc.expectedAddr, extAddr.String()) + intAddr, err := w.NewChangeAddress(acct1.AccountNumber, tc.expectedScope) require.NoError(t, err) - require.Equal(t, tc.expectedChangeAddr, addr.String()) + require.Equal(t, tc.expectedChangeAddr, intAddr.String()) // Make sure the key count was increased. acct1, err = w.AccountProperties(tc.expectedScope, acct1.AccountNumber) @@ -243,7 +254,7 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool, require.Equal(t, uint32(0), acct1.ImportedKeyCount) // Make sure we can't get private keys for the imported accounts. - _, err = w.DumpWIFPrivateKey(addr) + _, err = w.DumpWIFPrivateKey(intAddr) require.True(t, waddrmgr.IsError(err, waddrmgr.ErrWatchingOnly)) // Get the address info for the single key we imported. @@ -258,13 +269,13 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool, witnessProg, err := txscript.PayToAddrScript(witnessAddr) require.NoError(t, err) - addr, err = btcutil.NewAddressScriptHash( + intAddr, err = btcutil.NewAddressScriptHash( witnessProg, &chaincfg.TestNet3Params, ) require.NoError(t, err) case waddrmgr.WitnessPubKey: - addr, err = btcutil.NewAddressWitnessPubKeyHash( + intAddr, err = btcutil.NewAddressWitnessPubKeyHash( btcutil.Hash160(acct3ExternalPub.SerializeCompressed()), &chaincfg.TestNet3Params, ) @@ -274,7 +285,7 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool, t.Fatalf("unhandled address type %v", tc.addrType) } - addrManaged, err := w.AddressInfo(addr) + addrManaged, err := w.AddressInfo(intAddr) require.NoError(t, err) require.Equal(t, true, addrManaged.Imported()) }