Fix concurrency issues.

This commit is contained in:
Alex 2017-04-26 16:34:05 -06:00 committed by Olaoluwa Osuntokun
parent c7b26a11e2
commit fe632ff233
4 changed files with 531 additions and 439 deletions

View file

@ -1031,9 +1031,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
// Should probably use better isolation for this but we're in // Should probably use better isolation for this but we're in
// the same package. One of the things to clean up when we do // the same package. One of the things to clean up when we do
// more general cleanup. // more general cleanup.
sp.mtxReqCFH.Lock()
sp.requestedCFHeaders[cfhReqB] = cfhCount sp.requestedCFHeaders[cfhReqB] = cfhCount
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false)
sp.requestedCFHeaders[cfhReqE] = cfhCount sp.requestedCFHeaders[cfhReqE] = cfhCount
sp.mtxReqCFH.Unlock()
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false)
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true) sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true)
}) })
@ -1078,13 +1080,20 @@ func (b *blockManager) QueueCFHeaders(cfheaders *wire.MsgCFHeaders,
extended: cfheaders.Extended, extended: cfheaders.Extended,
stopHash: cfheaders.StopHash, stopHash: cfheaders.StopHash,
} }
if sp.requestedCFHeaders[req] != len(cfheaders.HeaderHashes) { // TODO: Get rid of this by refactoring all of this using the query API
sp.mtxReqCFH.Lock()
expLen := sp.requestedCFHeaders[req]
sp.mtxReqCFH.Unlock()
if expLen != len(cfheaders.HeaderHashes) {
log.Warnf("Received cfheaders message doesn't match any "+ log.Warnf("Received cfheaders message doesn't match any "+
"getcfheaders request. Peer %s is probably on a "+ "getcfheaders request. Peer %s is probably on a "+
"different chain -- ignoring", sp.Addr()) "different chain -- ignoring", sp.Addr())
return return
} }
// TODO: Remove this by refactoring this section into a query client.
sp.mtxReqCFH.Lock()
delete(sp.requestedCFHeaders, req) delete(sp.requestedCFHeaders, req)
sp.mtxReqCFH.Unlock()
// Track number of pending cfheaders messsages for both basic and // Track number of pending cfheaders messsages for both basic and
// extended filters. // extended filters.

View file

@ -8,13 +8,7 @@ import (
"errors" "errors"
"github.com/btcsuite/btcd/addrmgr" "github.com/btcsuite/btcd/addrmgr"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
) )
type getConnCountMsg struct { type getConnCountMsg struct {
@ -54,19 +48,6 @@ type forAllPeersMsg struct {
closure func(*serverPeer) closure func(*serverPeer)
} }
type getCFilterMsg struct {
cfRequest
prevHeader *chainhash.Hash
curHeader *chainhash.Hash
reply chan *gcs.Filter
}
type getBlockMsg struct {
blockHeader *wire.BlockHeader
height uint32
reply chan *btcutil.Block
}
// TODO: General - abstract out more of blockmanager into queries. It'll make // TODO: General - abstract out more of blockmanager into queries. It'll make
// this way more maintainable and usable. // this way more maintainable and usable.
@ -172,128 +153,12 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
msg.reply <- errors.New("peer not found") msg.reply <- errors.New("peer not found")
case forAllPeersMsg: case forAllPeersMsg:
// TODO: Remove this when it's unnecessary due to wider use of
// queryPeers.
// Run the closure on all peers in the passed state. // Run the closure on all peers in the passed state.
state.forAllPeers(msg.closure) state.forAllPeers(msg.closure)
// Even though this is a query, there's no reply channel as the // Even though this is a query, there's no reply channel as the
// forAllPeers method doesn't return anything. An error might be // forAllPeers method doesn't return anything. An error might be
// useful in the future. // useful in the future.
case getCFilterMsg:
found := false
state.queryPeers(
// Should we query this peer?
func(sp *serverPeer) bool {
// Don't send requests to disconnected peers.
return sp.Connected()
},
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&msg.blockHash, msg.extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
filter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if MakeHeaderForFilter(filter,
*msg.prevHeader) !=
*msg.curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
found = true
close(quit)
msg.reply <- filter
default:
}
},
)
// We timed out without finding a correct answer to our query.
if !found {
msg.reply <- nil
}
case getBlockMsg:
found := false
getData := wire.NewMsgGetData()
blockHash := msg.blockHeader.BlockHash()
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
&blockHash))
state.queryPeers(
// Should we query this peer?
func(sp *serverPeer) bool {
// Don't send requests to disconnected peers.
return sp.Connected()
},
// Send a wire.GetCFilterMsg
getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(msg.height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for "+
"%s received from %s "+
"-- disconnecting peer",
blockHash, sp.Addr())
sp.Disconnect()
return
}
found = true
close(quit)
msg.reply <- block
default:
}
},
)
// We timed out without finding a correct answer to our query.
if !found {
msg.reply <- nil
}
} }
} }

