diff --git a/chain/bitcoind_client.go b/chain/bitcoind_client.go index 2e42e7d..6b0584d 100644 --- a/chain/bitcoind_client.go +++ b/chain/bitcoind_client.go @@ -29,6 +29,10 @@ var ( // BitcoindClient represents a persistent client connection to a bitcoind server // for information regarding the current best block chain. type BitcoindClient struct { + // notifyBlocks signals whether the client is sending block + // notifications to the caller. This must be used atomically. + notifyBlocks uint32 + started int32 // To be used atomically. stopped int32 // To be used atomically. @@ -52,10 +56,6 @@ type BitcoindClient struct { bestBlockMtx sync.RWMutex bestBlock waddrmgr.BlockStamp - // notifyBlocks signals whether the client is sending block - // notifications to the caller. - notifyBlocks uint32 - // rescanUpdate is a channel will be sent items that we should match // transactions against while processing a chain rescan to determine if // they are relevant to the client. @@ -174,6 +174,20 @@ func (c *BitcoindClient) GetBlockHeaderVerbose( return c.chainConn.client.GetBlockHeaderVerbose(hash) } +// IsCurrent returns whether the chain backend considers its view of the network +// as "current". +func (c *BitcoindClient) IsCurrent() bool { + bestHash, _, err := c.GetBestBlock() + if err != nil { + return false + } + bestHeader, err := c.GetBlockHeader(bestHash) + if err != nil { + return false + } + return bestHeader.Timestamp.After(time.Now().Add(-isCurrentDelta)) +} + // GetRawTransactionVerbose returns a transaction from the tx hash. func (c *BitcoindClient) GetRawTransactionVerbose( hash *chainhash.Hash) (*btcjson.TxRawResult, error) { @@ -251,7 +265,42 @@ func (c *BitcoindClient) NotifyTx(txids []chainhash.Hash) error { // // NOTE: This is part of the chain.Interface interface. func (c *BitcoindClient) NotifyBlocks() error { - atomic.StoreUint32(&c.notifyBlocks, 1) + // We'll guard the goroutine being spawned below by the notifyBlocks + // variable we'll use atomically. We'll make sure to reset it in case of + // a failure before spawning the goroutine so that it can be retried. + if !atomic.CompareAndSwapUint32(&c.notifyBlocks, 0, 1) { + return nil + } + + // Re-evaluate our known best block since it's possible that blocks have + // occurred between now and when the client was created. This ensures we + // don't detect a new notified block as a potential reorg. + bestHash, bestHeight, err := c.GetBestBlock() + if err != nil { + atomic.StoreUint32(&c.notifyBlocks, 0) + return fmt.Errorf("unable to retrieve best block: %v", err) + } + bestHeader, err := c.GetBlockHeaderVerbose(bestHash) + if err != nil { + atomic.StoreUint32(&c.notifyBlocks, 0) + return fmt.Errorf("unable to retrieve header for best block: "+ + "%v", err) + } + + c.bestBlockMtx.Lock() + c.bestBlock.Hash = *bestHash + c.bestBlock.Height = bestHeight + c.bestBlock.Timestamp = time.Unix(bestHeader.Time, 0) + c.bestBlockMtx.Unlock() + + // Include the client in the set of rescan clients of the backing + // bitcoind connection in order to receive ZMQ event notifications for + // new blocks and transactions. + c.chainConn.AddClient(c) + + c.wg.Add(1) + go c.ntfnHandler() + return nil } @@ -423,14 +472,8 @@ func (c *BitcoindClient) Start() error { } c.bestBlockMtx.Unlock() - // Once the client has started successfully, we'll include it in the set - // of rescan clients of the backing bitcoind connection in order to - // received ZMQ event notifications. - c.chainConn.AddClient(c) - - c.wg.Add(2) + c.wg.Add(1) go c.rescanHandler() - go c.ntfnHandler() return nil } @@ -562,9 +605,9 @@ func (c *BitcoindClient) ntfnHandler() { // successor, so we'll update our best block to reflect // this and determine if this new block matches any of // our existing filters. - c.bestBlockMtx.Lock() + c.bestBlockMtx.RLock() bestBlock := c.bestBlock - c.bestBlockMtx.Unlock() + c.bestBlockMtx.RUnlock() if newBlock.Header.PrevBlock == bestBlock.Hash { newBlockHeight := bestBlock.Height + 1 _ = c.filterBlock(newBlock, newBlockHeight, true) @@ -720,8 +763,6 @@ func (c *BitcoindClient) onRescanProgress(hash *chainhash.Hash, height int32, func (c *BitcoindClient) onRescanFinished(hash *chainhash.Hash, height int32, timestamp time.Time) { - log.Infof("Rescan finished at %d (%s)", height, hash) - select { case c.notificationQueue.ChanIn() <- &RescanFinished{ Hash: hash, @@ -748,8 +789,8 @@ func (c *BitcoindClient) reorg(currentBlock waddrmgr.BlockStamp, bestHash, err) } - log.Debugf("Possible reorg at block: height=%v, hash=%s", bestHash, - bestHeight) + log.Debugf("Possible reorg at block: height=%v, hash=%v", bestHeight, + bestHash) if bestHeight < currentBlock.Height { log.Debugf("Detected multiple reorgs: best_height=%v below "+ diff --git a/chain/interface.go b/chain/interface.go index 885f632..da23e14 100644 --- a/chain/interface.go +++ b/chain/interface.go @@ -10,6 +10,11 @@ import ( "github.com/btcsuite/btcwallet/wtxmgr" ) +// isCurrentDelta is the delta duration we'll use from the present time to +// determine if a backend is considered "current", i.e. synced to the tip of +// the chain. +const isCurrentDelta = 2 * time.Hour + // BackEnds returns a list of the available back ends. // TODO: Refactor each into a driver and use dynamic registration. func BackEnds() []string { @@ -31,6 +36,7 @@ type Interface interface { GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) GetBlockHash(int64) (*chainhash.Hash, error) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error) + IsCurrent() bool FilterBlocks(*FilterBlocksRequest) (*FilterBlocksResponse, error) BlockStamp() (*waddrmgr.BlockStamp, error) SendRawTransaction(*wire.MsgTx, bool) (*chainhash.Hash, error) diff --git a/chain/neutrino.go b/chain/neutrino.go index c87e89e..d74a6a1 100644 --- a/chain/neutrino.go +++ b/chain/neutrino.go @@ -157,6 +157,12 @@ func (s *NeutrinoClient) GetBlockHeader( return s.CS.GetBlockHeader(blockHash) } +// IsCurrent returns whether the chain backend considers its view of the network +// as "current". +func (s *NeutrinoClient) IsCurrent() bool { + return s.CS.IsCurrent() +} + // SendRawTransaction replicates the RPC client's SendRawTransaction command. func (s *NeutrinoClient) SendRawTransaction(tx *wire.MsgTx, allowHighFees bool) ( *chainhash.Hash, error) { diff --git a/chain/rpc.go b/chain/rpc.go index cea23ba..af9bbf5 100644 --- a/chain/rpc.go +++ b/chain/rpc.go @@ -140,6 +140,20 @@ func (c *RPCClient) Stop() { c.quitMtx.Unlock() } +// IsCurrent returns whether the chain backend considers its view of the network +// as "current". +func (c *RPCClient) IsCurrent() bool { + bestHash, _, err := c.GetBestBlock() + if err != nil { + return false + } + bestHeader, err := c.GetBlockHeader(bestHash) + if err != nil { + return false + } + return bestHeader.Timestamp.After(time.Now().Add(-isCurrentDelta)) +} + // Rescan wraps the normal Rescan command with an additional paramter that // allows us to map an oupoint to the address in the chain that it pays to. // This is useful when using BIP 158 filters as they include the prev pkScript diff --git a/waddrmgr/db.go b/waddrmgr/db.go index 6a6678f..7e9069f 100644 --- a/waddrmgr/db.go +++ b/waddrmgr/db.go @@ -16,6 +16,13 @@ import ( "github.com/btcsuite/btcwallet/walletdb" ) +const ( + // MaxReorgDepth represents the maximum number of block hashes we'll + // keep within the wallet at any given point in order to recover from + // long reorgs. + MaxReorgDepth = 10000 +) + var ( // LatestMgrVersion is the most recent manager version. LatestMgrVersion = getLatestVersion() @@ -1832,40 +1839,45 @@ func fetchSyncedTo(ns walletdb.ReadBucket) (*BlockStamp, error) { // PutSyncedTo stores the provided synced to blockstamp to the database. func PutSyncedTo(ns walletdb.ReadWriteBucket, bs *BlockStamp) error { - bucket := ns.NestedReadWriteBucket(syncBucketName) errStr := fmt.Sprintf("failed to store sync information %v", bs.Hash) // If the block height is greater than zero, check that the previous - // block height exists. This prevents reorg issues in the future. - // We use BigEndian so that keys/values are added to the bucket in - // order, making writes more efficient for some database backends. + // block height exists. This prevents reorg issues in the future. We use + // BigEndian so that keys/values are added to the bucket in order, + // making writes more efficient for some database backends. if bs.Height > 0 { - if _, err := fetchBlockHash(ns, bs.Height-1); err != nil { - return managerError(ErrDatabase, errStr, err) + // We'll only check the previous block height exists if we've + // determined our birthday block. This is needed as we'll no + // longer store _all_ block hashes of the chain, so we only + // expect the previous block to exist once our initial sync has + // completed, which is dictated by our birthday block being set. + if _, err := FetchBirthdayBlock(ns); err == nil { + _, err := fetchBlockHash(ns, bs.Height-1) + if err != nil { + return managerError(ErrBlockNotFound, errStr, err) + } } } // Store the block hash by block height. - height := make([]byte, 4) - binary.BigEndian.PutUint32(height, uint32(bs.Height)) - err := bucket.Put(height, bs.Hash[0:32]) - if err != nil { + if err := addBlockHash(ns, bs.Height, bs.Hash); err != nil { return managerError(ErrDatabase, errStr, err) } - // The serialized synced to format is: - // - // - // 4 bytes block height + 32 bytes hash length + 4 byte timestamp length - buf := make([]byte, 40) - binary.LittleEndian.PutUint32(buf[0:4], uint32(bs.Height)) - copy(buf[4:36], bs.Hash[0:32]) - binary.LittleEndian.PutUint32(buf[36:], uint32(bs.Timestamp.Unix())) + // Remove the stale height if any, as we should only store MaxReorgDepth + // block hashes at any given point. + staleHeight := staleHeight(bs.Height) + if staleHeight > 0 { + if err := deleteBlockHash(ns, staleHeight); err != nil { + return managerError(ErrDatabase, errStr, err) + } + } - err = bucket.Put(syncedToName, buf) - if err != nil { + // Finally, we can update the syncedTo value. + if err := updateSyncedTo(ns, bs); err != nil { return managerError(ErrDatabase, errStr, err) } + return nil } @@ -1893,6 +1905,62 @@ func fetchBlockHash(ns walletdb.ReadBucket, height int32) (*chainhash.Hash, erro return &hash, nil } +// addBlockHash adds a block hash entry to the index within the syncBucket. +func addBlockHash(ns walletdb.ReadWriteBucket, height int32, hash chainhash.Hash) error { + var rawHeight [4]byte + binary.BigEndian.PutUint32(rawHeight[:], uint32(height)) + bucket := ns.NestedReadWriteBucket(syncBucketName) + if err := bucket.Put(rawHeight[:], hash[:]); err != nil { + errStr := fmt.Sprintf("failed to add hash %v", hash) + return managerError(ErrDatabase, errStr, err) + } + return nil +} + +// deleteBlockHash deletes the block hash entry within the syncBucket for the +// given height. +func deleteBlockHash(ns walletdb.ReadWriteBucket, height int32) error { + var rawHeight [4]byte + binary.BigEndian.PutUint32(rawHeight[:], uint32(height)) + bucket := ns.NestedReadWriteBucket(syncBucketName) + if err := bucket.Delete(rawHeight[:]); err != nil { + errStr := fmt.Sprintf("failed to delete hash for height %v", + height) + return managerError(ErrDatabase, errStr, err) + } + return nil +} + +// updateSyncedTo updates the value behind the syncedToName key to the given +// block. +func updateSyncedTo(ns walletdb.ReadWriteBucket, bs *BlockStamp) error { + // The serialized synced to format is: + // + // + // 4 bytes block height + 32 bytes hash length + 4 byte timestamp length + var serializedStamp [40]byte + binary.LittleEndian.PutUint32(serializedStamp[0:4], uint32(bs.Height)) + copy(serializedStamp[4:36], bs.Hash[0:32]) + binary.LittleEndian.PutUint32( + serializedStamp[36:], uint32(bs.Timestamp.Unix()), + ) + + bucket := ns.NestedReadWriteBucket(syncBucketName) + if err := bucket.Put(syncedToName, serializedStamp[:]); err != nil { + errStr := "failed to update synced to value" + return managerError(ErrDatabase, errStr, err) + } + + return nil +} + +// staleHeight returns the stale height for the given height. The stale height +// indicates the height we should remove in order to maintain a maximum of +// MaxReorgDepth block hashes. +func staleHeight(height int32) int32 { + return height - MaxReorgDepth +} + // FetchStartBlock loads the start block stamp for the manager from the // database. func FetchStartBlock(ns walletdb.ReadBucket) (*BlockStamp, error) { diff --git a/waddrmgr/db_test.go b/waddrmgr/db_test.go new file mode 100644 index 0000000..c29cb85 --- /dev/null +++ b/waddrmgr/db_test.go @@ -0,0 +1,131 @@ +package waddrmgr + +import ( + "encoding/binary" + "fmt" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcwallet/walletdb" +) + +// TestStoreMaxReorgDepth ensures that we can only store up to MaxReorgDepth +// blocks at any given time. +func TestStoreMaxReorgDepth(t *testing.T) { + t.Parallel() + + teardown, db, _ := setupManager(t) + defer teardown() + + // We'll start the test by simulating a synced chain where we start from + // 1000 and end at 109999. + const ( + startHeight = 1000 + numBlocks = MaxReorgDepth - 1 + ) + + blocks := make([]*BlockStamp, 0, numBlocks) + for i := int32(startHeight); i <= startHeight+numBlocks; i++ { + var hash chainhash.Hash + binary.BigEndian.PutUint32(hash[:], uint32(i)) + blocks = append(blocks, &BlockStamp{ + Hash: hash, + Height: i, + }) + } + + // We'll write all of the blocks to the database. + err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error { + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + for _, block := range blocks { + if err := PutSyncedTo(ns, block); err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + // We should be able to retrieve them all as we have MaxReorgDepth + // blocks. + err = walletdb.View(db, func(tx walletdb.ReadTx) error { + ns := tx.ReadBucket(waddrmgrNamespaceKey) + syncedTo, err := fetchSyncedTo(ns) + if err != nil { + return err + } + lastBlock := blocks[len(blocks)-1] + if syncedTo.Height != lastBlock.Height { + return fmt.Errorf("expected synced to block height "+ + "%v, got %v", lastBlock.Height, syncedTo.Height) + } + if syncedTo.Hash != lastBlock.Hash { + return fmt.Errorf("expected synced to block hash %v, "+ + "got %v", lastBlock.Hash, syncedTo.Hash) + } + + firstBlock := blocks[0] + hash, err := fetchBlockHash(ns, firstBlock.Height) + if err != nil { + return err + } + if *hash != firstBlock.Hash { + return fmt.Errorf("expected hash %v for height %v, "+ + "got %v", firstBlock.Hash, firstBlock.Height, + hash) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + + // Then, we'll create a new block which we'll use to extend the chain. + lastBlock := blocks[len(blocks)-1] + newBlockHeight := lastBlock.Height + 1 + var newBlockHash chainhash.Hash + binary.BigEndian.PutUint32(newBlockHash[:], uint32(newBlockHeight)) + newBlock := &BlockStamp{Height: newBlockHeight, Hash: newBlockHash} + + err = walletdb.Update(db, func(tx walletdb.ReadWriteTx) error { + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + return PutSyncedTo(ns, newBlock) + }) + if err != nil { + t.Fatal(err) + } + + // Extending the chain would cause us to exceed our MaxReorgDepth blocks + // stored, so we should see the first block we ever added to now be + // removed. + err = walletdb.View(db, func(tx walletdb.ReadTx) error { + ns := tx.ReadBucket(waddrmgrNamespaceKey) + syncedTo, err := fetchSyncedTo(ns) + if err != nil { + return err + } + if syncedTo.Height != newBlock.Height { + return fmt.Errorf("expected synced to block height "+ + "%v, got %v", newBlock.Height, syncedTo.Height) + } + if syncedTo.Hash != newBlock.Hash { + return fmt.Errorf("expected synced to block hash %v, "+ + "got %v", newBlock.Hash, syncedTo.Hash) + } + + firstBlock := blocks[0] + _, err = fetchBlockHash(ns, firstBlock.Height) + if !IsError(err, ErrBlockNotFound) { + return fmt.Errorf("expected ErrBlockNotFound, got %v", + err) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/waddrmgr/migrations.go b/waddrmgr/migrations.go index dceae04..bf73c63 100644 --- a/waddrmgr/migrations.go +++ b/waddrmgr/migrations.go @@ -31,6 +31,10 @@ var versions = []migration.Version{ Number: 7, Migration: resetSyncedBlockToBirthday, }, + { + Number: 8, + Migration: storeMaxReorgDepth, + }, } // getLatestVersion returns the version number of the latest database version. @@ -372,3 +376,37 @@ func resetSyncedBlockToBirthday(ns walletdb.ReadWriteBucket) error { return PutSyncedTo(ns, &birthdayBlock) } + +// storeMaxReorgDepth is a migration responsible for allowing the wallet to only +// maintain MaxReorgDepth block hashes stored in order to recover from long +// reorgs. +func storeMaxReorgDepth(ns walletdb.ReadWriteBucket) error { + // Retrieve the current tip of the wallet. We'll use this to determine + // the highest stale height we currently have stored within it. + syncedTo, err := fetchSyncedTo(ns) + if err != nil { + return err + } + maxStaleHeight := staleHeight(syncedTo.Height) + + // It's possible for this height to be non-sensical if we have less than + // MaxReorgDepth blocks stored, so we can end the migration now. + if maxStaleHeight < 1 { + return nil + } + + log.Infof("Removing block hash entries beyond maximum reorg depth of "+ + "%v from current tip %v", MaxReorgDepth, syncedTo.Height) + + // Otherwise, since we currently store all block hashes of the chain + // before this migration, we'll remove all stale block hash entries + // above the genesis block. This would leave us with only MaxReorgDepth + // blocks stored. + for height := maxStaleHeight; height > 0; height-- { + if err := deleteBlockHash(ns, height); err != nil { + return err + } + } + + return nil +} diff --git a/waddrmgr/migrations_test.go b/waddrmgr/migrations_test.go index 181f55e..9e8a870 100644 --- a/waddrmgr/migrations_test.go +++ b/waddrmgr/migrations_test.go @@ -2,12 +2,14 @@ package waddrmgr import ( "bytes" + "encoding/binary" "errors" "fmt" "testing" "time" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcwallet/walletdb" ) @@ -296,3 +298,154 @@ func TestMigrationResetSyncedBlockToBirthdayWithNoBirthdayBlock(t *testing.T) { true, ) } + +// TestMigrationStoreMaxReorgDepth ensures that the storeMaxReorgDepth migration +// works as expected under different sync scenarios. +func TestMigrationStoreMaxReorgDepth(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + numBlocks int32 + }{ + { + name: "genesis only", + numBlocks: 0, + }, + { + name: "below max reorg depth", + numBlocks: MaxReorgDepth - 1, + }, + { + name: "above max reorg depth", + numBlocks: MaxReorgDepth + 1, + }, + { + name: "double max reorg depth", + numBlocks: MaxReorgDepth * 2, + }, + } + + for _, testCase := range testCases { + success := t.Run(testCase.name, func(t *testing.T) { + // We'll start the test by creating the number of blocks + // we'll add to the chain. We start from height 1 as the + // genesis block (height 0) is already included when the + // address manager is created. + blocks := make([]*BlockStamp, 0, testCase.numBlocks) + for i := int32(1); i <= testCase.numBlocks; i++ { + var hash chainhash.Hash + binary.BigEndian.PutUint32(hash[:], uint32(i)) + blocks = append(blocks, &BlockStamp{ + Hash: hash, + Height: i, + }) + } + + // Before the migration, we'll go ahead and add all of + // the blocks created. This simulates the behavior of an + // existing synced chain. We won't use PutSyncedTo as + // that would remove the stale entries on its own. + beforeMigration := func(ns walletdb.ReadWriteBucket) error { + if testCase.numBlocks == 0 { + return nil + } + + // Write all the block hash entries. + for _, block := range blocks { + err := addBlockHash( + ns, block.Height, block.Hash, + ) + if err != nil { + return err + } + err = updateSyncedTo(ns, block) + if err != nil { + return err + } + } + + // Check to make sure they've been added + // properly. + for _, block := range blocks { + hash, err := fetchBlockHash( + ns, block.Height, + ) + if err != nil { + return err + } + if *hash != block.Hash { + return fmt.Errorf("expected "+ + "hash %v for height "+ + "%v, got %v", + block.Hash, + block.Height, hash) + } + } + block, err := fetchSyncedTo(ns) + if err != nil { + return err + } + expectedBlock := blocks[len(blocks)-1] + if block.Height != block.Height { + return fmt.Errorf("expected synced to "+ + "block height %v, got %v", + expectedBlock.Height, + block.Height) + } + if block.Hash != block.Hash { + return fmt.Errorf("expected synced to "+ + "block hash %v, got %v", + expectedBlock.Hash, + block.Hash) + } + + return nil + } + + // After the migration, we'll ensure we're unable to + // find all the block hashes that should have been + // removed. + afterMigration := func(ns walletdb.ReadWriteBucket) error { + maxStaleHeight := staleHeight(testCase.numBlocks) + for _, block := range blocks { + if block.Height <= maxStaleHeight { + _, err := fetchBlockHash( + ns, block.Height, + ) + if IsError(err, ErrBlockNotFound) { + continue + } + return fmt.Errorf("expected "+ + "ErrBlockNotFound for "+ + "height %v, got %v", + block.Height, err) + } + + hash, err := fetchBlockHash( + ns, block.Height, + ) + if err != nil { + return err + } + if *hash != block.Hash { + return fmt.Errorf("expected "+ + "hash %v for height "+ + "%v, got %v", + block.Hash, + block.Height, hash) + } + } + return nil + } + + applyMigration( + t, beforeMigration, afterMigration, + storeMaxReorgDepth, false, + ) + }) + if !success { + return + } + } +} diff --git a/wallet/chainntfns.go b/wallet/chainntfns.go index b719cf9..32f2acd 100644 --- a/wallet/chainntfns.go +++ b/wallet/chainntfns.go @@ -7,7 +7,6 @@ package wallet import ( "bytes" "fmt" - "strings" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -35,17 +34,6 @@ func (w *Wallet) handleChainNotifications() { return } - sync := func(w *Wallet, birthdayStamp *waddrmgr.BlockStamp) { - // At the moment there is no recourse if the rescan fails for - // some reason, however, the wallet will not be marked synced - // and many methods will error early since the wallet is known - // to be out of date. - err := w.syncWithChain(birthdayStamp) - if err != nil && !w.ShuttingDown() { - log.Warnf("Unable to synchronize wallet to chain: %v", err) - } - } - catchUpHashes := func(w *Wallet, client chain.Interface, height int32) error { // TODO(aakselrod): There's a race conditon here, which @@ -119,29 +107,31 @@ func (w *Wallet) handleChainNotifications() { chainClient, birthdayStore, ) if err != nil && !waddrmgr.IsError(err, waddrmgr.ErrBirthdayBlockNotSet) { - err := fmt.Errorf("unable to sanity "+ + panic(fmt.Errorf("Unable to sanity "+ "check wallet birthday block: %v", - err) - log.Error(err) - panic(err) + err)) } - go sync(w, birthdayBlock) + err = w.syncWithChain(birthdayBlock) + if err != nil && !w.ShuttingDown() { + panic(fmt.Errorf("Unable to synchronize "+ + "wallet to chain: %v", err)) + } case chain.BlockConnected: err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { return w.connectBlock(tx, wtxmgr.BlockMeta(n)) }) - notificationName = "blockconnected" + notificationName = "block connected" case chain.BlockDisconnected: err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { return w.disconnectBlock(tx, wtxmgr.BlockMeta(n)) }) - notificationName = "blockdisconnected" + notificationName = "block disconnected" case chain.RelevantTx: err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { return w.addRelevantTx(tx, n.TxRecord, n.Block) }) - notificationName = "recvtx/redeemingtx" + notificationName = "relevant transaction" case chain.FilteredBlockConnected: // Atomically update for the whole block. if len(n.RelevantTxs) > 0 { @@ -158,13 +148,13 @@ func (w *Wallet) handleChainNotifications() { return nil }) } - notificationName = "filteredblockconnected" + notificationName = "filtered block connected" // The following require some database maintenance, but also // need to be reported to the wallet's rescan goroutine. case *chain.RescanProgress: err = catchUpHashes(w, chainClient, n.Height) - notificationName = "rescanprogress" + notificationName = "rescan progress" select { case w.rescanNotifications <- n: case <-w.quitChan(): @@ -172,7 +162,7 @@ func (w *Wallet) handleChainNotifications() { } case *chain.RescanFinished: err = catchUpHashes(w, chainClient, n.Height) - notificationName = "rescanprogress" + notificationName = "rescan finished" w.SetChainSynced(true) select { case w.rescanNotifications <- n: @@ -181,17 +171,24 @@ func (w *Wallet) handleChainNotifications() { } } if err != nil { - // On out-of-sync blockconnected notifications, only - // send a debug message. - errStr := "Failed to process consensus server " + - "notification (name: `%s`, detail: `%v`)" - if notificationName == "blockconnected" && - strings.Contains(err.Error(), - "couldn't get hash from database") { - log.Debugf(errStr, notificationName, err) - } else { - log.Errorf(errStr, notificationName, err) + // If we received a block connected notification + // while rescanning, then we can ignore logging + // the error as we'll properly catch up once we + // process the RescanFinished notification. + if notificationName == "block connected" && + waddrmgr.IsError(err, waddrmgr.ErrBlockNotFound) && + !w.ChainSynced() { + + log.Debugf("Received block connected "+ + "notification for height %v "+ + "while rescanning", + n.(chain.BlockConnected).Height) + continue } + + log.Errorf("Unable to process chain backend "+ + "%v notification: %v", notificationName, + err) } case <-w.quit: return @@ -473,140 +470,16 @@ func birthdaySanityCheck(chainConn chainConn, return &birthdayBlock, nil } - log.Debugf("Starting sanity check for the wallet's birthday block "+ - "from: height=%d, hash=%v", birthdayBlock.Height, - birthdayBlock.Hash) - - // Now, we'll need to determine if our block correctly reflects our - // timestamp. To do so, we'll fetch the block header and check its - // timestamp in the event that the birthday block's timestamp was not - // set (this is possible if it was set through the migration, since we - // do not store block timestamps). - candidate := birthdayBlock - header, err := chainConn.GetBlockHeader(&candidate.Hash) - if err != nil { - return nil, fmt.Errorf("unable to get header for block hash "+ - "%v: %v", candidate.Hash, err) - } - candidate.Timestamp = header.Timestamp - - // We'll go back a day worth of blocks in the chain until we find a - // block whose timestamp is below our birthday timestamp. - heightDelta := int32(144) - for birthdayTimestamp.Before(candidate.Timestamp) { - // If the birthday block has reached genesis, then we can exit - // our search as there exists no data before this point. - if candidate.Height == 0 { - break - } - - // To prevent requesting blocks out of range, we'll use a lower - // bound of the first block in the chain. - newCandidateHeight := int64(candidate.Height - heightDelta) - if newCandidateHeight < 0 { - newCandidateHeight = 0 - } - - // Then, we'll fetch the current candidate's hash and header to - // determine if it is valid. - hash, err := chainConn.GetBlockHash(newCandidateHeight) - if err != nil { - return nil, fmt.Errorf("unable to get block hash for "+ - "height %d: %v", candidate.Height, err) - } - header, err := chainConn.GetBlockHeader(hash) - if err != nil { - return nil, fmt.Errorf("unable to get header for "+ - "block hash %v: %v", candidate.Hash, err) - } - - candidate.Hash = *hash - candidate.Height = int32(newCandidateHeight) - candidate.Timestamp = header.Timestamp - - log.Debugf("Checking next birthday block candidate: "+ - "height=%d, hash=%v, timestamp=%v", - candidate.Height, candidate.Hash, - candidate.Timestamp) - } - - // To ensure we have a reasonable birthday block, we'll make sure it - // respects our birthday timestamp and it is within a reasonable delta. - // The birthday has already been adjusted to two days in the past of the - // actual birthday, so we'll make our expected delta to be within two - // hours of it to account for the network-adjusted time and prevent - // fetching more unnecessary blocks. - _, bestHeight, err := chainConn.GetBestBlock() + // Otherwise, we'll attempt to locate a better one now that we have + // access to the chain. + newBirthdayBlock, err := locateBirthdayBlock(chainConn, birthdayTimestamp) if err != nil { return nil, err } - timestampDelta := birthdayTimestamp.Sub(candidate.Timestamp) - for timestampDelta > birthdayBlockDelta { - // We'll determine the height for our next candidate and make - // sure it is not out of range. If it is, we'll lower our height - // delta until finding a height within range. - newHeight := candidate.Height + heightDelta - if newHeight > bestHeight { - heightDelta /= 2 - // If we've exhausted all of our possible options at a - // later height, then we can assume the current birthday - // block is our best estimate. - if heightDelta == 0 { - break - } - - continue - } - - // We'll fetch the header for the next candidate and compare its - // timestamp. - hash, err := chainConn.GetBlockHash(int64(newHeight)) - if err != nil { - return nil, fmt.Errorf("unable to get block hash for "+ - "height %d: %v", candidate.Height, err) - } - header, err := chainConn.GetBlockHeader(hash) - if err != nil { - return nil, fmt.Errorf("unable to get header for "+ - "block hash %v: %v", hash, err) - } - - log.Debugf("Checking next birthday block candidate: "+ - "height=%d, hash=%v, timestamp=%v", newHeight, hash, - header.Timestamp) - - // If this block has exceeded our birthday timestamp, we'll look - // for the next candidate with a lower height delta. - if birthdayTimestamp.Before(header.Timestamp) { - heightDelta /= 2 - - // If we've exhausted all of our possible options at a - // later height, then we can assume the current birthday - // block is our best estimate. - if heightDelta == 0 { - break - } - - continue - } - - // Otherwise, this is a valid candidate, so we'll check to see - // if it meets our expected timestamp delta. - candidate.Hash = *hash - candidate.Height = newHeight - candidate.Timestamp = header.Timestamp - timestampDelta = birthdayTimestamp.Sub(header.Timestamp) - } - - // At this point, we've found a new, better candidate, so we'll write it - // to disk. - log.Debugf("Found a new valid wallet birthday block: height=%d, hash=%v", - candidate.Height, candidate.Hash) - - if err := birthdayStore.SetBirthdayBlock(candidate); err != nil { + if err := birthdayStore.SetBirthdayBlock(*newBirthdayBlock); err != nil { return nil, err } - return &candidate, nil + return newBirthdayBlock, nil } diff --git a/wallet/chainntfns_test.go b/wallet/chainntfns_test.go index ba7f8aa..8d62390 100644 --- a/wallet/chainntfns_test.go +++ b/wallet/chainntfns_test.go @@ -13,14 +13,16 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" ) +const ( + // defaultBlockInterval is the default time interval between any two + // blocks in a mocked chain. + defaultBlockInterval = 10 * time.Minute +) + var ( // chainParams are the chain parameters used throughout the wallet // tests. chainParams = chaincfg.MainNetParams - - // blockInterval is the time interval between any two blocks in a mocked - // chain. - blockInterval = 10 * time.Minute ) // mockChainConn is a mock in-memory implementation of the chainConn interface @@ -36,9 +38,11 @@ type mockChainConn struct { var _ chainConn = (*mockChainConn)(nil) // createMockChainConn creates a new mock chain connection backed by a chain -// with N blocks. Each block has a timestamp that is exactly 10 minutes after +// with N blocks. Each block has a timestamp that is exactly blockInterval after // the previous block's timestamp. -func createMockChainConn(genesis *wire.MsgBlock, n uint32) *mockChainConn { +func createMockChainConn(genesis *wire.MsgBlock, n uint32, + blockInterval time.Duration) *mockChainConn { + c := &mockChainConn{ chainTip: n, blockHashes: make(map[uint32]chainhash.Hash), @@ -163,7 +167,9 @@ func TestBirthdaySanityCheckVerifiedBirthdayBlock(t *testing.T) { t.Parallel() const chainTip = 5000 - chainConn := createMockChainConn(chainParams.GenesisBlock, chainTip) + chainConn := createMockChainConn( + chainParams.GenesisBlock, chainTip, defaultBlockInterval, + ) expectedBirthdayBlock := waddrmgr.BlockStamp{Height: 1337} // Our birthday store reflects that our birthday block has already been @@ -205,10 +211,12 @@ func TestBirthdaySanityCheckLowerEstimate(t *testing.T) { // We'll start by defining our birthday timestamp to be around the // timestamp of the 1337th block. genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp - birthday := genesisTimestamp.Add(1337 * blockInterval) + birthday := genesisTimestamp.Add(1337 * defaultBlockInterval) // We'll establish a connection to a mock chain of 5000 blocks. - chainConn := createMockChainConn(chainParams.GenesisBlock, 5000) + chainConn := createMockChainConn( + chainParams.GenesisBlock, 5000, defaultBlockInterval, + ) // Our birthday store will reflect that our birthday block is currently // set as the genesis block. This value is too low and should be @@ -256,10 +264,12 @@ func TestBirthdaySanityCheckHigherEstimate(t *testing.T) { // We'll start by defining our birthday timestamp to be around the // timestamp of the 1337th block. genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp - birthday := genesisTimestamp.Add(1337 * blockInterval) + birthday := genesisTimestamp.Add(1337 * defaultBlockInterval) // We'll establish a connection to a mock chain of 5000 blocks. - chainConn := createMockChainConn(chainParams.GenesisBlock, 5000) + chainConn := createMockChainConn( + chainParams.GenesisBlock, 5000, defaultBlockInterval, + ) // Our birthday store will reflect that our birthday block is currently // set as the chain tip. This value is too high and should be adjusted diff --git a/wallet/mock.go b/wallet/mock.go index a626515..09b5e65 100644 --- a/wallet/mock.go +++ b/wallet/mock.go @@ -41,6 +41,10 @@ func (m *mockChainClient) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, return nil, nil } +func (m *mockChainClient) IsCurrent() bool { + return false +} + func (m *mockChainClient) FilterBlocks(*chain.FilterBlocksRequest) ( *chain.FilterBlocksResponse, error) { return nil, nil diff --git a/wallet/rescan.go b/wallet/rescan.go index e3deb5d..1e5dbf4 100644 --- a/wallet/rescan.go +++ b/wallet/rescan.go @@ -56,7 +56,11 @@ type rescanBatch struct { func (w *Wallet) SubmitRescan(job *RescanJob) <-chan error { errChan := make(chan error, 1) job.err = errChan - w.rescanAddJob <- job + select { + case w.rescanAddJob <- job: + case <-w.quitChan(): + errChan <- ErrWalletShuttingDown + } return errChan } @@ -103,10 +107,11 @@ func (b *rescanBatch) done(err error) { // submissions, and possibly batching many waiting requests together so they // can be handled by a single rescan after the current one completes. func (w *Wallet) rescanBatchHandler() { + defer w.wg.Done() + var curBatch, nextBatch *rescanBatch quit := w.quitChan() -out: for { select { case job := <-w.rescanAddJob: @@ -114,7 +119,12 @@ out: // Set current batch as this job and send // request. curBatch = job.batch() - w.rescanBatch <- curBatch + select { + case w.rescanBatch <- curBatch: + case <-quit: + job.err <- ErrWalletShuttingDown + return + } } else { // Create next batch if it doesn't exist, or // merge the job. @@ -134,9 +144,16 @@ out: "currently running") continue } - w.rescanProgress <- &RescanProgressMsg{ + select { + case w.rescanProgress <- &RescanProgressMsg{ Addresses: curBatch.addrs, Notification: n, + }: + case <-quit: + for _, errChan := range curBatch.errChans { + errChan <- ErrWalletShuttingDown + } + return } case *chain.RescanFinished: @@ -146,15 +163,29 @@ out: "currently running") continue } - w.rescanFinished <- &RescanFinishedMsg{ + select { + case w.rescanFinished <- &RescanFinishedMsg{ Addresses: curBatch.addrs, Notification: n, + }: + case <-quit: + for _, errChan := range curBatch.errChans { + errChan <- ErrWalletShuttingDown + } + return } curBatch, nextBatch = nextBatch, nil if curBatch != nil { - w.rescanBatch <- curBatch + select { + case w.rescanBatch <- curBatch: + case <-quit: + for _, errChan := range curBatch.errChans { + errChan <- ErrWalletShuttingDown + } + return + } } default: @@ -163,11 +194,9 @@ out: } case <-quit: - break out + return } } - - w.wg.Done() } // rescanProgressHandler handles notifications for partially and fully completed @@ -280,5 +309,10 @@ func (w *Wallet) rescanWithTarget(addrs []btcutil.Address, } // Submit merged job and block until rescan completes. - return <-w.SubmitRescan(job) + select { + case err := <-w.SubmitRescan(job): + return err + case <-w.quitChan(): + return ErrWalletShuttingDown + } } diff --git a/wallet/wallet.go b/wallet/wallet.go index 1c5755f..374436d 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -54,13 +54,18 @@ const ( recoveryBatchSize = 2000 ) -// ErrNotSynced describes an error where an operation cannot complete -// due wallet being out of sync (and perhaps currently syncing with) -// the remote chain server. -var ErrNotSynced = errors.New("wallet is not synchronized with the chain server") - -// Namespace bucket keys. var ( + // ErrNotSynced describes an error where an operation cannot complete + // due wallet being out of sync (and perhaps currently syncing with) + // the remote chain server. + ErrNotSynced = errors.New("wallet is not synchronized with the chain server") + + // ErrWalletShuttingDown is an error returned when we attempt to make a + // request to the wallet but it is in the process of or has already shut + // down. + ErrWalletShuttingDown = errors.New("wallet shutting down") + + // Namespace bucket keys. waddrmgrNamespaceKey = []byte("waddrmgr") wtxmgrNamespaceKey = []byte("wtxmgr") ) @@ -322,28 +327,75 @@ func (w *Wallet) activeData(dbtx walletdb.ReadTx) ([]btcutil.Address, []wtxmgr.C // finished. The birthday block can be passed in, if set, to ensure we can // properly detect if it gets rolled back. func (w *Wallet) syncWithChain(birthdayStamp *waddrmgr.BlockStamp) error { - // To start, if we've yet to find our birthday stamp, we'll do so now. + chainClient, err := w.requireChainClient() + if err != nil { + return err + } + + // We'll wait until the backend is synced to ensure we get the latest + // MaxReorgDepth blocks to store. We don't do this for development + // environments as we can't guarantee a lively chain. + if !w.isDevEnv() { + log.Debug("Waiting for chain backend to sync to tip") + if err := w.waitUntilBackendSynced(chainClient); err != nil { + return err + } + log.Debug("Chain backend synced to tip!") + } + + // If we've yet to find our birthday block, we'll do so now. if birthdayStamp == nil { var err error - birthdayStamp, err = w.syncToBirthday() + birthdayStamp, err = locateBirthdayBlock( + chainClient, w.Manager.Birthday(), + ) + if err != nil { + return fmt.Errorf("unable to locate birthday block: %v", + err) + } + + // We'll also determine our initial sync starting height. This + // is needed as the wallet can now begin storing blocks from an + // arbitrary height, rather than all the blocks from genesis, so + // we persist this height to ensure we don't store any blocks + // before it. + startHeight, _, err := w.getSyncRange(chainClient, birthdayStamp) if err != nil { return err } + startHash, err := chainClient.GetBlockHash(int64(startHeight)) + if err != nil { + return err + } + startHeader, err := chainClient.GetBlockHeader(startHash) + if err != nil { + return err + } + + err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + err := w.Manager.SetSyncedTo(ns, &waddrmgr.BlockStamp{ + Hash: *startHash, + Height: startHeight, + Timestamp: startHeader.Timestamp, + }) + if err != nil { + return err + } + return w.Manager.SetBirthdayBlock(ns, *birthdayStamp, true) + }) + if err != nil { + return fmt.Errorf("unable to persist initial sync "+ + "data: %v", err) + } } // If the wallet requested an on-chain recovery of its funds, we'll do // so now. if w.recoveryWindow > 0 { - // We'll start the recovery from our birthday unless we were - // in the middle of a previous recovery attempt. If that's the - // case, we'll resume from that point. - startHeight := birthdayStamp.Height - walletHeight := w.Manager.SyncedTo().Height - if walletHeight > startHeight { - startHeight = walletHeight - } - if err := w.recovery(startHeight); err != nil { - return err + if err := w.recovery(chainClient, birthdayStamp); err != nil { + return fmt.Errorf("unable to perform wallet recovery: "+ + "%v", err) } } @@ -352,11 +404,6 @@ func (w *Wallet) syncWithChain(birthdayStamp *waddrmgr.BlockStamp) error { // before catching up with the rescan. rollback := false rollbackStamp := w.Manager.SyncedTo() - chainClient, err := w.requireChainClient() - if err != nil { - return err - } - err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey) txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey) @@ -432,9 +479,9 @@ func (w *Wallet) syncWithChain(birthdayStamp *waddrmgr.BlockStamp) error { return err } - // Finally, we'll trigger a wallet rescan from the currently synced tip - // and request notifications for transactions sending to all wallet - // addresses and spending all wallet UTXOs. + // Finally, we'll trigger a wallet rescan and request notifications for + // transactions sending to all wallet addresses and spending all wallet + // UTXOs. var ( addrs []btcutil.Address unspent []wtxmgr.Credit @@ -462,195 +509,111 @@ func (w *Wallet) isDevEnv() bool { return true } -// scanChain is a helper method that scans the chain from the starting height -// until the tip of the chain. The onBlock callback can be used to perform -// certain operations for every block that we process as we scan the chain. -func (w *Wallet) scanChain(startHeight int32, - onBlock func(int32, *chainhash.Hash, *wire.BlockHeader) error) error { +// waitUntilBackendSynced blocks until the chain backend considers itself +// "current". +func (w *Wallet) waitUntilBackendSynced(chainClient chain.Interface) error { + // We'll poll every second to determine if our chain considers itself + // "current". + t := time.NewTicker(time.Second) + defer t.Stop() - chainClient, err := w.requireChainClient() - if err != nil { - return err - } - - // isCurrent is a helper function that we'll use to determine if the - // chain backend is currently synced. When running with a btcd or - // bitcoind backend, it will use the height of the latest checkpoint as - // its lower bound. - var latestCheckptHeight int32 - if len(w.chainParams.Checkpoints) > 0 { - latestCheckptHeight = w.chainParams. - Checkpoints[len(w.chainParams.Checkpoints)-1].Height - } - isCurrent := func(bestHeight int32) bool { - // If the best height is zero, we assume the chain backend is - // still looking for peers to sync to in the case of a global - // network, e.g., testnet and mainnet. - if bestHeight == 0 && !w.isDevEnv() { - return false - } - - switch c := chainClient.(type) { - case *chain.NeutrinoClient: - return c.CS.IsCurrent() - } - return bestHeight >= latestCheckptHeight - } - - // Determine the latest height known to the chain backend and begin - // scanning the chain from the start height up until this point. - _, bestHeight, err := chainClient.GetBestBlock() - if err != nil { - return err - } - - for height := startHeight; height <= bestHeight; height++ { - hash, err := chainClient.GetBlockHash(int64(height)) - if err != nil { - return err - } - header, err := chainClient.GetBlockHeader(hash) - if err != nil { - return err - } - - if err := onBlock(height, hash, header); err != nil { - return err - } - - // If we've reached our best height, we'll wait for blocks at - // tip to ensure we go through all existent blocks in the chain. - // We'll update our bestHeight before checking if we're current - // with the chain to ensure we process any additional blocks - // that came in while we were scanning from our starting point. - for height == bestHeight { - time.Sleep(100 * time.Millisecond) - _, bestHeight, err = chainClient.GetBestBlock() - if err != nil { - return err - } - if isCurrent(bestHeight) { - break + for { + select { + case <-t.C: + if chainClient.IsCurrent() { + return nil } + case <-w.quitChan(): + return ErrWalletShuttingDown } } - - return nil } -// syncToBirthday attempts to sync the wallet's point of view of the chain until -// it finds the first block whose timestamp is above the wallet's birthday. The -// wallet's birthday is already two days in the past of its actual birthday, so -// this is relatively safe to do. -func (w *Wallet) syncToBirthday() (*waddrmgr.BlockStamp, error) { - var birthdayStamp *waddrmgr.BlockStamp - birthday := w.Manager.Birthday() +// locateBirthdayBlock returns a block that meets the given birthday timestamp +// by a margin of +/-2 hours. This is safe to do as the timestamp is already 2 +// days in the past of the actual timestamp. +func locateBirthdayBlock(chainClient chainConn, + birthday time.Time) (*waddrmgr.BlockStamp, error) { - tx, err := w.db.BeginReadWriteTx() + // Retrieve the lookup range for our block. + startHeight := int32(0) + _, bestHeight, err := chainClient.GetBestBlock() if err != nil { return nil, err } - ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) - // We'll begin scanning the chain from our last sync point until finding - // the first block with a timestamp greater than our birthday. We'll use - // this block to represent our birthday stamp. errDone is an error we'll - // use to signal that we've found it and no longer need to keep scanning - // the chain. - errDone := errors.New("done") - err = w.scanChain(w.Manager.SyncedTo().Height, func(height int32, - hash *chainhash.Hash, header *wire.BlockHeader) error { + log.Debugf("Locating suitable block for birthday %v between blocks "+ + "%v-%v", birthday, startHeight, bestHeight) - if header.Timestamp.After(birthday) { - log.Debugf("Found birthday block: height=%d, hash=%v", - height, hash) + var ( + birthdayBlock *waddrmgr.BlockStamp + left, right = startHeight, bestHeight + ) - birthdayStamp = &waddrmgr.BlockStamp{ - Hash: *hash, - Height: height, - Timestamp: header.Timestamp, - } - - err := w.Manager.SetBirthdayBlock( - ns, *birthdayStamp, true, - ) - if err != nil { - return err - } - } - - err = w.Manager.SetSyncedTo(ns, &waddrmgr.BlockStamp{ - Hash: *hash, - Height: height, - Timestamp: header.Timestamp, - }) - if err != nil { - return err - } - - // Checkpoint our state every 10K blocks. - if height%10000 == 0 { - if err := tx.Commit(); err != nil { - return err - } - - log.Infof("Caught up to height %d", height) - - tx, err = w.db.BeginReadWriteTx() - if err != nil { - return err - } - ns = tx.ReadWriteBucket(waddrmgrNamespaceKey) - } - - // If we've found our birthday, we can return errDone to signal - // that we should stop scanning the chain and persist our state. - if birthdayStamp != nil { - return errDone - } - - return nil - }) - if err != nil && err != errDone { - tx.Rollback() - return nil, err - } - - // If a birthday stamp has yet to be found, we'll return an error - // indicating so, but only if this is a live chain like it is the case - // with testnet and mainnet. - if birthdayStamp == nil && !w.isDevEnv() { - tx.Rollback() - return nil, fmt.Errorf("did not find a suitable birthday "+ - "block with a timestamp greater than %v", birthday) - } - - // Otherwise, if we're in a development environment and we've yet to - // find a birthday block due to the chain not being current, we'll - // use the last block we've synced to as our birthday to proceed. - if birthdayStamp == nil { - syncedTo := w.Manager.SyncedTo() - err := w.Manager.SetBirthdayBlock(ns, syncedTo, true) + // Binary search for a block that meets the birthday timestamp by a + // margin of +/-2 hours. + for { + // Retrieve the timestamp for the block halfway through our + // range. + mid := left + (right-left)/2 + hash, err := chainClient.GetBlockHash(int64(mid)) if err != nil { return nil, err } - birthdayStamp = &syncedTo + header, err := chainClient.GetBlockHeader(hash) + if err != nil { + return nil, err + } + + log.Debugf("Checking candidate block: height=%v, hash=%v, "+ + "timestamp=%v", mid, hash, header.Timestamp) + + // If the search happened to reach either of our range extremes, + // then we'll just use that as there's nothing left to search. + if mid == startHeight || mid == bestHeight || mid == left { + birthdayBlock = &waddrmgr.BlockStamp{ + Hash: *hash, + Height: int32(mid), + Timestamp: header.Timestamp, + } + break + } + + // The block's timestamp is more than 2 hours after the + // birthday, so look for a lower block. + if header.Timestamp.Sub(birthday) > birthdayBlockDelta { + right = mid + continue + } + + // The birthday is more than 2 hours before the block's + // timestamp, so look for a higher block. + if header.Timestamp.Sub(birthday) < -birthdayBlockDelta { + left = mid + continue + } + + birthdayBlock = &waddrmgr.BlockStamp{ + Hash: *hash, + Height: int32(mid), + Timestamp: header.Timestamp, + } + break } - if err := tx.Commit(); err != nil { - tx.Rollback() - return nil, err - } + log.Debugf("Found birthday block: height=%d, hash=%v, timestamp=%v", + birthdayBlock.Height, birthdayBlock.Hash, + birthdayBlock.Timestamp) - return birthdayStamp, nil + return birthdayBlock, nil } // recovery attempts to recover any unspent outputs that pay to any of our -// addresses starting from the specified height. -// -// NOTE: The starting height must be at least the height of the wallet's -// birthday or later. -func (w *Wallet) recovery(startHeight int32) error { +// addresses starting from our birthday, or the wallet's tip (if higher), which +// would indicate resuming a recovery after a restart. +func (w *Wallet) recovery(chainClient chain.Interface, + birthdayBlock *waddrmgr.BlockStamp) error { + log.Infof("RECOVERY MODE ENABLED -- rescanning for used addresses "+ "with recovery_window=%d", w.recoveryWindow) @@ -667,110 +630,129 @@ func (w *Wallet) recovery(startHeight int32) error { if err != nil { return err } - tx, err := w.db.BeginReadWriteTx() - if err != nil { - return err - } - txMgrNS := tx.ReadBucket(wtxmgrNamespaceKey) - credits, err := w.TxStore.UnspentOutputs(txMgrNS) - if err != nil { - tx.Rollback() - return err - } - addrMgrNS := tx.ReadWriteBucket(waddrmgrNamespaceKey) - err = recoveryMgr.Resurrect(addrMgrNS, scopedMgrs, credits) - if err != nil { - tx.Rollback() - return err - } - - // We'll also retrieve our chain backend client in order to filter the - // blocks as we go. - chainClient, err := w.requireChainClient() - if err != nil { - tx.Rollback() - return err - } - - // We'll begin scanning the chain from the specified starting height. - // Since we assume that the lowest height we start with will at least be - // that of our birthday, we can just add every block we process from - // this point forward to the recovery batch. - err = w.scanChain(startHeight, func(height int32, - hash *chainhash.Hash, header *wire.BlockHeader) error { - - recoveryMgr.AddToBlockBatch(hash, height, header.Timestamp) - - // We'll checkpoint our current batch every 2K blocks, so we'll - // need to start a new database transaction. If our current - // batch is empty, then this will act as a NOP. - if height%recoveryBatchSize == 0 { - blockBatch := recoveryMgr.BlockBatch() - err := w.recoverDefaultScopes( - chainClient, tx, addrMgrNS, blockBatch, - recoveryMgr.State(), - ) - if err != nil { - return err - } - - // Clear the batch of all processed blocks. - recoveryMgr.ResetBlockBatch() - - if err := tx.Commit(); err != nil { - return err - } - - log.Infof("Recovered addresses from blocks %d-%d", - blockBatch[0].Height, - blockBatch[len(blockBatch)-1].Height) - - tx, err = w.db.BeginReadWriteTx() - if err != nil { - return err - } - addrMgrNS = tx.ReadWriteBucket(waddrmgrNamespaceKey) + err = walletdb.View(w.db, func(tx walletdb.ReadTx) error { + txMgrNS := tx.ReadBucket(wtxmgrNamespaceKey) + credits, err := w.TxStore.UnspentOutputs(txMgrNS) + if err != nil { + return err } + addrMgrNS := tx.ReadBucket(waddrmgrNamespaceKey) + return recoveryMgr.Resurrect(addrMgrNS, scopedMgrs, credits) + }) + if err != nil { + return err + } - // Since the recovery in a way acts as a rescan, we'll update - // the wallet's tip to point to the current block so that we - // don't unnecessarily rescan the same block again later on. - return w.Manager.SetSyncedTo(addrMgrNS, &waddrmgr.BlockStamp{ + // We'll then need to determine the range of our recovery. This properly + // handles the case where we resume a previous recovery attempt after a + // restart. + startHeight, bestHeight, err := w.getSyncRange(chainClient, birthdayBlock) + if err != nil { + return err + } + + // Now we can begin scanning the chain from the specified starting + // height. Since the recovery process itself acts as rescan, we'll also + // update our wallet's synced state along the way to reflect the blocks + // we process and prevent rescanning them later on. + // + // NOTE: We purposefully don't update our best height since we assume + // that a wallet rescan will be performed from the wallet's tip, which + // will be of bestHeight after completing the recovery process. + var blocks []*waddrmgr.BlockStamp + for height := startHeight; height <= bestHeight; height++ { + hash, err := chainClient.GetBlockHash(int64(height)) + if err != nil { + return err + } + header, err := chainClient.GetBlockHeader(hash) + if err != nil { + return err + } + blocks = append(blocks, &waddrmgr.BlockStamp{ Hash: *hash, Height: height, Timestamp: header.Timestamp, }) - }) - if err != nil { - tx.Rollback() - return err - } - // Now that we've reached the chain tip, we can process our final batch - // with the remaining blocks if it did not reach its maximum size. - blockBatch := recoveryMgr.BlockBatch() - err = w.recoverDefaultScopes( - chainClient, tx, addrMgrNS, blockBatch, recoveryMgr.State(), - ) - if err != nil { - tx.Rollback() - return err - } + // It's possible for us to run into blocks before our birthday + // if our birthday is after our reorg safe height, so we'll make + // sure to not add those to the batch. + if height >= birthdayBlock.Height { + recoveryMgr.AddToBlockBatch( + hash, height, header.Timestamp, + ) + } - // With the recovery complete, we can persist our new state and exit. - if err := tx.Commit(); err != nil { - tx.Rollback() - return err - } + // We'll perform our recovery in batches of 2000 blocks. It's + // possible for us to reach our best height without exceeding + // the recovery batch size, so we can proceed to commit our + // state to disk. + recoveryBatch := recoveryMgr.BlockBatch() + if len(recoveryBatch) == recoveryBatchSize || height == bestHeight { + err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + for _, block := range blocks { + err := w.Manager.SetSyncedTo(ns, block) + if err != nil { + return err + } + } + return w.recoverDefaultScopes( + chainClient, tx, ns, recoveryBatch, + recoveryMgr.State(), + ) + }) + if err != nil { + return err + } - if len(blockBatch) > 0 { - log.Infof("Recovered addresses from blocks %d-%d", blockBatch[0].Height, - blockBatch[len(blockBatch)-1].Height) + if len(recoveryBatch) > 0 { + log.Infof("Recovered addresses from blocks "+ + "%d-%d", recoveryBatch[0].Height, + recoveryBatch[len(recoveryBatch)-1].Height) + } + + // Clear the batch of all processed blocks to reuse the + // same memory for future batches. + blocks = blocks[:0] + recoveryMgr.ResetBlockBatch() + } } return nil } +// getSyncRange determines the best height range to sync with the chain to +// ensure we don't rescan blocks more than once. +func (w *Wallet) getSyncRange(chainClient chain.Interface, + birthdayBlock *waddrmgr.BlockStamp) (int32, int32, error) { + + // The wallet requires to store up to MaxReorgDepth blocks, so we'll + // start from there, unless our birthday is before it. + _, bestHeight, err := chainClient.GetBestBlock() + if err != nil { + return 0, 0, err + } + startHeight := bestHeight - waddrmgr.MaxReorgDepth + 1 + if startHeight < 0 { + startHeight = 0 + } + if birthdayBlock.Height < startHeight { + startHeight = birthdayBlock.Height + } + + // If the wallet's tip has surpassed our starting height, then we'll + // start there as we don't need to rescan blocks we've already + // processed. + walletHeight := w.Manager.SyncedTo().Height + if walletHeight > startHeight { + startHeight = walletHeight + } + + return startHeight, bestHeight, nil +} + // defaultScopeManagers fetches the ScopedKeyManagers from the wallet using the // default set of key scopes. func (w *Wallet) defaultScopeManagers() ( diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go new file mode 100644 index 0000000..f70f43c --- /dev/null +++ b/wallet/wallet_test.go @@ -0,0 +1,85 @@ +package wallet + +import ( + "testing" + "time" +) + +// TestLocateBirthdayBlock ensures we can properly map a block in the chain to a +//timestamp. +func TestLocateBirthdayBlock(t *testing.T) { + t.Parallel() + + // We'll use test chains of 30 blocks with a duration between two + // consecutive blocks being slightly greater than the largest margin + // allowed by locateBirthdayBlock. Doing so lets us test the method more + // effectively as there is only one block within the chain that can map + // to a timestamp (this does not apply to the first and last blocks, + // which can map to many timestamps beyond either end of chain). + const ( + numBlocks = 30 + blockInterval = birthdayBlockDelta + 1 + ) + + genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp + + testCases := []struct { + name string + birthday time.Time + birthdayHeight int32 + }{ + { + name: "left-right-left-left", + birthday: genesisTimestamp.Add(8 * blockInterval), + birthdayHeight: 8, + }, + { + name: "right-right-right-left", + birthday: genesisTimestamp.Add(27 * blockInterval), + birthdayHeight: 27, + }, + { + name: "before start height", + birthday: genesisTimestamp.Add(-blockInterval), + birthdayHeight: 0, + }, + { + name: "start height", + birthday: genesisTimestamp, + birthdayHeight: 0, + }, + { + name: "end height", + birthday: genesisTimestamp.Add(numBlocks * blockInterval), + birthdayHeight: numBlocks - 1, + }, + { + name: "after end height", + birthday: genesisTimestamp.Add(2 * numBlocks * blockInterval), + birthdayHeight: numBlocks - 1, + }, + } + + for _, testCase := range testCases { + success := t.Run(testCase.name, func(t *testing.T) { + chainConn := createMockChainConn( + chainParams.GenesisBlock, numBlocks, blockInterval, + ) + birthdayBlock, err := locateBirthdayBlock( + chainConn, testCase.birthday, + ) + if err != nil { + t.Fatalf("unable to locate birthday block: %v", + err) + } + if birthdayBlock.Height != testCase.birthdayHeight { + t.Fatalf("expected birthday block with height "+ + "%d, got %d", testCase.birthdayHeight, + birthdayBlock.Height) + } + }) + if !success { + break + } + } +}