Serialize filter with N as a VarInt instead of fixed-size.

This commit is contained in:
Jim Posen 2018-01-17 14:49:07 -08:00 committed by Olaoluwa Osuntokun
parent 9da482119c
commit 884680ddbd
3 changed files with 145 additions and 47 deletions

View file

@ -371,21 +371,28 @@ func BuildExtFilter(block *wire.MsgBlock) (*gcs.Filter, error) {
} }
// GetFilterHash returns the double-SHA256 of the filter. // GetFilterHash returns the double-SHA256 of the filter.
func GetFilterHash(filter *gcs.Filter) chainhash.Hash { func GetFilterHash(filter *gcs.Filter) (chainhash.Hash, error) {
var zero chainhash.Hash var zero chainhash.Hash
if filter == nil { if filter == nil {
return zero return zero, nil
} }
hash1 := chainhash.HashH(filter.NBytes()) filterData, err := filter.NBytes()
return chainhash.HashH(hash1[:]) if err != nil {
return zero, err
}
return chainhash.DoubleHashH(filterData), nil
} }
// MakeHeaderForFilter makes a filter chain header for a filter, given the // MakeHeaderForFilter makes a filter chain header for a filter, given the
// filter and the previous filter chain header. // filter and the previous filter chain header.
func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhash.Hash { func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) (chainhash.Hash, error) {
filterTip := make([]byte, 2*chainhash.HashSize) filterTip := make([]byte, 2*chainhash.HashSize)
filterHash := GetFilterHash(filter) filterHash, err := GetFilterHash(filter)
if err != nil {
return chainhash.Hash{}, err
}
// In the buffer we created above we'll compute hash || prevHash as an // In the buffer we created above we'll compute hash || prevHash as an
// intermediate value. // intermediate value.
@ -394,6 +401,5 @@ func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhas
// The final filter hash is the double-sha256 of the hash computed // The final filter hash is the double-sha256 of the hash computed
// above. // above.
hash1 := chainhash.HashH(filterTip) return chainhash.DoubleHashH(filterTip), nil
return chainhash.HashH(hash1[:])
} }

View file

