Serialize filter with N as a VarInt instead of fixed-size.
This commit is contained in:
parent
9da482119c
commit
884680ddbd
3 changed files with 145 additions and 47 deletions
|
@ -371,21 +371,28 @@ func BuildExtFilter(block *wire.MsgBlock) (*gcs.Filter, error) {
|
|||
}
|
||||
|
||||
// GetFilterHash returns the double-SHA256 of the filter.
|
||||
func GetFilterHash(filter *gcs.Filter) chainhash.Hash {
|
||||
func GetFilterHash(filter *gcs.Filter) (chainhash.Hash, error) {
|
||||
var zero chainhash.Hash
|
||||
if filter == nil {
|
||||
return zero
|
||||
return zero, nil
|
||||
}
|
||||
|
||||
hash1 := chainhash.HashH(filter.NBytes())
|
||||
return chainhash.HashH(hash1[:])
|
||||
filterData, err := filter.NBytes()
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
return chainhash.DoubleHashH(filterData), nil
|
||||
}
|
||||
|
||||
// MakeHeaderForFilter makes a filter chain header for a filter, given the
|
||||
// filter and the previous filter chain header.
|
||||
func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhash.Hash {
|
||||
func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) (chainhash.Hash, error) {
|
||||
filterTip := make([]byte, 2*chainhash.HashSize)
|
||||
filterHash := GetFilterHash(filter)
|
||||
filterHash, err := GetFilterHash(filter)
|
||||
if err != nil {
|
||||
return chainhash.Hash{}, err
|
||||
}
|
||||
|
||||
// In the buffer we created above we'll compute hash || prevHash as an
|
||||
// intermediate value.
|
||||
|
@ -394,6 +401,5 @@ func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhas
|
|||
|
||||
// The final filter hash is the double-sha256 of the hash computed
|
||||
// above.
|
||||
hash1 := chainhash.HashH(filterTip)
|
||||
return chainhash.HashH(hash1[:])
|
||||
return chainhash.DoubleHashH(filterTip), nil
|
||||
}
|
||||
|
|
118
gcs/gcs.go
118
gcs/gcs.go
|
@ -6,14 +6,14 @@
|
|||
package gcs
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/aead/siphash"
|
||||
"github.com/kkdai/bstream"
|
||||
"github.com/roasbeef/btcd/wire"
|
||||
)
|
||||
|
||||
// Inspired by https://github.com/rasky/gcs
|
||||
|
@ -34,6 +34,10 @@ const (
|
|||
// KeySize is the size of the byte array required for key material for
|
||||
// the SipHash keyed hash function.
|
||||
KeySize = 16
|
||||
|
||||
// varIntProtoVer is the protocol version to use for serializing N as a
|
||||
// VarInt.
|
||||
varIntProtoVer uint32 = 0
|
||||
)
|
||||
|
||||
// fastReduction calculates a mapping that's more ore less equivalent to: x mod
|
||||
|
@ -83,7 +87,6 @@ func fastReduction(v, nHi, nLo uint64) uint64 {
|
|||
type Filter struct {
|
||||
n uint32
|
||||
p uint8
|
||||
modulusP uint64
|
||||
modulusNP uint64
|
||||
filterData []byte
|
||||
}
|
||||
|
@ -98,7 +101,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
|
|||
if len(data) == 0 {
|
||||
return nil, ErrNoData
|
||||
}
|
||||
if len(data) > math.MaxInt32 {
|
||||
if uint64(len(data)) >= (1 << 32) {
|
||||
return nil, ErrNTooBig
|
||||
}
|
||||
if P > 32 {
|
||||
|
@ -110,8 +113,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
|
|||
n: uint32(len(data)),
|
||||
p: P,
|
||||
}
|
||||
f.modulusP = uint64(1 << f.p)
|
||||
f.modulusNP = uint64(f.n) * f.modulusP
|
||||
f.modulusNP = uint64(f.n) << P
|
||||
|
||||
// Build the filter.
|
||||
values := make(uint64Slice, 0, len(data))
|
||||
|
@ -141,7 +143,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error)
|
|||
for _, v := range values {
|
||||
// Calculate the difference between this value and the last,
|
||||
// modulo P.
|
||||
remainder = (v - lastValue) & (f.modulusP - 1)
|
||||
remainder = (v - lastValue) & ((uint64(1) << P) - 1)
|
||||
|
||||
// Calculate the difference between this value and the last,
|
||||
// divided by P.
|
||||
|
@ -181,8 +183,7 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
|
|||
n: N,
|
||||
p: P,
|
||||
}
|
||||
f.modulusP = uint64(1 << f.p)
|
||||
f.modulusNP = uint64(f.n) * f.modulusP
|
||||
f.modulusNP = uint64(f.n) << P
|
||||
|
||||
// Copy the filter.
|
||||
f.filterData = make([]byte, len(d))
|
||||
|
@ -194,7 +195,15 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
|
|||
// FromNBytes deserializes a GCS filter from a known P, and serialized N and
|
||||
// filter as returned by NBytes().
|
||||
func FromNBytes(P uint8, d []byte) (*Filter, error) {
|
||||
return FromBytes(binary.BigEndian.Uint32(d[:4]), P, d[4:])
|
||||
buffer := bytes.NewBuffer(d)
|
||||
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if N >= (1 << 32) {
|
||||
return nil, ErrNTooBig
|
||||
}
|
||||
return FromBytes(uint32(N), P, buffer.Bytes())
|
||||
}
|
||||
|
||||
// FromPBytes deserializes a GCS filter from a known N, and serialized P and
|
||||
|
@ -206,43 +215,82 @@ func FromPBytes(N uint32, d []byte) (*Filter, error) {
|
|||
// FromNPBytes deserializes a GCS filter from a serialized N, P, and filter as
|
||||
// returned by NPBytes().
|
||||
func FromNPBytes(d []byte) (*Filter, error) {
|
||||
return FromBytes(binary.BigEndian.Uint32(d[:4]), d[4], d[5:])
|
||||
buffer := bytes.NewBuffer(d)
|
||||
|
||||
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if N >= (1 << 32) {
|
||||
return nil, ErrNTooBig
|
||||
}
|
||||
|
||||
P, err := buffer.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return FromBytes(uint32(N), P, buffer.Bytes())
|
||||
}
|
||||
|
||||
// Bytes returns the serialized format of the GCS filter, which does not
|
||||
// include N or P (returned by separate methods) or the key used by SipHash.
|
||||
func (f *Filter) Bytes() []byte {
|
||||
func (f *Filter) Bytes() ([]byte, error) {
|
||||
filterData := make([]byte, len(f.filterData))
|
||||
copy(filterData, f.filterData)
|
||||
return filterData
|
||||
return filterData, nil
|
||||
}
|
||||
|
||||
// NBytes returns the serialized format of the GCS filter with N, which does
|
||||
// not include P (returned by a separate method) or the key used by SipHash.
|
||||
func (f *Filter) NBytes() []byte {
|
||||
filterData := make([]byte, len(f.filterData)+4)
|
||||
binary.BigEndian.PutUint32(filterData[:4], f.n)
|
||||
copy(filterData[4:], f.filterData)
|
||||
return filterData
|
||||
func (f *Filter) NBytes() ([]byte, error) {
|
||||
var buffer bytes.Buffer
|
||||
buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData))
|
||||
|
||||
err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = buffer.Write(f.filterData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// PBytes returns the serialized format of the GCS filter with P, which does
|
||||
// not include N (returned by a separate method) or the key used by SipHash.
|
||||
func (f *Filter) PBytes() []byte {
|
||||
func (f *Filter) PBytes() ([]byte, error) {
|
||||
filterData := make([]byte, len(f.filterData)+1)
|
||||
filterData[0] = f.p
|
||||
copy(filterData[1:], f.filterData)
|
||||
return filterData
|
||||
return filterData, nil
|
||||
}
|
||||
|
||||
// NPBytes returns the serialized format of the GCS filter with N and P, which
|
||||
// does not include the key used by SipHash.
|
||||
func (f *Filter) NPBytes() []byte {
|
||||
filterData := make([]byte, len(f.filterData)+5)
|
||||
binary.BigEndian.PutUint32(filterData[:4], f.n)
|
||||
filterData[4] = f.p
|
||||
copy(filterData[5:], f.filterData)
|
||||
return filterData
|
||||
func (f *Filter) NPBytes() ([]byte, error) {
|
||||
var buffer bytes.Buffer
|
||||
buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData))
|
||||
|
||||
err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = buffer.WriteByte(f.p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = buffer.Write(f.filterData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// P returns the filter's collision probability as a negative power of 2 (that
|
||||
|
@ -261,7 +309,11 @@ func (f *Filter) N() uint32 {
|
|||
func (f *Filter) Match(key [KeySize]byte, data []byte) (bool, error) {
|
||||
|
||||
// Create a filter bitstream.
|
||||
filterData := f.Bytes()
|
||||
filterData, err := f.Bytes()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
b := bstream.NewBStreamReader(filterData)
|
||||
|
||||
// We take the high and low bits of modulusNP for the multiplication
|
||||
|
@ -310,7 +362,11 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
|
|||
}
|
||||
|
||||
// Create a filter bitstream.
|
||||
filterData := f.Bytes()
|
||||
filterData, err := f.Bytes()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
b := bstream.NewBStreamReader(filterData)
|
||||
|
||||
// Create an uncompressed filter of the search values.
|
||||
|
@ -372,7 +428,7 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
|
|||
// readFullUint64 reads a value represented by the sum of a unary multiple of
|
||||
// the filter's P modulus (`2**P`) and a big-endian P-bit remainder.
|
||||
func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
|
||||
var v uint64
|
||||
var quotient uint64
|
||||
|
||||
// Count the 1s until we reach a 0.
|
||||
c, err := b.ReadBit()
|
||||
|
@ -380,7 +436,7 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
|
|||
return 0, err
|
||||
}
|
||||
for c {
|
||||
v++
|
||||
quotient++
|
||||
c, err = b.ReadBit()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
@ -394,6 +450,6 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
|
|||
}
|
||||
|
||||
// Add the multiple and the remainder.
|
||||
v = v*f.modulusP + remainder
|
||||
v := (quotient << f.p) + remainder
|
||||
return v, nil
|
||||
}
|
||||
|
|
|
@ -87,19 +87,35 @@ func TestGCSFilterBuild(t *testing.T) {
|
|||
|
||||
// TestGCSFilterCopy deserializes and serializes a filter to create a copy.
|
||||
func TestGCSFilterCopy(t *testing.T) {
|
||||
filter2, err = gcs.FromBytes(filter.N(), P, filter.Bytes())
|
||||
serialized2, err := filter.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter Bytes() failed: %v", err)
|
||||
}
|
||||
filter2, err = gcs.FromBytes(filter.N(), P, serialized2)
|
||||
if err != nil {
|
||||
t.Fatalf("Filter copy failed: %s", err.Error())
|
||||
}
|
||||
filter3, err = gcs.FromNBytes(filter.P(), filter.NBytes())
|
||||
serialized3, err := filter.NBytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter NBytes() failed: %v", err)
|
||||
}
|
||||
filter3, err = gcs.FromNBytes(filter.P(), serialized3)
|
||||
if err != nil {
|
||||
t.Fatalf("Filter copy failed: %s", err.Error())
|
||||
}
|
||||
filter4, err = gcs.FromPBytes(filter.N(), filter.PBytes())
|
||||
serialized4, err := filter.PBytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter PBytes() failed: %v", err)
|
||||
}
|
||||
filter4, err = gcs.FromPBytes(filter.N(), serialized4)
|
||||
if err != nil {
|
||||
t.Fatalf("Filter copy failed: %s", err.Error())
|
||||
}
|
||||
filter5, err = gcs.FromNPBytes(filter.NPBytes())
|
||||
serialized5, err := filter.NPBytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter NPBytes() failed: %v", err)
|
||||
}
|
||||
filter5, err = gcs.FromNPBytes(serialized5)
|
||||
if err != nil {
|
||||
t.Fatalf("Filter copy failed: %s", err.Error())
|
||||
}
|
||||
|
@ -138,16 +154,36 @@ func TestGCSFilterMetadata(t *testing.T) {
|
|||
if filter.N() != filter5.N() {
|
||||
t.Fatal("N doesn't match between copied filters")
|
||||
}
|
||||
if !bytes.Equal(filter.Bytes(), filter2.Bytes()) {
|
||||
serialized, err := filter.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter Bytes() failed: %v", err)
|
||||
}
|
||||
serialized2, err := filter2.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter Bytes() failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(serialized, serialized2) {
|
||||
t.Fatal("Bytes don't match between copied filters")
|
||||
}
|
||||
if !bytes.Equal(filter.Bytes(), filter3.Bytes()) {
|
||||
serialized3, err := filter3.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter Bytes() failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(serialized, serialized3) {
|
||||
t.Fatal("Bytes don't match between copied filters")
|
||||
}
|
||||
if !bytes.Equal(filter.Bytes(), filter4.Bytes()) {
|
||||
serialized4, err := filter3.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter Bytes() failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(serialized, serialized4) {
|
||||
t.Fatal("Bytes don't match between copied filters")
|
||||
}
|
||||
if !bytes.Equal(filter.Bytes(), filter5.Bytes()) {
|
||||
serialized5, err := filter5.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Filter Bytes() failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(serialized, serialized5) {
|
||||
t.Fatal("Bytes don't match between copied filters")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue