diff --git a/spvsvc/spvchain/blockmanager.go b/spvsvc/spvchain/blockmanager.go index 6839872..d85d224 100644 --- a/spvsvc/spvchain/blockmanager.go +++ b/spvsvc/spvchain/blockmanager.go @@ -13,6 +13,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/btcsuite/btcutil/gcs" + "github.com/btcsuite/btcutil/gcs/builder" ) const ( @@ -1091,15 +1093,17 @@ func (b *blockManager) QueueCFHeaders(cfheaders *wire.MsgCFHeaders, // Check that the count is correct. This works even when the map lookup // fails as it returns 0 in that case. - if sp.requestedCFHeaders[cfhRequest{ + req := cfhRequest{ extended: cfheaders.Extended, stopHash: cfheaders.StopHash, - }] != len(cfheaders.HeaderHashes) { + } + if sp.requestedCFHeaders[req] != len(cfheaders.HeaderHashes) { log.Warnf("Received cfheaders message doesn't match any "+ "getcfheaders request. Peer %s is probably on a "+ "different chain -- ignoring", sp.Addr()) return } + delete(sp.requestedCFHeaders, req) // Track number of pending cfheaders messsages for both basic and // extended filters. @@ -1116,24 +1120,20 @@ func (b *blockManager) handleCFHeadersMsg(cfhmsg *cfheadersMsg) { // Grab the matching request we sent, as this message should correspond // to that, and delete it from the map on return as we're now handling // it. - req := cfhRequest{ - extended: cfhmsg.cfheaders.Extended, - stopHash: cfhmsg.cfheaders.StopHash, - } headerMap := b.basicHeaders pendingMsgs := &b.numBasicCFHeadersMsgs - if req.extended { + if cfhmsg.cfheaders.Extended { headerMap = b.extendedHeaders pendingMsgs = &b.numExtCFHeadersMsgs } - defer delete(cfhmsg.peer.requestedCFHeaders, req) atomic.AddInt32(pendingMsgs, -1) headerList := cfhmsg.cfheaders.HeaderHashes respLen := len(headerList) // Find the block header matching the last filter header, if any. el := b.headerList.Back() for el != nil { - if el.Value.(*headerNode).header.BlockHash() == req.stopHash { + if el.Value.(*headerNode).header.BlockHash() == + cfhmsg.cfheaders.StopHash { break } el = el.Prev() @@ -1166,12 +1166,13 @@ func (b *blockManager) handleCFHeadersMsg(cfhmsg *cfheadersMsg) { } b.intChan <- &processCFHeadersMsg{ earliestNode: node, - stopHash: req.stopHash, - extended: req.extended, + stopHash: cfhmsg.cfheaders.StopHash, + extended: cfhmsg.cfheaders.Extended, } log.Tracef("Processed cfheaders starting at %d(%s), ending at %s, from"+ " peer %s, extended: %t", node.height, node.header.BlockHash(), - req.stopHash, cfhmsg.peer.Addr(), req.extended) + cfhmsg.cfheaders.StopHash, cfhmsg.peer.Addr(), + cfhmsg.cfheaders.Extended) } // handleProcessCFHeadersMsg checks to see if we have enough cfheaders to make @@ -1278,7 +1279,7 @@ func (b *blockManager) handleProcessCFHeadersMsg(msg *processCFHeadersMsg) { *lastCFHeaderHeight = node.height // This is when we have conflicting information from // multiple peers. - // TODO: Handle this case. + // TODO: Handle this case as an adversarial condition. default: log.Warnf("Got more than 1 possible filter "+ "header for block %d (%s)", node.height, @@ -1311,12 +1312,77 @@ func (b *blockManager) QueueCFilter(cfilter *wire.MsgCFilter, sp *serverPeer) { return } + // Make sure we've actually requested this message. + req := cfRequest{ + extended: cfilter.Extended, + blockHash: cfilter.BlockHash, + } + if _, ok := sp.requestedCFilters[req]; !ok { + return + } + delete(sp.requestedCFilters, req) + b.peerChan <- &cfilterMsg{cfilter: cfilter, peer: sp} } // handleCFilterMsg handles cfilter messages from all peers. +// TODO: Refactor for checking adversarial conditions. func (b *blockManager) handleCFilterMsg(cfmsg *cfilterMsg) { - + readFunc := b.server.GetBasicHeader + putFunc := b.server.putBasicFilter + if cfmsg.cfilter.Extended { + readFunc = b.server.GetExtHeader + putFunc = b.server.putExtFilter + } + // Check that the cfilter we received fits correctly into the filter + // chain. + blockHeader, _, err := b.server.GetBlockByHash(cfmsg.cfilter.BlockHash) + if err != nil { + log.Warnf("Received cfilter for unknown block: %s, extended: "+ + "%t", cfmsg.cfilter.BlockHash, cfmsg.cfilter.Extended) + return + } + cfHeader, err := readFunc(cfmsg.cfilter.BlockHash) + if err != nil { + log.Warnf("Received cfilter for block with unknown cfheader: "+ + "%s, extended: %t", cfmsg.cfilter.BlockHash, + cfmsg.cfilter.Extended) + return + } + cfPrevHeader, err := readFunc(blockHeader.PrevBlock) + if err != nil { + log.Warnf("Received cfilter for block with unknown previous "+ + "cfheader: %s, extended: %t", blockHeader.PrevBlock, + cfmsg.cfilter.Extended) + return + } + filter, err := gcs.FromNBytes(builder.DefaultP, cfmsg.cfilter.Data) + if err != nil { + log.Warnf("Couldn't parse cfilter data for block: %s, "+ + "extended: %t", cfmsg.cfilter.BlockHash, + cfmsg.cfilter.Extended) + return + } + if makeHeaderForFilter(filter, *cfPrevHeader) != *cfHeader { + log.Warnf("Got cfilter that doesn't match cfheader chain for "+ + "block: %s, extended: %t", cfmsg.cfilter.BlockHash, + cfmsg.cfilter.Extended) + return + } + // Save the cfilter we received into the database. + err = putFunc(cfmsg.cfilter.BlockHash, filter) + if err != nil { + log.Warnf("Couldn't write cfilter to database for block: %s, "+ + "extended: %t", cfmsg.cfilter.BlockHash, + cfmsg.cfilter.Extended) + // Should we panic here? + return + } + // Notify the ChainService of the newly-found filter. + b.server.query <- processCFilterMsg{ + filter: filter, + extended: cfmsg.cfilter.Extended, + } } // checkHeaderSanity checks the PoW, and timestamp of a block header. diff --git a/spvsvc/spvchain/db.go b/spvsvc/spvchain/db.go index 0cbd786..5dcfecd 100644 --- a/spvsvc/spvchain/db.go +++ b/spvsvc/spvchain/db.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil/gcs" + "github.com/btcsuite/btcutil/gcs/builder" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/walletdb" ) @@ -153,8 +154,8 @@ func putExtFilter(tx walletdb.Tx, blockHash chainhash.Hash, return putFilter(tx, blockHash, extFilterBucketName, filter) } -// putHeader stores the provided filter, keyed to the block hash, in the -// appropriate filter bucket in the database. +// putHeader stores the provided header, keyed to the block hash, in the +// appropriate filter header bucket in the database. func putHeader(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte, filterTip chainhash.Hash) error { @@ -168,22 +169,50 @@ func putHeader(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte, return nil } -// putBasicHeader stores the provided filter, keyed to the block hash, in the -// basic filter bucket in the database. +// putBasicHeader stores the provided header, keyed to the block hash, in the +// basic filter header bucket in the database. func putBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash, filterTip chainhash.Hash) error { return putHeader(tx, blockHash, basicHeaderBucketName, filterTip) } -// putExtHeader stores the provided filter, keyed to the block hash, in the -// extended filter bucket in the database. +// putExtHeader stores the provided header, keyed to the block hash, in the +// extended filter header bucket in the database. func putExtHeader(tx walletdb.Tx, blockHash chainhash.Hash, filterTip chainhash.Hash) error { return putHeader(tx, blockHash, extHeaderBucketName, filterTip) } -// getHeader retrieves the provided filter, keyed to the block hash, from the +// getFilter retreives the filter, keyed to the provided block hash, from the // appropriate filter bucket in the database. +func getFilter(tx walletdb.Tx, blockHash chainhash.Hash, + bucketName []byte) (*gcs.Filter, error) { + bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) + + filterBytes := bucket.Get(blockHash[:]) + if len(filterBytes) == 0 { + return nil, fmt.Errorf("failed to get filter") + } + + return gcs.FromNBytes(builder.DefaultP, filterBytes) +} + +// getBasicFilter retrieves the filter, keyed to the provided block hash, from +// the basic filter bucket in the database. +func getBasicFilter(tx walletdb.Tx, blockHash chainhash.Hash) (*gcs.Filter, + error) { + return getFilter(tx, blockHash, basicFilterBucketName) +} + +// getExtFilter retrieves the filter, keyed to the provided block hash, from +// the extended filter bucket in the database. +func getExtFilter(tx walletdb.Tx, blockHash chainhash.Hash) (*gcs.Filter, + error) { + return getFilter(tx, blockHash, extFilterBucketName) +} + +// getHeader retrieves the header, keyed to the provided block hash, from the +// appropriate filter header bucket in the database. func getHeader(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte) (*chainhash.Hash, error) { @@ -191,22 +220,21 @@ func getHeader(tx walletdb.Tx, blockHash chainhash.Hash, filterTip := bucket.Get(blockHash[:]) if len(filterTip) == 0 { - return &chainhash.Hash{}, - fmt.Errorf("failed to get filter header") + return nil, fmt.Errorf("failed to get filter header") } return chainhash.NewHash(filterTip) } -// getBasicHeader retrieves the provided filter, keyed to the block hash, from -// the basic filter bucket in the database. +// getBasicHeader retrieves the header, keyed to the provided block hash, from +// the basic filter header bucket in the database. func getBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, error) { return getHeader(tx, blockHash, basicHeaderBucketName) } -// getExtHeader retrieves the provided filter, keyed to the block hash, from the -// extended filter bucket in the database. +// getExtHeader retrieves the header, keyed to the provided block hash, from the +// extended filter header bucket in the database. func getExtHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, error) { return getHeader(tx, blockHash, extHeaderBucketName) diff --git a/spvsvc/spvchain/notifications.go b/spvsvc/spvchain/notifications.go index f6668f7..f31dbdf 100644 --- a/spvsvc/spvchain/notifications.go +++ b/spvsvc/spvchain/notifications.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/addrmgr" "github.com/btcsuite/btcd/connmgr" + "github.com/btcsuite/btcutil/gcs" ) type getConnCountMsg struct { @@ -48,6 +49,11 @@ type forAllPeersMsg struct { closure func(*serverPeer) } +type processCFilterMsg struct { + filter *gcs.Filter + extended bool +} + // handleQuery is the central handler for all queries and commands from other // goroutines related to peer state. func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { @@ -156,4 +162,6 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { // forAllPeers method doesn't return anything. An error might be // useful in the future. } + //case processCFilterMsg: + // TODO: make this work } diff --git a/spvsvc/spvchain/spvchain.go b/spvsvc/spvchain/spvchain.go index c96624c..73e2353 100644 --- a/spvsvc/spvchain/spvchain.go +++ b/spvsvc/spvchain/spvchain.go @@ -17,6 +17,7 @@ import ( "github.com/btcsuite/btcd/peer" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/btcsuite/btcutil/gcs" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/walletdb" @@ -113,6 +114,12 @@ type cfhRequest struct { stopHash chainhash.Hash } +// cfRequest records which cfilters we've requested. +type cfRequest struct { + extended bool + blockHash chainhash.Hash +} + // serverPeer extends the peer to maintain state shared by the server and // the blockmanager. type serverPeer struct { @@ -127,7 +134,7 @@ type serverPeer struct { continueHash *chainhash.Hash relayMtx sync.Mutex requestQueue []*wire.InvVect - requestedCFilters map[chainhash.Hash]bool + requestedCFilters map[cfRequest]struct{} requestedCFHeaders map[cfhRequest]int requestedBlocks map[chainhash.Hash]struct{} knownAddresses map[string]struct{} @@ -143,7 +150,7 @@ func newServerPeer(s *ChainService, isPersistent bool) *serverPeer { return &serverPeer{ server: s, persistent: isPersistent, - requestedCFilters: make(map[chainhash.Hash]bool), + requestedCFilters: make(map[cfRequest]struct{}), requestedBlocks: make(map[chainhash.Hash]struct{}), requestedCFHeaders: make(map[cfhRequest]int), knownAddresses: make(map[string]struct{}), @@ -224,6 +231,20 @@ func (sp *serverPeer) pushGetCFHeadersMsg(locator blockchain.BlockLocator, return nil } +// pushGetCFilterMsg sends a getcfilter message for the provided block hash to +// the connected peer. +func (sp *serverPeer) pushGetCFilterMsg(blockHash *chainhash.Hash, + ext bool) error { + req := cfRequest{ + extended: ext, + blockHash: *blockHash, + } + sp.requestedCFilters[req] = struct{}{} + msg := wire.NewMsgGetCFilter(blockHash, ext) + sp.QueueMessage(msg, nil) + return nil +} + // pushSendHeadersMsg sends a sendheaders message to the connected peer. func (sp *serverPeer) pushSendHeadersMsg() error { if sp.VersionKnown() { @@ -1453,6 +1474,46 @@ func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash, return filterTip, err } +// putBasicFilter puts a verified basic filter in the ChainService database. +func (s *ChainService) putBasicFilter(blockHash chainhash.Hash, + filter *gcs.Filter) error { + return s.namespace.Update(func(dbTx walletdb.Tx) error { + return putBasicFilter(dbTx, blockHash, filter) + }) +} + +// putExtFilter puts a verified extended filter in the ChainService database. +func (s *ChainService) putExtFilter(blockHash chainhash.Hash, + filter *gcs.Filter) error { + return s.namespace.Update(func(dbTx walletdb.Tx) error { + return putExtFilter(dbTx, blockHash, filter) + }) +} + +// GetBasicFilter gets a verified basic filter from the ChainService database. +func (s *ChainService) GetBasicFilter(blockHash chainhash.Hash) (*gcs.Filter, + error) { + var filter *gcs.Filter + var err error + err = s.namespace.View(func(dbTx walletdb.Tx) error { + filter, err = getBasicFilter(dbTx, blockHash) + return err + }) + return filter, err +} + +// GetExtFilter gets a verified extended filter from the ChainService database. +func (s *ChainService) GetExtFilter(blockHash chainhash.Hash) (*gcs.Filter, + error) { + var filter *gcs.Filter + var err error + err = s.namespace.View(func(dbTx walletdb.Tx) error { + filter, err = getExtFilter(dbTx, blockHash) + return err + }) + return filter, err +} + // putMaxBlockHeight puts the max block height to the ChainService database. func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error { return s.namespace.Update(func(dbTx walletdb.Tx) error {