View file

@ -18,6 +18,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/gcs" "github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
"github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
@ -63,8 +64,14 @@ var (
// from DNS. // from DNS.
DisableDNSSeed = false DisableDNSSeed = false
// Timeout specifies how long to wait for a peer to answer a query. // QueryTimeout specifies how long to wait for a peer to answer a query.
Timeout = time.Second * 5 QueryTimeout = time.Second * 3
// QueryNumRetries specifies how many times to retry sending a query to
// each peer before we've concluded we aren't going to get a valid
// response. This allows to make up for missed messages in some
// instances.
QueryNumRetries = 2
) )
// updatePeerHeightsMsg is a message sent from the blockmanager to the server // updatePeerHeightsMsg is a message sent from the blockmanager to the server
@ -110,137 +117,6 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) {
ps.forAllOutboundPeers(closure) ps.forAllOutboundPeers(closure)
} }
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// queryTimeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
queryTimeout time.Duration
}
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
queryTimeout: Timeout,
}
}
// QueryTimeout is a query option that lets the query know to ask each peer we're
// connected to for its opinion, if any. By default, we only ask peers until one
// gives us a valid response.
func QueryTimeout(timeout time.Duration) func(*queryOptions) {
return func(qo *queryOptions) {
qo.queryTimeout = timeout
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (ps *peerState) queryPeers(
// selectPeer is a closure which decides whether or not to send the
// query to the peer.
selectPeer func(sp *serverPeer) bool,
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...func(*queryOptions),
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
channel := make(chan spMsg)
// This goroutine will monitor all messages from all peers until the
// peer goroutines all exit.
go func() {
for {
select {
case <-quit:
close(channel)
ps.forAllPeers(
func(sp *serverPeer) {
sp.unsubscribeRecvMsgs(channel)
})
return
case sm := <-channel:
// TODO: This will get stuck if checkResponse
// gets stuck.
checkResponse(sm.sp, sm.msg, quit)
}
}
}()
// Start a goroutine for each peer that potentially queries each peer
ps.forAllPeers(func(sp *serverPeer) {
wg.Add(1)
go func() {
defer wg.Done()
if !selectPeer(sp) {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we return and notify
// another goroutine that we've done so.
// We only send if there's someone left
// to receive.
startQuery <- struct{}{}
return
case <-quit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support multiple
// queries at once. For now, we're
// relying on the query handling loop
// to make sure we don't interrupt
// another query. We need broadcast
// support in OnRead to do this right.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(channel)
sp.QueueMessage(queryMsg, nil)
timeout = time.After(qo.queryTimeout)
default:
}
}
}()
})
startQuery <- struct{}{}
wg.Wait()
// If we timed out and didn't quit, make sure our response monitor
// goroutine knows to quit.
select {
case <-quit:
default:
close(quit)
}
}
// cfhRequest records which cfheaders we've requested, and the order in which // cfhRequest records which cfheaders we've requested, and the order in which
// we've requested them. Since there's no way to associate the cfheaders to the // we've requested them. Since there's no way to associate the cfheaders to the
// actual block hashes based on the cfheaders message to keep it compact, we // actual block hashes based on the cfheaders message to keep it compact, we
@ -250,12 +126,6 @@ type cfhRequest struct {
stopHash chainhash.Hash stopHash chainhash.Hash
} }
// cfRequest records which cfilters we've requested.
type cfRequest struct {
extended bool
blockHash chainhash.Hash
}
// serverPeer extends the peer to maintain state shared by the server and // serverPeer extends the peer to maintain state shared by the server and
// the blockmanager. // the blockmanager.
type serverPeer struct { type serverPeer struct {
@ -264,24 +134,27 @@ type serverPeer struct {
*peer.Peer *peer.Peer
connReq *connmgr.ConnReq connReq *connmgr.ConnReq
server *ChainService server *ChainService
persistent bool persistent bool
continueHash *chainhash.Hash continueHash *chainhash.Hash
relayMtx sync.Mutex relayMtx sync.Mutex
requestQueue []*wire.InvVect requestQueue []*wire.InvVect
requestedCFHeaders map[cfhRequest]int knownAddresses map[string]struct{}
knownAddresses map[string]struct{} banScore connmgr.DynamicBanScore
banScore connmgr.DynamicBanScore quit chan struct{}
quit chan struct{}
// The following slice of channels is used to subscribe to messages from // The following slice of channels is used to subscribe to messages from
// the peer. This allows broadcast to multiple subscribers at once, // the peer. This allows broadcast to multiple subscribers at once,
// allowing for multiple queries to be going to multiple peers at any // allowing for multiple queries to be going to multiple peers at any
// one time. The mutex is for subscribe/unsubscribe functionality. // one time. The mutex is for subscribe/unsubscribe functionality.
// The sends on these channels WILL NOT block; any messages the channel // The sends on these channels WILL NOT block; any messages the channel
// can't accept will be dropped silently. // can't accept will be dropped silently.
recvSubscribers []chan<- spMsg recvSubscribers map[spMsgSubscription]struct{}
mtxSubscribers sync.RWMutex mtxSubscribers sync.RWMutex
// These are only necessary until the cfheaders logic is refactored as
// a query client.
requestedCFHeaders map[cfhRequest]int
mtxReqCFH sync.Mutex
} }
// newServerPeer returns a new serverPeer instance. The peer needs to be set by // newServerPeer returns a new serverPeer instance. The peer needs to be set by
@ -293,6 +166,7 @@ func newServerPeer(s *ChainService, isPersistent bool) *serverPeer {
requestedCFHeaders: make(map[cfhRequest]int), requestedCFHeaders: make(map[cfhRequest]int),
knownAddresses: make(map[string]struct{}), knownAddresses: make(map[string]struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
recvSubscribers: make(map[spMsgSubscription]struct{}),
} }
} }
@ -615,42 +489,38 @@ func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message,
err error) { err error) {
sp.server.AddBytesReceived(uint64(bytesRead)) sp.server.AddBytesReceived(uint64(bytesRead))
// Try to send a message to the subscriber channel if it isn't nil, but // Try to send a message to the subscriber channel if it isn't nil, but
// don't block on failure. // don't block on failure. Do this inside a goroutine to prevent the
// server from slowing down too fast.
sp.mtxSubscribers.RLock() sp.mtxSubscribers.RLock()
defer sp.mtxSubscribers.RUnlock() defer sp.mtxSubscribers.RUnlock()
for _, channel := range sp.recvSubscribers { for subscription := range sp.recvSubscribers {
if channel != nil { subscription.wg.Add(1)
go func(subscription spMsgSubscription) {
defer subscription.wg.Done()
select { select {
case channel <- spMsg{ case <-subscription.quitChan:
sp: sp, case subscription.msgChan <- spMsg{
msg: msg, msg: msg,
sp: sp,
}: }:
default:
} }
} }(subscription)
} }
} }
// subscribeRecvMsg handles adding OnRead subscriptions to the server peer. // subscribeRecvMsg handles adding OnRead subscriptions to the server peer.
func (sp *serverPeer) subscribeRecvMsg(channel chan<- spMsg) { func (sp *serverPeer) subscribeRecvMsg(subscription spMsgSubscription) {
sp.mtxSubscribers.Lock() sp.mtxSubscribers.Lock()
defer sp.mtxSubscribers.Unlock() defer sp.mtxSubscribers.Unlock()
sp.recvSubscribers = append(sp.recvSubscribers, channel) sp.recvSubscribers[subscription] = struct{}{}
} }
// unsubscribeRecvMsgs handles removing OnRead subscriptions from the server // unsubscribeRecvMsgs handles removing OnRead subscriptions from the server
// peer. // peer.
func (sp *serverPeer) unsubscribeRecvMsgs(channel chan<- spMsg) { func (sp *serverPeer) unsubscribeRecvMsgs(subscription spMsgSubscription) {
sp.mtxSubscribers.Lock() sp.mtxSubscribers.Lock()
defer sp.mtxSubscribers.Unlock() defer sp.mtxSubscribers.Unlock()
var updatedSubscribers []chan<- spMsg delete(sp.recvSubscribers, subscription)
for _, candidate := range sp.recvSubscribers {
if candidate != channel {
updatedSubscribers = append(updatedSubscribers,
candidate)
}
}
sp.recvSubscribers = updatedSubscribers
} }
// OnWrite is invoked when a peer sends a message and it is used to update // OnWrite is invoked when a peer sends a message and it is used to update
@ -683,9 +553,6 @@ type ChainService struct {
timeSource blockchain.MedianTimeSource timeSource blockchain.MedianTimeSource
services wire.ServiceFlag services wire.ServiceFlag
cfilterRequests map[cfRequest][]chan *gcs.Filter
cfRequestHeaders map[cfRequest][2]*chainhash.Hash
userAgentName string userAgentName string
userAgentVersion string userAgentVersion string
} }
@ -894,8 +761,6 @@ func NewChainService(cfg Config) (*ChainService, error) {
services: Services, services: Services,
userAgentName: UserAgentName, userAgentName: UserAgentName,
userAgentVersion: UserAgentVersion, userAgentVersion: UserAgentVersion,
cfilterRequests: make(map[cfRequest][]chan *gcs.Filter),
cfRequestHeaders: make(map[cfRequest][2]*chainhash.Hash),
} }
err := createSPVNS(s.namespace, &s.chainParams) err := createSPVNS(s.namespace, &s.chainParams)
@ -1686,10 +1551,238 @@ func (s *ChainService) IsCurrent() bool {
return s.blockManager.IsCurrent() return s.blockManager.IsCurrent()
} }
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// timeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
timeout time.Duration
// numRetries tells the query how many times to retry asking each peer
// the query.
numRetries uint8
// doneChan lets the query signal the caller when it's done, in case
// it's run in a goroutine.
doneChan chan<- struct{}
}
// QueryOption is a functional option argument to any of the network query
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
// network query).
type QueryOption func(*queryOptions)
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
timeout: QueryTimeout,
numRetries: uint8(QueryNumRetries),
}
}
// Timeout is a query option that lets the query know how long to wait for
// each peer we ask the query to answer it before moving on.
func Timeout(timeout time.Duration) QueryOption {
return func(qo *queryOptions) {
qo.timeout = timeout
}
}
// NumRetries is a query option that lets the query know the maximum number of
// times each peer should be queried. The default is one.
func NumRetries(numRetries uint8) QueryOption {
return func(qo *queryOptions) {
qo.numRetries = numRetries
}
}
// DoneChan allows the caller to pass a channel that will get closed when the
// query is finished.
func DoneChan(doneChan chan<- struct{}) QueryOption {
return func(qo *queryOptions) {
qo.doneChan = doneChan
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
type spMsgSubscription struct {
msgChan chan<- spMsg
quitChan <-chan struct{}
wg *sync.WaitGroup
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (s *ChainService) queryPeers(
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...QueryOption,
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This is done in a single-threaded query because the peerState is held
// in a single thread. This is the only part of the query framework that
// requires access to peerState, so it's done once per query.
peers := s.Peers()
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
allQuit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
// Increase this number to be able to handle more queries at once as
// each channel gets results for all queries, otherwise messages can
// get mixed and there's a vicious cycle of retries causing a bigger
// message flood, more of which get missed.
msgChan := make(chan spMsg)
var subwg sync.WaitGroup
subscription := spMsgSubscription{
msgChan: msgChan,
quitChan: allQuit,
wg: &subwg,
}
// Start a goroutine for each peer that potentially queries that peer.
for _, sp := range peers {
wg.Add(1)
go func(sp *serverPeer) {
numRetries := qo.numRetries
defer wg.Done()
defer sp.unsubscribeRecvMsgs(subscription)
// Should we do this when the goroutine gets a message
// via startQuery rather than at the launch of the
// goroutine?
if !sp.Connected() {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we try to notify
// another of our peer goroutines to
// do a query until we get a signal to
// quit.
select {
case startQuery <- struct{}{}:
case <-quit:
return
case <-allQuit:
return
}
// At this point, we've sent startQuery.
// We return if we've run through this
// section of code numRetries times.
if numRetries--; numRetries == 0 {
return
}
case <-quit:
// After we're told to quit, we return.
return
case <-allQuit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support either
// querying *all* peers simultaneously
// to avoid timeout delays, or starting
// with the syncPeer when not querying
// *all* peers.
sp.subscribeRecvMsg(subscription)
// Don't want the peer hanging on send
// to the channel if we quit before
// reading the channel.
sentChan := make(chan struct{}, 1)
sp.QueueMessage(queryMsg, sentChan)
select {
case <-sentChan:
case <-quit:
return
case <-allQuit:
return
}
timeout = time.After(qo.timeout)
default:
}
}
}(sp)
}
startQuery <- struct{}{}
// This goroutine will wait until all of the peer-query goroutines have
// terminated, and then initiate a query shutdown.
go func() {
wg.Wait()
// If we timed out on each goroutine and didn't quit or time out
// on the main goroutine, make sure our main goroutine knows to
// quit.
select {
case <-allQuit:
default:
close(allQuit)
}
// Close the done channel, if any
if qo.doneChan != nil {
close(qo.doneChan)
}
// Wait until all goroutines started by subscriptions have
// exited after we closed allQuit before letting the message
// channel get garbage collected.
subwg.Wait()
}()
// Loop for any messages sent to us via our subscription channel and
// check them for whether they satisfy the query. Break the loop if it's
// time to quit.
timeout := time.After(time.Duration(len(peers)+1) *
qo.timeout * time.Duration(qo.numRetries))
checkResponses:
for {
select {
case <-timeout:
// When we time out, close the allQuit channel
// if it hasn't already been closed.
select {
case <-allQuit:
default:
close(allQuit)
}
break checkResponses
case <-quit:
break checkResponses
case <-allQuit:
break checkResponses
case sm := <-msgChan:
// TODO: This will get stuck if checkResponse
// gets stuck. This is a caveat for callers that
// should be fixed before exposing this function
// for public use.
checkResponse(sm.sp, sm.msg, quit)
}
}
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the // GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database. // cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash, func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool) *gcs.Filter { extended bool, options ...QueryOption) *gcs.Filter {
getFilter := s.GetBasicFilter getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter putFilter := s.putBasicFilter
@ -1714,17 +1807,54 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
if err != nil { if err != nil {
return nil return nil
} }
replyChan := make(chan *gcs.Filter) s.queryPeers(
s.query <- getCFilterMsg{ // Send a wire.GetCFilterMsg
cfRequest: cfRequest{ wire.NewMsgGetCFilter(&blockHash, extended),
blockHash: blockHash, // Check responses and if we get one that matches,
extended: extended, // end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
if blockHash != response.BlockHash {
// The response doesn't match our
// request. Ignore this message.
return
}
gotFilter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if MakeHeaderForFilter(gotFilter,
*prevHeader) !=
*curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
close(quit)
filter = gotFilter
default:
}
}, },
prevHeader: prevHeader, options...,
curHeader: curHeader, )
reply: replyChan, // If we've found a filter, write it to the database for next time.
}
filter = <-replyChan
if filter != nil { if filter != nil {
putFilter(blockHash, filter) putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t", log.Tracef("Wrote filter for block %s, extended: %t",
@ -1736,20 +1866,72 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
// GetBlockFromNetwork gets a block by requesting it from the network, one peer // GetBlockFromNetwork gets a block by requesting it from the network, one peer
// at a time, until one answers. // at a time, until one answers.
func (s *ChainService) GetBlockFromNetwork( func (s *ChainService) GetBlockFromNetwork(
blockHash chainhash.Hash) *btcutil.Block { blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
blockHeader, height, err := s.GetBlockByHash(blockHash) blockHeader, height, err := s.GetBlockByHash(blockHash)
if err != nil || blockHeader.BlockHash() != blockHash { if err != nil || blockHeader.BlockHash() != blockHash {
return nil return nil
} }
replyChan := make(chan *btcutil.Block) getData := wire.NewMsgGetData()
s.query <- getBlockMsg{ getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
blockHeader: &blockHeader, &blockHash))
height: height, // The block is only updated from the checkResponse function argument,
reply: replyChan, // which is always called single-threadedly. We don't check the block
} // until after the query is finished, so we can just write to it
block := <-replyChan // naively.
if block != nil { var foundBlock *btcutil.Block
log.Tracef("Got block %s from network", blockHash) s.queryPeers(
} // Send a wire.GetCFilterMsg
return block getData,
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "block" messages.
case *wire.MsgBlock:
// If this isn't our block, ignore it.
if response.BlockHash() !=
blockHash {
return
}
block := btcutil.NewBlock(response)
// Only set height if btcutil hasn't
// automagically put one in.
if block.Height() ==
btcutil.BlockHeightUnknown {
block.SetHeight(
int32(height))
}
// If this claims our block but doesn't
// pass the sanity check, the peer is
// trying to bamboozle us. Disconnect
// it.
if err := blockchain.CheckBlockSanity(
block,
// We don't need to check PoW
// because by the time we get
// here, it's been checked
// during header synchronization
s.chainParams.PowLimit,
s.timeSource,
); err != nil {
log.Warnf("Invalid block for %s "+
"received from %s -- "+
"disconnecting peer", blockHash,
sp.Addr())
sp.Disconnect()
return
}
// At this point, the block matches what we know
// about it and we declare it sane. We can kill
// the query and pass the response back to the
// caller.
close(quit)
foundBlock = block
default:
}
},
options...,
)
return foundBlock
} }

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand"
"os" "os"
"reflect" "reflect"
"sync" "sync"
@ -24,14 +23,32 @@ import (
_ "github.com/btcsuite/btcwallet/walletdb/bdb" _ "github.com/btcsuite/btcwallet/walletdb/bdb"
) )
const ( var (
logLevel = btclog.TraceLvl logLevel = btclog.TraceLvl
syncTimeout = 30 * time.Second syncTimeout = 30 * time.Second
syncUpdate = time.Second syncUpdate = time.Second
numTestBlocks = 50 // Don't set this too high for your platform, or the tests will miss
// messages.
// TODO: Make this a benchmark instead.
numQueryThreads = 50
queryOptions = []spvchain.QueryOption{
//spvchain.NumRetries(5),
}
) )
func TestSetup(t *testing.T) { func TestSetup(t *testing.T) {
// Set up logging.
logger, err := btctestlog.NewTestLogger(t)
if err != nil {
t.Fatalf("Could not set up logger: %s", err)
}
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
chainLogger.SetLevel(logLevel)
spvchain.UseLogger(chainLogger)
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
rpcLogger.SetLevel(logLevel)
btcrpcclient.UseLogger(rpcLogger)
// Create a btcd SimNet node and generate 500 blocks // Create a btcd SimNet node and generate 500 blocks
h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil) h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
if err != nil { if err != nil {
@ -147,16 +164,6 @@ func TestSetup(t *testing.T) {
spvchain.BanDuration = 5 * time.Second spvchain.BanDuration = 5 * time.Second
spvchain.RequiredServices = wire.SFNodeNetwork spvchain.RequiredServices = wire.SFNodeNetwork
spvchain.WaitForMoreCFHeaders = time.Second spvchain.WaitForMoreCFHeaders = time.Second
logger, err := btctestlog.NewTestLogger(t)
if err != nil {
t.Fatalf("Could not set up logger: %s", err)
}
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
chainLogger.SetLevel(logLevel)
spvchain.UseLogger(chainLogger)
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
rpcLogger.SetLevel(logLevel)
btcrpcclient.UseLogger(rpcLogger)
svc, err := spvchain.NewChainService(config) svc, err := spvchain.NewChainService(config)
if err != nil { if err != nil {
t.Fatalf("Error creating ChainService: %s", err) t.Fatalf("Error creating ChainService: %s", err)
@ -170,6 +177,13 @@ func TestSetup(t *testing.T) {
t.Fatalf("Couldn't sync ChainService: %s", err) t.Fatalf("Couldn't sync ChainService: %s", err)
} }
// Test that we can get blocks and cfilters via P2P and decide which are
// valid and which aren't.
err = testRandomBlocks(t, svc, h1)
if err != nil {
t.Fatalf("Testing blocks and cfilters failed: %s", err)
}
// Generate 125 blocks on h1 to make sure it reorgs the other nodes. // Generate 125 blocks on h1 to make sure it reorgs the other nodes.
// Ensure the ChainService instance stays caught up. // Ensure the ChainService instance stays caught up.
h1.Node.Generate(125) h1.Node.Generate(125)
@ -209,13 +223,6 @@ func TestSetup(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Couldn't sync ChainService: %s", err) t.Fatalf("Couldn't sync ChainService: %s", err)
} }
// Test that we can get blocks and cfilters via P2P and decide which are
// valid and which aren't.
err = testRandomBlocks(t, svc, h1)
if err != nil {
t.Fatalf("Testing blocks and cfilters failed: %s", err)
}
} }
// csd does a connect-sync-disconnect between nodes in order to support // csd does a connect-sync-disconnect between nodes in order to support
@ -377,7 +384,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService,
// can correctly get filters from them. We don't go through *all* the blocks // can correctly get filters from them. We don't go through *all* the blocks
// because it can be a little slow, but we'll improve that soon-ish hopefully // because it can be a little slow, but we'll improve that soon-ish hopefully
// to the point where we can do it. // to the point where we can do it.
// TODO: Improve concurrency on framework side. // TODO: Make this a benchmark instead.
func testRandomBlocks(t *testing.T, svc *spvchain.ChainService, func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
correctSyncNode *rpctest.Harness) error { correctSyncNode *rpctest.Harness) error {
var haveBest *waddrmgr.BlockStamp var haveBest *waddrmgr.BlockStamp
@ -386,180 +393,200 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
return fmt.Errorf("Couldn't get best snapshot from "+ return fmt.Errorf("Couldn't get best snapshot from "+
"ChainService: %s", err) "ChainService: %s", err)
} }
// Keep track of an error channel // Keep track of an error channel with enough buffer space to track one
errChan := make(chan error) // error per block.
var lastErr error errChan := make(chan error, haveBest.Height)
go func() { // Test getting all of the blocks and filters.
for err := range errChan {
if err != nil {
t.Errorf("%s", err)
lastErr = fmt.Errorf("Couldn't validate all " +
"blocks, filters, and filter headers.")
}
}
}()
// Test getting numTestBlocks random blocks and filters.
var wg sync.WaitGroup var wg sync.WaitGroup
heights := rand.Perm(int(haveBest.Height)) workerQueue := make(chan struct{}, numQueryThreads)
for i := 0; i < numTestBlocks; i++ { for i := int32(1); i <= haveBest.Height; i++ {
wg.Add(1) wg.Add(1)
height := uint32(heights[i]) height := uint32(i)
// Wait until there's room in the worker queue.
workerQueue <- struct{}{}
go func() { go func() {
// On exit, open a spot in workerQueue and tell the
// wait group we're done.
defer func() {
<-workerQueue
}()
defer wg.Done() defer wg.Done()
// Get block header from database. // Get block header from database.
blockHeader, blockHeight, err := svc.GetBlockByHeight(height) blockHeader, blockHeight, err := svc.GetBlockByHeight(
height)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get block "+ errChan <- fmt.Errorf("Couldn't get block "+
"header by height %d: %s", height, err) "header by height %d: %s", height, err)
return return
} }
if blockHeight != height { if blockHeight != height {
errChan <- fmt.Errorf("Block height retrieved from DB "+ errChan <- fmt.Errorf("Block height retrieved "+
"doesn't match expected height. Want: %d, "+ "from DB doesn't match expected "+
"have: %d", height, blockHeight) "height. Want: %d, have: %d", height,
blockHeight)
return return
} }
blockHash := blockHeader.BlockHash() blockHash := blockHeader.BlockHash()
// Get block via RPC. // Get block via RPC.
wantBlock, err := correctSyncNode.Node.GetBlock(&blockHash) wantBlock, err := correctSyncNode.Node.GetBlock(
&blockHash)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get block %d (%s) by RPC", errChan <- fmt.Errorf("Couldn't get block %d "+
height, blockHash) "(%s) by RPC", height, blockHash)
return return
} }
// Get block from network. // Get block from network.
haveBlock := svc.GetBlockFromNetwork(blockHash) haveBlock := svc.GetBlockFromNetwork(blockHash,
queryOptions...)
if haveBlock == nil { if haveBlock == nil {
errChan <- fmt.Errorf("Couldn't get block %d (%s) from"+ errChan <- fmt.Errorf("Couldn't get block %d "+
"network", height, blockHash) "(%s) from network", height, blockHash)
return return
} }
// Check that network and RPC blocks match. // Check that network and RPC blocks match.
if !reflect.DeepEqual(*haveBlock.MsgBlock(), *wantBlock) { if !reflect.DeepEqual(*haveBlock.MsgBlock(),
errChan <- fmt.Errorf("Block from network doesn't match "+ *wantBlock) {
"block from RPC. Want: %s, RPC: %s, network: "+ errChan <- fmt.Errorf("Block from network "+
"%s", blockHash, wantBlock.BlockHash(), "doesn't match block from RPC. Want: "+
"%s, RPC: %s, network: %s", blockHash,
wantBlock.BlockHash(),
haveBlock.MsgBlock().BlockHash()) haveBlock.MsgBlock().BlockHash())
return return
} }
// Check that block height matches what we have. // Check that block height matches what we have.
if int32(blockHeight) != haveBlock.Height() { if int32(blockHeight) != haveBlock.Height() {
errChan <- fmt.Errorf("Block height from network doesn't "+ errChan <- fmt.Errorf("Block height from "+
"match expected height. Want: %s, network: %s", "network doesn't match expected "+
"height. Want: %s, network: %s",
blockHeight, haveBlock.Height()) blockHeight, haveBlock.Height())
return return
} }
// Get basic cfilter from network. // Get basic cfilter from network.
haveFilter := svc.GetCFilter(blockHash, false) haveFilter := svc.GetCFilter(blockHash, false,
queryOptions...)
if haveFilter == nil { if haveFilter == nil {
errChan <- fmt.Errorf("Couldn't get basic "+ errChan <- fmt.Errorf("Couldn't get basic "+
"filter for block %d", height) "filter for block %d", height)
return return
} }
// Get basic cfilter from RPC. // Get basic cfilter from RPC.
wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash, wantFilter, err := correctSyncNode.Node.GetCFilter(
false) &blockHash, false)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter for "+ errChan <- fmt.Errorf("Couldn't get basic "+
"block %d via RPC: %s", height, err) "filter for block %d via RPC: %s",
height, err)
return return
} }
// Check that network and RPC cfilters match. // Check that network and RPC cfilters match.
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) { if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
errChan <- fmt.Errorf("Basic filter from P2P network/DB"+ errChan <- fmt.Errorf("Basic filter from P2P "+
" doesn't match RPC value for block %d", height) "network/DB doesn't match RPC value "+
"for block %d", height)
return return
} }
// Calculate basic filter from block. // Calculate basic filter from block.
calcFilter, err := spvchain.BuildBasicFilter( calcFilter, err := spvchain.BuildBasicFilter(
haveBlock.MsgBlock()) haveBlock.MsgBlock())
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't build basic filter for "+ errChan <- fmt.Errorf("Couldn't build basic "+
"block %d (%s): %s", height, blockHash, err) "filter for block %d (%s): %s", height,
blockHash, err)
return return
} }
// Check that the network value matches the calculated value // Check that the network value matches the calculated
// from the block. // value from the block.
if !reflect.DeepEqual(*haveFilter, *calcFilter) { if !reflect.DeepEqual(*haveFilter, *calcFilter) {
errChan <- fmt.Errorf("Basic filter from P2P network/DB "+ errChan <- fmt.Errorf("Basic filter from P2P "+
"doesn't match calculated value for block %d", "network/DB doesn't match calculated "+
height) "value for block %d", height)
return return
} }
// Get previous basic filter header from the database. // Get previous basic filter header from the database.
prevHeader, err := svc.GetBasicHeader(blockHeader.PrevBlock) prevHeader, err := svc.GetBasicHeader(
blockHeader.PrevBlock)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter header "+ errChan <- fmt.Errorf("Couldn't get basic "+
"for block %d (%s) from DB: %s", height-1, "filter header for block %d (%s) from "+
"DB: %s", height-1,
blockHeader.PrevBlock, err) blockHeader.PrevBlock, err)
return return
} }
// Get current basic filter header from the database. // Get current basic filter header from the database.
curHeader, err := svc.GetBasicHeader(blockHash) curHeader, err := svc.GetBasicHeader(blockHash)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get basic filter header "+ errChan <- fmt.Errorf("Couldn't get basic "+
"for block %d (%s) from DB: %s", height-1, "filter header for block %d (%s) from "+
blockHash, err) "DB: %s", height-1, blockHash, err)
return return
} }
// Check that the filter and header line up. // Check that the filter and header line up.
calcHeader := spvchain.MakeHeaderForFilter(calcFilter, calcHeader := spvchain.MakeHeaderForFilter(calcFilter,
*prevHeader) *prevHeader)
if !bytes.Equal(curHeader[:], calcHeader[:]) { if !bytes.Equal(curHeader[:], calcHeader[:]) {
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+ errChan <- fmt.Errorf("Filter header doesn't "+
"%s, got: %s", curHeader, calcHeader) "match. Want: %s, got: %s", curHeader,
calcHeader)
return return
} }
// Get extended cfilter from network // Get extended cfilter from network
haveFilter = svc.GetCFilter(blockHash, true) haveFilter = svc.GetCFilter(blockHash, true,
queryOptions...)
if haveFilter == nil { if haveFilter == nil {
errChan <- fmt.Errorf("Couldn't get extended "+ errChan <- fmt.Errorf("Couldn't get extended "+
"filter for block %d", height) "filter for block %d", height)
return return
} }
// Get extended cfilter from RPC // Get extended cfilter from RPC
wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash, wantFilter, err = correctSyncNode.Node.GetCFilter(
true) &blockHash, true)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter for "+ errChan <- fmt.Errorf("Couldn't get extended "+
"block %d via RPC: %s", height, err) "filter for block %d via RPC: %s",
height, err)
return return
} }
// Check that network and RPC cfilters match // Check that network and RPC cfilters match
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) { if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+ errChan <- fmt.Errorf("Extended filter from "+
" doesn't match RPC value for block %d", height) "P2P network/DB doesn't match RPC "+
"value for block %d", height)
return return
} }
// Calculate extended filter from block // Calculate extended filter from block
calcFilter, err = spvchain.BuildExtFilter( calcFilter, err = spvchain.BuildExtFilter(
haveBlock.MsgBlock()) haveBlock.MsgBlock())
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't build extended filter for "+ errChan <- fmt.Errorf("Couldn't build extended"+
"block %d (%s): %s", height, blockHash, err) " filter for block %d (%s): %s", height,
blockHash, err)
return return
} }
// Check that the network value matches the calculated value // Check that the network value matches the calculated
// from the block. // value from the block.
if !reflect.DeepEqual(*haveFilter, *calcFilter) { if !reflect.DeepEqual(*haveFilter, *calcFilter) {
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+ errChan <- fmt.Errorf("Extended filter from "+
" doesn't match calculated value for block %d", "P2P network/DB doesn't match "+
height) "calculated value for block %d", height)
return return
} }
// Get previous extended filter header from the database. // Get previous extended filter header from the
prevHeader, err = svc.GetExtHeader(blockHeader.PrevBlock) // database.
prevHeader, err = svc.GetExtHeader(
blockHeader.PrevBlock)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter header"+ errChan <- fmt.Errorf("Couldn't get extended "+
" for block %d (%s) from DB: %s", height-1, "filter header for block %d (%s) from "+
"DB: %s", height-1,
blockHeader.PrevBlock, err) blockHeader.PrevBlock, err)
return return
} }
// Get current basic filter header from the database. // Get current basic filter header from the database.
curHeader, err = svc.GetExtHeader(blockHash) curHeader, err = svc.GetExtHeader(blockHash)
if err != nil { if err != nil {
errChan <- fmt.Errorf("Couldn't get extended filter header"+ errChan <- fmt.Errorf("Couldn't get extended "+
" for block %d (%s) from DB: %s", height-1, "filter header for block %d (%s) from "+
"DB: %s", height-1,
blockHash, err) blockHash, err)
return return
} }
@ -567,20 +594,29 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
calcHeader = spvchain.MakeHeaderForFilter(calcFilter, calcHeader = spvchain.MakeHeaderForFilter(calcFilter,
*prevHeader) *prevHeader)
if !bytes.Equal(curHeader[:], calcHeader[:]) { if !bytes.Equal(curHeader[:], calcHeader[:]) {
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+ errChan <- fmt.Errorf("Filter header doesn't "+
"%s, got: %s", curHeader, calcHeader) "match. Want: %s, got: %s", curHeader,
calcHeader)
return return
} }
}() }()
} }
// Wait for all queries to finish. // Wait for all queries to finish.
wg.Wait() wg.Wait()
if logLevel != btclog.Off {
t.Logf("Finished checking %d blocks and their cfilters",
numTestBlocks)
}
// Close the error channel to make the error monitoring goroutine // Close the error channel to make the error monitoring goroutine
// finish. // finish.
close(errChan) close(errChan)
var lastErr error
for err := range errChan {
if err != nil {
t.Errorf("%s", err)
lastErr = fmt.Errorf("Couldn't validate all " +
"blocks, filters, and filter headers.")
}
}
if logLevel != btclog.Off {
t.Logf("Finished checking %d blocks and their cfilters",
haveBest.Height)
}
return lastErr return lastErr
} }