Working cfilter download (slows down tests).

This commit is contained in:
Alex 2017-04-24 21:44:45 -06:00 committed by Olaoluwa Osuntokun
parent 125d47b55c
commit 6a1cb8c846
4 changed files with 476 additions and 100 deletions

View file

@ -37,6 +37,7 @@ const (
maxTimeOffset = 2 * time.Hour
)
// TODO: Redo this using query API.
var (
// WaitForMoreCFHeaders is a configurable time to wait for CFHeaders
// messages from peers. It defaults to 3 seconds but can be increased
@ -1380,8 +1381,11 @@ func (b *blockManager) handleCFilterMsg(cfmsg *cfilterMsg) {
}
// Notify the ChainService of the newly-found filter.
b.server.query <- processCFilterMsg{
filter: filter,
extended: cfmsg.cfilter.Extended,
cfRequest: cfRequest{
blockHash: cfmsg.cfilter.BlockHash,
extended: cfmsg.cfilter.Extended,
},
filter: filter,
}
}

View file

@ -8,8 +8,11 @@ import (
"errors"
"github.com/btcsuite/btcd/addrmgr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/connmgr"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
)
type getConnCountMsg struct {
@ -49,11 +52,21 @@ type forAllPeersMsg struct {
closure func(*serverPeer)
}
type processCFilterMsg struct {
filter *gcs.Filter
extended bool
type getCFilterMsg struct {
cfRequest
prevHeader *chainhash.Hash
curHeader *chainhash.Hash
reply chan *gcs.Filter
}
type processCFilterMsg struct {
cfRequest
filter *gcs.Filter
}
// TODO: General - abstract out more of blockmanager into queries. It'll make
// this way more maintainable and usable.
// handleQuery is the central handler for all queries and commands from other
// goroutines related to peer state.
func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
@ -161,7 +174,107 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) {
// Even though this is a query, there's no reply channel as the
// forAllPeers method doesn't return anything. An error might be
// useful in the future.
case getCFilterMsg:
found := false
state.queryPeers(
// Should we query this peer?
func(sp *serverPeer) bool {
// Don't send requests to disconnected peers.
return sp.Connected()
},
// Send a wire.GetCFilterMsg
wire.NewMsgGetCFilter(&msg.blockHash, msg.extended),
// Check responses and if we get one that matches,
// end the query early.
func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}) {
switch response := resp.(type) {
// We're only interested in "cfilter" messages.
case *wire.MsgCFilter:
if len(response.Data) < 4 {
// Filter data is too short.
// Ignore this message.
return
}
filter, err :=
gcs.FromNBytes(builder.DefaultP,
response.Data)
if err != nil {
// Malformed filter data. We
// can ignore this message.
return
}
if makeHeaderForFilter(filter,
*msg.prevHeader) !=
*msg.curHeader {
// Filter data doesn't match
// the headers we know about.
// Ignore this response.
return
}
// At this point, the filter matches
// what we know about it and we declare
// it sane. We can kill the query and
// pass the response back to the caller.
found = true
close(quit)
msg.reply <- filter
default:
}
},
)
// We timed out without finding a correct answer to our query.
if !found {
msg.reply <- nil
}
/*sent := false
state.forAllPeers(func(sp *serverPeer) {
// Send to one peer at a time. No use flooding the
// network.
if sent {
return
}
// Don't send to a peer that's not connected.
if !sp.Connected() {
return
}
// Don't send to any peer from which we've already
// requested this cfilter.
if _, ok := sp.requestedCFilters[msg.cfRequest]; ok {
return
}
// Request a cfilter from the peer and mark sent as
// true so we don't ask any other peers unless
// necessary.
err := sp.pushGetCFilterMsg(
&msg.cfRequest.blockHash,
msg.cfRequest.extended)
if err == nil {
sent = true
}
})
if !sent {
msg.reply <- nil
s.signalAllCFilters(msg.cfRequest, nil)
return
}
// Record the required header information against which to check
// the cfilter.
s.cfRequestHeaders[msg.cfRequest] = [2]*chainhash.Hash{
msg.prevHeader,
msg.curHeader,
}*/
case processCFilterMsg:
s.signalAllCFilters(msg.cfRequest, msg.filter)
}
//case processCFilterMsg:
// TODO: make this work
}
func (s *ChainService) signalAllCFilters(req cfRequest, filter *gcs.Filter) {
go func() {
for _, replyChan := range s.cfilterRequests[req] {
replyChan <- filter
}
s.cfilterRequests[req] = make([]chan *gcs.Filter, 0)
}()
}

