Database update and refactor after rebase.

This commit is contained in:
Alex 2017-04-28 11:36:05 -06:00 committed by Olaoluwa Osuntokun
parent e273e178dd
commit 551a03107a
6 changed files with 1145 additions and 1090 deletions

View file

@ -934,7 +934,7 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
// with the rest of the headers in the message as if // with the rest of the headers in the message as if
// nothing has happened. // nothing has happened.
b.syncPeer = hmsg.peer b.syncPeer = hmsg.peer
_, err = b.server.rollbackToHeight(backHeight) _, err = b.server.rollBackToHeight(backHeight)
if err != nil { if err != nil {
log.Criticalf("Rollback failed: %s", log.Criticalf("Rollback failed: %s",
err) err)
@ -974,7 +974,8 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
} }
// Verify the header at the next checkpoint height matches. // Verify the header at the next checkpoint height matches.
if b.nextCheckpoint != nil && node.height == b.nextCheckpoint.Height { if b.nextCheckpoint != nil &&
node.height == b.nextCheckpoint.Height {
nodeHash := node.header.BlockHash() nodeHash := node.header.BlockHash()
if nodeHash.IsEqual(b.nextCheckpoint.Hash) { if nodeHash.IsEqual(b.nextCheckpoint.Hash) {
receivedCheckpoint = true receivedCheckpoint = true
@ -988,12 +989,14 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
"disconnecting", node.height, "disconnecting", node.height,
nodeHash, hmsg.peer.Addr(), nodeHash, hmsg.peer.Addr(),
b.nextCheckpoint.Hash) b.nextCheckpoint.Hash)
prevCheckpoint := b.findPreviousHeaderCheckpoint(node.height) prevCheckpoint :=
b.findPreviousHeaderCheckpoint(
node.height)
log.Infof("Rolling back to previous validated "+ log.Infof("Rolling back to previous validated "+
"checkpoint at height %d/hash %s", "checkpoint at height %d/hash %s",
prevCheckpoint.Height, prevCheckpoint.Height,
prevCheckpoint.Hash) prevCheckpoint.Hash)
_, err := b.server.rollbackToHeight(uint32( _, err := b.server.rollBackToHeight(uint32(
prevCheckpoint.Height)) prevCheckpoint.Height))
if err != nil { if err != nil {
log.Criticalf("Rollback failed: %s", log.Criticalf("Rollback failed: %s",
@ -1371,7 +1374,7 @@ func (b *blockManager) calcNextRequiredDifficulty(newBlockTime time.Time,
// Get the block node at the previous retarget (targetTimespan days // Get the block node at the previous retarget (targetTimespan days
// worth of blocks). // worth of blocks).
firstNode, _, err := b.server.GetBlockByHeight( firstNode, err := b.server.GetBlockByHeight(
uint32(lastNode.height + 1 - b.blocksPerRetarget)) uint32(lastNode.height + 1 - b.blocksPerRetarget))
if err != nil { if err != nil {
return 0, err return 0, err
@ -1451,7 +1454,8 @@ func (b *blockManager) findPrevTestNetDifficulty(hList *list.List) (uint32, erro
if el != nil { if el != nil {
iterNode = el.Value.(*headerNode).header iterNode = el.Value.(*headerNode).header
} else { } else {
node, _, err := b.server.GetBlockByHeight(uint32(iterHeight)) node, err := b.server.GetBlockByHeight(
uint32(iterHeight))
if err != nil { if err != nil {
log.Errorf("GetBlockByHeight: %s", err) log.Errorf("GetBlockByHeight: %s", err)
return 0, err return 0, err

View file

@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/gcs" "github.com/btcsuite/btcutil/gcs"
@ -30,7 +29,7 @@ var (
// Key names for various database fields. // Key names for various database fields.
var ( var (
// Bucket names. // Bucket names.
spvBucketName = []byte("spv") spvBucketName = []byte("spvchain")
blockHeaderBucketName = []byte("bh") blockHeaderBucketName = []byte("bh")
basicHeaderBucketName = []byte("bfh") basicHeaderBucketName = []byte("bfh")
basicFilterBucketName = []byte("bf") basicFilterBucketName = []byte("bf")
@ -59,41 +58,71 @@ func uint64ToBytes(number uint64) []byte {
return buf return buf
} }
// dbUpdateOption is a function type for the kind of DB update to be done.
// These can call each other and dbViewOption functions; however, they cannot
// be called by dbViewOption functions.
type dbUpdateOption func(bucket walletdb.ReadWriteBucket) error
// dbViewOption is a funciton type for the kind of data to be fetched from DB.
// These can call each other and can be called by dbUpdateOption functions;
// however, they cannot call dbUpdateOption functions.
type dbViewOption func(bucket walletdb.ReadBucket) error
// fetchDBVersion fetches the current manager version from the database. // fetchDBVersion fetches the current manager version from the database.
func fetchDBVersion(tx walletdb.Tx) (uint32, error) { func (s *ChainService) fetchDBVersion() (uint32, error) {
bucket := tx.RootBucket().Bucket(spvBucketName) var version uint32
err := s.dbView(fetchDBVersion(&version))
return version, err
}
func fetchDBVersion(version *uint32) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
verBytes := bucket.Get(dbVersionName) verBytes := bucket.Get(dbVersionName)
if verBytes == nil { if verBytes == nil {
return 0, fmt.Errorf("required version number not stored in " + return fmt.Errorf("required version number not " +
"database") "stored in database")
}
*version = binary.LittleEndian.Uint32(verBytes)
return nil
} }
version := binary.LittleEndian.Uint32(verBytes)
return version, nil
} }
// putDBVersion stores the provided version to the database. // putDBVersion stores the provided version to the database.
func putDBVersion(tx walletdb.Tx, version uint32) error { func (s *ChainService) putDBVersion(version uint32) error {
bucket := tx.RootBucket().Bucket(spvBucketName) return s.dbUpdate(putDBVersion(version))
}
func putDBVersion(version uint32) dbUpdateOption {
return func(bucket walletdb.ReadWriteBucket) error {
verBytes := uint32ToBytes(version) verBytes := uint32ToBytes(version)
return bucket.Put(dbVersionName, verBytes) return bucket.Put(dbVersionName, verBytes)
}
} }
// putMaxBlockHeight stores the max block height to the database. // putMaxBlockHeight stores the max block height to the database.
func putMaxBlockHeight(tx walletdb.Tx, maxBlockHeight uint32) error { func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error {
bucket := tx.RootBucket().Bucket(spvBucketName) return s.dbUpdate(putMaxBlockHeight(maxBlockHeight))
}
func putMaxBlockHeight(maxBlockHeight uint32) dbUpdateOption {
return func(bucket walletdb.ReadWriteBucket) error {
maxBlockHeightBytes := uint32ToBytes(maxBlockHeight) maxBlockHeightBytes := uint32ToBytes(maxBlockHeight)
err := bucket.Put(maxBlockHeightName, maxBlockHeightBytes) err := bucket.Put(maxBlockHeightName, maxBlockHeightBytes)
if err != nil { if err != nil {
return fmt.Errorf("failed to store max block height: %s", err) return fmt.Errorf("failed to store max block height: %s", err)
} }
return nil return nil
}
} }
// putBlock stores the provided block header and height, keyed to the block // putBlock stores the provided block header and height, keyed to the block
// hash, in the database. // hash, in the database.
func putBlock(tx walletdb.Tx, header wire.BlockHeader, height uint32) error { func (s *ChainService) putBlock(header wire.BlockHeader, height uint32) error {
return s.dbUpdate(putBlock(header, height))
}
func putBlock(header wire.BlockHeader, height uint32) dbUpdateOption {
return func(bucket walletdb.ReadWriteBucket) error {
var buf bytes.Buffer var buf bytes.Buffer
err := header.Serialize(&buf) err := header.Serialize(&buf)
if err != nil { if err != nil {
@ -103,300 +132,431 @@ func putBlock(tx walletdb.Tx, header wire.BlockHeader, height uint32) error {
if err != nil { if err != nil {
return err return err
} }
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName)
blockHash := header.BlockHash() blockHash := header.BlockHash()
bhBucket := bucket.NestedReadWriteBucket(blockHeaderBucketName)
err = bucket.Put(blockHash[:], buf.Bytes()) err = bhBucket.Put(blockHash[:], buf.Bytes())
if err != nil { if err != nil {
return fmt.Errorf("failed to store SPV block info: %s", err) return fmt.Errorf("failed to store SPV block info: %s",
err)
} }
err = bhBucket.Put(uint32ToBytes(height), blockHash[:])
err = bucket.Put(uint32ToBytes(height), blockHash[:])
if err != nil { if err != nil {
return fmt.Errorf("failed to store block height info: %s", err) return fmt.Errorf("failed to store block height info:"+
" %s", err)
} }
return nil return nil
}
} }
// putFilter stores the provided filter, keyed to the block hash, in the // putFilter stores the provided filter, keyed to the block hash, in the
// appropriate filter bucket in the database. // appropriate filter bucket in the database.
func putFilter(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte, func (s *ChainService) putFilter(blockHash chainhash.Hash, bucketName []byte,
filter *gcs.Filter) error { filter *gcs.Filter) error {
return s.dbUpdate(putFilter(blockHash, bucketName, filter))
}
func putFilter(blockHash chainhash.Hash, bucketName []byte,
filter *gcs.Filter) dbUpdateOption {
return func(bucket walletdb.ReadWriteBucket) error {
var buf bytes.Buffer var buf bytes.Buffer
_, err := buf.Write(filter.NBytes()) _, err := buf.Write(filter.NBytes())
if err != nil { if err != nil {
return err return err
} }
filterBucket := bucket.NestedReadWriteBucket(bucketName)
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) err = filterBucket.Put(blockHash[:], buf.Bytes())
err = bucket.Put(blockHash[:], buf.Bytes())
if err != nil { if err != nil {
return fmt.Errorf("failed to store filter: %s", err) return fmt.Errorf("failed to store filter: %s", err)
} }
return nil return nil
}
} }
// putBasicFilter stores the provided filter, keyed to the block hash, in the // putBasicFilter stores the provided filter, keyed to the block hash, in the
// basic filter bucket in the database. // basic filter bucket in the database.
func putBasicFilter(tx walletdb.Tx, blockHash chainhash.Hash, func (s *ChainService) putBasicFilter(blockHash chainhash.Hash,
filter *gcs.Filter) error { filter *gcs.Filter) error {
return putFilter(tx, blockHash, basicFilterBucketName, filter) return s.dbUpdate(putBasicFilter(blockHash, filter))
}
func putBasicFilter(blockHash chainhash.Hash,
filter *gcs.Filter) dbUpdateOption {
return putFilter(blockHash, basicFilterBucketName, filter)
} }
// putExtFilter stores the provided filter, keyed to the block hash, in the // putExtFilter stores the provided filter, keyed to the block hash, in the
// extended filter bucket in the database. // extended filter bucket in the database.
func putExtFilter(tx walletdb.Tx, blockHash chainhash.Hash, func (s *ChainService) putExtFilter(blockHash chainhash.Hash,
filter *gcs.Filter) error { filter *gcs.Filter) error {
return putFilter(tx, blockHash, extFilterBucketName, filter) return s.dbUpdate(putExtFilter(blockHash, filter))
}
func putExtFilter(blockHash chainhash.Hash,
filter *gcs.Filter) dbUpdateOption {
return putFilter(blockHash, extFilterBucketName, filter)
} }
// putHeader stores the provided header, keyed to the block hash, in the // putHeader stores the provided header, keyed to the block hash, in the
// appropriate filter header bucket in the database. // appropriate filter header bucket in the database.
func putHeader(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte, func (s *ChainService) putHeader(blockHash chainhash.Hash, bucketName []byte,
filterTip chainhash.Hash) error { filterTip chainhash.Hash) error {
return s.dbUpdate(putHeader(blockHash, bucketName, filterTip))
}
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) func putHeader(blockHash chainhash.Hash, bucketName []byte,
filterTip chainhash.Hash) dbUpdateOption {
err := bucket.Put(blockHash[:], filterTip[:]) return func(bucket walletdb.ReadWriteBucket) error {
headerBucket := bucket.NestedReadWriteBucket(bucketName)
err := headerBucket.Put(blockHash[:], filterTip[:])
if err != nil { if err != nil {
return fmt.Errorf("failed to store filter header: %s", err) return fmt.Errorf("failed to store filter header: %s", err)
} }
return nil return nil
}
} }
// putBasicHeader stores the provided header, keyed to the block hash, in the // putBasicHeader stores the provided header, keyed to the block hash, in the
// basic filter header bucket in the database. // basic filter header bucket in the database.
func putBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash, func (s *ChainService) putBasicHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) error { filterTip chainhash.Hash) error {
return putHeader(tx, blockHash, basicHeaderBucketName, filterTip) return s.dbUpdate(putBasicHeader(blockHash, filterTip))
}
func putBasicHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) dbUpdateOption {
return putHeader(blockHash, basicHeaderBucketName, filterTip)
} }
// putExtHeader stores the provided header, keyed to the block hash, in the // putExtHeader stores the provided header, keyed to the block hash, in the
// extended filter header bucket in the database. // extended filter header bucket in the database.
func putExtHeader(tx walletdb.Tx, blockHash chainhash.Hash, func (s *ChainService) putExtHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) error { filterTip chainhash.Hash) error {
return putHeader(tx, blockHash, extHeaderBucketName, filterTip) return s.dbUpdate(putExtHeader(blockHash, filterTip))
}
func putExtHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) dbUpdateOption {
return putHeader(blockHash, extHeaderBucketName, filterTip)
} }
// getFilter retreives the filter, keyed to the provided block hash, from the // getFilter retreives the filter, keyed to the provided block hash, from the
// appropriate filter bucket in the database. // appropriate filter bucket in the database.
func getFilter(tx walletdb.Tx, blockHash chainhash.Hash, func (s *ChainService) getFilter(blockHash chainhash.Hash,
bucketName []byte) (*gcs.Filter, error) { bucketName []byte) (*gcs.Filter, error) {
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) var filter gcs.Filter
err := s.dbView(getFilter(blockHash, bucketName, &filter))
return &filter, err
}
filterBytes := bucket.Get(blockHash[:]) func getFilter(blockHash chainhash.Hash, bucketName []byte,
filter *gcs.Filter) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
filterBucket := bucket.NestedReadBucket(bucketName)
filterBytes := filterBucket.Get(blockHash[:])
if len(filterBytes) == 0 { if len(filterBytes) == 0 {
return nil, fmt.Errorf("failed to get filter") return fmt.Errorf("failed to get filter")
}
calcFilter, err := gcs.FromNBytes(builder.DefaultP, filterBytes)
if calcFilter != nil {
*filter = *calcFilter
}
return err
} }
return gcs.FromNBytes(builder.DefaultP, filterBytes)
} }
// getBasicFilter retrieves the filter, keyed to the provided block hash, from // GetBasicFilter retrieves the filter, keyed to the provided block hash, from
// the basic filter bucket in the database. // the basic filter bucket in the database.
func getBasicFilter(tx walletdb.Tx, blockHash chainhash.Hash) (*gcs.Filter, func (s *ChainService) GetBasicFilter(blockHash chainhash.Hash) (*gcs.Filter,
error) { error) {
return getFilter(tx, blockHash, basicFilterBucketName) var filter gcs.Filter
err := s.dbView(getBasicFilter(blockHash, &filter))
return &filter, err
} }
// getExtFilter retrieves the filter, keyed to the provided block hash, from func getBasicFilter(blockHash chainhash.Hash, filter *gcs.Filter) dbViewOption {
return getFilter(blockHash, basicFilterBucketName, filter)
}
// GetExtFilter retrieves the filter, keyed to the provided block hash, from
// the extended filter bucket in the database. // the extended filter bucket in the database.
func getExtFilter(tx walletdb.Tx, blockHash chainhash.Hash) (*gcs.Filter, func (s *ChainService) GetExtFilter(blockHash chainhash.Hash) (*gcs.Filter,
error) { error) {
return getFilter(tx, blockHash, extFilterBucketName) var filter gcs.Filter
err := s.dbView(getExtFilter(blockHash, &filter))
return &filter, err
}
func getExtFilter(blockHash chainhash.Hash, filter *gcs.Filter) dbViewOption {
return getFilter(blockHash, extFilterBucketName, filter)
} }
// getHeader retrieves the header, keyed to the provided block hash, from the // getHeader retrieves the header, keyed to the provided block hash, from the
// appropriate filter header bucket in the database. // appropriate filter header bucket in the database.
func getHeader(tx walletdb.Tx, blockHash chainhash.Hash, func (s *ChainService) getHeader(blockHash chainhash.Hash,
bucketName []byte) (*chainhash.Hash, error) { bucketName []byte) (*chainhash.Hash, error) {
var filterTip chainhash.Hash
err := s.dbView(getHeader(blockHash, bucketName, &filterTip))
return &filterTip, err
}
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) func getHeader(blockHash chainhash.Hash, bucketName []byte,
filterTip *chainhash.Hash) dbViewOption {
filterTip := bucket.Get(blockHash[:]) return func(bucket walletdb.ReadBucket) error {
headerBucket := bucket.NestedReadBucket(bucketName)
headerBytes := headerBucket.Get(blockHash[:])
if len(filterTip) == 0 { if len(filterTip) == 0 {
return nil, fmt.Errorf("failed to get filter header") return fmt.Errorf("failed to get filter header")
}
calcFilterTip, err := chainhash.NewHash(headerBytes)
if calcFilterTip != nil {
*filterTip = *calcFilterTip
}
return err
} }
return chainhash.NewHash(filterTip)
} }
// getBasicHeader retrieves the header, keyed to the provided block hash, from // GetBasicHeader retrieves the header, keyed to the provided block hash, from
// the basic filter header bucket in the database. // the basic filter header bucket in the database.
func getBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, func (s *ChainService) GetBasicHeader(blockHash chainhash.Hash) (
error) { *chainhash.Hash, error) {
return getHeader(tx, blockHash, basicHeaderBucketName) var filterTip chainhash.Hash
err := s.dbView(getBasicHeader(blockHash, &filterTip))
return &filterTip, err
} }
// getExtHeader retrieves the header, keyed to the provided block hash, from the func getBasicHeader(blockHash chainhash.Hash,
filterTip *chainhash.Hash) dbViewOption {
return getHeader(blockHash, basicHeaderBucketName, filterTip)
}
// GetExtHeader retrieves the header, keyed to the provided block hash, from the
// extended filter header bucket in the database. // extended filter header bucket in the database.
func getExtHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash,
error) { error) {
return getHeader(tx, blockHash, extHeaderBucketName) var filterTip chainhash.Hash
err := s.dbView(getExtHeader(blockHash, &filterTip))
return &filterTip, err
} }
// rollbackLastBlock rolls back the last known block and returns the BlockStamp func getExtHeader(blockHash chainhash.Hash,
filterTip *chainhash.Hash) dbViewOption {
return getHeader(blockHash, extHeaderBucketName, filterTip)
}
// rollBackLastBlock rolls back the last known block and returns the BlockStamp
// representing the new last known block. // representing the new last known block.
func rollbackLastBlock(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { func (s *ChainService) rollBackLastBlock() (*waddrmgr.BlockStamp, error) {
bs, err := syncedTo(tx) var bs waddrmgr.BlockStamp
if err != nil { err := s.dbUpdate(rollBackLastBlock(&bs))
return nil, err return &bs, err
}
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName)
err = bucket.Delete(bs.Hash[:])
if err != nil {
return nil, err
}
err = bucket.Delete(uint32ToBytes(uint32(bs.Height)))
if err != nil {
return nil, err
}
err = putMaxBlockHeight(tx, uint32(bs.Height-1))
if err != nil {
return nil, err
}
return syncedTo(tx)
} }
// getBlockByHash retrieves the block header, filter, and filter tip, based on func rollBackLastBlock(bs *waddrmgr.BlockStamp) dbUpdateOption {
// the provided block hash, from the database. return func(bucket walletdb.ReadWriteBucket) error {
func getBlockByHash(tx walletdb.Tx, blockHash chainhash.Hash) (wire.BlockHeader, headerBucket := bucket.NestedReadWriteBucket(
uint32, error) { blockHeaderBucketName)
//chainhash.Hash, chainhash.Hash, var sync waddrmgr.BlockStamp
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) err := syncedTo(&sync)(bucket)
blockBytes := bucket.Get(blockHash[:])
if len(blockBytes) == 0 {
return wire.BlockHeader{}, 0,
fmt.Errorf("failed to retrieve block info for hash: %s",
blockHash)
}
buf := bytes.NewReader(blockBytes[:wire.MaxBlockHeaderPayload])
var header wire.BlockHeader
err := header.Deserialize(buf)
if err != nil { if err != nil {
return wire.BlockHeader{}, 0, return err
fmt.Errorf("failed to deserialize block header for "+ }
"hash: %s", blockHash) err = headerBucket.Delete(sync.Hash[:])
if err != nil {
return err
}
err = headerBucket.Delete(uint32ToBytes(uint32(sync.Height)))
if err != nil {
return err
}
err = putMaxBlockHeight(uint32(sync.Height - 1))(bucket)
if err != nil {
return err
}
sync = waddrmgr.BlockStamp{}
err = syncedTo(&sync)(bucket)
if sync != (waddrmgr.BlockStamp{}) {
*bs = sync
}
return err
} }
height := binary.LittleEndian.Uint32(
blockBytes[wire.MaxBlockHeaderPayload : wire.MaxBlockHeaderPayload+4])
return header, height, nil
} }
// getBlockHashByHeight retrieves the hash of a block by its height. // rollBackToHeight rolls back all blocks until it hits the specified height.
func getBlockHashByHeight(tx walletdb.Tx, height uint32) (chainhash.Hash, func (s *ChainService) rollBackToHeight(height uint32) (*waddrmgr.BlockStamp, error) {
error) { var bs waddrmgr.BlockStamp
bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) err := s.dbUpdate(rollBackToHeight(height, &bs))
var hash chainhash.Hash return &bs, err
hashBytes := bucket.Get(uint32ToBytes(height))
if hashBytes == nil {
return hash, fmt.Errorf("no block hash for height %d", height)
}
hash.SetBytes(hashBytes)
return hash, nil
} }
// getBlockByHeight retrieves a block's information by its height. func rollBackToHeight(height uint32, bs *waddrmgr.BlockStamp) dbUpdateOption {
func getBlockByHeight(tx walletdb.Tx, height uint32) (wire.BlockHeader, uint32, return func(bucket walletdb.ReadWriteBucket) error {
error) { err := syncedTo(bs)(bucket)
// chainhash.Hash, chainhash.Hash
blockHash, err := getBlockHashByHeight(tx, height)
if err != nil { if err != nil {
return wire.BlockHeader{}, 0, err return err
} }
for uint32(bs.Height) > height {
return getBlockByHash(tx, blockHash) err = rollBackLastBlock(bs)(bucket)
}
// syncedTo retrieves the most recent block's height and hash.
func syncedTo(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) {
header, height, err := latestBlock(tx)
if err != nil { if err != nil {
return nil, err return err
} }
var blockStamp waddrmgr.BlockStamp
blockStamp.Hash = header.BlockHash()
blockStamp.Height = int32(height)
return &blockStamp, nil
}
// latestBlock retrieves all the info about the latest stored block.
func latestBlock(tx walletdb.Tx) (wire.BlockHeader, uint32, error) {
bucket := tx.RootBucket().Bucket(spvBucketName)
maxBlockHeightBytes := bucket.Get(maxBlockHeightName)
if maxBlockHeightBytes == nil {
return wire.BlockHeader{}, 0,
fmt.Errorf("no max block height stored")
}
maxBlockHeight := binary.LittleEndian.Uint32(maxBlockHeightBytes)
header, height, err := getBlockByHeight(tx, maxBlockHeight)
if err != nil {
return wire.BlockHeader{}, 0, err
}
if height != maxBlockHeight {
return wire.BlockHeader{}, 0,
fmt.Errorf("max block height inconsistent")
}
return header, height, nil
}
// CheckConnectivity cycles through all of the block headers, from last to
// first, and makes sure they all connect to each other.
func CheckConnectivity(tx walletdb.Tx) error {
header, height, err := latestBlock(tx)
if err != nil {
return fmt.Errorf("Couldn't retrieve latest block: %s", err)
}
for height > 0 {
newheader, newheight, err := getBlockByHash(tx,
header.PrevBlock)
if err != nil {
return fmt.Errorf("Couldn't retrieve block %s: %s",
header.PrevBlock, err)
}
if newheader.BlockHash() != header.PrevBlock {
return fmt.Errorf("Block %s doesn't match block %s's "+
"PrevBlock (%s)", newheader.BlockHash(),
header.BlockHash(), header.PrevBlock)
}
if newheight != height-1 {
return fmt.Errorf("Block %s doesn't have correct "+
"height: want %d, got %d",
newheader.BlockHash(), height-1, newheight)
}
header = newheader
height = newheight
} }
return nil return nil
}
} }
// blockLocatorFromHash returns a block locator based on the provided hash. // GetBlockByHash retrieves the block header, filter, and filter tip, based on
func blockLocatorFromHash(tx walletdb.Tx, hash chainhash.Hash) blockchain.BlockLocator { // the provided block hash, from the database.
locator := make(blockchain.BlockLocator, 0, wire.MaxBlockLocatorsPerMsg) func (s *ChainService) GetBlockByHash(blockHash chainhash.Hash) (
locator = append(locator, &hash) wire.BlockHeader, uint32, error) {
var header wire.BlockHeader
var height uint32
err := s.dbView(getBlockByHash(blockHash, &header, &height))
return header, height, err
}
func getBlockByHash(blockHash chainhash.Hash, header *wire.BlockHeader,
height *uint32) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
headerBucket := bucket.NestedReadBucket(blockHeaderBucketName)
blockBytes := headerBucket.Get(blockHash[:])
if len(blockBytes) < wire.MaxBlockHeaderPayload+4 {
return fmt.Errorf("failed to retrieve block info for"+
" hash %s: want %d bytes, got %d.", blockHash,
wire.MaxBlockHeaderPayload+4, len(blockBytes))
}
buf := bytes.NewReader(blockBytes[:wire.MaxBlockHeaderPayload])
err := header.Deserialize(buf)
if err != nil {
return fmt.Errorf("failed to deserialize block header "+
"for hash: %s", blockHash)
}
*height = binary.LittleEndian.Uint32(
blockBytes[wire.MaxBlockHeaderPayload : wire.MaxBlockHeaderPayload+4])
return nil
}
}
// GetBlockHashByHeight retrieves the hash of a block by its height.
func (s *ChainService) GetBlockHashByHeight(height uint32) (chainhash.Hash,
error) {
var blockHash chainhash.Hash
err := s.dbView(getBlockHashByHeight(height, &blockHash))
return blockHash, err
}
func getBlockHashByHeight(height uint32,
blockHash *chainhash.Hash) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
headerBucket := bucket.NestedReadBucket(blockHeaderBucketName)
hashBytes := headerBucket.Get(uint32ToBytes(height))
if hashBytes == nil {
return fmt.Errorf("no block hash for height %d", height)
}
blockHash.SetBytes(hashBytes)
return nil
}
}
// GetBlockByHeight retrieves a block's information by its height.
func (s *ChainService) GetBlockByHeight(height uint32) (wire.BlockHeader,
error) {
var header wire.BlockHeader
err := s.dbView(getBlockByHeight(height, &header))
return header, err
}
func getBlockByHeight(height uint32, header *wire.BlockHeader) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
var blockHash chainhash.Hash
err := getBlockHashByHeight(height, &blockHash)(bucket)
if err != nil {
return err
}
var gotHeight uint32
err = getBlockByHash(blockHash, header, &gotHeight)(bucket)
if err != nil {
return err
}
if gotHeight != height {
return fmt.Errorf("Got height %d for block at "+
"requested height %d", gotHeight, height)
}
return nil
}
}
// BestSnapshot is a synonym for SyncedTo
func (s *ChainService) BestSnapshot() (*waddrmgr.BlockStamp, error) {
return s.SyncedTo()
}
// SyncedTo retrieves the most recent block's height and hash.
func (s *ChainService) SyncedTo() (*waddrmgr.BlockStamp, error) {
var bs waddrmgr.BlockStamp
err := s.dbView(syncedTo(&bs))
return &bs, err
}
func syncedTo(bs *waddrmgr.BlockStamp) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
var header wire.BlockHeader
var height uint32
err := latestBlock(&header, &height)(bucket)
if err != nil {
return err
}
bs.Hash = header.BlockHash()
bs.Height = int32(height)
return nil
}
}
// LatestBlock retrieves latest stored block's header and height.
func (s *ChainService) LatestBlock() (wire.BlockHeader, uint32, error) {
var bh wire.BlockHeader
var h uint32
err := s.dbView(latestBlock(&bh, &h))
return bh, h, err
}
func latestBlock(header *wire.BlockHeader, height *uint32) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
maxBlockHeightBytes := bucket.Get(maxBlockHeightName)
if maxBlockHeightBytes == nil {
return fmt.Errorf("no max block height stored")
}
*height = binary.LittleEndian.Uint32(maxBlockHeightBytes)
return getBlockByHeight(*height, header)(bucket)
}
}
// BlockLocatorFromHash returns a block locator based on the provided hash.
func (s *ChainService) BlockLocatorFromHash(hash chainhash.Hash) (
blockchain.BlockLocator, error) {
var locator blockchain.BlockLocator
err := s.dbView(blockLocatorFromHash(hash, &locator))
return locator, err
}
func blockLocatorFromHash(hash chainhash.Hash,
locator *blockchain.BlockLocator) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
// Append the initial hash
*locator = append(*locator, &hash)
// If hash isn't found in DB or this is the genesis block, return // If hash isn't found in DB or this is the genesis block, return
// the locator as is // the locator as is
_, height, err := getBlockByHash(tx, hash) var header wire.BlockHeader
var height uint32
err := getBlockByHash(hash, &header, &height)(bucket)
if (err != nil) || (height == 0) { if (err != nil) || (height == 0) {
return locator return nil
} }
decrement := uint32(1) decrement := uint32(1)
for (height > 0) && (len(locator) < wire.MaxBlockLocatorsPerMsg) { for (height > 0) && (len(*locator) < wire.MaxBlockLocatorsPerMsg) {
// Decrement by 1 for the first 10 blocks, then double the // Decrement by 1 for the first 10 blocks, then double the
// jump until we get to the genesis hash // jump until we get to the genesis hash
if len(locator) > 10 { if len(*locator) > 10 {
decrement *= 2 decrement *= 2
} }
if decrement > height { if decrement > height {
@ -404,32 +564,97 @@ func blockLocatorFromHash(tx walletdb.Tx, hash chainhash.Hash) blockchain.BlockL
} else { } else {
height -= decrement height -= decrement
} }
blockHash, err := getBlockHashByHeight(tx, height) var blockHash chainhash.Hash
err := getBlockHashByHeight(height, &blockHash)(bucket)
if err != nil { if err != nil {
return locator return nil
} }
locator = append(locator, &blockHash) *locator = append(*locator, &blockHash)
} }
return nil
}
}
return locator // LatestBlockLocator returns the block locator for the latest known block
// stored in the database.
func (s *ChainService) LatestBlockLocator() (blockchain.BlockLocator, error) {
var locator blockchain.BlockLocator
err := s.dbView(latestBlockLocator(&locator))
return locator, err
}
func latestBlockLocator(locator *blockchain.BlockLocator) dbViewOption {
return func(bucket walletdb.ReadBucket) error {
var best waddrmgr.BlockStamp
err := syncedTo(&best)(bucket)
if err != nil {
return err
}
return blockLocatorFromHash(best.Hash, locator)(bucket)
}
}
// CheckConnectivity cycles through all of the block headers, from last to
// first, and makes sure they all connect to each other.
func (s *ChainService) CheckConnectivity() error {
return s.dbView(checkConnectivity())
}
func checkConnectivity() dbViewOption {
return func(bucket walletdb.ReadBucket) error {
var header wire.BlockHeader
var height uint32
err := latestBlock(&header, &height)(bucket)
if err != nil {
return fmt.Errorf("Couldn't retrieve latest block: %s",
err)
}
for height > 0 {
var newHeader wire.BlockHeader
var newHeight uint32
err := getBlockByHash(header.PrevBlock, &newHeader,
&newHeight)(bucket)
if err != nil {
return fmt.Errorf("Couldn't retrieve block %s:"+
" %s", header.PrevBlock, err)
}
if newHeader.BlockHash() != header.PrevBlock {
return fmt.Errorf("Block %s doesn't match "+
"block %s's PrevBlock (%s)",
newHeader.BlockHash(),
header.BlockHash(), header.PrevBlock)
}
if newHeight != height-1 {
return fmt.Errorf("Block %s doesn't have "+
"correct height: want %d, got %d",
newHeader.BlockHash(), height-1,
newHeight)
}
header = newHeader
height = newHeight
}
return nil
}
} }
// createSPVNS creates the initial namespace structure needed for all of the // createSPVNS creates the initial namespace structure needed for all of the
// SPV-related data. This includes things such as all of the buckets as well as // SPV-related data. This includes things such as all of the buckets as well as
// the version and creation date. // the version and creation date.
func createSPVNS(namespace walletdb.Namespace, params *chaincfg.Params) error { func (s *ChainService) createSPVNS() error {
err := namespace.Update(func(tx walletdb.Tx) error { tx, err := s.db.BeginReadWriteTx()
rootBucket := tx.RootBucket()
spvBucket, err := rootBucket.CreateBucketIfNotExists(spvBucketName)
if err != nil { if err != nil {
return fmt.Errorf("failed to create main bucket: %s", return err
err) }
spvBucket, err := tx.CreateTopLevelBucket(spvBucketName)
if err != nil {
return fmt.Errorf("failed to create main bucket: %s", err)
} }
_, err = spvBucket.CreateBucketIfNotExists(blockHeaderBucketName) _, err = spvBucket.CreateBucketIfNotExists(blockHeaderBucketName)
if err != nil { if err != nil {
return fmt.Errorf("failed to create block header "+ return fmt.Errorf("failed to create block header bucket: %s",
"bucket: %s", err) err)
} }
_, err = spvBucket.CreateBucketIfNotExists(basicFilterBucketName) _, err = spvBucket.CreateBucketIfNotExists(basicFilterBucketName)
@ -465,53 +690,55 @@ func createSPVNS(namespace walletdb.Namespace, params *chaincfg.Params) error {
log.Info("Creating wallet SPV namespace.") log.Info("Creating wallet SPV namespace.")
basicFilter, err := builder.BuildBasicFilter( basicFilter, err := builder.BuildBasicFilter(
params.GenesisBlock) s.chainParams.GenesisBlock)
if err != nil { if err != nil {
return err return err
} }
basicFilterTip := builder.MakeHeaderForFilter(basicFilter, basicFilterTip := builder.MakeHeaderForFilter(basicFilter,
params.GenesisBlock.Header.PrevBlock) s.chainParams.GenesisBlock.Header.PrevBlock)
extFilter, err := builder.BuildExtFilter(params.GenesisBlock) extFilter, err := builder.BuildExtFilter(
s.chainParams.GenesisBlock)
if err != nil { if err != nil {
return err return err
} }
extFilterTip := builder.MakeHeaderForFilter(extFilter, extFilterTip := builder.MakeHeaderForFilter(extFilter,
params.GenesisBlock.Header.PrevBlock) s.chainParams.GenesisBlock.Header.PrevBlock)
err = putBlock(tx, params.GenesisBlock.Header, 0) err = putBlock(s.chainParams.GenesisBlock.Header, 0)(spvBucket)
if err != nil { if err != nil {
return err return err
} }
err = putBasicFilter(tx, *params.GenesisHash, basicFilter) err = putBasicFilter(*s.chainParams.GenesisHash, basicFilter)(spvBucket)
if err != nil { if err != nil {
return err return err
} }
err = putBasicHeader(tx, *params.GenesisHash, basicFilterTip) err = putBasicHeader(*s.chainParams.GenesisHash, basicFilterTip)(
spvBucket)
if err != nil { if err != nil {
return err return err
} }
err = putExtFilter(tx, *params.GenesisHash, extFilter) err = putExtFilter(*s.chainParams.GenesisHash, extFilter)(spvBucket)
if err != nil { if err != nil {
return err return err
} }
err = putExtHeader(tx, *params.GenesisHash, extFilterTip) err = putExtHeader(*s.chainParams.GenesisHash, extFilterTip)(spvBucket)
if err != nil { if err != nil {
return err return err
} }
err = putDBVersion(tx, latestDBVersion) err = putDBVersion(latestDBVersion)(spvBucket)
if err != nil { if err != nil {
return err return err
} }
err = putMaxBlockHeight(tx, 0) err = putMaxBlockHeight(0)(spvBucket)
if err != nil { if err != nil {
return err return err
} }
@ -523,11 +750,33 @@ func createSPVNS(namespace walletdb.Namespace, params *chaincfg.Params) error {
"time: %s", err) "time: %s", err)
} }
return nil return tx.Commit()
}) }
if err != nil {
return fmt.Errorf("failed to update database: %s", err) // dbUpdate allows the passed function to update the ChainService DB bucket.
} func (s *ChainService) dbUpdate(updateFunc dbUpdateOption) error {
tx, err := s.db.BeginReadWriteTx()
return nil if err != nil {
tx.Rollback()
return err
}
bucket := tx.ReadWriteBucket(spvBucketName)
err = updateFunc(bucket)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// dbView allows the passed function to read the ChainService DB bucket.
func (s *ChainService) dbView(viewFunc dbViewOption) error {
tx, err := s.db.BeginReadTx()
defer tx.Rollback()
if err != nil {
return err
}
bucket := tx.ReadBucket(spvBucketName)
return viewFunc(bucket)
} }

View file

@ -162,3 +162,112 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
// useful in the future. // useful in the future.
} }
} }
// ConnectedCount returns the number of currently connected peers.
func (s *ChainService) ConnectedCount() int32 {
replyChan := make(chan int32)
s.query <- getConnCountMsg{reply: replyChan}
return <-replyChan
}
// OutboundGroupCount returns the number of peers connected to the given
// outbound group key.
func (s *ChainService) OutboundGroupCount(key string) int {
replyChan := make(chan int)
s.query <- getOutboundGroup{key: key, reply: replyChan}
return <-replyChan
}
// AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures
// describing the persistent (added) nodes.
func (s *ChainService) AddedNodeInfo() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getAddedNodesMsg{reply: replyChan}
return <-replyChan
}
// Peers returns an array of all connected peers.
func (s *ChainService) Peers() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getPeersMsg{reply: replyChan}
return <-replyChan
}
// DisconnectNodeByAddr disconnects a peer by target address. Both outbound and
// inbound nodes will be searched for the target node. An error message will
// be returned if the peer was not found.
func (s *ChainService) DisconnectNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// DisconnectNodeByID disconnects a peer by target node id. Both outbound and
// inbound nodes will be searched for the target node. An error message will be
// returned if the peer was not found.
func (s *ChainService) DisconnectNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByAddr removes a peer from the list of persistent peers if
// present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByID removes a peer by node ID from the list of persistent peers
// if present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// ConnectNode adds `addr' as a new outbound peer. If permanent is true then the
// peer will be persistent and reconnect if the connection is lost.
// It is an error to call this with an already existing peer.
func (s *ChainService) ConnectNode(addr string, permanent bool) error {
replyChan := make(chan error)
s.query <- connectNodeMsg{addr: addr, permanent: permanent, reply: replyChan}
return <-replyChan
}
// ForAllPeers runs a closure over all peers (outbound and persistent) to which
// the ChainService is connected. Nothing is returned because the peerState's
// ForAllPeers method doesn't return anything as the closure passed to it
// doesn't return anything.
func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) {
s.query <- forAllPeersMsg{
closure: closure,
}
}

