Serialize filter with N as a VarInt instead of fixed-size.

This commit is contained in:
Jim Posen 2018-01-17 14:49:07 -08:00 committed by Olaoluwa Osuntokun
parent 9da482119c
commit 884680ddbd
3 changed files with 145 additions and 47 deletions

View file

@ -371,21 +371,28 @@ func BuildExtFilter(block *wire.MsgBlock) (*gcs.Filter, error) {
}
// GetFilterHash returns the double-SHA256 of the filter.
func GetFilterHash(filter *gcs.Filter) chainhash.Hash {
func GetFilterHash(filter *gcs.Filter) (chainhash.Hash, error) {
var zero chainhash.Hash
if filter == nil {
return zero
return zero, nil
}
hash1 := chainhash.HashH(filter.NBytes())
return chainhash.HashH(hash1[:])
filterData, err := filter.NBytes()
if err != nil {
return zero, err
}
return chainhash.DoubleHashH(filterData), nil
}
// MakeHeaderForFilter makes a filter chain header for a filter, given the
// filter and the previous filter chain header.
func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhash.Hash {
func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) (chainhash.Hash, error) {
filterTip := make([]byte, 2*chainhash.HashSize)
filterHash := GetFilterHash(filter)
filterHash, err := GetFilterHash(filter)
if err != nil {
return chainhash.Hash{}, err
}
// In the buffer we created above we'll compute hash || prevHash as an
// intermediate value.
@ -394,6 +401,5 @@ func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhas
// The final filter hash is the double-sha256 of the hash computed
// above.
hash1 := chainhash.HashH(filterTip)
return chainhash.HashH(hash1[:])
return chainhash.DoubleHashH(filterTip), nil
}

View file

@ -6,14 +6,14 @@
package gcs
import (
"encoding/binary"
"bytes"
"fmt"
"io"
"math"
"sort"
"github.com/aead/siphash"
"github.com/kkdai/bstream"
"github.com/roasbeef/btcd/wire"
)
// Inspired by https://github.com/rasky/gcs
@ -34,6 +34,10 @@ const (
// KeySize is the size of the byte array required for key material for
// the SipHash keyed hash function.
KeySize = 16
// varIntProtoVer is the protocol version to use for serializing N as a
// VarInt.
varIntProtoVer uint32 = 0
)
// fastReduction calculates a mapping that's more ore less equivalent to: x mod
@ -83,7 +87,6 @@ func fastReduction(v, nHi, nLo uint64) uint64 {
type Filter struct {
n uint32
p uint8
modulusP uint64
modulusNP uint64
filterData []byte
}
@ -98,7 +101,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
if len(data) == 0 {
return nil, ErrNoData
}
if len(data) > math.MaxInt32 {
if uint64(len(data)) >= (1 << 32) {
return nil, ErrNTooBig
}
if P > 32 {
@ -110,8 +113,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
n: uint32(len(data)),
p: P,
}
f.modulusP = uint64(1 << f.p)
f.modulusNP = uint64(f.n) * f.modulusP
f.modulusNP = uint64(f.n) << P
// Build the filter.
values := make(uint64Slice, 0, len(data))
@ -141,7 +143,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
for _, v := range values {
// Calculate the difference between this value and the last,
// modulo P.
remainder = (v - lastValue) & (f.modulusP - 1)
remainder = (v - lastValue) & ((uint64(1) << P) - 1)
// Calculate the difference between this value and the last,
// divided by P.
@ -181,8 +183,7 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
n: N,
p: P,
}
f.modulusP = uint64(1 << f.p)
f.modulusNP = uint64(f.n) * f.modulusP
f.modulusNP = uint64(f.n) << P
// Copy the filter.
f.filterData = make([]byte, len(d))
@ -194,7 +195,15 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
// FromNBytes deserializes a GCS filter from a known P, and serialized N and
// filter as returned by NBytes().
func FromNBytes(P uint8, d []byte) (*Filter, error) {
return FromBytes(binary.BigEndian.Uint32(d[:4]), P, d[4:])
buffer := bytes.NewBuffer(d)
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
if err != nil {
return nil, err
}
if N >= (1 << 32) {
return nil, ErrNTooBig
}
return FromBytes(uint32(N), P, buffer.Bytes())
}
// FromPBytes deserializes a GCS filter from a known N, and serialized P and
@ -206,43 +215,82 @@ func FromPBytes(N uint32, d []byte) (*Filter, error) {
// FromNPBytes deserializes a GCS filter from a serialized N, P, and filter as
// returned by NPBytes().
func FromNPBytes(d []byte) (*Filter, error) {
return FromBytes(binary.BigEndian.Uint32(d[:4]), d[4], d[5:])
buffer := bytes.NewBuffer(d)
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
if err != nil {
return nil, err
}
if N >= (1 << 32) {
return nil, ErrNTooBig
}
P, err := buffer.ReadByte()
if err != nil {
return nil, err
}
return FromBytes(uint32(N), P, buffer.Bytes())
}
// Bytes returns the serialized format of the GCS filter, which does not
// include N or P (returned by separate methods) or the key used by SipHash.
func (f *Filter) Bytes() []byte {
func (f *Filter) Bytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData))
copy(filterData, f.filterData)
return filterData
return filterData, nil
}
// NBytes returns the serialized format of the GCS filter with N, which does
// not include P (returned by a separate method) or the key used by SipHash.
func (f *Filter) NBytes() []byte {
filterData := make([]byte, len(f.filterData)+4)
binary.BigEndian.PutUint32(filterData[:4], f.n)
copy(filterData[4:], f.filterData)
return filterData
func (f *Filter) NBytes() ([]byte, error) {
var buffer bytes.Buffer
buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData))
err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
if err != nil {
return nil, err
}
_, err = buffer.Write(f.filterData)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
// PBytes returns the serialized format of the GCS filter with P, which does
// not include N (returned by a separate method) or the key used by SipHash.
func (f *Filter) PBytes() []byte {
func (f *Filter) PBytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData)+1)
filterData[0] = f.p
copy(filterData[1:], f.filterData)
return filterData
return filterData, nil
}
// NPBytes returns the serialized format of the GCS filter with N and P, which
// does not include the key used by SipHash.
func (f *Filter) NPBytes() []byte {
filterData := make([]byte, len(f.filterData)+5)
binary.BigEndian.PutUint32(filterData[:4], f.n)
filterData[4] = f.p
copy(filterData[5:], f.filterData)
return filterData
func (f *Filter) NPBytes() ([]byte, error) {
var buffer bytes.Buffer
buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData))
err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
if err != nil {
return nil, err
}
err = buffer.WriteByte(f.p)
if err != nil {
return nil, err
}
_, err = buffer.Write(f.filterData)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
// P returns the filter's collision probability as a negative power of 2 (that
@ -261,7 +309,11 @@ func (f *Filter) N() uint32 {
func (f *Filter) Match(key [KeySize]byte, data []byte) (bool, error) {
// Create a filter bitstream.
filterData := f.Bytes()
filterData, err := f.Bytes()
if err != nil {
return false, err
}
b := bstream.NewBStreamReader(filterData)
// We take the high and low bits of modulusNP for the multiplication
@ -310,7 +362,11 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
}
// Create a filter bitstream.
filterData := f.Bytes()
filterData, err := f.Bytes()
if err != nil {
return false, err
}
b := bstream.NewBStreamReader(filterData)
// Create an uncompressed filter of the search values.
@ -372,7 +428,7 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
// readFullUint64 reads a value represented by the sum of a unary multiple of
// the filter's P modulus (`2**P`) and a big-endian P-bit remainder.
func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
var v uint64
var quotient uint64
// Count the 1s until we reach a 0.
c, err := b.ReadBit()
@ -380,7 +436,7 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
return 0, err
}
for c {
v++
quotient++
c, err = b.ReadBit()
if err != nil {
return 0, err
@ -394,6 +450,6 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
}
// Add the multiple and the remainder.
v = v*f.modulusP + remainder
v := (quotient << f.p) + remainder
return v, nil
}

