diff --git a/gcs/builder/builder.go b/gcs/builder/builder.go index e2aefc9..7f79b58 100644 --- a/gcs/builder/builder.go +++ b/gcs/builder/builder.go @@ -371,21 +371,28 @@ func BuildExtFilter(block *wire.MsgBlock) (*gcs.Filter, error) { } // GetFilterHash returns the double-SHA256 of the filter. -func GetFilterHash(filter *gcs.Filter) chainhash.Hash { +func GetFilterHash(filter *gcs.Filter) (chainhash.Hash, error) { var zero chainhash.Hash if filter == nil { - return zero + return zero, nil } - hash1 := chainhash.HashH(filter.NBytes()) - return chainhash.HashH(hash1[:]) + filterData, err := filter.NBytes() + if err != nil { + return zero, err + } + + return chainhash.DoubleHashH(filterData), nil } // MakeHeaderForFilter makes a filter chain header for a filter, given the // filter and the previous filter chain header. -func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhash.Hash { +func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) (chainhash.Hash, error) { filterTip := make([]byte, 2*chainhash.HashSize) - filterHash := GetFilterHash(filter) + filterHash, err := GetFilterHash(filter) + if err != nil { + return chainhash.Hash{}, err + } // In the buffer we created above we'll compute hash || prevHash as an // intermediate value. @@ -394,6 +401,5 @@ func MakeHeaderForFilter(filter *gcs.Filter, prevHeader chainhash.Hash) chainhas // The final filter hash is the double-sha256 of the hash computed // above. - hash1 := chainhash.HashH(filterTip) - return chainhash.HashH(hash1[:]) + return chainhash.DoubleHashH(filterTip), nil } diff --git a/gcs/gcs.go b/gcs/gcs.go index a33d871..b0f735e 100644 --- a/gcs/gcs.go +++ b/gcs/gcs.go @@ -6,14 +6,14 @@ package gcs import ( - "encoding/binary" + "bytes" "fmt" "io" - "math" "sort" "github.com/aead/siphash" "github.com/kkdai/bstream" + "github.com/roasbeef/btcd/wire" ) // Inspired by https://github.com/rasky/gcs @@ -34,6 +34,10 @@ const ( // KeySize is the size of the byte array required for key material for // the SipHash keyed hash function. KeySize = 16 + + // varIntProtoVer is the protocol version to use for serializing N as a + // VarInt. + varIntProtoVer uint32 = 0 ) // fastReduction calculates a mapping that's more ore less equivalent to: x mod @@ -83,7 +87,6 @@ func fastReduction(v, nHi, nLo uint64) uint64 { type Filter struct { n uint32 p uint8 - modulusP uint64 modulusNP uint64 filterData []byte } @@ -98,7 +101,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error) if len(data) == 0 { return nil, ErrNoData } - if len(data) > math.MaxInt32 { + if uint64(len(data)) >= (1 << 32) { return nil, ErrNTooBig } if P > 32 { @@ -110,8 +113,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error) n: uint32(len(data)), p: P, } - f.modulusP = uint64(1 << f.p) - f.modulusNP = uint64(f.n) * f.modulusP + f.modulusNP = uint64(f.n) << P // Build the filter. values := make(uint64Slice, 0, len(data)) @@ -141,7 +143,7 @@ func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error) for _, v := range values { // Calculate the difference between this value and the last, // modulo P. - remainder = (v - lastValue) & (f.modulusP - 1) + remainder = (v - lastValue) & ((uint64(1) << P) - 1) // Calculate the difference between this value and the last, // divided by P. @@ -181,8 +183,7 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) { n: N, p: P, } - f.modulusP = uint64(1 << f.p) - f.modulusNP = uint64(f.n) * f.modulusP + f.modulusNP = uint64(f.n) << P // Copy the filter. f.filterData = make([]byte, len(d)) @@ -194,7 +195,15 @@ func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) { // FromNBytes deserializes a GCS filter from a known P, and serialized N and // filter as returned by NBytes(). func FromNBytes(P uint8, d []byte) (*Filter, error) { - return FromBytes(binary.BigEndian.Uint32(d[:4]), P, d[4:]) + buffer := bytes.NewBuffer(d) + N, err := wire.ReadVarInt(buffer, varIntProtoVer) + if err != nil { + return nil, err + } + if N >= (1 << 32) { + return nil, ErrNTooBig + } + return FromBytes(uint32(N), P, buffer.Bytes()) } // FromPBytes deserializes a GCS filter from a known N, and serialized P and @@ -206,43 +215,82 @@ func FromPBytes(N uint32, d []byte) (*Filter, error) { // FromNPBytes deserializes a GCS filter from a serialized N, P, and filter as // returned by NPBytes(). func FromNPBytes(d []byte) (*Filter, error) { - return FromBytes(binary.BigEndian.Uint32(d[:4]), d[4], d[5:]) + buffer := bytes.NewBuffer(d) + + N, err := wire.ReadVarInt(buffer, varIntProtoVer) + if err != nil { + return nil, err + } + if N >= (1 << 32) { + return nil, ErrNTooBig + } + + P, err := buffer.ReadByte() + if err != nil { + return nil, err + } + + return FromBytes(uint32(N), P, buffer.Bytes()) } // Bytes returns the serialized format of the GCS filter, which does not // include N or P (returned by separate methods) or the key used by SipHash. -func (f *Filter) Bytes() []byte { +func (f *Filter) Bytes() ([]byte, error) { filterData := make([]byte, len(f.filterData)) copy(filterData, f.filterData) - return filterData + return filterData, nil } // NBytes returns the serialized format of the GCS filter with N, which does // not include P (returned by a separate method) or the key used by SipHash. -func (f *Filter) NBytes() []byte { - filterData := make([]byte, len(f.filterData)+4) - binary.BigEndian.PutUint32(filterData[:4], f.n) - copy(filterData[4:], f.filterData) - return filterData +func (f *Filter) NBytes() ([]byte, error) { + var buffer bytes.Buffer + buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData)) + + err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n)) + if err != nil { + return nil, err + } + + _, err = buffer.Write(f.filterData) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil } // PBytes returns the serialized format of the GCS filter with P, which does // not include N (returned by a separate method) or the key used by SipHash. -func (f *Filter) PBytes() []byte { +func (f *Filter) PBytes() ([]byte, error) { filterData := make([]byte, len(f.filterData)+1) filterData[0] = f.p copy(filterData[1:], f.filterData) - return filterData + return filterData, nil } // NPBytes returns the serialized format of the GCS filter with N and P, which // does not include the key used by SipHash. -func (f *Filter) NPBytes() []byte { - filterData := make([]byte, len(f.filterData)+5) - binary.BigEndian.PutUint32(filterData[:4], f.n) - filterData[4] = f.p - copy(filterData[5:], f.filterData) - return filterData +func (f *Filter) NPBytes() ([]byte, error) { + var buffer bytes.Buffer + buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData)) + + err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n)) + if err != nil { + return nil, err + } + + err = buffer.WriteByte(f.p) + if err != nil { + return nil, err + } + + _, err = buffer.Write(f.filterData) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil } // P returns the filter's collision probability as a negative power of 2 (that @@ -261,7 +309,11 @@ func (f *Filter) N() uint32 { func (f *Filter) Match(key [KeySize]byte, data []byte) (bool, error) { // Create a filter bitstream. - filterData := f.Bytes() + filterData, err := f.Bytes() + if err != nil { + return false, err + } + b := bstream.NewBStreamReader(filterData) // We take the high and low bits of modulusNP for the multiplication @@ -310,7 +362,11 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) { } // Create a filter bitstream. - filterData := f.Bytes() + filterData, err := f.Bytes() + if err != nil { + return false, err + } + b := bstream.NewBStreamReader(filterData) // Create an uncompressed filter of the search values. @@ -372,7 +428,7 @@ func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) { // readFullUint64 reads a value represented by the sum of a unary multiple of // the filter's P modulus (`2**P`) and a big-endian P-bit remainder. func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) { - var v uint64 + var quotient uint64 // Count the 1s until we reach a 0. c, err := b.ReadBit() @@ -380,7 +436,7 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) { return 0, err } for c { - v++ + quotient++ c, err = b.ReadBit() if err != nil { return 0, err @@ -394,6 +450,6 @@ func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) { } // Add the multiple and the remainder. - v = v*f.modulusP + remainder + v := (quotient << f.p) + remainder return v, nil } diff --git a/gcs/gcs_test.go b/gcs/gcs_test.go index ed62de4..fb604d8 100644 --- a/gcs/gcs_test.go +++ b/gcs/gcs_test.go @@ -87,19 +87,35 @@ func TestGCSFilterBuild(t *testing.T) { // TestGCSFilterCopy deserializes and serializes a filter to create a copy. func TestGCSFilterCopy(t *testing.T) { - filter2, err = gcs.FromBytes(filter.N(), P, filter.Bytes()) + serialized2, err := filter.Bytes() + if err != nil { + t.Fatalf("Filter Bytes() failed: %v", err) + } + filter2, err = gcs.FromBytes(filter.N(), P, serialized2) if err != nil { t.Fatalf("Filter copy failed: %s", err.Error()) } - filter3, err = gcs.FromNBytes(filter.P(), filter.NBytes()) + serialized3, err := filter.NBytes() + if err != nil { + t.Fatalf("Filter NBytes() failed: %v", err) + } + filter3, err = gcs.FromNBytes(filter.P(), serialized3) if err != nil { t.Fatalf("Filter copy failed: %s", err.Error()) } - filter4, err = gcs.FromPBytes(filter.N(), filter.PBytes()) + serialized4, err := filter.PBytes() + if err != nil { + t.Fatalf("Filter PBytes() failed: %v", err) + } + filter4, err = gcs.FromPBytes(filter.N(), serialized4) if err != nil { t.Fatalf("Filter copy failed: %s", err.Error()) } - filter5, err = gcs.FromNPBytes(filter.NPBytes()) + serialized5, err := filter.NPBytes() + if err != nil { + t.Fatalf("Filter NPBytes() failed: %v", err) + } + filter5, err = gcs.FromNPBytes(serialized5) if err != nil { t.Fatalf("Filter copy failed: %s", err.Error()) } @@ -138,16 +154,36 @@ func TestGCSFilterMetadata(t *testing.T) { if filter.N() != filter5.N() { t.Fatal("N doesn't match between copied filters") } - if !bytes.Equal(filter.Bytes(), filter2.Bytes()) { + serialized, err := filter.Bytes() + if err != nil { + t.Fatalf("Filter Bytes() failed: %v", err) + } + serialized2, err := filter2.Bytes() + if err != nil { + t.Fatalf("Filter Bytes() failed: %v", err) + } + if !bytes.Equal(serialized, serialized2) { t.Fatal("Bytes don't match between copied filters") } - if !bytes.Equal(filter.Bytes(), filter3.Bytes()) { + serialized3, err := filter3.Bytes() + if err != nil { + t.Fatalf("Filter Bytes() failed: %v", err) + } + if !bytes.Equal(serialized, serialized3) { t.Fatal("Bytes don't match between copied filters") } - if !bytes.Equal(filter.Bytes(), filter4.Bytes()) { + serialized4, err := filter3.Bytes() + if err != nil { + t.Fatalf("Filter Bytes() failed: %v", err) + } + if !bytes.Equal(serialized, serialized4) { t.Fatal("Bytes don't match between copied filters") } - if !bytes.Equal(filter.Bytes(), filter5.Bytes()) { + serialized5, err := filter5.Bytes() + if err != nil { + t.Fatalf("Filter Bytes() failed: %v", err) + } + if !bytes.Equal(serialized, serialized5) { t.Fatal("Bytes don't match between copied filters") } }