multi: use hidden varint for cftypes count; make filter type enum, not uint8

This commit is contained in:
Alex 2017-10-31 00:24:57 -06:00 committed by Olaoluwa Osuntokun
parent 8d2ce855eb
commit c7e7acc7fd
12 changed files with 85 additions and 55 deletions

View file

@ -15,6 +15,7 @@ import (
"github.com/btcsuite/btcutil/gcs" "github.com/btcsuite/btcutil/gcs"
"github.com/btcsuite/btcutil/gcs/builder" "github.com/btcsuite/btcutil/gcs/builder"
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/roasbeef/btcd/wire"
) )
const ( const (
@ -149,7 +150,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
firstHeader := make([]byte, chainhash.HashSize) firstHeader := make([]byte, chainhash.HashSize)
err = dbStoreFilterHeader( err = dbStoreFilterHeader(
dbTx, dbTx,
cfHeaderKeys[0], cfHeaderKeys[wire.GCSFilterRegular],
&idx.chainParams.GenesisBlock.Header.PrevBlock, &idx.chainParams.GenesisBlock.Header.PrevBlock,
firstHeader, firstHeader,
) )
@ -159,7 +160,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
return dbStoreFilterHeader( return dbStoreFilterHeader(
dbTx, dbTx,
cfHeaderKeys[1], cfHeaderKeys[wire.GCSFilterExtended],
&idx.chainParams.GenesisBlock.Header.PrevBlock, &idx.chainParams.GenesisBlock.Header.PrevBlock,
firstHeader, firstHeader,
) )
@ -168,8 +169,8 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
// storeFilter stores a given filter, and performs the steps needed to // storeFilter stores a given filter, and performs the steps needed to
// generate the filter's header. // generate the filter's header.
func storeFilter(dbTx database.Tx, block *btcutil.Block, f *gcs.Filter, func storeFilter(dbTx database.Tx, block *btcutil.Block, f *gcs.Filter,
filterType uint8) error { filterType wire.FilterType) error {
if filterType > maxFilterType { if uint8(filterType) > maxFilterType {
return errors.New("unsupported filter type") return errors.New("unsupported filter type")
} }
@ -215,7 +216,8 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block,
return err return err
} }
if err := storeFilter(dbTx, block, f, 0); err != nil { if err := storeFilter(dbTx, block, f,
wire.GCSFilterRegular); err != nil {
return err return err
} }
@ -224,7 +226,7 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block,
return err return err
} }
return storeFilter(dbTx, block, f, 1) return storeFilter(dbTx, block, f, wire.GCSFilterExtended)
} }
// DisconnectBlock is invoked by the index manager when a block has been // DisconnectBlock is invoked by the index manager when a block has been
@ -252,10 +254,11 @@ func (idx *CfIndex) DisconnectBlock(dbTx database.Tx, block *btcutil.Block,
// FilterByBlockHash returns the serialized contents of a block's basic or // FilterByBlockHash returns the serialized contents of a block's basic or
// extended committed filter. // extended committed filter.
func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) { func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash,
filterType wire.FilterType) ([]byte, error) {
var f []byte var f []byte
err := idx.db.View(func(dbTx database.Tx) error { err := idx.db.View(func(dbTx database.Tx) error {
if filterType > maxFilterType { if uint8(filterType) > maxFilterType {
return errors.New("unsupported filter type") return errors.New("unsupported filter type")
} }
@ -268,15 +271,17 @@ func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]by
// FilterHeaderByBlockHash returns the serialized contents of a block's basic // FilterHeaderByBlockHash returns the serialized contents of a block's basic
// or extended committed filter header. // or extended committed filter header.
func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) { func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash,
filterType wire.FilterType) ([]byte, error) {
var fh []byte var fh []byte
err := idx.db.View(func(dbTx database.Tx) error { err := idx.db.View(func(dbTx database.Tx) error {
if filterType > 1 { if uint8(filterType) > maxFilterType {
return errors.New("unsupported filter type") return errors.New("unsupported filter type")
} }
var err error var err error
fh, err = dbFetchFilterHeader(dbTx, cfHeaderKeys[filterType], h) fh, err = dbFetchFilterHeader(dbTx,
cfHeaderKeys[filterType], h)
return err return err
}) })
return fh, err return fh, err

View file

@ -10,6 +10,8 @@ package btcjson
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/roasbeef/btcd/wire"
) )
// AddNodeSubCmd defines the type used in the addnode JSON-RPC command for the // AddNodeSubCmd defines the type used in the addnode JSON-RPC command for the
@ -282,12 +284,12 @@ func NewGetBlockTemplateCmd(request *TemplateRequest) *GetBlockTemplateCmd {
// GetCFilterCmd defines the getcfilter JSON-RPC command. // GetCFilterCmd defines the getcfilter JSON-RPC command.
type GetCFilterCmd struct { type GetCFilterCmd struct {
Hash string Hash string
FilterType uint8 FilterType wire.FilterType
} }
// NewGetCFilterCmd returns a new instance which can be used to issue a // NewGetCFilterCmd returns a new instance which can be used to issue a
// getcfilter JSON-RPC command. // getcfilter JSON-RPC command.
func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd { func NewGetCFilterCmd(hash string, filterType wire.FilterType) *GetCFilterCmd {
return &GetCFilterCmd{ return &GetCFilterCmd{
Hash: hash, Hash: hash,
FilterType: filterType, FilterType: filterType,
@ -297,12 +299,13 @@ func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd {
// GetCFilterHeaderCmd defines the getcfilterheader JSON-RPC command. // GetCFilterHeaderCmd defines the getcfilterheader JSON-RPC command.
type GetCFilterHeaderCmd struct { type GetCFilterHeaderCmd struct {
Hash string Hash string
FilterType uint8 FilterType wire.FilterType
} }
// NewGetCFilterHeaderCmd returns a new instance which can be used to issue a // NewGetCFilterHeaderCmd returns a new instance which can be used to issue a
// getcfilterheader JSON-RPC command. // getcfilterheader JSON-RPC command.
func NewGetCFilterHeaderCmd(hash string, filterType uint8) *GetCFilterHeaderCmd { func NewGetCFilterHeaderCmd(hash string,
filterType wire.FilterType) *GetCFilterHeaderCmd {
return &GetCFilterHeaderCmd{ return &GetCFilterHeaderCmd{
Hash: hash, Hash: hash,
FilterType: filterType, FilterType: filterType,

View file

@ -12,6 +12,7 @@ import (
"testing" "testing"
"github.com/btcsuite/btcd/btcjson" "github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/wire"
) )
// TestChainSvrCmds tests all of the chain server commands marshal and unmarshal // TestChainSvrCmds tests all of the chain server commands marshal and unmarshal
@ -320,27 +321,33 @@ func TestChainSvrCmds(t *testing.T) {
{ {
name: "getcfilter", name: "getcfilter",
newCmd: func() (interface{}, error) { newCmd: func() (interface{}, error) {
return btcjson.NewCmd("getcfilter", "123", 0) return btcjson.NewCmd("getcfilter", "123",
wire.GCSFilterExtended)
}, },
staticCmd: func() interface{} { staticCmd: func() interface{} {
return btcjson.NewGetCFilterCmd("123", 0) return btcjson.NewGetCFilterCmd("123",
wire.GCSFilterExtended)
}, },
marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",0],"id":1}`, marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",1],"id":1}`,
unmarshalled: &btcjson.GetCFilterCmd{ unmarshalled: &btcjson.GetCFilterCmd{
Hash: "123", Hash: "123",
FilterType: wire.GCSFilterExtended,
}, },
}, },
{ {
name: "getcfilterheader", name: "getcfilterheader",
newCmd: func() (interface{}, error) { newCmd: func() (interface{}, error) {
return btcjson.NewCmd("getcfilterheader", "123", 0) return btcjson.NewCmd("getcfilterheader", "123",
wire.GCSFilterExtended)
}, },
staticCmd: func() interface{} { staticCmd: func() interface{} {
return btcjson.NewGetCFilterHeaderCmd("123", 0) return btcjson.NewGetCFilterHeaderCmd("123",
wire.GCSFilterExtended)
}, },
marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",0],"id":1}`, marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",1],"id":1}`,
unmarshalled: &btcjson.GetCFilterHeaderCmd{ unmarshalled: &btcjson.GetCFilterHeaderCmd{
Hash: "123", Hash: "123",
FilterType: wire.GCSFilterExtended,
}, },
}, },
{ {

View file

@ -542,7 +542,8 @@ func TestPeerListeners(t *testing.T) {
}, },
{ {
"OnGetCFilter", "OnGetCFilter",
wire.NewMsgGetCFilter(&chainhash.Hash{}, 0), wire.NewMsgGetCFilter(&chainhash.Hash{},
wire.GCSFilterRegular),
}, },
{ {
"OnGetCFHeaders", "OnGetCFHeaders",
@ -554,8 +555,8 @@ func TestPeerListeners(t *testing.T) {
}, },
{ {
"OnCFilter", "OnCFilter",
wire.NewMsgCFilter(&chainhash.Hash{}, 1, wire.NewMsgCFilter(&chainhash.Hash{},
[]byte("payload")), wire.GCSFilterRegular, []byte("payload")),
}, },
{ {
"OnCFHeaders", "OnCFHeaders",
@ -563,7 +564,8 @@ func TestPeerListeners(t *testing.T) {
}, },
{ {
"OnCFTypes", "OnCFTypes",
wire.NewMsgCFTypes([]uint8{0, 1}), wire.NewMsgCFTypes([]wire.FilterType{
wire.GCSFilterRegular, wire.GCSFilterExtended}),
}, },
{ {
"OnFeeFilter", "OnFeeFilter",

View file

@ -821,7 +821,7 @@ func (r FutureGetCFilterResult) Receive() (*wire.MsgCFilter, error) {
// //
// See GetCFilter for the blocking version and more details. // See GetCFilter for the blocking version and more details.
func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash, func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash,
filterType uint8) FutureGetCFilterResult { filterType wire.FilterType) FutureGetCFilterResult {
hash := "" hash := ""
if blockHash != nil { if blockHash != nil {
hash = blockHash.String() hash = blockHash.String()
@ -833,7 +833,7 @@ func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash,
// GetCFilter returns a raw filter from the server given its block hash. // GetCFilter returns a raw filter from the server given its block hash.
func (c *Client) GetCFilter(blockHash *chainhash.Hash, func (c *Client) GetCFilter(blockHash *chainhash.Hash,
filterType uint8) (*wire.MsgCFilter, error) { filterType wire.FilterType) (*wire.MsgCFilter, error) {
return c.GetCFilterAsync(blockHash, filterType).Receive() return c.GetCFilterAsync(blockHash, filterType).Receive()
} }
@ -878,7 +878,7 @@ func (r FutureGetCFilterHeaderResult) Receive() (*wire.MsgCFHeaders, error) {
// //
// See GetCFilterHeader for the blocking version and more details. // See GetCFilterHeader for the blocking version and more details.
func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash, func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash,
filterType uint8) FutureGetCFilterHeaderResult { filterType wire.FilterType) FutureGetCFilterHeaderResult {
hash := "" hash := ""
if blockHash != nil { if blockHash != nil {
hash = blockHash.String() hash = blockHash.String()
@ -891,6 +891,6 @@ func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash,
// GetCFilterHeader returns a raw filter header from the server given its block // GetCFilterHeader returns a raw filter header from the server given its block
// hash. // hash.
func (c *Client) GetCFilterHeader(blockHash *chainhash.Hash, func (c *Client) GetCFilterHeader(blockHash *chainhash.Hash,
filterType uint8) (*wire.MsgCFHeaders, error) { filterType wire.FilterType) (*wire.MsgCFHeaders, error) {
return c.GetCFilterHeaderAsync(blockHash, filterType).Receive() return c.GetCFilterHeaderAsync(blockHash, filterType).Receive()
} }

View file

@ -881,7 +881,8 @@ func (sp *serverPeer) OnGetCFTypes(_ *peer.Peer, msg *wire.MsgGetCFTypes) {
return return
} }
cfTypesMsg := wire.NewMsgCFTypes([]uint8{0, 1}) cfTypesMsg := wire.NewMsgCFTypes([]wire.FilterType{
wire.GCSFilterRegular, wire.GCSFilterExtended})
sp.QueueMessage(cfTypesMsg, nil) sp.QueueMessage(cfTypesMsg, nil)
} }

View file

@ -69,12 +69,13 @@ func TestMessage(t *testing.T) {
bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0) bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0)
msgMerkleBlock := NewMsgMerkleBlock(bh) msgMerkleBlock := NewMsgMerkleBlock(bh)
msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block") msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block")
msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, 0) msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, GCSFilterExtended)
msgGetCFHeaders := NewMsgGetCFHeaders() msgGetCFHeaders := NewMsgGetCFHeaders()
msgGetCFTypes := NewMsgGetCFTypes() msgGetCFTypes := NewMsgGetCFTypes()
msgCFilter := NewMsgCFilter(&chainhash.Hash{}, 1, []byte("payload")) msgCFilter := NewMsgCFilter(&chainhash.Hash{}, GCSFilterExtended,
[]byte("payload"))
msgCFHeaders := NewMsgCFHeaders() msgCFHeaders := NewMsgCFHeaders()
msgCFTypes := NewMsgCFTypes([]uint8{2}) msgCFTypes := NewMsgCFTypes([]FilterType{GCSFilterExtended})
tests := []struct { tests := []struct {
in Message // Value to encode in Message // Value to encode

View file

@ -28,7 +28,7 @@ const (
// MsgGetCFHeaders for details on requesting the headers. // MsgGetCFHeaders for details on requesting the headers.
type MsgCFHeaders struct { type MsgCFHeaders struct {
StopHash chainhash.Hash StopHash chainhash.Hash
FilterType uint8 FilterType FilterType
HeaderHashes []*chainhash.Hash HeaderHashes []*chainhash.Hash
} }

View file

@ -22,7 +22,7 @@ const (
// getcfilter (MsgGetCFilter) message. // getcfilter (MsgGetCFilter) message.
type MsgCFilter struct { type MsgCFilter struct {
BlockHash chainhash.Hash BlockHash chainhash.Hash
FilterType uint8 FilterType FilterType
Data []byte Data []byte
} }
@ -99,7 +99,7 @@ func (msg *MsgCFilter) MaxPayloadLength(pver uint32) uint32 {
// NewMsgCFilter returns a new bitcoin cfilter message that conforms to the // NewMsgCFilter returns a new bitcoin cfilter message that conforms to the
// Message interface. See MsgCFilter for details. // Message interface. See MsgCFilter for details.
func NewMsgCFilter(blockHash *chainhash.Hash, filterType uint8, func NewMsgCFilter(blockHash *chainhash.Hash, filterType FilterType,
data []byte) *MsgCFilter { data []byte) *MsgCFilter {
return &MsgCFilter{ return &MsgCFilter{
BlockHash: *blockHash, BlockHash: *blockHash,

View file

@ -6,28 +6,40 @@ package wire
import "io" import "io"
// FilterType is used to represent a filter type.
type FilterType uint8
const (
// GCSFilterRegular is the regular filter type.
GCSFilterRegular FilterType = iota
// GCSFilterExtended is the extended filter type.
GCSFilterExtended
)
// MsgCFTypes is the cftypes message. // MsgCFTypes is the cftypes message.
type MsgCFTypes struct { type MsgCFTypes struct {
NumFilters uint8 SupportedFilters []FilterType
SupportedFilters []uint8
} }
// BtcDecode decodes r using the bitcoin protocol encoding into the receiver. // BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
// This is part of the Message interface implementation. // This is part of the Message interface implementation.
func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) error { func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) error {
// Read the number of filter types supported // Read the number of filter types supported.
err := readElement(r, &msg.NumFilters) count, err := ReadVarInt(r, pver)
if err != nil { if err != nil {
return err return err
} }
// Read each filter type. // Read each filter type.
msg.SupportedFilters = make([]uint8, msg.NumFilters) msg.SupportedFilters = make([]FilterType, count)
for i := uint8(0); i < msg.NumFilters; i++ { for i := uint64(0); i < count; i++ {
err = readElement(r, &msg.SupportedFilters[i]) var filterType uint8
err = readElement(r, &filterType)
if err != nil { if err != nil {
return err return err
} }
msg.SupportedFilters[i] = FilterType(filterType)
} }
return nil return nil
@ -36,9 +48,8 @@ func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) er
// BtcEncode encodes the receiver to w using the bitcoin protocol encoding. // BtcEncode encodes the receiver to w using the bitcoin protocol encoding.
// This is part of the Message interface implementation. // This is part of the Message interface implementation.
func (msg *MsgCFTypes) BtcEncode(w io.Writer, pver uint32, _ MessageEncoding) error { func (msg *MsgCFTypes) BtcEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
// Write length of supported filters slice; don't trust that the caller // Write length of supported filters slice. We assume it's deduplicated.
// has set it correctly. err := WriteVarInt(w, pver, uint64(len(msg.SupportedFilters)))
err := writeElement(w, uint8(len(msg.SupportedFilters)))
if err != nil { if err != nil {
return err return err
} }
@ -78,15 +89,14 @@ func (msg *MsgCFTypes) Command() string {
// MaxPayloadLength returns the maximum length the payload can be for the // MaxPayloadLength returns the maximum length the payload can be for the
// receiver. This is part of the Message interface implementation. // receiver. This is part of the Message interface implementation.
func (msg *MsgCFTypes) MaxPayloadLength(pver uint32) uint32 { func (msg *MsgCFTypes) MaxPayloadLength(pver uint32) uint32 {
// 1 byte for NumFilters, and 1 byte for up to 255 filter types. // 2 bytes for filter count, and 1 byte for up to 256 filter types.
return 256 return 258
} }
// NewMsgCFTypes returns a new bitcoin cftypes message that conforms to the // NewMsgCFTypes returns a new bitcoin cftypes message that conforms to the
// Message interface. See MsgCFTypes for details. // Message interface. See MsgCFTypes for details.
func NewMsgCFTypes(filterTypes []uint8) *MsgCFTypes { func NewMsgCFTypes(filterTypes []FilterType) *MsgCFTypes {
return &MsgCFTypes{ return &MsgCFTypes{
NumFilters: uint8(len(filterTypes)),
SupportedFilters: filterTypes, SupportedFilters: filterTypes,
} }
} }

View file

@ -17,7 +17,7 @@ import (
type MsgGetCFHeaders struct { type MsgGetCFHeaders struct {
BlockLocatorHashes []*chainhash.Hash BlockLocatorHashes []*chainhash.Hash
HashStop chainhash.Hash HashStop chainhash.Hash
FilterType uint8 FilterType FilterType
} }
// AddBlockLocatorHash adds a new block locator hash to the message. // AddBlockLocatorHash adds a new block locator hash to the message.

View file

@ -14,7 +14,7 @@ import (
// getcfilter message. It is used to request a committed filter for a block. // getcfilter message. It is used to request a committed filter for a block.
type MsgGetCFilter struct { type MsgGetCFilter struct {
BlockHash chainhash.Hash BlockHash chainhash.Hash
FilterType uint8 FilterType FilterType
} }
// BtcDecode decodes r using the bitcoin protocol encoding into the receiver. // BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
@ -53,7 +53,8 @@ func (msg *MsgGetCFilter) MaxPayloadLength(pver uint32) uint32 {
// NewMsgGetCFilter returns a new bitcoin getcfilter message that conforms to // NewMsgGetCFilter returns a new bitcoin getcfilter message that conforms to
// the Message interface using the passed parameters and defaults for the // the Message interface using the passed parameters and defaults for the
// remaining fields. // remaining fields.
func NewMsgGetCFilter(blockHash *chainhash.Hash, filterType uint8) *MsgGetCFilter { func NewMsgGetCFilter(blockHash *chainhash.Hash,
filterType FilterType) *MsgGetCFilter {
return &MsgGetCFilter{ return &MsgGetCFilter{
BlockHash: *blockHash, BlockHash: *blockHash,
FilterType: filterType, FilterType: filterType,