diff --git a/spvsvc/spvchain/blockmanager.go b/spvsvc/spvchain/blockmanager.go index a97781a..c2ccf88 100644 --- a/spvsvc/spvchain/blockmanager.go +++ b/spvsvc/spvchain/blockmanager.go @@ -64,6 +64,30 @@ type headersMsg struct { peer *serverPeer } +// cfheadersMsg packages a bitcoin cfheaders message and the peer it came from +// together so the block handler has access to that information. +type cfheadersMsg struct { + cfheaders *wire.MsgCFHeaders + peer *serverPeer +} + +// cfheadersProcessedMsg tells the block manager to try to see if there are +// enough samples of cfheaders messages to process the committed filter header +// chain. This is kind of a hack until these get soft-forked in, but we do +// verification to avoid getting bamboozled by malicious nodes. +type processCFHeadersMsg struct { + earliestNode *headerNode + stopHash chainhash.Hash + extended bool +} + +// cfilterMsg packages a bitcoin cfilter message and the peer it came from +// together so the block handler has access to that information. +type cfilterMsg struct { + cfilter *wire.MsgCFilter + peer *serverPeer +} + // donePeerMsg signifies a newly disconnected peer to the block handler. type donePeerMsg struct { peer *serverPeer @@ -133,6 +157,10 @@ type blockManager struct { nextCheckpoint *chaincfg.Checkpoint lastRequested chainhash.Hash + basicHeaders map[chainhash.Hash]map[chainhash.Hash][]*serverPeer + extendedHeaders map[chainhash.Hash]map[chainhash.Hash][]*serverPeer + lastFilterHeight int32 + minRetargetTimespan int64 // target timespan / adjustment factor maxRetargetTimespan int64 // target timespan * adjustment factor blocksPerRetarget int32 // target timespan / target time per block @@ -156,6 +184,12 @@ func newBlockManager(s *ChainService) (*blockManager, error) { blocksPerRetarget: int32(targetTimespan / targetTimePerBlock), minRetargetTimespan: targetTimespan / adjustmentFactor, maxRetargetTimespan: targetTimespan * adjustmentFactor, + basicHeaders: make( + map[chainhash.Hash]map[chainhash.Hash][]*serverPeer, + ), + extendedHeaders: make( + map[chainhash.Hash]map[chainhash.Hash][]*serverPeer, + ), } // Initialize the next checkpoint based on the current height. @@ -300,9 +334,18 @@ out: case *headersMsg: b.handleHeadersMsg(msg) + case *cfheadersMsg: + b.handleCFHeadersMsg(msg) + + case *cfilterMsg: + b.handleCFilterMsg(msg) + case *donePeerMsg: b.handleDonePeerMsg(candidatePeers, msg.peer) + case *processCFHeadersMsg: + b.handleProcessCFHeadersMsg(msg) + case getSyncPeerMsg: msg.reply <- b.syncPeer @@ -405,12 +448,24 @@ func (b *blockManager) resetHeaderState(newestHeader *wire.BlockHeader, newestHeight int32) { b.headerList.Init() b.startHeader = nil + b.basicHeaders = make( + map[chainhash.Hash]map[chainhash.Hash][]*serverPeer, + ) + b.extendedHeaders = make( + map[chainhash.Hash]map[chainhash.Hash][]*serverPeer, + ) // Add an entry for the latest known block into the header pool. // This allows the next downloaded header to prove it links to the chain // properly. node := headerNode{header: newestHeader, height: newestHeight} b.headerList.PushBack(&node) + b.basicHeaders[newestHeader.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) + b.extendedHeaders[newestHeader.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) } // startSync will choose the best peer among the available candidate peers to @@ -605,10 +660,14 @@ func (b *blockManager) handleInvMsg(imsg *invMsg) { // If this is the sync peer or we're current, get the headers // for the announced blocks and update the last announced block. if lastBlock != -1 && (imsg.peer == b.syncPeer || b.current()) { - lastHash := b.headerList.Back().Value.(*headerNode).header.BlockHash() + lastEl := b.headerList.Back() + var lastHash chainhash.Hash + if lastEl != nil { + lastHash = lastEl.Value.(*headerNode).header.BlockHash() + } // Only send getheaders if we don't already know about the last // block hash being announced. - if lastHash != invVects[lastBlock].Hash && + if lastHash != invVects[lastBlock].Hash && lastEl != nil && b.lastRequested != invVects[lastBlock].Hash { // Make a locator starting from the latest known header // we've processed. @@ -711,6 +770,12 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { } hmsg.peer.UpdateLastBlockHeight(node.height) e := b.headerList.PushBack(&node) + b.basicHeaders[node.header.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) + b.extendedHeaders[node.header.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) if b.startHeader == nil { b.startHeader = e } @@ -849,9 +914,16 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // We also change the sync peer. Then we can continue // with the rest of the headers in the message as if // nothing has happened. + // TODO: Error handling, duh! b.syncPeer = hmsg.peer b.server.rollbackToHeight(backHeight) b.server.putBlock(*blockHeader, backHeight+1) + b.basicHeaders[node.header.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) + b.extendedHeaders[node.header.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) b.server.putMaxBlockHeight(backHeight + 1) b.resetHeaderState(&backHead, int32(backHeight)) b.headerList.PushBack(&headerNode{ @@ -912,9 +984,34 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { //return } + // Send getcfheaders to each peer based on these headers. + cfhLocator := blockchain.BlockLocator([]*chainhash.Hash{ + &msg.Headers[0].PrevBlock, + }) + cfhStopHash := msg.Headers[len(msg.Headers)-1].BlockHash() + cfhCount := len(msg.Headers) + cfhReqB := cfhRequest{ + extended: false, + stopHash: cfhStopHash, + } + cfhReqE := cfhRequest{ + extended: true, + stopHash: cfhStopHash, + } + b.server.ForAllPeers(func(sp *serverPeer) { + // Should probably use better isolation for this but we're in + // the same package. One of the things to clean up when we do + // more general cleanup. + sp.requestedCFHeaders[cfhReqB] = cfhCount + sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, false) + sp.requestedCFHeaders[cfhReqE] = cfhCount + sp.pushGetCFHeadersMsg(cfhLocator, &cfhStopHash, true) + }) + // If not current, request the next batch of headers starting from the // latest known header and ending with the next checkpoint. - if !b.current() { + if !b.current() || b.server.chainParams.Net == + chaincfg.SimNetParams.Net { locator := blockchain.BlockLocator([]*chainhash.Hash{finalHash}) nextHash := zeroHash if b.nextCheckpoint != nil { @@ -924,11 +1021,201 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { if err != nil { log.Warnf("Failed to send getheaders message to "+ "peer %s: %s", hmsg.peer.Addr(), err) + // Unnecessary but we might put other code after this + // eventually. return } } } +// QueueCFHeaders adds the passed headers message and peer to the block handling +// queue. +func (b *blockManager) QueueCFHeaders(cfheaders *wire.MsgCFHeaders, + sp *serverPeer) { + // No channel handling here because peers do not need to block on + // cfheaders messages. + if atomic.LoadInt32(&b.shutdown) != 0 { + return + } + + b.msgChan <- &cfheadersMsg{cfheaders: cfheaders, peer: sp} +} + +// handleCFHeadersMsg handles cfheaders messages from all peers. +func (b *blockManager) handleCFHeadersMsg(cfhmsg *cfheadersMsg) { + // Grab the matching request we sent, as this message should correspond + // to that, and delete it from the map on return as we're now handling + // it. + req := cfhRequest{ + extended: cfhmsg.cfheaders.Extended, + stopHash: cfhmsg.cfheaders.StopHash, + } + defer delete(cfhmsg.peer.requestedCFHeaders, req) + // Check that the count is correct. This works even when the map lookup + // fails as it returns 0 in that case. + headerList := cfhmsg.cfheaders.HeaderHashes + respLen := len(headerList) + if cfhmsg.peer.requestedCFHeaders[req] != respLen { + log.Warnf("Received cfheaders message doesn't match any "+ + "getcfheaders request. Peer %s is probably on a "+ + "different chain -- ignoring", cfhmsg.peer.Addr()) + return + } + if respLen == 0 { + return + } + // Find the block header matching the last filter header, if any. + el := b.headerList.Back() + for el != nil { + if el.Value.(*headerNode).header.BlockHash() == req.stopHash { + break + } + el = el.Prev() + } + // If nothing matched, there's nothing more to do. + if el == nil { + return + } + // Cycle through the filter header hashes and process them. + filterMap := b.basicHeaders + if req.extended { + filterMap = b.extendedHeaders + } + var node *headerNode + var hash chainhash.Hash + for i := respLen - 1; i >= 0 && el != nil; i-- { + // If there's no map for this header, the header is either no + // longer valid or has already been processed and committed to + // the database. Either way, break processing. + node = el.Value.(*headerNode) + hash = node.header.BlockHash() + if _, ok := filterMap[hash]; !ok { + break + } + // Process this header and set up the next iteration. + filterMap[hash][*headerList[i]] = append( + filterMap[hash][*headerList[i]], cfhmsg.peer, + ) + el = el.Prev() + } + b.msgChan <- &processCFHeadersMsg{ + earliestNode: node, + stopHash: req.stopHash, + extended: req.extended, + } + log.Tracef("Processed cfheaders starting at %s, ending at %s", + node.header.BlockHash(), req.stopHash) +} + +// handleProcessCFHeadersMsg checks to see if we have enough cfheaders to make +// a decision about what the correct headers are, makes that decision if +// possible, and downloads any cfilters and blocks necessary to make that +// decision. +func (b *blockManager) handleProcessCFHeadersMsg(msg *processCFHeadersMsg) { + // Assume we aren't ready to make a decision about correct headers yet. + ready := false + + // If we have started receiving cfheaders messages for blocks farther + // than the last set we haven't made a decision on, it's time to make + // a decision. + if msg.earliestNode.height > b.lastFilterHeight { + if b.lastFilterHeight != 0 { + ready = true + } + b.lastFilterHeight = msg.earliestNode.height + } + + // If there are no other messages left, we should go ahead and make a + // decision because we have all the info we're going to get. + // TODO: Instead of using just a channel to queue messages, create + // another goroutine that reads the channel and appends the messages to + // a slice. Then we can check the slice for only cfheaders messages. We + // might need to add DoS protection for that. + if len(b.msgChan) == 0 { + ready = true + } + + // Do nothing if we're not ready to make a decision yet. + if !ready { + return + } + + // At this point, we've got all the cfheaders messages we're going to + // get for the range of headers described by the passed message. We now + // iterate through all of those headers, looking for conflicts. If we + // find a conflict, we have to do additional checks; otherwise, we write + // the filter header to the database. + el := b.headerList.Front() + filterMap := b.basicHeaders + writeFunc := b.server.putBasicHeader + readFunc := b.server.GetBasicHeader + if msg.extended { + filterMap = b.extendedHeaders + writeFunc = b.server.putExtHeader + readFunc = b.server.GetExtHeader + } + for el != nil { + node := el.Value.(*headerNode) + hash := node.header.BlockHash() + if node.height >= msg.earliestNode.height { + blockMap := filterMap[hash] + switch len(blockMap) { + // This should only happen if the filter has already + // been written to the database or if there's a reorg. + case 0: + if _, err := readFunc(hash); err != nil { + // We don't have the filter stored in + // the DB, there's been a reorg. + log.Warnf("Somehow we have 0 cfheaders"+ + " for block %d (%s)", + node.height, hash) + return + } + // This is the normal case when nobody's trying to + // bamboozle us (or ALL our peers are). + case 1: + // This will only cycle once + for filterHash := range blockMap { + writeFunc(hash, filterHash) + } + // This is when we have conflicting information from + // multiple peers. + // TODO: Handle this case. + default: + } + } + + //elToRemove := el + el = el.Next() + //b.headerList.Remove(elToRemove) + //b.startHeader = el + + // If we've reached the end, we can return + if hash == msg.stopHash { + log.Tracef("Finished processing cfheaders messages up "+ + "to height %d/hash %s", node.height, hash) + return + } + } +} + +// QueueCFilter adds the passed cfilter message and peer to the block handling +// queue. +func (b *blockManager) QueueCFilter(cfilter *wire.MsgCFilter, sp *serverPeer) { + // No channel handling here because peers do not need to block on + // headers messages. + if atomic.LoadInt32(&b.shutdown) != 0 { + return + } + + b.msgChan <- &cfilterMsg{cfilter: cfilter, peer: sp} +} + +// handleCFilterMsg handles cfilter messages from all peers. +func (b *blockManager) handleCFilterMsg(cfmsg *cfilterMsg) { + +} + // checkHeaderSanity checks the PoW, and timestamp of a block header. func (b *blockManager) checkHeaderSanity(blockHeader *wire.BlockHeader, maxTimestamp time.Time, reorgAttempt bool) error { diff --git a/spvsvc/spvchain/db.go b/spvsvc/spvchain/db.go index cb82fd3..0cbd786 100644 --- a/spvsvc/spvchain/db.go +++ b/spvsvc/spvchain/db.go @@ -182,10 +182,40 @@ func putExtHeader(tx walletdb.Tx, blockHash chainhash.Hash, return putHeader(tx, blockHash, extHeaderBucketName, filterTip) } +// getHeader retrieves the provided filter, keyed to the block hash, from the +// appropriate filter bucket in the database. +func getHeader(tx walletdb.Tx, blockHash chainhash.Hash, + bucketName []byte) (*chainhash.Hash, error) { + + bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) + + filterTip := bucket.Get(blockHash[:]) + if len(filterTip) == 0 { + return &chainhash.Hash{}, + fmt.Errorf("failed to get filter header") + } + + return chainhash.NewHash(filterTip) +} + +// getBasicHeader retrieves the provided filter, keyed to the block hash, from +// the basic filter bucket in the database. +func getBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, + error) { + return getHeader(tx, blockHash, basicHeaderBucketName) +} + +// getExtHeader retrieves the provided filter, keyed to the block hash, from the +// extended filter bucket in the database. +func getExtHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, + error) { + return getHeader(tx, blockHash, extHeaderBucketName) +} + // rollbackLastBlock rolls back the last known block and returns the BlockStamp // representing the new last known block. func rollbackLastBlock(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { - bs, err := SyncedTo(tx) + bs, err := syncedTo(tx) if err != nil { return nil, err } @@ -202,12 +232,12 @@ func rollbackLastBlock(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { if err != nil { return nil, err } - return SyncedTo(tx) + return syncedTo(tx) } -// GetBlockByHash retrieves the block header, filter, and filter tip, based on +// getBlockByHash retrieves the block header, filter, and filter tip, based on // the provided block hash, from the database. -func GetBlockByHash(tx walletdb.Tx, blockHash chainhash.Hash) (wire.BlockHeader, +func getBlockByHash(tx walletdb.Tx, blockHash chainhash.Hash) (wire.BlockHeader, uint32, error) { //chainhash.Hash, chainhash.Hash, bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) @@ -233,8 +263,8 @@ func GetBlockByHash(tx walletdb.Tx, blockHash chainhash.Hash) (wire.BlockHeader, return header, height, nil } -// GetBlockHashByHeight retrieves the hash of a block by its height. -func GetBlockHashByHeight(tx walletdb.Tx, height uint32) (chainhash.Hash, +// getBlockHashByHeight retrieves the hash of a block by its height. +func getBlockHashByHeight(tx walletdb.Tx, height uint32) (chainhash.Hash, error) { bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) var hash chainhash.Hash @@ -246,21 +276,21 @@ func GetBlockHashByHeight(tx walletdb.Tx, height uint32) (chainhash.Hash, return hash, nil } -// GetBlockByHeight retrieves a block's information by its height. -func GetBlockByHeight(tx walletdb.Tx, height uint32) (wire.BlockHeader, uint32, +// getBlockByHeight retrieves a block's information by its height. +func getBlockByHeight(tx walletdb.Tx, height uint32) (wire.BlockHeader, uint32, error) { // chainhash.Hash, chainhash.Hash - blockHash, err := GetBlockHashByHeight(tx, height) + blockHash, err := getBlockHashByHeight(tx, height) if err != nil { return wire.BlockHeader{}, 0, err } - return GetBlockByHash(tx, blockHash) + return getBlockByHash(tx, blockHash) } -// SyncedTo retrieves the most recent block's height and hash. -func SyncedTo(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { - header, height, err := LatestBlock(tx) +// syncedTo retrieves the most recent block's height and hash. +func syncedTo(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { + header, height, err := latestBlock(tx) if err != nil { return nil, err } @@ -270,8 +300,8 @@ func SyncedTo(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { return &blockStamp, nil } -// LatestBlock retrieves all the info about the latest stored block. -func LatestBlock(tx walletdb.Tx) (wire.BlockHeader, uint32, error) { +// latestBlock retrieves all the info about the latest stored block. +func latestBlock(tx walletdb.Tx) (wire.BlockHeader, uint32, error) { bucket := tx.RootBucket().Bucket(spvBucketName) maxBlockHeightBytes := bucket.Get(maxBlockHeightName) @@ -281,7 +311,7 @@ func LatestBlock(tx walletdb.Tx) (wire.BlockHeader, uint32, error) { } maxBlockHeight := binary.LittleEndian.Uint32(maxBlockHeightBytes) - header, height, err := GetBlockByHeight(tx, maxBlockHeight) + header, height, err := getBlockByHeight(tx, maxBlockHeight) if err != nil { return wire.BlockHeader{}, 0, err } @@ -295,12 +325,12 @@ func LatestBlock(tx walletdb.Tx) (wire.BlockHeader, uint32, error) { // CheckConnectivity cycles through all of the block headers, from last to // first, and makes sure they all connect to each other. func CheckConnectivity(tx walletdb.Tx) error { - header, height, err := LatestBlock(tx) + header, height, err := latestBlock(tx) if err != nil { return fmt.Errorf("Couldn't retrieve latest block: %s", err) } for height > 0 { - newheader, newheight, err := GetBlockByHash(tx, + newheader, newheight, err := getBlockByHash(tx, header.PrevBlock) if err != nil { return fmt.Errorf("Couldn't retrieve block %s: %s", @@ -322,14 +352,14 @@ func CheckConnectivity(tx walletdb.Tx) error { return nil } -// BlockLocatorFromHash returns a block locator based on the provided hash. -func BlockLocatorFromHash(tx walletdb.Tx, hash chainhash.Hash) blockchain.BlockLocator { +// blockLocatorFromHash returns a block locator based on the provided hash. +func blockLocatorFromHash(tx walletdb.Tx, hash chainhash.Hash) blockchain.BlockLocator { locator := make(blockchain.BlockLocator, 0, wire.MaxBlockLocatorsPerMsg) locator = append(locator, &hash) // If hash isn't found in DB or this is the genesis block, return // the locator as is - _, height, err := GetBlockByHash(tx, hash) + _, height, err := getBlockByHash(tx, hash) if (err != nil) || (height == 0) { return locator } @@ -346,7 +376,7 @@ func BlockLocatorFromHash(tx walletdb.Tx, hash chainhash.Hash) blockchain.BlockL } else { height -= decrement } - blockHash, err := GetBlockHashByHeight(tx, height) + blockHash, err := getBlockHashByHeight(tx, height) if err != nil { return locator } diff --git a/spvsvc/spvchain/notifications.go b/spvsvc/spvchain/notifications.go index 2ba05c5..f6668f7 100644 --- a/spvsvc/spvchain/notifications.go +++ b/spvsvc/spvchain/notifications.go @@ -44,6 +44,10 @@ type removeNodeMsg struct { reply chan error } +type forAllPeersMsg struct { + closure func(*serverPeer) +} + // handleQuery is the central handler for all queries and commands from other // goroutines related to peer state. func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { @@ -145,5 +149,11 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { } msg.reply <- errors.New("peer not found") + case forAllPeersMsg: + // Run the closure on all peers in the passed state. + state.forAllPeers(msg.closure) + // Even though this is a query, there's no reply channel as the + // forAllPeers method doesn't return anything. An error might be + // useful in the future. } } diff --git a/spvsvc/spvchain/spvchain.go b/spvsvc/spvchain/spvchain.go index c990ab5..c96624c 100644 --- a/spvsvc/spvchain/spvchain.go +++ b/spvsvc/spvchain/spvchain.go @@ -104,6 +104,15 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) { ps.forAllOutboundPeers(closure) } +// cfhRequest records which cfheaders we've requested, and the order in which +// we've requested them. Since there's no way to associate the cfheaders to the +// actual block hashes based on the cfheaders message to keep it compact, we +// track it this way. +type cfhRequest struct { + extended bool + stopHash chainhash.Hash +} + // serverPeer extends the peer to maintain state shared by the server and // the blockmanager. type serverPeer struct { @@ -112,17 +121,18 @@ type serverPeer struct { *peer.Peer - connReq *connmgr.ConnReq - server *ChainService - persistent bool - continueHash *chainhash.Hash - relayMtx sync.Mutex - requestQueue []*wire.InvVect - requestedFilters map[chainhash.Hash]bool - requestedBlocks map[chainhash.Hash]struct{} - knownAddresses map[string]struct{} - banScore connmgr.DynamicBanScore - quit chan struct{} + connReq *connmgr.ConnReq + server *ChainService + persistent bool + continueHash *chainhash.Hash + relayMtx sync.Mutex + requestQueue []*wire.InvVect + requestedCFilters map[chainhash.Hash]bool + requestedCFHeaders map[cfhRequest]int + requestedBlocks map[chainhash.Hash]struct{} + knownAddresses map[string]struct{} + banScore connmgr.DynamicBanScore + quit chan struct{} // The following chans are used to sync blockmanager and server. blockProcessed chan struct{} } @@ -131,13 +141,14 @@ type serverPeer struct { // the caller. func newServerPeer(s *ChainService, isPersistent bool) *serverPeer { return &serverPeer{ - server: s, - persistent: isPersistent, - requestedFilters: make(map[chainhash.Hash]bool), - requestedBlocks: make(map[chainhash.Hash]struct{}), - knownAddresses: make(map[string]struct{}), - quit: make(chan struct{}), - blockProcessed: make(chan struct{}, 1), + server: s, + persistent: isPersistent, + requestedCFilters: make(map[chainhash.Hash]bool), + requestedBlocks: make(map[chainhash.Hash]struct{}), + requestedCFHeaders: make(map[cfhRequest]int), + knownAddresses: make(map[string]struct{}), + quit: make(chan struct{}), + blockProcessed: make(chan struct{}, 1), } } @@ -199,7 +210,7 @@ func (sp *serverPeer) addBanScore(persistent, transient uint32, reason string) { // pushGetCFHeadersMsg sends a getcfheaders message for the provided block // locator and stop hash to the connected peer. func (sp *serverPeer) pushGetCFHeadersMsg(locator blockchain.BlockLocator, - stopHash *chainhash.Hash) error { + stopHash *chainhash.Hash, ext bool) error { msg := wire.NewMsgGetCFHeaders() msg.HashStop = *stopHash for _, hash := range locator { @@ -208,6 +219,7 @@ func (sp *serverPeer) pushGetCFHeadersMsg(locator blockchain.BlockLocator, return err } } + msg.Extended = ext sp.QueueMessage(msg, nil) return nil } @@ -424,14 +436,17 @@ func (sp *serverPeer) OnReject(_ *peer.Peer, msg *wire.MsgReject) { // OnCFHeaders is invoked when a peer receives a cfheaders bitcoin message and // is used to notify the server about a list of committed filter headers. -func (sp *serverPeer) OnCFHeaders(_ *peer.Peer, msg *wire.MsgCFHeaders) { - log.Trace("Got cfheaders message!") +func (sp *serverPeer) OnCFHeaders(p *peer.Peer, msg *wire.MsgCFHeaders) { + log.Tracef("Got cfheaders message with %d items from %s", + len(msg.HeaderHashes), p.Addr()) + sp.server.blockManager.QueueCFHeaders(msg, sp) } // OnCFilter is invoked when a peer receives a cfilter bitcoin message and is // used to notify the server about a committed filter. -func (sp *serverPeer) OnCFilter(_ *peer.Peer, msg *wire.MsgCFilter) { - +func (sp *serverPeer) OnCFilter(p *peer.Peer, msg *wire.MsgCFilter) { + log.Tracef("Got cfilter message from %s", p.Addr()) + sp.server.blockManager.QueueCFilter(msg, sp) } // OnAddr is invoked when a peer receives an addr bitcoin message and is @@ -534,7 +549,7 @@ func (s *ChainService) BestSnapshot() (*waddrmgr.BlockStamp, error) { var best *waddrmgr.BlockStamp var err error err = s.namespace.View(func(tx walletdb.Tx) error { - best, err = SyncedTo(tx) + best, err = syncedTo(tx) return err }) if err != nil { @@ -549,11 +564,11 @@ func (s *ChainService) LatestBlockLocator() (blockchain.BlockLocator, error) { var locator blockchain.BlockLocator var err error err = s.namespace.View(func(tx walletdb.Tx) error { - best, err := SyncedTo(tx) + best, err := syncedTo(tx) if err != nil { return err } - locator = BlockLocatorFromHash(tx, best.Hash) + locator = blockLocatorFromHash(tx, best.Hash) return nil }) if err != nil { @@ -1236,6 +1251,16 @@ func (s *ChainService) ConnectNode(addr string, permanent bool) error { return <-replyChan } +// ForAllPeers runs a closure over all peers (outbound and persistent) to which +// the ChainService is connected. Nothing is returned because the peerState's +// ForAllPeers method doesn't return anything as the closure passed to it +// doesn't return anything. +func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) { + s.query <- forAllPeersMsg{ + closure: closure, + } +} + // UpdatePeerHeights updates the heights of all peers who have have announced // the latest connected main chain block, or a recognized orphan. These height // updates allow us to dynamically refresh peer heights, ensuring sync peer @@ -1295,14 +1320,14 @@ out: // Drain channels before exiting so nothing is left waiting around // to send. -cleanup: + /*cleanup: for { select { //case <-s.modifyRebroadcastInv: default: break cleanup } - } + }*/ s.wg.Done() } @@ -1343,7 +1368,7 @@ func (s *ChainService) GetBlockByHeight(height uint32) (wire.BlockHeader, var h uint32 var err error err = s.namespace.View(func(dbTx walletdb.Tx) error { - bh, h, err = GetBlockByHeight(dbTx, height) + bh, h, err = getBlockByHeight(dbTx, height) return err }) return bh, h, err @@ -1357,7 +1382,7 @@ func (s *ChainService) GetBlockByHash(hash chainhash.Hash) (wire.BlockHeader, var h uint32 var err error err = s.namespace.View(func(dbTx walletdb.Tx) error { - bh, h, err = GetBlockByHash(dbTx, hash) + bh, h, err = getBlockByHash(dbTx, hash) return err }) return bh, h, err @@ -1370,7 +1395,7 @@ func (s *ChainService) LatestBlock() (wire.BlockHeader, uint32, error) { var h uint32 var err error err = s.namespace.View(func(dbTx walletdb.Tx) error { - bh, h, err = LatestBlock(dbTx) + bh, h, err = latestBlock(dbTx) return err }) return bh, h, err @@ -1384,6 +1409,50 @@ func (s *ChainService) putBlock(header wire.BlockHeader, height uint32) error { }) } +// putBasicHeader puts a verified basic filter header in the ChainService +// database. +func (s *ChainService) putBasicHeader(blockHash chainhash.Hash, + filterTip chainhash.Hash) error { + return s.namespace.Update(func(dbTx walletdb.Tx) error { + return putBasicHeader(dbTx, blockHash, filterTip) + }) +} + +// putExtHeader puts a verified extended filter header in the ChainService +// database. +func (s *ChainService) putExtHeader(blockHash chainhash.Hash, + filterTip chainhash.Hash) error { + return s.namespace.Update(func(dbTx walletdb.Tx) error { + return putExtHeader(dbTx, blockHash, filterTip) + }) +} + +// GetBasicHeader gets a verified basic filter header from the ChainService +// database. +func (s *ChainService) GetBasicHeader(blockHash chainhash.Hash) (*chainhash.Hash, + error) { + var filterTip *chainhash.Hash + var err error + err = s.namespace.View(func(dbTx walletdb.Tx) error { + filterTip, err = getBasicHeader(dbTx, blockHash) + return err + }) + return filterTip, err +} + +// GetExtHeader gets a verified extended filter header from the ChainService +// database. +func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash, + error) { + var filterTip *chainhash.Hash + var err error + err = s.namespace.View(func(dbTx walletdb.Tx) error { + filterTip, err = getExtHeader(dbTx, blockHash) + return err + }) + return filterTip, err +} + // putMaxBlockHeight puts the max block height to the ChainService database. func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error { return s.namespace.Update(func(dbTx walletdb.Tx) error { @@ -1405,7 +1474,7 @@ func (s *ChainService) rollbackToHeight(height uint32) (*waddrmgr.BlockStamp, er var bs *waddrmgr.BlockStamp var err error err = s.namespace.Update(func(dbTx walletdb.Tx) error { - bs, err = SyncedTo(dbTx) + bs, err = syncedTo(dbTx) if err != nil { return err } diff --git a/spvsvc/spvchain/sync_test.go b/spvsvc/spvchain/sync_test.go index d121308..18f7d10 100644 --- a/spvsvc/spvchain/sync_test.go +++ b/spvsvc/spvchain/sync_test.go @@ -7,15 +7,19 @@ import ( "testing" "time" + "github.com/aakselrod/btctestlog" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/rpctest" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btclog" "github.com/btcsuite/btcwallet/spvsvc/spvchain" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/walletdb" _ "github.com/btcsuite/btcwallet/walletdb/bdb" ) +var logLevel = btclog.TraceLvl + func TestSetup(t *testing.T) { // Create a btcd SimNet node and generate 500 blocks h1, err := rpctest.New(&chaincfg.SimNetParams, nil, nil) @@ -129,14 +133,15 @@ func TestSetup(t *testing.T) { spvchain.Services = 0 spvchain.MaxPeers = 3 + spvchain.BanDuration = 5 * time.Second spvchain.RequiredServices = wire.SFNodeNetwork - /*logger, err := btctestlog.NewTestLogger(t) + logger, err := btctestlog.NewTestLogger(t) if err != nil { t.Fatalf("Could not set up logger: %s", err) } chainLogger := btclog.NewSubsystemLogger(logger, "CHAIN: ") - chainLogger.SetLevel(btclog.InfoLvl) - spvchain.UseLogger(chainLogger) //*/ + chainLogger.SetLevel(logLevel) + spvchain.UseLogger(chainLogger) svc, err := spvchain.NewChainService(config) if err != nil { t.Fatalf("Error creating ChainService: %s", err) @@ -223,7 +228,9 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, if err != nil { return err } - //t.Logf("Syncing to %d (%s)", knownBestHeight, knownBestHash) + if logLevel != btclog.Off { + t.Logf("Syncing to %d (%s)", knownBestHeight, knownBestHash) + } var haveBest *waddrmgr.BlockStamp haveBest, err = svc.BestSnapshot() if err != nil { @@ -244,7 +251,10 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, return fmt.Errorf("Couldn't get best snapshot from "+ "ChainService: %s", err) } - //t.Logf("Synced to %d (%s)", haveBest.Height, haveBest.Hash) + if logLevel != btclog.Off { + t.Logf("Synced to %d (%s)", haveBest.Height, + haveBest.Hash) + } } // Check if we're current if !svc.IsCurrent() {