Fix concurrency issues.
This commit is contained in:
parent
c7b26a11e2
commit
fe632ff233
4 changed files with 531 additions and 439 deletions
|
@ -1031,9 +1031,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) {
|
||||||
// Should probably use better isolation for this but we're in
|
// Should probably use better isolation for this but we're in
|
||||||
// the same package. One of the things to clean up when we do
|
// the same package. One of the things to clean up when we do
|
||||||
// more general cleanup.
|
// more general cleanup.
|
||||||
|
sp.mtxReqCFH.Lock()
|
||||||
sp.requestedCFHeaders[cfhReqB] = cfhCount
|
sp.requestedCFHeaders[cfhReqB] = cfhCount
|
||||||
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false)
|
|
||||||
sp.requestedCFHeaders[cfhReqE] = cfhCount
|
sp.requestedCFHeaders[cfhReqE] = cfhCount
|
||||||
|
sp.mtxReqCFH.Unlock()
|
||||||
|
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false)
|
||||||
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true)
|
sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -1078,13 +1080,20 @@ func (b *blockManager) QueueCFHeaders(cfheaders *wire.MsgCFHeaders,
|
||||||
extended: cfheaders.Extended,
|
extended: cfheaders.Extended,
|
||||||
stopHash: cfheaders.StopHash,
|
stopHash: cfheaders.StopHash,
|
||||||
}
|
}
|
||||||
if sp.requestedCFHeaders[req] != len(cfheaders.HeaderHashes) {
|
// TODO: Get rid of this by refactoring all of this using the query API
|
||||||
|
sp.mtxReqCFH.Lock()
|
||||||
|
expLen := sp.requestedCFHeaders[req]
|
||||||
|
sp.mtxReqCFH.Unlock()
|
||||||
|
if expLen != len(cfheaders.HeaderHashes) {
|
||||||
log.Warnf("Received cfheaders message doesn't match any "+
|
log.Warnf("Received cfheaders message doesn't match any "+
|
||||||
"getcfheaders request. Peer %s is probably on a "+
|
"getcfheaders request. Peer %s is probably on a "+
|
||||||
"different chain -- ignoring", sp.Addr())
|
"different chain -- ignoring", sp.Addr())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// TODO: Remove this by refactoring this section into a query client.
|
||||||
|
sp.mtxReqCFH.Lock()
|
||||||
delete(sp.requestedCFHeaders, req)
|
delete(sp.requestedCFHeaders, req)
|
||||||
|
sp.mtxReqCFH.Unlock()
|
||||||
|
|
||||||
// Track number of pending cfheaders messsages for both basic and
|
// Track number of pending cfheaders messsages for both basic and
|
||||||
// extended filters.
|
// extended filters.
|
||||||
|
|
|
@ -8,13 +8,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/addrmgr"
|
"github.com/btcsuite/btcd/addrmgr"
|
||||||
"github.com/btcsuite/btcd/blockchain"
|
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
||||||
"github.com/btcsuite/btcd/connmgr"
|
"github.com/btcsuite/btcd/connmgr"
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
"github.com/btcsuite/btcutil"
|
|
||||||
"github.com/btcsuite/btcutil/gcs"
|
|
||||||
"github.com/btcsuite/btcutil/gcs/builder"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type getConnCountMsg struct {
|
type getConnCountMsg struct {
|
||||||
|
@ -54,19 +48,6 @@ type forAllPeersMsg struct {
|
||||||
closure func(*serverPeer)
|
closure func(*serverPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
type getCFilterMsg struct {
|
|
||||||
cfRequest
|
|
||||||
prevHeader *chainhash.Hash
|
|
||||||
curHeader *chainhash.Hash
|
|
||||||
reply chan *gcs.Filter
|
|
||||||
}
|
|
||||||
|
|
||||||
type getBlockMsg struct {
|
|
||||||
blockHeader *wire.BlockHeader
|
|
||||||
height uint32
|
|
||||||
reply chan *btcutil.Block
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: General - abstract out more of blockmanager into queries. It'll make
|
// TODO: General - abstract out more of blockmanager into queries. It'll make
|
||||||
// this way more maintainable and usable.
|
// this way more maintainable and usable.
|
||||||
|
|
||||||
|
@ -172,128 +153,12 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
|
||||||
|
|
||||||
msg.reply <- errors.New("peer not found")
|
msg.reply <- errors.New("peer not found")
|
||||||
case forAllPeersMsg:
|
case forAllPeersMsg:
|
||||||
|
// TODO: Remove this when it's unnecessary due to wider use of
|
||||||
|
// queryPeers.
|
||||||
// Run the closure on all peers in the passed state.
|
// Run the closure on all peers in the passed state.
|
||||||
state.forAllPeers(msg.closure)
|
state.forAllPeers(msg.closure)
|
||||||
// Even though this is a query, there's no reply channel as the
|
// Even though this is a query, there's no reply channel as the
|
||||||
// forAllPeers method doesn't return anything. An error might be
|
// forAllPeers method doesn't return anything. An error might be
|
||||||
// useful in the future.
|
// useful in the future.
|
||||||
case getCFilterMsg:
|
|
||||||
found := false
|
|
||||||
state.queryPeers(
|
|
||||||
// Should we query this peer?
|
|
||||||
func(sp *serverPeer) bool {
|
|
||||||
// Don't send requests to disconnected peers.
|
|
||||||
return sp.Connected()
|
|
||||||
},
|
|
||||||
// Send a wire.GetCFilterMsg
|
|
||||||
wire.NewMsgGetCFilter(&msg.blockHash, msg.extended),
|
|
||||||
// Check responses and if we get one that matches,
|
|
||||||
// end the query early.
|
|
||||||
func(sp *serverPeer, resp wire.Message,
|
|
||||||
quit chan<- struct{}) {
|
|
||||||
switch response := resp.(type) {
|
|
||||||
// We're only interested in "cfilter" messages.
|
|
||||||
case *wire.MsgCFilter:
|
|
||||||
if len(response.Data) < 4 {
|
|
||||||
// Filter data is too short.
|
|
||||||
// Ignore this message.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
filter, err :=
|
|
||||||
gcs.FromNBytes(builder.DefaultP,
|
|
||||||
response.Data)
|
|
||||||
if err != nil {
|
|
||||||
// Malformed filter data. We
|
|
||||||
// can ignore this message.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if MakeHeaderForFilter(filter,
|
|
||||||
*msg.prevHeader) !=
|
|
||||||
*msg.curHeader {
|
|
||||||
// Filter data doesn't match
|
|
||||||
// the headers we know about.
|
|
||||||
// Ignore this response.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// At this point, the filter matches
|
|
||||||
// what we know about it and we declare
|
|
||||||
// it sane. We can kill the query and
|
|
||||||
// pass the response back to the caller.
|
|
||||||
found = true
|
|
||||||
close(quit)
|
|
||||||
msg.reply <- filter
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
// We timed out without finding a correct answer to our query.
|
|
||||||
if !found {
|
|
||||||
msg.reply <- nil
|
|
||||||
}
|
|
||||||
case getBlockMsg:
|
|
||||||
found := false
|
|
||||||
getData := wire.NewMsgGetData()
|
|
||||||
blockHash := msg.blockHeader.BlockHash()
|
|
||||||
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
|
|
||||||
&blockHash))
|
|
||||||
state.queryPeers(
|
|
||||||
// Should we query this peer?
|
|
||||||
func(sp *serverPeer) bool {
|
|
||||||
// Don't send requests to disconnected peers.
|
|
||||||
return sp.Connected()
|
|
||||||
},
|
|
||||||
// Send a wire.GetCFilterMsg
|
|
||||||
getData,
|
|
||||||
// Check responses and if we get one that matches,
|
|
||||||
// end the query early.
|
|
||||||
func(sp *serverPeer, resp wire.Message,
|
|
||||||
quit chan<- struct{}) {
|
|
||||||
switch response := resp.(type) {
|
|
||||||
// We're only interested in "block" messages.
|
|
||||||
case *wire.MsgBlock:
|
|
||||||
// If this isn't our block, ignore it.
|
|
||||||
if response.BlockHash() !=
|
|
||||||
blockHash {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
block := btcutil.NewBlock(response)
|
|
||||||
// Only set height if btcutil hasn't
|
|
||||||
// automagically put one in.
|
|
||||||
if block.Height() ==
|
|
||||||
btcutil.BlockHeightUnknown {
|
|
||||||
block.SetHeight(
|
|
||||||
int32(msg.height))
|
|
||||||
}
|
|
||||||
// If this claims our block but doesn't
|
|
||||||
// pass the sanity check, the peer is
|
|
||||||
// trying to bamboozle us. Disconnect
|
|
||||||
// it.
|
|
||||||
if err := blockchain.CheckBlockSanity(
|
|
||||||
block,
|
|
||||||
// We don't need to check PoW
|
|
||||||
// because by the time we get
|
|
||||||
// here, it's been checked
|
|
||||||
// during header synchronization
|
|
||||||
s.chainParams.PowLimit,
|
|
||||||
s.timeSource,
|
|
||||||
); err != nil {
|
|
||||||
log.Warnf("Invalid block for "+
|
|
||||||
"%s received from %s "+
|
|
||||||
"-- disconnecting peer",
|
|
||||||
blockHash, sp.Addr())
|
|
||||||
sp.Disconnect()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
found = true
|
|
||||||
close(quit)
|
|
||||||
msg.reply <- block
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
// We timed out without finding a correct answer to our query.
|
|
||||||
if !found {
|
|
||||||
msg.reply <- nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/btcsuite/btcutil/gcs"
|
"github.com/btcsuite/btcutil/gcs"
|
||||||
|
"github.com/btcsuite/btcutil/gcs/builder"
|
||||||
"github.com/btcsuite/btcwallet/waddrmgr"
|
"github.com/btcsuite/btcwallet/waddrmgr"
|
||||||
"github.com/btcsuite/btcwallet/wallet"
|
"github.com/btcsuite/btcwallet/wallet"
|
||||||
"github.com/btcsuite/btcwallet/walletdb"
|
"github.com/btcsuite/btcwallet/walletdb"
|
||||||
|
@ -63,8 +64,14 @@ var (
|
||||||
// from DNS.
|
// from DNS.
|
||||||
DisableDNSSeed = false
|
DisableDNSSeed = false
|
||||||
|
|
||||||
// Timeout specifies how long to wait for a peer to answer a query.
|
// QueryTimeout specifies how long to wait for a peer to answer a query.
|
||||||
Timeout = time.Second * 5
|
QueryTimeout = time.Second * 3
|
||||||
|
|
||||||
|
// QueryNumRetries specifies how many times to retry sending a query to
|
||||||
|
// each peer before we've concluded we aren't going to get a valid
|
||||||
|
// response. This allows to make up for missed messages in some
|
||||||
|
// instances.
|
||||||
|
QueryNumRetries = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
// updatePeerHeightsMsg is a message sent from the blockmanager to the server
|
// updatePeerHeightsMsg is a message sent from the blockmanager to the server
|
||||||
|
@ -110,137 +117,6 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) {
|
||||||
ps.forAllOutboundPeers(closure)
|
ps.forAllOutboundPeers(closure)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query options can be modified per-query, unlike global options.
|
|
||||||
// TODO: Make more query options that override global options.
|
|
||||||
type queryOptions struct {
|
|
||||||
// queryTimeout lets the query know how long to wait for a peer to
|
|
||||||
// answer the query before moving onto the next peer.
|
|
||||||
queryTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultQueryOptions returns a queryOptions set to package-level defaults.
|
|
||||||
func defaultQueryOptions() *queryOptions {
|
|
||||||
return &queryOptions{
|
|
||||||
queryTimeout: Timeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryTimeout is a query option that lets the query know to ask each peer we're
|
|
||||||
// connected to for its opinion, if any. By default, we only ask peers until one
|
|
||||||
// gives us a valid response.
|
|
||||||
func QueryTimeout(timeout time.Duration) func(*queryOptions) {
|
|
||||||
return func(qo *queryOptions) {
|
|
||||||
qo.queryTimeout = timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type spMsg struct {
|
|
||||||
sp *serverPeer
|
|
||||||
msg wire.Message
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryPeers is a helper function that sends a query to one or more peers and
|
|
||||||
// waits for an answer. The timeout for queries is set by the QueryTimeout
|
|
||||||
// package-level variable.
|
|
||||||
func (ps *peerState) queryPeers(
|
|
||||||
// selectPeer is a closure which decides whether or not to send the
|
|
||||||
// query to the peer.
|
|
||||||
selectPeer func(sp *serverPeer) bool,
|
|
||||||
// queryMsg is the message to send to each peer selected by selectPeer.
|
|
||||||
queryMsg wire.Message,
|
|
||||||
// checkResponse is caled for every message within the timeout period.
|
|
||||||
// The quit channel lets the query know to terminate because the
|
|
||||||
// required response has been found. This is done by closing the
|
|
||||||
// channel.
|
|
||||||
checkResponse func(sp *serverPeer, resp wire.Message,
|
|
||||||
quit chan<- struct{}),
|
|
||||||
// options takes functional options for executing the query.
|
|
||||||
options ...func(*queryOptions),
|
|
||||||
) {
|
|
||||||
qo := defaultQueryOptions()
|
|
||||||
for _, option := range options {
|
|
||||||
option(qo)
|
|
||||||
}
|
|
||||||
// This will be shared state between the per-peer goroutines.
|
|
||||||
quit := make(chan struct{})
|
|
||||||
startQuery := make(chan struct{})
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
channel := make(chan spMsg)
|
|
||||||
|
|
||||||
// This goroutine will monitor all messages from all peers until the
|
|
||||||
// peer goroutines all exit.
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-quit:
|
|
||||||
close(channel)
|
|
||||||
ps.forAllPeers(
|
|
||||||
func(sp *serverPeer) {
|
|
||||||
sp.unsubscribeRecvMsgs(channel)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
case sm := <-channel:
|
|
||||||
// TODO: This will get stuck if checkResponse
|
|
||||||
// gets stuck.
|
|
||||||
checkResponse(sm.sp, sm.msg, quit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start a goroutine for each peer that potentially queries each peer
|
|
||||||
ps.forAllPeers(func(sp *serverPeer) {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
if !selectPeer(sp) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
timeout := make(<-chan time.Time)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
// After timeout, we return and notify
|
|
||||||
// another goroutine that we've done so.
|
|
||||||
// We only send if there's someone left
|
|
||||||
// to receive.
|
|
||||||
startQuery <- struct{}{}
|
|
||||||
return
|
|
||||||
case <-quit:
|
|
||||||
// After we're told to quit, we return.
|
|
||||||
return
|
|
||||||
case <-startQuery:
|
|
||||||
// We're the lucky peer whose turn it is
|
|
||||||
// to try to answer the current query.
|
|
||||||
// TODO: Fix this to support multiple
|
|
||||||
// queries at once. For now, we're
|
|
||||||
// relying on the query handling loop
|
|
||||||
// to make sure we don't interrupt
|
|
||||||
// another query. We need broadcast
|
|
||||||
// support in OnRead to do this right.
|
|
||||||
// TODO: Fix this to support either
|
|
||||||
// querying *all* peers simultaneously
|
|
||||||
// to avoid timeout delays, or starting
|
|
||||||
// with the syncPeer when not querying
|
|
||||||
// *all* peers.
|
|
||||||
sp.subscribeRecvMsg(channel)
|
|
||||||
sp.QueueMessage(queryMsg, nil)
|
|
||||||
timeout = time.After(qo.queryTimeout)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
startQuery <- struct{}{}
|
|
||||||
wg.Wait()
|
|
||||||
// If we timed out and didn't quit, make sure our response monitor
|
|
||||||
// goroutine knows to quit.
|
|
||||||
select {
|
|
||||||
case <-quit:
|
|
||||||
default:
|
|
||||||
close(quit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cfhRequest records which cfheaders we've requested, and the order in which
|
// cfhRequest records which cfheaders we've requested, and the order in which
|
||||||
// we've requested them. Since there's no way to associate the cfheaders to the
|
// we've requested them. Since there's no way to associate the cfheaders to the
|
||||||
// actual block hashes based on the cfheaders message to keep it compact, we
|
// actual block hashes based on the cfheaders message to keep it compact, we
|
||||||
|
@ -250,12 +126,6 @@ type cfhRequest struct {
|
||||||
stopHash chainhash.Hash
|
stopHash chainhash.Hash
|
||||||
}
|
}
|
||||||
|
|
||||||
// cfRequest records which cfilters we've requested.
|
|
||||||
type cfRequest struct {
|
|
||||||
extended bool
|
|
||||||
blockHash chainhash.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverPeer extends the peer to maintain state shared by the server and
|
// serverPeer extends the peer to maintain state shared by the server and
|
||||||
// the blockmanager.
|
// the blockmanager.
|
||||||
type serverPeer struct {
|
type serverPeer struct {
|
||||||
|
@ -264,24 +134,27 @@ type serverPeer struct {
|
||||||
|
|
||||||
*peer.Peer
|
*peer.Peer
|
||||||
|
|
||||||
connReq *connmgr.ConnReq
|
connReq *connmgr.ConnReq
|
||||||
server *ChainService
|
server *ChainService
|
||||||
persistent bool
|
persistent bool
|
||||||
continueHash *chainhash.Hash
|
continueHash *chainhash.Hash
|
||||||
relayMtx sync.Mutex
|
relayMtx sync.Mutex
|
||||||
requestQueue []*wire.InvVect
|
requestQueue []*wire.InvVect
|
||||||
requestedCFHeaders map[cfhRequest]int
|
knownAddresses map[string]struct{}
|
||||||
knownAddresses map[string]struct{}
|
banScore connmgr.DynamicBanScore
|
||||||
banScore connmgr.DynamicBanScore
|
quit chan struct{}
|
||||||
quit chan struct{}
|
|
||||||
// The following slice of channels is used to subscribe to messages from
|
// The following slice of channels is used to subscribe to messages from
|
||||||
// the peer. This allows broadcast to multiple subscribers at once,
|
// the peer. This allows broadcast to multiple subscribers at once,
|
||||||
// allowing for multiple queries to be going to multiple peers at any
|
// allowing for multiple queries to be going to multiple peers at any
|
||||||
// one time. The mutex is for subscribe/unsubscribe functionality.
|
// one time. The mutex is for subscribe/unsubscribe functionality.
|
||||||
// The sends on these channels WILL NOT block; any messages the channel
|
// The sends on these channels WILL NOT block; any messages the channel
|
||||||
// can't accept will be dropped silently.
|
// can't accept will be dropped silently.
|
||||||
recvSubscribers []chan<- spMsg
|
recvSubscribers map[spMsgSubscription]struct{}
|
||||||
mtxSubscribers sync.RWMutex
|
mtxSubscribers sync.RWMutex
|
||||||
|
// These are only necessary until the cfheaders logic is refactored as
|
||||||
|
// a query client.
|
||||||
|
requestedCFHeaders map[cfhRequest]int
|
||||||
|
mtxReqCFH sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServerPeer returns a new serverPeer instance. The peer needs to be set by
|
// newServerPeer returns a new serverPeer instance. The peer needs to be set by
|
||||||
|
@ -293,6 +166,7 @@ func newServerPeer(s *ChainService, isPersistent bool) *serverPeer {
|
||||||
requestedCFHeaders: make(map[cfhRequest]int),
|
requestedCFHeaders: make(map[cfhRequest]int),
|
||||||
knownAddresses: make(map[string]struct{}),
|
knownAddresses: make(map[string]struct{}),
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
|
recvSubscribers: make(map[spMsgSubscription]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -615,42 +489,38 @@ func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message,
|
||||||
err error) {
|
err error) {
|
||||||
sp.server.AddBytesReceived(uint64(bytesRead))
|
sp.server.AddBytesReceived(uint64(bytesRead))
|
||||||
// Try to send a message to the subscriber channel if it isn't nil, but
|
// Try to send a message to the subscriber channel if it isn't nil, but
|
||||||
// don't block on failure.
|
// don't block on failure. Do this inside a goroutine to prevent the
|
||||||
|
// server from slowing down too fast.
|
||||||
sp.mtxSubscribers.RLock()
|
sp.mtxSubscribers.RLock()
|
||||||
defer sp.mtxSubscribers.RUnlock()
|
defer sp.mtxSubscribers.RUnlock()
|
||||||
for _, channel := range sp.recvSubscribers {
|
for subscription := range sp.recvSubscribers {
|
||||||
if channel != nil {
|
subscription.wg.Add(1)
|
||||||
|
go func(subscription spMsgSubscription) {
|
||||||
|
defer subscription.wg.Done()
|
||||||
select {
|
select {
|
||||||
case channel <- spMsg{
|
case <-subscription.quitChan:
|
||||||
sp: sp,
|
case subscription.msgChan <- spMsg{
|
||||||
msg: msg,
|
msg: msg,
|
||||||
|
sp: sp,
|
||||||
}:
|
}:
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
}
|
}(subscription)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// subscribeRecvMsg handles adding OnRead subscriptions to the server peer.
|
// subscribeRecvMsg handles adding OnRead subscriptions to the server peer.
|
||||||
func (sp *serverPeer) subscribeRecvMsg(channel chan<- spMsg) {
|
func (sp *serverPeer) subscribeRecvMsg(subscription spMsgSubscription) {
|
||||||
sp.mtxSubscribers.Lock()
|
sp.mtxSubscribers.Lock()
|
||||||
defer sp.mtxSubscribers.Unlock()
|
defer sp.mtxSubscribers.Unlock()
|
||||||
sp.recvSubscribers = append(sp.recvSubscribers, channel)
|
sp.recvSubscribers[subscription] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// unsubscribeRecvMsgs handles removing OnRead subscriptions from the server
|
// unsubscribeRecvMsgs handles removing OnRead subscriptions from the server
|
||||||
// peer.
|
// peer.
|
||||||
func (sp *serverPeer) unsubscribeRecvMsgs(channel chan<- spMsg) {
|
func (sp *serverPeer) unsubscribeRecvMsgs(subscription spMsgSubscription) {
|
||||||
sp.mtxSubscribers.Lock()
|
sp.mtxSubscribers.Lock()
|
||||||
defer sp.mtxSubscribers.Unlock()
|
defer sp.mtxSubscribers.Unlock()
|
||||||
var updatedSubscribers []chan<- spMsg
|
delete(sp.recvSubscribers, subscription)
|
||||||
for _, candidate := range sp.recvSubscribers {
|
|
||||||
if candidate != channel {
|
|
||||||
updatedSubscribers = append(updatedSubscribers,
|
|
||||||
candidate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sp.recvSubscribers = updatedSubscribers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnWrite is invoked when a peer sends a message and it is used to update
|
// OnWrite is invoked when a peer sends a message and it is used to update
|
||||||
|
@ -683,9 +553,6 @@ type ChainService struct {
|
||||||
timeSource blockchain.MedianTimeSource
|
timeSource blockchain.MedianTimeSource
|
||||||
services wire.ServiceFlag
|
services wire.ServiceFlag
|
||||||
|
|
||||||
cfilterRequests map[cfRequest][]chan *gcs.Filter
|
|
||||||
cfRequestHeaders map[cfRequest][2]*chainhash.Hash
|
|
||||||
|
|
||||||
userAgentName string
|
userAgentName string
|
||||||
userAgentVersion string
|
userAgentVersion string
|
||||||
}
|
}
|
||||||
|
@ -894,8 +761,6 @@ func NewChainService(cfg Config) (*ChainService, error) {
|
||||||
services: Services,
|
services: Services,
|
||||||
userAgentName: UserAgentName,
|
userAgentName: UserAgentName,
|
||||||
userAgentVersion: UserAgentVersion,
|
userAgentVersion: UserAgentVersion,
|
||||||
cfilterRequests: make(map[cfRequest][]chan *gcs.Filter),
|
|
||||||
cfRequestHeaders: make(map[cfRequest][2]*chainhash.Hash),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := createSPVNS(s.namespace, &s.chainParams)
|
err := createSPVNS(s.namespace, &s.chainParams)
|
||||||
|
@ -1686,10 +1551,238 @@ func (s *ChainService) IsCurrent() bool {
|
||||||
return s.blockManager.IsCurrent()
|
return s.blockManager.IsCurrent()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Query options can be modified per-query, unlike global options.
|
||||||
|
// TODO: Make more query options that override global options.
|
||||||
|
type queryOptions struct {
|
||||||
|
// timeout lets the query know how long to wait for a peer to
|
||||||
|
// answer the query before moving onto the next peer.
|
||||||
|
timeout time.Duration
|
||||||
|
|
||||||
|
// numRetries tells the query how many times to retry asking each peer
|
||||||
|
// the query.
|
||||||
|
numRetries uint8
|
||||||
|
|
||||||
|
// doneChan lets the query signal the caller when it's done, in case
|
||||||
|
// it's run in a goroutine.
|
||||||
|
doneChan chan<- struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOption is a functional option argument to any of the network query
|
||||||
|
// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a
|
||||||
|
// network query).
|
||||||
|
type QueryOption func(*queryOptions)
|
||||||
|
|
||||||
|
// defaultQueryOptions returns a queryOptions set to package-level defaults.
|
||||||
|
func defaultQueryOptions() *queryOptions {
|
||||||
|
return &queryOptions{
|
||||||
|
timeout: QueryTimeout,
|
||||||
|
numRetries: uint8(QueryNumRetries),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout is a query option that lets the query know how long to wait for
|
||||||
|
// each peer we ask the query to answer it before moving on.
|
||||||
|
func Timeout(timeout time.Duration) QueryOption {
|
||||||
|
return func(qo *queryOptions) {
|
||||||
|
qo.timeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumRetries is a query option that lets the query know the maximum number of
|
||||||
|
// times each peer should be queried. The default is one.
|
||||||
|
func NumRetries(numRetries uint8) QueryOption {
|
||||||
|
return func(qo *queryOptions) {
|
||||||
|
qo.numRetries = numRetries
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoneChan allows the caller to pass a channel that will get closed when the
|
||||||
|
// query is finished.
|
||||||
|
func DoneChan(doneChan chan<- struct{}) QueryOption {
|
||||||
|
return func(qo *queryOptions) {
|
||||||
|
qo.doneChan = doneChan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type spMsg struct {
|
||||||
|
sp *serverPeer
|
||||||
|
msg wire.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
type spMsgSubscription struct {
|
||||||
|
msgChan chan<- spMsg
|
||||||
|
quitChan <-chan struct{}
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryPeers is a helper function that sends a query to one or more peers and
|
||||||
|
// waits for an answer. The timeout for queries is set by the QueryTimeout
|
||||||
|
// package-level variable.
|
||||||
|
func (s *ChainService) queryPeers(
|
||||||
|
// queryMsg is the message to send to each peer selected by selectPeer.
|
||||||
|
queryMsg wire.Message,
|
||||||
|
// checkResponse is caled for every message within the timeout period.
|
||||||
|
// The quit channel lets the query know to terminate because the
|
||||||
|
// required response has been found. This is done by closing the
|
||||||
|
// channel.
|
||||||
|
checkResponse func(sp *serverPeer, resp wire.Message,
|
||||||
|
quit chan<- struct{}),
|
||||||
|
// options takes functional options for executing the query.
|
||||||
|
options ...QueryOption,
|
||||||
|
) {
|
||||||
|
qo := defaultQueryOptions()
|
||||||
|
for _, option := range options {
|
||||||
|
option(qo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is done in a single-threaded query because the peerState is held
|
||||||
|
// in a single thread. This is the only part of the query framework that
|
||||||
|
// requires access to peerState, so it's done once per query.
|
||||||
|
peers := s.Peers()
|
||||||
|
|
||||||
|
// This will be shared state between the per-peer goroutines.
|
||||||
|
quit := make(chan struct{})
|
||||||
|
allQuit := make(chan struct{})
|
||||||
|
startQuery := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
// Increase this number to be able to handle more queries at once as
|
||||||
|
// each channel gets results for all queries, otherwise messages can
|
||||||
|
// get mixed and there's a vicious cycle of retries causing a bigger
|
||||||
|
// message flood, more of which get missed.
|
||||||
|
msgChan := make(chan spMsg)
|
||||||
|
var subwg sync.WaitGroup
|
||||||
|
subscription := spMsgSubscription{
|
||||||
|
msgChan: msgChan,
|
||||||
|
quitChan: allQuit,
|
||||||
|
wg: &subwg,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a goroutine for each peer that potentially queries that peer.
|
||||||
|
for _, sp := range peers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(sp *serverPeer) {
|
||||||
|
numRetries := qo.numRetries
|
||||||
|
defer wg.Done()
|
||||||
|
defer sp.unsubscribeRecvMsgs(subscription)
|
||||||
|
// Should we do this when the goroutine gets a message
|
||||||
|
// via startQuery rather than at the launch of the
|
||||||
|
// goroutine?
|
||||||
|
if !sp.Connected() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
timeout := make(<-chan time.Time)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
// After timeout, we try to notify
|
||||||
|
// another of our peer goroutines to
|
||||||
|
// do a query until we get a signal to
|
||||||
|
// quit.
|
||||||
|
select {
|
||||||
|
case startQuery <- struct{}{}:
|
||||||
|
case <-quit:
|
||||||
|
return
|
||||||
|
case <-allQuit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// At this point, we've sent startQuery.
|
||||||
|
// We return if we've run through this
|
||||||
|
// section of code numRetries times.
|
||||||
|
if numRetries--; numRetries == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-quit:
|
||||||
|
// After we're told to quit, we return.
|
||||||
|
return
|
||||||
|
case <-allQuit:
|
||||||
|
// After we're told to quit, we return.
|
||||||
|
return
|
||||||
|
case <-startQuery:
|
||||||
|
// We're the lucky peer whose turn it is
|
||||||
|
// to try to answer the current query.
|
||||||
|
// TODO: Fix this to support either
|
||||||
|
// querying *all* peers simultaneously
|
||||||
|
// to avoid timeout delays, or starting
|
||||||
|
// with the syncPeer when not querying
|
||||||
|
// *all* peers.
|
||||||
|
sp.subscribeRecvMsg(subscription)
|
||||||
|
// Don't want the peer hanging on send
|
||||||
|
// to the channel if we quit before
|
||||||
|
// reading the channel.
|
||||||
|
sentChan := make(chan struct{}, 1)
|
||||||
|
sp.QueueMessage(queryMsg, sentChan)
|
||||||
|
select {
|
||||||
|
case <-sentChan:
|
||||||
|
case <-quit:
|
||||||
|
return
|
||||||
|
case <-allQuit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
timeout = time.After(qo.timeout)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(sp)
|
||||||
|
}
|
||||||
|
startQuery <- struct{}{}
|
||||||
|
|
||||||
|
// This goroutine will wait until all of the peer-query goroutines have
|
||||||
|
// terminated, and then initiate a query shutdown.
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
// If we timed out on each goroutine and didn't quit or time out
|
||||||
|
// on the main goroutine, make sure our main goroutine knows to
|
||||||
|
// quit.
|
||||||
|
select {
|
||||||
|
case <-allQuit:
|
||||||
|
default:
|
||||||
|
close(allQuit)
|
||||||
|
}
|
||||||
|
// Close the done channel, if any
|
||||||
|
if qo.doneChan != nil {
|
||||||
|
close(qo.doneChan)
|
||||||
|
}
|
||||||
|
// Wait until all goroutines started by subscriptions have
|
||||||
|
// exited after we closed allQuit before letting the message
|
||||||
|
// channel get garbage collected.
|
||||||
|
subwg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Loop for any messages sent to us via our subscription channel and
|
||||||
|
// check them for whether they satisfy the query. Break the loop if it's
|
||||||
|
// time to quit.
|
||||||
|
timeout := time.After(time.Duration(len(peers)+1) *
|
||||||
|
qo.timeout * time.Duration(qo.numRetries))
|
||||||
|
checkResponses:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
// When we time out, close the allQuit channel
|
||||||
|
// if it hasn't already been closed.
|
||||||
|
select {
|
||||||
|
case <-allQuit:
|
||||||
|
default:
|
||||||
|
close(allQuit)
|
||||||
|
}
|
||||||
|
break checkResponses
|
||||||
|
case <-quit:
|
||||||
|
break checkResponses
|
||||||
|
case <-allQuit:
|
||||||
|
break checkResponses
|
||||||
|
case sm := <-msgChan:
|
||||||
|
// TODO: This will get stuck if checkResponse
|
||||||
|
// gets stuck. This is a caveat for callers that
|
||||||
|
// should be fixed before exposing this function
|
||||||
|
// for public use.
|
||||||
|
checkResponse(sm.sp, sm.msg, quit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetCFilter gets a cfilter from the database. Failing that, it requests the
|
// GetCFilter gets a cfilter from the database. Failing that, it requests the
|
||||||
// cfilter from the network and writes it to the database.
|
// cfilter from the network and writes it to the database.
|
||||||
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
|
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
|
||||||
extended bool) *gcs.Filter {
|
extended bool, options ...QueryOption) *gcs.Filter {
|
||||||
getFilter := s.GetBasicFilter
|
getFilter := s.GetBasicFilter
|
||||||
getHeader := s.GetBasicHeader
|
getHeader := s.GetBasicHeader
|
||||||
putFilter := s.putBasicFilter
|
putFilter := s.putBasicFilter
|
||||||
|
@ -1714,17 +1807,54 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
replyChan := make(chan *gcs.Filter)
|
s.queryPeers(
|
||||||
s.query <- getCFilterMsg{
|
// Send a wire.GetCFilterMsg
|
||||||
cfRequest: cfRequest{
|
wire.NewMsgGetCFilter(&blockHash, extended),
|
||||||
blockHash: blockHash,
|
// Check responses and if we get one that matches,
|
||||||
extended: extended,
|
// end the query early.
|
||||||
|
func(sp *serverPeer, resp wire.Message,
|
||||||
|
quit chan<- struct{}) {
|
||||||
|
switch response := resp.(type) {
|
||||||
|
// We're only interested in "cfilter" messages.
|
||||||
|
case *wire.MsgCFilter:
|
||||||
|
if len(response.Data) < 4 {
|
||||||
|
// Filter data is too short.
|
||||||
|
// Ignore this message.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if blockHash != response.BlockHash {
|
||||||
|
// The response doesn't match our
|
||||||
|
// request. Ignore this message.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gotFilter, err :=
|
||||||
|
gcs.FromNBytes(builder.DefaultP,
|
||||||
|
response.Data)
|
||||||
|
if err != nil {
|
||||||
|
// Malformed filter data. We
|
||||||
|
// can ignore this message.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if MakeHeaderForFilter(gotFilter,
|
||||||
|
*prevHeader) !=
|
||||||
|
*curHeader {
|
||||||
|
// Filter data doesn't match
|
||||||
|
// the headers we know about.
|
||||||
|
// Ignore this response.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// At this point, the filter matches
|
||||||
|
// what we know about it and we declare
|
||||||
|
// it sane. We can kill the query and
|
||||||
|
// pass the response back to the caller.
|
||||||
|
close(quit)
|
||||||
|
filter = gotFilter
|
||||||
|
default:
|
||||||
|
}
|
||||||
},
|
},
|
||||||
prevHeader: prevHeader,
|
options...,
|
||||||
curHeader: curHeader,
|
)
|
||||||
reply: replyChan,
|
// If we've found a filter, write it to the database for next time.
|
||||||
}
|
|
||||||
filter = <-replyChan
|
|
||||||
if filter != nil {
|
if filter != nil {
|
||||||
putFilter(blockHash, filter)
|
putFilter(blockHash, filter)
|
||||||
log.Tracef("Wrote filter for block %s, extended: %t",
|
log.Tracef("Wrote filter for block %s, extended: %t",
|
||||||
|
@ -1736,20 +1866,72 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
|
||||||
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
|
// GetBlockFromNetwork gets a block by requesting it from the network, one peer
|
||||||
// at a time, until one answers.
|
// at a time, until one answers.
|
||||||
func (s *ChainService) GetBlockFromNetwork(
|
func (s *ChainService) GetBlockFromNetwork(
|
||||||
blockHash chainhash.Hash) *btcutil.Block {
|
blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block {
|
||||||
blockHeader, height, err := s.GetBlockByHash(blockHash)
|
blockHeader, height, err := s.GetBlockByHash(blockHash)
|
||||||
if err != nil || blockHeader.BlockHash() != blockHash {
|
if err != nil || blockHeader.BlockHash() != blockHash {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
replyChan := make(chan *btcutil.Block)
|
getData := wire.NewMsgGetData()
|
||||||
s.query <- getBlockMsg{
|
getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock,
|
||||||
blockHeader: &blockHeader,
|
&blockHash))
|
||||||
height: height,
|
// The block is only updated from the checkResponse function argument,
|
||||||
reply: replyChan,
|
// which is always called single-threadedly. We don't check the block
|
||||||
}
|
// until after the query is finished, so we can just write to it
|
||||||
block := <-replyChan
|
// naively.
|
||||||
if block != nil {
|
var foundBlock *btcutil.Block
|
||||||
log.Tracef("Got block %s from network", blockHash)
|
s.queryPeers(
|
||||||
}
|
// Send a wire.GetCFilterMsg
|
||||||
return block
|
getData,
|
||||||
|
// Check responses and if we get one that matches,
|
||||||
|
// end the query early.
|
||||||
|
func(sp *serverPeer, resp wire.Message,
|
||||||
|
quit chan<- struct{}) {
|
||||||
|
switch response := resp.(type) {
|
||||||
|
// We're only interested in "block" messages.
|
||||||
|
case *wire.MsgBlock:
|
||||||
|
// If this isn't our block, ignore it.
|
||||||
|
if response.BlockHash() !=
|
||||||
|
blockHash {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
block := btcutil.NewBlock(response)
|
||||||
|
// Only set height if btcutil hasn't
|
||||||
|
// automagically put one in.
|
||||||
|
if block.Height() ==
|
||||||
|
btcutil.BlockHeightUnknown {
|
||||||
|
block.SetHeight(
|
||||||
|
int32(height))
|
||||||
|
}
|
||||||
|
// If this claims our block but doesn't
|
||||||
|
// pass the sanity check, the peer is
|
||||||
|
// trying to bamboozle us. Disconnect
|
||||||
|
// it.
|
||||||
|
if err := blockchain.CheckBlockSanity(
|
||||||
|
block,
|
||||||
|
// We don't need to check PoW
|
||||||
|
// because by the time we get
|
||||||
|
// here, it's been checked
|
||||||
|
// during header synchronization
|
||||||
|
s.chainParams.PowLimit,
|
||||||
|
s.timeSource,
|
||||||
|
); err != nil {
|
||||||
|
log.Warnf("Invalid block for %s "+
|
||||||
|
"received from %s -- "+
|
||||||
|
"disconnecting peer", blockHash,
|
||||||
|
sp.Addr())
|
||||||
|
sp.Disconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// At this point, the block matches what we know
|
||||||
|
// about it and we declare it sane. We can kill
|
||||||
|
// the query and pass the response back to the
|
||||||
|
// caller.
|
||||||
|
close(quit)
|
||||||
|
foundBlock = block
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
},
|
||||||
|
options...,
|
||||||
|
)
|
||||||
|
return foundBlock
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -24,14 +23,32 @@ import (
|
||||||
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
|
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
logLevel = btclog.TraceLvl
|
logLevel = btclog.TraceLvl
|
||||||
syncTimeout = 30 * time.Second
|
syncTimeout = 30 * time.Second
|
||||||
syncUpdate = time.Second
|
syncUpdate = time.Second
|
||||||
numTestBlocks = 50
|
// Don't set this too high for your platform, or the tests will miss
|
||||||
|
// messages.
|
||||||
|
// TODO: Make this a benchmark instead.
|
||||||
|
numQueryThreads = 50
|
||||||
|
queryOptions = []spvchain.QueryOption{
|
||||||
|
//spvchain.NumRetries(5),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetup(t *testing.T) {
|
func TestSetup(t *testing.T) {
|
||||||
|
// Set up logging.
|
||||||
|
logger, err := btctestlog.NewTestLogger(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not set up logger: %s", err)
|
||||||
|
}
|
||||||
|
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
|
||||||
|
chainLogger.SetLevel(logLevel)
|
||||||
|
spvchain.UseLogger(chainLogger)
|
||||||
|
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
|
||||||
|
rpcLogger.SetLevel(logLevel)
|
||||||
|
btcrpcclient.UseLogger(rpcLogger)
|
||||||
|
|
||||||
// Create a btcd SimNet node and generate 500 blocks
|
// Create a btcd SimNet node and generate 500 blocks
|
||||||
h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
|
h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -147,16 +164,6 @@ func TestSetup(t *testing.T) {
|
||||||
spvchain.BanDuration = 5 * time.Second
|
spvchain.BanDuration = 5 * time.Second
|
||||||
spvchain.RequiredServices = wire.SFNodeNetwork
|
spvchain.RequiredServices = wire.SFNodeNetwork
|
||||||
spvchain.WaitForMoreCFHeaders = time.Second
|
spvchain.WaitForMoreCFHeaders = time.Second
|
||||||
logger, err := btctestlog.NewTestLogger(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Could not set up logger: %s", err)
|
|
||||||
}
|
|
||||||
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
|
|
||||||
chainLogger.SetLevel(logLevel)
|
|
||||||
spvchain.UseLogger(chainLogger)
|
|
||||||
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
|
|
||||||
rpcLogger.SetLevel(logLevel)
|
|
||||||
btcrpcclient.UseLogger(rpcLogger)
|
|
||||||
svc, err := spvchain.NewChainService(config)
|
svc, err := spvchain.NewChainService(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error creating ChainService: %s", err)
|
t.Fatalf("Error creating ChainService: %s", err)
|
||||||
|
@ -170,6 +177,13 @@ func TestSetup(t *testing.T) {
|
||||||
t.Fatalf("Couldn't sync ChainService: %s", err)
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that we can get blocks and cfilters via P2P and decide which are
|
||||||
|
// valid and which aren't.
|
||||||
|
err = testRandomBlocks(t, svc, h1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Testing blocks and cfilters failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Generate 125 blocks on h1 to make sure it reorgs the other nodes.
|
// Generate 125 blocks on h1 to make sure it reorgs the other nodes.
|
||||||
// Ensure the ChainService instance stays caught up.
|
// Ensure the ChainService instance stays caught up.
|
||||||
h1.Node.Generate(125)
|
h1.Node.Generate(125)
|
||||||
|
@ -209,13 +223,6 @@ func TestSetup(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Couldn't sync ChainService: %s", err)
|
t.Fatalf("Couldn't sync ChainService: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that we can get blocks and cfilters via P2P and decide which are
|
|
||||||
// valid and which aren't.
|
|
||||||
err = testRandomBlocks(t, svc, h1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Testing blocks and cfilters failed: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// csd does a connect-sync-disconnect between nodes in order to support
|
// csd does a connect-sync-disconnect between nodes in order to support
|
||||||
|
@ -377,7 +384,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService,
|
||||||
// can correctly get filters from them. We don't go through *all* the blocks
|
// can correctly get filters from them. We don't go through *all* the blocks
|
||||||
// because it can be a little slow, but we'll improve that soon-ish hopefully
|
// because it can be a little slow, but we'll improve that soon-ish hopefully
|
||||||
// to the point where we can do it.
|
// to the point where we can do it.
|
||||||
// TODO: Improve concurrency on framework side.
|
// TODO: Make this a benchmark instead.
|
||||||
func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
|
func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
|
||||||
correctSyncNode *rpctest.Harness) error {
|
correctSyncNode *rpctest.Harness) error {
|
||||||
var haveBest *waddrmgr.BlockStamp
|
var haveBest *waddrmgr.BlockStamp
|
||||||
|
@ -386,180 +393,200 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
|
||||||
return fmt.Errorf("Couldn't get best snapshot from "+
|
return fmt.Errorf("Couldn't get best snapshot from "+
|
||||||
"ChainService: %s", err)
|
"ChainService: %s", err)
|
||||||
}
|
}
|
||||||
// Keep track of an error channel
|
// Keep track of an error channel with enough buffer space to track one
|
||||||
errChan := make(chan error)
|
// error per block.
|
||||||
var lastErr error
|
errChan := make(chan error, haveBest.Height)
|
||||||
go func() {
|
// Test getting all of the blocks and filters.
|
||||||
for err := range errChan {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s", err)
|
|
||||||
lastErr = fmt.Errorf("Couldn't validate all " +
|
|
||||||
"blocks, filters, and filter headers.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
// Test getting numTestBlocks random blocks and filters.
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
heights := rand.Perm(int(haveBest.Height))
|
workerQueue := make(chan struct{}, numQueryThreads)
|
||||||
for i := 0; i < numTestBlocks; i++ {
|
for i := int32(1); i <= haveBest.Height; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
height := uint32(heights[i])
|
height := uint32(i)
|
||||||
|
// Wait until there's room in the worker queue.
|
||||||
|
workerQueue <- struct{}{}
|
||||||
go func() {
|
go func() {
|
||||||
|
// On exit, open a spot in workerQueue and tell the
|
||||||
|
// wait group we're done.
|
||||||
|
defer func() {
|
||||||
|
<-workerQueue
|
||||||
|
}()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// Get block header from database.
|
// Get block header from database.
|
||||||
blockHeader, blockHeight, err := svc.GetBlockByHeight(height)
|
blockHeader, blockHeight, err := svc.GetBlockByHeight(
|
||||||
|
height)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get block "+
|
errChan <- fmt.Errorf("Couldn't get block "+
|
||||||
"header by height %d: %s", height, err)
|
"header by height %d: %s", height, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if blockHeight != height {
|
if blockHeight != height {
|
||||||
errChan <- fmt.Errorf("Block height retrieved from DB "+
|
errChan <- fmt.Errorf("Block height retrieved "+
|
||||||
"doesn't match expected height. Want: %d, "+
|
"from DB doesn't match expected "+
|
||||||
"have: %d", height, blockHeight)
|
"height. Want: %d, have: %d", height,
|
||||||
|
blockHeight)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
blockHash := blockHeader.BlockHash()
|
blockHash := blockHeader.BlockHash()
|
||||||
// Get block via RPC.
|
// Get block via RPC.
|
||||||
wantBlock, err := correctSyncNode.Node.GetBlock(&blockHash)
|
wantBlock, err := correctSyncNode.Node.GetBlock(
|
||||||
|
&blockHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get block %d (%s) by RPC",
|
errChan <- fmt.Errorf("Couldn't get block %d "+
|
||||||
height, blockHash)
|
"(%s) by RPC", height, blockHash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get block from network.
|
// Get block from network.
|
||||||
haveBlock := svc.GetBlockFromNetwork(blockHash)
|
haveBlock := svc.GetBlockFromNetwork(blockHash,
|
||||||
|
queryOptions...)
|
||||||
if haveBlock == nil {
|
if haveBlock == nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get block %d (%s) from"+
|
errChan <- fmt.Errorf("Couldn't get block %d "+
|
||||||
"network", height, blockHash)
|
"(%s) from network", height, blockHash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that network and RPC blocks match.
|
// Check that network and RPC blocks match.
|
||||||
if !reflect.DeepEqual(*haveBlock.MsgBlock(), *wantBlock) {
|
if !reflect.DeepEqual(*haveBlock.MsgBlock(),
|
||||||
errChan <- fmt.Errorf("Block from network doesn't match "+
|
*wantBlock) {
|
||||||
"block from RPC. Want: %s, RPC: %s, network: "+
|
errChan <- fmt.Errorf("Block from network "+
|
||||||
"%s", blockHash, wantBlock.BlockHash(),
|
"doesn't match block from RPC. Want: "+
|
||||||
|
"%s, RPC: %s, network: %s", blockHash,
|
||||||
|
wantBlock.BlockHash(),
|
||||||
haveBlock.MsgBlock().BlockHash())
|
haveBlock.MsgBlock().BlockHash())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that block height matches what we have.
|
// Check that block height matches what we have.
|
||||||
if int32(blockHeight) != haveBlock.Height() {
|
if int32(blockHeight) != haveBlock.Height() {
|
||||||
errChan <- fmt.Errorf("Block height from network doesn't "+
|
errChan <- fmt.Errorf("Block height from "+
|
||||||
"match expected height. Want: %s, network: %s",
|
"network doesn't match expected "+
|
||||||
|
"height. Want: %s, network: %s",
|
||||||
blockHeight, haveBlock.Height())
|
blockHeight, haveBlock.Height())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get basic cfilter from network.
|
// Get basic cfilter from network.
|
||||||
haveFilter := svc.GetCFilter(blockHash, false)
|
haveFilter := svc.GetCFilter(blockHash, false,
|
||||||
|
queryOptions...)
|
||||||
if haveFilter == nil {
|
if haveFilter == nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get basic "+
|
errChan <- fmt.Errorf("Couldn't get basic "+
|
||||||
"filter for block %d", height)
|
"filter for block %d", height)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get basic cfilter from RPC.
|
// Get basic cfilter from RPC.
|
||||||
wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash,
|
wantFilter, err := correctSyncNode.Node.GetCFilter(
|
||||||
false)
|
&blockHash, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get basic filter for "+
|
errChan <- fmt.Errorf("Couldn't get basic "+
|
||||||
"block %d via RPC: %s", height, err)
|
"filter for block %d via RPC: %s",
|
||||||
|
height, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that network and RPC cfilters match.
|
// Check that network and RPC cfilters match.
|
||||||
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
|
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
|
||||||
errChan <- fmt.Errorf("Basic filter from P2P network/DB"+
|
errChan <- fmt.Errorf("Basic filter from P2P "+
|
||||||
" doesn't match RPC value for block %d", height)
|
"network/DB doesn't match RPC value "+
|
||||||
|
"for block %d", height)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Calculate basic filter from block.
|
// Calculate basic filter from block.
|
||||||
calcFilter, err := spvchain.BuildBasicFilter(
|
calcFilter, err := spvchain.BuildBasicFilter(
|
||||||
haveBlock.MsgBlock())
|
haveBlock.MsgBlock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't build basic filter for "+
|
errChan <- fmt.Errorf("Couldn't build basic "+
|
||||||
"block %d (%s): %s", height, blockHash, err)
|
"filter for block %d (%s): %s", height,
|
||||||
|
blockHash, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that the network value matches the calculated value
|
// Check that the network value matches the calculated
|
||||||
// from the block.
|
// value from the block.
|
||||||
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
|
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
|
||||||
errChan <- fmt.Errorf("Basic filter from P2P network/DB "+
|
errChan <- fmt.Errorf("Basic filter from P2P "+
|
||||||
"doesn't match calculated value for block %d",
|
"network/DB doesn't match calculated "+
|
||||||
height)
|
"value for block %d", height)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get previous basic filter header from the database.
|
// Get previous basic filter header from the database.
|
||||||
prevHeader, err := svc.GetBasicHeader(blockHeader.PrevBlock)
|
prevHeader, err := svc.GetBasicHeader(
|
||||||
|
blockHeader.PrevBlock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get basic filter header "+
|
errChan <- fmt.Errorf("Couldn't get basic "+
|
||||||
"for block %d (%s) from DB: %s", height-1,
|
"filter header for block %d (%s) from "+
|
||||||
|
"DB: %s", height-1,
|
||||||
blockHeader.PrevBlock, err)
|
blockHeader.PrevBlock, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get current basic filter header from the database.
|
// Get current basic filter header from the database.
|
||||||
curHeader, err := svc.GetBasicHeader(blockHash)
|
curHeader, err := svc.GetBasicHeader(blockHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get basic filter header "+
|
errChan <- fmt.Errorf("Couldn't get basic "+
|
||||||
"for block %d (%s) from DB: %s", height-1,
|
"filter header for block %d (%s) from "+
|
||||||
blockHash, err)
|
"DB: %s", height-1, blockHash, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that the filter and header line up.
|
// Check that the filter and header line up.
|
||||||
calcHeader := spvchain.MakeHeaderForFilter(calcFilter,
|
calcHeader := spvchain.MakeHeaderForFilter(calcFilter,
|
||||||
*prevHeader)
|
*prevHeader)
|
||||||
if !bytes.Equal(curHeader[:], calcHeader[:]) {
|
if !bytes.Equal(curHeader[:], calcHeader[:]) {
|
||||||
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+
|
errChan <- fmt.Errorf("Filter header doesn't "+
|
||||||
"%s, got: %s", curHeader, calcHeader)
|
"match. Want: %s, got: %s", curHeader,
|
||||||
|
calcHeader)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get extended cfilter from network
|
// Get extended cfilter from network
|
||||||
haveFilter = svc.GetCFilter(blockHash, true)
|
haveFilter = svc.GetCFilter(blockHash, true,
|
||||||
|
queryOptions...)
|
||||||
if haveFilter == nil {
|
if haveFilter == nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get extended "+
|
errChan <- fmt.Errorf("Couldn't get extended "+
|
||||||
"filter for block %d", height)
|
"filter for block %d", height)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get extended cfilter from RPC
|
// Get extended cfilter from RPC
|
||||||
wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash,
|
wantFilter, err = correctSyncNode.Node.GetCFilter(
|
||||||
true)
|
&blockHash, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get extended filter for "+
|
errChan <- fmt.Errorf("Couldn't get extended "+
|
||||||
"block %d via RPC: %s", height, err)
|
"filter for block %d via RPC: %s",
|
||||||
|
height, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that network and RPC cfilters match
|
// Check that network and RPC cfilters match
|
||||||
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
|
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
|
||||||
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+
|
errChan <- fmt.Errorf("Extended filter from "+
|
||||||
" doesn't match RPC value for block %d", height)
|
"P2P network/DB doesn't match RPC "+
|
||||||
|
"value for block %d", height)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Calculate extended filter from block
|
// Calculate extended filter from block
|
||||||
calcFilter, err = spvchain.BuildExtFilter(
|
calcFilter, err = spvchain.BuildExtFilter(
|
||||||
haveBlock.MsgBlock())
|
haveBlock.MsgBlock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't build extended filter for "+
|
errChan <- fmt.Errorf("Couldn't build extended"+
|
||||||
"block %d (%s): %s", height, blockHash, err)
|
" filter for block %d (%s): %s", height,
|
||||||
|
blockHash, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check that the network value matches the calculated value
|
// Check that the network value matches the calculated
|
||||||
// from the block.
|
// value from the block.
|
||||||
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
|
if !reflect.DeepEqual(*haveFilter, *calcFilter) {
|
||||||
errChan <- fmt.Errorf("Extended filter from P2P network/DB"+
|
errChan <- fmt.Errorf("Extended filter from "+
|
||||||
" doesn't match calculated value for block %d",
|
"P2P network/DB doesn't match "+
|
||||||
height)
|
"calculated value for block %d", height)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get previous extended filter header from the database.
|
// Get previous extended filter header from the
|
||||||
prevHeader, err = svc.GetExtHeader(blockHeader.PrevBlock)
|
// database.
|
||||||
|
prevHeader, err = svc.GetExtHeader(
|
||||||
|
blockHeader.PrevBlock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get extended filter header"+
|
errChan <- fmt.Errorf("Couldn't get extended "+
|
||||||
" for block %d (%s) from DB: %s", height-1,
|
"filter header for block %d (%s) from "+
|
||||||
|
"DB: %s", height-1,
|
||||||
blockHeader.PrevBlock, err)
|
blockHeader.PrevBlock, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get current basic filter header from the database.
|
// Get current basic filter header from the database.
|
||||||
curHeader, err = svc.GetExtHeader(blockHash)
|
curHeader, err = svc.GetExtHeader(blockHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("Couldn't get extended filter header"+
|
errChan <- fmt.Errorf("Couldn't get extended "+
|
||||||
" for block %d (%s) from DB: %s", height-1,
|
"filter header for block %d (%s) from "+
|
||||||
|
"DB: %s", height-1,
|
||||||
blockHash, err)
|
blockHash, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -567,20 +594,29 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService,
|
||||||
calcHeader = spvchain.MakeHeaderForFilter(calcFilter,
|
calcHeader = spvchain.MakeHeaderForFilter(calcFilter,
|
||||||
*prevHeader)
|
*prevHeader)
|
||||||
if !bytes.Equal(curHeader[:], calcHeader[:]) {
|
if !bytes.Equal(curHeader[:], calcHeader[:]) {
|
||||||
errChan <- fmt.Errorf("Filter header doesn't match. Want: "+
|
errChan <- fmt.Errorf("Filter header doesn't "+
|
||||||
"%s, got: %s", curHeader, calcHeader)
|
"match. Want: %s, got: %s", curHeader,
|
||||||
|
calcHeader)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
// Wait for all queries to finish.
|
// Wait for all queries to finish.
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
if logLevel != btclog.Off {
|
|
||||||
t.Logf("Finished checking %d blocks and their cfilters",
|
|
||||||
numTestBlocks)
|
|
||||||
}
|
|
||||||
// Close the error channel to make the error monitoring goroutine
|
// Close the error channel to make the error monitoring goroutine
|
||||||
// finish.
|
// finish.
|
||||||
close(errChan)
|
close(errChan)
|
||||||
|
var lastErr error
|
||||||
|
for err := range errChan {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s", err)
|
||||||
|
lastErr = fmt.Errorf("Couldn't validate all " +
|
||||||
|
"blocks, filters, and filter headers.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logLevel != btclog.Off {
|
||||||
|
t.Logf("Finished checking %d blocks and their cfilters",
|
||||||
|
haveBest.Height)
|
||||||
|
}
|
||||||
return lastErr
|
return lastErr
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue