From fe632ff233096e2afac29b32117874f9d92d5342 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 26 Apr 2017 16:34:05 -0600 Subject: [PATCH] Fix concurrency issues. --- spvsvc/spvchain/blockmanager.go | 13 +- spvsvc/spvchain/notifications.go | 139 +------- spvsvc/spvchain/spvchain.go | 574 ++++++++++++++++++++----------- spvsvc/spvchain/sync_test.go | 244 +++++++------ 4 files changed, 531 insertions(+), 439 deletions(-) diff --git a/spvsvc/spvchain/blockmanager.go b/spvsvc/spvchain/blockmanager.go index 1533cf7..b0818cb 100644 --- a/spvsvc/spvchain/blockmanager.go +++ b/spvsvc/spvchain/blockmanager.go @@ -1031,9 +1031,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // Should probably use better isolation for this but we're in // the same package. One of the things to clean up when we do // more general cleanup. + sp.mtxReqCFH.Lock() sp.requestedCFHeaders[cfhReqB] = cfhCount - sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false) sp.requestedCFHeaders[cfhReqE] = cfhCount + sp.mtxReqCFH.Unlock() + sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false) sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true) }) @@ -1078,13 +1080,20 @@ func (b *blockManager) QueueCFHeaders(cfheaders *wire.MsgCFHeaders, extended: cfheaders.Extended, stopHash: cfheaders.StopHash, } - if sp.requestedCFHeaders[req] != len(cfheaders.HeaderHashes) { + // TODO: Get rid of this by refactoring all of this using the query API + sp.mtxReqCFH.Lock() + expLen := sp.requestedCFHeaders[req] + sp.mtxReqCFH.Unlock() + if expLen != len(cfheaders.HeaderHashes) { log.Warnf("Received cfheaders message doesn't match any "+ "getcfheaders request. Peer %s is probably on a "+ "different chain -- ignoring", sp.Addr()) return } + // TODO: Remove this by refactoring this section into a query client. + sp.mtxReqCFH.Lock() delete(sp.requestedCFHeaders, req) + sp.mtxReqCFH.Unlock() // Track number of pending cfheaders messsages for both basic and // extended filters. diff --git a/spvsvc/spvchain/notifications.go b/spvsvc/spvchain/notifications.go index cf6db03..c38c2a1 100644 --- a/spvsvc/spvchain/notifications.go +++ b/spvsvc/spvchain/notifications.go @@ -8,13 +8,7 @@ import ( "errors" "github.com/btcsuite/btcd/addrmgr" - "github.com/btcsuite/btcd/blockchain" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" - "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcutil" - "github.com/btcsuite/btcutil/gcs" - "github.com/btcsuite/btcutil/gcs/builder" ) type getConnCountMsg struct { @@ -54,19 +48,6 @@ type forAllPeersMsg struct { closure func(*serverPeer) } -type getCFilterMsg struct { - cfRequest - prevHeader *chainhash.Hash - curHeader *chainhash.Hash - reply chan *gcs.Filter -} - -type getBlockMsg struct { - blockHeader *wire.BlockHeader - height uint32 - reply chan *btcutil.Block -} - // TODO: General - abstract out more of blockmanager into queries. It'll make // this way more maintainable and usable. @@ -172,128 +153,12 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { msg.reply <- errors.New("peer not found") case forAllPeersMsg: + // TODO: Remove this when it's unnecessary due to wider use of + // queryPeers. // Run the closure on all peers in the passed state. state.forAllPeers(msg.closure) // Even though this is a query, there's no reply channel as the // forAllPeers method doesn't return anything. An error might be // useful in the future. - case getCFilterMsg: - found := false - state.queryPeers( - // Should we query this peer? - func(sp *serverPeer) bool { - // Don't send requests to disconnected peers. - return sp.Connected() - }, - // Send a wire.GetCFilterMsg - wire.NewMsgGetCFilter(&msg.blockHash, msg.extended), - // Check responses and if we get one that matches, - // end the query early. - func(sp *serverPeer, resp wire.Message, - quit chan<- struct{}) { - switch response := resp.(type) { - // We're only interested in "cfilter" messages. - case *wire.MsgCFilter: - if len(response.Data) < 4 { - // Filter data is too short. - // Ignore this message. - return - } - filter, err := - gcs.FromNBytes(builder.DefaultP, - response.Data) - if err != nil { - // Malformed filter data. We - // can ignore this message. - return - } - if MakeHeaderForFilter(filter, - *msg.prevHeader) != - *msg.curHeader { - // Filter data doesn't match - // the headers we know about. - // Ignore this response. - return - } - // At this point, the filter matches - // what we know about it and we declare - // it sane. We can kill the query and - // pass the response back to the caller. - found = true - close(quit) - msg.reply <- filter - default: - } - }, - ) - // We timed out without finding a correct answer to our query. - if !found { - msg.reply <- nil - } - case getBlockMsg: - found := false - getData := wire.NewMsgGetData() - blockHash := msg.blockHeader.BlockHash() - getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock, - &blockHash)) - state.queryPeers( - // Should we query this peer? - func(sp *serverPeer) bool { - // Don't send requests to disconnected peers. - return sp.Connected() - }, - // Send a wire.GetCFilterMsg - getData, - // Check responses and if we get one that matches, - // end the query early. - func(sp *serverPeer, resp wire.Message, - quit chan<- struct{}) { - switch response := resp.(type) { - // We're only interested in "block" messages. - case *wire.MsgBlock: - // If this isn't our block, ignore it. - if response.BlockHash() != - blockHash { - return - } - block := btcutil.NewBlock(response) - // Only set height if btcutil hasn't - // automagically put one in. - if block.Height() == - btcutil.BlockHeightUnknown { - block.SetHeight( - int32(msg.height)) - } - // If this claims our block but doesn't - // pass the sanity check, the peer is - // trying to bamboozle us. Disconnect - // it. - if err := blockchain.CheckBlockSanity( - block, - // We don't need to check PoW - // because by the time we get - // here, it's been checked - // during header synchronization - s.chainParams.PowLimit, - s.timeSource, - ); err != nil { - log.Warnf("Invalid block for "+ - "%s received from %s "+ - "-- disconnecting peer", - blockHash, sp.Addr()) - sp.Disconnect() - return - } - found = true - close(quit) - msg.reply <- block - default: - } - }, - ) - // We timed out without finding a correct answer to our query. - if !found { - msg.reply <- nil - } } } diff --git a/spvsvc/spvchain/spvchain.go b/spvsvc/spvchain/spvchain.go index c663ff4..95337e3 100644 --- a/spvsvc/spvchain/spvchain.go +++ b/spvsvc/spvchain/spvchain.go @@ -18,6 +18,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil/gcs" + "github.com/btcsuite/btcutil/gcs/builder" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/walletdb" @@ -63,8 +64,14 @@ var ( // from DNS. DisableDNSSeed = false - // Timeout specifies how long to wait for a peer to answer a query. - Timeout = time.Second * 5 + // QueryTimeout specifies how long to wait for a peer to answer a query. + QueryTimeout = time.Second * 3 + + // QueryNumRetries specifies how many times to retry sending a query to + // each peer before we've concluded we aren't going to get a valid + // response. This allows to make up for missed messages in some + // instances. + QueryNumRetries = 2 ) // updatePeerHeightsMsg is a message sent from the blockmanager to the server @@ -110,137 +117,6 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) { ps.forAllOutboundPeers(closure) } -// Query options can be modified per-query, unlike global options. -// TODO: Make more query options that override global options. -type queryOptions struct { - // queryTimeout lets the query know how long to wait for a peer to - // answer the query before moving onto the next peer. - queryTimeout time.Duration -} - -// defaultQueryOptions returns a queryOptions set to package-level defaults. -func defaultQueryOptions() *queryOptions { - return &queryOptions{ - queryTimeout: Timeout, - } -} - -// QueryTimeout is a query option that lets the query know to ask each peer we're -// connected to for its opinion, if any. By default, we only ask peers until one -// gives us a valid response. -func QueryTimeout(timeout time.Duration) func(*queryOptions) { - return func(qo *queryOptions) { - qo.queryTimeout = timeout - } -} - -type spMsg struct { - sp *serverPeer - msg wire.Message -} - -// queryPeers is a helper function that sends a query to one or more peers and -// waits for an answer. The timeout for queries is set by the QueryTimeout -// package-level variable. -func (ps *peerState) queryPeers( - // selectPeer is a closure which decides whether or not to send the - // query to the peer. - selectPeer func(sp *serverPeer) bool, - // queryMsg is the message to send to each peer selected by selectPeer. - queryMsg wire.Message, - // checkResponse is caled for every message within the timeout period. - // The quit channel lets the query know to terminate because the - // required response has been found. This is done by closing the - // channel. - checkResponse func(sp *serverPeer, resp wire.Message, - quit chan<- struct{}), - // options takes functional options for executing the query. - options ...func(*queryOptions), -) { - qo := defaultQueryOptions() - for _, option := range options { - option(qo) - } - // This will be shared state between the per-peer goroutines. - quit := make(chan struct{}) - startQuery := make(chan struct{}) - var wg sync.WaitGroup - channel := make(chan spMsg) - - // This goroutine will monitor all messages from all peers until the - // peer goroutines all exit. - go func() { - for { - select { - case <-quit: - close(channel) - ps.forAllPeers( - func(sp *serverPeer) { - sp.unsubscribeRecvMsgs(channel) - }) - return - case sm := <-channel: - // TODO: This will get stuck if checkResponse - // gets stuck. - checkResponse(sm.sp, sm.msg, quit) - } - } - }() - - // Start a goroutine for each peer that potentially queries each peer - ps.forAllPeers(func(sp *serverPeer) { - wg.Add(1) - go func() { - defer wg.Done() - if !selectPeer(sp) { - return - } - timeout := make(<-chan time.Time) - for { - select { - case <-timeout: - // After timeout, we return and notify - // another goroutine that we've done so. - // We only send if there's someone left - // to receive. - startQuery <- struct{}{} - return - case <-quit: - // After we're told to quit, we return. - return - case <-startQuery: - // We're the lucky peer whose turn it is - // to try to answer the current query. - // TODO: Fix this to support multiple - // queries at once. For now, we're - // relying on the query handling loop - // to make sure we don't interrupt - // another query. We need broadcast - // support in OnRead to do this right. - // TODO: Fix this to support either - // querying *all* peers simultaneously - // to avoid timeout delays, or starting - // with the syncPeer when not querying - // *all* peers. - sp.subscribeRecvMsg(channel) - sp.QueueMessage(queryMsg, nil) - timeout = time.After(qo.queryTimeout) - default: - } - } - }() - }) - startQuery <- struct{}{} - wg.Wait() - // If we timed out and didn't quit, make sure our response monitor - // goroutine knows to quit. - select { - case <-quit: - default: - close(quit) - } -} - // cfhRequest records which cfheaders we've requested, and the order in which // we've requested them. Since there's no way to associate the cfheaders to the // actual block hashes based on the cfheaders message to keep it compact, we @@ -250,12 +126,6 @@ type cfhRequest struct { stopHash chainhash.Hash } -// cfRequest records which cfilters we've requested. -type cfRequest struct { - extended bool - blockHash chainhash.Hash -} - // serverPeer extends the peer to maintain state shared by the server and // the blockmanager. type serverPeer struct { @@ -264,24 +134,27 @@ type serverPeer struct { *peer.Peer - connReq *connmgr.ConnReq - server *ChainService - persistent bool - continueHash *chainhash.Hash - relayMtx sync.Mutex - requestQueue []*wire.InvVect - requestedCFHeaders map[cfhRequest]int - knownAddresses map[string]struct{} - banScore connmgr.DynamicBanScore - quit chan struct{} + connReq *connmgr.ConnReq + server *ChainService + persistent bool + continueHash *chainhash.Hash + relayMtx sync.Mutex + requestQueue []*wire.InvVect + knownAddresses map[string]struct{} + banScore connmgr.DynamicBanScore + quit chan struct{} // The following slice of channels is used to subscribe to messages from // the peer. This allows broadcast to multiple subscribers at once, // allowing for multiple queries to be going to multiple peers at any // one time. The mutex is for subscribe/unsubscribe functionality. // The sends on these channels WILL NOT block; any messages the channel // can't accept will be dropped silently. - recvSubscribers []chan<- spMsg + recvSubscribers map[spMsgSubscription]struct{} mtxSubscribers sync.RWMutex + // These are only necessary until the cfheaders logic is refactored as + // a query client. + requestedCFHeaders map[cfhRequest]int + mtxReqCFH sync.Mutex } // newServerPeer returns a new serverPeer instance. The peer needs to be set by @@ -293,6 +166,7 @@ func newServerPeer(s *ChainService, isPersistent bool) *serverPeer { requestedCFHeaders: make(map[cfhRequest]int), knownAddresses: make(map[string]struct{}), quit: make(chan struct{}), + recvSubscribers: make(map[spMsgSubscription]struct{}), } } @@ -615,42 +489,38 @@ func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message, err error) { sp.server.AddBytesReceived(uint64(bytesRead)) // Try to send a message to the subscriber channel if it isn't nil, but - // don't block on failure. + // don't block on failure. Do this inside a goroutine to prevent the + // server from slowing down too fast. sp.mtxSubscribers.RLock() defer sp.mtxSubscribers.RUnlock() - for _, channel := range sp.recvSubscribers { - if channel != nil { + for subscription := range sp.recvSubscribers { + subscription.wg.Add(1) + go func(subscription spMsgSubscription) { + defer subscription.wg.Done() select { - case channel <- spMsg{ - sp: sp, + case <-subscription.quitChan: + case subscription.msgChan <- spMsg{ msg: msg, + sp: sp, }: - default: } - } + }(subscription) } } // subscribeRecvMsg handles adding OnRead subscriptions to the server peer. -func (sp *serverPeer) subscribeRecvMsg(channel chan<- spMsg) { +func (sp *serverPeer) subscribeRecvMsg(subscription spMsgSubscription) { sp.mtxSubscribers.Lock() defer sp.mtxSubscribers.Unlock() - sp.recvSubscribers = append(sp.recvSubscribers, channel) + sp.recvSubscribers[subscription] = struct{}{} } // unsubscribeRecvMsgs handles removing OnRead subscriptions from the server // peer. -func (sp *serverPeer) unsubscribeRecvMsgs(channel chan<- spMsg) { +func (sp *serverPeer) unsubscribeRecvMsgs(subscription spMsgSubscription) { sp.mtxSubscribers.Lock() defer sp.mtxSubscribers.Unlock() - var updatedSubscribers []chan<- spMsg - for _, candidate := range sp.recvSubscribers { - if candidate != channel { - updatedSubscribers = append(updatedSubscribers, - candidate) - } - } - sp.recvSubscribers = updatedSubscribers + delete(sp.recvSubscribers, subscription) } // OnWrite is invoked when a peer sends a message and it is used to update @@ -683,9 +553,6 @@ type ChainService struct { timeSource blockchain.MedianTimeSource services wire.ServiceFlag - cfilterRequests map[cfRequest][]chan *gcs.Filter - cfRequestHeaders map[cfRequest][2]*chainhash.Hash - userAgentName string userAgentVersion string } @@ -894,8 +761,6 @@ func NewChainService(cfg Config) (*ChainService, error) { services: Services, userAgentName: UserAgentName, userAgentVersion: UserAgentVersion, - cfilterRequests: make(map[cfRequest][]chan *gcs.Filter), - cfRequestHeaders: make(map[cfRequest][2]*chainhash.Hash), } err := createSPVNS(s.namespace, &s.chainParams) @@ -1686,10 +1551,238 @@ func (s *ChainService) IsCurrent() bool { return s.blockManager.IsCurrent() } +// Query options can be modified per-query, unlike global options. +// TODO: Make more query options that override global options. +type queryOptions struct { + // timeout lets the query know how long to wait for a peer to + // answer the query before moving onto the next peer. + timeout time.Duration + + // numRetries tells the query how many times to retry asking each peer + // the query. + numRetries uint8 + + // doneChan lets the query signal the caller when it's done, in case + // it's run in a goroutine. + doneChan chan<- struct{} +} + +// QueryOption is a functional option argument to any of the network query +// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a +// network query). +type QueryOption func(*queryOptions) + +// defaultQueryOptions returns a queryOptions set to package-level defaults. +func defaultQueryOptions() *queryOptions { + return &queryOptions{ + timeout: QueryTimeout, + numRetries: uint8(QueryNumRetries), + } +} + +// Timeout is a query option that lets the query know how long to wait for +// each peer we ask the query to answer it before moving on. +func Timeout(timeout time.Duration) QueryOption { + return func(qo *queryOptions) { + qo.timeout = timeout + } +} + +// NumRetries is a query option that lets the query know the maximum number of +// times each peer should be queried. The default is one. +func NumRetries(numRetries uint8) QueryOption { + return func(qo *queryOptions) { + qo.numRetries = numRetries + } +} + +// DoneChan allows the caller to pass a channel that will get closed when the +// query is finished. +func DoneChan(doneChan chan<- struct{}) QueryOption { + return func(qo *queryOptions) { + qo.doneChan = doneChan + } +} + +type spMsg struct { + sp *serverPeer + msg wire.Message +} + +type spMsgSubscription struct { + msgChan chan<- spMsg + quitChan <-chan struct{} + wg *sync.WaitGroup +} + +// queryPeers is a helper function that sends a query to one or more peers and +// waits for an answer. The timeout for queries is set by the QueryTimeout +// package-level variable. +func (s *ChainService) queryPeers( + // queryMsg is the message to send to each peer selected by selectPeer. + queryMsg wire.Message, + // checkResponse is caled for every message within the timeout period. + // The quit channel lets the query know to terminate because the + // required response has been found. This is done by closing the + // channel. + checkResponse func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}), + // options takes functional options for executing the query. + options ...QueryOption, +) { + qo := defaultQueryOptions() + for _, option := range options { + option(qo) + } + + // This is done in a single-threaded query because the peerState is held + // in a single thread. This is the only part of the query framework that + // requires access to peerState, so it's done once per query. + peers := s.Peers() + + // This will be shared state between the per-peer goroutines. + quit := make(chan struct{}) + allQuit := make(chan struct{}) + startQuery := make(chan struct{}) + var wg sync.WaitGroup + // Increase this number to be able to handle more queries at once as + // each channel gets results for all queries, otherwise messages can + // get mixed and there's a vicious cycle of retries causing a bigger + // message flood, more of which get missed. + msgChan := make(chan spMsg) + var subwg sync.WaitGroup + subscription := spMsgSubscription{ + msgChan: msgChan, + quitChan: allQuit, + wg: &subwg, + } + + // Start a goroutine for each peer that potentially queries that peer. + for _, sp := range peers { + wg.Add(1) + go func(sp *serverPeer) { + numRetries := qo.numRetries + defer wg.Done() + defer sp.unsubscribeRecvMsgs(subscription) + // Should we do this when the goroutine gets a message + // via startQuery rather than at the launch of the + // goroutine? + if !sp.Connected() { + return + } + timeout := make(<-chan time.Time) + for { + select { + case <-timeout: + // After timeout, we try to notify + // another of our peer goroutines to + // do a query until we get a signal to + // quit. + select { + case startQuery <- struct{}{}: + case <-quit: + return + case <-allQuit: + return + } + // At this point, we've sent startQuery. + // We return if we've run through this + // section of code numRetries times. + if numRetries--; numRetries == 0 { + return + } + case <-quit: + // After we're told to quit, we return. + return + case <-allQuit: + // After we're told to quit, we return. + return + case <-startQuery: + // We're the lucky peer whose turn it is + // to try to answer the current query. + // TODO: Fix this to support either + // querying *all* peers simultaneously + // to avoid timeout delays, or starting + // with the syncPeer when not querying + // *all* peers. + sp.subscribeRecvMsg(subscription) + // Don't want the peer hanging on send + // to the channel if we quit before + // reading the channel. + sentChan := make(chan struct{}, 1) + sp.QueueMessage(queryMsg, sentChan) + select { + case <-sentChan: + case <-quit: + return + case <-allQuit: + return + } + timeout = time.After(qo.timeout) + default: + } + } + }(sp) + } + startQuery <- struct{}{} + + // This goroutine will wait until all of the peer-query goroutines have + // terminated, and then initiate a query shutdown. + go func() { + wg.Wait() + // If we timed out on each goroutine and didn't quit or time out + // on the main goroutine, make sure our main goroutine knows to + // quit. + select { + case <-allQuit: + default: + close(allQuit) + } + // Close the done channel, if any + if qo.doneChan != nil { + close(qo.doneChan) + } + // Wait until all goroutines started by subscriptions have + // exited after we closed allQuit before letting the message + // channel get garbage collected. + subwg.Wait() + }() + + // Loop for any messages sent to us via our subscription channel and + // check them for whether they satisfy the query. Break the loop if it's + // time to quit. + timeout := time.After(time.Duration(len(peers)+1) * + qo.timeout * time.Duration(qo.numRetries)) +checkResponses: + for { + select { + case <-timeout: + // When we time out, close the allQuit channel + // if it hasn't already been closed. + select { + case <-allQuit: + default: + close(allQuit) + } + break checkResponses + case <-quit: + break checkResponses + case <-allQuit: + break checkResponses + case sm := <-msgChan: + // TODO: This will get stuck if checkResponse + // gets stuck. This is a caveat for callers that + // should be fixed before exposing this function + // for public use. + checkResponse(sm.sp, sm.msg, quit) + } + } +} + // GetCFilter gets a cfilter from the database. Failing that, it requests the // cfilter from the network and writes it to the database. func (s *ChainService) GetCFilter(blockHash chainhash.Hash, - extended bool) *gcs.Filter { + extended bool, options ...QueryOption) *gcs.Filter { getFilter := s.GetBasicFilter getHeader := s.GetBasicHeader putFilter := s.putBasicFilter @@ -1714,17 +1807,54 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash, if err != nil { return nil } - replyChan := make(chan *gcs.Filter) - s.query <- getCFilterMsg{ - cfRequest: cfRequest{ - blockHash: blockHash, - extended: extended, + s.queryPeers( + // Send a wire.GetCFilterMsg + wire.NewMsgGetCFilter(&blockHash, extended), + // Check responses and if we get one that matches, + // end the query early. + func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}) { + switch response := resp.(type) { + // We're only interested in "cfilter" messages. + case *wire.MsgCFilter: + if len(response.Data) < 4 { + // Filter data is too short. + // Ignore this message. + return + } + if blockHash != response.BlockHash { + // The response doesn't match our + // request. Ignore this message. + return + } + gotFilter, err := + gcs.FromNBytes(builder.DefaultP, + response.Data) + if err != nil { + // Malformed filter data. We + // can ignore this message. + return + } + if MakeHeaderForFilter(gotFilter, + *prevHeader) != + *curHeader { + // Filter data doesn't match + // the headers we know about. + // Ignore this response. + return + } + // At this point, the filter matches + // what we know about it and we declare + // it sane. We can kill the query and + // pass the response back to the caller. + close(quit) + filter = gotFilter + default: + } }, - prevHeader: prevHeader, - curHeader: curHeader, - reply: replyChan, - } - filter = <-replyChan + options..., + ) + // If we've found a filter, write it to the database for next time. if filter != nil { putFilter(blockHash, filter) log.Tracef("Wrote filter for block %s, extended: %t", @@ -1736,20 +1866,72 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash, // GetBlockFromNetwork gets a block by requesting it from the network, one peer // at a time, until one answers. func (s *ChainService) GetBlockFromNetwork( - blockHash chainhash.Hash) *btcutil.Block { + blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block { blockHeader, height, err := s.GetBlockByHash(blockHash) if err != nil || blockHeader.BlockHash() != blockHash { return nil } - replyChan := make(chan *btcutil.Block) - s.query <- getBlockMsg{ - blockHeader: &blockHeader, - height: height, - reply: replyChan, - } - block := <-replyChan - if block != nil { - log.Tracef("Got block %s from network", blockHash) - } - return block + getData := wire.NewMsgGetData() + getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock, + &blockHash)) + // The block is only updated from the checkResponse function argument, + // which is always called single-threadedly. We don't check the block + // until after the query is finished, so we can just write to it + // naively. + var foundBlock *btcutil.Block + s.queryPeers( + // Send a wire.GetCFilterMsg + getData, + // Check responses and if we get one that matches, + // end the query early. + func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}) { + switch response := resp.(type) { + // We're only interested in "block" messages. + case *wire.MsgBlock: + // If this isn't our block, ignore it. + if response.BlockHash() != + blockHash { + return + } + block := btcutil.NewBlock(response) + // Only set height if btcutil hasn't + // automagically put one in. + if block.Height() == + btcutil.BlockHeightUnknown { + block.SetHeight( + int32(height)) + } + // If this claims our block but doesn't + // pass the sanity check, the peer is + // trying to bamboozle us. Disconnect + // it. + if err := blockchain.CheckBlockSanity( + block, + // We don't need to check PoW + // because by the time we get + // here, it's been checked + // during header synchronization + s.chainParams.PowLimit, + s.timeSource, + ); err != nil { + log.Warnf("Invalid block for %s "+ + "received from %s -- "+ + "disconnecting peer", blockHash, + sp.Addr()) + sp.Disconnect() + return + } + // At this point, the block matches what we know + // about it and we declare it sane. We can kill + // the query and pass the response back to the + // caller. + close(quit) + foundBlock = block + default: + } + }, + options..., + ) + return foundBlock } diff --git a/spvsvc/spvchain/sync_test.go b/spvsvc/spvchain/sync_test.go index a3fcc1a..903de9c 100644 --- a/spvsvc/spvchain/sync_test.go +++ b/spvsvc/spvchain/sync_test.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io/ioutil" - "math/rand" "os" "reflect" "sync" @@ -24,14 +23,32 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" ) -const ( - logLevel = btclog.TraceLvl - syncTimeout = 30 * time.Second - syncUpdate = time.Second - numTestBlocks = 50 +var ( + logLevel = btclog.TraceLvl + syncTimeout = 30 * time.Second + syncUpdate = time.Second + // Don't set this too high for your platform, or the tests will miss + // messages. + // TODO: Make this a benchmark instead. + numQueryThreads = 50 + queryOptions = []spvchain.QueryOption{ + //spvchain.NumRetries(5), + } ) func TestSetup(t *testing.T) { + // Set up logging. + logger, err := btctestlog.NewTestLogger(t) + if err != nil { + t.Fatalf("Could not set up logger: %s", err) + } + chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ") + chainLogger.SetLevel(logLevel) + spvchain.UseLogger(chainLogger) + rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ") + rpcLogger.SetLevel(logLevel) + btcrpcclient.UseLogger(rpcLogger) + // Create a btcd SimNet node and generate 500 blocks h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil) if err != nil { @@ -147,16 +164,6 @@ func TestSetup(t *testing.T) { spvchain.BanDuration = 5 * time.Second spvchain.RequiredServices = wire.SFNodeNetwork spvchain.WaitForMoreCFHeaders = time.Second - logger, err := btctestlog.NewTestLogger(t) - if err != nil { - t.Fatalf("Could not set up logger: %s", err) - } - chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ") - chainLogger.SetLevel(logLevel) - spvchain.UseLogger(chainLogger) - rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ") - rpcLogger.SetLevel(logLevel) - btcrpcclient.UseLogger(rpcLogger) svc, err := spvchain.NewChainService(config) if err != nil { t.Fatalf("Error creating ChainService: %s", err) @@ -170,6 +177,13 @@ func TestSetup(t *testing.T) { t.Fatalf("Couldn't sync ChainService: %s", err) } + // Test that we can get blocks and cfilters via P2P and decide which are + // valid and which aren't. + err = testRandomBlocks(t, svc, h1) + if err != nil { + t.Fatalf("Testing blocks and cfilters failed: %s", err) + } + // Generate 125 blocks on h1 to make sure it reorgs the other nodes. // Ensure the ChainService instance stays caught up. h1.Node.Generate(125) @@ -209,13 +223,6 @@ func TestSetup(t *testing.T) { if err != nil { t.Fatalf("Couldn't sync ChainService: %s", err) } - - // Test that we can get blocks and cfilters via P2P and decide which are - // valid and which aren't. - err = testRandomBlocks(t, svc, h1) - if err != nil { - t.Fatalf("Testing blocks and cfilters failed: %s", err) - } } // csd does a connect-sync-disconnect between nodes in order to support @@ -377,7 +384,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, // can correctly get filters from them. We don't go through *all* the blocks // because it can be a little slow, but we'll improve that soon-ish hopefully // to the point where we can do it. -// TODO: Improve concurrency on framework side. +// TODO: Make this a benchmark instead. func testRandomBlocks(t *testing.T, svc *spvchain.ChainService, correctSyncNode *rpctest.Harness) error { var haveBest *waddrmgr.BlockStamp @@ -386,180 +393,200 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService, return fmt.Errorf("Couldn't get best snapshot from "+ "ChainService: %s", err) } - // Keep track of an error channel - errChan := make(chan error) - var lastErr error - go func() { - for err := range errChan { - if err != nil { - t.Errorf("%s", err) - lastErr = fmt.Errorf("Couldn't validate all " + - "blocks, filters, and filter headers.") - } - } - }() - // Test getting numTestBlocks random blocks and filters. + // Keep track of an error channel with enough buffer space to track one + // error per block. + errChan := make(chan error, haveBest.Height) + // Test getting all of the blocks and filters. var wg sync.WaitGroup - heights := rand.Perm(int(haveBest.Height)) - for i := 0; i < numTestBlocks; i++ { + workerQueue := make(chan struct{}, numQueryThreads) + for i := int32(1); i <= haveBest.Height; i++ { wg.Add(1) - height := uint32(heights[i]) + height := uint32(i) + // Wait until there's room in the worker queue. + workerQueue <- struct{}{} go func() { + // On exit, open a spot in workerQueue and tell the + // wait group we're done. + defer func() { + <-workerQueue + }() defer wg.Done() // Get block header from database. - blockHeader, blockHeight, err := svc.GetBlockByHeight(height) + blockHeader, blockHeight, err := svc.GetBlockByHeight( + height) if err != nil { errChan <- fmt.Errorf("Couldn't get block "+ "header by height %d: %s", height, err) return } if blockHeight != height { - errChan <- fmt.Errorf("Block height retrieved from DB "+ - "doesn't match expected height. Want: %d, "+ - "have: %d", height, blockHeight) + errChan <- fmt.Errorf("Block height retrieved "+ + "from DB doesn't match expected "+ + "height. Want: %d, have: %d", height, + blockHeight) return } blockHash := blockHeader.BlockHash() // Get block via RPC. - wantBlock, err := correctSyncNode.Node.GetBlock(&blockHash) + wantBlock, err := correctSyncNode.Node.GetBlock( + &blockHash) if err != nil { - errChan <- fmt.Errorf("Couldn't get block %d (%s) by RPC", - height, blockHash) + errChan <- fmt.Errorf("Couldn't get block %d "+ + "(%s) by RPC", height, blockHash) return } // Get block from network. - haveBlock := svc.GetBlockFromNetwork(blockHash) + haveBlock := svc.GetBlockFromNetwork(blockHash, + queryOptions...) if haveBlock == nil { - errChan <- fmt.Errorf("Couldn't get block %d (%s) from"+ - "network", height, blockHash) + errChan <- fmt.Errorf("Couldn't get block %d "+ + "(%s) from network", height, blockHash) return } // Check that network and RPC blocks match. - if !reflect.DeepEqual(*haveBlock.MsgBlock(), *wantBlock) { - errChan <- fmt.Errorf("Block from network doesn't match "+ - "block from RPC. Want: %s, RPC: %s, network: "+ - "%s", blockHash, wantBlock.BlockHash(), + if !reflect.DeepEqual(*haveBlock.MsgBlock(), + *wantBlock) { + errChan <- fmt.Errorf("Block from network "+ + "doesn't match block from RPC. Want: "+ + "%s, RPC: %s, network: %s", blockHash, + wantBlock.BlockHash(), haveBlock.MsgBlock().BlockHash()) return } // Check that block height matches what we have. if int32(blockHeight) != haveBlock.Height() { - errChan <- fmt.Errorf("Block height from network doesn't "+ - "match expected height. Want: %s, network: %s", + errChan <- fmt.Errorf("Block height from "+ + "network doesn't match expected "+ + "height. Want: %s, network: %s", blockHeight, haveBlock.Height()) return } // Get basic cfilter from network. - haveFilter := svc.GetCFilter(blockHash, false) + haveFilter := svc.GetCFilter(blockHash, false, + queryOptions...) if haveFilter == nil { errChan <- fmt.Errorf("Couldn't get basic "+ "filter for block %d", height) return } // Get basic cfilter from RPC. - wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash, - false) + wantFilter, err := correctSyncNode.Node.GetCFilter( + &blockHash, false) if err != nil { - errChan <- fmt.Errorf("Couldn't get basic filter for "+ - "block %d via RPC: %s", height, err) + errChan <- fmt.Errorf("Couldn't get basic "+ + "filter for block %d via RPC: %s", + height, err) return } // Check that network and RPC cfilters match. if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) { - errChan <- fmt.Errorf("Basic filter from P2P network/DB"+ - " doesn't match RPC value for block %d", height) + errChan <- fmt.Errorf("Basic filter from P2P "+ + "network/DB doesn't match RPC value "+ + "for block %d", height) return } // Calculate basic filter from block. calcFilter, err := spvchain.BuildBasicFilter( haveBlock.MsgBlock()) if err != nil { - errChan <- fmt.Errorf("Couldn't build basic filter for "+ - "block %d (%s): %s", height, blockHash, err) + errChan <- fmt.Errorf("Couldn't build basic "+ + "filter for block %d (%s): %s", height, + blockHash, err) return } - // Check that the network value matches the calculated value - // from the block. + // Check that the network value matches the calculated + // value from the block. if !reflect.DeepEqual(*haveFilter, *calcFilter) { - errChan <- fmt.Errorf("Basic filter from P2P network/DB "+ - "doesn't match calculated value for block %d", - height) + errChan <- fmt.Errorf("Basic filter from P2P "+ + "network/DB doesn't match calculated "+ + "value for block %d", height) return } // Get previous basic filter header from the database. - prevHeader, err := svc.GetBasicHeader(blockHeader.PrevBlock) + prevHeader, err := svc.GetBasicHeader( + blockHeader.PrevBlock) if err != nil { - errChan <- fmt.Errorf("Couldn't get basic filter header "+ - "for block %d (%s) from DB: %s", height-1, + errChan <- fmt.Errorf("Couldn't get basic "+ + "filter header for block %d (%s) from "+ + "DB: %s", height-1, blockHeader.PrevBlock, err) return } // Get current basic filter header from the database. curHeader, err := svc.GetBasicHeader(blockHash) if err != nil { - errChan <- fmt.Errorf("Couldn't get basic filter header "+ - "for block %d (%s) from DB: %s", height-1, - blockHash, err) + errChan <- fmt.Errorf("Couldn't get basic "+ + "filter header for block %d (%s) from "+ + "DB: %s", height-1, blockHash, err) return } // Check that the filter and header line up. calcHeader := spvchain.MakeHeaderForFilter(calcFilter, *prevHeader) if !bytes.Equal(curHeader[:], calcHeader[:]) { - errChan <- fmt.Errorf("Filter header doesn't match. Want: "+ - "%s, got: %s", curHeader, calcHeader) + errChan <- fmt.Errorf("Filter header doesn't "+ + "match. Want: %s, got: %s", curHeader, + calcHeader) return } // Get extended cfilter from network - haveFilter = svc.GetCFilter(blockHash, true) + haveFilter = svc.GetCFilter(blockHash, true, + queryOptions...) if haveFilter == nil { errChan <- fmt.Errorf("Couldn't get extended "+ "filter for block %d", height) return } // Get extended cfilter from RPC - wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash, - true) + wantFilter, err = correctSyncNode.Node.GetCFilter( + &blockHash, true) if err != nil { - errChan <- fmt.Errorf("Couldn't get extended filter for "+ - "block %d via RPC: %s", height, err) + errChan <- fmt.Errorf("Couldn't get extended "+ + "filter for block %d via RPC: %s", + height, err) return } // Check that network and RPC cfilters match if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) { - errChan <- fmt.Errorf("Extended filter from P2P network/DB"+ - " doesn't match RPC value for block %d", height) + errChan <- fmt.Errorf("Extended filter from "+ + "P2P network/DB doesn't match RPC "+ + "value for block %d", height) return } // Calculate extended filter from block calcFilter, err = spvchain.BuildExtFilter( haveBlock.MsgBlock()) if err != nil { - errChan <- fmt.Errorf("Couldn't build extended filter for "+ - "block %d (%s): %s", height, blockHash, err) + errChan <- fmt.Errorf("Couldn't build extended"+ + " filter for block %d (%s): %s", height, + blockHash, err) return } - // Check that the network value matches the calculated value - // from the block. + // Check that the network value matches the calculated + // value from the block. if !reflect.DeepEqual(*haveFilter, *calcFilter) { - errChan <- fmt.Errorf("Extended filter from P2P network/DB"+ - " doesn't match calculated value for block %d", - height) + errChan <- fmt.Errorf("Extended filter from "+ + "P2P network/DB doesn't match "+ + "calculated value for block %d", height) return } - // Get previous extended filter header from the database. - prevHeader, err = svc.GetExtHeader(blockHeader.PrevBlock) + // Get previous extended filter header from the + // database. + prevHeader, err = svc.GetExtHeader( + blockHeader.PrevBlock) if err != nil { - errChan <- fmt.Errorf("Couldn't get extended filter header"+ - " for block %d (%s) from DB: %s", height-1, + errChan <- fmt.Errorf("Couldn't get extended "+ + "filter header for block %d (%s) from "+ + "DB: %s", height-1, blockHeader.PrevBlock, err) return } // Get current basic filter header from the database. curHeader, err = svc.GetExtHeader(blockHash) if err != nil { - errChan <- fmt.Errorf("Couldn't get extended filter header"+ - " for block %d (%s) from DB: %s", height-1, + errChan <- fmt.Errorf("Couldn't get extended "+ + "filter header for block %d (%s) from "+ + "DB: %s", height-1, blockHash, err) return } @@ -567,20 +594,29 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService, calcHeader = spvchain.MakeHeaderForFilter(calcFilter, *prevHeader) if !bytes.Equal(curHeader[:], calcHeader[:]) { - errChan <- fmt.Errorf("Filter header doesn't match. Want: "+ - "%s, got: %s", curHeader, calcHeader) + errChan <- fmt.Errorf("Filter header doesn't "+ + "match. Want: %s, got: %s", curHeader, + calcHeader) return } }() } // Wait for all queries to finish. wg.Wait() - if logLevel != btclog.Off { - t.Logf("Finished checking %d blocks and their cfilters", - numTestBlocks) - } // Close the error channel to make the error monitoring goroutine // finish. close(errChan) + var lastErr error + for err := range errChan { + if err != nil { + t.Errorf("%s", err) + lastErr = fmt.Errorf("Couldn't validate all " + + "blocks, filters, and filter headers.") + } + } + if logLevel != btclog.Off { + t.Logf("Finished checking %d blocks and their cfilters", + haveBest.Height) + } return lastErr }