diff --git a/spvsvc/spvchain/blockmanager.go b/spvsvc/spvchain/blockmanager.go index d85d224..e671caf 100644 --- a/spvsvc/spvchain/blockmanager.go +++ b/spvsvc/spvchain/blockmanager.go @@ -37,6 +37,7 @@ const ( maxTimeOffset = 2 * time.Hour ) +// TODO: Redo this using query API. var ( // WaitForMoreCFHeaders is a configurable time to wait for CFHeaders // messages from peers. It defaults to 3 seconds but can be increased @@ -1380,8 +1381,11 @@ func (b *blockManager) handleCFilterMsg(cfmsg *cfilterMsg) { } // Notify the ChainService of the newly-found filter. b.server.query <- processCFilterMsg{ - filter: filter, - extended: cfmsg.cfilter.Extended, + cfRequest: cfRequest{ + blockHash: cfmsg.cfilter.BlockHash, + extended: cfmsg.cfilter.Extended, + }, + filter: filter, } } diff --git a/spvsvc/spvchain/notifications.go b/spvsvc/spvchain/notifications.go index f31dbdf..418a84e 100644 --- a/spvsvc/spvchain/notifications.go +++ b/spvsvc/spvchain/notifications.go @@ -8,8 +8,11 @@ import ( "errors" "github.com/btcsuite/btcd/addrmgr" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" + "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil/gcs" + "github.com/btcsuite/btcutil/gcs/builder" ) type getConnCountMsg struct { @@ -49,11 +52,21 @@ type forAllPeersMsg struct { closure func(*serverPeer) } -type processCFilterMsg struct { - filter *gcs.Filter - extended bool +type getCFilterMsg struct { + cfRequest + prevHeader *chainhash.Hash + curHeader *chainhash.Hash + reply chan *gcs.Filter } +type processCFilterMsg struct { + cfRequest + filter *gcs.Filter +} + +// TODO: General - abstract out more of blockmanager into queries. It'll make +// this way more maintainable and usable. + // handleQuery is the central handler for all queries and commands from other // goroutines related to peer state. func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { @@ -161,7 +174,107 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { // Even though this is a query, there's no reply channel as the // forAllPeers method doesn't return anything. An error might be // useful in the future. + case getCFilterMsg: + found := false + state.queryPeers( + // Should we query this peer? + func(sp *serverPeer) bool { + // Don't send requests to disconnected peers. + return sp.Connected() + }, + // Send a wire.GetCFilterMsg + wire.NewMsgGetCFilter(&msg.blockHash, msg.extended), + // Check responses and if we get one that matches, + // end the query early. + func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}) { + switch response := resp.(type) { + // We're only interested in "cfilter" messages. + case *wire.MsgCFilter: + if len(response.Data) < 4 { + // Filter data is too short. + // Ignore this message. + return + } + filter, err := + gcs.FromNBytes(builder.DefaultP, + response.Data) + if err != nil { + // Malformed filter data. We + // can ignore this message. + return + } + if makeHeaderForFilter(filter, + *msg.prevHeader) != + *msg.curHeader { + // Filter data doesn't match + // the headers we know about. + // Ignore this response. + return + } + // At this point, the filter matches + // what we know about it and we declare + // it sane. We can kill the query and + // pass the response back to the caller. + found = true + close(quit) + msg.reply <- filter + default: + } + }, + ) + // We timed out without finding a correct answer to our query. + if !found { + msg.reply <- nil + } + /*sent := false + state.forAllPeers(func(sp *serverPeer) { + // Send to one peer at a time. No use flooding the + // network. + if sent { + return + } + // Don't send to a peer that's not connected. + if !sp.Connected() { + return + } + // Don't send to any peer from which we've already + // requested this cfilter. + if _, ok := sp.requestedCFilters[msg.cfRequest]; ok { + return + } + // Request a cfilter from the peer and mark sent as + // true so we don't ask any other peers unless + // necessary. + err := sp.pushGetCFilterMsg( + &msg.cfRequest.blockHash, + msg.cfRequest.extended) + if err == nil { + sent = true + } + + }) + if !sent { + msg.reply <- nil + s.signalAllCFilters(msg.cfRequest, nil) + return + } + // Record the required header information against which to check + // the cfilter. + s.cfRequestHeaders[msg.cfRequest] = [2]*chainhash.Hash{ + msg.prevHeader, + msg.curHeader, + }*/ + case processCFilterMsg: + s.signalAllCFilters(msg.cfRequest, msg.filter) } - //case processCFilterMsg: - // TODO: make this work +} + +func (s *ChainService) signalAllCFilters(req cfRequest, filter *gcs.Filter) { + go func() { + for _, replyChan := range s.cfilterRequests[req] { + replyChan <- filter + } + s.cfilterRequests[req] = make([]chan *gcs.Filter, 0) + }() } diff --git a/spvsvc/spvchain/spvchain.go b/spvsvc/spvchain/spvchain.go index 73e2353..5014776 100644 --- a/spvsvc/spvchain/spvchain.go +++ b/spvsvc/spvchain/spvchain.go @@ -24,6 +24,8 @@ import ( ) // These are exported variables so they can be changed by users. +// TODO: Export functional options for these as much as possible so they can be +// changed call-to-call. var ( // ConnectionRetryInterval is the base amount of time to wait in between // retries when connecting to persistent peers. It is adjusted by the @@ -60,6 +62,9 @@ var ( // DisableDNSSeed disables getting initial addresses for Bitcoin nodes // from DNS. DisableDNSSeed = false + + // Timeout specifies how long to wait for a peer to answer a query. + Timeout = time.Second * 5 ) // updatePeerHeightsMsg is a message sent from the blockmanager to the server @@ -105,6 +110,132 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) { ps.forAllOutboundPeers(closure) } +// Query options can be modified per-query, unlike global options. +// TODO: Make more query options that override global options. +type queryOptions struct { + // queryTimeout lets the query know how long to wait for a peer to + // answer the query before moving onto the next peer. + queryTimeout time.Duration +} + +// defaultQueryOptions returns a queryOptions set to package-level defaults. +func defaultQueryOptions() *queryOptions { + return &queryOptions{ + queryTimeout: Timeout, + } +} + +// QueryTimeout is a query option that lets the query know to ask each peer we're +// connected to for its opinion, if any. By default, we only ask peers until one +// gives us a valid response. +func QueryTimeout(timeout time.Duration) func(*queryOptions) { + return func(qo *queryOptions) { + qo.queryTimeout = timeout + } +} + +type spMsg struct { + sp *serverPeer + msg wire.Message +} + +// queryPeers is a helper function that sends a query to one or more peers and +// waits for an answer. The timeout for queries is set by the QueryTimeout +// package-level variable. +func (ps *peerState) queryPeers( + // selectPeer is a closure which decides whether or not to send the + // query to the peer. + selectPeer func(sp *serverPeer) bool, + // queryMsg is the message to send to each peer selected by selectPeer. + queryMsg wire.Message, + // checkResponse is caled for every message within the timeout period. + // The quit channel lets the query know to terminate because the + // required response has been found. This is done by closing the + // channel. + checkResponse func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}), + // options takes functional options for executing the query. + options ...func(*queryOptions), +) { + qo := defaultQueryOptions() + for _, option := range options { + option(qo) + } + // This will be shared state between the per-peer goroutines. + quit := make(chan struct{}) + startQuery := make(chan struct{}) + var wg sync.WaitGroup + channel := make(chan spMsg) + + // This goroutine will monitor all messages from all peers until the + // peer goroutines all exit. + go func() { + for { + select { + case <-quit: + close(channel) + ps.forAllPeers( + func(sp *serverPeer) { + sp.unsubscribeRecvMsgs(channel) + }) + return + case sm := <-channel: + // TODO: This will get stuck if checkResponse + // gets stuck. + checkResponse(sm.sp, sm.msg, quit) + } + } + }() + + // Start a goroutine for each peer that potentially queries each peer + ps.forAllPeers(func(sp *serverPeer) { + wg.Add(1) + go func() { + defer wg.Done() + if !selectPeer(sp) { + return + } + timeout := make(<-chan time.Time) + for { + select { + case <-timeout: + // After timeout, we return and notify + // another goroutine that we've done so. + // We only send if there's someone left + // to receive. + startQuery <- struct{}{} + return + case <-quit: + // After we're told to quit, we return. + return + case <-startQuery: + // We're the lucky peer whose turn it is + // to try to answer the current query. + // TODO: Fix this to support multiple + // queries at once. For now, we're + // relying on the query handling loop + // to make sure we don't interrupt + // another query. We need broadcast + // support in OnRead to do this right. + sp.subscribeRecvMsg(channel) + sp.QueueMessage(queryMsg, nil) + timeout = time.After(qo.queryTimeout) + default: + } + } + }() + }) + startQuery <- struct{}{} + wg.Wait() + // If we timed out and didn't quit, make sure our response monitor + // goroutine knows to quit. + select { + case <-quit: + default: + close(quit) + } +} + // cfhRequest records which cfheaders we've requested, and the order in which // we've requested them. Since there's no way to associate the cfheaders to the // actual block hashes based on the cfheaders message to keep it compact, we @@ -142,6 +273,14 @@ type serverPeer struct { quit chan struct{} // The following chans are used to sync blockmanager and server. blockProcessed chan struct{} + // The following slice of channels is used to subscribe to messages from + // the peer. This allows broadcast to multiple subscribers at once, + // allowing for multiple queries to be going to multiple peers at any + // one time. The mutex is for subscribe/unsubscribe functionality. + // The sends on these channels WILL NOT block; any messages the channel + // can't accept will be dropped silently. + recvSubscribers []chan<- spMsg + mtxSubscribers sync.RWMutex } // newServerPeer returns a new serverPeer instance. The peer needs to be set by @@ -522,8 +661,46 @@ func (sp *serverPeer) OnAddr(_ *peer.Peer, msg *wire.MsgAddr) { // OnRead is invoked when a peer receives a message and it is used to update // the bytes received by the server. -func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message, err error) { +func (sp *serverPeer) OnRead(_ *peer.Peer, bytesRead int, msg wire.Message, + err error) { sp.server.AddBytesReceived(uint64(bytesRead)) + // Try to send a message to the subscriber channel if it isn't nil, but + // don't block on failure. + sp.mtxSubscribers.RLock() + defer sp.mtxSubscribers.RUnlock() + for _, channel := range sp.recvSubscribers { + if channel != nil { + select { + case channel <- spMsg{ + sp: sp, + msg: msg, + }: + default: + } + } + } +} + +// subscribeRecvMsg handles adding OnRead subscriptions to the server peer. +func (sp *serverPeer) subscribeRecvMsg(channel chan<- spMsg) { + sp.mtxSubscribers.Lock() + defer sp.mtxSubscribers.Unlock() + sp.recvSubscribers = append(sp.recvSubscribers, channel) +} + +// unsubscribeRecvMsgs handles removing OnRead subscriptions from the server +// peer. +func (sp *serverPeer) unsubscribeRecvMsgs(channel chan<- spMsg) { + sp.mtxSubscribers.Lock() + defer sp.mtxSubscribers.Unlock() + var updatedSubscribers []chan<- spMsg + for _, candidate := range sp.recvSubscribers { + if candidate != channel { + updatedSubscribers = append(updatedSubscribers, + candidate) + } + } + sp.recvSubscribers = updatedSubscribers } // OnWrite is invoked when a peer sends a message and it is used to update @@ -556,6 +733,9 @@ type ChainService struct { timeSource blockchain.MedianTimeSource services wire.ServiceFlag + cfilterRequests map[cfRequest][]chan *gcs.Filter + cfRequestHeaders map[cfRequest][2]*chainhash.Hash + userAgentName string userAgentVersion string } @@ -764,6 +944,8 @@ func NewChainService(cfg Config) (*ChainService, error) { services: Services, userAgentName: UserAgentName, userAgentVersion: UserAgentVersion, + cfilterRequests: make(map[cfRequest][]chan *gcs.Filter), + cfRequestHeaders: make(map[cfRequest][2]*chainhash.Hash), } err := createSPVNS(s.namespace, &s.chainParams) @@ -1555,3 +1737,50 @@ func (s *ChainService) rollbackToHeight(height uint32) (*waddrmgr.BlockStamp, er func (s *ChainService) IsCurrent() bool { return s.blockManager.IsCurrent() } + +// GetCFilter gets a cfilter from the database. Failing that, it requests the +// cfilter from the network and writes it to the database. +func (s *ChainService) GetCFilter(blockHash chainhash.Hash, + extended bool) *gcs.Filter { + getFilter := s.GetBasicFilter + getHeader := s.GetBasicHeader + putFilter := s.putBasicFilter + if extended { + getFilter = s.GetExtFilter + getHeader = s.GetExtHeader + putFilter = s.putExtFilter + } + filter, err := getFilter(blockHash) + if err == nil && filter != nil { + return filter + } + block, _, err := s.GetBlockByHash(blockHash) + if err != nil || block.BlockHash() != blockHash { + return nil + } + curHeader, err := getHeader(blockHash) + if err != nil { + return nil + } + prevHeader, err := getHeader(block.PrevBlock) + if err != nil { + return nil + } + replyChan := make(chan *gcs.Filter) + s.query <- getCFilterMsg{ + cfRequest: cfRequest{ + blockHash: blockHash, + extended: extended, + }, + prevHeader: prevHeader, + curHeader: curHeader, + reply: replyChan, + } + filter = <-replyChan + if filter != nil { + putFilter(blockHash, filter) + log.Tracef("Wrote filter for block %s, extended: %t", + blockHash, extended) + } + return filter +} diff --git a/spvsvc/spvchain/sync_test.go b/spvsvc/spvchain/sync_test.go index 2d6b3e8..72fb11f 100644 --- a/spvsvc/spvchain/sync_test.go +++ b/spvsvc/spvchain/sync_test.go @@ -1,17 +1,21 @@ package spvchain_test import ( + "bytes" "fmt" "io/ioutil" + "math/rand" "os" "testing" "time" "github.com/aakselrod/btctestlog" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/rpctest" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog" + "github.com/btcsuite/btcrpcclient" "github.com/btcsuite/btcwallet/spvsvc/spvchain" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/walletdb" @@ -147,6 +151,9 @@ func TestSetup(t *testing.T) { chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ") chainLogger.SetLevel(logLevel) spvchain.UseLogger(chainLogger) + rpcLogger := btclog.NewSubsystemLogger(logger, "RPCC: ") + rpcLogger.SetLevel(logLevel) + btcrpcclient.UseLogger(rpcLogger) svc, err := spvchain.NewChainService(config) if err != nil { t.Fatalf("Error creating ChainService: %s", err) @@ -279,99 +286,122 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, return fmt.Errorf("Couldn't get latest extended header from "+ "%s: %s", correctSyncNode.P2PAddress(), err) } - for total <= syncTimeout { + haveBasicHeader := &chainhash.Hash{} + haveExtHeader := &chainhash.Hash{} + for (*knownBasicHeader.HeaderHashes[0] != *haveBasicHeader) && + (*knownExtHeader.HeaderHashes[0] != *haveExtHeader) { + if total > syncTimeout { + return fmt.Errorf("Timed out after %v waiting for "+ + "cfheaders synchronization.", syncTimeout) + } + haveBasicHeader, _ = svc.GetBasicHeader(*knownBestHash) + haveExtHeader, _ = svc.GetExtHeader(*knownBestHash) time.Sleep(syncUpdate) total += syncUpdate - haveBasicHeader, err := svc.GetBasicHeader(*knownBestHash) - if err != nil { - if logLevel != btclog.Off { - t.Logf("Basic header unknown.") - } - continue - } - haveExtHeader, err := svc.GetExtHeader(*knownBestHash) - if err != nil { - if logLevel != btclog.Off { - t.Logf("Extended header unknown.") - } - continue - } - if *knownBasicHeader.HeaderHashes[0] != *haveBasicHeader { - return fmt.Errorf("Known basic header doesn't match "+ - "the basic header the ChainService has. Known:"+ - " %s, ChainService: %s", - knownBasicHeader.HeaderHashes[0], - haveBasicHeader) - } - if *knownExtHeader.HeaderHashes[0] != *haveExtHeader { - return fmt.Errorf("Known extended header doesn't "+ - "match the extended header the ChainService "+ - "has. Known: %s, ChainService: %s", - knownExtHeader.HeaderHashes[0], haveExtHeader) - } - // At this point, we know the latest cfheader is stored in the - // ChainService database. We now compare each cfheader the - // harness knows about to what's stored in the ChainService - // database to see if we've missed anything or messed anything - // up. - for i := int32(0); i <= haveBest.Height; i++ { - head, _, err := svc.GetBlockByHeight(uint32(i)) - if err != nil { - return fmt.Errorf("Couldn't read block by "+ - "height: %s", err) - } - hash := head.BlockHash() - haveBasicHeader, err := svc.GetBasicHeader(hash) - if err != nil { - return fmt.Errorf("Couldn't get basic header "+ - "for %d (%s) from DB", i, hash) - } - haveExtHeader, err := svc.GetExtHeader(hash) - if err != nil { - return fmt.Errorf("Couldn't get extended "+ - "header for %d (%s) from DB", i, hash) - } - knownBasicHeader, err := - correctSyncNode.Node.GetCFilterHeader(&hash, - false) - if err != nil { - return fmt.Errorf("Couldn't get basic header "+ - "for %d (%s) from node %s", i, hash, - correctSyncNode.P2PAddress()) - } - knownExtHeader, err := - correctSyncNode.Node.GetCFilterHeader(&hash, - true) - if err != nil { - return fmt.Errorf("Couldn't get extended "+ - "header for %d (%s) from node %s", i, - hash, correctSyncNode.P2PAddress()) - } - if *haveBasicHeader != - *knownBasicHeader.HeaderHashes[0] { - return fmt.Errorf("Basic header for %d (%s) "+ - "doesn't match node %s. DB: %s, node: "+ - "%s", i, hash, - correctSyncNode.P2PAddress(), - haveBasicHeader, - knownBasicHeader.HeaderHashes[0]) - } - if *haveExtHeader != - *knownExtHeader.HeaderHashes[0] { - return fmt.Errorf("Extended header for %d (%s)"+ - " doesn't match node %s. DB: %s, node:"+ - " %s", i, hash, - correctSyncNode.P2PAddress(), - haveExtHeader, - knownExtHeader.HeaderHashes[0]) - } - } - if logLevel != btclog.Off { - t.Logf("Synced cfheaders to %d (%s)", haveBest.Height, - haveBest.Hash) - } - return nil } - return fmt.Errorf("Timeout waiting for cfheaders synchronization after"+ - " %v", syncTimeout) + if logLevel != btclog.Off { + t.Logf("Synced cfheaders to %d (%s)", haveBest.Height, + haveBest.Hash) + } + // At this point, we know the latest cfheader is stored in the + // ChainService database. We now compare each cfheader the + // harness knows about to what's stored in the ChainService + // database to see if we've missed anything or messed anything + // up. + for i := int32(0); i <= haveBest.Height; i++ { + head, _, err := svc.GetBlockByHeight(uint32(i)) + if err != nil { + return fmt.Errorf("Couldn't read block by "+ + "height: %s", err) + } + hash := head.BlockHash() + haveBasicHeader, err = svc.GetBasicHeader(hash) + if err != nil { + return fmt.Errorf("Couldn't get basic header "+ + "for %d (%s) from DB", i, hash) + } + haveExtHeader, err = svc.GetExtHeader(hash) + if err != nil { + return fmt.Errorf("Couldn't get extended "+ + "header for %d (%s) from DB", i, hash) + } + knownBasicHeader, err = + correctSyncNode.Node.GetCFilterHeader(&hash, + false) + if err != nil { + return fmt.Errorf("Couldn't get basic header "+ + "for %d (%s) from node %s", i, hash, + correctSyncNode.P2PAddress()) + } + knownExtHeader, err = + correctSyncNode.Node.GetCFilterHeader(&hash, + true) + if err != nil { + return fmt.Errorf("Couldn't get extended "+ + "header for %d (%s) from node %s", i, + hash, correctSyncNode.P2PAddress()) + } + if *haveBasicHeader != + *knownBasicHeader.HeaderHashes[0] { + return fmt.Errorf("Basic header for %d (%s) "+ + "doesn't match node %s. DB: %s, node: "+ + "%s", i, hash, + correctSyncNode.P2PAddress(), + haveBasicHeader, + knownBasicHeader.HeaderHashes[0]) + } + if *haveExtHeader != + *knownExtHeader.HeaderHashes[0] { + return fmt.Errorf("Extended header for %d (%s)"+ + " doesn't match node %s. DB: %s, node:"+ + " %s", i, hash, + correctSyncNode.P2PAddress(), + haveExtHeader, + knownExtHeader.HeaderHashes[0]) + } + } + // Test getting 15 random filters. + heights := rand.Perm(int(haveBest.Height)) + for i := 0; i < 15; i++ { + height := uint32(heights[i]) + block, _, err := svc.GetBlockByHeight(height) + if err != nil { + return fmt.Errorf("Get block by height %d:"+ + " %s", height, err) + } + blockHash := block.BlockHash() + haveFilter := svc.GetCFilter(blockHash, false) + if haveFilter == nil { + return fmt.Errorf("Couldn't get basic "+ + "filter for block %d", height) + } + t.Logf("%x", haveFilter.NBytes()) + wantFilter, err := correctSyncNode.Node.GetCFilter(&blockHash, + false) + if err != nil { + return fmt.Errorf("Couldn't get basic filter for "+ + "block %d via RPC: %s", height, err) + } + if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) { + return fmt.Errorf("Basic filter from P2P network/DB"+ + " doesn't match RPC value for block %d", height) + } + haveFilter = svc.GetCFilter(blockHash, true) + if haveFilter == nil { + return fmt.Errorf("Couldn't get extended "+ + "filter for block %d", height) + } + t.Logf("%x", haveFilter.NBytes()) + wantFilter, err = correctSyncNode.Node.GetCFilter(&blockHash, + true) + if err != nil { + return fmt.Errorf("Couldn't get extended filter for "+ + "block %d via RPC: %s", height, err) + } + if !bytes.Equal(haveFilter.NBytes(), wantFilter.Data) { + return fmt.Errorf("Extended filter from P2P network/DB"+ + " doesn't match RPC value for block %d", height) + } + } + return nil }