View file

@ -24,6 +24,8 @@ import (
)
// These are exported variables so they can be changed by users.
// TODO: Export functional options for these as much as possible so they can be
// changed call-to-call.
var (
// ConnectionRetryInterval is the base amount of time to wait in between
// retries when connecting to persistent peers. It is adjusted by the
@ -60,6 +62,9 @@ var (
// DisableDNSSeed disables getting initial addresses for Bitcoin nodes
// from DNS.
DisableDNSSeed = false
// Timeout specifies how long to wait for a peer to answer a query.
Timeout = time.Second * 5
)
// updatePeerHeightsMsg is a message sent from the blockmanager to the server
@ -105,6 +110,132 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) {
ps.forAllOutboundPeers(closure)
}
// Query options can be modified per-query, unlike global options.
// TODO: Make more query options that override global options.
type queryOptions struct {
// queryTimeout lets the query know how long to wait for a peer to
// answer the query before moving onto the next peer.
queryTimeout time.Duration
}
// defaultQueryOptions returns a queryOptions set to package-level defaults.
func defaultQueryOptions() *queryOptions {
return &queryOptions{
queryTimeout: Timeout,
}
}
// QueryTimeout is a query option that lets the query know to ask each peer we're
// connected to for its opinion, if any. By default, we only ask peers until one
// gives us a valid response.
func QueryTimeout(timeout time.Duration) func(*queryOptions) {
return func(qo *queryOptions) {
qo.queryTimeout = timeout
}
}
type spMsg struct {
sp *serverPeer
msg wire.Message
}
// queryPeers is a helper function that sends a query to one or more peers and
// waits for an answer. The timeout for queries is set by the QueryTimeout
// package-level variable.
func (ps *peerState) queryPeers(
// selectPeer is a closure which decides whether or not to send the
// query to the peer.
selectPeer func(sp *serverPeer) bool,
// queryMsg is the message to send to each peer selected by selectPeer.
queryMsg wire.Message,
// checkResponse is caled for every message within the timeout period.
// The quit channel lets the query know to terminate because the
// required response has been found. This is done by closing the
// channel.
checkResponse func(sp *serverPeer, resp wire.Message,
quit chan<- struct{}),
// options takes functional options for executing the query.
options ...func(*queryOptions),
) {
qo := defaultQueryOptions()
for _, option := range options {
option(qo)
}
// This will be shared state between the per-peer goroutines.
quit := make(chan struct{})
startQuery := make(chan struct{})
var wg sync.WaitGroup
channel := make(chan spMsg)
// This goroutine will monitor all messages from all peers until the
// peer goroutines all exit.
go func() {
for {
select {
case <-quit:
close(channel)
ps.forAllPeers(
func(sp *serverPeer) {
sp.unsubscribeRecvMsgs(channel)
})
return
case sm := <-channel:
// TODO: This will get stuck if checkResponse
// gets stuck.
checkResponse(sm.sp, sm.msg, quit)
}
}
}()
// Start a goroutine for each peer that potentially queries each peer
ps.forAllPeers(func(sp *serverPeer) {
wg.Add(1)
go func() {
defer wg.Done()
if !selectPeer(sp) {
return
}
timeout := make(<-chan time.Time)
for {
select {
case <-timeout:
// After timeout, we return and notify
// another goroutine that we've done so.
// We only send if there's someone left
// to receive.
startQuery <- struct{}{}
return
case <-quit:
// After we're told to quit, we return.
return
case <-startQuery:
// We're the lucky peer whose turn it is
// to try to answer the current query.
// TODO: Fix this to support multiple
// queries at once. For now, we're
// relying on the query handling loop
// to make sure we don't interrupt
// another query. We need broadcast
// support in OnRead to do this right.
sp.subscribeRecvMsg(channel)
sp.QueueMessage(queryMsg, nil)
timeout = time.After(qo.queryTimeout)
default:
}
}
}()
})
startQuery <- struct{}{}
wg.Wait()
// If we timed out and didn't quit, make sure our response monitor
// goroutine knows to quit.
select {
case <-quit:
default:
close(quit)
}
}
// cfhRequest records which cfheaders we've requested, and the order in which
// we've requested them. Since there's no way to associate the cfheaders to the
// actual block hashes based on the cfheaders message to keep it compact, we
@ -142,6 +273,14 @@ type serverPeer struct {
quit chan struct{}
// The following chans are used to sync blockmanager and server.
blockProcessed chan struct{}
// The following slice of channels is used to subscribe to messages from
// the peer. This allows broadcast to multiple subscribers at once,
// allowing for multiple queries to be going to multiple peers at any
// one time. The mutex is for subscribe/unsubscribe functionality.
// The sends on these channels WILL NOT block; any messages the channel
// can't accept will be dropped silently.
recvSubscribers []chan<- spMsg
mtxSubscribers sync.RWMutex
}
// newServerPeer returns a new serverPeer instance. The peer needs to be set by
@ -522,8 +661,46 @@ func (sp *serverPeer) OnAddr(_ *peer.Peer, msg *wire.MsgAddr) {
// OnRead is invoked when a peer receives a message and it is used to update
// the bytes received by the server.
func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message, err error) {
func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message,
err error) {
sp.server.AddBytesReceived(uint64(bytesRead))
// Try to send a message to the subscriber channel if it isn't nil, but
// don't block on failure.
sp.mtxSubscribers.RLock()
defer sp.mtxSubscribers.RUnlock()
for _, channel := range sp.recvSubscribers {
if channel != nil {
select {
case channel <- spMsg{
sp: sp,
msg: msg,
}:
default:
}
}
}
}
// subscribeRecvMsg handles adding OnRead subscriptions to the server peer.
func (sp *serverPeer) subscribeRecvMsg(channel chan<- spMsg) {
sp.mtxSubscribers.Lock()
defer sp.mtxSubscribers.Unlock()
sp.recvSubscribers = append(sp.recvSubscribers, channel)
}
// unsubscribeRecvMsgs handles removing OnRead subscriptions from the server
// peer.
func (sp *serverPeer) unsubscribeRecvMsgs(channel chan<- spMsg) {
sp.mtxSubscribers.Lock()
defer sp.mtxSubscribers.Unlock()
var updatedSubscribers []chan<- spMsg
for _, candidate := range sp.recvSubscribers {
if candidate != channel {
updatedSubscribers = append(updatedSubscribers,
candidate)
}
}
sp.recvSubscribers = updatedSubscribers
}
// OnWrite is invoked when a peer sends a message and it is used to update
@ -556,6 +733,9 @@ type ChainService struct {
timeSource blockchain.MedianTimeSource
services wire.ServiceFlag
cfilterRequests map[cfRequest][]chan *gcs.Filter
cfRequestHeaders map[cfRequest][2]*chainhash.Hash
userAgentName string
userAgentVersion string
}
@ -764,6 +944,8 @@ func NewChainService(cfg Config) (*ChainService, error) {
services: Services,
userAgentName: UserAgentName,
userAgentVersion: UserAgentVersion,
cfilterRequests: make(map[cfRequest][]chan *gcs.Filter),
cfRequestHeaders: make(map[cfRequest][2]*chainhash.Hash),
}
err := createSPVNS(s.namespace, &s.chainParams)
@ -1555,3 +1737,50 @@ func (s *ChainService) rollbackToHeight(height uint32) (*waddrmgr.BlockStamp, er
func (s *ChainService) IsCurrent() bool {
return s.blockManager.IsCurrent()
}
// GetCFilter gets a cfilter from the database. Failing that, it requests the
// cfilter from the network and writes it to the database.
func (s *ChainService) GetCFilter(blockHash chainhash.Hash,
extended bool) *gcs.Filter {
getFilter := s.GetBasicFilter
getHeader := s.GetBasicHeader
putFilter := s.putBasicFilter
if extended {
getFilter = s.GetExtFilter
getHeader = s.GetExtHeader
putFilter = s.putExtFilter
}
filter, err := getFilter(blockHash)
if err == nil && filter != nil {
return filter
}
block, _, err := s.GetBlockByHash(blockHash)
if err != nil || block.BlockHash() != blockHash {
return nil
}
curHeader, err := getHeader(blockHash)
if err != nil {
return nil
}
prevHeader, err := getHeader(block.PrevBlock)
if err != nil {
return nil
}
replyChan := make(chan *gcs.Filter)
s.query <- getCFilterMsg{
cfRequest: cfRequest{
blockHash: blockHash,
extended: extended,
},
prevHeader: prevHeader,
curHeader: curHeader,
reply: replyChan,
}
filter = <-replyChan
if filter != nil {
putFilter(blockHash, filter)
log.Tracef("Wrote filter for block %s, extended: %t",
blockHash, extended)
}
return filter
}

View file

@ -1,17 +1,21 @@
package spvchain_test
import (
"bytes"
"fmt"
"io/ioutil"
"math/rand"
"os"
"testing"
"time"
"github.com/aakselrod/btctestlog"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/rpctest"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog"
"github.com/btcsuite/btcrpcclient"
"github.com/btcsuite/btcwallet/spvsvc/spvchain"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
@ -147,6 +151,9 @@ func TestSetup(t *testing.T) {
chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ")
chainLogger.SetLevel(logLevel)
spvchain.UseLogger(chainLogger)
rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ")
rpcLogger.SetLevel(logLevel)
btcrpcclient.UseLogger(rpcLogger)
svc, err := spvchain.NewChainService(config)
if err != nil {
t.Fatalf("Error creating ChainService: %s", err)
@ -279,99 +286,122 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService,
return fmt.Errorf("Couldn't get latest extended header from "+
"%s: %s", correctSyncNode.P2PAddress(), err)
}
for total <= syncTimeout {
haveBasicHeader := &chainhash.Hash{}
haveExtHeader := &chainhash.Hash{}
for (*knownBasicHeader.HeaderHashes[0] != *haveBasicHeader) &&
(*knownExtHeader.HeaderHashes[0] != *haveExtHeader) {
if total > syncTimeout {
return fmt.Errorf("Timed out after %v waiting for "+
"cfheaders synchronization.", syncTimeout)
}
haveBasicHeader, _ = svc.GetBasicHeader(*knownBestHash)
haveExtHeader, _ = svc.GetExtHeader(*knownBestHash)
time.Sleep(syncUpdate)
total += syncUpdate
haveBasicHeader, err := svc.GetBasicHeader(*knownBestHash)
if err != nil {
if logLevel != btclog.Off {
t.Logf("Basic header unknown.")
}
continue
}
haveExtHeader, err := svc.GetExtHeader(*knownBestHash)
if err != nil {
if logLevel != btclog.Off {
t.Logf("Extended header unknown.")
}
continue
}
if *knownBasicHeader.HeaderHashes[0] != *haveBasicHeader {
return fmt.Errorf("Known basic header doesn't match "+
"the basic header the ChainService has. Known:"+
" %s, ChainService: %s",
knownBasicHeader.HeaderHashes[0],
haveBasicHeader)
}
if *knownExtHeader.HeaderHashes[0] != *haveExtHeader {
return fmt.Errorf("Known extended header doesn't "+
"match the extended header the ChainService "+
"has. Known: %s, ChainService: %s",
knownExtHeader.HeaderHashes[0], haveExtHeader)
}
// At this point, we know the latest cfheader is stored in the
// ChainService database. We now compare each cfheader the
// harness knows about to what's stored in the ChainService
// database to see if we've missed anything or messed anything
// up.
for i := int32(0); i <= haveBest.Height; i++ {
head, _, err := svc.GetBlockByHeight(uint32(i))
if err != nil {
return fmt.Errorf("Couldn't read block by "+
"height: %s", err)
}
hash := head.BlockHash()
haveBasicHeader, err := svc.GetBasicHeader(hash)
if err != nil {
return fmt.Errorf("Couldn't get basic header "+
"for %d (%s) from DB", i, hash)
}
haveExtHeader, err := svc.GetExtHeader(hash)
if err != nil {
return fmt.Errorf("Couldn't get extended "+
"header for %d (%s) from DB", i, hash)
}
knownBasicHeader, err :=
correctSyncNode.Node.GetCFilterHeader(&hash,
false)
if err != nil {
return fmt.Errorf("Couldn't get basic header "+
"for %d (%s) from node %s", i, hash,
correctSyncNode.P2PAddress())
}
knownExtHeader, err :=
correctSyncNode.Node.GetCFilterHeader(&hash,
true)
if err != nil {
return fmt.Errorf("Couldn't get extended "+
"header for %d (%s) from node %s", i,
hash, correctSyncNode.P2PAddress())
}
if *haveBasicHeader !=
*knownBasicHeader.HeaderHashes[0] {
return fmt.Errorf("Basic header for %d (%s) "+
"doesn't match node %s. DB: %s, node: "+
"%s", i, hash,
correctSyncNode.P2PAddress(),
haveBasicHeader,
knownBasicHeader.HeaderHashes[0])
}
if *haveExtHeader !=
*knownExtHeader.HeaderHashes[0] {
return fmt.Errorf("Extended header for %d (%s)"+
" doesn't match node %s. DB: %s, node:"+
" %s", i, hash,
correctSyncNode.P2PAddress(),
haveExtHeader,
knownExtHeader.HeaderHashes[0])
}
}
if logLevel != btclog.Off {
t.Logf("Synced cfheaders to %d (%s)", haveBest.Height,
haveBest.Hash)
}
return nil
}
return fmt.Errorf("Timeout waiting for cfheaders synchronization after"+
" %v", syncTimeout)
if logLevel != btclog.Off {
t.Logf("Synced cfheaders to %d (%s)", haveBest.Height,
haveBest.Hash)
}
// At this point, we know the latest cfheader is stored in the
// ChainService database. We now compare each cfheader the
// harness knows about to what's stored in the ChainService
// database to see if we've missed anything or messed anything
// up.
for i := int32(0); i <= haveBest.Height; i++ {
head, _, err := svc.GetBlockByHeight(uint32(i))
if err != nil {
return fmt.Errorf("Couldn't read block by "+
"height: %s", err)
}
hash := head.BlockHash()
haveBasicHeader, err = svc.GetBasicHeader(hash)
if err != nil {
return fmt.Errorf("Couldn't get basic header "+
"for %d (%s) from DB", i, hash)
}
haveExtHeader, err = svc.GetExtHeader(hash)
if err != nil {
return fmt.Errorf("Couldn't get extended "+
"header for %d (%s) from DB", i, hash)
}
knownBasicHeader, err =
correctSyncNode.Node.GetCFilterHeader(&hash,
false)
if err != nil {
return fmt.Errorf("Couldn't get basic header "+
"for %d (%s) from node %s", i, hash,
correctSyncNode.P2PAddress())
}
knownExtHeader, err =
correctSyncNode.Node.GetCFilterHeader(&hash,
true)
if err != nil {
return fmt.Errorf("Couldn't get extended "+
"header for %d (%s) from node %s", i,
hash, correctSyncNode.P2PAddress())
}
if *haveBasicHeader !=
*knownBasicHeader.HeaderHashes[0] {
return fmt.Errorf("Basic header for %d (%s) "+
"doesn't match node %s. DB: %s, node: "+
"%s", i, hash,
correctSyncNode.P2PAddress(),
haveBasicHeader,
knownBasicHeader.HeaderHashes[0])
}
if *haveExtHeader !=
*knownExtHeader.HeaderHashes[0] {
return fmt.Errorf("Extended header for %d (%s)"+
" doesn't match node %s. DB: %s, node:"+
" %s", i, hash,
correctSyncNode.P2PAddress(),
haveExtHeader,
knownExtHeader.HeaderHashes[0])
}
}
// Test getting 15 random filters.
heights := rand.Perm(int(haveBest.Height))
for i := 0; i < 15; i++ {
height := uint32(heights[i])
block, _, err := svc.GetBlockByHeight(height)
if err != nil {
return fmt.Errorf("Get block by height %d:"+
" %s", height, err)
}
blockHash := block.BlockHash()
haveFilter := svc.GetCFilter(blockHash, false)
if haveFilter == nil {
return fmt.Errorf("Couldn't get basic "+
"filter for block %d", height)
}
t.Logf("%x", haveFilter.NBytes())
wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash,
false)
if err != nil {
return fmt.Errorf("Couldn't get basic filter for "+
"block %d via RPC: %s", height, err)
}
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
return fmt.Errorf("Basic filter from P2P network/DB"+
" doesn't match RPC value for block %d", height)
}
haveFilter = svc.GetCFilter(blockHash, true)
if haveFilter == nil {
return fmt.Errorf("Couldn't get extended "+
"filter for block %d", height)
}
t.Logf("%x", haveFilter.NBytes())
wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash,
true)
if err != nil {
return fmt.Errorf("Couldn't get extended filter for "+
"block %d via RPC: %s", height, err)
}
if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) {
return fmt.Errorf("Extended filter from P2P network/DB"+
" doesn't match RPC value for block %d", height)
}
}
return nil
}