Database update and refactor after rebase.

This commit is contained in:
Alex 2017-04-28 11:36:05 -06:00 committed by Olaoluwa Osuntokun
parent e273e178dd
commit 551a03107a
6 changed files with 1145 additions and 1090 deletions

View file

@ -934,7 +934,7 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
// with the rest of the headers in the message as if
// nothing has happened.
b.syncPeer = hmsg.peer
_, err = b.server.rollbackToHeight(backHeight)
_, err = b.server.rollBackToHeight(backHeight)
if err != nil {
log.Criticalf("Rollback failed: %s",
err)
@ -974,7 +974,8 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
}
// Verify the header at the next checkpoint height matches.
if b.nextCheckpoint != nil && node.height == b.nextCheckpoint.Height {
if b.nextCheckpoint != nil &&
node.height == b.nextCheckpoint.Height {
nodeHash := node.header.BlockHash()
if nodeHash.IsEqual(b.nextCheckpoint.Hash) {
receivedCheckpoint = true
@ -988,12 +989,14 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
"disconnecting", node.height,
nodeHash, hmsg.peer.Addr(),
b.nextCheckpoint.Hash)
prevCheckpoint := b.findPreviousHeaderCheckpoint(node.height)
prevCheckpoint :=
b.findPreviousHeaderCheckpoint(
node.height)
log.Infof("Rolling back to previous validated "+
"checkpoint at height %d/hash %s",
prevCheckpoint.Height,
prevCheckpoint.Hash)
_, err := b.server.rollbackToHeight(uint32(
_, err := b.server.rollBackToHeight(uint32(
prevCheckpoint.Height))
if err != nil {
log.Criticalf("Rollback failed: %s",
@ -1371,7 +1374,7 @@ func (b *blockManager) calcNextRequiredDifficulty(newBlockTime time.Time,
// Get the block node at the previous retarget (targetTimespan days
// worth of blocks).
firstNode, _, err := b.server.GetBlockByHeight(
firstNode, err := b.server.GetBlockByHeight(
uint32(lastNode.height + 1 - b.blocksPerRetarget))
if err != nil {
return 0, err
@ -1451,7 +1454,8 @@ func (b *blockManager) findPrevTestNetDifficulty(hList *list.List) (uint32, erro
if el != nil {
iterNode = el.Value.(*headerNode).header
} else {
node, _, err := b.server.GetBlockByHeight(uint32(iterHeight))
node, err := b.server.GetBlockByHeight(
uint32(iterHeight))
if err != nil {
log.Errorf("GetBlockByHeight: %s", err)
return 0, err

File diff suppressed because it is too large Load diff

View file

@ -162,3 +162,112 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
// useful in the future.
}
}
// ConnectedCount returns the number of currently connected peers.
func (s *ChainService) ConnectedCount() int32 {
replyChan := make(chan int32)
s.query <- getConnCountMsg{reply: replyChan}
return <-replyChan
}
// OutboundGroupCount returns the number of peers connected to the given
// outbound group key.
func (s *ChainService) OutboundGroupCount(key string) int {
replyChan := make(chan int)
s.query <- getOutboundGroup{key: key, reply: replyChan}
return <-replyChan
}
// AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures
// describing the persistent (added) nodes.
func (s *ChainService) AddedNodeInfo() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getAddedNodesMsg{reply: replyChan}
return <-replyChan
}
// Peers returns an array of all connected peers.
func (s *ChainService) Peers() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getPeersMsg{reply: replyChan}
return <-replyChan
}
// DisconnectNodeByAddr disconnects a peer by target address. Both outbound and
// inbound nodes will be searched for the target node. An error message will
// be returned if the peer was not found.
func (s *ChainService) DisconnectNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// DisconnectNodeByID disconnects a peer by target node id. Both outbound and
// inbound nodes will be searched for the target node. An error message will be
// returned if the peer was not found.
func (s *ChainService) DisconnectNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByAddr removes a peer from the list of persistent peers if
// present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByID removes a peer by node ID from the list of persistent peers
// if present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// ConnectNode adds `addr' as a new outbound peer. If permanent is true then the
// peer will be persistent and reconnect if the connection is lost.
// It is an error to call this with an already existing peer.
func (s *ChainService) ConnectNode(addr string, permanent bool) error {
replyChan := make(chan error)
s.query <- connectNodeMsg{addr: addr, permanent: permanent, reply: replyChan}
return <-replyChan
}
// ForAllPeers runs a closure over all peers (outbound and persistent) to which
// the ChainService is connected. Nothing is returned because the peerState's
// ForAllPeers method doesn't return anything as the closure passed to it
// doesn't return anything.
func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) {
s.query <- forAllPeersMsg{
closure: closure,
}
}

409
spvsvc/spvchain/query.go Normal file
View file

@ -0,0 +1,409 @@
package spvchain
import (
"sync"
"time"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
)
var (
// QueryTimeout specifies how long to wait for a peer to answer a query.
QueryTimeout = time.Second * 3
// QueryNumRetries specifies how many times to retry sending a query to
// each peer before we've concluded we aren't going to get a valid
// response. This allows to make up for missed messages in some
// instances.
QueryNumRetries = 2
)
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// timeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
timeout time.Duration
// numRetries tells the query how many times to retry asking each peer
// the query.
numRetries uint8
// doneChan lets the query signal the caller when it's done, in case
// it's run in a goroutine.
doneChan chan<- struct{}
}
// QueryOption is a functional option argument to any of the network query
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
// network query).
type QueryOption func(*queryOptions)
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
timeout: QueryTimeout,
numRetries: uint8(QueryNumRetries),
}
}
// Timeout is a query option that lets the query know how long to wait for
// each peer we ask the query to answer it before moving on.
func Timeout(timeout time.Duration) QueryOption {
return func(qo *queryOptions) {
qo.timeout = timeout
}
}
// NumRetries is a query option that lets the query know the maximum number of
// times each peer should be queried. The default is one.
func NumRetries(numRetries uint8) QueryOption {
return func(qo *queryOptions) {
qo.numRetries = numRetries
}
}
// DoneChan allows the caller to pass a channel that will get closed when the
// query is finished.
func DoneChan(doneChan chan<- struct{}) QueryOption {
return func(qo *queryOptions) {
qo.doneChan = doneChan
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
type spMsgSubscription struct {
msgChan chan<- spMsg
quitChan <-chan struct{}
wg *sync.WaitGroup
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (s *ChainService) queryPeers(
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...QueryOption,
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This is done in a single-threaded query because the peerState is held
// in a single thread. This is the only part of the query framework that
// requires access to peerState, so it's done once per query.
peers := s.Peers()
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
allQuit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
// Increase this number to be able to handle more queries at once as
// each channel gets results for all queries, otherwise messages can
// get mixed and there's a vicious cycle of retries causing a bigger
// message flood, more of which get missed.
msgChan := make(chan spMsg)
var subwg sync.WaitGroup
subscription := spMsgSubscription{
msgChan: msgChan,
quitChan: allQuit,
wg: &subwg,
}
// Start a goroutine for each peer that potentially queries that peer.
for _, sp := range peers {
wg.Add(1)
go func(sp *serverPeer) {
numRetries := qo.numRetries
defer wg.Done()
defer sp.unsubscribeRecvMsgs(subscription)
// Should we do this when the goroutine gets a message
// via startQuery rather than at the launch of the
// goroutine?
if !sp.Connected() {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we try to notify
// another of our peer goroutines to
// do a query until we get a signal to
// quit.
select {
case startQuery <- struct{}{}:
case <-quit:
return
case <-allQuit:
return
}
// At this point, we've sent startQuery.
// We return if we've run through this
// section of code numRetries times.
if numRetries--; numRetries == 0 {
return
}
case <-quit:
// After we're told to quit, we return.
return
case <-allQuit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(subscription)
// Don't want the peer hanging on send
// to the channel if we quit before
// reading the channel.
sentChan := make(chan struct{}, 1)
sp.QueueMessage(queryMsg, sentChan)
select {
case <-sentChan:
case <-quit:
return
case <-allQuit:
return
}
timeout = time.After(qo.timeout)
default:
}
}
}(sp)
}
startQuery <- struct{}{}
// This goroutine will wait until all of the peer-query goroutines have
// terminated, and then initiate a query shutdown.
go func() {
wg.Wait()
// If we timed out on each goroutine and didn't quit or time out
// on the main goroutine, make sure our main goroutine knows to
// quit.
select {
case <-allQuit:
default:
close(allQuit)
}
// Close the done channel, if any
if qo.doneChan != nil {
close(qo.doneChan)
}
// Wait until all goroutines started by subscriptions have
// exited after we closed allQuit before letting the message
// channel get garbage collected.
subwg.Wait()
}()
// Loop for any messages sent to us via our subscription channel and
// check them for whether they satisfy the query. Break the loop if it's
// time to quit.
timeout := time.After(time.Duration(len(peers)+1) *
qo.timeout * time.Duration(qo.numRetries))
checkResponses:
for {
select {
case <-timeout:
// When we time out, close the allQuit channel
// if it hasn't already been closed.
select {
case <-allQuit:
default:
close(allQuit)
}
break checkResponses
case <-quit:
break checkResponses
case <-allQuit:
break checkResponses
case sm := <-msgChan:
// TODO: This will get stuck if checkResponse
// gets stuck. This is a caveat for callers that
// should be fixed before exposing this function
// for public use.
checkResponse(sm.sp, sm.msg, quit)
}
}
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool, options ...QueryOption) *gcs.Filter {
getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter
if extended {
getFilter = s.GetExtFilter
getHeader = s.GetExtHeader
putFilter = s.putExtFilter
}
filter, err := getFilter(blockHash)
if err == nil && filter != nil {
return filter
}
block, _, err := s.GetBlockByHash(blockHash)
if err != nil || block.BlockHash() != blockHash {
return nil
}
curHeader, err := getHeader(blockHash)
if err != nil {
return nil
}
prevHeader, err := getHeader(block.PrevBlock)
if err != nil {
return nil
}
s.queryPeers(
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&blockHash, extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
if blockHash != response.BlockHash {
// The response doesn't match our
// request. Ignore this message.
return
}
gotFilter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if builder.MakeHeaderForFilter(gotFilter,
*prevHeader) !=
*curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
close(quit)
filter = gotFilter
default:
}
},
options...,
)
// If we've found a filter, write it to the database for next time.
if filter != nil {
putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t",
blockHash, extended)
}
return filter
}
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
// at a time, until one answers.
func (s *ChainService) GetBlockFromNetwork(
blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
blockHeader, height, err := s.GetBlockByHash(blockHash)
if err != nil || blockHeader.BlockHash() != blockHash {
return nil
}
getData := wire.NewMsgGetData()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
// The block is only updated from the checkResponse function argument,
// which is always called single-threadedly. We don't check the block
// until after the query is finished, so we can just write to it
// naively.
var foundBlock *btcutil.Block
s.queryPeers(
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for %s "+
"received from %s -- "+
"disconnecting peer", blockHash,
sp.Addr())
sp.Disconnect()
return
}
// At this point, the block matches what we know
// about it and we declare it sane. We can kill
// the query and pass the response back to the
// caller.
close(quit)
foundBlock = block
default:
}
},
options...,
)
return foundBlock
}

View file

@ -17,9 +17,6 @@ import (
"github.com/btcsuite/btcd/peer"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb"
)
@ -63,15 +60,6 @@ var (
// DisableDNSSeed disables getting initial addresses for Bitcoin nodes
// from DNS.
DisableDNSSeed = false
// QueryTimeout specifies how long to wait for a peer to answer a query.
QueryTimeout = time.Second * 3
// QueryNumRetries specifies how many times to retry sending a query to
// each peer before we've concluded we aren't going to get a valid
// response. This allows to make up for missed messages in some
// instances.
QueryNumRetries = 2
)
// updatePeerHeightsMsg is a message sent from the blockmanager to the server
@ -538,7 +526,7 @@ type ChainService struct {
started int32
shutdown int32
namespace walletdb.Namespace
db walletdb.DB
chainParams chaincfg.Params
addrManager *addrmgr.AddrManager
connManager *connmgr.ConnManager
@ -562,39 +550,6 @@ func (s *ChainService) BanPeer(sp *serverPeer) {
s.banPeers <- sp
}
// BestSnapshot returns the best block hash and height known to the database.
func (s *ChainService) BestSnapshot() (*waddrmgr.BlockStamp, error) {
var best *waddrmgr.BlockStamp
var err error
err = s.namespace.View(func(tx walletdb.Tx) error {
best, err = syncedTo(tx)
return err
})
if err != nil {
return nil, err
}
return best, nil
}
// LatestBlockLocator returns the block locator for the latest known block
// stored in the database.
func (s *ChainService) LatestBlockLocator() (blockchain.BlockLocator, error) {
var locator blockchain.BlockLocator
var err error
err = s.namespace.View(func(tx walletdb.Tx) error {
best, err := syncedTo(tx)
if err != nil {
return err
}
locator = blockLocatorFromHash(tx, best.Hash)
return nil
})
if err != nil {
return nil, err
}
return locator, nil
}
// AddPeer adds a new peer that has already been connected to the server.
func (s *ChainService) AddPeer(sp *serverPeer) {
s.newPeers <- sp
@ -735,7 +690,8 @@ cleanup:
// Config is a struct detailing the configuration of the chain service.
type Config struct {
DataDir string
Namespace walletdb.Namespace
Database walletdb.DB
Namespace []byte
ChainParams chaincfg.Params
ConnectPeers []string
AddPeers []string
@ -756,14 +712,14 @@ func NewChainService(cfg Config) (*ChainService, error) {
query: make(chan interface{}),
quit: make(chan struct{}),
peerHeightsUpdate: make(chan updatePeerHeightsMsg),
namespace: cfg.Namespace,
db: cfg.Database,
timeSource: blockchain.NewMedianTime(),
services: Services,
userAgentName: UserAgentName,
userAgentVersion: UserAgentVersion,
}
err := createSPVNS(s.namespace, &s.chainParams)
err := s.createSPVNS()
if err != nil {
return nil, err
}
@ -1168,115 +1124,6 @@ func (s *ChainService) peerDoneHandler(sp *serverPeer) {
close(sp.quit)
}
// ConnectedCount returns the number of currently connected peers.
func (s *ChainService) ConnectedCount() int32 {
replyChan := make(chan int32)
s.query <- getConnCountMsg{reply: replyChan}
return <-replyChan
}
// OutboundGroupCount returns the number of peers connected to the given
// outbound group key.
func (s *ChainService) OutboundGroupCount(key string) int {
replyChan := make(chan int)
s.query <- getOutboundGroup{key: key, reply: replyChan}
return <-replyChan
}
// AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures
// describing the persistent (added) nodes.
func (s *ChainService) AddedNodeInfo() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getAddedNodesMsg{reply: replyChan}
return <-replyChan
}
// Peers returns an array of all connected peers.
func (s *ChainService) Peers() []*serverPeer {
replyChan := make(chan []*serverPeer)
s.query <- getPeersMsg{reply: replyChan}
return <-replyChan
}
// DisconnectNodeByAddr disconnects a peer by target address. Both outbound and
// inbound nodes will be searched for the target node. An error message will
// be returned if the peer was not found.
func (s *ChainService) DisconnectNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// DisconnectNodeByID disconnects a peer by target node id. Both outbound and
// inbound nodes will be searched for the target node. An error message will be
// returned if the peer was not found.
func (s *ChainService) DisconnectNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByAddr removes a peer from the list of persistent peers if
// present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByAddr(addr string) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
}
// RemoveNodeByID removes a peer by node ID from the list of persistent peers
// if present. An error will be returned if the peer was not found.
func (s *ChainService) RemoveNodeByID(id int32) error {
replyChan := make(chan error)
s.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
}
// ConnectNode adds `addr' as a new outbound peer. If permanent is true then the
// peer will be persistent and reconnect if the connection is lost.
// It is an error to call this with an already existing peer.
func (s *ChainService) ConnectNode(addr string, permanent bool) error {
replyChan := make(chan error)
s.query <- connectNodeMsg{addr: addr, permanent: permanent, reply: replyChan}
return <-replyChan
}
// ForAllPeers runs a closure over all peers (outbound and persistent) to which
// the ChainService is connected. Nothing is returned because the peerState's
// ForAllPeers method doesn't return anything as the closure passed to it
// doesn't return anything.
func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) {
s.query <- forAllPeersMsg{
closure: closure,
}
}
// UpdatePeerHeights updates the heights of all peers who have have announced
// the latest connected main chain block, or a recognized orphan. These height
// updates allow us to dynamically refresh peer heights, ensuring sync peer
@ -1376,562 +1223,8 @@ func (s *ChainService) Stop() error {
return nil
}
// GetBlockByHeight gets block information from the ChainService database by
// its height.
func (s *ChainService) GetBlockByHeight(height uint32) (wire.BlockHeader,
uint32, error) {
var bh wire.BlockHeader
var h uint32
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
bh, h, err = getBlockByHeight(dbTx, height)
return err
})
return bh, h, err
}
// GetBlockByHash gets block information from the ChainService database by its
// hash.
func (s *ChainService) GetBlockByHash(hash chainhash.Hash) (wire.BlockHeader,
uint32, error) {
var bh wire.BlockHeader
var h uint32
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
bh, h, err = getBlockByHash(dbTx, hash)
return err
})
return bh, h, err
}
// LatestBlock gets the latest block's information from the ChainService
// database.
func (s *ChainService) LatestBlock() (wire.BlockHeader, uint32, error) {
var bh wire.BlockHeader
var h uint32
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
bh, h, err = latestBlock(dbTx)
return err
})
return bh, h, err
}
// putBlock puts a verified block header and height in the ChainService
// database.
func (s *ChainService) putBlock(header wire.BlockHeader, height uint32) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putBlock(dbTx, header, height)
})
}
// putBasicHeader puts a verified basic filter header in the ChainService
// database.
func (s *ChainService) putBasicHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putBasicHeader(dbTx, blockHash, filterTip)
})
}
// putExtHeader puts a verified extended filter header in the ChainService
// database.
func (s *ChainService) putExtHeader(blockHash chainhash.Hash,
filterTip chainhash.Hash) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putExtHeader(dbTx, blockHash, filterTip)
})
}
// GetBasicHeader gets a verified basic filter header from the ChainService
// database.
func (s *ChainService) GetBasicHeader(blockHash chainhash.Hash) (*chainhash.Hash,
error) {
var filterTip *chainhash.Hash
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filterTip, err = getBasicHeader(dbTx, blockHash)
return err
})
return filterTip, err
}
// GetExtHeader gets a verified extended filter header from the ChainService
// database.
func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash,
error) {
var filterTip *chainhash.Hash
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filterTip, err = getExtHeader(dbTx, blockHash)
return err
})
return filterTip, err
}
// putBasicFilter puts a verified basic filter in the ChainService database.
func (s *ChainService) putBasicFilter(blockHash chainhash.Hash,
filter *gcs.Filter) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putBasicFilter(dbTx, blockHash, filter)
})
}
// putExtFilter puts a verified extended filter in the ChainService database.
func (s *ChainService) putExtFilter(blockHash chainhash.Hash,
filter *gcs.Filter) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putExtFilter(dbTx, blockHash, filter)
})
}
// GetBasicFilter gets a verified basic filter from the ChainService database.
func (s *ChainService) GetBasicFilter(blockHash chainhash.Hash) (*gcs.Filter,
error) {
var filter *gcs.Filter
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filter, err = getBasicFilter(dbTx, blockHash)
return err
})
return filter, err
}
// GetExtFilter gets a verified extended filter from the ChainService database.
func (s *ChainService) GetExtFilter(blockHash chainhash.Hash) (*gcs.Filter,
error) {
var filter *gcs.Filter
var err error
err = s.namespace.View(func(dbTx walletdb.Tx) error {
filter, err = getExtFilter(dbTx, blockHash)
return err
})
return filter, err
}
// putMaxBlockHeight puts the max block height to the ChainService database.
func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error {
return s.namespace.Update(func(dbTx walletdb.Tx) error {
return putMaxBlockHeight(dbTx, maxBlockHeight)
})
}
func (s *ChainService) rollbackLastBlock() (*waddrmgr.BlockStamp, error) {
var bs *waddrmgr.BlockStamp
var err error
err = s.namespace.Update(func(dbTx walletdb.Tx) error {
bs, err = rollbackLastBlock(dbTx)
return err
})
return bs, err
}
func (s *ChainService) rollbackToHeight(height uint32) (*waddrmgr.BlockStamp, error) {
var bs *waddrmgr.BlockStamp
var err error
err = s.namespace.Update(func(dbTx walletdb.Tx) error {
bs, err = syncedTo(dbTx)
if err != nil {
return err
}
for uint32(bs.Height) > height {
bs, err = rollbackLastBlock(dbTx)
if err != nil {
return err
}
}
return nil
})
return bs, err
}
// IsCurrent lets the caller know whether the chain service's block manager
// thinks its view of the network is current.
func (s *ChainService) IsCurrent() bool {
return s.blockManager.IsCurrent()
}
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// timeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
timeout time.Duration
// numRetries tells the query how many times to retry asking each peer
// the query.
numRetries uint8
// doneChan lets the query signal the caller when it's done, in case
// it's run in a goroutine.
doneChan chan<- struct{}
}
// QueryOption is a functional option argument to any of the network query
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
// network query).
type QueryOption func(*queryOptions)
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
timeout: QueryTimeout,
numRetries: uint8(QueryNumRetries),
}
}
// Timeout is a query option that lets the query know how long to wait for
// each peer we ask the query to answer it before moving on.
func Timeout(timeout time.Duration) QueryOption {
return func(qo *queryOptions) {
qo.timeout = timeout
}
}
// NumRetries is a query option that lets the query know the maximum number of
// times each peer should be queried. The default is one.
func NumRetries(numRetries uint8) QueryOption {
return func(qo *queryOptions) {
qo.numRetries = numRetries
}
}
// DoneChan allows the caller to pass a channel that will get closed when the
// query is finished.
func DoneChan(doneChan chan<- struct{}) QueryOption {
return func(qo *queryOptions) {
qo.doneChan = doneChan
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
type spMsgSubscription struct {
msgChan chan<- spMsg
quitChan <-chan struct{}
wg *sync.WaitGroup
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (s *ChainService) queryPeers(
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...QueryOption,
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This is done in a single-threaded query because the peerState is held
// in a single thread. This is the only part of the query framework that
// requires access to peerState, so it's done once per query.
peers := s.Peers()
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
allQuit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
// Increase this number to be able to handle more queries at once as
// each channel gets results for all queries, otherwise messages can
// get mixed and there's a vicious cycle of retries causing a bigger
// message flood, more of which get missed.
msgChan := make(chan spMsg)
var subwg sync.WaitGroup
subscription := spMsgSubscription{
msgChan: msgChan,
quitChan: allQuit,
wg: &subwg,
}
// Start a goroutine for each peer that potentially queries that peer.
for _, sp := range peers {
wg.Add(1)
go func(sp *serverPeer) {
numRetries := qo.numRetries
defer wg.Done()
defer sp.unsubscribeRecvMsgs(subscription)
// Should we do this when the goroutine gets a message
// via startQuery rather than at the launch of the
// goroutine?
if !sp.Connected() {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we try to notify
// another of our peer goroutines to
// do a query until we get a signal to
// quit.
select {
case startQuery <- struct{}{}:
case <-quit:
return
case <-allQuit:
return
}
// At this point, we've sent startQuery.
// We return if we've run through this
// section of code numRetries times.
if numRetries--; numRetries == 0 {
return
}
case <-quit:
// After we're told to quit, we return.
return
case <-allQuit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(subscription)
// Don't want the peer hanging on send
// to the channel if we quit before
// reading the channel.
sentChan := make(chan struct{}, 1)
sp.QueueMessage(queryMsg, sentChan)
select {
case <-sentChan:
case <-quit:
return
case <-allQuit:
return
}
timeout = time.After(qo.timeout)
default:
}
}
}(sp)
}
startQuery <- struct{}{}
// This goroutine will wait until all of the peer-query goroutines have
// terminated, and then initiate a query shutdown.
go func() {
wg.Wait()
// If we timed out on each goroutine and didn't quit or time out
// on the main goroutine, make sure our main goroutine knows to
// quit.
select {
case <-allQuit:
default:
close(allQuit)
}
// Close the done channel, if any
if qo.doneChan != nil {
close(qo.doneChan)
}
// Wait until all goroutines started by subscriptions have
// exited after we closed allQuit before letting the message
// channel get garbage collected.
subwg.Wait()
}()
// Loop for any messages sent to us via our subscription channel and
// check them for whether they satisfy the query. Break the loop if it's
// time to quit.
timeout := time.After(time.Duration(len(peers)+1) *
qo.timeout * time.Duration(qo.numRetries))
checkResponses:
for {
select {
case <-timeout:
// When we time out, close the allQuit channel
// if it hasn't already been closed.
select {
case <-allQuit:
default:
close(allQuit)
}
break checkResponses
case <-quit:
break checkResponses
case <-allQuit:
break checkResponses
case sm := <-msgChan:
// TODO: This will get stuck if checkResponse
// gets stuck. This is a caveat for callers that
// should be fixed before exposing this function
// for public use.
checkResponse(sm.sp, sm.msg, quit)
}
}
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool, options ...QueryOption) *gcs.Filter {
getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter
if extended {
getFilter = s.GetExtFilter
getHeader = s.GetExtHeader
putFilter = s.putExtFilter
}
filter, err := getFilter(blockHash)
if err == nil && filter != nil {
return filter
}
block, _, err := s.GetBlockByHash(blockHash)
if err != nil || block.BlockHash() != blockHash {
return nil
}
curHeader, err := getHeader(blockHash)
if err != nil {
return nil
}
prevHeader, err := getHeader(block.PrevBlock)
if err != nil {
return nil
}
s.queryPeers(
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&blockHash, extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
if blockHash != response.BlockHash {
// The response doesn't match our
// request. Ignore this message.
return
}
gotFilter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if builder.MakeHeaderForFilter(gotFilter,
*prevHeader) !=
*curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
close(quit)
filter = gotFilter
default:
}
},
options...,
)
// If we've found a filter, write it to the database for next time.
if filter != nil {
putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t",
blockHash, extended)
}
return filter
}
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
// at a time, until one answers.
func (s *ChainService) GetBlockFromNetwork(
blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
blockHeader, height, err := s.GetBlockByHash(blockHash)
if err != nil || blockHeader.BlockHash() != blockHash {
return nil
}
getData := wire.NewMsgGetData()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
// The block is only updated from the checkResponse function argument,
// which is always called single-threadedly. We don't check the block
// until after the query is finished, so we can just write to it
// naively.
var foundBlock *btcutil.Block
s.queryPeers(
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for %s "+
"received from %s -- "+
"disconnecting peer", blockHash,
sp.Addr())
sp.Disconnect()
return
}
// At this point, the block matches what we know
// about it and we declare it sane. We can kill
// the query and pass the response back to the
// caller.
close(quit)
foundBlock = block
default:
}
},
options...,
)
return foundBlock
}

