Fix concurrency issues.

This commit is contained in:
Alex 2017-04-26 16:34:05 -06:00 committed by Olaoluwa Osuntokun
parent c7b26a11e2
commit fe632ff233
4 changed files with 531 additions and 439 deletions

View file

@ -1031,9 +1031,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
// Should probably use better isolation for this but we're in
// the same package. One of the things to clean up when we do
// more general cleanup.
sp.mtxReqCFH.Lock()
sp.requestedCFHeaders[cfhReqB] = cfhCount
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false)
sp.requestedCFHeaders[cfhReqE] = cfhCount
sp.mtxReqCFH.Unlock()
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false)
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true)
})
@ -1078,13 +1080,20 @@ func (b *blockManager) QueueCFHeaders(cfheaders *wire.MsgCFHeaders,
extended: cfheaders.Extended,
stopHash: cfheaders.StopHash,
}
if sp.requestedCFHeaders[req] != len(cfheaders.HeaderHashes) {
// TODO: Get rid of this by refactoring all of this using the query API
sp.mtxReqCFH.Lock()
expLen := sp.requestedCFHeaders[req]
sp.mtxReqCFH.Unlock()
if expLen != len(cfheaders.HeaderHashes) {
log.Warnf("Received cfheaders message doesn't match any "+
"getcfheaders request. Peer %s is probably on a "+
"different chain -- ignoring", sp.Addr())
return
}
// TODO: Remove this by refactoring this section into a query client.
sp.mtxReqCFH.Lock()
delete(sp.requestedCFHeaders, req)
sp.mtxReqCFH.Unlock()
// Track number of pending cfheaders messsages for both basic and
// extended filters.

View file

@ -8,13 +8,7 @@ import (
"errors"
"github.com/btcsuite/btcd/addrmgr"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
)
type getConnCountMsg struct {
@ -54,19 +48,6 @@ type forAllPeersMsg struct {
closure func(*serverPeer)
}
type getCFilterMsg struct {
cfRequest
prevHeader *chainhash.Hash
curHeader *chainhash.Hash
reply chan *gcs.Filter
}
type getBlockMsg struct {
blockHeader *wire.BlockHeader
height uint32
reply chan *btcutil.Block
}
// TODO: General - abstract out more of blockmanager into queries. It'll make
// this way more maintainable and usable.
@ -172,128 +153,12 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
msg.reply <- errors.New("peer not found")
case forAllPeersMsg:
// TODO: Remove this when it's unnecessary due to wider use of
// queryPeers.
// Run the closure on all peers in the passed state.
state.forAllPeers(msg.closure)
// Even though this is a query, there's no reply channel as the
// forAllPeers method doesn't return anything. An error might be
// useful in the future.
case getCFilterMsg:
found := false
state.queryPeers(
// Should we query this peer?
func(sp *serverPeer) bool {
// Don't send requests to disconnected peers.
return sp.Connected()
},
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&msg.blockHash, msg.extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
filter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if MakeHeaderForFilter(filter,
*msg.prevHeader) !=
*msg.curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
found = true
close(quit)
msg.reply <- filter
default:
}
},
)
// We timed out without finding a correct answer to our query.
if !found {
msg.reply <- nil
}
case getBlockMsg:
found := false
getData := wire.NewMsgGetData()
blockHash := msg.blockHeader.BlockHash()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
state.queryPeers(
// Should we query this peer?
func(sp *serverPeer) bool {
// Don't send requests to disconnected peers.
return sp.Connected()
},
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(msg.height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for "+
"%s received from %s "+
"-- disconnecting peer",
blockHash, sp.Addr())
sp.Disconnect()
return
}
found = true
close(quit)
msg.reply <- block
default:
}
},
)
// We timed out without finding a correct answer to our query.
if !found {
msg.reply <- nil
}
}
}

View file

