From b8c3be740f72389285b6f489228e8f61f7ada177 Mon Sep 17 00:00:00 2001
From: pedro martelletto <pedro@ambientworks.net>
Date: Thu, 2 Feb 2017 10:42:44 +0000
Subject: [PATCH] Add CFilterHeader p2p counterparts

---
 peer/peer.go                |  9 +++++
 peer/peer_test.go           |  7 ++++
 server.go                   | 65 +++++++++++++++++++++++++------------
 wire/message.go             | 58 +++++++++++++++++++--------------
 wire/msgcfilterheader.go    | 63 +++++++++++++++++++++++++++++++++++
 wire/msggetcfilter.go       |  2 +-
 wire/msggetcfilterheader.go | 59 +++++++++++++++++++++++++++++++++
 7 files changed, 216 insertions(+), 47 deletions(-)
 create mode 100644 wire/msgcfilterheader.go
 create mode 100644 wire/msggetcfilterheader.go

diff --git a/peer/peer.go b/peer/peer.go
index feefed6c..488f3bcc 100644
--- a/peer/peer.go
+++ b/peer/peer.go
@@ -148,6 +148,10 @@ type MessageListeners struct {
 	// message.
 	OnGetCFilter func(p *Peer, msg *wire.MsgGetCFilter)
 
+	// OnGetCFilterHeader is invoked when a peer receives a
+	// getcfilterheader bitcoin message.
+	OnGetCFilterHeader func(p *Peer, msg *wire.MsgGetCFilterHeader)
+
 	// OnFeeFilter is invoked when a peer receives a feefilter bitcoin message.
 	OnFeeFilter func(p *Peer, msg *wire.MsgFeeFilter)
 
@@ -1588,6 +1592,11 @@ out:
 				p.cfg.Listeners.OnGetCFilter(p, msg)
 			}
 
