diff --git a/spvsvc/spvchain/blockmanager.go b/spvsvc/spvchain/blockmanager.go index b0818cb..20b0b63 100644 --- a/spvsvc/spvchain/blockmanager.go +++ b/spvsvc/spvchain/blockmanager.go @@ -934,7 +934,7 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // with the rest of the headers in the message as if // nothing has happened. b.syncPeer = hmsg.peer - _, err = b.server.rollbackToHeight(backHeight) + _, err = b.server.rollBackToHeight(backHeight) if err != nil { log.Criticalf("Rollback failed: %s", err) @@ -974,7 +974,8 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { } // Verify the header at the next checkpoint height matches. - if b.nextCheckpoint != nil && node.height == b.nextCheckpoint.Height { + if b.nextCheckpoint != nil && + node.height == b.nextCheckpoint.Height { nodeHash := node.header.BlockHash() if nodeHash.IsEqual(b.nextCheckpoint.Hash) { receivedCheckpoint = true @@ -988,12 +989,14 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { "disconnecting", node.height, nodeHash, hmsg.peer.Addr(), b.nextCheckpoint.Hash) - prevCheckpoint := b.findPreviousHeaderCheckpoint(node.height) + prevCheckpoint := + b.findPreviousHeaderCheckpoint( + node.height) log.Infof("Rolling back to previous validated "+ "checkpoint at height %d/hash %s", prevCheckpoint.Height, prevCheckpoint.Hash) - _, err := b.server.rollbackToHeight(uint32( + _, err := b.server.rollBackToHeight(uint32( prevCheckpoint.Height)) if err != nil { log.Criticalf("Rollback failed: %s", @@ -1371,7 +1374,7 @@ func (b *blockManager) calcNextRequiredDifficulty(newBlockTime time.Time, // Get the block node at the previous retarget (targetTimespan days // worth of blocks). - firstNode, _, err := b.server.GetBlockByHeight( + firstNode, err := b.server.GetBlockByHeight( uint32(lastNode.height + 1 - b.blocksPerRetarget)) if err != nil { return 0, err @@ -1451,7 +1454,8 @@ func (b *blockManager) findPrevTestNetDifficulty(hList *list.List) (uint32, erro if el != nil { iterNode = el.Value.(*headerNode).header } else { - node, _, err := b.server.GetBlockByHeight(uint32(iterHeight)) + node, err := b.server.GetBlockByHeight( + uint32(iterHeight)) if err != nil { log.Errorf("GetBlockByHeight: %s", err) return 0, err diff --git a/spvsvc/spvchain/db.go b/spvsvc/spvchain/db.go index c36aef3..945a051 100644 --- a/spvsvc/spvchain/db.go +++ b/spvsvc/spvchain/db.go @@ -7,7 +7,6 @@ import ( "time" "github.com/btcsuite/btcd/blockchain" - "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil/gcs" @@ -30,7 +29,7 @@ var ( // Key names for various database fields. var ( // Bucket names. - spvBucketName = []byte("spv") + spvBucketName = []byte("spvchain") blockHeaderBucketName = []byte("bh") basicHeaderBucketName = []byte("bfh") basicFilterBucketName = []byte("bf") @@ -59,475 +58,725 @@ func uint64ToBytes(number uint64) []byte { return buf } +// dbUpdateOption is a function type for the kind of DB update to be done. +// These can call each other and dbViewOption functions; however, they cannot +// be called by dbViewOption functions. +type dbUpdateOption func(bucket walletdb.ReadWriteBucket) error + +// dbViewOption is a funciton type for the kind of data to be fetched from DB. +// These can call each other and can be called by dbUpdateOption functions; +// however, they cannot call dbUpdateOption functions. +type dbViewOption func(bucket walletdb.ReadBucket) error + // fetchDBVersion fetches the current manager version from the database. -func fetchDBVersion(tx walletdb.Tx) (uint32, error) { - bucket := tx.RootBucket().Bucket(spvBucketName) - verBytes := bucket.Get(dbVersionName) - if verBytes == nil { - return 0, fmt.Errorf("required version number not stored in " + - "database") +func (s *ChainService) fetchDBVersion() (uint32, error) { + var version uint32 + err := s.dbView(fetchDBVersion(&version)) + return version, err +} + +func fetchDBVersion(version *uint32) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + verBytes := bucket.Get(dbVersionName) + if verBytes == nil { + return fmt.Errorf("required version number not " + + "stored in database") + } + *version = binary.LittleEndian.Uint32(verBytes) + return nil } - version := binary.LittleEndian.Uint32(verBytes) - return version, nil } // putDBVersion stores the provided version to the database. -func putDBVersion(tx walletdb.Tx, version uint32) error { - bucket := tx.RootBucket().Bucket(spvBucketName) +func (s *ChainService) putDBVersion(version uint32) error { + return s.dbUpdate(putDBVersion(version)) +} - verBytes := uint32ToBytes(version) - return bucket.Put(dbVersionName, verBytes) +func putDBVersion(version uint32) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + verBytes := uint32ToBytes(version) + return bucket.Put(dbVersionName, verBytes) + } } // putMaxBlockHeight stores the max block height to the database. -func putMaxBlockHeight(tx walletdb.Tx, maxBlockHeight uint32) error { - bucket := tx.RootBucket().Bucket(spvBucketName) +func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error { + return s.dbUpdate(putMaxBlockHeight(maxBlockHeight)) +} - maxBlockHeightBytes := uint32ToBytes(maxBlockHeight) - err := bucket.Put(maxBlockHeightName, maxBlockHeightBytes) - if err != nil { - return fmt.Errorf("failed to store max block height: %s", err) +func putMaxBlockHeight(maxBlockHeight uint32) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + maxBlockHeightBytes := uint32ToBytes(maxBlockHeight) + err := bucket.Put(maxBlockHeightName, maxBlockHeightBytes) + if err != nil { + return fmt.Errorf("failed to store max block height: %s", err) + } + return nil } - return nil } // putBlock stores the provided block header and height, keyed to the block // hash, in the database. -func putBlock(tx walletdb.Tx, header wire.BlockHeader, height uint32) error { - var buf bytes.Buffer - err := header.Serialize(&buf) - if err != nil { - return err - } - _, err = buf.Write(uint32ToBytes(height)) - if err != nil { - return err - } +func (s *ChainService) putBlock(header wire.BlockHeader, height uint32) error { + return s.dbUpdate(putBlock(header, height)) +} - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) - blockHash := header.BlockHash() - - err = bucket.Put(blockHash[:], buf.Bytes()) - if err != nil { - return fmt.Errorf("failed to store SPV block info: %s", err) +func putBlock(header wire.BlockHeader, height uint32) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + var buf bytes.Buffer + err := header.Serialize(&buf) + if err != nil { + return err + } + _, err = buf.Write(uint32ToBytes(height)) + if err != nil { + return err + } + blockHash := header.BlockHash() + bhBucket := bucket.NestedReadWriteBucket(blockHeaderBucketName) + err = bhBucket.Put(blockHash[:], buf.Bytes()) + if err != nil { + return fmt.Errorf("failed to store SPV block info: %s", + err) + } + err = bhBucket.Put(uint32ToBytes(height), blockHash[:]) + if err != nil { + return fmt.Errorf("failed to store block height info:"+ + " %s", err) + } + return nil } - - err = bucket.Put(uint32ToBytes(height), blockHash[:]) - if err != nil { - return fmt.Errorf("failed to store block height info: %s", err) - } - - return nil } // putFilter stores the provided filter, keyed to the block hash, in the // appropriate filter bucket in the database. -func putFilter(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte, +func (s *ChainService) putFilter(blockHash chainhash.Hash, bucketName []byte, filter *gcs.Filter) error { - var buf bytes.Buffer - _, err := buf.Write(filter.NBytes()) - if err != nil { - return err + return s.dbUpdate(putFilter(blockHash, bucketName, filter)) +} + +func putFilter(blockHash chainhash.Hash, bucketName []byte, + filter *gcs.Filter) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + var buf bytes.Buffer + _, err := buf.Write(filter.NBytes()) + if err != nil { + return err + } + filterBucket := bucket.NestedReadWriteBucket(bucketName) + err = filterBucket.Put(blockHash[:], buf.Bytes()) + if err != nil { + return fmt.Errorf("failed to store filter: %s", err) + } + return nil } - - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) - - err = bucket.Put(blockHash[:], buf.Bytes()) - if err != nil { - return fmt.Errorf("failed to store filter: %s", err) - } - - return nil } // putBasicFilter stores the provided filter, keyed to the block hash, in the // basic filter bucket in the database. -func putBasicFilter(tx walletdb.Tx, blockHash chainhash.Hash, +func (s *ChainService) putBasicFilter(blockHash chainhash.Hash, filter *gcs.Filter) error { - return putFilter(tx, blockHash, basicFilterBucketName, filter) + return s.dbUpdate(putBasicFilter(blockHash, filter)) +} + +func putBasicFilter(blockHash chainhash.Hash, + filter *gcs.Filter) dbUpdateOption { + return putFilter(blockHash, basicFilterBucketName, filter) } // putExtFilter stores the provided filter, keyed to the block hash, in the // extended filter bucket in the database. -func putExtFilter(tx walletdb.Tx, blockHash chainhash.Hash, +func (s *ChainService) putExtFilter(blockHash chainhash.Hash, filter *gcs.Filter) error { - return putFilter(tx, blockHash, extFilterBucketName, filter) + return s.dbUpdate(putExtFilter(blockHash, filter)) +} + +func putExtFilter(blockHash chainhash.Hash, + filter *gcs.Filter) dbUpdateOption { + return putFilter(blockHash, extFilterBucketName, filter) } // putHeader stores the provided header, keyed to the block hash, in the // appropriate filter header bucket in the database. -func putHeader(tx walletdb.Tx, blockHash chainhash.Hash, bucketName []byte, +func (s *ChainService) putHeader(blockHash chainhash.Hash, bucketName []byte, filterTip chainhash.Hash) error { + return s.dbUpdate(putHeader(blockHash, bucketName, filterTip)) +} - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) - - err := bucket.Put(blockHash[:], filterTip[:]) - if err != nil { - return fmt.Errorf("failed to store filter header: %s", err) +func putHeader(blockHash chainhash.Hash, bucketName []byte, + filterTip chainhash.Hash) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + headerBucket := bucket.NestedReadWriteBucket(bucketName) + err := headerBucket.Put(blockHash[:], filterTip[:]) + if err != nil { + return fmt.Errorf("failed to store filter header: %s", err) + } + return nil } - - return nil } // putBasicHeader stores the provided header, keyed to the block hash, in the // basic filter header bucket in the database. -func putBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash, +func (s *ChainService) putBasicHeader(blockHash chainhash.Hash, filterTip chainhash.Hash) error { - return putHeader(tx, blockHash, basicHeaderBucketName, filterTip) + return s.dbUpdate(putBasicHeader(blockHash, filterTip)) +} + +func putBasicHeader(blockHash chainhash.Hash, + filterTip chainhash.Hash) dbUpdateOption { + return putHeader(blockHash, basicHeaderBucketName, filterTip) } // putExtHeader stores the provided header, keyed to the block hash, in the // extended filter header bucket in the database. -func putExtHeader(tx walletdb.Tx, blockHash chainhash.Hash, +func (s *ChainService) putExtHeader(blockHash chainhash.Hash, filterTip chainhash.Hash) error { - return putHeader(tx, blockHash, extHeaderBucketName, filterTip) + return s.dbUpdate(putExtHeader(blockHash, filterTip)) +} + +func putExtHeader(blockHash chainhash.Hash, + filterTip chainhash.Hash) dbUpdateOption { + return putHeader(blockHash, extHeaderBucketName, filterTip) } // getFilter retreives the filter, keyed to the provided block hash, from the // appropriate filter bucket in the database. -func getFilter(tx walletdb.Tx, blockHash chainhash.Hash, +func (s *ChainService) getFilter(blockHash chainhash.Hash, bucketName []byte) (*gcs.Filter, error) { - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) + var filter gcs.Filter + err := s.dbView(getFilter(blockHash, bucketName, &filter)) + return &filter, err +} - filterBytes := bucket.Get(blockHash[:]) - if len(filterBytes) == 0 { - return nil, fmt.Errorf("failed to get filter") +func getFilter(blockHash chainhash.Hash, bucketName []byte, + filter *gcs.Filter) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + filterBucket := bucket.NestedReadBucket(bucketName) + filterBytes := filterBucket.Get(blockHash[:]) + if len(filterBytes) == 0 { + return fmt.Errorf("failed to get filter") + } + calcFilter, err := gcs.FromNBytes(builder.DefaultP, filterBytes) + if calcFilter != nil { + *filter = *calcFilter + } + return err } - - return gcs.FromNBytes(builder.DefaultP, filterBytes) } -// getBasicFilter retrieves the filter, keyed to the provided block hash, from +// GetBasicFilter retrieves the filter, keyed to the provided block hash, from // the basic filter bucket in the database. -func getBasicFilter(tx walletdb.Tx, blockHash chainhash.Hash) (*gcs.Filter, +func (s *ChainService) GetBasicFilter(blockHash chainhash.Hash) (*gcs.Filter, error) { - return getFilter(tx, blockHash, basicFilterBucketName) + var filter gcs.Filter + err := s.dbView(getBasicFilter(blockHash, &filter)) + return &filter, err } -// getExtFilter retrieves the filter, keyed to the provided block hash, from +func getBasicFilter(blockHash chainhash.Hash, filter *gcs.Filter) dbViewOption { + return getFilter(blockHash, basicFilterBucketName, filter) +} + +// GetExtFilter retrieves the filter, keyed to the provided block hash, from // the extended filter bucket in the database. -func getExtFilter(tx walletdb.Tx, blockHash chainhash.Hash) (*gcs.Filter, +func (s *ChainService) GetExtFilter(blockHash chainhash.Hash) (*gcs.Filter, error) { - return getFilter(tx, blockHash, extFilterBucketName) + var filter gcs.Filter + err := s.dbView(getExtFilter(blockHash, &filter)) + return &filter, err +} + +func getExtFilter(blockHash chainhash.Hash, filter *gcs.Filter) dbViewOption { + return getFilter(blockHash, extFilterBucketName, filter) } // getHeader retrieves the header, keyed to the provided block hash, from the // appropriate filter header bucket in the database. -func getHeader(tx walletdb.Tx, blockHash chainhash.Hash, +func (s *ChainService) getHeader(blockHash chainhash.Hash, bucketName []byte) (*chainhash.Hash, error) { - - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(bucketName) - - filterTip := bucket.Get(blockHash[:]) - if len(filterTip) == 0 { - return nil, fmt.Errorf("failed to get filter header") - } - - return chainhash.NewHash(filterTip) + var filterTip chainhash.Hash + err := s.dbView(getHeader(blockHash, bucketName, &filterTip)) + return &filterTip, err } -// getBasicHeader retrieves the header, keyed to the provided block hash, from +func getHeader(blockHash chainhash.Hash, bucketName []byte, + filterTip *chainhash.Hash) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + headerBucket := bucket.NestedReadBucket(bucketName) + headerBytes := headerBucket.Get(blockHash[:]) + if len(filterTip) == 0 { + return fmt.Errorf("failed to get filter header") + } + calcFilterTip, err := chainhash.NewHash(headerBytes) + if calcFilterTip != nil { + *filterTip = *calcFilterTip + } + return err + } +} + +// GetBasicHeader retrieves the header, keyed to the provided block hash, from // the basic filter header bucket in the database. -func getBasicHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, - error) { - return getHeader(tx, blockHash, basicHeaderBucketName) +func (s *ChainService) GetBasicHeader(blockHash chainhash.Hash) ( + *chainhash.Hash, error) { + var filterTip chainhash.Hash + err := s.dbView(getBasicHeader(blockHash, &filterTip)) + return &filterTip, err } -// getExtHeader retrieves the header, keyed to the provided block hash, from the +func getBasicHeader(blockHash chainhash.Hash, + filterTip *chainhash.Hash) dbViewOption { + return getHeader(blockHash, basicHeaderBucketName, filterTip) +} + +// GetExtHeader retrieves the header, keyed to the provided block hash, from the // extended filter header bucket in the database. -func getExtHeader(tx walletdb.Tx, blockHash chainhash.Hash) (*chainhash.Hash, +func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash, error) { - return getHeader(tx, blockHash, extHeaderBucketName) + var filterTip chainhash.Hash + err := s.dbView(getExtHeader(blockHash, &filterTip)) + return &filterTip, err } -// rollbackLastBlock rolls back the last known block and returns the BlockStamp +func getExtHeader(blockHash chainhash.Hash, + filterTip *chainhash.Hash) dbViewOption { + return getHeader(blockHash, extHeaderBucketName, filterTip) +} + +// rollBackLastBlock rolls back the last known block and returns the BlockStamp // representing the new last known block. -func rollbackLastBlock(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { - bs, err := syncedTo(tx) - if err != nil { - return nil, err - } - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) - err = bucket.Delete(bs.Hash[:]) - if err != nil { - return nil, err - } - err = bucket.Delete(uint32ToBytes(uint32(bs.Height))) - if err != nil { - return nil, err - } - err = putMaxBlockHeight(tx, uint32(bs.Height-1)) - if err != nil { - return nil, err - } - return syncedTo(tx) +func (s *ChainService) rollBackLastBlock() (*waddrmgr.BlockStamp, error) { + var bs waddrmgr.BlockStamp + err := s.dbUpdate(rollBackLastBlock(&bs)) + return &bs, err } -// getBlockByHash retrieves the block header, filter, and filter tip, based on +func rollBackLastBlock(bs *waddrmgr.BlockStamp) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + headerBucket := bucket.NestedReadWriteBucket( + blockHeaderBucketName) + var sync waddrmgr.BlockStamp + err := syncedTo(&sync)(bucket) + if err != nil { + return err + } + err = headerBucket.Delete(sync.Hash[:]) + if err != nil { + return err + } + err = headerBucket.Delete(uint32ToBytes(uint32(sync.Height))) + if err != nil { + return err + } + err = putMaxBlockHeight(uint32(sync.Height - 1))(bucket) + if err != nil { + return err + } + sync = waddrmgr.BlockStamp{} + err = syncedTo(&sync)(bucket) + if sync != (waddrmgr.BlockStamp{}) { + *bs = sync + } + return err + } +} + +// rollBackToHeight rolls back all blocks until it hits the specified height. +func (s *ChainService) rollBackToHeight(height uint32) (*waddrmgr.BlockStamp, error) { + var bs waddrmgr.BlockStamp + err := s.dbUpdate(rollBackToHeight(height, &bs)) + return &bs, err +} + +func rollBackToHeight(height uint32, bs *waddrmgr.BlockStamp) dbUpdateOption { + return func(bucket walletdb.ReadWriteBucket) error { + err := syncedTo(bs)(bucket) + if err != nil { + return err + } + for uint32(bs.Height) > height { + err = rollBackLastBlock(bs)(bucket) + if err != nil { + return err + } + } + return nil + } +} + +// GetBlockByHash retrieves the block header, filter, and filter tip, based on // the provided block hash, from the database. -func getBlockByHash(tx walletdb.Tx, blockHash chainhash.Hash) (wire.BlockHeader, - uint32, error) { - //chainhash.Hash, chainhash.Hash, - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) - blockBytes := bucket.Get(blockHash[:]) - if len(blockBytes) == 0 { - return wire.BlockHeader{}, 0, - fmt.Errorf("failed to retrieve block info for hash: %s", - blockHash) - } - - buf := bytes.NewReader(blockBytes[:wire.MaxBlockHeaderPayload]) +func (s *ChainService) GetBlockByHash(blockHash chainhash.Hash) ( + wire.BlockHeader, uint32, error) { var header wire.BlockHeader - err := header.Deserialize(buf) - if err != nil { - return wire.BlockHeader{}, 0, - fmt.Errorf("failed to deserialize block header for "+ - "hash: %s", blockHash) - } - - height := binary.LittleEndian.Uint32( - blockBytes[wire.MaxBlockHeaderPayload : wire.MaxBlockHeaderPayload+4]) - - return header, height, nil + var height uint32 + err := s.dbView(getBlockByHash(blockHash, &header, &height)) + return header, height, err } -// getBlockHashByHeight retrieves the hash of a block by its height. -func getBlockHashByHeight(tx walletdb.Tx, height uint32) (chainhash.Hash, +func getBlockByHash(blockHash chainhash.Hash, header *wire.BlockHeader, + height *uint32) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + headerBucket := bucket.NestedReadBucket(blockHeaderBucketName) + blockBytes := headerBucket.Get(blockHash[:]) + if len(blockBytes) < wire.MaxBlockHeaderPayload+4 { + return fmt.Errorf("failed to retrieve block info for"+ + " hash %s: want %d bytes, got %d.", blockHash, + wire.MaxBlockHeaderPayload+4, len(blockBytes)) + } + buf := bytes.NewReader(blockBytes[:wire.MaxBlockHeaderPayload]) + err := header.Deserialize(buf) + if err != nil { + return fmt.Errorf("failed to deserialize block header "+ + "for hash: %s", blockHash) + } + *height = binary.LittleEndian.Uint32( + blockBytes[wire.MaxBlockHeaderPayload : wire.MaxBlockHeaderPayload+4]) + return nil + } +} + +// GetBlockHashByHeight retrieves the hash of a block by its height. +func (s *ChainService) GetBlockHashByHeight(height uint32) (chainhash.Hash, error) { - bucket := tx.RootBucket().Bucket(spvBucketName).Bucket(blockHeaderBucketName) - var hash chainhash.Hash - hashBytes := bucket.Get(uint32ToBytes(height)) - if hashBytes == nil { - return hash, fmt.Errorf("no block hash for height %d", height) - } - hash.SetBytes(hashBytes) - return hash, nil + var blockHash chainhash.Hash + err := s.dbView(getBlockHashByHeight(height, &blockHash)) + return blockHash, err } -// getBlockByHeight retrieves a block's information by its height. -func getBlockByHeight(tx walletdb.Tx, height uint32) (wire.BlockHeader, uint32, +func getBlockHashByHeight(height uint32, + blockHash *chainhash.Hash) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + headerBucket := bucket.NestedReadBucket(blockHeaderBucketName) + hashBytes := headerBucket.Get(uint32ToBytes(height)) + if hashBytes == nil { + return fmt.Errorf("no block hash for height %d", height) + } + blockHash.SetBytes(hashBytes) + return nil + } +} + +// GetBlockByHeight retrieves a block's information by its height. +func (s *ChainService) GetBlockByHeight(height uint32) (wire.BlockHeader, error) { - // chainhash.Hash, chainhash.Hash - blockHash, err := getBlockHashByHeight(tx, height) - if err != nil { - return wire.BlockHeader{}, 0, err - } - - return getBlockByHash(tx, blockHash) + var header wire.BlockHeader + err := s.dbView(getBlockByHeight(height, &header)) + return header, err } -// syncedTo retrieves the most recent block's height and hash. -func syncedTo(tx walletdb.Tx) (*waddrmgr.BlockStamp, error) { - header, height, err := latestBlock(tx) - if err != nil { - return nil, err +func getBlockByHeight(height uint32, header *wire.BlockHeader) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + var blockHash chainhash.Hash + err := getBlockHashByHeight(height, &blockHash)(bucket) + if err != nil { + return err + } + var gotHeight uint32 + err = getBlockByHash(blockHash, header, &gotHeight)(bucket) + if err != nil { + return err + } + if gotHeight != height { + return fmt.Errorf("Got height %d for block at "+ + "requested height %d", gotHeight, height) + } + return nil } - var blockStamp waddrmgr.BlockStamp - blockStamp.Hash = header.BlockHash() - blockStamp.Height = int32(height) - return &blockStamp, nil } -// latestBlock retrieves all the info about the latest stored block. -func latestBlock(tx walletdb.Tx) (wire.BlockHeader, uint32, error) { - bucket := tx.RootBucket().Bucket(spvBucketName) +// BestSnapshot is a synonym for SyncedTo +func (s *ChainService) BestSnapshot() (*waddrmgr.BlockStamp, error) { + return s.SyncedTo() +} - maxBlockHeightBytes := bucket.Get(maxBlockHeightName) - if maxBlockHeightBytes == nil { - return wire.BlockHeader{}, 0, - fmt.Errorf("no max block height stored") - } +// SyncedTo retrieves the most recent block's height and hash. +func (s *ChainService) SyncedTo() (*waddrmgr.BlockStamp, error) { + var bs waddrmgr.BlockStamp + err := s.dbView(syncedTo(&bs)) + return &bs, err +} - maxBlockHeight := binary.LittleEndian.Uint32(maxBlockHeightBytes) - header, height, err := getBlockByHeight(tx, maxBlockHeight) - if err != nil { - return wire.BlockHeader{}, 0, err +func syncedTo(bs *waddrmgr.BlockStamp) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + var header wire.BlockHeader + var height uint32 + err := latestBlock(&header, &height)(bucket) + if err != nil { + return err + } + bs.Hash = header.BlockHash() + bs.Height = int32(height) + return nil } - if height != maxBlockHeight { - return wire.BlockHeader{}, 0, - fmt.Errorf("max block height inconsistent") +} + +// LatestBlock retrieves latest stored block's header and height. +func (s *ChainService) LatestBlock() (wire.BlockHeader, uint32, error) { + var bh wire.BlockHeader + var h uint32 + err := s.dbView(latestBlock(&bh, &h)) + return bh, h, err +} + +func latestBlock(header *wire.BlockHeader, height *uint32) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + maxBlockHeightBytes := bucket.Get(maxBlockHeightName) + if maxBlockHeightBytes == nil { + return fmt.Errorf("no max block height stored") + } + *height = binary.LittleEndian.Uint32(maxBlockHeightBytes) + return getBlockByHeight(*height, header)(bucket) + } +} + +// BlockLocatorFromHash returns a block locator based on the provided hash. +func (s *ChainService) BlockLocatorFromHash(hash chainhash.Hash) ( + blockchain.BlockLocator, error) { + var locator blockchain.BlockLocator + err := s.dbView(blockLocatorFromHash(hash, &locator)) + return locator, err +} + +func blockLocatorFromHash(hash chainhash.Hash, + locator *blockchain.BlockLocator) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + // Append the initial hash + *locator = append(*locator, &hash) + // If hash isn't found in DB or this is the genesis block, return + // the locator as is + var header wire.BlockHeader + var height uint32 + err := getBlockByHash(hash, &header, &height)(bucket) + if (err != nil) || (height == 0) { + return nil + } + + decrement := uint32(1) + for (height > 0) && (len(*locator) < wire.MaxBlockLocatorsPerMsg) { + // Decrement by 1 for the first 10 blocks, then double the + // jump until we get to the genesis hash + if len(*locator) > 10 { + decrement *= 2 + } + if decrement > height { + height = 0 + } else { + height -= decrement + } + var blockHash chainhash.Hash + err := getBlockHashByHeight(height, &blockHash)(bucket) + if err != nil { + return nil + } + *locator = append(*locator, &blockHash) + } + return nil + } +} + +// LatestBlockLocator returns the block locator for the latest known block +// stored in the database. +func (s *ChainService) LatestBlockLocator() (blockchain.BlockLocator, error) { + var locator blockchain.BlockLocator + err := s.dbView(latestBlockLocator(&locator)) + return locator, err +} + +func latestBlockLocator(locator *blockchain.BlockLocator) dbViewOption { + return func(bucket walletdb.ReadBucket) error { + var best waddrmgr.BlockStamp + err := syncedTo(&best)(bucket) + if err != nil { + return err + } + return blockLocatorFromHash(best.Hash, locator)(bucket) } - return header, height, nil } // CheckConnectivity cycles through all of the block headers, from last to // first, and makes sure they all connect to each other. -func CheckConnectivity(tx walletdb.Tx) error { - header, height, err := latestBlock(tx) - if err != nil { - return fmt.Errorf("Couldn't retrieve latest block: %s", err) - } - for height > 0 { - newheader, newheight, err := getBlockByHash(tx, - header.PrevBlock) - if err != nil { - return fmt.Errorf("Couldn't retrieve block %s: %s", - header.PrevBlock, err) - } - if newheader.BlockHash() != header.PrevBlock { - return fmt.Errorf("Block %s doesn't match block %s's "+ - "PrevBlock (%s)", newheader.BlockHash(), - header.BlockHash(), header.PrevBlock) - } - if newheight != height-1 { - return fmt.Errorf("Block %s doesn't have correct "+ - "height: want %d, got %d", - newheader.BlockHash(), height-1, newheight) - } - header = newheader - height = newheight - } - return nil +func (s *ChainService) CheckConnectivity() error { + return s.dbView(checkConnectivity()) } -// blockLocatorFromHash returns a block locator based on the provided hash. -func blockLocatorFromHash(tx walletdb.Tx, hash chainhash.Hash) blockchain.BlockLocator { - locator := make(blockchain.BlockLocator, 0, wire.MaxBlockLocatorsPerMsg) - locator = append(locator, &hash) - - // If hash isn't found in DB or this is the genesis block, return - // the locator as is - _, height, err := getBlockByHash(tx, hash) - if (err != nil) || (height == 0) { - return locator - } - - decrement := uint32(1) - for (height > 0) && (len(locator) < wire.MaxBlockLocatorsPerMsg) { - // Decrement by 1 for the first 10 blocks, then double the - // jump until we get to the genesis hash - if len(locator) > 10 { - decrement *= 2 - } - if decrement > height { - height = 0 - } else { - height -= decrement - } - blockHash, err := getBlockHashByHeight(tx, height) +func checkConnectivity() dbViewOption { + return func(bucket walletdb.ReadBucket) error { + var header wire.BlockHeader + var height uint32 + err := latestBlock(&header, &height)(bucket) if err != nil { - return locator + return fmt.Errorf("Couldn't retrieve latest block: %s", + err) } - locator = append(locator, &blockHash) + for height > 0 { + var newHeader wire.BlockHeader + var newHeight uint32 + err := getBlockByHash(header.PrevBlock, &newHeader, + &newHeight)(bucket) + if err != nil { + return fmt.Errorf("Couldn't retrieve block %s:"+ + " %s", header.PrevBlock, err) + } + if newHeader.BlockHash() != header.PrevBlock { + return fmt.Errorf("Block %s doesn't match "+ + "block %s's PrevBlock (%s)", + newHeader.BlockHash(), + header.BlockHash(), header.PrevBlock) + } + if newHeight != height-1 { + return fmt.Errorf("Block %s doesn't have "+ + "correct height: want %d, got %d", + newHeader.BlockHash(), height-1, + newHeight) + } + header = newHeader + height = newHeight + } + return nil } - - return locator } // createSPVNS creates the initial namespace structure needed for all of the // SPV-related data. This includes things such as all of the buckets as well as // the version and creation date. -func createSPVNS(namespace walletdb.Namespace, params *chaincfg.Params) error { - err := namespace.Update(func(tx walletdb.Tx) error { - rootBucket := tx.RootBucket() - spvBucket, err := rootBucket.CreateBucketIfNotExists(spvBucketName) - if err != nil { - return fmt.Errorf("failed to create main bucket: %s", - err) - } - - _, err = spvBucket.CreateBucketIfNotExists(blockHeaderBucketName) - if err != nil { - return fmt.Errorf("failed to create block header "+ - "bucket: %s", err) - } - - _, err = spvBucket.CreateBucketIfNotExists(basicFilterBucketName) - if err != nil { - return fmt.Errorf("failed to create basic filter "+ - "bucket: %s", err) - } - - _, err = spvBucket.CreateBucketIfNotExists(basicHeaderBucketName) - if err != nil { - return fmt.Errorf("failed to create basic header "+ - "bucket: %s", err) - } - - _, err = spvBucket.CreateBucketIfNotExists(extFilterBucketName) - if err != nil { - return fmt.Errorf("failed to create extended filter "+ - "bucket: %s", err) - } - - _, err = spvBucket.CreateBucketIfNotExists(extHeaderBucketName) - if err != nil { - return fmt.Errorf("failed to create extended header "+ - "bucket: %s", err) - } - - createDate := spvBucket.Get(dbCreateDateName) - if createDate != nil { - log.Info("Wallet SPV namespace already created.") - return nil - } - - log.Info("Creating wallet SPV namespace.") - - basicFilter, err := builder.BuildBasicFilter( - params.GenesisBlock) - if err != nil { - return err - } - - basicFilterTip := builder.MakeHeaderForFilter(basicFilter, - params.GenesisBlock.Header.PrevBlock) - - extFilter, err := builder.BuildExtFilter(params.GenesisBlock) - if err != nil { - return err - } - - extFilterTip := builder.MakeHeaderForFilter(extFilter, - params.GenesisBlock.Header.PrevBlock) - - err = putBlock(tx, params.GenesisBlock.Header, 0) - if err != nil { - return err - } - - err = putBasicFilter(tx, *params.GenesisHash, basicFilter) - if err != nil { - return err - } - - err = putBasicHeader(tx, *params.GenesisHash, basicFilterTip) - if err != nil { - return err - } - - err = putExtFilter(tx, *params.GenesisHash, extFilter) - if err != nil { - return err - } - - err = putExtHeader(tx, *params.GenesisHash, extFilterTip) - if err != nil { - return err - } - - err = putDBVersion(tx, latestDBVersion) - if err != nil { - return err - } - - err = putMaxBlockHeight(tx, 0) - if err != nil { - return err - } - - err = spvBucket.Put(dbCreateDateName, - uint64ToBytes(uint64(time.Now().Unix()))) - if err != nil { - return fmt.Errorf("failed to store database creation "+ - "time: %s", err) - } - - return nil - }) +func (s *ChainService) createSPVNS() error { + tx, err := s.db.BeginReadWriteTx() if err != nil { - return fmt.Errorf("failed to update database: %s", err) + return err } - return nil + spvBucket, err := tx.CreateTopLevelBucket(spvBucketName) + if err != nil { + return fmt.Errorf("failed to create main bucket: %s", err) + } + + _, err = spvBucket.CreateBucketIfNotExists(blockHeaderBucketName) + if err != nil { + return fmt.Errorf("failed to create block header bucket: %s", + err) + } + + _, err = spvBucket.CreateBucketIfNotExists(basicFilterBucketName) + if err != nil { + return fmt.Errorf("failed to create basic filter "+ + "bucket: %s", err) + } + + _, err = spvBucket.CreateBucketIfNotExists(basicHeaderBucketName) + if err != nil { + return fmt.Errorf("failed to create basic header "+ + "bucket: %s", err) + } + + _, err = spvBucket.CreateBucketIfNotExists(extFilterBucketName) + if err != nil { + return fmt.Errorf("failed to create extended filter "+ + "bucket: %s", err) + } + + _, err = spvBucket.CreateBucketIfNotExists(extHeaderBucketName) + if err != nil { + return fmt.Errorf("failed to create extended header "+ + "bucket: %s", err) + } + + createDate := spvBucket.Get(dbCreateDateName) + if createDate != nil { + log.Info("Wallet SPV namespace already created.") + return nil + } + + log.Info("Creating wallet SPV namespace.") + + basicFilter, err := builder.BuildBasicFilter( + s.chainParams.GenesisBlock) + if err != nil { + return err + } + + basicFilterTip := builder.MakeHeaderForFilter(basicFilter, + s.chainParams.GenesisBlock.Header.PrevBlock) + + extFilter, err := builder.BuildExtFilter( + s.chainParams.GenesisBlock) + if err != nil { + return err + } + + extFilterTip := builder.MakeHeaderForFilter(extFilter, + s.chainParams.GenesisBlock.Header.PrevBlock) + + err = putBlock(s.chainParams.GenesisBlock.Header, 0)(spvBucket) + if err != nil { + return err + } + + err = putBasicFilter(*s.chainParams.GenesisHash, basicFilter)(spvBucket) + if err != nil { + return err + } + + err = putBasicHeader(*s.chainParams.GenesisHash, basicFilterTip)( + spvBucket) + if err != nil { + return err + } + + err = putExtFilter(*s.chainParams.GenesisHash, extFilter)(spvBucket) + if err != nil { + return err + } + + err = putExtHeader(*s.chainParams.GenesisHash, extFilterTip)(spvBucket) + if err != nil { + return err + } + + err = putDBVersion(latestDBVersion)(spvBucket) + if err != nil { + return err + } + + err = putMaxBlockHeight(0)(spvBucket) + if err != nil { + return err + } + + err = spvBucket.Put(dbCreateDateName, + uint64ToBytes(uint64(time.Now().Unix()))) + if err != nil { + return fmt.Errorf("failed to store database creation "+ + "time: %s", err) + } + + return tx.Commit() +} + +// dbUpdate allows the passed function to update the ChainService DB bucket. +func (s *ChainService) dbUpdate(updateFunc dbUpdateOption) error { + tx, err := s.db.BeginReadWriteTx() + if err != nil { + tx.Rollback() + return err + } + bucket := tx.ReadWriteBucket(spvBucketName) + err = updateFunc(bucket) + if err != nil { + tx.Rollback() + return err + } + return tx.Commit() +} + +// dbView allows the passed function to read the ChainService DB bucket. +func (s *ChainService) dbView(viewFunc dbViewOption) error { + tx, err := s.db.BeginReadTx() + defer tx.Rollback() + if err != nil { + return err + } + bucket := tx.ReadBucket(spvBucketName) + return viewFunc(bucket) + } diff --git a/spvsvc/spvchain/notifications.go b/spvsvc/spvchain/notifications.go index c38c2a1..d887a4d 100644 --- a/spvsvc/spvchain/notifications.go +++ b/spvsvc/spvchain/notifications.go @@ -162,3 +162,112 @@ func (s *ChainService) handleQuery(state *peerState, querymsg interface{}) { // useful in the future. } } + +// ConnectedCount returns the number of currently connected peers. +func (s *ChainService) ConnectedCount() int32 { + replyChan := make(chan int32) + + s.query <- getConnCountMsg{reply: replyChan} + + return <-replyChan +} + +// OutboundGroupCount returns the number of peers connected to the given +// outbound group key. +func (s *ChainService) OutboundGroupCount(key string) int { + replyChan := make(chan int) + s.query <- getOutboundGroup{key: key, reply: replyChan} + return <-replyChan +} + +// AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures +// describing the persistent (added) nodes. +func (s *ChainService) AddedNodeInfo() []*serverPeer { + replyChan := make(chan []*serverPeer) + s.query <- getAddedNodesMsg{reply: replyChan} + return <-replyChan +} + +// Peers returns an array of all connected peers. +func (s *ChainService) Peers() []*serverPeer { + replyChan := make(chan []*serverPeer) + + s.query <- getPeersMsg{reply: replyChan} + + return <-replyChan +} + +// DisconnectNodeByAddr disconnects a peer by target address. Both outbound and +// inbound nodes will be searched for the target node. An error message will +// be returned if the peer was not found. +func (s *ChainService) DisconnectNodeByAddr(addr string) error { + replyChan := make(chan error) + + s.query <- disconnectNodeMsg{ + cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, + reply: replyChan, + } + + return <-replyChan +} + +// DisconnectNodeByID disconnects a peer by target node id. Both outbound and +// inbound nodes will be searched for the target node. An error message will be +// returned if the peer was not found. +func (s *ChainService) DisconnectNodeByID(id int32) error { + replyChan := make(chan error) + + s.query <- disconnectNodeMsg{ + cmp: func(sp *serverPeer) bool { return sp.ID() == id }, + reply: replyChan, + } + + return <-replyChan +} + +// RemoveNodeByAddr removes a peer from the list of persistent peers if +// present. An error will be returned if the peer was not found. +func (s *ChainService) RemoveNodeByAddr(addr string) error { + replyChan := make(chan error) + + s.query <- removeNodeMsg{ + cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, + reply: replyChan, + } + + return <-replyChan +} + +// RemoveNodeByID removes a peer by node ID from the list of persistent peers +// if present. An error will be returned if the peer was not found. +func (s *ChainService) RemoveNodeByID(id int32) error { + replyChan := make(chan error) + + s.query <- removeNodeMsg{ + cmp: func(sp *serverPeer) bool { return sp.ID() == id }, + reply: replyChan, + } + + return <-replyChan +} + +// ConnectNode adds `addr' as a new outbound peer. If permanent is true then the +// peer will be persistent and reconnect if the connection is lost. +// It is an error to call this with an already existing peer. +func (s *ChainService) ConnectNode(addr string, permanent bool) error { + replyChan := make(chan error) + + s.query <- connectNodeMsg{addr: addr, permanent: permanent, reply: replyChan} + + return <-replyChan +} + +// ForAllPeers runs a closure over all peers (outbound and persistent) to which +// the ChainService is connected. Nothing is returned because the peerState's +// ForAllPeers method doesn't return anything as the closure passed to it +// doesn't return anything. +func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) { + s.query <- forAllPeersMsg{ + closure: closure, + } +} diff --git a/spvsvc/spvchain/query.go b/spvsvc/spvchain/query.go new file mode 100644 index 0000000..4f8d68b --- /dev/null +++ b/spvsvc/spvchain/query.go @@ -0,0 +1,409 @@ +package spvchain + +import ( + "sync" + "time" + + "github.com/btcsuite/btcd/blockchain" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/btcsuite/btcutil/gcs" + "github.com/btcsuite/btcutil/gcs/builder" +) + +var ( + // QueryTimeout specifies how long to wait for a peer to answer a query. + QueryTimeout = time.Second * 3 + + // QueryNumRetries specifies how many times to retry sending a query to + // each peer before we've concluded we aren't going to get a valid + // response. This allows to make up for missed messages in some + // instances. + QueryNumRetries = 2 +) + +// Query options can be modified per-query, unlike global options. +// TODO: Make more query options that override global options. +type queryOptions struct { + // timeout lets the query know how long to wait for a peer to + // answer the query before moving onto the next peer. + timeout time.Duration + + // numRetries tells the query how many times to retry asking each peer + // the query. + numRetries uint8 + + // doneChan lets the query signal the caller when it's done, in case + // it's run in a goroutine. + doneChan chan<- struct{} +} + +// QueryOption is a functional option argument to any of the network query +// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a +// network query). +type QueryOption func(*queryOptions) + +// defaultQueryOptions returns a queryOptions set to package-level defaults. +func defaultQueryOptions() *queryOptions { + return &queryOptions{ + timeout: QueryTimeout, + numRetries: uint8(QueryNumRetries), + } +} + +// Timeout is a query option that lets the query know how long to wait for +// each peer we ask the query to answer it before moving on. +func Timeout(timeout time.Duration) QueryOption { + return func(qo *queryOptions) { + qo.timeout = timeout + } +} + +// NumRetries is a query option that lets the query know the maximum number of +// times each peer should be queried. The default is one. +func NumRetries(numRetries uint8) QueryOption { + return func(qo *queryOptions) { + qo.numRetries = numRetries + } +} + +// DoneChan allows the caller to pass a channel that will get closed when the +// query is finished. +func DoneChan(doneChan chan<- struct{}) QueryOption { + return func(qo *queryOptions) { + qo.doneChan = doneChan + } +} + +type spMsg struct { + sp *serverPeer + msg wire.Message +} + +type spMsgSubscription struct { + msgChan chan<- spMsg + quitChan <-chan struct{} + wg *sync.WaitGroup +} + +// queryPeers is a helper function that sends a query to one or more peers and +// waits for an answer. The timeout for queries is set by the QueryTimeout +// package-level variable. +func (s *ChainService) queryPeers( + // queryMsg is the message to send to each peer selected by selectPeer. + queryMsg wire.Message, + // checkResponse is caled for every message within the timeout period. + // The quit channel lets the query know to terminate because the + // required response has been found. This is done by closing the + // channel. + checkResponse func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}), + // options takes functional options for executing the query. + options ...QueryOption, +) { + qo := defaultQueryOptions() + for _, option := range options { + option(qo) + } + + // This is done in a single-threaded query because the peerState is held + // in a single thread. This is the only part of the query framework that + // requires access to peerState, so it's done once per query. + peers := s.Peers() + + // This will be shared state between the per-peer goroutines. + quit := make(chan struct{}) + allQuit := make(chan struct{}) + startQuery := make(chan struct{}) + var wg sync.WaitGroup + // Increase this number to be able to handle more queries at once as + // each channel gets results for all queries, otherwise messages can + // get mixed and there's a vicious cycle of retries causing a bigger + // message flood, more of which get missed. + msgChan := make(chan spMsg) + var subwg sync.WaitGroup + subscription := spMsgSubscription{ + msgChan: msgChan, + quitChan: allQuit, + wg: &subwg, + } + + // Start a goroutine for each peer that potentially queries that peer. + for _, sp := range peers { + wg.Add(1) + go func(sp *serverPeer) { + numRetries := qo.numRetries + defer wg.Done() + defer sp.unsubscribeRecvMsgs(subscription) + // Should we do this when the goroutine gets a message + // via startQuery rather than at the launch of the + // goroutine? + if !sp.Connected() { + return + } + timeout := make(<-chan time.Time) + for { + select { + case <-timeout: + // After timeout, we try to notify + // another of our peer goroutines to + // do a query until we get a signal to + // quit. + select { + case startQuery <- struct{}{}: + case <-quit: + return + case <-allQuit: + return + } + // At this point, we've sent startQuery. + // We return if we've run through this + // section of code numRetries times. + if numRetries--; numRetries == 0 { + return + } + case <-quit: + // After we're told to quit, we return. + return + case <-allQuit: + // After we're told to quit, we return. + return + case <-startQuery: + // We're the lucky peer whose turn it is + // to try to answer the current query. + // TODO: Fix this to support either + // querying *all* peers simultaneously + // to avoid timeout delays, or starting + // with the syncPeer when not querying + // *all* peers. + sp.subscribeRecvMsg(subscription) + // Don't want the peer hanging on send + // to the channel if we quit before + // reading the channel. + sentChan := make(chan struct{}, 1) + sp.QueueMessage(queryMsg, sentChan) + select { + case <-sentChan: + case <-quit: + return + case <-allQuit: + return + } + timeout = time.After(qo.timeout) + default: + } + } + }(sp) + } + startQuery <- struct{}{} + + // This goroutine will wait until all of the peer-query goroutines have + // terminated, and then initiate a query shutdown. + go func() { + wg.Wait() + // If we timed out on each goroutine and didn't quit or time out + // on the main goroutine, make sure our main goroutine knows to + // quit. + select { + case <-allQuit: + default: + close(allQuit) + } + // Close the done channel, if any + if qo.doneChan != nil { + close(qo.doneChan) + } + // Wait until all goroutines started by subscriptions have + // exited after we closed allQuit before letting the message + // channel get garbage collected. + subwg.Wait() + }() + + // Loop for any messages sent to us via our subscription channel and + // check them for whether they satisfy the query. Break the loop if it's + // time to quit. + timeout := time.After(time.Duration(len(peers)+1) * + qo.timeout * time.Duration(qo.numRetries)) +checkResponses: + for { + select { + case <-timeout: + // When we time out, close the allQuit channel + // if it hasn't already been closed. + select { + case <-allQuit: + default: + close(allQuit) + } + break checkResponses + case <-quit: + break checkResponses + case <-allQuit: + break checkResponses + case sm := <-msgChan: + // TODO: This will get stuck if checkResponse + // gets stuck. This is a caveat for callers that + // should be fixed before exposing this function + // for public use. + checkResponse(sm.sp, sm.msg, quit) + } + } +} + +// GetCFilter gets a cfilter from the database. Failing that, it requests the +// cfilter from the network and writes it to the database. +func (s *ChainService) GetCFilter(blockHash chainhash.Hash, + extended bool, options ...QueryOption) *gcs.Filter { + getFilter := s.GetBasicFilter + getHeader := s.GetBasicHeader + putFilter := s.putBasicFilter + if extended { + getFilter = s.GetExtFilter + getHeader = s.GetExtHeader + putFilter = s.putExtFilter + } + filter, err := getFilter(blockHash) + if err == nil && filter != nil { + return filter + } + block, _, err := s.GetBlockByHash(blockHash) + if err != nil || block.BlockHash() != blockHash { + return nil + } + curHeader, err := getHeader(blockHash) + if err != nil { + return nil + } + prevHeader, err := getHeader(block.PrevBlock) + if err != nil { + return nil + } + s.queryPeers( + // Send a wire.GetCFilterMsg + wire.NewMsgGetCFilter(&blockHash, extended), + // Check responses and if we get one that matches, + // end the query early. + func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}) { + switch response := resp.(type) { + // We're only interested in "cfilter" messages. + case *wire.MsgCFilter: + if len(response.Data) < 4 { + // Filter data is too short. + // Ignore this message. + return + } + if blockHash != response.BlockHash { + // The response doesn't match our + // request. Ignore this message. + return + } + gotFilter, err := + gcs.FromNBytes(builder.DefaultP, + response.Data) + if err != nil { + // Malformed filter data. We + // can ignore this message. + return + } + if builder.MakeHeaderForFilter(gotFilter, + *prevHeader) != + *curHeader { + // Filter data doesn't match + // the headers we know about. + // Ignore this response. + return + } + // At this point, the filter matches + // what we know about it and we declare + // it sane. We can kill the query and + // pass the response back to the caller. + close(quit) + filter = gotFilter + default: + } + }, + options..., + ) + // If we've found a filter, write it to the database for next time. + if filter != nil { + putFilter(blockHash, filter) + log.Tracef("Wrote filter for block %s, extended: %t", + blockHash, extended) + } + return filter +} + +// GetBlockFromNetwork gets a block by requesting it from the network, one peer +// at a time, until one answers. +func (s *ChainService) GetBlockFromNetwork( + blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block { + blockHeader, height, err := s.GetBlockByHash(blockHash) + if err != nil || blockHeader.BlockHash() != blockHash { + return nil + } + getData := wire.NewMsgGetData() + getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock, + &blockHash)) + // The block is only updated from the checkResponse function argument, + // which is always called single-threadedly. We don't check the block + // until after the query is finished, so we can just write to it + // naively. + var foundBlock *btcutil.Block + s.queryPeers( + // Send a wire.GetCFilterMsg + getData, + // Check responses and if we get one that matches, + // end the query early. + func(sp *serverPeer, resp wire.Message, + quit chan<- struct{}) { + switch response := resp.(type) { + // We're only interested in "block" messages. + case *wire.MsgBlock: + // If this isn't our block, ignore it. + if response.BlockHash() != + blockHash { + return + } + block := btcutil.NewBlock(response) + // Only set height if btcutil hasn't + // automagically put one in. + if block.Height() == + btcutil.BlockHeightUnknown { + block.SetHeight( + int32(height)) + } + // If this claims our block but doesn't + // pass the sanity check, the peer is + // trying to bamboozle us. Disconnect + // it. + if err := blockchain.CheckBlockSanity( + block, + // We don't need to check PoW + // because by the time we get + // here, it's been checked + // during header synchronization + s.chainParams.PowLimit, + s.timeSource, + ); err != nil { + log.Warnf("Invalid block for %s "+ + "received from %s -- "+ + "disconnecting peer", blockHash, + sp.Addr()) + sp.Disconnect() + return + } + // At this point, the block matches what we know + // about it and we declare it sane. We can kill + // the query and pass the response back to the + // caller. + close(quit) + foundBlock = block + default: + } + }, + options..., + ) + return foundBlock +} diff --git a/spvsvc/spvchain/spvchain.go b/spvsvc/spvchain/spvchain.go index 29fe348..b6f73ce 100644 --- a/spvsvc/spvchain/spvchain.go +++ b/spvsvc/spvchain/spvchain.go @@ -17,9 +17,6 @@ import ( "github.com/btcsuite/btcd/peer" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/btcsuite/btcutil/gcs" - "github.com/btcsuite/btcutil/gcs/builder" - "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/walletdb" ) @@ -63,15 +60,6 @@ var ( // DisableDNSSeed disables getting initial addresses for Bitcoin nodes // from DNS. DisableDNSSeed = false - - // QueryTimeout specifies how long to wait for a peer to answer a query. - QueryTimeout = time.Second * 3 - - // QueryNumRetries specifies how many times to retry sending a query to - // each peer before we've concluded we aren't going to get a valid - // response. This allows to make up for missed messages in some - // instances. - QueryNumRetries = 2 ) // updatePeerHeightsMsg is a message sent from the blockmanager to the server @@ -538,7 +526,7 @@ type ChainService struct { started int32 shutdown int32 - namespace walletdb.Namespace + db walletdb.DB chainParams chaincfg.Params addrManager *addrmgr.AddrManager connManager *connmgr.ConnManager @@ -562,39 +550,6 @@ func (s *ChainService) BanPeer(sp *serverPeer) { s.banPeers <- sp } -// BestSnapshot returns the best block hash and height known to the database. -func (s *ChainService) BestSnapshot() (*waddrmgr.BlockStamp, error) { - var best *waddrmgr.BlockStamp - var err error - err = s.namespace.View(func(tx walletdb.Tx) error { - best, err = syncedTo(tx) - return err - }) - if err != nil { - return nil, err - } - return best, nil -} - -// LatestBlockLocator returns the block locator for the latest known block -// stored in the database. -func (s *ChainService) LatestBlockLocator() (blockchain.BlockLocator, error) { - var locator blockchain.BlockLocator - var err error - err = s.namespace.View(func(tx walletdb.Tx) error { - best, err := syncedTo(tx) - if err != nil { - return err - } - locator = blockLocatorFromHash(tx, best.Hash) - return nil - }) - if err != nil { - return nil, err - } - return locator, nil -} - // AddPeer adds a new peer that has already been connected to the server. func (s *ChainService) AddPeer(sp *serverPeer) { s.newPeers <- sp @@ -735,7 +690,8 @@ cleanup: // Config is a struct detailing the configuration of the chain service. type Config struct { DataDir string - Namespace walletdb.Namespace + Database walletdb.DB + Namespace []byte ChainParams chaincfg.Params ConnectPeers []string AddPeers []string @@ -756,14 +712,14 @@ func NewChainService(cfg Config) (*ChainService, error) { query: make(chan interface{}), quit: make(chan struct{}), peerHeightsUpdate: make(chan updatePeerHeightsMsg), - namespace: cfg.Namespace, + db: cfg.Database, timeSource: blockchain.NewMedianTime(), services: Services, userAgentName: UserAgentName, userAgentVersion: UserAgentVersion, } - err := createSPVNS(s.namespace, &s.chainParams) + err := s.createSPVNS() if err != nil { return nil, err } @@ -1168,115 +1124,6 @@ func (s *ChainService) peerDoneHandler(sp *serverPeer) { close(sp.quit) } -// ConnectedCount returns the number of currently connected peers. -func (s *ChainService) ConnectedCount() int32 { - replyChan := make(chan int32) - - s.query <- getConnCountMsg{reply: replyChan} - - return <-replyChan -} - -// OutboundGroupCount returns the number of peers connected to the given -// outbound group key. -func (s *ChainService) OutboundGroupCount(key string) int { - replyChan := make(chan int) - s.query <- getOutboundGroup{key: key, reply: replyChan} - return <-replyChan -} - -// AddedNodeInfo returns an array of btcjson.GetAddedNodeInfoResult structures -// describing the persistent (added) nodes. -func (s *ChainService) AddedNodeInfo() []*serverPeer { - replyChan := make(chan []*serverPeer) - s.query <- getAddedNodesMsg{reply: replyChan} - return <-replyChan -} - -// Peers returns an array of all connected peers. -func (s *ChainService) Peers() []*serverPeer { - replyChan := make(chan []*serverPeer) - - s.query <- getPeersMsg{reply: replyChan} - - return <-replyChan -} - -// DisconnectNodeByAddr disconnects a peer by target address. Both outbound and -// inbound nodes will be searched for the target node. An error message will -// be returned if the peer was not found. -func (s *ChainService) DisconnectNodeByAddr(addr string) error { - replyChan := make(chan error) - - s.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, - reply: replyChan, - } - - return <-replyChan -} - -// DisconnectNodeByID disconnects a peer by target node id. Both outbound and -// inbound nodes will be searched for the target node. An error message will be -// returned if the peer was not found. -func (s *ChainService) DisconnectNodeByID(id int32) error { - replyChan := make(chan error) - - s.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.ID() == id }, - reply: replyChan, - } - - return <-replyChan -} - -// RemoveNodeByAddr removes a peer from the list of persistent peers if -// present. An error will be returned if the peer was not found. -func (s *ChainService) RemoveNodeByAddr(addr string) error { - replyChan := make(chan error) - - s.query <- removeNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, - reply: replyChan, - } - - return <-replyChan -} - -// RemoveNodeByID removes a peer by node ID from the list of persistent peers -// if present. An error will be returned if the peer was not found. -func (s *ChainService) RemoveNodeByID(id int32) error { - replyChan := make(chan error) - - s.query <- removeNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.ID() == id }, - reply: replyChan, - } - - return <-replyChan -} - -// ConnectNode adds `addr' as a new outbound peer. If permanent is true then the -// peer will be persistent and reconnect if the connection is lost. -// It is an error to call this with an already existing peer. -func (s *ChainService) ConnectNode(addr string, permanent bool) error { - replyChan := make(chan error) - - s.query <- connectNodeMsg{addr: addr, permanent: permanent, reply: replyChan} - - return <-replyChan -} - -// ForAllPeers runs a closure over all peers (outbound and persistent) to which -// the ChainService is connected. Nothing is returned because the peerState's -// ForAllPeers method doesn't return anything as the closure passed to it -// doesn't return anything. -func (s *ChainService) ForAllPeers(closure func(sp *serverPeer)) { - s.query <- forAllPeersMsg{ - closure: closure, - } -} - // UpdatePeerHeights updates the heights of all peers who have have announced // the latest connected main chain block, or a recognized orphan. These height // updates allow us to dynamically refresh peer heights, ensuring sync peer @@ -1376,562 +1223,8 @@ func (s *ChainService) Stop() error { return nil } -// GetBlockByHeight gets block information from the ChainService database by -// its height. -func (s *ChainService) GetBlockByHeight(height uint32) (wire.BlockHeader, - uint32, error) { - var bh wire.BlockHeader - var h uint32 - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - bh, h, err = getBlockByHeight(dbTx, height) - return err - }) - return bh, h, err -} - -// GetBlockByHash gets block information from the ChainService database by its -// hash. -func (s *ChainService) GetBlockByHash(hash chainhash.Hash) (wire.BlockHeader, - uint32, error) { - var bh wire.BlockHeader - var h uint32 - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - bh, h, err = getBlockByHash(dbTx, hash) - return err - }) - return bh, h, err -} - -// LatestBlock gets the latest block's information from the ChainService -// database. -func (s *ChainService) LatestBlock() (wire.BlockHeader, uint32, error) { - var bh wire.BlockHeader - var h uint32 - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - bh, h, err = latestBlock(dbTx) - return err - }) - return bh, h, err -} - -// putBlock puts a verified block header and height in the ChainService -// database. -func (s *ChainService) putBlock(header wire.BlockHeader, height uint32) error { - return s.namespace.Update(func(dbTx walletdb.Tx) error { - return putBlock(dbTx, header, height) - }) -} - -// putBasicHeader puts a verified basic filter header in the ChainService -// database. -func (s *ChainService) putBasicHeader(blockHash chainhash.Hash, - filterTip chainhash.Hash) error { - return s.namespace.Update(func(dbTx walletdb.Tx) error { - return putBasicHeader(dbTx, blockHash, filterTip) - }) -} - -// putExtHeader puts a verified extended filter header in the ChainService -// database. -func (s *ChainService) putExtHeader(blockHash chainhash.Hash, - filterTip chainhash.Hash) error { - return s.namespace.Update(func(dbTx walletdb.Tx) error { - return putExtHeader(dbTx, blockHash, filterTip) - }) -} - -// GetBasicHeader gets a verified basic filter header from the ChainService -// database. -func (s *ChainService) GetBasicHeader(blockHash chainhash.Hash) (*chainhash.Hash, - error) { - var filterTip *chainhash.Hash - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - filterTip, err = getBasicHeader(dbTx, blockHash) - return err - }) - return filterTip, err -} - -// GetExtHeader gets a verified extended filter header from the ChainService -// database. -func (s *ChainService) GetExtHeader(blockHash chainhash.Hash) (*chainhash.Hash, - error) { - var filterTip *chainhash.Hash - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - filterTip, err = getExtHeader(dbTx, blockHash) - return err - }) - return filterTip, err -} - -// putBasicFilter puts a verified basic filter in the ChainService database. -func (s *ChainService) putBasicFilter(blockHash chainhash.Hash, - filter *gcs.Filter) error { - return s.namespace.Update(func(dbTx walletdb.Tx) error { - return putBasicFilter(dbTx, blockHash, filter) - }) -} - -// putExtFilter puts a verified extended filter in the ChainService database. -func (s *ChainService) putExtFilter(blockHash chainhash.Hash, - filter *gcs.Filter) error { - return s.namespace.Update(func(dbTx walletdb.Tx) error { - return putExtFilter(dbTx, blockHash, filter) - }) -} - -// GetBasicFilter gets a verified basic filter from the ChainService database. -func (s *ChainService) GetBasicFilter(blockHash chainhash.Hash) (*gcs.Filter, - error) { - var filter *gcs.Filter - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - filter, err = getBasicFilter(dbTx, blockHash) - return err - }) - return filter, err -} - -// GetExtFilter gets a verified extended filter from the ChainService database. -func (s *ChainService) GetExtFilter(blockHash chainhash.Hash) (*gcs.Filter, - error) { - var filter *gcs.Filter - var err error - err = s.namespace.View(func(dbTx walletdb.Tx) error { - filter, err = getExtFilter(dbTx, blockHash) - return err - }) - return filter, err -} - -// putMaxBlockHeight puts the max block height to the ChainService database. -func (s *ChainService) putMaxBlockHeight(maxBlockHeight uint32) error { - return s.namespace.Update(func(dbTx walletdb.Tx) error { - return putMaxBlockHeight(dbTx, maxBlockHeight) - }) -} - -func (s *ChainService) rollbackLastBlock() (*waddrmgr.BlockStamp, error) { - var bs *waddrmgr.BlockStamp - var err error - err = s.namespace.Update(func(dbTx walletdb.Tx) error { - bs, err = rollbackLastBlock(dbTx) - return err - }) - return bs, err -} - -func (s *ChainService) rollbackToHeight(height uint32) (*waddrmgr.BlockStamp, error) { - var bs *waddrmgr.BlockStamp - var err error - err = s.namespace.Update(func(dbTx walletdb.Tx) error { - bs, err = syncedTo(dbTx) - if err != nil { - return err - } - for uint32(bs.Height) > height { - bs, err = rollbackLastBlock(dbTx) - if err != nil { - return err - } - } - return nil - }) - return bs, err -} - // IsCurrent lets the caller know whether the chain service's block manager // thinks its view of the network is current. func (s *ChainService) IsCurrent() bool { return s.blockManager.IsCurrent() } - -// Query options can be modified per-query, unlike global options. -// TODO: Make more query options that override global options. -type queryOptions struct { - // timeout lets the query know how long to wait for a peer to - // answer the query before moving onto the next peer. - timeout time.Duration - - // numRetries tells the query how many times to retry asking each peer - // the query. - numRetries uint8 - - // doneChan lets the query signal the caller when it's done, in case - // it's run in a goroutine. - doneChan chan<- struct{} -} - -// QueryOption is a functional option argument to any of the network query -// methods, such as GetBlockFromNetwork and GetCFilter (when that resorts to a -// network query). -type QueryOption func(*queryOptions) - -// defaultQueryOptions returns a queryOptions set to package-level defaults. -func defaultQueryOptions() *queryOptions { - return &queryOptions{ - timeout: QueryTimeout, - numRetries: uint8(QueryNumRetries), - } -} - -// Timeout is a query option that lets the query know how long to wait for -// each peer we ask the query to answer it before moving on. -func Timeout(timeout time.Duration) QueryOption { - return func(qo *queryOptions) { - qo.timeout = timeout - } -} - -// NumRetries is a query option that lets the query know the maximum number of -// times each peer should be queried. The default is one. -func NumRetries(numRetries uint8) QueryOption { - return func(qo *queryOptions) { - qo.numRetries = numRetries - } -} - -// DoneChan allows the caller to pass a channel that will get closed when the -// query is finished. -func DoneChan(doneChan chan<- struct{}) QueryOption { - return func(qo *queryOptions) { - qo.doneChan = doneChan - } -} - -type spMsg struct { - sp *serverPeer - msg wire.Message -} - -type spMsgSubscription struct { - msgChan chan<- spMsg - quitChan <-chan struct{} - wg *sync.WaitGroup -} - -// queryPeers is a helper function that sends a query to one or more peers and -// waits for an answer. The timeout for queries is set by the QueryTimeout -// package-level variable. -func (s *ChainService) queryPeers( - // queryMsg is the message to send to each peer selected by selectPeer. - queryMsg wire.Message, - // checkResponse is caled for every message within the timeout period. - // The quit channel lets the query know to terminate because the - // required response has been found. This is done by closing the - // channel. - checkResponse func(sp *serverPeer, resp wire.Message, - quit chan<- struct{}), - // options takes functional options for executing the query. - options ...QueryOption, -) { - qo := defaultQueryOptions() - for _, option := range options { - option(qo) - } - - // This is done in a single-threaded query because the peerState is held - // in a single thread. This is the only part of the query framework that - // requires access to peerState, so it's done once per query. - peers := s.Peers() - - // This will be shared state between the per-peer goroutines. - quit := make(chan struct{}) - allQuit := make(chan struct{}) - startQuery := make(chan struct{}) - var wg sync.WaitGroup - // Increase this number to be able to handle more queries at once as - // each channel gets results for all queries, otherwise messages can - // get mixed and there's a vicious cycle of retries causing a bigger - // message flood, more of which get missed. - msgChan := make(chan spMsg) - var subwg sync.WaitGroup - subscription := spMsgSubscription{ - msgChan: msgChan, - quitChan: allQuit, - wg: &subwg, - } - - // Start a goroutine for each peer that potentially queries that peer. - for _, sp := range peers { - wg.Add(1) - go func(sp *serverPeer) { - numRetries := qo.numRetries - defer wg.Done() - defer sp.unsubscribeRecvMsgs(subscription) - // Should we do this when the goroutine gets a message - // via startQuery rather than at the launch of the - // goroutine? - if !sp.Connected() { - return - } - timeout := make(<-chan time.Time) - for { - select { - case <-timeout: - // After timeout, we try to notify - // another of our peer goroutines to - // do a query until we get a signal to - // quit. - select { - case startQuery <- struct{}{}: - case <-quit: - return - case <-allQuit: - return - } - // At this point, we've sent startQuery. - // We return if we've run through this - // section of code numRetries times. - if numRetries--; numRetries == 0 { - return - } - case <-quit: - // After we're told to quit, we return. - return - case <-allQuit: - // After we're told to quit, we return. - return - case <-startQuery: - // We're the lucky peer whose turn it is - // to try to answer the current query. - // TODO: Fix this to support either - // querying *all* peers simultaneously - // to avoid timeout delays, or starting - // with the syncPeer when not querying - // *all* peers. - sp.subscribeRecvMsg(subscription) - // Don't want the peer hanging on send - // to the channel if we quit before - // reading the channel. - sentChan := make(chan struct{}, 1) - sp.QueueMessage(queryMsg, sentChan) - select { - case <-sentChan: - case <-quit: - return - case <-allQuit: - return - } - timeout = time.After(qo.timeout) - default: - } - } - }(sp) - } - startQuery <- struct{}{} - - // This goroutine will wait until all of the peer-query goroutines have - // terminated, and then initiate a query shutdown. - go func() { - wg.Wait() - // If we timed out on each goroutine and didn't quit or time out - // on the main goroutine, make sure our main goroutine knows to - // quit. - select { - case <-allQuit: - default: - close(allQuit) - } - // Close the done channel, if any - if qo.doneChan != nil { - close(qo.doneChan) - } - // Wait until all goroutines started by subscriptions have - // exited after we closed allQuit before letting the message - // channel get garbage collected. - subwg.Wait() - }() - - // Loop for any messages sent to us via our subscription channel and - // check them for whether they satisfy the query. Break the loop if it's - // time to quit. - timeout := time.After(time.Duration(len(peers)+1) * - qo.timeout * time.Duration(qo.numRetries)) -checkResponses: - for { - select { - case <-timeout: - // When we time out, close the allQuit channel - // if it hasn't already been closed. - select { - case <-allQuit: - default: - close(allQuit) - } - break checkResponses - case <-quit: - break checkResponses - case <-allQuit: - break checkResponses - case sm := <-msgChan: - // TODO: This will get stuck if checkResponse - // gets stuck. This is a caveat for callers that - // should be fixed before exposing this function - // for public use. - checkResponse(sm.sp, sm.msg, quit) - } - } -} - -// GetCFilter gets a cfilter from the database. Failing that, it requests the -// cfilter from the network and writes it to the database. -func (s *ChainService) GetCFilter(blockHash chainhash.Hash, - extended bool, options ...QueryOption) *gcs.Filter { - getFilter := s.GetBasicFilter - getHeader := s.GetBasicHeader - putFilter := s.putBasicFilter - if extended { - getFilter = s.GetExtFilter - getHeader = s.GetExtHeader - putFilter = s.putExtFilter - } - filter, err := getFilter(blockHash) - if err == nil && filter != nil { - return filter - } - block, _, err := s.GetBlockByHash(blockHash) - if err != nil || block.BlockHash() != blockHash { - return nil - } - curHeader, err := getHeader(blockHash) - if err != nil { - return nil - } - prevHeader, err := getHeader(block.PrevBlock) - if err != nil { - return nil - } - s.queryPeers( - // Send a wire.GetCFilterMsg - wire.NewMsgGetCFilter(&blockHash, extended), - // Check responses and if we get one that matches, - // end the query early. - func(sp *serverPeer, resp wire.Message, - quit chan<- struct{}) { - switch response := resp.(type) { - // We're only interested in "cfilter" messages. - case *wire.MsgCFilter: - if len(response.Data) < 4 { - // Filter data is too short. - // Ignore this message. - return - } - if blockHash != response.BlockHash { - // The response doesn't match our - // request. Ignore this message. - return - } - gotFilter, err := - gcs.FromNBytes(builder.DefaultP, - response.Data) - if err != nil { - // Malformed filter data. We - // can ignore this message. - return - } - if builder.MakeHeaderForFilter(gotFilter, - *prevHeader) != - *curHeader { - // Filter data doesn't match - // the headers we know about. - // Ignore this response. - return - } - // At this point, the filter matches - // what we know about it and we declare - // it sane. We can kill the query and - // pass the response back to the caller. - close(quit) - filter = gotFilter - default: - } - }, - options..., - ) - // If we've found a filter, write it to the database for next time. - if filter != nil { - putFilter(blockHash, filter) - log.Tracef("Wrote filter for block %s, extended: %t", - blockHash, extended) - } - return filter -} - -// GetBlockFromNetwork gets a block by requesting it from the network, one peer -// at a time, until one answers. -func (s *ChainService) GetBlockFromNetwork( - blockHash chainhash.Hash, options ...QueryOption) *btcutil.Block { - blockHeader, height, err := s.GetBlockByHash(blockHash) - if err != nil || blockHeader.BlockHash() != blockHash { - return nil - } - getData := wire.NewMsgGetData() - getData.AddInvVect(wire.NewInvVect(wire.InvTypeBlock, - &blockHash)) - // The block is only updated from the checkResponse function argument, - // which is always called single-threadedly. We don't check the block - // until after the query is finished, so we can just write to it - // naively. - var foundBlock *btcutil.Block - s.queryPeers( - // Send a wire.GetCFilterMsg - getData, - // Check responses and if we get one that matches, - // end the query early. - func(sp *serverPeer, resp wire.Message, - quit chan<- struct{}) { - switch response := resp.(type) { - // We're only interested in "block" messages. - case *wire.MsgBlock: - // If this isn't our block, ignore it. - if response.BlockHash() != - blockHash { - return - } - block := btcutil.NewBlock(response) - // Only set height if btcutil hasn't - // automagically put one in. - if block.Height() == - btcutil.BlockHeightUnknown { - block.SetHeight( - int32(height)) - } - // If this claims our block but doesn't - // pass the sanity check, the peer is - // trying to bamboozle us. Disconnect - // it. - if err := blockchain.CheckBlockSanity( - block, - // We don't need to check PoW - // because by the time we get - // here, it's been checked - // during header synchronization - s.chainParams.PowLimit, - s.timeSource, - ); err != nil { - log.Warnf("Invalid block for %s "+ - "received from %s -- "+ - "disconnecting peer", blockHash, - sp.Addr()) - sp.Disconnect() - return - } - // At this point, the block matches what we know - // about it and we declare it sane. We can kill - // the query and pass the response back to the - // caller. - close(quit) - foundBlock = block - default: - } - }, - options..., - ) - return foundBlock -} diff --git a/spvsvc/spvchain/sync_test.go b/spvsvc/spvchain/sync_test.go index dfe392d..a3a549c 100644 --- a/spvsvc/spvchain/sync_test.go +++ b/spvsvc/spvchain/sync_test.go @@ -25,7 +25,7 @@ import ( ) var ( - logLevel = btclog.TraceLvl + logLevel = btclog.Off syncTimeout = 30 * time.Second syncUpdate = time.Second // Don't set this too high for your platform, or the tests will miss @@ -145,13 +145,12 @@ func TestSetup(t *testing.T) { if err != nil { t.Fatalf("Error opening DB: %s\n", err) } - ns, err := db.Namespace([]byte("weks")) if err != nil { t.Fatalf("Error geting namespace: %s\n", err) } config := spvchain.Config{ DataDir: tempDir, - Namespace: ns, + Database: db, ChainParams: modParams, AddPeers: []string{ h3.P2PAddress(), @@ -327,7 +326,7 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, // database to see if we've missed anything or messed anything // up. for i := int32(0); i <= haveBest.Height; i++ { - head, _, err := svc.GetBlockByHeight(uint32(i)) + head, err := svc.GetBlockByHeight(uint32(i)) if err != nil { return fmt.Errorf("Couldn't read block by "+ "height: %s", err) @@ -413,20 +412,12 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService, }() defer wg.Done() // Get block header from database. - blockHeader, blockHeight, err := svc.GetBlockByHeight( - height) + blockHeader, err := svc.GetBlockByHeight(height) if err != nil { errChan <- fmt.Errorf("Couldn't get block "+ "header by height %d: %s", height, err) return } - if blockHeight != height { - errChan <- fmt.Errorf("Block height retrieved "+ - "from DB doesn't match expected "+ - "height. Want: %d, have: %d", height, - blockHeight) - return - } blockHash := blockHeader.BlockHash() // Get block via RPC. wantBlock, err := correctSyncNode.Node.GetBlock( @@ -455,11 +446,11 @@ func testRandomBlocks(t *testing.T, svc *spvchain.ChainService, return } // Check that block height matches what we have. - if int32(blockHeight) != haveBlock.Height() { + if int32(height) != haveBlock.Height() { errChan <- fmt.Errorf("Block height from "+ "network doesn't match expected "+ "height. Want: %s, network: %s", - blockHeight, haveBlock.Height()) + height, haveBlock.Height()) return } // Get basic cfilter from network.