View file

@ -87,19 +87,35 @@ func TestGCSFilterBuild(t *testing.T) {
// TestGCSFilterCopy deserializes and serializes a filter to create a copy.
func TestGCSFilterCopy(t *testing.T) {
filter2, err = gcs.FromBytes(filter.N(), P, filter.Bytes())
serialized2, err := filter.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
filter2, err = gcs.FromBytes(filter.N(), P, serialized2)
if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error())
}
filter3, err = gcs.FromNBytes(filter.P(), filter.NBytes())
serialized3, err := filter.NBytes()
if err != nil {
t.Fatalf("Filter NBytes() failed: %v", err)
}
filter3, err = gcs.FromNBytes(filter.P(), serialized3)
if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error())
}
filter4, err = gcs.FromPBytes(filter.N(), filter.PBytes())
serialized4, err := filter.PBytes()
if err != nil {
t.Fatalf("Filter PBytes() failed: %v", err)
}
filter4, err = gcs.FromPBytes(filter.N(), serialized4)
if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error())
}
filter5, err = gcs.FromNPBytes(filter.NPBytes())
serialized5, err := filter.NPBytes()
if err != nil {
t.Fatalf("Filter NPBytes() failed: %v", err)
}
filter5, err = gcs.FromNPBytes(serialized5)
if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error())
}
@ -138,16 +154,36 @@ func TestGCSFilterMetadata(t *testing.T) {
if filter.N() != filter5.N() {
t.Fatal("N doesn't match between copied filters")
}
if !bytes.Equal(filter.Bytes(), filter2.Bytes()) {
serialized, err := filter.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
serialized2, err := filter2.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized2) {
t.Fatal("Bytes don't match between copied filters")
}
if !bytes.Equal(filter.Bytes(), filter3.Bytes()) {
serialized3, err := filter3.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized3) {
t.Fatal("Bytes don't match between copied filters")
}
if !bytes.Equal(filter.Bytes(), filter4.Bytes()) {
serialized4, err := filter3.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized4) {
t.Fatal("Bytes don't match between copied filters")
}
if !bytes.Equal(filter.Bytes(), filter5.Bytes()) {
serialized5, err := filter5.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized5) {
t.Fatal("Bytes don't match between copied filters")
}
}