409
spvsvc/spvchain/query.go Normal file
View file

@ -0,0 +1,409 @@
package spvchain
import (
"sync"
"time"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
)
var (
// QueryTimeout specifies how long to wait for a peer to answer a query.
QueryTimeout = time.Second * 3
// QueryNumRetries specifies how many times to retry sending a query to
// each peer before we've concluded we aren't going to get a valid
// response. This allows to make up for missed messages in some
// instances.
QueryNumRetries = 2
)
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// timeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
timeout time.Duration
// numRetries tells the query how many times to retry asking each peer
// the query.
numRetries uint8
// doneChan lets the query signal the caller when it's done, in case
// it's run in a goroutine.
doneChan chan<- struct{}
}
// QueryOption is a functional option argument to any of the network query
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
// network query).
type QueryOption func(*queryOptions)
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
timeout: QueryTimeout,
numRetries: uint8(QueryNumRetries),
}
}
// Timeout is a query option that lets the query know how long to wait for
// each peer we ask the query to answer it before moving on.
func Timeout(timeout time.Duration) QueryOption {
return func(qo *queryOptions) {
qo.timeout = timeout
}
}
// NumRetries is a query option that lets the query know the maximum number of
// times each peer should be queried. The default is one.
func NumRetries(numRetries uint8) QueryOption {
return func(qo *queryOptions) {
qo.numRetries = numRetries
}
}
// DoneChan allows the caller to pass a channel that will get closed when the
// query is finished.
func DoneChan(doneChan chan<- struct{}) QueryOption {
return func(qo *queryOptions) {
qo.doneChan = doneChan
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
type spMsgSubscription struct {
msgChan chan<- spMsg
quitChan <-chan struct{}
wg *sync.WaitGroup
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (s *ChainService) queryPeers(
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...QueryOption,
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This is done in a single-threaded query because the peerState is held
// in a single thread. This is the only part of the query framework that
// requires access to peerState, so it's done once per query.
peers := s.Peers()
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
allQuit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
// Increase this number to be able to handle more queries at once as
// each channel gets results for all queries, otherwise messages can
// get mixed and there's a vicious cycle of retries causing a bigger
// message flood, more of which get missed.
msgChan := make(chan spMsg)
var subwg sync.WaitGroup
subscription := spMsgSubscription{
msgChan: msgChan,
quitChan: allQuit,
wg: &subwg,
}
// Start a goroutine for each peer that potentially queries that peer.
for _, sp := range peers {
wg.Add(1)
go func(sp *serverPeer) {
numRetries := qo.numRetries
defer wg.Done()
defer sp.unsubscribeRecvMsgs(subscription)
// Should we do this when the goroutine gets a message
// via startQuery rather than at the launch of the
// goroutine?
if !sp.Connected() {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we try to notify
// another of our peer goroutines to
// do a query until we get a signal to
// quit.
select {
case startQuery <- struct{}{}:
case <-quit:
return
case <-allQuit:
return
}
// At this point, we've sent startQuery.
// We return if we've run through this
// section of code numRetries times.
if numRetries--; numRetries == 0 {
return
}
case <-quit:
// After we're told to quit, we return.
return
case <-allQuit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(subscription)
// Don't want the peer hanging on send
// to the channel if we quit before
// reading the channel.
sentChan := make(chan struct{}, 1)
sp.QueueMessage(queryMsg, sentChan)
select {
case <-sentChan:
case <-quit:
return
case <-allQuit:
return
}
timeout = time.After(qo.timeout)
default:
}
}
}(sp)
}
startQuery <- struct{}{}
// This goroutine will wait until all of the peer-query goroutines have
// terminated, and then initiate a query shutdown.
go func() {
wg.Wait()
// If we timed out on each goroutine and didn't quit or time out
// on the main goroutine, make sure our main goroutine knows to
// quit.
select {
case <-allQuit:
default:
close(allQuit)
}
// Close the done channel, if any
if qo.doneChan != nil {
close(qo.doneChan)
}
// Wait until all goroutines started by subscriptions have
// exited after we closed allQuit before letting the message
// channel get garbage collected.
subwg.Wait()
}()
// Loop for any messages sent to us via our subscription channel and
// check them for whether they satisfy the query. Break the loop if it's
// time to quit.
timeout := time.After(time.Duration(len(peers)+1) *
qo.timeout * time.Duration(qo.numRetries))
checkResponses:
for {
select {
case <-timeout:
// When we time out, close the allQuit channel
// if it hasn't already been closed.
select {
case <-allQuit:
default:
close(allQuit)
}
break checkResponses
case <-quit:
break checkResponses
case <-allQuit:
break checkResponses
case sm := <-msgChan:
// TODO: This will get stuck if checkResponse
// gets stuck. This is a caveat for callers that
// should be fixed before exposing this function
// for public use.
checkResponse(sm.sp, sm.msg, quit)
}
}
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool, options ...QueryOption) *gcs.Filter {
getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter
if extended {
getFilter = s.GetExtFilter
getHeader = s.GetExtHeader
putFilter = s.putExtFilter
}
filter, err := getFilter(blockHash)
if err == nil && filter != nil {
return filter
}
block, _, err := s.GetBlockByHash(blockHash)
if err != nil || block.BlockHash() != blockHash {
return nil
}
curHeader, err := getHeader(blockHash)
if err != nil {
return nil
}
prevHeader, err := getHeader(block.PrevBlock)
if err != nil {
return nil
}
s.queryPeers(
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&blockHash, extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
if blockHash != response.BlockHash {
// The response doesn't match our
// request. Ignore this message.
return
}
gotFilter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if builder.MakeHeaderForFilter(gotFilter,
*prevHeader) !=
*curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
close(quit)
filter = gotFilter
default:
}
},
options...,
)
// If we've found a filter, write it to the database for next time.
if filter != nil {
putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t",
blockHash, extended)
}
return filter
}
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
// at a time, until one answers.
func (s *ChainService) GetBlockFromNetwork(
blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
blockHeader, height, err := s.GetBlockByHash(blockHash)
if err != nil || blockHeader.BlockHash() != blockHash {
return nil
}
getData := wire.NewMsgGetData()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
// The block is only updated from the checkResponse function argument,
// which is always called single-threadedly. We don't check the block
// until after the query is finished, so we can just write to it
// naively.
var foundBlock *btcutil.Block
s.queryPeers(
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for %s "+
"received from %s -- "+
"disconnecting peer", blockHash,
sp.Addr())
sp.Disconnect()
return
}
// At this point, the block matches what we know
// about it and we declare it sane. We can kill
// the query and pass the response back to the
// caller.
close(quit)
foundBlock = block
default:
}
},
options...,
)
return foundBlock
}

View file

@ -17,9 +17,6 @@ import (
"github.com/btcsuite/btcd/peer" "github.com/btcsuite/btcd/peer"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
) )
@ -63,15 +60,6 @@ var (
// DisableDNSSeed disables getting initial addresses for Bitcoin nodes // DisableDNSSeed disables getting initial addresses for Bitcoin nodes
// from DNS. // from DNS.
DisableDNSSeed = false DisableDNSSeed = false
// QueryTimeout specifies how long to wait for a peer to answer a query.
QueryTimeout = time.Second * 3
// QueryNumRetries specifies how many times to retry sending a query to
// each peer before we've concluded we aren't going to get a valid
// response. This allows to make up for missed messages in some
// instances.
QueryNumRetries = 2
) )
// updatePeerHeightsMsg is a message sent from the blockmanager to the server // updatePeerHeightsMsg is a message sent from the blockmanager to the server
@ -538,7 +526,7 @@ type ChainService struct {
started int32 started int32
shutdown int32 shutdown int32
namespace walletdb.Namespace db walletdb.DB
chainParams chaincfg.Params chainParams chaincfg.Params
addrManager *addrmgr.AddrManager addrManager *addrmgr.AddrManager
connManager *connmgr.ConnManager connManager *connmgr.ConnManager
@ -562,39 +550,6 @@ func (s *ChainService) BanPeer(sp *serverPeer) {
s.banPeers <- sp s.banPeers <- sp
} }
// BestSnapshot returns the best block hash and height known to the database.
func (s *ChainService) BestSnapshot() (*waddrmgr.BlockStamp, error) {
var best *waddrmgr.BlockStamp
var err error
err = s.namespace.View(func(tx walletdb.Tx) error {
best, err = syncedTo(tx)
return err
})
if err != nil {
return nil, err
}
return best, nil
}
// LatestBlockLocator returns the block locator for the latest known block
// stored in the database.
func (s *ChainService) LatestBlockLocator() (blockchain.BlockLocator, error) {
var locator blockchain.BlockLocator
var err error
err = s.namespace.View(func(tx walletdb.Tx) error {
best, err := syncedTo(tx)
if err != nil {
return err
}
locator = blockLocatorFromHash(tx, best.Hash)
return nil
})
if err != nil {
return nil, err
}
return locator, nil
}
// AddPeer adds a new peer that has already been connected to the server. // AddPeer adds a new peer that has already been connected to the server.
func (s *ChainService) AddPeer(sp *serverPeer) { func (s *ChainService) AddPeer(sp *serverPeer) {
s.newPeers <- sp s.newPeers <- sp
@ -735,7 +690,8 @@ cleanup:
// Config is a struct detailing the configuration of the chain service. // Config is a struct detailing the configuration of the chain service.
type Config struct { type Config struct {
DataDir string DataDir string
Namespace walletdb.Namespace Database walletdb.DB
Namespace []byte
ChainParams chaincfg.Params ChainParams chaincfg.Params
ConnectPeers []string ConnectPeers []string
AddPeers []string AddPeers []string
@ -756,14 +712,14 @@ func NewChainService(cfg Config) (*ChainService, error) {
query: make(chan interface{}), query: make(chan interface{}),
quit: make(chan struct{}), quit: make(chan struct{}),
peerHeightsUpdate: make(chan updatePeerHeightsMsg), peerHeightsUpdate: make(chan updatePeerHeightsMsg),
namespace: cfg.Namespace, db: cfg.Database,
timeSource: blockchain.NewMedianTime(), timeSource: blockchain.NewMedianTime(),
services: Services, services: Services,
userAgentName: UserAgentName, userAgentName: UserAgentName,
userAgentVersion: UserAgentVersion, userAgentVersion: UserAgentVersion,
} }
err := createSPVNS(s.namespace, &s.chainParams) err := s.createSPVNS()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1168,115 +1124,6 @@ func (s *ChainService) peerDoneHandler(sp *serverPeer) {
close(sp.quit) close(sp.quit)
} }
// ConnectedCount returns the number of currently connected peers.
func (s *ChainService) ConnectedCount() int32 {
replyChan := make(chan int32)
s.query <- getConnCountMsg{reply: replyChan}
return <-replyChan
}
// OutboundGroupCount returns the number of peers connected to the given
// outbound group key.
func (s *ChainService) OutboundGroupCount(key string) int {
replyChan := make(chan int)
s.query <- getOutboundGroup{key: key, reply: replyChan}
return <-replyChan
}
// AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures
// describing the persistent (added) nodes.
func (s *ChainService) AddedNodeInfo() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getAddedNodesMsg{reply: replyChan}
return <-replyChan
}
// Peers returns an array of all connected peers.
func (s *ChainService) Peers() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getPeersMsg{reply: replyChan}
return <-replyChan
}
// DisconnectNodeByAddr disconnects a peer by target address. Both outbound and
// inbound nodes will be searched for the target node. An error message will
// be returned if the peer was not found.
func (s *ChainService) DisconnectNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// DisconnectNodeByID disconnects a peer by target node id. Both outbound and
// inbound nodes will be searched for the target node. An error message will be
// returned if the peer was not found.
func (s *ChainService) DisconnectNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByAddr removes a peer from the list of persistent peers if
// present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByID removes a peer by node ID from the list of persistent peers
// if present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// ConnectNode adds `addr' as a new outbound peer. If permanent is true then the
// peer will be persistent and reconnect if the connection is lost.
// It is an error to call this with an already existing peer.
func (s *ChainService) ConnectNode(addr string, permanent bool) error {
replyChan := make(chan error)
s.query <- connectNodeMsg{addr: addr, permanent: permanent, reply: replyChan}
return <-replyChan
}
// ForAllPeers runs a closure over all peers (outbound and persistent) to which
// the ChainService is connected. Nothing is returned because the peerState's
// ForAllPeers method doesn't return anything as the closure passed to it
// doesn't return anything.
func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) {
s.query <- forAllPeersMsg{
closure: closure,
}
}
// UpdatePeerHeights updates the heights of all peers who have have announced // UpdatePeerHeights updates the heights of all peers who have have announced
// the latest connected main chain block, or a recognized orphan. These height // the latest connected main chain block, or a recognized orphan. These height
// updates allow us to dynamically refresh peer heights, ensuring sync peer // updates allow us to dynamically refresh peer heights, ensuring sync peer
@ -1376,562 +1223,8 @@ func (s *ChainService) Stop() error {
return nil return nil
} }
// GetBlockByHeight gets block information from the ChainService database by
// its height.
func (s *ChainService) GetBlockByHeight(height uint32) (wire.BlockHeader,
uint32, error) {
var bh wire.BlockHeader
var h uint32
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
bh, h, err = getBlockByHeight(dbTx, height)
return err
})
return bh, h, err
}
// GetBlockByHash gets block information from the ChainService database by its
// hash.
func (s *ChainService) GetBlockByHash(hash chainhash.Hash) (wire.BlockHeader,
uint32, error) {
var bh wire.BlockHeader
var h uint32
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
bh, h, err = getBlockByHash(dbTx, hash)
return err
})
return bh, h, err
}
// LatestBlock gets the latest block's information from the ChainService
// database.
func (s *ChainService) LatestBlock() (wire.BlockHeader, uint32, error) {
var bh wire.BlockHeader
var h uint32
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
bh, h, err = latestBlock(dbTx)
return err
})
return bh, h, err
}
// putBlock puts a verified block header and height in the ChainService
// database.
func (s *ChainService) putBlock(header wire.BlockHeader, height uint32) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putBlock(dbTx, header, height)
})
}
// putBasicHeader puts a verified basic filter header in the ChainService
// database.
func (s *ChainService) putBasicHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putBasicHeader(dbTx, blockHash, filterTip)
})
}
// putExtHeader puts a verified extended filter header in the ChainService
// database.
func (s *ChainService) putExtHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putExtHeader(dbTx, blockHash, filterTip)
})
}
// GetBasicHeader gets a verified basic filter header from the ChainService
// database.
func (s *ChainService) GetBasicHeader(blockHash chainhash.Hash) (*chainhash.Hash,
error) {
var filterTip *chainhash.Hash
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filterTip, err = getBasicHeader(dbTx, blockHash)
return err
})
return filterTip, err
}
// GetExtHeader gets a verified extended filter header from the ChainService
// database.
func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash,
error) {
var filterTip *chainhash.Hash
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filterTip, err = getExtHeader(dbTx, blockHash)
return err
})
return filterTip, err
}
// putBasicFilter puts a verified basic filter in the ChainService database.
func (s *ChainService) putBasicFilter(blockHash chainhash.Hash,
filter *gcs.Filter) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putBasicFilter(dbTx, blockHash, filter)
})
}
// putExtFilter puts a verified extended filter in the ChainService database.
func (s *ChainService) putExtFilter(blockHash chainhash.Hash,
filter *gcs.Filter) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putExtFilter(dbTx, blockHash, filter)
})
}
// GetBasicFilter gets a verified basic filter from the ChainService database.
func (s *ChainService) GetBasicFilter(blockHash chainhash.Hash) (*gcs.Filter,
error) {
var filter *gcs.Filter
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filter, err = getBasicFilter(dbTx, blockHash)
return err
})
return filter, err
}
// GetExtFilter gets a verified extended filter from the ChainService database.
func (s *ChainService) GetExtFilter(blockHash chainhash.Hash) (*gcs.Filter,
error) {
var filter *gcs.Filter
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filter, err = getExtFilter(dbTx, blockHash)
return err
})
return filter, err
}
// putMaxBlockHeight puts the max block height to the ChainService database.
func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putMaxBlockHeight(dbTx, maxBlockHeight)
})
}
func (s *ChainService) rollbackLastBlock() (*waddrmgr.BlockStamp, error) {
var bs *waddrmgr.BlockStamp
var err error
err = s.namespace.Update(func(dbTx walletdb.Tx) error {
bs, err = rollbackLastBlock(dbTx)
return err
})
return bs, err
}
func (s *ChainService) rollbackToHeight(height uint32) (*waddrmgr.BlockStamp, error) {
var bs *waddrmgr.BlockStamp
var err error
err = s.namespace.Update(func(dbTx walletdb.Tx) error {
bs, err = syncedTo(dbTx)
if err != nil {
return err
}
for uint32(bs.Height) > height {
bs, err = rollbackLastBlock(dbTx)
if err != nil {
return err
}
}
return nil
})
return bs, err
}
// IsCurrent lets the caller know whether the chain service's block manager // IsCurrent lets the caller know whether the chain service's block manager
// thinks its view of the network is current. // thinks its view of the network is current.
func (s *ChainService) IsCurrent() bool { func (s *ChainService) IsCurrent() bool {
return s.blockManager.IsCurrent() return s.blockManager.IsCurrent()
} }
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// timeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
timeout time.Duration
// numRetries tells the query how many times to retry asking each peer
// the query.
numRetries uint8
// doneChan lets the query signal the caller when it's done, in case
// it's run in a goroutine.
doneChan chan<- struct{}
}
// QueryOption is a functional option argument to any of the network query
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
// network query).
type QueryOption func(*queryOptions)
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
timeout: QueryTimeout,
numRetries: uint8(QueryNumRetries),
}
}
// Timeout is a query option that lets the query know how long to wait for
// each peer we ask the query to answer it before moving on.
func Timeout(timeout time.Duration) QueryOption {
return func(qo *queryOptions) {
qo.timeout = timeout
}
}
// NumRetries is a query option that lets the query know the maximum number of
// times each peer should be queried. The default is one.
func NumRetries(numRetries uint8) QueryOption {
return func(qo *queryOptions) {
qo.numRetries = numRetries
}
}
// DoneChan allows the caller to pass a channel that will get closed when the
// query is finished.
func DoneChan(doneChan chan<- struct{}) QueryOption {
return func(qo *queryOptions) {
qo.doneChan = doneChan
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
type spMsgSubscription struct {
msgChan chan<- spMsg
quitChan <-chan struct{}
wg *sync.WaitGroup
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (s *ChainService) queryPeers(
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...QueryOption,
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This is done in a single-threaded query because the peerState is held
// in a single thread. This is the only part of the query framework that
// requires access to peerState, so it's done once per query.
peers := s.Peers()
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
allQuit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
// Increase this number to be able to handle more queries at once as
// each channel gets results for all queries, otherwise messages can
// get mixed and there's a vicious cycle of retries causing a bigger
// message flood, more of which get missed.
msgChan := make(chan spMsg)
var subwg sync.WaitGroup
subscription := spMsgSubscription{
msgChan: msgChan,
quitChan: allQuit,
wg: &subwg,
}
// Start a goroutine for each peer that potentially queries that peer.
for _, sp := range peers {
wg.Add(1)
go func(sp *serverPeer) {
numRetries := qo.numRetries
defer wg.Done()
defer sp.unsubscribeRecvMsgs(subscription)
// Should we do this when the goroutine gets a message
// via startQuery rather than at the launch of the
// goroutine?
if !sp.Connected() {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we try to notify
// another of our peer goroutines to
// do a query until we get a signal to
// quit.
select {
case startQuery <- struct{}{}:
case <-quit:
return
case <-allQuit:
return
}
// At this point, we've sent startQuery.
// We return if we've run through this
// section of code numRetries times.
if numRetries--; numRetries == 0 {
return
}
case <-quit:
// After we're told to quit, we return.
return
case <-allQuit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(subscription)
// Don't want the peer hanging on send
// to the channel if we quit before
// reading the channel.
sentChan := make(chan struct{}, 1)
sp.QueueMessage(queryMsg, sentChan)
select {
case <-sentChan:
case <-quit:
return
case <-allQuit:
return
}
timeout = time.After(qo.timeout)
default:
}
}
}(sp)
}
startQuery <- struct{}{}
// This goroutine will wait until all of the peer-query goroutines have
// terminated, and then initiate a query shutdown.
go func() {
wg.Wait()
// If we timed out on each goroutine and didn't quit or time out
// on the main goroutine, make sure our main goroutine knows to
// quit.
select {
case <-allQuit:
default:
close(allQuit)
}
// Close the done channel, if any
if qo.doneChan != nil {
close(qo.doneChan)
}
// Wait until all goroutines started by subscriptions have
// exited after we closed allQuit before letting the message
// channel get garbage collected.
subwg.Wait()
}()
// Loop for any messages sent to us via our subscription channel and
// check them for whether they satisfy the query. Break the loop if it's
// time to quit.
timeout := time.After(time.Duration(len(peers)+1) *
qo.timeout * time.Duration(qo.numRetries))
checkResponses:
for {
select {
case <-timeout:
// When we time out, close the allQuit channel
// if it hasn't already been closed.
select {
case <-allQuit:
default:
close(allQuit)
}
break checkResponses
case <-quit:
break checkResponses
case <-allQuit:
break checkResponses
case sm := <-msgChan:
// TODO: This will get stuck if checkResponse
// gets stuck. This is a caveat for callers that
// should be fixed before exposing this function
// for public use.
checkResponse(sm.sp, sm.msg, quit)
}
}
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool, options ...QueryOption) *gcs.Filter {
getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter
if extended {
getFilter = s.GetExtFilter
getHeader = s.GetExtHeader
putFilter = s.putExtFilter
}
filter, err := getFilter(blockHash)
if err == nil && filter != nil {
return filter
}
block, _, err := s.GetBlockByHash(blockHash)
if err != nil || block.BlockHash() != blockHash {
return nil
}
curHeader, err := getHeader(blockHash)
if err != nil {
return nil
}
prevHeader, err := getHeader(block.PrevBlock)
if err != nil {
return nil
}
s.queryPeers(
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&blockHash, extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
if blockHash != response.BlockHash {
// The response doesn't match our
// request. Ignore this message.
return
}
gotFilter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if builder.MakeHeaderForFilter(gotFilter,
*prevHeader) !=
*curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
close(quit)
filter = gotFilter
default:
}
},
options...,
)
// If we've found a filter, write it to the database for next time.
if filter != nil {
putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t",
blockHash, extended)
}
return filter
}
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
// at a time, until one answers.
func (s *ChainService) GetBlockFromNetwork(
blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
blockHeader, height, err := s.GetBlockByHash(blockHash)
if err != nil || blockHeader.BlockHash() != blockHash {
return nil
}
getData := wire.NewMsgGetData()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
// The block is only updated from the checkResponse function argument,
// which is always called single-threadedly. We don't check the block
// until after the query is finished, so we can just write to it
// naively.
var foundBlock *btcutil.Block
s.queryPeers(
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for %s "+
"received from %s -- "+
"disconnecting peer", blockHash,
sp.Addr())
sp.Disconnect()
return
}
// At this point, the block matches what we know
// about it and we declare it sane. We can kill
// the query and pass the response back to the
// caller.
close(quit)
foundBlock = block
default:
}
},
options...,
)
return foundBlock
}

