multi: use hidden varint for cftypes count; make filter type enum, not uint8
This commit is contained in:
parent
8d2ce855eb
commit
c7e7acc7fd
12 changed files with 85 additions and 55 deletions
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/btcsuite/btcutil/gcs"
|
||||
"github.com/btcsuite/btcutil/gcs/builder"
|
||||
"github.com/btcsuite/fastsha256"
|
||||
"github.com/roasbeef/btcd/wire"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -149,7 +150,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
|
|||
firstHeader := make([]byte, chainhash.HashSize)
|
||||
err = dbStoreFilterHeader(
|
||||
dbTx,
|
||||
cfHeaderKeys[0],
|
||||
cfHeaderKeys[wire.GCSFilterRegular],
|
||||
&idx.chainParams.GenesisBlock.Header.PrevBlock,
|
||||
firstHeader,
|
||||
)
|
||||
|
@ -159,7 +160,7 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
|
|||
|
||||
return dbStoreFilterHeader(
|
||||
dbTx,
|
||||
cfHeaderKeys[1],
|
||||
cfHeaderKeys[wire.GCSFilterExtended],
|
||||
&idx.chainParams.GenesisBlock.Header.PrevBlock,
|
||||
firstHeader,
|
||||
)
|
||||
|
@ -168,8 +169,8 @@ func (idx *CfIndex) Create(dbTx database.Tx) error {
|
|||
// storeFilter stores a given filter, and performs the steps needed to
|
||||
// generate the filter's header.
|
||||
func storeFilter(dbTx database.Tx, block *btcutil.Block, f *gcs.Filter,
|
||||
filterType uint8) error {
|
||||
if filterType > maxFilterType {
|
||||
filterType wire.FilterType) error {
|
||||
if uint8(filterType) > maxFilterType {
|
||||
return errors.New("unsupported filter type")
|
||||
}
|
||||
|
||||
|
@ -215,7 +216,8 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block,
|
|||
return err
|
||||
}
|
||||
|
||||
if err := storeFilter(dbTx, block, f, 0); err != nil {
|
||||
if err := storeFilter(dbTx, block, f,
|
||||
wire.GCSFilterRegular); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -224,7 +226,7 @@ func (idx *CfIndex) ConnectBlock(dbTx database.Tx, block *btcutil.Block,
|
|||
return err
|
||||
}
|
||||
|
||||
return storeFilter(dbTx, block, f, 1)
|
||||
return storeFilter(dbTx, block, f, wire.GCSFilterExtended)
|
||||
}
|
||||
|
||||
// DisconnectBlock is invoked by the index manager when a block has been
|
||||
|
@ -252,10 +254,11 @@ func (idx *CfIndex) DisconnectBlock(dbTx database.Tx, block *btcutil.Block,
|
|||
|
||||
// FilterByBlockHash returns the serialized contents of a block's basic or
|
||||
// extended committed filter.
|
||||
func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) {
|
||||
func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash,
|
||||
filterType wire.FilterType) ([]byte, error) {
|
||||
var f []byte
|
||||
err := idx.db.View(func(dbTx database.Tx) error {
|
||||
if filterType > maxFilterType {
|
||||
if uint8(filterType) > maxFilterType {
|
||||
return errors.New("unsupported filter type")
|
||||
}
|
||||
|
||||
|
@ -268,15 +271,17 @@ func (idx *CfIndex) FilterByBlockHash(h *chainhash.Hash, filterType uint8) ([]by
|
|||
|
||||
// FilterHeaderByBlockHash returns the serialized contents of a block's basic
|
||||
// or extended committed filter header.
|
||||
func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash, filterType uint8) ([]byte, error) {
|
||||
func (idx *CfIndex) FilterHeaderByBlockHash(h *chainhash.Hash,
|
||||
filterType wire.FilterType) ([]byte, error) {
|
||||
var fh []byte
|
||||
err := idx.db.View(func(dbTx database.Tx) error {
|
||||
if filterType > 1 {
|
||||
if uint8(filterType) > maxFilterType {
|
||||
return errors.New("unsupported filter type")
|
||||
}
|
||||
|
||||
var err error
|
||||
fh, err = dbFetchFilterHeader(dbTx, cfHeaderKeys[filterType], h)
|
||||
fh, err = dbFetchFilterHeader(dbTx,
|
||||
cfHeaderKeys[filterType], h)
|
||||
return err
|
||||
})
|
||||
return fh, err
|
||||
|
|
|
@ -10,6 +10,8 @@ package btcjson
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/roasbeef/btcd/wire"
|
||||
)
|
||||
|
||||
// AddNodeSubCmd defines the type used in the addnode JSON-RPC command for the
|
||||
|
@ -282,12 +284,12 @@ func NewGetBlockTemplateCmd(request *TemplateRequest) *GetBlockTemplateCmd {
|
|||
// GetCFilterCmd defines the getcfilter JSON-RPC command.
|
||||
type GetCFilterCmd struct {
|
||||
Hash string
|
||||
FilterType uint8
|
||||
FilterType wire.FilterType
|
||||
}
|
||||
|
||||
// NewGetCFilterCmd returns a new instance which can be used to issue a
|
||||
// getcfilter JSON-RPC command.
|
||||
func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd {
|
||||
func NewGetCFilterCmd(hash string, filterType wire.FilterType) *GetCFilterCmd {
|
||||
return &GetCFilterCmd{
|
||||
Hash: hash,
|
||||
FilterType: filterType,
|
||||
|
@ -297,12 +299,13 @@ func NewGetCFilterCmd(hash string, filterType uint8) *GetCFilterCmd {
|
|||
// GetCFilterHeaderCmd defines the getcfilterheader JSON-RPC command.
|
||||
type GetCFilterHeaderCmd struct {
|
||||
Hash string
|
||||
FilterType uint8
|
||||
FilterType wire.FilterType
|
||||
}
|
||||
|
||||
// NewGetCFilterHeaderCmd returns a new instance which can be used to issue a
|
||||
// getcfilterheader JSON-RPC command.
|
||||
func NewGetCFilterHeaderCmd(hash string, filterType uint8) *GetCFilterHeaderCmd {
|
||||
func NewGetCFilterHeaderCmd(hash string,
|
||||
filterType wire.FilterType) *GetCFilterHeaderCmd {
|
||||
return &GetCFilterHeaderCmd{
|
||||
Hash: hash,
|
||||
FilterType: filterType,
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcjson"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
)
|
||||
|
||||
// TestChainSvrCmds tests all of the chain server commands marshal and unmarshal
|
||||
|
@ -320,27 +321,33 @@ func TestChainSvrCmds(t *testing.T) {
|
|||
{
|
||||
name: "getcfilter",
|
||||
newCmd: func() (interface{}, error) {
|
||||
return btcjson.NewCmd("getcfilter", "123", 0)
|
||||
return btcjson.NewCmd("getcfilter", "123",
|
||||
wire.GCSFilterExtended)
|
||||
},
|
||||
staticCmd: func() interface{} {
|
||||
return btcjson.NewGetCFilterCmd("123", 0)
|
||||
return btcjson.NewGetCFilterCmd("123",
|
||||
wire.GCSFilterExtended)
|
||||
},
|
||||
marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",0],"id":1}`,
|
||||
marshalled: `{"jsonrpc":"1.0","method":"getcfilter","params":["123",1],"id":1}`,
|
||||
unmarshalled: &btcjson.GetCFilterCmd{
|
||||
Hash: "123",
|
||||
Hash: "123",
|
||||
FilterType: wire.GCSFilterExtended,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "getcfilterheader",
|
||||
newCmd: func() (interface{}, error) {
|
||||
return btcjson.NewCmd("getcfilterheader", "123", 0)
|
||||
return btcjson.NewCmd("getcfilterheader", "123",
|
||||
wire.GCSFilterExtended)
|
||||
},
|
||||
staticCmd: func() interface{} {
|
||||
return btcjson.NewGetCFilterHeaderCmd("123", 0)
|
||||
return btcjson.NewGetCFilterHeaderCmd("123",
|
||||
wire.GCSFilterExtended)
|
||||
},
|
||||
marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",0],"id":1}`,
|
||||
marshalled: `{"jsonrpc":"1.0","method":"getcfilterheader","params":["123",1],"id":1}`,
|
||||
unmarshalled: &btcjson.GetCFilterHeaderCmd{
|
||||
Hash: "123",
|
||||
Hash: "123",
|
||||
FilterType: wire.GCSFilterExtended,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -542,7 +542,8 @@ func TestPeerListeners(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"OnGetCFilter",
|
||||
wire.NewMsgGetCFilter(&chainhash.Hash{}, 0),
|
||||
wire.NewMsgGetCFilter(&chainhash.Hash{},
|
||||
wire.GCSFilterRegular),
|
||||
},
|
||||
{
|
||||
"OnGetCFHeaders",
|
||||
|
@ -554,8 +555,8 @@ func TestPeerListeners(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"OnCFilter",
|
||||
wire.NewMsgCFilter(&chainhash.Hash{}, 1,
|
||||
[]byte("payload")),
|
||||
wire.NewMsgCFilter(&chainhash.Hash{},
|
||||
wire.GCSFilterRegular, []byte("payload")),
|
||||
},
|
||||
{
|
||||
"OnCFHeaders",
|
||||
|
@ -563,7 +564,8 @@ func TestPeerListeners(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"OnCFTypes",
|
||||
wire.NewMsgCFTypes([]uint8{0, 1}),
|
||||
wire.NewMsgCFTypes([]wire.FilterType{
|
||||
wire.GCSFilterRegular, wire.GCSFilterExtended}),
|
||||
},
|
||||
{
|
||||
"OnFeeFilter",
|
||||
|
|
|
@ -821,7 +821,7 @@ func (r FutureGetCFilterResult) Receive() (*wire.MsgCFilter, error) {
|
|||
//
|
||||
// See GetCFilter for the blocking version and more details.
|
||||
func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash,
|
||||
filterType uint8) FutureGetCFilterResult {
|
||||
filterType wire.FilterType) FutureGetCFilterResult {
|
||||
hash := ""
|
||||
if blockHash != nil {
|
||||
hash = blockHash.String()
|
||||
|
@ -833,7 +833,7 @@ func (c *Client) GetCFilterAsync(blockHash *chainhash.Hash,
|
|||
|
||||
// GetCFilter returns a raw filter from the server given its block hash.
|
||||
func (c *Client) GetCFilter(blockHash *chainhash.Hash,
|
||||
filterType uint8) (*wire.MsgCFilter, error) {
|
||||
filterType wire.FilterType) (*wire.MsgCFilter, error) {
|
||||
return c.GetCFilterAsync(blockHash, filterType).Receive()
|
||||
}
|
||||
|
||||
|
@ -878,7 +878,7 @@ func (r FutureGetCFilterHeaderResult) Receive() (*wire.MsgCFHeaders, error) {
|
|||
//
|
||||
// See GetCFilterHeader for the blocking version and more details.
|
||||
func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash,
|
||||
filterType uint8) FutureGetCFilterHeaderResult {
|
||||
filterType wire.FilterType) FutureGetCFilterHeaderResult {
|
||||
hash := ""
|
||||
if blockHash != nil {
|
||||
hash = blockHash.String()
|
||||
|
@ -891,6 +891,6 @@ func (c *Client) GetCFilterHeaderAsync(blockHash *chainhash.Hash,
|
|||
// GetCFilterHeader returns a raw filter header from the server given its block
|
||||
// hash.
|
||||
func (c *Client) GetCFilterHeader(blockHash *chainhash.Hash,
|
||||
filterType uint8) (*wire.MsgCFHeaders, error) {
|
||||
filterType wire.FilterType) (*wire.MsgCFHeaders, error) {
|
||||
return c.GetCFilterHeaderAsync(blockHash, filterType).Receive()
|
||||
}
|
||||
|
|
|
@ -881,7 +881,8 @@ func (sp *serverPeer) OnGetCFTypes(_ *peer.Peer, msg *wire.MsgGetCFTypes) {
|
|||
return
|
||||
}
|
||||
|
||||
cfTypesMsg := wire.NewMsgCFTypes([]uint8{0, 1})
|
||||
cfTypesMsg := wire.NewMsgCFTypes([]wire.FilterType{
|
||||
wire.GCSFilterRegular, wire.GCSFilterExtended})
|
||||
sp.QueueMessage(cfTypesMsg, nil)
|
||||
}
|
||||
|
||||
|
|
|
@ -69,12 +69,13 @@ func TestMessage(t *testing.T) {
|
|||
bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0)
|
||||
msgMerkleBlock := NewMsgMerkleBlock(bh)
|
||||
msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block")
|
||||
msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, 0)
|
||||
msgGetCFilter := NewMsgGetCFilter(&chainhash.Hash{}, GCSFilterExtended)
|
||||
msgGetCFHeaders := NewMsgGetCFHeaders()
|
||||
msgGetCFTypes := NewMsgGetCFTypes()
|
||||
msgCFilter := NewMsgCFilter(&chainhash.Hash{}, 1, []byte("payload"))
|
||||
msgCFilter := NewMsgCFilter(&chainhash.Hash{}, GCSFilterExtended,
|
||||
[]byte("payload"))
|
||||
msgCFHeaders := NewMsgCFHeaders()
|
||||
msgCFTypes := NewMsgCFTypes([]uint8{2})
|
||||
msgCFTypes := NewMsgCFTypes([]FilterType{GCSFilterExtended})
|
||||
|
||||
tests := []struct {
|
||||
in Message // Value to encode
|
||||
|
|
|
@ -28,7 +28,7 @@ const (
|
|||
// MsgGetCFHeaders for details on requesting the headers.
|
||||
type MsgCFHeaders struct {
|
||||
StopHash chainhash.Hash
|
||||
FilterType uint8
|
||||
FilterType FilterType
|
||||
HeaderHashes []*chainhash.Hash
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ const (
|
|||
// getcfilter (MsgGetCFilter) message.
|
||||
type MsgCFilter struct {
|
||||
BlockHash chainhash.Hash
|
||||
FilterType uint8
|
||||
FilterType FilterType
|
||||
Data []byte
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ func (msg *MsgCFilter) MaxPayloadLength(pver uint32) uint32 {
|
|||
|
||||
// NewMsgCFilter returns a new bitcoin cfilter message that conforms to the
|
||||
// Message interface. See MsgCFilter for details.
|
||||
func NewMsgCFilter(blockHash *chainhash.Hash, filterType uint8,
|
||||
func NewMsgCFilter(blockHash *chainhash.Hash, filterType FilterType,
|
||||
data []byte) *MsgCFilter {
|
||||
return &MsgCFilter{
|
||||
BlockHash: *blockHash,
|
||||
|
|
|
@ -6,28 +6,40 @@ package wire
|
|||
|
||||
import "io"
|
||||
|
||||
// FilterType is used to represent a filter type.
|
||||
type FilterType uint8
|
||||
|
||||
const (
|
||||
// GCSFilterRegular is the regular filter type.
|
||||
GCSFilterRegular FilterType = iota
|
||||
|
||||
// GCSFilterExtended is the extended filter type.
|
||||
GCSFilterExtended
|
||||
)
|
||||
|
||||
// MsgCFTypes is the cftypes message.
|
||||
type MsgCFTypes struct {
|
||||
NumFilters uint8
|
||||
SupportedFilters []uint8
|
||||
SupportedFilters []FilterType
|
||||
}
|
||||
|
||||
// BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
|
||||
// This is part of the Message interface implementation.
|
||||
func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) error {
|
||||
// Read the number of filter types supported
|
||||
err := readElement(r, &msg.NumFilters)
|
||||
// Read the number of filter types supported.
|
||||
count, err := ReadVarInt(r, pver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read each filter type.
|
||||
msg.SupportedFilters = make([]uint8, msg.NumFilters)
|
||||
for i := uint8(0); i < msg.NumFilters; i++ {
|
||||
err = readElement(r, &msg.SupportedFilters[i])
|
||||
msg.SupportedFilters = make([]FilterType, count)
|
||||
for i := uint64(0); i < count; i++ {
|
||||
var filterType uint8
|
||||
err = readElement(r, &filterType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.SupportedFilters[i] = FilterType(filterType)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -36,9 +48,8 @@ func (msg *MsgCFTypes) BtcDecode(r io.Reader, pver uint32, _ MessageEncoding) er
|
|||
// BtcEncode encodes the receiver to w using the bitcoin protocol encoding.
|
||||
// This is part of the Message interface implementation.
|
||||
func (msg *MsgCFTypes) BtcEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
|
||||
// Write length of supported filters slice; don't trust that the caller
|
||||
// has set it correctly.
|
||||
err := writeElement(w, uint8(len(msg.SupportedFilters)))
|
||||
// Write length of supported filters slice. We assume it's deduplicated.
|
||||
err := WriteVarInt(w, pver, uint64(len(msg.SupportedFilters)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -78,15 +89,14 @@ func (msg *MsgCFTypes) Command() string {
|
|||
// MaxPayloadLength returns the maximum length the payload can be for the
|
||||
// receiver. This is part of the Message interface implementation.
|
||||
func (msg *MsgCFTypes) MaxPayloadLength(pver uint32) uint32 {
|
||||
// 1 byte for NumFilters, and 1 byte for up to 255 filter types.
|
||||
return 256
|
||||
// 2 bytes for filter count, and 1 byte for up to 256 filter types.
|
||||
return 258
|
||||
}
|
||||
|
||||
// NewMsgCFTypes returns a new bitcoin cftypes message that conforms to the
|
||||
// Message interface. See MsgCFTypes for details.
|
||||
func NewMsgCFTypes(filterTypes []uint8) *MsgCFTypes {
|
||||
func NewMsgCFTypes(filterTypes []FilterType) *MsgCFTypes {
|
||||
return &MsgCFTypes{
|
||||
NumFilters: uint8(len(filterTypes)),
|
||||
SupportedFilters: filterTypes,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
type MsgGetCFHeaders struct {
|
||||
BlockLocatorHashes []*chainhash.Hash
|
||||
HashStop chainhash.Hash
|
||||
FilterType uint8
|
||||
FilterType FilterType
|
||||
}
|
||||
|
||||
// AddBlockLocatorHash adds a new block locator hash to the message.
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
// getcfilter message. It is used to request a committed filter for a block.
|
||||
type MsgGetCFilter struct {
|
||||
BlockHash chainhash.Hash
|
||||
FilterType uint8
|
||||
FilterType FilterType
|
||||
}
|
||||
|
||||
// BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
|
||||
|
@ -53,7 +53,8 @@ func (msg *MsgGetCFilter) MaxPayloadLength(pver uint32) uint32 {
|
|||
// NewMsgGetCFilter returns a new bitcoin getcfilter message that conforms to
|
||||
// the Message interface using the passed parameters and defaults for the
|
||||
// remaining fields.
|
||||
func NewMsgGetCFilter(blockHash *chainhash.Hash, filterType uint8) *MsgGetCFilter {
|
||||
func NewMsgGetCFilter(blockHash *chainhash.Hash,
|
||||
filterType FilterType) *MsgGetCFilter {
|
||||
return &MsgGetCFilter{
|
||||
BlockHash: *blockHash,
|
||||
FilterType: filterType,
|
||||
|
|
Loading…
Reference in a new issue