@ -6,14 +6,14 @@
package gcs package gcs
import ( import (
"encoding/binary" "bytes"
"fmt" "fmt"
"io" "io"
"math"
"sort" "sort"
"github.com/aead/siphash" "github.com/aead/siphash"
"github.com/kkdai/bstream" "github.com/kkdai/bstream"
"github.com/roasbeef/btcd/wire"
) )
// Inspired by https://github.com/rasky/gcs // Inspired by https://github.com/rasky/gcs
@ -34,6 +34,10 @@ const (
// KeySize is the size of the byte array required for key material for // KeySize is the size of the byte array required for key material for
// the SipHash keyed hash function. // the SipHash keyed hash function.
KeySize = 16 KeySize = 16
// varIntProtoVer is the protocol version to use for serializing N as a
// VarInt.
varIntProtoVer uint32 = 0
) )
// fastReduction calculates a mapping that's more ore less equivalent to: x mod // fastReduction calculates a mapping that's more ore less equivalent to: x mod
@ -83,7 +87,6 @@ func fastReduction(v, nHi, nLo uint64) uint64 {
type Filter struct { type Filter struct {
n uint32 n uint32
p uint8 p uint8
modulusP uint64
modulusNP uint64 modulusNP uint64
filterData []byte filterData []byte
} }
@ -98,7 +101,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
if len(data) == 0 { if len(data) == 0 {
return nil, ErrNoData return nil, ErrNoData
} }
if len(data) > math.MaxInt32 { if uint64(len(data)) >= (1 << 32) {
return nil, ErrNTooBig return nil, ErrNTooBig
} }
if P > 32 { if P > 32 {
@ -110,8 +113,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
n: uint32(len(data)), n: uint32(len(data)),
p: P, p: P,
} }
f.modulusP = uint64(1 << f.p) f.modulusNP = uint64(f.n) << P
f.modulusNP = uint64(f.n) * f.modulusP
// Build the filter. // Build the filter.
values := make(uint64Slice, 0, len(data)) values := make(uint64Slice, 0, len(data))
@ -141,7 +143,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
for _, v := range values { for _, v := range values {
// Calculate the difference between this value and the last, // Calculate the difference between this value and the last,
// modulo P. // modulo P.
remainder = (v - lastValue) & (f.modulusP - 1) remainder = (v - lastValue) & ((uint64(1) << P) - 1)
// Calculate the difference between this value and the last, // Calculate the difference between this value and the last,
// divided by P. // divided by P.
@ -181,8 +183,7 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
n: N, n: N,
p: P, p: P,
} }
f.modulusP = uint64(1 << f.p) f.modulusNP = uint64(f.n) << P
f.modulusNP = uint64(f.n) * f.modulusP
// Copy the filter. // Copy the filter.
f.filterData = make([]byte, len(d)) f.filterData = make([]byte, len(d))
@ -194,7 +195,15 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
// FromNBytes deserializes a GCS filter from a known P, and serialized N and // FromNBytes deserializes a GCS filter from a known P, and serialized N and
// filter as returned by NBytes(). // filter as returned by NBytes().
func FromNBytes(P uint8, d []byte) (*Filter, error) { func FromNBytes(P uint8, d []byte) (*Filter, error) {
return FromBytes(binary.BigEndian.Uint32(d[:4]), P, d[4:]) buffer := bytes.NewBuffer(d)
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
if err != nil {
return nil, err
}
if N >= (1 << 32) {
return nil, ErrNTooBig
}
return FromBytes(uint32(N), P, buffer.Bytes())
} }
// FromPBytes deserializes a GCS filter from a known N, and serialized P and // FromPBytes deserializes a GCS filter from a known N, and serialized P and
@ -206,43 +215,82 @@ func FromPBytes(N uint32, d []byte) (*Filter, error) {
// FromNPBytes deserializes a GCS filter from a serialized N, P, and filter as // FromNPBytes deserializes a GCS filter from a serialized N, P, and filter as
// returned by NPBytes(). // returned by NPBytes().
func FromNPBytes(d []byte) (*Filter, error) { func FromNPBytes(d []byte) (*Filter, error) {
return FromBytes(binary.BigEndian.Uint32(d[:4]), d[4], d[5:]) buffer := bytes.NewBuffer(d)
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
if err != nil {
return nil, err
}
if N >= (1 << 32) {
return nil, ErrNTooBig
}
P, err := buffer.ReadByte()
if err != nil {
return nil, err
}
return FromBytes(uint32(N), P, buffer.Bytes())
} }
// Bytes returns the serialized format of the GCS filter, which does not // Bytes returns the serialized format of the GCS filter, which does not
// include N or P (returned by separate methods) or the key used by SipHash. // include N or P (returned by separate methods) or the key used by SipHash.
func (f *Filter) Bytes() []byte { func (f *Filter) Bytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData)) filterData := make([]byte, len(f.filterData))
copy(filterData, f.filterData) copy(filterData, f.filterData)
return filterData return filterData, nil
} }
// NBytes returns the serialized format of the GCS filter with N, which does // NBytes returns the serialized format of the GCS filter with N, which does
// not include P (returned by a separate method) or the key used by SipHash. // not include P (returned by a separate method) or the key used by SipHash.
func (f *Filter) NBytes() []byte { func (f *Filter) NBytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData)+4) var buffer bytes.Buffer
binary.BigEndian.PutUint32(filterData[:4], f.n) buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData))
copy(filterData[4:], f.filterData)
return filterData err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
if err != nil {
return nil, err
}
_, err = buffer.Write(f.filterData)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
} }
// PBytes returns the serialized format of the GCS filter with P, which does // PBytes returns the serialized format of the GCS filter with P, which does
// not include N (returned by a separate method) or the key used by SipHash. // not include N (returned by a separate method) or the key used by SipHash.
func (f *Filter) PBytes() []byte { func (f *Filter) PBytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData)+1) filterData := make([]byte, len(f.filterData)+1)
filterData[0] = f.p filterData[0] = f.p
copy(filterData[1:], f.filterData) copy(filterData[1:], f.filterData)
return filterData return filterData, nil
} }
// NPBytes returns the serialized format of the GCS filter with N and P, which // NPBytes returns the serialized format of the GCS filter with N and P, which
// does not include the key used by SipHash. // does not include the key used by SipHash.
func (f *Filter) NPBytes() []byte { func (f *Filter) NPBytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData)+5) var buffer bytes.Buffer
binary.BigEndian.PutUint32(filterData[:4], f.n) buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData))
filterData[4] = f.p
copy(filterData[5:], f.filterData) err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
return filterData if err != nil {
return nil, err
}
err = buffer.WriteByte(f.p)
if err != nil {
return nil, err
}
_, err = buffer.Write(f.filterData)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
} }
// P returns the filter's collision probability as a negative power of 2 (that // P returns the filter's collision probability as a negative power of 2 (that
@ -261,7 +309,11 @@ func (f *Filter) N() uint32 {
func (f *Filter) Match(key [KeySize]byte, data []byte) (bool, error) { func (f *Filter) Match(key [KeySize]byte, data []byte) (bool, error) {
// Create a filter bitstream. // Create a filter bitstream.
filterData := f.Bytes() filterData, err := f.Bytes()
if err != nil {
return false, err
}
b := bstream.NewBStreamReader(filterData) b := bstream.NewBStreamReader(filterData)
// We take the high and low bits of modulusNP for the multiplication // We take the high and low bits of modulusNP for the multiplication
@ -310,7 +362,11 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
} }
// Create a filter bitstream. // Create a filter bitstream.
filterData := f.Bytes() filterData, err := f.Bytes()
if err != nil {
return false, err
}
b := bstream.NewBStreamReader(filterData) b := bstream.NewBStreamReader(filterData)
// Create an uncompressed filter of the search values. // Create an uncompressed filter of the search values.
@ -372,7 +428,7 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
// readFullUint64 reads a value represented by the sum of a unary multiple of // readFullUint64 reads a value represented by the sum of a unary multiple of
// the filter's P modulus (`2**P`) and a big-endian P-bit remainder. // the filter's P modulus (`2**P`) and a big-endian P-bit remainder.
func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) { func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
var v uint64 var quotient uint64
// Count the 1s until we reach a 0. // Count the 1s until we reach a 0.
c, err := b.ReadBit() c, err := b.ReadBit()
@ -380,7 +436,7 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
return 0, err return 0, err
} }
for c { for c {
v++ quotient++
c, err = b.ReadBit() c, err = b.ReadBit()
if err != nil { if err != nil {
return 0, err return 0, err
@ -394,6 +450,6 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
} }
// Add the multiple and the remainder. // Add the multiple and the remainder.
v = v*f.modulusP + remainder v := (quotient << f.p) + remainder
return v, nil return v, nil
} }