View file

@ -25,7 +25,7 @@ import (
)
var (
logLevel = btclog.TraceLvl
logLevel = btclog.Off
syncTimeout = 30 * time.Second
syncUpdate = time.Second
// Don't set this too high for your platform, or the tests will miss
@ -145,13 +145,12 @@ func TestSetup(t *testing.T) {
if err != nil {
t.Fatalf("Error opening DB: %s\n", err)
}
ns, err := db.Namespace([]byte("weks"))
if err != nil {
t.Fatalf("Error geting namespace: %s\n", err)
}
config := spvchain.Config{
DataDir: tempDir,
Namespace: ns,
Database: db,
ChainParams: modParams,
AddPeers: []string{
h3.P2PAddress(),
@ -327,7 +326,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService,
// database to see if we've missed anything or messed anything
// up.
for i := int32(0); i <= haveBest.Height; i++ {
head, _, err := svc.GetBlockByHeight(uint32(i))
head, err := svc.GetBlockByHeight(uint32(i))
if err != nil {
return fmt.Errorf("Couldn't read block by "+
"height: %s", err)
@ -413,20 +412,12 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
}()
defer wg.Done()
// Get block header from database.
blockHeader, blockHeight, err := svc.GetBlockByHeight(
height)
blockHeader, err := svc.GetBlockByHeight(height)
if err != nil {
errChan <- fmt.Errorf("Couldn't get block "+
"header by height %d: %s", height, err)
return
}
if blockHeight != height {
errChan <- fmt.Errorf("Block height retrieved "+
"from DB doesn't match expected "+
"height. Want: %d, have: %d", height,
blockHeight)
return
}
blockHash := blockHeader.BlockHash()
// Get block via RPC.
wantBlock, err := correctSyncNode.Node.GetBlock(
@ -455,11 +446,11 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
return
}
// Check that block height matches what we have.
if int32(blockHeight) != haveBlock.Height() {
if int32(height) != haveBlock.Height() {
errChan <- fmt.Errorf("Block height from "+
"network doesn't match expected "+
"height. Want: %s, network: %s",
blockHeight, haveBlock.Height())
height, haveBlock.Height())
return
}
// Get basic cfilter from network.