Merge pull request #746 from wpaulino/import-dry-run

wallet: add dry run implementation of ImportAccount
This commit is contained in:
Olaoluwa Osuntokun 2021-05-06 18:27:17 -07:00 committed by GitHub
commit 82fa030bda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 23 deletions

View file

@ -2270,3 +2270,11 @@ func (s *ScopedKeyManager) cloneKeyWithVersion(key *hdkeychain.ExtendedKey) (
return key.CloneWithVersion(versionBytes[:]) return key.CloneWithVersion(versionBytes[:])
} }
// InvalidateAccountCache invalidates the cache for the given account, forcing a
// database read to retrieve the account information.
func (s *ScopedKeyManager) InvalidateAccountCache(account uint32) {
s.mtx.Lock()
defer s.mtx.Unlock()
delete(s.acctInfo, account)
}

View file

@ -192,6 +192,24 @@ func (w *Wallet) ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKe
masterKeyFingerprint uint32, addrType *waddrmgr.AddressType) ( masterKeyFingerprint uint32, addrType *waddrmgr.AddressType) (
*waddrmgr.AccountProperties, error) { *waddrmgr.AccountProperties, error) {
var accountProps *waddrmgr.AccountProperties
err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
var err error
accountProps, err = w.importAccount(
ns, name, accountPubKey, masterKeyFingerprint, addrType,
)
return err
})
return accountProps, err
}
// importAccount is the internal implementation of ImportAccount -- one should
// reference its documentation for this method.
func (w *Wallet) importAccount(ns walletdb.ReadWriteBucket, name string,
accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
addrType *waddrmgr.AddressType) (*waddrmgr.AccountProperties, error) {
// Ensure we have a valid account public key. // Ensure we have a valid account public key.
if err := w.validateExtendedPubKey(accountPubKey, true); err != nil { if err := w.validateExtendedPubKey(accountPubKey, true); err != nil {
return nil, err return nil, err
@ -208,21 +226,80 @@ func (w *Wallet) ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKe
return nil, err return nil, err
} }
// Store the account as watch-only within the database. account, err := scopedMgr.NewAccountWatchingOnly(
var accountProps *waddrmgr.AccountProperties ns, name, accountPubKey, masterKeyFingerprint, addrSchema,
err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { )
ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) if err != nil {
account, err := scopedMgr.NewAccountWatchingOnly( return nil, err
ns, name, accountPubKey, masterKeyFingerprint, }
addrSchema, return scopedMgr.AccountProperties(ns, account)
) }
if err != nil {
return err // ImportAccountDryRun serves as a dry run implementation of ImportAccount. This
} // method also returns the first N external and internal addresses, which can be
accountProps, err = scopedMgr.AccountProperties(ns, account) // presented to users to confirm whether the account has been imported
return err // correctly.
}) func (w *Wallet) ImportAccountDryRun(name string,
return accountProps, err accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
addrType *waddrmgr.AddressType, numAddrs uint32) (
*waddrmgr.AccountProperties, []waddrmgr.ManagedAddress,
[]waddrmgr.ManagedAddress, error) {
// Start a database transaction that we'll never commit and always
// rollback.
tx, err := w.db.BeginReadWriteTx()
if err != nil {
return nil, nil, nil, err
}
defer func() {
_ = tx.Rollback()
}()
ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
// Import the account as usual.
accountProps, err := w.importAccount(
ns, name, accountPubKey, masterKeyFingerprint, addrType,
)
if err != nil {
return nil, nil, nil, err
}
// Derive the external and internal addresses. Note that we could do
// this based on the provided accountPubKey alone, but we go through the
// ScopedKeyManager instead to ensure addresses will be derived as
// expected from the wallet's point-of-view.
manager, err := w.Manager.FetchScopedKeyManager(accountProps.KeyScope)
if err != nil {
return nil, nil, nil, err
}
// The importAccount method above will cache the imported account within
// the scoped manager. Since this is a dry-run attempt, we'll want to
// invalidate the cache for it.
defer manager.InvalidateAccountCache(accountProps.AccountNumber)
externalAddrs, err := manager.NextExternalAddresses(
ns, accountProps.AccountNumber, numAddrs,
)
if err != nil {
return nil, nil, nil, err
}
internalAddrs, err := manager.NextInternalAddresses(
ns, accountProps.AccountNumber, numAddrs,
)
if err != nil {
return nil, nil, nil, err
}
// Refresh the account's properties after generating the addresses.
accountProps, err = manager.AccountProperties(
ns, accountProps.AccountNumber,
)
if err != nil {
return nil, nil, nil, err
}
return accountProps, externalAddrs, internalAddrs, nil
} }
// ImportPublicKey imports a single derived public key into the address manager. // ImportPublicKey imports a single derived public key into the address manager.

