407 lines
12 KiB
Go
407 lines
12 KiB
Go
package spvchain_test
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/aakselrod/btctestlog"
|
|
"github.com/btcsuite/btcd/chaincfg"
|
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
"github.com/btcsuite/btcd/rpctest"
|
|
"github.com/btcsuite/btcd/wire"
|
|
"github.com/btcsuite/btclog"
|
|
"github.com/btcsuite/btcrpcclient"
|
|
"github.com/btcsuite/btcwallet/spvsvc/spvchain"
|
|
"github.com/btcsuite/btcwallet/waddrmgr"
|
|
"github.com/btcsuite/btcwallet/walletdb"
|
|
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
|
|
)
|
|
|
|
const (
|
|
logLevel = btclog.TraceLvl
|
|
syncTimeout = 30 * time.Second
|
|
syncUpdate = time.Second
|
|
)
|
|
|
|
func TestSetup(t *testing.T) {
|
|
// Create a btcd SimNet node and generate 500 blocks
|
|
h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't create harness: %s", err)
|
|
}
|
|
defer h1.TearDown()
|
|
err = h1.SetUp(false, 0)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't set up harness: %s", err)
|
|
}
|
|
_, err = h1.Node.Generate(500)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't generate blocks: %s", err)
|
|
}
|
|
|
|
// Create a second btcd SimNet node
|
|
h2, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't create harness: %s", err)
|
|
}
|
|
defer h2.TearDown()
|
|
err = h2.SetUp(false, 0)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't set up harness: %s", err)
|
|
}
|
|
|
|
// Create a third btcd SimNet node and generate 900 blocks
|
|
h3, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't create harness: %s", err)
|
|
}
|
|
defer h3.TearDown()
|
|
err = h3.SetUp(false, 0)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't set up harness: %s", err)
|
|
}
|
|
_, err = h3.Node.Generate(900)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't generate blocks: %s", err)
|
|
}
|
|
|
|
// Connect, sync, and disconnect h1 and h2
|
|
err = csd([]*rpctest.Harness{h1, h2})
|
|
if err != nil {
|
|
t.Fatalf("Couldn't connect/sync/disconnect h1 and h2: %s", err)
|
|
}
|
|
|
|
// Generate 300 blocks on the first node and 350 on the second
|
|
_, err = h1.Node.Generate(300)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't generate blocks: %s", err)
|
|
}
|
|
_, err = h2.Node.Generate(350)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't generate blocks: %s", err)
|
|
}
|
|
|
|
// Now we have a node with 800 blocks (h1), 850 blocks (h2), and
|
|
// 900 blocks (h3). The chains of nodes h1 and h2 match up to block
|
|
// 500. By default, a synchronizing wallet connected to all three
|
|
// should synchronize to h3. However, we're going to take checkpoints
|
|
// from h1 at 111, 333, 555, and 777, and add those to the
|
|
// synchronizing wallet's chain parameters so that it should
|
|
// disconnect from h3 at block 111, and from h2 at block 555, and
|
|
// then synchronize to block 800 from h1. Order of connection is
|
|
// unfortunately not guaranteed, so the reorg may not happen with every
|
|
// test.
|
|
|
|
// Copy parameters and insert checkpoints
|
|
modParams := chaincfg.SimNetParams
|
|
for _, height := range []int64{111, 333, 555, 777} {
|
|
hash, err := h1.Node.GetBlockHash(height)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't get block hash for height %d: %s",
|
|
height, err)
|
|
}
|
|
modParams.Checkpoints = append(modParams.Checkpoints,
|
|
chaincfg.Checkpoint{
|
|
Hash: hash,
|
|
Height: int32(height),
|
|
})
|
|
}
|
|
|
|
// Create a temporary directory, initialize an empty walletdb with an
|
|
// SPV chain namespace, and create a configuration for the ChainService.
|
|
tempDir, err := ioutil.TempDir("", "spvchain")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temporary directory: %s", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
db, err := walletdb.Create("bdb", tempDir+"/weks.db")
|
|
defer db.Close()
|
|
if err != nil {
|
|
t.Fatalf("Error opening DB: %s\n", err)
|
|
}
|
|
ns, err := db.Namespace([]byte("weks"))
|
|
if err != nil {
|
|
t.Fatalf("Error geting namespace: %s\n", err)
|
|
}
|
|
config := spvchain.Config{
|
|
DataDir: tempDir,
|
|
Namespace: ns,
|
|
ChainParams: modParams,
|
|
AddPeers: []string{
|
|
h3.P2PAddress(),
|
|
h2.P2PAddress(),
|
|
h1.P2PAddress(),
|
|
},
|
|
}
|
|
|
|
spvchain.Services = 0
|
|
spvchain.MaxPeers = 3
|
|
spvchain.BanDuration = 5 * time.Second
|
|
spvchain.RequiredServices = wire.SFNodeNetwork
|
|
spvchain.WaitForMoreCFHeaders = time.Second
|
|
logger, err := btctestlog.NewTestLogger(t)
|
|
if err != nil {
|
|
t.Fatalf("Could not set up logger: %s", err)
|
|
}
|
|
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
|
|
chainLogger.SetLevel(logLevel)
|
|
spvchain.UseLogger(chainLogger)
|
|
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
|
|
rpcLogger.SetLevel(logLevel)
|
|
btcrpcclient.UseLogger(rpcLogger)
|
|
svc, err := spvchain.NewChainService(config)
|
|
if err != nil {
|
|
t.Fatalf("Error creating ChainService: %s", err)
|
|
}
|
|
svc.Start()
|
|
defer svc.Stop()
|
|
|
|
// Make sure the client synchronizes with the correct node
|
|
err = waitForSync(t, svc, h1)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
|
}
|
|
|
|
// Generate 125 blocks on h1 to make sure it reorgs the other nodes.
|
|
// Ensure the ChainService instance stays caught up.
|
|
h1.Node.Generate(125)
|
|
err = waitForSync(t, svc, h1)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
|
}
|
|
|
|
// Connect/sync/disconnect h2 to make it reorg to the h1 chain.
|
|
err = csd([]*rpctest.Harness{h1, h2})
|
|
if err != nil {
|
|
t.Fatalf("Couldn't sync h2 to h1: %s", err)
|
|
}
|
|
|
|
// Generate 3 blocks on h1, one at a time, to make sure the
|
|
// ChainService instance stays caught up.
|
|
for i := 0; i < 3; i++ {
|
|
h1.Node.Generate(1)
|
|
err = waitForSync(t, svc, h1)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
|
}
|
|
}
|
|
|
|
// Generate 5 blocks on h2 and wait for ChainService to sync to the
|
|
// newly-best chain on h2.
|
|
h2.Node.Generate(5)
|
|
err = waitForSync(t, svc, h2)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
|
}
|
|
|
|
// Generate 7 blocks on h1 and wait for ChainService to sync to the
|
|
// newly-best chain on h1.
|
|
h1.Node.Generate(7)
|
|
err = waitForSync(t, svc, h1)
|
|
if err != nil {
|
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
|
}
|
|
}
|
|
|
|
// csd does a connect-sync-disconnect between nodes in order to support
|
|
// reorg testing. It brings up and tears down a temporary node, otherwise the
|
|
// nodes try to reconnect to each other which results in unintended reorgs.
|
|
func csd(harnesses []*rpctest.Harness) error {
|
|
hTemp, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Tear down node at the end of the function.
|
|
defer hTemp.TearDown()
|
|
err = hTemp.SetUp(false, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, harness := range harnesses {
|
|
err = rpctest.ConnectNode(hTemp, harness)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return rpctest.JoinNodes(harnesses, rpctest.Blocks)
|
|
}
|
|
|
|
// waitForSync waits for the ChainService to sync to the current chain state.
|
|
func waitForSync(t *testing.T, svc *spvchain.ChainService,
|
|
correctSyncNode *rpctest.Harness) error {
|
|
knownBestHash, knownBestHeight, err :=
|
|
correctSyncNode.Node.GetBestBlock()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if logLevel != btclog.Off {
|
|
t.Logf("Syncing to %d (%s)", knownBestHeight, knownBestHash)
|
|
}
|
|
var haveBest *waddrmgr.BlockStamp
|
|
haveBest, err = svc.BestSnapshot()
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get best snapshot from "+
|
|
"ChainService: %s", err)
|
|
}
|
|
var total time.Duration
|
|
for haveBest.Hash != *knownBestHash {
|
|
if total > syncTimeout {
|
|
return fmt.Errorf("Timed out after %v waiting for "+
|
|
"header synchronization.", syncTimeout)
|
|
}
|
|
if haveBest.Height > knownBestHeight {
|
|
return fmt.Errorf("Synchronized to the wrong chain.")
|
|
}
|
|
time.Sleep(syncUpdate)
|
|
total += syncUpdate
|
|
haveBest, err = svc.BestSnapshot()
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get best snapshot from "+
|
|
"ChainService: %s", err)
|
|
}
|
|
if logLevel != btclog.Off {
|
|
t.Logf("Synced to %d (%s)", haveBest.Height,
|
|
haveBest.Hash)
|
|
}
|
|
}
|
|
// Check if we're current.
|
|
if !svc.IsCurrent() {
|
|
return fmt.Errorf("ChainService doesn't see itself as current!")
|
|
}
|
|
// Check if we have all of the cfheaders.
|
|
knownBasicHeader, err := correctSyncNode.Node.GetCFilterHeader(
|
|
knownBestHash, false)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get latest basic header from "+
|
|
"%s: %s", correctSyncNode.P2PAddress(), err)
|
|
}
|
|
knownExtHeader, err := correctSyncNode.Node.GetCFilterHeader(
|
|
knownBestHash, true)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get latest extended header from "+
|
|
"%s: %s", correctSyncNode.P2PAddress(), err)
|
|
}
|
|
haveBasicHeader := &chainhash.Hash{}
|
|
haveExtHeader := &chainhash.Hash{}
|
|
for (*knownBasicHeader.HeaderHashes[0] != *haveBasicHeader) &&
|
|
(*knownExtHeader.HeaderHashes[0] != *haveExtHeader) {
|
|
if total > syncTimeout {
|
|
return fmt.Errorf("Timed out after %v waiting for "+
|
|
"cfheaders synchronization.", syncTimeout)
|
|
}
|
|
haveBasicHeader, _ = svc.GetBasicHeader(*knownBestHash)
|
|
haveExtHeader, _ = svc.GetExtHeader(*knownBestHash)
|
|
time.Sleep(syncUpdate)
|
|
total += syncUpdate
|
|
}
|
|
if logLevel != btclog.Off {
|
|
t.Logf("Synced cfheaders to %d (%s)", haveBest.Height,
|
|
haveBest.Hash)
|
|
}
|
|
// At this point, we know the latest cfheader is stored in the
|
|
// ChainService database. We now compare each cfheader the
|
|
// harness knows about to what's stored in the ChainService
|
|
// database to see if we've missed anything or messed anything
|
|
// up.
|
|
for i := int32(0); i <= haveBest.Height; i++ {
|
|
head, _, err := svc.GetBlockByHeight(uint32(i))
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't read block by "+
|
|
"height: %s", err)
|
|
}
|
|
hash := head.BlockHash()
|
|
haveBasicHeader, err = svc.GetBasicHeader(hash)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get basic header "+
|
|
"for %d (%s) from DB", i, hash)
|
|
}
|
|
haveExtHeader, err = svc.GetExtHeader(hash)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get extended "+
|
|
"header for %d (%s) from DB", i, hash)
|
|
}
|
|
knownBasicHeader, err =
|
|
correctSyncNode.Node.GetCFilterHeader(&hash,
|
|
false)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get basic header "+
|
|
"for %d (%s) from node %s", i, hash,
|
|
correctSyncNode.P2PAddress())
|
|
}
|
|
knownExtHeader, err =
|
|
correctSyncNode.Node.GetCFilterHeader(&hash,
|
|
true)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get extended "+
|
|
"header for %d (%s) from node %s", i,
|
|
hash, correctSyncNode.P2PAddress())
|
|
}
|
|
if *haveBasicHeader !=
|
|
*knownBasicHeader.HeaderHashes[0] {
|
|
return fmt.Errorf("Basic header for %d (%s) "+
|
|
"doesn't match node %s. DB: %s, node: "+
|
|
"%s", i, hash,
|
|
correctSyncNode.P2PAddress(),
|
|
haveBasicHeader,
|
|
knownBasicHeader.HeaderHashes[0])
|
|
}
|
|
if *haveExtHeader !=
|
|
*knownExtHeader.HeaderHashes[0] {
|
|
return fmt.Errorf("Extended header for %d (%s)"+
|
|
" doesn't match node %s. DB: %s, node:"+
|
|
" %s", i, hash,
|
|
correctSyncNode.P2PAddress(),
|
|
haveExtHeader,
|
|
knownExtHeader.HeaderHashes[0])
|
|
}
|
|
}
|
|
// Test getting 15 random filters.
|
|
heights := rand.Perm(int(haveBest.Height))
|
|
for i := 0; i < 15; i++ {
|
|
height := uint32(heights[i])
|
|
block, _, err := svc.GetBlockByHeight(height)
|
|
if err != nil {
|
|
return fmt.Errorf("Get block by height %d:"+
|
|
" %s", height, err)
|
|
}
|
|
blockHash := block.BlockHash()
|
|
haveFilter := svc.GetCFilter(blockHash, false)
|
|
if haveFilter == nil {
|
|
return fmt.Errorf("Couldn't get basic "+
|
|
"filter for block %d", height)
|
|
}
|
|
t.Logf("%x", haveFilter.NBytes())
|
|
wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash,
|
|
false)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get basic filter for "+
|
|
"block %d via RPC: %s", height, err)
|
|
}
|
|
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
|
|
return fmt.Errorf("Basic filter from P2P network/DB"+
|
|
" doesn't match RPC value for block %d", height)
|
|
}
|
|
haveFilter = svc.GetCFilter(blockHash, true)
|
|
if haveFilter == nil {
|
|
return fmt.Errorf("Couldn't get extended "+
|
|
"filter for block %d", height)
|
|
}
|
|
t.Logf("%x", haveFilter.NBytes())
|
|
wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash,
|
|
true)
|
|
if err != nil {
|
|
return fmt.Errorf("Couldn't get extended filter for "+
|
|
"block %d via RPC: %s", height, err)
|
|
}
|
|
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
|
|
return fmt.Errorf("Extended filter from P2P network/DB"+
|
|
" doesn't match RPC value for block %d", height)
|
|
}
|
|
}
|
|
return nil
|
|
}
|