+		case *wire.MsgGetCFilterHeader:
+			if p.cfg.Listeners.OnGetCFilterHeader != nil {
+				p.cfg.Listeners.OnGetCFilterHeader(p, msg)
+			}
+
 		case *wire.MsgFeeFilter:
 			if p.cfg.Listeners.OnFeeFilter != nil {
 				p.cfg.Listeners.OnFeeFilter(p, msg)
diff --git a/peer/peer_test.go b/peer/peer_test.go
index bfe2987b..c913f79e 100644
--- a/peer/peer_test.go
+++ b/peer/peer_test.go
@@ -402,6 +402,9 @@ func TestPeerListeners(t *testing.T) {
 			OnGetCFilter: func(p *peer.Peer, msg *wire.MsgGetCFilter) {
 				ok <- msg
 			},
+			OnGetCFilterHeader: func(p *peer.Peer, msg *wire.MsgGetCFilterHeader) {
+				ok <- msg
+			},
 			OnFeeFilter: func(p *peer.Peer, msg *wire.MsgFeeFilter) {
 				ok <- msg
 			},
@@ -529,6 +532,10 @@ func TestPeerListeners(t *testing.T) {
 			"OnGetCFilter",
 			wire.NewMsgGetCFilter(&chainhash.Hash{}),
 		},
+		{
+			"OnGetCFilterHeader",
+			wire.NewMsgGetCFilterHeader(&chainhash.Hash{}),
+		},
 		{
 			"OnFeeFilter",
 			wire.NewMsgFeeFilter(15000),
diff --git a/server.go b/server.go
index 3ca3882a..91b614cb 100644
--- a/server.go
+++ b/server.go
@@ -750,16 +750,38 @@ func (sp *serverPeer) OnGetCFilter(_ *peer.Peer, msg *wire.MsgGetCFilter) {
 		msg.Extended)
 
 	if len(filterBytes) > 0 {
-		peerLog.Infof("Obtained CB filter for %v", msg.BlockHash)
+		peerLog.Infof("Obtained CF for %v", msg.BlockHash)
 	} else {
-		peerLog.Infof("Could not obtain CB filter for %v: %v",
-			msg.BlockHash, err)
+		peerLog.Infof("Could not obtain CF for %v: %v", msg.BlockHash,
+			err)
 	}
 
 	filterMsg := wire.NewMsgCFilter(filterBytes)
 	sp.QueueMessage(filterMsg, nil)
 }
 
+// OnGetCFilterHeader is invoked when a peer receives a getcfilterheader bitcoin
+// message.
+func (sp *serverPeer) OnGetCFilterHeader(_ *peer.Peer, msg *wire.MsgGetCFilterHeader) {
+	// Ignore getcfilterheader requests if not in sync.
+	if !sp.server.blockManager.IsCurrent() {
+		return
+	}
+
+	headerBytes, err := sp.server.cfIndex.FilterHeaderByBlockHash(
+		&msg.BlockHash, msg.Extended)
+
+	if len(headerBytes) > 0 {
+		peerLog.Infof("Obtained CF header for %v", msg.BlockHash)
+	} else {
+		peerLog.Infof("Could not obtain CF header for %v: %v",
+			msg.BlockHash, err)
+	}
+
+	headerMsg := wire.NewMsgCFilterHeader(headerBytes)
+	sp.QueueMessage(headerMsg, nil)
+}
+
 // enforceNodeBloomFlag disconnects the peer if the server is not configured to
 // allow bloom filters.  Additionally, if the peer has negotiated to a protocol
 // version  that is high enough to observe the bloom filter service support bit,
@@ -1598,24 +1620,25 @@ func disconnectPeer(peerList map[int32]*serverPeer, compareFunc func(*serverPeer
 func newPeerConfig(sp *serverPeer) *peer.Config {
 	return &peer.Config{
 		Listeners: peer.MessageListeners{
-			OnVersion:     sp.OnVersion,
-			OnMemPool:     sp.OnMemPool,
-			OnTx:          sp.OnTx,
-			OnBlock:       sp.OnBlock,
-			OnInv:         sp.OnInv,
-			OnHeaders:     sp.OnHeaders,
-			OnGetData:     sp.OnGetData,
-			OnGetBlocks:   sp.OnGetBlocks,
-			OnGetHeaders:  sp.OnGetHeaders,
-			OnGetCFilter:  sp.OnGetCFilter,
-			OnFeeFilter:   sp.OnFeeFilter,
-			OnFilterAdd:   sp.OnFilterAdd,
-			OnFilterClear: sp.OnFilterClear,
-			OnFilterLoad:  sp.OnFilterLoad,
-			OnGetAddr:     sp.OnGetAddr,
-			OnAddr:        sp.OnAddr,
-			OnRead:        sp.OnRead,
-			OnWrite:       sp.OnWrite,
+			OnVersion:           sp.OnVersion,
+			OnMemPool:           sp.OnMemPool,
+			OnTx:                sp.OnTx,
+			OnBlock:             sp.OnBlock,
+			OnInv:               sp.OnInv,
+			OnHeaders:           sp.OnHeaders,
+			OnGetData:           sp.OnGetData,
+			OnGetBlocks:         sp.OnGetBlocks,
+			OnGetHeaders:        sp.OnGetHeaders,
+			OnGetCFilter:        sp.OnGetCFilter,
+			OnGetCFilterHeader:  sp.OnGetCFilterHeader,
+			OnFeeFilter:         sp.OnFeeFilter,
+			OnFilterAdd:         sp.OnFilterAdd,
+			OnFilterClear:       sp.OnFilterClear,
+			OnFilterLoad:        sp.OnFilterLoad,
+			OnGetAddr:           sp.OnGetAddr,
+			OnAddr:              sp.OnAddr,
+			OnRead:              sp.OnRead,
+			OnWrite:             sp.OnWrite,
 
 			// Note: The reference client currently bans peers that send alerts
 			// not signed with its key.  We could verify against their key, but
diff --git a/wire/message.go b/wire/message.go
index 168d881d..ade36386 100644
--- a/wire/message.go
+++ b/wire/message.go
@@ -28,31 +28,33 @@ const MaxMessagePayload = (1024 * 1024 * 32) // 32MB
 
 // Commands used in bitcoin message headers which describe the type of message.
 const (
-	CmdVersion     = "version"
-	CmdVerAck      = "verack"
-	CmdGetAddr     = "getaddr"
-	CmdAddr        = "addr"
-	CmdGetBlocks   = "getblocks"
-	CmdInv         = "inv"
-	CmdGetData     = "getdata"
-	CmdNotFound    = "notfound"
-	CmdBlock       = "block"
-	CmdTx          = "tx"
-	CmdGetHeaders  = "getheaders"
-	CmdHeaders     = "headers"
-	CmdPing        = "ping"
-	CmdPong        = "pong"
-	CmdAlert       = "alert"
-	CmdMemPool     = "mempool"
-	CmdFilterAdd   = "filteradd"
-	CmdFilterClear = "filterclear"
-	CmdFilterLoad  = "filterload"
-	CmdMerkleBlock = "merkleblock"
-	CmdReject      = "reject"
-	CmdSendHeaders = "sendheaders"
-	CmdFeeFilter   = "feefilter"
-	CmdGetCFilter  = "getcfilter"
-	CmdCFilter     = "cfilter"
+	CmdVersion           = "version"
+	CmdVerAck            = "verack"
+	CmdGetAddr           = "getaddr"
+	CmdAddr              = "addr"
+	CmdGetBlocks         = "getblocks"
+	CmdInv               = "inv"
+	CmdGetData           = "getdata"
+	CmdNotFound          = "notfound"
+	CmdBlock             = "block"
+	CmdTx                = "tx"
+	CmdGetHeaders        = "getheaders"
+	CmdHeaders           = "headers"
+	CmdPing              = "ping"
+	CmdPong              = "pong"
+	CmdAlert             = "alert"
+	CmdMemPool           = "mempool"
+	CmdFilterAdd         = "filteradd"
+	CmdFilterClear       = "filterclear"
+	CmdFilterLoad        = "filterload"
+	CmdMerkleBlock       = "merkleblock"
+	CmdReject            = "reject"
+	CmdSendHeaders       = "sendheaders"
+	CmdFeeFilter         = "feefilter"
+	CmdGetCFilter        = "getcfilter"
+	CmdGetCFilterHeader  = "getcfilterheader"
+	CmdCFilter           = "cfilter"
+	CmdCFilterHeader     = "cfilterheader"
 )
 
 // MessageEncoding represents the wire message encoding format to be used.
@@ -161,9 +163,15 @@ func makeEmptyMessage(command string) (Message, error) {
 	case CmdGetCFilter:
 		msg = &MsgGetCFilter{}
 
+	case CmdGetCFilterHeader:
+		msg = &MsgGetCFilterHeader{}
+
 	case CmdCFilter:
 		msg = &MsgCFilter{}
 
+	case CmdCFilterHeader:
+		msg = &MsgCFilterHeader{}
+
 	default:
 		return nil, fmt.Errorf("unhandled command [%s]", command)
 	}
diff --git a/wire/msgcfilterheader.go b/wire/msgcfilterheader.go
new file mode 100644
index 00000000..ce8d102a
--- /dev/null
+++ b/wire/msgcfilterheader.go
@@ -0,0 +1,63 @@
+// Copyright (c) 2017 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wire
+
+import (
+	"fmt"
+	"github.com/btcsuite/fastsha256"
+	"io"
+)
+
+const (
+	// MaxCFilterHeaderDataSize is the maximum byte size of a committed
+	// filter header.
+	MaxCFilterHeaderDataSize = fastsha256.Size
+)
+type MsgCFilterHeader struct {
+	Data []byte
+}
+
+// BtcDecode decodes r using the bitcoin protocol encoding into the receiver.
+// This is part of the Message interface implementation.
+func (msg *MsgCFilterHeader) BtcDecode(r io.Reader, pver uint32) error {
+	var err error
+	msg.Data, err = ReadVarBytes(r, pver, MaxCFilterHeaderDataSize,
+		"cf header data")
+	return err
+}
+
+// BtcEncode encodes the receiver to w using the bitcoin protocol encoding.
+// This is part of the Message interface implementation.
+func (msg *MsgCFilterHeader) BtcEncode(w io.Writer, pver uint32) error {
+	size := len(msg.Data)
+	if size > MaxCFilterHeaderDataSize {
+		str := fmt.Sprintf("cf header size too large for message " +
+			"[size %v, max %v]", size, MaxCFilterHeaderDataSize)
+		return messageError("MsgCFilterHeader.BtcEncode", str)
+	}
+
+	return WriteVarBytes(w, pver, msg.Data)
+}
+
+// Command returns the protocol command string for the message.  This is part
+// of the Message interface implementation.
+func (msg *MsgCFilterHeader) Command() string {
+	return CmdCFilterHeader
+}
+
+// MaxPayloadLength returns the maximum length the payload can be for the
+// receiver.  This is part of the Message interface implementation.
+func (msg *MsgCFilterHeader) MaxPayloadLength(pver uint32) uint32 {
+	return uint32(VarIntSerializeSize(MaxCFilterHeaderDataSize)) +
+		MaxCFilterHeaderDataSize
+}
+
+// NewMsgFilterAdd returns a new bitcoin cfilterheader message that conforms to
+// the Message interface. See MsgCFilterHeader for details.
+func NewMsgCFilterHeader(data []byte) *MsgCFilterHeader {
+	return &MsgCFilterHeader{
+		Data: data,
+	}
+}
diff --git a/wire/msggetcfilter.go b/wire/msggetcfilter.go
index b8aa18e7..2b775103 100644
--- a/wire/msggetcfilter.go
+++ b/wire/msggetcfilter.go
@@ -47,7 +47,7 @@ func (msg *MsgGetCFilter) MaxPayloadLength(pver uint32) uint32 {
 	return 4 + chainhash.HashSize + 1
 }
 
-// NewMsgGetCFilter returns a new bitcoin getblocks message that conforms to
+// NewMsgGetCFilter returns a new bitcoin getcfilter message that conforms to
 // the Message interface using the passed parameters and defaults for the
 // remaining fields.
 func NewMsgGetCFilter(blockHash *chainhash.Hash, extended bool) *MsgGetCFilter {
diff --git a/wire/msggetcfilterheader.go b/wire/msggetcfilterheader.go
new file mode 100644
index 00000000..d23e1d18
--- /dev/null
+++ b/wire/msggetcfilterheader.go
@@ -0,0 +1,59 @@
+// Copyright (c) 2017 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wire
+
+import (
+	"io"
+
+	"github.com/btcsuite/btcd/chaincfg/chainhash"
+)
+
+type MsgGetCFilterHeader struct {
+	ProtocolVersion    uint32
+	BlockHash          chainhash.Hash
+	Extended           bool
+}
+
+func (msg *MsgGetCFilterHeader) BtcDecode(r io.Reader, pver uint32) error {
+	err := readElement(r, &msg.BlockHash)
+	if err != nil {
+		return err
+	}
+	return readElement(r, &msg.Extended)
+}
+
+// BtcEncode encodes the receiver to w using the bitcoin protocol encoding.
+// This is part of the Message interface implementation.
+func (msg *MsgGetCFilterHeader) BtcEncode(w io.Writer, pver uint32) error {
+	err := writeElement(w, &msg.BlockHash)
+	if err != nil {
+		return err
+	}
+	return writeElement(w, &msg.Extended)
+}
+
+// Command returns the protocol command string for the message.  This is part
+// of the Message interface implementation.
+func (msg *MsgGetCFilterHeader) Command() string {
+	return CmdGetCFilterHeader
+}
+
+// MaxPayloadLength returns the maximum length the payload can be for the
+// receiver.  This is part of the Message interface implementation.
+func (msg *MsgGetCFilterHeader) MaxPayloadLength(pver uint32) uint32 {
+	// Protocol version 4 bytes + block hash + Extended flag.
+	return 4 + chainhash.HashSize + 1
+}
+
+// NewMsgGetCFilterHeader returns a new bitcoin getcfilterheader message that
+// conforms to the Message interface using the passed parameters and defaults
+// for the remaining fields.
+func NewMsgGetCFilterHeader(blockHash *chainhash.Hash, extended bool) *MsgGetCFilterHeader {
+	return &MsgGetCFilterHeader{
+		ProtocolVersion:     ProtocolVersion,
+		BlockHash:          *blockHash,
+		Extended:            extended,
+	}
+}