@ -18,6 +18,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb"
@ -63,8 +64,14 @@ var (
// from DNS.
DisableDNSSeed = false
// Timeout specifies how long to wait for a peer to answer a query.
Timeout = time.Second * 5
// QueryTimeout specifies how long to wait for a peer to answer a query.
QueryTimeout = time.Second * 3
// QueryNumRetries specifies how many times to retry sending a query to
// each peer before we've concluded we aren't going to get a valid
// response. This allows to make up for missed messages in some
// instances.
QueryNumRetries = 2
)
// updatePeerHeightsMsg is a message sent from the blockmanager to the server
@ -110,137 +117,6 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) {
ps.forAllOutboundPeers(closure)
}
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// queryTimeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
queryTimeout time.Duration
}
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
queryTimeout: Timeout,
}
}
// QueryTimeout is a query option that lets the query know to ask each peer we're
// connected to for its opinion, if any. By default, we only ask peers until one
// gives us a valid response.
func QueryTimeout(timeout time.Duration) func(*queryOptions) {
return func(qo *queryOptions) {
qo.queryTimeout = timeout
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (ps *peerState) queryPeers(
// selectPeer is a closure which decides whether or not to send the
// query to the peer.
selectPeer func(sp *serverPeer) bool,
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...func(*queryOptions),
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
channel := make(chan spMsg)
// This goroutine will monitor all messages from all peers until the
// peer goroutines all exit.
go func() {
for {
select {
case <-quit:
close(channel)
ps.forAllPeers(
func(sp *serverPeer) {
sp.unsubscribeRecvMsgs(channel)
})
return
case sm := <-channel:
// TODO: This will get stuck if checkResponse
// gets stuck.
checkResponse(sm.sp, sm.msg, quit)
}
}
}()
// Start a goroutine for each peer that potentially queries each peer
ps.forAllPeers(func(sp *serverPeer) {
wg.Add(1)
go func() {
defer wg.Done()
if !selectPeer(sp) {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we return and notify
// another goroutine that we've done so.
// We only send if there's someone left
// to receive.
startQuery <- struct{}{}
return
case <-quit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support multiple
// queries at once. For now, we're
// relying on the query handling loop
// to make sure we don't interrupt
// another query. We need broadcast
// support in OnRead to do this right.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(channel)
sp.QueueMessage(queryMsg, nil)
timeout = time.After(qo.queryTimeout)
default:
}
}
}()
})
startQuery <- struct{}{}
wg.Wait()
// If we timed out and didn't quit, make sure our response monitor
// goroutine knows to quit.
select {
case <-quit:
default:
close(quit)
}
}
// cfhRequest records which cfheaders we've requested, and the order in which
// we've requested them. Since there's no way to associate the cfheaders to the
// actual block hashes based on the cfheaders message to keep it compact, we
@ -250,12 +126,6 @@ type cfhRequest struct {
stopHash chainhash.Hash
}
// cfRequest records which cfilters we've requested.
type cfRequest struct {
extended bool
blockHash chainhash.Hash
}
// serverPeer extends the peer to maintain state shared by the server and
// the blockmanager.
type serverPeer struct {
@ -264,24 +134,27 @@ type serverPeer struct {
*peer.Peer
connReq *connmgr.ConnReq
server *ChainService
persistent bool
continueHash *chainhash.Hash
relayMtx sync.Mutex
requestQueue []*wire.InvVect
requestedCFHeaders map[cfhRequest]int
knownAddresses map[string]struct{}
banScore connmgr.DynamicBanScore
quit chan struct{}
connReq *connmgr.ConnReq
server *ChainService
persistent bool
continueHash *chainhash.Hash
relayMtx sync.Mutex
requestQueue []*wire.InvVect
knownAddresses map[string]struct{}
banScore connmgr.DynamicBanScore
quit chan struct{}
// The following slice of channels is used to subscribe to messages from
// the peer. This allows broadcast to multiple subscribers at once,
// allowing for multiple queries to be going to multiple peers at any
// one time. The mutex is for subscribe/unsubscribe functionality.
// The sends on these channels WILL NOT block; any messages the channel
// can't accept will be dropped silently.
recvSubscribers []chan<- spMsg
recvSubscribers map[spMsgSubscription]struct{}
mtxSubscribers sync.RWMutex
// These are only necessary until the cfheaders logic is refactored as
// a query client.
requestedCFHeaders map[cfhRequest]int
mtxReqCFH sync.Mutex
}
// newServerPeer returns a new serverPeer instance. The peer needs to be set by
@ -293,6 +166,7 @@ func newServerPeer(s *ChainService, isPersistent bool) *serverPeer {
requestedCFHeaders: make(map[cfhRequest]int),
knownAddresses: make(map[string]struct{}),
quit: make(chan struct{}),
recvSubscribers: make(map[spMsgSubscription]struct{}),
}
}
@ -615,42 +489,38 @@ func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message,
err error) {
sp.server.AddBytesReceived(uint64(bytesRead))
// Try to send a message to the subscriber channel if it isn't nil, but
// don't block on failure.
// don't block on failure. Do this inside a goroutine to prevent the
// server from slowing down too fast.
sp.mtxSubscribers.RLock()
defer sp.mtxSubscribers.RUnlock()
for _, channel := range sp.recvSubscribers {
if channel != nil {
for subscription := range sp.recvSubscribers {
subscription.wg.Add(1)
go func(subscription spMsgSubscription) {
defer subscription.wg.Done()
select {
case channel <- spMsg{
sp: sp,
case <-subscription.quitChan:
case subscription.msgChan <- spMsg{
msg: msg,
sp: sp,
}:
default:
}
}
}(subscription)
}
}
// subscribeRecvMsg handles adding OnRead subscriptions to the server peer.
func (sp *serverPeer) subscribeRecvMsg(channel chan<- spMsg) {
func (sp *serverPeer) subscribeRecvMsg(subscription spMsgSubscription) {
sp.mtxSubscribers.Lock()
defer sp.mtxSubscribers.Unlock()
sp.recvSubscribers = append(sp.recvSubscribers, channel)
sp.recvSubscribers[subscription] = struct{}{}
}
// unsubscribeRecvMsgs handles removing OnRead subscriptions from the server
// peer.
func (sp *serverPeer) unsubscribeRecvMsgs(channel chan<- spMsg) {
func (sp *serverPeer) unsubscribeRecvMsgs(subscription spMsgSubscription) {
sp.mtxSubscribers.Lock()
defer sp.mtxSubscribers.Unlock()
var updatedSubscribers []chan<- spMsg
for _, candidate := range sp.recvSubscribers {
if candidate != channel {
updatedSubscribers = append(updatedSubscribers,
candidate)
}
}
sp.recvSubscribers = updatedSubscribers
delete(sp.recvSubscribers, subscription)
}
// OnWrite is invoked when a peer sends a message and it is used to update
@ -683,9 +553,6 @@ type ChainService struct {
timeSource blockchain.MedianTimeSource
services wire.ServiceFlag
cfilterRequests map[cfRequest][]chan *gcs.Filter
cfRequestHeaders map[cfRequest][2]*chainhash.Hash
userAgentName string
userAgentVersion string
}
@ -894,8 +761,6 @@ func NewChainService(cfg Config) (*ChainService, error) {
services: Services,
userAgentName: UserAgentName,
userAgentVersion: UserAgentVersion,
cfilterRequests: make(map[cfRequest][]chan *gcs.Filter),
cfRequestHeaders: make(map[cfRequest][2]*chainhash.Hash),
}
err := createSPVNS(s.namespace, &s.chainParams)
@ -1686,10 +1551,238 @@ func (s *ChainService) IsCurrent() bool {
return s.blockManager.IsCurrent()
}
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// timeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
timeout time.Duration
// numRetries tells the query how many times to retry asking each peer
// the query.
numRetries uint8
// doneChan lets the query signal the caller when it's done, in case
// it's run in a goroutine.
doneChan chan<- struct{}
}
// QueryOption is a functional option argument to any of the network query
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
// network query).
type QueryOption func(*queryOptions)
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
timeout: QueryTimeout,
numRetries: uint8(QueryNumRetries),
}
}
// Timeout is a query option that lets the query know how long to wait for
// each peer we ask the query to answer it before moving on.
func Timeout(timeout time.Duration) QueryOption {
return func(qo *queryOptions) {
qo.timeout = timeout
}
}
// NumRetries is a query option that lets the query know the maximum number of
// times each peer should be queried. The default is one.
func NumRetries(numRetries uint8) QueryOption {
return func(qo *queryOptions) {
qo.numRetries = numRetries
}
}
// DoneChan allows the caller to pass a channel that will get closed when the
// query is finished.
func DoneChan(doneChan chan<- struct{}) QueryOption {
return func(qo *queryOptions) {
qo.doneChan = doneChan
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
type spMsgSubscription struct {
msgChan chan<- spMsg
quitChan <-chan struct{}
wg *sync.WaitGroup
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (s *ChainService) queryPeers(
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...QueryOption,
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This is done in a single-threaded query because the peerState is held
// in a single thread. This is the only part of the query framework that
// requires access to peerState, so it's done once per query.
peers := s.Peers()
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
allQuit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
// Increase this number to be able to handle more queries at once as
// each channel gets results for all queries, otherwise messages can
// get mixed and there's a vicious cycle of retries causing a bigger
// message flood, more of which get missed.
msgChan := make(chan spMsg)
var subwg sync.WaitGroup
subscription := spMsgSubscription{
msgChan: msgChan,
quitChan: allQuit,
wg: &subwg,
}
// Start a goroutine for each peer that potentially queries that peer.
for _, sp := range peers {
wg.Add(1)
go func(sp *serverPeer) {
numRetries := qo.numRetries
defer wg.Done()
defer sp.unsubscribeRecvMsgs(subscription)
// Should we do this when the goroutine gets a message
// via startQuery rather than at the launch of the
// goroutine?
if !sp.Connected() {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we try to notify
// another of our peer goroutines to
// do a query until we get a signal to
// quit.
select {
case startQuery <- struct{}{}:
case <-quit:
return
case <-allQuit:
return
}
// At this point, we've sent startQuery.
// We return if we've run through this
// section of code numRetries times.
if numRetries--; numRetries == 0 {
return
}
case <-quit:
// After we're told to quit, we return.
return
case <-allQuit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(subscription)
// Don't want the peer hanging on send
// to the channel if we quit before
// reading the channel.
sentChan := make(chan struct{}, 1)
sp.QueueMessage(queryMsg, sentChan)
select {
case <-sentChan:
case <-quit:
return
case <-allQuit:
return
}
timeout = time.After(qo.timeout)
default:
}
}
}(sp)
}
startQuery <- struct{}{}
// This goroutine will wait until all of the peer-query goroutines have
// terminated, and then initiate a query shutdown.
go func() {
wg.Wait()
// If we timed out on each goroutine and didn't quit or time out
// on the main goroutine, make sure our main goroutine knows to
// quit.
select {
case <-allQuit:
default:
close(allQuit)
}
// Close the done channel, if any
if qo.doneChan != nil {
close(qo.doneChan)
}
// Wait until all goroutines started by subscriptions have
// exited after we closed allQuit before letting the message
// channel get garbage collected.
subwg.Wait()
}()
// Loop for any messages sent to us via our subscription channel and
// check them for whether they satisfy the query. Break the loop if it's
// time to quit.
timeout := time.After(time.Duration(len(peers)+1) *
qo.timeout * time.Duration(qo.numRetries))
checkResponses:
for {
select {
case <-timeout:
// When we time out, close the allQuit channel
// if it hasn't already been closed.
select {
case <-allQuit:
default:
close(allQuit)
}
break checkResponses
case <-quit:
break checkResponses
case <-allQuit:
break checkResponses
case sm := <-msgChan:
// TODO: This will get stuck if checkResponse
// gets stuck. This is a caveat for callers that
// should be fixed before exposing this function
// for public use.
checkResponse(sm.sp, sm.msg, quit)
}
}
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool) *gcs.Filter {
extended bool, options ...QueryOption) *gcs.Filter {
getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter
@ -1714,17 +1807,54 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
if err != nil {
return nil
}
replyChan := make(chan *gcs.Filter)
s.query <- getCFilterMsg{
cfRequest: cfRequest{
blockHash: blockHash,
extended: extended,
s.queryPeers(
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&blockHash, extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
if blockHash != response.BlockHash {
// The response doesn't match our
// request. Ignore this message.
return
}
gotFilter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if MakeHeaderForFilter(gotFilter,
*prevHeader) !=
*curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
close(quit)
filter = gotFilter
default:
}
},
prevHeader: prevHeader,
curHeader: curHeader,
reply: replyChan,
}
filter = <-replyChan
options...,
)
// If we've found a filter, write it to the database for next time.
if filter != nil {
putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t",
@ -1736,20 +1866,72 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
// at a time, until one answers.
func (s *ChainService) GetBlockFromNetwork(
blockHash chainhash.Hash) *btcutil.Block {
blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
blockHeader, height, err := s.GetBlockByHash(blockHash)
if err != nil || blockHeader.BlockHash() != blockHash {
return nil
}
replyChan := make(chan *btcutil.Block)
s.query <- getBlockMsg{
blockHeader: &blockHeader,
height: height,
reply: replyChan,
}
block := <-replyChan
if block != nil {
log.Tracef("Got block %s from network", blockHash)
}
return block
getData := wire.NewMsgGetData()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
// The block is only updated from the checkResponse function argument,
// which is always called single-threadedly. We don't check the block
// until after the query is finished, so we can just write to it
// naively.
var foundBlock *btcutil.Block
s.queryPeers(
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for %s "+
"received from %s -- "+
"disconnecting peer", blockHash,
sp.Addr())
sp.Disconnect()
return
}
// At this point, the block matches what we know
// about it and we declare it sane. We can kill
// the query and pass the response back to the
// caller.
close(quit)
foundBlock = block
default:
}
},
options...,
)
return foundBlock
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io/ioutil"
"math/rand"
"os"
"reflect"
"sync"
@ -24,14 +23,32 @@ import (
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
)
const (
logLevel = btclog.TraceLvl
syncTimeout = 30 * time.Second
syncUpdate = time.Second
numTestBlocks = 50
var (
logLevel = btclog.TraceLvl
syncTimeout = 30 * time.Second
syncUpdate = time.Second
// Don't set this too high for your platform, or the tests will miss
// messages.
// TODO: Make this a benchmark instead.
numQueryThreads = 50
queryOptions = []spvchain.QueryOption{
//spvchain.NumRetries(5),
}
)
func TestSetup(t *testing.T) {
// Set up logging.
logger, err := btctestlog.NewTestLogger(t)
if err != nil {
t.Fatalf("Could not set up logger: %s", err)
}
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
chainLogger.SetLevel(logLevel)
spvchain.UseLogger(chainLogger)
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
rpcLogger.SetLevel(logLevel)
btcrpcclient.UseLogger(rpcLogger)
// Create a btcd SimNet node and generate 500 blocks
h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
if err != nil {
@ -147,16 +164,6 @@ func TestSetup(t *testing.T) {
spvchain.BanDuration = 5 * time.Second
spvchain.RequiredServices = wire.SFNodeNetwork
spvchain.WaitForMoreCFHeaders = time.Second
logger, err := btctestlog.NewTestLogger(t)
if err != nil {
t.Fatalf("Could not set up logger: %s", err)
}
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
chainLogger.SetLevel(logLevel)
spvchain.UseLogger(chainLogger)
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
rpcLogger.SetLevel(logLevel)
btcrpcclient.UseLogger(rpcLogger)
svc, err := spvchain.NewChainService(config)
if err != nil {
t.Fatalf("Error creating ChainService: %s", err)
@ -170,6 +177,13 @@ func TestSetup(t *testing.T) {
t.Fatalf("Couldn't sync ChainService: %s", err)
}
// Test that we can get blocks and cfilters via P2P and decide which are
// valid and which aren't.
err = testRandomBlocks(t, svc, h1)
if err != nil {
t.Fatalf("Testing blocks and cfilters failed: %s", err)
}
// Generate 125 blocks on h1 to make sure it reorgs the other nodes.
// Ensure the ChainService instance stays caught up.
h1.Node.Generate(125)
@ -209,13 +223,6 @@ func TestSetup(t *testing.T) {
if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err)
}
// Test that we can get blocks and cfilters via P2P and decide which are
// valid and which aren't.
err = testRandomBlocks(t, svc, h1)
if err != nil {
t.Fatalf("Testing blocks and cfilters failed: %s", err)
}
}
// csd does a connect-sync-disconnect between nodes in order to support
@ -377,7 +384,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService,
// can correctly get filters from them. We don't go through *all* the blocks
// because it can be a little slow, but we'll improve that soon-ish hopefully
// to the point where we can do it.
// TODO: Improve concurrency on framework side.
// TODO: Make this a benchmark instead.
func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
correctSyncNode *rpctest.Harness) error {
var haveBest *waddrmgr.BlockStamp
@ -386,180 +393,200 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
return fmt.Errorf("Couldn't get best snapshot from "+
"ChainService: %s", err)
}
// Keep track of an error channel
errChan := make(chan error)
var lastErr error
go func() {
for err := range errChan {
if err != nil {
t.Errorf("%s", err)
lastErr = fmt.Errorf("Couldn't validate all " +
"blocks, filters, and filter headers.")
}
}
}()
// Test getting numTestBlocks random blocks and filters.
// Keep track of an error channel with enough buffer space to track one
// error per block.
errChan := make(chan error, haveBest.Height)
// Test getting all of the blocks and filters.
var wg sync.WaitGroup
heights := rand.Perm(int(haveBest.Height))
for i := 0; i < numTestBlocks; i++ {
workerQueue := make(chan struct{}, numQueryThreads)
for i := int32(1); i <= haveBest.Height; i++ {
wg.Add(1)
height := uint32(heights[i])
height := uint32(i)
// Wait until there's room in the worker queue.
workerQueue <- struct{}{}
go func() {
// On exit, open a spot in workerQueue and tell the
// wait group we're done.
defer func() {
<-workerQueue
}()
defer wg.Done()
// Get block header from database.
blockHeader, blockHeight, err := svc.GetBlockByHeight(height)
blockHeader, blockHeight, err := svc.GetBlockByHeight(
height)
if err != nil {
errChan <- fmt.Errorf("Couldn't get block "+
"header by height %d: %s", height, err)
return
}
if blockHeight != height {
errChan <- fmt.Errorf("Block height retrieved from DB "+
"doesn't match expected height. Want: %d, "+
"have: %d", height, blockHeight)
errChan <- fmt.Errorf("Block height retrieved "+
"from DB doesn't match expected "+
"height. Want: %d, have: %d", height,
blockHeight)
return
}
blockHash := blockHeader.BlockHash()
// Get block via RPC.
wantBlock, err := correctSyncNode.Node.GetBlock(&blockHash)
wantBlock, err := correctSyncNode.Node.GetBlock(
&blockHash)
if err != nil {
errChan <- fmt.Errorf("Couldn't get block %d (%s) by RPC",
height, blockHash)
errChan <- fmt.Errorf("Couldn't get block %d "+
"(%s) by RPC", height, blockHash)
return
}
// Get block from network.
haveBlock := svc.GetBlockFromNetwork(blockHash)
haveBlock := svc.GetBlockFromNetwork(blockHash,
queryOptions...)
if haveBlock == nil {
errChan <- fmt.Errorf("Couldn't get block %d (%s) from"+
"network", height, blockHash)
errChan <- fmt.Errorf("Couldn't get block %d "+
"(%s) from network", height, blockHash)
return
}
// Check that network and RPC blocks match.
if !reflect.DeepEqual(*haveBlock.MsgBlock(), *wantBlock) {
errChan <- fmt.Errorf("Block from network doesn't match "+
"block from RPC. Want: %s, RPC: %s, network: "+
"%s", blockHash, wantBlock.BlockHash(),
if !reflect.DeepEqual(*haveBlock.MsgBlock(),
*wantBlock) {
errChan <- fmt.Errorf("Block from network "+
"doesn't match block from RPC. Want: "+
"%s, RPC: %s, network: %s", blockHash,
wantBlock.BlockHash(),
haveBlock.MsgBlock().BlockHash())
return
}
// Check that block height matches what we have.
if int32(blockHeight) != haveBlock.Height() {
errChan <- fmt.Errorf("Block height from network doesn't "+
"match expected height. Want: %s, network: %s",
errChan <- fmt.Errorf("Block height from "+
"network doesn't match expected "+
"height. Want: %s, network: %s",
blockHeight, haveBlock.Height())
return
}
// Get basic cfilter from network.
haveFilter := svc.GetCFilter(blockHash, false)
haveFilter := svc.GetCFilter(blockHash, false,
queryOptions...)
if haveFilter == nil {
errChan <- fmt.Errorf("Couldn't get basic "+
"filter for block %d", height)
return
}
// Get basic cfilter from RPC.
wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash,
false)
wantFilter, err := correctSyncNode.Node.GetCFilter(
&blockHash, false)
if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter for "+
"block %d via RPC: %s", height, err)
errChan <- fmt.Errorf("Couldn't get basic "+
"filter for block %d via RPC: %s",
height, err)
return
}
// Check that network and RPC cfilters match.
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
errChan <- fmt.Errorf("Basic filter from P2P network/DB"+
" doesn't match RPC value for block %d", height)
errChan <- fmt.Errorf("Basic filter from P2P "+
"network/DB doesn't match RPC value "+
"for block %d", height)
return
}
// Calculate basic filter from block.
calcFilter, err := spvchain.BuildBasicFilter(
haveBlock.MsgBlock())
if err != nil {
errChan <- fmt.Errorf("Couldn't build basic filter for "+
"block %d (%s): %s", height, blockHash, err)
errChan <- fmt.Errorf("Couldn't build basic "+
"filter for block %d (%s): %s", height,
blockHash, err)
return
}
// Check that the network value matches the calculated value
// from the block.
// Check that the network value matches the calculated
// value from the block.
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
errChan <- fmt.Errorf("Basic filter from P2P network/DB "+
"doesn't match calculated value for block %d",
height)
errChan <- fmt.Errorf("Basic filter from P2P "+
"network/DB doesn't match calculated "+
"value for block %d", height)
return
}
// Get previous basic filter header from the database.
prevHeader, err := svc.GetBasicHeader(blockHeader.PrevBlock)
prevHeader, err := svc.GetBasicHeader(
blockHeader.PrevBlock)
if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter header "+
"for block %d (%s) from DB: %s", height-1,
errChan <- fmt.Errorf("Couldn't get basic "+
"filter header for block %d (%s) from "+
"DB: %s", height-1,
blockHeader.PrevBlock, err)
return
}
// Get current basic filter header from the database.
curHeader, err := svc.GetBasicHeader(blockHash)
if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter header "+
"for block %d (%s) from DB: %s", height-1,
blockHash, err)
errChan <- fmt.Errorf("Couldn't get basic "+
"filter header for block %d (%s) from "+
"DB: %s", height-1, blockHash, err)
return
}
// Check that the filter and header line up.
calcHeader := spvchain.MakeHeaderForFilter(calcFilter,
*prevHeader)
if !bytes.Equal(curHeader[:], calcHeader[:]) {
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+
"%s, got: %s", curHeader, calcHeader)
errChan <- fmt.Errorf("Filter header doesn't "+
"match. Want: %s, got: %s", curHeader,
calcHeader)
return
}
// Get extended cfilter from network
haveFilter = svc.GetCFilter(blockHash, true)
haveFilter = svc.GetCFilter(blockHash, true,
queryOptions...)
if haveFilter == nil {
errChan <- fmt.Errorf("Couldn't get extended "+
"filter for block %d", height)
return
}
// Get extended cfilter from RPC
wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash,
true)
wantFilter, err = correctSyncNode.Node.GetCFilter(
&blockHash, true)
if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter for "+
"block %d via RPC: %s", height, err)
errChan <- fmt.Errorf("Couldn't get extended "+
"filter for block %d via RPC: %s",
height, err)
return
}
// Check that network and RPC cfilters match
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+
" doesn't match RPC value for block %d", height)
errChan <- fmt.Errorf("Extended filter from "+
"P2P network/DB doesn't match RPC "+
"value for block %d", height)
return
}
// Calculate extended filter from block
calcFilter, err = spvchain.BuildExtFilter(
haveBlock.MsgBlock())
if err != nil {
errChan <- fmt.Errorf("Couldn't build extended filter for "+
"block %d (%s): %s", height, blockHash, err)
errChan <- fmt.Errorf("Couldn't build extended"+
" filter for block %d (%s): %s", height,
blockHash, err)
return
}
// Check that the network value matches the calculated value
// from the block.
// Check that the network value matches the calculated
// value from the block.
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+
" doesn't match calculated value for block %d",
height)
errChan <- fmt.Errorf("Extended filter from "+
"P2P network/DB doesn't match "+
"calculated value for block %d", height)
return
}
// Get previous extended filter header from the database.
prevHeader, err = svc.GetExtHeader(blockHeader.PrevBlock)
// Get previous extended filter header from the
// database.
prevHeader, err = svc.GetExtHeader(
blockHeader.PrevBlock)
if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter header"+
" for block %d (%s) from DB: %s", height-1,
errChan <- fmt.Errorf("Couldn't get extended "+
"filter header for block %d (%s) from "+
"DB: %s", height-1,
blockHeader.PrevBlock, err)
return
}
// Get current basic filter header from the database.
curHeader, err = svc.GetExtHeader(blockHash)
if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter header"+
" for block %d (%s) from DB: %s", height-1,
errChan <- fmt.Errorf("Couldn't get extended "+
"filter header for block %d (%s) from "+
"DB: %s", height-1,
blockHash, err)
return
}
@ -567,20 +594,29 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
calcHeader = spvchain.MakeHeaderForFilter(calcFilter,
*prevHeader)
if !bytes.Equal(curHeader[:], calcHeader[:]) {
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+
"%s, got: %s", curHeader, calcHeader)
errChan <- fmt.Errorf("Filter header doesn't "+
"match. Want: %s, got: %s", curHeader,
calcHeader)
return
}
}()
}
// Wait for all queries to finish.
wg.Wait()
if logLevel != btclog.Off {
t.Logf("Finished checking %d blocks and their cfilters",
numTestBlocks)
}
// Close the error channel to make the error monitoring goroutine
// finish.
close(errChan)
var lastErr error
for err := range errChan {
if err != nil {
t.Errorf("%s", err)
lastErr = fmt.Errorf("Couldn't validate all " +
"blocks, filters, and filter headers.")
}
}
if logLevel != btclog.Off {
t.Logf("Finished checking %d blocks and their cfilters",
haveBest.Height)
}
return lastErr
}