View file

@ -87,19 +87,35 @@ func TestGCSFilterBuild(t *testing.T) {
// TestGCSFilterCopy deserializes and serializes a filter to create a copy. // TestGCSFilterCopy deserializes and serializes a filter to create a copy.
func TestGCSFilterCopy(t *testing.T) { func TestGCSFilterCopy(t *testing.T) {
filter2, err = gcs.FromBytes(filter.N(), P, filter.Bytes()) serialized2, err := filter.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
filter2, err = gcs.FromBytes(filter.N(), P, serialized2)
if err != nil { if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error()) t.Fatalf("Filter copy failed: %s", err.Error())
} }
filter3, err = gcs.FromNBytes(filter.P(), filter.NBytes()) serialized3, err := filter.NBytes()
if err != nil {
t.Fatalf("Filter NBytes() failed: %v", err)
}
filter3, err = gcs.FromNBytes(filter.P(), serialized3)
if err != nil { if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error()) t.Fatalf("Filter copy failed: %s", err.Error())
} }
filter4, err = gcs.FromPBytes(filter.N(), filter.PBytes()) serialized4, err := filter.PBytes()
if err != nil {
t.Fatalf("Filter PBytes() failed: %v", err)
}
filter4, err = gcs.FromPBytes(filter.N(), serialized4)
if err != nil { if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error()) t.Fatalf("Filter copy failed: %s", err.Error())
} }
filter5, err = gcs.FromNPBytes(filter.NPBytes()) serialized5, err := filter.NPBytes()
if err != nil {
t.Fatalf("Filter NPBytes() failed: %v", err)
}
filter5, err = gcs.FromNPBytes(serialized5)
if err != nil { if err != nil {
t.Fatalf("Filter copy failed: %s", err.Error()) t.Fatalf("Filter copy failed: %s", err.Error())
} }
@ -138,16 +154,36 @@ func TestGCSFilterMetadata(t *testing.T) {
if filter.N() != filter5.N() { if filter.N() != filter5.N() {
t.Fatal("N doesn't match between copied filters") t.Fatal("N doesn't match between copied filters")
} }
if !bytes.Equal(filter.Bytes(), filter2.Bytes()) { serialized, err := filter.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
serialized2, err := filter2.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized2) {
t.Fatal("Bytes don't match between copied filters") t.Fatal("Bytes don't match between copied filters")
} }
if !bytes.Equal(filter.Bytes(), filter3.Bytes()) { serialized3, err := filter3.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized3) {
t.Fatal("Bytes don't match between copied filters") t.Fatal("Bytes don't match between copied filters")
} }
if !bytes.Equal(filter.Bytes(), filter4.Bytes()) { serialized4, err := filter3.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized4) {
t.Fatal("Bytes don't match between copied filters") t.Fatal("Bytes don't match between copied filters")
} }
if !bytes.Equal(filter.Bytes(), filter5.Bytes()) { serialized5, err := filter5.Bytes()
if err != nil {
t.Fatalf("Filter Bytes() failed: %v", err)
}
if !bytes.Equal(serialized, serialized5) {
t.Fatal("Bytes don't match between copied filters") t.Fatal("Bytes don't match between copied filters")
} }
} }