lbcwallet/spvsvc/spvchain/sync_test.go

587 lines
18 KiB
Go
Raw Normal View History

package spvchain_test
import (
"bytes"
"fmt"
"io/ioutil"
"math/rand"
"os"
"reflect"
"sync"
"testing"
"time"
"github.com/aakselrod/btctestlog"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/rpctest"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog"
"github.com/btcsuite/btcrpcclient"
"github.com/btcsuite/btcwallet/spvsvc/spvchain"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
)
2017-04-18 01:55:42 +02:00
const (
logLevel = btclog.TraceLvl
syncTimeout = 30 * time.Second
syncUpdate = time.Second
numTestBlocks = 50
2017-04-18 01:55:42 +02:00
)
func TestSetup(t *testing.T) {
// Create a btcd SimNet node and generate 500 blocks
h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
if err != nil {
t.Fatalf("Couldn't create harness: %s", err)
}
defer h1.TearDown()
err = h1.SetUp(false, 0)
if err != nil {
t.Fatalf("Couldn't set up harness: %s", err)
}
_, err = h1.Node.Generate(500)
if err != nil {
t.Fatalf("Couldn't generate blocks: %s", err)
}
// Create a second btcd SimNet node
h2, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
if err != nil {
t.Fatalf("Couldn't create harness: %s", err)
}
defer h2.TearDown()
err = h2.SetUp(false, 0)
if err != nil {
t.Fatalf("Couldn't set up harness: %s", err)
}
// Create a third btcd SimNet node and generate 900 blocks
h3, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
if err != nil {
t.Fatalf("Couldn't create harness: %s", err)
}
defer h3.TearDown()
err = h3.SetUp(false, 0)
if err != nil {
t.Fatalf("Couldn't set up harness: %s", err)
}
_, err = h3.Node.Generate(900)
if err != nil {
t.Fatalf("Couldn't generate blocks: %s", err)
}
// Connect, sync, and disconnect h1 and h2
err = csd([]*rpctest.Harness{h1, h2})
if err != nil {
t.Fatalf("Couldn't connect/sync/disconnect h1 and h2: %s", err)
}
// Generate 300 blocks on the first node and 350 on the second
_, err = h1.Node.Generate(300)
if err != nil {
t.Fatalf("Couldn't generate blocks: %s", err)
}
_, err = h2.Node.Generate(350)
if err != nil {
t.Fatalf("Couldn't generate blocks: %s", err)
}
// Now we have a node with 800 blocks (h1), 850 blocks (h2), and
// 900 blocks (h3). The chains of nodes h1 and h2 match up to block
// 500. By default, a synchronizing wallet connected to all three
// should synchronize to h3. However, we're going to take checkpoints
// from h1 at 111, 333, 555, and 777, and add those to the
// synchronizing wallet's chain parameters so that it should
// disconnect from h3 at block 111, and from h2 at block 555, and
// then synchronize to block 800 from h1. Order of connection is
// unfortunately not guaranteed, so the reorg may not happen with every
// test.
// Copy parameters and insert checkpoints
modParams := chaincfg.SimNetParams
for _, height := range []int64{111, 333, 555, 777} {
hash, err := h1.Node.GetBlockHash(height)
if err != nil {
t.Fatalf("Couldn't get block hash for height %d: %s",
height, err)
}
modParams.Checkpoints = append(modParams.Checkpoints,
chaincfg.Checkpoint{
Hash: hash,
Height: int32(height),
})
}
// Create a temporary directory, initialize an empty walletdb with an
// SPV chain namespace, and create a configuration for the ChainService.
tempDir, err := ioutil.TempDir("", "spvchain")
if err != nil {
t.Fatalf("Failed to create temporary directory: %s", err)
}
defer os.RemoveAll(tempDir)
db, err := walletdb.Create("bdb", tempDir+"/weks.db")
defer db.Close()
if err != nil {
t.Fatalf("Error opening DB: %s\n", err)
}
ns, err := db.Namespace([]byte("weks"))
if err != nil {
t.Fatalf("Error geting namespace: %s\n", err)
}
config := spvchain.Config{
DataDir: tempDir,
Namespace: ns,
ChainParams: modParams,
AddPeers: []string{
h3.P2PAddress(),
h2.P2PAddress(),
h1.P2PAddress(),
},
}
spvchain.Services = 0
spvchain.MaxPeers = 3
spvchain.BanDuration = 5 * time.Second
spvchain.RequiredServices = wire.SFNodeNetwork
2017-04-18 01:55:42 +02:00
spvchain.WaitForMoreCFHeaders = time.Second
logger, err := btctestlog.NewTestLogger(t)
if err != nil {
t.Fatalf("Could not set up logger: %s", err)
}
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
chainLogger.SetLevel(logLevel)
spvchain.UseLogger(chainLogger)
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
rpcLogger.SetLevel(logLevel)
btcrpcclient.UseLogger(rpcLogger)
svc, err := spvchain.NewChainService(config)
if err != nil {
t.Fatalf("Error creating ChainService: %s", err)
}
svc.Start()
defer svc.Stop()
// Make sure the client synchronizes with the correct node
2017-04-18 01:55:42 +02:00
err = waitForSync(t, svc, h1)
if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err)
}
// Generate 125 blocks on h1 to make sure it reorgs the other nodes.
// Ensure the ChainService instance stays caught up.
h1.Node.Generate(125)
2017-04-18 01:55:42 +02:00
err = waitForSync(t, svc, h1)
if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err)
}
// Connect/sync/disconnect h2 to make it reorg to the h1 chain.
err = csd([]*rpctest.Harness{h1, h2})
if err != nil {
t.Fatalf("Couldn't sync h2 to h1: %s", err)
}
// Generate 3 blocks on h1, one at a time, to make sure the
// ChainService instance stays caught up.
for i := 0; i < 3; i++ {
h1.Node.Generate(1)
2017-04-18 01:55:42 +02:00
err = waitForSync(t, svc, h1)
if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err)
}
}
// Generate 5 blocks on h2 and wait for ChainService to sync to the
// newly-best chain on h2.
2017-04-18 01:55:42 +02:00
h2.Node.Generate(5)
err = waitForSync(t, svc, h2)
if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err)
}
// Generate 7 blocks on h1 and wait for ChainService to sync to the
// newly-best chain on h1.
h1.Node.Generate(7)
2017-04-18 01:55:42 +02:00
err = waitForSync(t, svc, h1)
if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err)
2017-04-18 01:55:42 +02:00
}
// Test that we can get blocks and cfilters via P2P and decide which are
// valid and which aren't.
err = testRandomBlocks(t, svc, h1)
if err != nil {
t.Fatalf("Testing blocks and cfilters failed: %s", err)
}
}
// csd does a connect-sync-disconnect between nodes in order to support
// reorg testing. It brings up and tears down a temporary node, otherwise the
// nodes try to reconnect to each other which results in unintended reorgs.
func csd(harnesses []*rpctest.Harness) error {
hTemp, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
if err != nil {
return err
}
// Tear down node at the end of the function.
defer hTemp.TearDown()
err = hTemp.SetUp(false, 0)
if err != nil {
return err
}
for _, harness := range harnesses {
err = rpctest.ConnectNode(hTemp, harness)
if err != nil {
return err
}
}
return rpctest.JoinNodes(harnesses, rpctest.Blocks)
}
// waitForSync waits for the ChainService to sync to the current chain state.
func waitForSync(t *testing.T, svc *spvchain.ChainService,
2017-04-18 01:55:42 +02:00
correctSyncNode *rpctest.Harness) error {
knownBestHash, knownBestHeight, err :=
correctSyncNode.Node.GetBestBlock()
if err != nil {
return err
}
if logLevel != btclog.Off {
t.Logf("Syncing to %d (%s)", knownBestHeight, knownBestHash)
}
var haveBest *waddrmgr.BlockStamp
haveBest, err = svc.BestSnapshot()
if err != nil {
return fmt.Errorf("Couldn't get best snapshot from "+
"ChainService: %s", err)
}
var total time.Duration
for haveBest.Hash != *knownBestHash {
2017-04-18 01:55:42 +02:00
if total > syncTimeout {
return fmt.Errorf("Timed out after %v waiting for "+
2017-04-18 01:55:42 +02:00
"header synchronization.", syncTimeout)
}
if haveBest.Height > knownBestHeight {
return fmt.Errorf("Synchronized to the wrong chain.")
}
2017-04-18 01:55:42 +02:00
time.Sleep(syncUpdate)
total += syncUpdate
haveBest, err = svc.BestSnapshot()
if err != nil {
return fmt.Errorf("Couldn't get best snapshot from "+
"ChainService: %s", err)
}
if logLevel != btclog.Off {
t.Logf("Synced to %d (%s)", haveBest.Height,
haveBest.Hash)
}
}
// Check if we're current.
if !svc.IsCurrent() {
return fmt.Errorf("ChainService doesn't see itself as current!")
}
// Check if we have all of the cfheaders.
knownBasicHeader, err := correctSyncNode.Node.GetCFilterHeader(
knownBestHash, false)
if err != nil {
return fmt.Errorf("Couldn't get latest basic header from "+
"%s: %s", correctSyncNode.P2PAddress(), err)
}
knownExtHeader, err := correctSyncNode.Node.GetCFilterHeader(
knownBestHash, true)
if err != nil {
return fmt.Errorf("Couldn't get latest extended header from "+
"%s: %s", correctSyncNode.P2PAddress(), err)
}
haveBasicHeader := &chainhash.Hash{}
haveExtHeader := &chainhash.Hash{}
for (*knownBasicHeader.HeaderHashes[0] != *haveBasicHeader) &&
(*knownExtHeader.HeaderHashes[0] != *haveExtHeader) {
if total > syncTimeout {
return fmt.Errorf("Timed out after %v waiting for "+
"cfheaders synchronization.", syncTimeout)
}
haveBasicHeader, _ = svc.GetBasicHeader(*knownBestHash)
haveExtHeader, _ = svc.GetExtHeader(*knownBestHash)
2017-04-18 01:55:42 +02:00
time.Sleep(syncUpdate)
total += syncUpdate
}
if logLevel != btclog.Off {
t.Logf("Synced cfheaders to %d (%s)", haveBest.Height,
haveBest.Hash)
}
// At this point, we know the latest cfheader is stored in the
// ChainService database. We now compare each cfheader the
// harness knows about to what's stored in the ChainService
// database to see if we've missed anything or messed anything
// up.
for i := int32(0); i <= haveBest.Height; i++ {
head, _, err := svc.GetBlockByHeight(uint32(i))
if err != nil {
return fmt.Errorf("Couldn't read block by "+
"height: %s", err)
}
hash := head.BlockHash()
haveBasicHeader, err = svc.GetBasicHeader(hash)
if err != nil {
return fmt.Errorf("Couldn't get basic header "+
"for %d (%s) from DB", i, hash)
}
haveExtHeader, err = svc.GetExtHeader(hash)
if err != nil {
return fmt.Errorf("Couldn't get extended "+
"header for %d (%s) from DB", i, hash)
}
knownBasicHeader, err =
correctSyncNode.Node.GetCFilterHeader(&hash,
false)
if err != nil {
return fmt.Errorf("Couldn't get basic header "+
"for %d (%s) from node %s", i, hash,
correctSyncNode.P2PAddress())
}
knownExtHeader, err =
correctSyncNode.Node.GetCFilterHeader(&hash,
true)
if err != nil {
return fmt.Errorf("Couldn't get extended "+
"header for %d (%s) from node %s", i,
hash, correctSyncNode.P2PAddress())
}
if *haveBasicHeader !=
*knownBasicHeader.HeaderHashes[0] {
return fmt.Errorf("Basic header for %d (%s) "+
"doesn't match node %s. DB: %s, node: "+
"%s", i, hash,
correctSyncNode.P2PAddress(),
haveBasicHeader,
knownBasicHeader.HeaderHashes[0])
}
if *haveExtHeader !=
*knownExtHeader.HeaderHashes[0] {
return fmt.Errorf("Extended header for %d (%s)"+
" doesn't match node %s. DB: %s, node:"+
" %s", i, hash,
correctSyncNode.P2PAddress(),
haveExtHeader,
knownExtHeader.HeaderHashes[0])
}
}
return nil
}
// testRandomBlocks goes through numTestBlocks random blocks and ensures we
// can correctly get filters from them. We don't go through *all* the blocks
// because it can be a little slow, but we'll improve that soon-ish hopefully
// to the point where we can do it.
// TODO: Improve concurrency on framework side.
func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
correctSyncNode *rpctest.Harness) error {
var haveBest *waddrmgr.BlockStamp
haveBest, err := svc.BestSnapshot()
if err != nil {
return fmt.Errorf("Couldn't get best snapshot from "+
"ChainService: %s", err)
}
// Keep track of an error channel
errChan := make(chan error)
var lastErr error
go func() {
for err := range errChan {
if err != nil {
t.Errorf("%s", err)
lastErr = fmt.Errorf("Couldn't validate all " +
"blocks, filters, and filter headers.")
}
}
}()
// Test getting numTestBlocks random blocks and filters.
var wg sync.WaitGroup
heights := rand.Perm(int(haveBest.Height))
for i := 0; i < numTestBlocks; i++ {
wg.Add(1)
height := uint32(heights[i])
go func() {
defer wg.Done()
// Get block header from database.
blockHeader, blockHeight, err := svc.GetBlockByHeight(height)
if err != nil {
errChan <- fmt.Errorf("Couldn't get block "+
"header by height %d: %s", height, err)
return
}
if blockHeight != height {
errChan <- fmt.Errorf("Block height retrieved from DB "+
"doesn't match expected height. Want: %d, "+
"have: %d", height, blockHeight)
return
}
blockHash := blockHeader.BlockHash()
// Get block via RPC.
wantBlock, err := correctSyncNode.Node.GetBlock(&blockHash)
if err != nil {
errChan <- fmt.Errorf("Couldn't get block %d (%s) by RPC",
height, blockHash)
return
}
// Get block from network.
haveBlock := svc.GetBlockFromNetwork(blockHash)
if haveBlock == nil {
errChan <- fmt.Errorf("Couldn't get block %d (%s) from"+
"network", height, blockHash)
return
}
// Check that network and RPC blocks match.
if !reflect.DeepEqual(*haveBlock.MsgBlock(), *wantBlock) {
errChan <- fmt.Errorf("Block from network doesn't match "+
"block from RPC. Want: %s, RPC: %s, network: "+
"%s", blockHash, wantBlock.BlockHash(),
haveBlock.MsgBlock().BlockHash())
return
}
// Check that block height matches what we have.
if int32(blockHeight) != haveBlock.Height() {
errChan <- fmt.Errorf("Block height from network doesn't "+
"match expected height. Want: %s, network: %s",
blockHeight, haveBlock.Height())
return
}
// Get basic cfilter from network.
haveFilter := svc.GetCFilter(blockHash, false)
if haveFilter == nil {
errChan <- fmt.Errorf("Couldn't get basic "+
"filter for block %d", height)
return
}
// Get basic cfilter from RPC.
wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash,
false)
if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter for "+
"block %d via RPC: %s", height, err)
return
}
// Check that network and RPC cfilters match.
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
errChan <- fmt.Errorf("Basic filter from P2P network/DB"+
" doesn't match RPC value for block %d", height)
return
}
// Calculate basic filter from block.
calcFilter, err := spvchain.BuildBasicFilter(
haveBlock.MsgBlock())
if err != nil {
errChan <- fmt.Errorf("Couldn't build basic filter for "+
"block %d (%s): %s", height, blockHash, err)
return
}
// Check that the network value matches the calculated value
// from the block.
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
errChan <- fmt.Errorf("Basic filter from P2P network/DB "+
"doesn't match calculated value for block %d",
height)
return
}
// Get previous basic filter header from the database.
prevHeader, err := svc.GetBasicHeader(blockHeader.PrevBlock)
if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter header "+
"for block %d (%s) from DB: %s", height-1,
blockHeader.PrevBlock, err)
return
}
// Get current basic filter header from the database.
curHeader, err := svc.GetBasicHeader(blockHash)
if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter header "+
"for block %d (%s) from DB: %s", height-1,
blockHash, err)
return
}
// Check that the filter and header line up.
calcHeader := spvchain.MakeHeaderForFilter(calcFilter,
*prevHeader)
if !bytes.Equal(curHeader[:], calcHeader[:]) {
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+
"%s, got: %s", curHeader, calcHeader)
return
}
// Get extended cfilter from network
haveFilter = svc.GetCFilter(blockHash, true)
if haveFilter == nil {
errChan <- fmt.Errorf("Couldn't get extended "+
"filter for block %d", height)
return
}
// Get extended cfilter from RPC
wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash,
true)
if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter for "+
"block %d via RPC: %s", height, err)
return
}
// Check that network and RPC cfilters match
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+
" doesn't match RPC value for block %d", height)
return
}
// Calculate extended filter from block
calcFilter, err = spvchain.BuildExtFilter(
haveBlock.MsgBlock())
if err != nil {
errChan <- fmt.Errorf("Couldn't build extended filter for "+
"block %d (%s): %s", height, blockHash, err)
return
}
// Check that the network value matches the calculated value
// from the block.
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+
" doesn't match calculated value for block %d",
height)
return
}
// Get previous extended filter header from the database.
prevHeader, err = svc.GetExtHeader(blockHeader.PrevBlock)
if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter header"+
" for block %d (%s) from DB: %s", height-1,
blockHeader.PrevBlock, err)
return
}
// Get current basic filter header from the database.
curHeader, err = svc.GetExtHeader(blockHash)
if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter header"+
" for block %d (%s) from DB: %s", height-1,
blockHash, err)
return
}
// Check that the filter and header line up.
calcHeader = spvchain.MakeHeaderForFilter(calcFilter,
*prevHeader)
if !bytes.Equal(curHeader[:], calcHeader[:]) {
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+
"%s, got: %s", curHeader, calcHeader)
return
}
}()
}
// Wait for all queries to finish.
wg.Wait()
if logLevel != btclog.Off {
t.Logf("Finished checking %d blocks and their cfilters",
numTestBlocks)
}
// Close the error channel to make the error monitoring goroutine
// finish.
close(errChan)
return lastErr
}