View file

@ -25,7 +25,7 @@ import (
) )
var ( var (
logLevel = btclog.TraceLvl logLevel = btclog.Off
syncTimeout = 30 * time.Second syncTimeout = 30 * time.Second
syncUpdate = time.Second syncUpdate = time.Second
// Don't set this too high for your platform, or the tests will miss // Don't set this too high for your platform, or the tests will miss
@ -145,13 +145,12 @@ func TestSetup(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error opening DB: %s\n", err) t.Fatalf("Error opening DB: %s\n", err)
} }
ns, err := db.Namespace([]byte("weks"))
if err != nil { if err != nil {
t.Fatalf("Error geting namespace: %s\n", err) t.Fatalf("Error geting namespace: %s\n", err)
} }
config := spvchain.Config{ config := spvchain.Config{
DataDir: tempDir, DataDir: tempDir,
Namespace: ns, Database: db,
ChainParams: modParams, ChainParams: modParams,
AddPeers: []string{ AddPeers: []string{
h3.P2PAddress(), h3.P2PAddress(),
@ -327,7 +326,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService,
// database to see if we've missed anything or messed anything // database to see if we've missed anything or messed anything
// up. // up.
for i := int32(0); i <= haveBest.Height; i++ { for i := int32(0); i <= haveBest.Height; i++ {
head, _, err := svc.GetBlockByHeight(uint32(i)) head, err := svc.GetBlockByHeight(uint32(i))
if err != nil { if err != nil {
return fmt.Errorf("Couldn't read block by "+ return fmt.Errorf("Couldn't read block by "+
"height: %s", err) "height: %s", err)
@ -413,20 +412,12 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
}() }()
defer wg.Done() defer wg.Done()
// Get block header from database. // Get block header from database.
blockHeader, blockHeight, err := svc.GetBlockByHeight( blockHeader, err := svc.GetBlockByHeight(height)
height)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get block "+ errChan <- fmt.Errorf("Couldn't get block "+
"header by height %d: %s", height, err) "header by height %d: %s", height, err)
return return
} }
if blockHeight != height {
errChan <- fmt.Errorf("Block height retrieved "+
"from DB doesn't match expected "+
"height. Want: %d, have: %d", height,
blockHeight)
return
}
blockHash := blockHeader.BlockHash() blockHash := blockHeader.BlockHash()
// Get block via RPC. // Get block via RPC.
wantBlock, err := correctSyncNode.Node.GetBlock( wantBlock, err := correctSyncNode.Node.GetBlock(
@ -455,11 +446,11 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
return return
} }
// Check that block height matches what we have. // Check that block height matches what we have.
if int32(blockHeight) != haveBlock.Height() { if int32(height) != haveBlock.Height() {
errChan <- fmt.Errorf("Block height from "+ errChan <- fmt.Errorf("Block height from "+
"network doesn't match expected "+ "network doesn't match expected "+
"height. Want: %s, network: %s", "height. Want: %s, network: %s",
blockHeight, haveBlock.Height()) height, haveBlock.Height())
return return
} }
// Get basic cfilter from network. // Get basic cfilter from network.