View file

@ -178,6 +178,17 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
acct3ExternalPub, err := acct3ExternalExtPub.ECPubKey() acct3ExternalPub, err := acct3ExternalExtPub.ECPubKey()
require.NoError(t, err) require.NoError(t, err)
// Do a dry run import first and check that it results in the expected
// addresses being derived.
_, extAddrs, intAddrs, err := w.ImportAccountDryRun(
name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType, 1,
)
require.NoError(t, err)
require.Len(t, extAddrs, 1)
require.Equal(t, tc.expectedAddr, extAddrs[0].Address().String())
require.Len(t, intAddrs, 1)
require.Equal(t, tc.expectedChangeAddr, intAddrs[0].Address().String())
// Import the extended public keys into new accounts. // Import the extended public keys into new accounts.
acct1, err := w.ImportAccount( acct1, err := w.ImportAccount(
name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType, name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType,
@ -228,12 +239,12 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
require.Equal(t, uint32(0), acct2.ImportedKeyCount) require.Equal(t, uint32(0), acct2.ImportedKeyCount)
// Test address derivation. // Test address derivation.
addr, err := w.NewAddress(acct1.AccountNumber, tc.expectedScope) extAddr, err := w.NewAddress(acct1.AccountNumber, tc.expectedScope)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tc.expectedAddr, addr.String()) require.Equal(t, tc.expectedAddr, extAddr.String())
addr, err = w.NewChangeAddress(acct1.AccountNumber, tc.expectedScope) intAddr, err := w.NewChangeAddress(acct1.AccountNumber, tc.expectedScope)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tc.expectedChangeAddr, addr.String()) require.Equal(t, tc.expectedChangeAddr, intAddr.String())
// Make sure the key count was increased. // Make sure the key count was increased.
acct1, err = w.AccountProperties(tc.expectedScope, acct1.AccountNumber) acct1, err = w.AccountProperties(tc.expectedScope, acct1.AccountNumber)
@ -243,7 +254,7 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
require.Equal(t, uint32(0), acct1.ImportedKeyCount) require.Equal(t, uint32(0), acct1.ImportedKeyCount)
// Make sure we can't get private keys for the imported accounts. // Make sure we can't get private keys for the imported accounts.
_, err = w.DumpWIFPrivateKey(addr) _, err = w.DumpWIFPrivateKey(intAddr)
require.True(t, waddrmgr.IsError(err, waddrmgr.ErrWatchingOnly)) require.True(t, waddrmgr.IsError(err, waddrmgr.ErrWatchingOnly))
// Get the address info for the single key we imported. // Get the address info for the single key we imported.
@ -258,13 +269,13 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
witnessProg, err := txscript.PayToAddrScript(witnessAddr) witnessProg, err := txscript.PayToAddrScript(witnessAddr)
require.NoError(t, err) require.NoError(t, err)
addr, err = btcutil.NewAddressScriptHash( intAddr, err = btcutil.NewAddressScriptHash(
witnessProg, &chaincfg.TestNet3Params, witnessProg, &chaincfg.TestNet3Params,
) )
require.NoError(t, err) require.NoError(t, err)
case waddrmgr.WitnessPubKey: case waddrmgr.WitnessPubKey:
addr, err = btcutil.NewAddressWitnessPubKeyHash( intAddr, err = btcutil.NewAddressWitnessPubKeyHash(
btcutil.Hash160(acct3ExternalPub.SerializeCompressed()), btcutil.Hash160(acct3ExternalPub.SerializeCompressed()),
&chaincfg.TestNet3Params, &chaincfg.TestNet3Params,
) )
@ -274,7 +285,7 @@ func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
t.Fatalf("unhandled address type %v", tc.addrType) t.Fatalf("unhandled address type %v", tc.addrType)
} }
addrManaged, err := w.AddressInfo(addr) addrManaged, err := w.AddressInfo(intAddr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, true, addrManaged.Imported()) require.Equal(t, true, addrManaged.Imported())
} }