From c7e7acc7fda9bd72574fa4a5af3396b0637cc72e Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 31 Oct 2017 00:24:57 -0600 Subject: [PATCH] multi: use hidden varint for cftypes count; make filter type enum, not uint8 --- blockchain/indexers/cfindex.go | 27 ++++++++++++++---------- btcjson/chainsvrcmds.go | 11 ++++++---- btcjson/chainsvrcmds_test.go | 23 +++++++++++++------- peer/peer_test.go | 10 +++++---- rpcclient/chain.go | 8 +++---- server.go | 3 ++- wire/message_test.go | 7 ++++--- wire/msgcfheaders.go | 2 +- wire/msgcfilter.go | 4 ++-- wire/msgcftypes.go | 38 +++++++++++++++++++++------------- wire/msggetcfheaders.go | 2 +- wire/msggetcfilter.go | 5 +++-- 12 files changed, 85 insertions(+), 55 deletions(-) diff --git a/blockchain/indexers/cfindex.go b/blockchain/indexers/cfindex.go index 0e7d3bbf..1572effd 100644 --- a/blockchain/indexers/cfindex.go +++ b/blockchain/indexers/cfindex.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcutil/gcs" "github.com/btcsuite/btcutil/gcs/builder" "github.com/btcsuite/fastsha256" + "github.com/roasbeef/btcd/wire" ) const ( @@ -149,7 +150,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error { firstHeader := make([]byte, chainhash.HashSize) err = dbStoreFilterHeader( dbTx, - cfHeaderKeys[0], + cfHeaderKeys[wire.GCSFilterRegular], &idx.chainParams.GenesisBlock.Header.PrevBlock, firstHeader, ) @@ -159,7 +160,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error { return dbStoreFilterHeader( dbTx, - cfHeaderKeys[1], + cfHeaderKeys[wire.GCSFilterExtended], &idx.chainParams.GenesisBlock.Header.PrevBlock, firstHeader, ) @@ -168,8 +169,8 @@ func (idx *CfIndex) Create(dbTx database.Tx) error { // storeFilter stores a given filter, and performs the steps needed to // generate the filter's header. func storeFilter(dbTx database.Tx, block *btcutil.Block, f *gcs.Filter, - filterType uint8) error { - if filterType > maxFilterType { + filterType wire.FilterType) error { + if uint8(filterType) > maxFilterType { return errors.New("unsupported filter type") } @@ -215,7 +216,8 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block, return err } - if err := storeFilter(dbTx, block, f, 0); err != nil { + if err := storeFilter(dbTx, block, f, + wire.GCSFilterRegular); err != nil { return err } @@ -224,7 +226,7 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block, return err } - return storeFilter(dbTx, block, f, 1) + return storeFilter(dbTx, block, f, wire.GCSFilterExtended) } // DisconnectBlock is invoked by the index manager when a block has been @@ -252,10 +254,11 @@ func (idx *CfIndex) DisconnectBlock(dbTx database.Tx, block *btcutil.Block, // FilterByBlockHash returns the serialized contents of a block's basic or // extended committed filter. -func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) { +func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, + filterType wire.FilterType) ([]byte, error) { var f []byte err := idx.db.View(func(dbTx database.Tx) error { - if filterType > maxFilterType { + if uint8(filterType) > maxFilterType { return errors.New("unsupported filter type") } @@ -268,15 +271,17 @@ func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]by // FilterHeaderByBlockHash returns the serialized contents of a block's basic // or extended committed filter header. -func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) { +func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash, + filterType wire.FilterType) ([]byte, error) { var fh []byte err := idx.db.View(func(dbTx database.Tx) error { - if filterType > 1 { + if uint8(filterType) > maxFilterType { return errors.New("unsupported filter type") } var err error - fh, err = dbFetchFilterHeader(dbTx, cfHeaderKeys[filterType], h) + fh, err = dbFetchFilterHeader(dbTx, + cfHeaderKeys[filterType], h) return err }) return fh, err diff --git a/btcjson/chainsvrcmds.go b/btcjson/chainsvrcmds.go index e7d736f0..18933da0 100644 --- a/btcjson/chainsvrcmds.go +++ b/btcjson/chainsvrcmds.go @@ -10,6 +10,8 @@ package btcjson import ( "encoding/json" "fmt" + + "github.com/roasbeef/btcd/wire" ) // AddNodeSubCmd defines the type used in the addnode JSON-RPC command for the @@ -282,12 +284,12 @@ func NewGetBlockTemplateCmd(request *TemplateRequest) *GetBlockTemplateCmd { // GetCFilterCmd defines the getcfilter JSON-RPC command. type GetCFilterCmd struct { Hash string - FilterType uint8 + FilterType wire.FilterType } // NewGetCFilterCmd returns a new instance which can be used to issue a // getcfilter JSON-RPC command. -func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd { +func NewGetCFilterCmd(hash string, filterType wire.FilterType) *GetCFilterCmd { return &GetCFilterCmd{ Hash: hash, FilterType: filterType, @@ -297,12 +299,13 @@ func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd { // GetCFilterHeaderCmd defines the getcfilterheader JSON-RPC command. type GetCFilterHeaderCmd struct { Hash string - FilterType uint8 + FilterType wire.FilterType } // NewGetCFilterHeaderCmd returns a new instance which can be used to issue a // getcfilterheader JSON-RPC command. -func NewGetCFilterHeaderCmd(hash string, filterType uint8) *GetCFilterHeaderCmd { +func NewGetCFilterHeaderCmd(hash string, + filterType wire.FilterType) *GetCFilterHeaderCmd { return &GetCFilterHeaderCmd{ Hash: hash, FilterType: filterType, diff --git a/btcjson/chainsvrcmds_test.go b/btcjson/chainsvrcmds_test.go index 27fb9d6e..12fba7d9 100644 --- a/btcjson/chainsvrcmds_test.go +++ b/btcjson/chainsvrcmds_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/wire" ) // TestChainSvrCmds tests all of the chain server commands marshal and unmarshal @@ -320,27 +321,33 @@ func TestChainSvrCmds(t *testing.T) { { name: "getcfilter", newCmd: func() (interface{}, error) { - return btcjson.NewCmd("getcfilter", "123", 0) + return btcjson.NewCmd("getcfilter", "123", + wire.GCSFilterExtended) }, staticCmd: func() interface{} { - return btcjson.NewGetCFilterCmd("123", 0) + return btcjson.NewGetCFilterCmd("123", + wire.GCSFilterExtended) }, - marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",0],"id":1}`, + marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",1],"id":1}`, unmarshalled: &btcjson.GetCFilterCmd{ - Hash: "123", + Hash: "123", + FilterType: wire.GCSFilterExtended, }, }, { name: "getcfilterheader", newCmd: func() (interface{}, error) { - return btcjson.NewCmd("getcfilterheader", "123", 0) + return btcjson.NewCmd("getcfilterheader", "123", + wire.GCSFilterExtended) }, staticCmd: func() interface{} { - return btcjson.NewGetCFilterHeaderCmd("123", 0) + return btcjson.NewGetCFilterHeaderCmd("123", + wire.GCSFilterExtended) }, - marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",0],"id":1}`, + marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",1],"id":1}`, unmarshalled: &btcjson.GetCFilterHeaderCmd{ - Hash: "123", + Hash: "123", + FilterType: wire.GCSFilterExtended, }, }, { diff --git a/peer/peer_test.go b/peer/peer_test.go index 89d114b1..742f97c6 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -542,7 +542,8 @@ func TestPeerListeners(t *testing.T) { }, { "OnGetCFilter", - wire.NewMsgGetCFilter(&chainhash.Hash{}, 0), + wire.NewMsgGetCFilter(&chainhash.Hash{}, + wire.GCSFilterRegular), }, { "OnGetCFHeaders", @@ -554,8 +555,8 @@ func TestPeerListeners(t *testing.T) { }, { "OnCFilter", - wire.NewMsgCFilter(&chainhash.Hash{}, 1, - []byte("payload")), + wire.NewMsgCFilter(&chainhash.Hash{}, + wire.GCSFilterRegular, []byte("payload")), }, { "OnCFHeaders", @@ -563,7 +564,8 @@ func TestPeerListeners(t *testing.T) { }, { "OnCFTypes", - wire.NewMsgCFTypes([]uint8{0, 1}), + wire.NewMsgCFTypes([]wire.FilterType{ + wire.GCSFilterRegular, wire.GCSFilterExtended}), }, { "OnFeeFilter", diff --git a/rpcclient/chain.go b/rpcclient/chain.go index 3e5c69e7..41faea18 100644 --- a/rpcclient/chain.go +++ b/rpcclient/chain.go @@ -821,7 +821,7 @@ func (r FutureGetCFilterResult) Receive() (*wire.MsgCFilter, error) { // // See GetCFilter for the blocking version and more details. func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash, - filterType uint8) FutureGetCFilterResult { + filterType wire.FilterType) FutureGetCFilterResult { hash := "" if blockHash != nil { hash = blockHash.String() @@ -833,7 +833,7 @@ func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash, // GetCFilter returns a raw filter from the server given its block hash. func (c *Client) GetCFilter(blockHash *chainhash.Hash, - filterType uint8) (*wire.MsgCFilter, error) { + filterType wire.FilterType) (*wire.MsgCFilter, error) { return c.GetCFilterAsync(blockHash, filterType).Receive() } @@ -878,7 +878,7 @@ func (r FutureGetCFilterHeaderResult) Receive() (*wire.MsgCFHeaders, error) { // // See GetCFilterHeader for the blocking version and more details. func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash, - filterType uint8) FutureGetCFilterHeaderResult { + filterType wire.FilterType) FutureGetCFilterHeaderResult { hash := "" if blockHash != nil { hash = blockHash.String() @@ -891,6 +891,6 @@ func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash, // GetCFilterHeader returns a raw filter header from the server given its block // hash. func (c *Client) GetCFilterHeader(blockHash *chainhash.Hash, - filterType uint8) (*wire.MsgCFHeaders, error) { + filterType wire.FilterType) (*wire.MsgCFHeaders, error) { return c.GetCFilterHeaderAsync(blockHash, filterType).Receive() } diff --git a/server.go b/server.go index 450ebb0e..535240af 100644 --- a/server.go +++ b/server.go @@ -881,7 +881,8 @@ func (sp *serverPeer) OnGetCFTypes(_ *peer.Peer, msg *wire.MsgGetCFTypes) { return } - cfTypesMsg := wire.NewMsgCFTypes([]uint8{0, 1}) + cfTypesMsg := wire.NewMsgCFTypes([]wire.FilterType{ + wire.GCSFilterRegular, wire.GCSFilterExtended}) sp.QueueMessage(cfTypesMsg, nil) } diff --git a/wire/message_test.go b/wire/message_test.go index 7d8cf9cc..897f36be 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -69,12 +69,13 @@ func TestMessage(t *testing.T) { bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0) msgMerkleBlock := NewMsgMerkleBlock(bh) msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block") - msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, 0) + msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, GCSFilterExtended) msgGetCFHeaders := NewMsgGetCFHeaders() msgGetCFTypes := NewMsgGetCFTypes() - msgCFilter := NewMsgCFilter(&chainhash.Hash{}, 1, []byte("payload")) + msgCFilter := NewMsgCFilter(&chainhash.Hash{}, GCSFilterExtended, + []byte("payload")) msgCFHeaders := NewMsgCFHeaders() - msgCFTypes := NewMsgCFTypes([]uint8{2}) + msgCFTypes := NewMsgCFTypes([]FilterType{GCSFilterExtended}) tests := []struct { in Message // Value to encode diff --git a/wire/msgcfheaders.go b/wire/msgcfheaders.go index 6e1ce2a5..0186a53f 100644 --- a/wire/msgcfheaders.go +++ b/wire/msgcfheaders.go @@ -28,7 +28,7 @@ const ( // MsgGetCFHeaders for details on requesting the headers. type MsgCFHeaders struct { StopHash chainhash.Hash - FilterType uint8 + FilterType FilterType HeaderHashes []*chainhash.Hash } diff --git a/wire/msgcfilter.go b/wire/msgcfilter.go index 22329938..3ce33180 100644 --- a/wire/msgcfilter.go +++ b/wire/msgcfilter.go @@ -22,7 +22,7 @@ const ( // getcfilter (MsgGetCFilter) message. type MsgCFilter struct { BlockHash chainhash.Hash - FilterType uint8 + FilterType FilterType Data []byte } @@ -99,7 +99,7 @@ func (msg *MsgCFilter) MaxPayloadLength(pver uint32) uint32 { // NewMsgCFilter returns a new bitcoin cfilter message that conforms to the // Message interface. See MsgCFilter for details. -func NewMsgCFilter(blockHash *chainhash.Hash, filterType uint8, +func NewMsgCFilter(blockHash *chainhash.Hash, filterType FilterType, data []byte) *MsgCFilter { return &MsgCFilter{ BlockHash: *blockHash, diff --git a/wire/msgcftypes.go b/wire/msgcftypes.go index 92a8f718..17388e02 100644 --- a/wire/msgcftypes.go +++ b/wire/msgcftypes.go @@ -6,28 +6,40 @@ package wire import "io" +// FilterType is used to represent a filter type. +type FilterType uint8 + +const ( + // GCSFilterRegular is the regular filter type. + GCSFilterRegular FilterType = iota + + // GCSFilterExtended is the extended filter type. + GCSFilterExtended +) + // MsgCFTypes is the cftypes message. type MsgCFTypes struct { - NumFilters uint8 - SupportedFilters []uint8 + SupportedFilters []FilterType } // BtcDecode decodes r using the bitcoin protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) error { - // Read the number of filter types supported - err := readElement(r, &msg.NumFilters) + // Read the number of filter types supported. + count, err := ReadVarInt(r, pver) if err != nil { return err } // Read each filter type. - msg.SupportedFilters = make([]uint8, msg.NumFilters) - for i := uint8(0); i < msg.NumFilters; i++ { - err = readElement(r, &msg.SupportedFilters[i]) + msg.SupportedFilters = make([]FilterType, count) + for i := uint64(0); i < count; i++ { + var filterType uint8 + err = readElement(r, &filterType) if err != nil { return err } + msg.SupportedFilters[i] = FilterType(filterType) } return nil @@ -36,9 +48,8 @@ func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) er // BtcEncode encodes the receiver to w using the bitcoin protocol encoding. // This is part of the Message interface implementation. func (msg *MsgCFTypes) BtcEncode(w io.Writer, pver uint32, _ MessageEncoding) error { - // Write length of supported filters slice; don't trust that the caller - // has set it correctly. - err := writeElement(w, uint8(len(msg.SupportedFilters))) + // Write length of supported filters slice. We assume it's deduplicated. + err := WriteVarInt(w, pver, uint64(len(msg.SupportedFilters))) if err != nil { return err } @@ -78,15 +89,14 @@ func (msg *MsgCFTypes) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgCFTypes) MaxPayloadLength(pver uint32) uint32 { - // 1 byte for NumFilters, and 1 byte for up to 255 filter types. - return 256 + // 2 bytes for filter count, and 1 byte for up to 256 filter types. + return 258 } // NewMsgCFTypes returns a new bitcoin cftypes message that conforms to the // Message interface. See MsgCFTypes for details. -func NewMsgCFTypes(filterTypes []uint8) *MsgCFTypes { +func NewMsgCFTypes(filterTypes []FilterType) *MsgCFTypes { return &MsgCFTypes{ - NumFilters: uint8(len(filterTypes)), SupportedFilters: filterTypes, } } diff --git a/wire/msggetcfheaders.go b/wire/msggetcfheaders.go index d82ddcc8..d93e4e62 100644 --- a/wire/msggetcfheaders.go +++ b/wire/msggetcfheaders.go @@ -17,7 +17,7 @@ import ( type MsgGetCFHeaders struct { BlockLocatorHashes []*chainhash.Hash HashStop chainhash.Hash - FilterType uint8 + FilterType FilterType } // AddBlockLocatorHash adds a new block locator hash to the message. diff --git a/wire/msggetcfilter.go b/wire/msggetcfilter.go index 8ba677f5..527a391e 100644 --- a/wire/msggetcfilter.go +++ b/wire/msggetcfilter.go @@ -14,7 +14,7 @@ import ( // getcfilter message. It is used to request a committed filter for a block. type MsgGetCFilter struct { BlockHash chainhash.Hash - FilterType uint8 + FilterType FilterType } // BtcDecode decodes r using the bitcoin protocol encoding into the receiver. @@ -53,7 +53,8 @@ func (msg *MsgGetCFilter) MaxPayloadLength(pver uint32) uint32 { // NewMsgGetCFilter returns a new bitcoin getcfilter message that conforms to // the Message interface using the passed parameters and defaults for the // remaining fields. -func NewMsgGetCFilter(blockHash *chainhash.Hash, filterType uint8) *MsgGetCFilter { +func NewMsgGetCFilter(blockHash *chainhash.Hash, + filterType FilterType) *MsgGetCFilter { return &MsgGetCFilter{ BlockHash: *blockHash, FilterType: filterType,