multi: use hidden varint for cftypes count; make filter type enum, not uint8

This commit is contained in:
Alex 2017-10-31 00:24:57 -06:00 committed by Olaoluwa Osuntokun
parent 8d2ce855eb
commit c7e7acc7fd
12 changed files with 85 additions and 55 deletions

View file

@ -15,6 +15,7 @@ import (
"github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder"
"github.com/btcsuite/fastsha256"
"github.com/roasbeef/btcd/wire"
)
const (
@ -149,7 +150,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
firstHeader := make([]byte, chainhash.HashSize)
err = dbStoreFilterHeader(
dbTx,
cfHeaderKeys[0],
cfHeaderKeys[wire.GCSFilterRegular],
&idx.chainParams.GenesisBlock.Header.PrevBlock,
firstHeader,
)
@ -159,7 +160,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
return dbStoreFilterHeader(
dbTx,
cfHeaderKeys[1],
cfHeaderKeys[wire.GCSFilterExtended],
&idx.chainParams.GenesisBlock.Header.PrevBlock,
firstHeader,
)
@ -168,8 +169,8 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
// storeFilter stores a given filter, and performs the steps needed to
// generate the filter's header.
func storeFilter(dbTx database.Tx, block *btcutil.Block, f *gcs.Filter,
filterType uint8) error {
if filterType > maxFilterType {
filterType wire.FilterType) error {
if uint8(filterType) > maxFilterType {
return errors.New("unsupported filter type")
}
@ -215,7 +216,8 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block,
return err
}
if err := storeFilter(dbTx, block, f, 0); err != nil {
if err := storeFilter(dbTx, block, f,
wire.GCSFilterRegular); err != nil {
return err
}
@ -224,7 +226,7 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block,
return err
}
return storeFilter(dbTx, block, f, 1)
return storeFilter(dbTx, block, f, wire.GCSFilterExtended)
}
// DisconnectBlock is invoked by the index manager when a block has been
@ -252,10 +254,11 @@ func (idx *CfIndex) DisconnectBlock(dbTx database.Tx, block *btcutil.Block,
// FilterByBlockHash returns the serialized contents of a block's basic or
// extended committed filter.
func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) {
func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash,
filterType wire.FilterType) ([]byte, error) {
var f []byte
err := idx.db.View(func(dbTx database.Tx) error {
if filterType > maxFilterType {
if uint8(filterType) > maxFilterType {
return errors.New("unsupported filter type")
}
@ -268,15 +271,17 @@ func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]by
// FilterHeaderByBlockHash returns the serialized contents of a block's basic
// or extended committed filter header.
func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) {
func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash,
filterType wire.FilterType) ([]byte, error) {
var fh []byte
err := idx.db.View(func(dbTx database.Tx) error {
if filterType > 1 {
if uint8(filterType) > maxFilterType {
return errors.New("unsupported filter type")
}
var err error
fh, err = dbFetchFilterHeader(dbTx, cfHeaderKeys[filterType], h)
fh, err = dbFetchFilterHeader(dbTx,
cfHeaderKeys[filterType], h)
return err
})
return fh, err

View file

@ -10,6 +10,8 @@ package btcjson
import (
"encoding/json"
"fmt"
"github.com/roasbeef/btcd/wire"
)
// AddNodeSubCmd defines the type used in the addnode JSON-RPC command for the
@ -282,12 +284,12 @@ func NewGetBlockTemplateCmd(request *TemplateRequest) *GetBlockTemplateCmd {
// GetCFilterCmd defines the getcfilter JSON-RPC command.
type GetCFilterCmd struct {
Hash string
FilterType uint8
FilterType wire.FilterType
}
// NewGetCFilterCmd returns a new instance which can be used to issue a
// getcfilter JSON-RPC command.
func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd {
func NewGetCFilterCmd(hash string, filterType wire.FilterType) *GetCFilterCmd {
return &GetCFilterCmd{
Hash: hash,
FilterType: filterType,
@ -297,12 +299,13 @@ func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd {
// GetCFilterHeaderCmd defines the getcfilterheader JSON-RPC command.
type GetCFilterHeaderCmd struct {
Hash string
FilterType uint8
FilterType wire.FilterType
}
// NewGetCFilterHeaderCmd returns a new instance which can be used to issue a
// getcfilterheader JSON-RPC command.
func NewGetCFilterHeaderCmd(hash string, filterType uint8) *GetCFilterHeaderCmd {
func NewGetCFilterHeaderCmd(hash string,
filterType wire.FilterType) *GetCFilterHeaderCmd {
return &GetCFilterHeaderCmd{
Hash: hash,
FilterType: filterType,

View file

@ -12,6 +12,7 @@ import (
"testing"
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/wire"
)
// TestChainSvrCmds tests all of the chain server commands marshal and unmarshal
@ -320,27 +321,33 @@ func TestChainSvrCmds(t *testing.T) {
{
name: "getcfilter",
newCmd: func() (interface{}, error) {
return btcjson.NewCmd("getcfilter", "123", 0)
return btcjson.NewCmd("getcfilter", "123",
wire.GCSFilterExtended)
},
staticCmd: func() interface{} {
return btcjson.NewGetCFilterCmd("123", 0)
return btcjson.NewGetCFilterCmd("123",
wire.GCSFilterExtended)
},
marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",0],"id":1}`,
marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",1],"id":1}`,
unmarshalled: &btcjson.GetCFilterCmd{
Hash: "123",
FilterType: wire.GCSFilterExtended,
},
},
{
name: "getcfilterheader",
newCmd: func() (interface{}, error) {
return btcjson.NewCmd("getcfilterheader", "123", 0)
return btcjson.NewCmd("getcfilterheader", "123",
wire.GCSFilterExtended)
},
staticCmd: func() interface{} {
return btcjson.NewGetCFilterHeaderCmd("123", 0)
return btcjson.NewGetCFilterHeaderCmd("123",
wire.GCSFilterExtended)
},
marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",0],"id":1}`,
marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",1],"id":1}`,
unmarshalled: &btcjson.GetCFilterHeaderCmd{
Hash: "123",
FilterType: wire.GCSFilterExtended,
},
},
{

View file

@ -542,7 +542,8 @@ func TestPeerListeners(t *testing.T) {
},
{
"OnGetCFilter",
wire.NewMsgGetCFilter(&chainhash.Hash{}, 0),
wire.NewMsgGetCFilter(&chainhash.Hash{},
wire.GCSFilterRegular),
},
{
"OnGetCFHeaders",
@ -554,8 +555,8 @@ func TestPeerListeners(t *testing.T) {
},
{
"OnCFilter",
wire.NewMsgCFilter(&chainhash.Hash{}, 1,
[]byte("payload")),
wire.NewMsgCFilter(&chainhash.Hash{},
wire.GCSFilterRegular, []byte("payload")),
},
{
"OnCFHeaders",
@ -563,7 +564,8 @@ func TestPeerListeners(t *testing.T) {
},
{
"OnCFTypes",
wire.NewMsgCFTypes([]uint8{0, 1}),
wire.NewMsgCFTypes([]wire.FilterType{
wire.GCSFilterRegular, wire.GCSFilterExtended}),
},
{
"OnFeeFilter",

View file

@ -821,7 +821,7 @@ func (r FutureGetCFilterResult) Receive() (*wire.MsgCFilter, error) {
//
// See GetCFilter for the blocking version and more details.
func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash,
filterType uint8) FutureGetCFilterResult {
filterType wire.FilterType) FutureGetCFilterResult {
hash := ""
if blockHash != nil {
hash = blockHash.String()
@ -833,7 +833,7 @@ func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash,
// GetCFilter returns a raw filter from the server given its block hash.
func (c *Client) GetCFilter(blockHash *chainhash.Hash,
filterType uint8) (*wire.MsgCFilter, error) {
filterType wire.FilterType) (*wire.MsgCFilter, error) {
return c.GetCFilterAsync(blockHash, filterType).Receive()
}
@ -878,7 +878,7 @@ func (r FutureGetCFilterHeaderResult) Receive() (*wire.MsgCFHeaders, error) {
//
// See GetCFilterHeader for the blocking version and more details.
func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash,
filterType uint8) FutureGetCFilterHeaderResult {
filterType wire.FilterType) FutureGetCFilterHeaderResult {
hash := ""
if blockHash != nil {
hash = blockHash.String()
@ -891,6 +891,6 @@ func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash,
// GetCFilterHeader returns a raw filter header from the server given its block
// hash.
func (c *Client) GetCFilterHeader(blockHash *chainhash.Hash,
filterType uint8) (*wire.MsgCFHeaders, error) {
filterType wire.FilterType) (*wire.MsgCFHeaders, error) {
return c.GetCFilterHeaderAsync(blockHash, filterType).Receive()
}

View file

@ -881,7 +881,8 @@ func (sp *serverPeer) OnGetCFTypes(_ *peer.Peer, msg *wire.MsgGetCFTypes) {
return
}
cfTypesMsg := wire.NewMsgCFTypes([]uint8{0, 1})
cfTypesMsg := wire.NewMsgCFTypes([]wire.FilterType{
wire.GCSFilterRegular, wire.GCSFilterExtended})
sp.QueueMessage(cfTypesMsg, nil)
}

View file

@ -69,12 +69,13 @@ func TestMessage(t *testing.T) {
bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0)
msgMerkleBlock := NewMsgMerkleBlock(bh)
msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block")
msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, 0)
msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, GCSFilterExtended)
msgGetCFHeaders := NewMsgGetCFHeaders()
msgGetCFTypes := NewMsgGetCFTypes()
msgCFilter := NewMsgCFilter(&chainhash.Hash{}, 1, []byte("payload"))
msgCFilter := NewMsgCFilter(&chainhash.Hash{}, GCSFilterExtended,
[]byte("payload"))
msgCFHeaders := NewMsgCFHeaders()
msgCFTypes := NewMsgCFTypes([]uint8{2})
msgCFTypes := NewMsgCFTypes([]FilterType{GCSFilterExtended})
tests := []struct {
in Message // Value to encode

View file

@ -28,7 +28,7 @@ const (
// MsgGetCFHeaders for details on requesting the headers.
type MsgCFHeaders struct {
StopHash chainhash.Hash
FilterType uint8
FilterType FilterType
HeaderHashes []*chainhash.Hash
}

View file

@ -22,7 +22,7 @@ const (
// getcfilter (MsgGetCFilter) message.
type MsgCFilter struct {
BlockHash chainhash.Hash
FilterType uint8
FilterType FilterType
Data []byte
}
@ -99,7 +99,7 @@ func (msg *MsgCFilter) MaxPayloadLength(pver uint32) uint32 {
// NewMsgCFilter returns a new bitcoin cfilter message that conforms to the
// Message interface. See MsgCFilter for details.
func NewMsgCFilter(blockHash *chainhash.Hash, filterType uint8,
func NewMsgCFilter(blockHash *chainhash.Hash, filterType FilterType,
data []byte) *MsgCFilter {
return &MsgCFilter{
BlockHash: *blockHash,

View file

@ -6,28 +6,40 @@ package wire
import "io"
// FilterType is used to represent a filter type.
type FilterType uint8
const (
// GCSFilterRegular is the regular filter type.
GCSFilterRegular FilterType = iota
// GCSFilterExtended is the extended filter type.
GCSFilterExtended
)
// MsgCFTypes is the cftypes message.
type MsgCFTypes struct {
NumFilters uint8
SupportedFilters []uint8
SupportedFilters []FilterType
}
// BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
// This is part of the Message interface implementation.
func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) error {
// Read the number of filter types supported
err := readElement(r, &msg.NumFilters)
// Read the number of filter types supported.
count, err := ReadVarInt(r, pver)
if err != nil {
return err
}
// Read each filter type.
msg.SupportedFilters = make([]uint8, msg.NumFilters)
for i := uint8(0); i < msg.NumFilters; i++ {
err = readElement(r, &msg.SupportedFilters[i])
msg.SupportedFilters = make([]FilterType, count)
for i := uint64(0); i < count; i++ {
var filterType uint8
err = readElement(r, &filterType)
if err != nil {
return err
}
msg.SupportedFilters[i] = FilterType(filterType)
}
return nil
@ -36,9 +48,8 @@ func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) er
// BtcEncode encodes the receiver to w using the bitcoin protocol encoding.
// This is part of the Message interface implementation.
func (msg *MsgCFTypes) BtcEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
// Write length of supported filters slice; don't trust that the caller
// has set it correctly.
err := writeElement(w, uint8(len(msg.SupportedFilters)))
// Write length of supported filters slice. We assume it's deduplicated.
err := WriteVarInt(w, pver, uint64(len(msg.SupportedFilters)))
if err != nil {
return err
}
@ -78,15 +89,14 @@ func (msg *MsgCFTypes) Command() string {
// MaxPayloadLength returns the maximum length the payload can be for the
// receiver. This is part of the Message interface implementation.
func (msg *MsgCFTypes) MaxPayloadLength(pver uint32) uint32 {
// 1 byte for NumFilters, and 1 byte for up to 255 filter types.
return 256
// 2 bytes for filter count, and 1 byte for up to 256 filter types.
return 258
}
// NewMsgCFTypes returns a new bitcoin cftypes message that conforms to the
// Message interface. See MsgCFTypes for details.
func NewMsgCFTypes(filterTypes []uint8) *MsgCFTypes {
func NewMsgCFTypes(filterTypes []FilterType) *MsgCFTypes {
return &MsgCFTypes{
NumFilters: uint8(len(filterTypes)),
SupportedFilters: filterTypes,
}
}

View file

@ -17,7 +17,7 @@ import (
type MsgGetCFHeaders struct {
BlockLocatorHashes []*chainhash.Hash
HashStop chainhash.Hash
FilterType uint8
FilterType FilterType
}
// AddBlockLocatorHash adds a new block locator hash to the message.

View file

@ -14,7 +14,7 @@ import (
// getcfilter message. It is used to request a committed filter for a block.
type MsgGetCFilter struct {
BlockHash chainhash.Hash
FilterType uint8
FilterType FilterType
}
// BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
@ -53,7 +53,8 @@ func (msg *MsgGetCFilter) MaxPayloadLength(pver uint32) uint32 {
// NewMsgGetCFilter returns a new bitcoin getcfilter message that conforms to
// the Message interface using the passed parameters and defaults for the
// remaining fields.
func NewMsgGetCFilter(blockHash *chainhash.Hash, filterType uint8) *MsgGetCFilter {
func NewMsgGetCFilter(blockHash *chainhash.Hash,
filterType FilterType) *MsgGetCFilter {
return &MsgGetCFilter{
BlockHash: *blockHash,
FilterType: filterType,