lbcutil/gcs/gcs.go

455 lines
12 KiB
Go
Raw Normal View History

// Copyright (c) 2016-2017 The btcsuite developers
// Copyright (c) 2016-2017 The Lightning Network Developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package gcs
import (
"bytes"
"fmt"
"io"
"sort"
"github.com/aead/siphash"
"github.com/kkdai/bstream"
"github.com/roasbeef/btcd/wire"
)
// Inspired by https://github.com/rasky/gcs
var (
// ErrNTooBig signifies that the filter can't handle N items.
ErrNTooBig = fmt.Errorf("N is too big to fit in uint32")
// ErrPTooBig signifies that the filter can't handle `1/2**P`
// collision probability.
ErrPTooBig = fmt.Errorf("P is too big to fit in uint32")
)
const (
// KeySize is the size of the byte array required for key material for
// the SipHash keyed hash function.
KeySize = 16
// varIntProtoVer is the protocol version to use for serializing N as a
// VarInt.
varIntProtoVer uint32 = 0
)
// fastReduction calculates a mapping that's more ore less equivalent to: x mod
// N. However, instead of using a mod operation, which using a non-power of two
// will lead to slowness on many processors due to unnecessary division, we
// instead use a "multiply-and-shift" trick which eliminates all divisions,
// described in:
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
//
// * v * N >> log_2(N)
//
// In our case, using 64-bit integers, log_2 is 64. As most processors don't
// support 128-bit arithmetic natively, we'll be super portable and unfold the
// operation into several operations with 64-bit arithmetic. As inputs, we the
// number to reduce, and our modulus N divided into its high 32-bits and lower
// 32-bits.
func fastReduction(v, nHi, nLo uint64) uint64 {
// First, we'll spit the item we need to reduce into its higher and
// lower bits.
vhi := v >> 32
vlo := uint64(uint32(v))
// Then, we distribute multiplication over each part.
vnphi := vhi * nHi
vnpmid := vhi * nLo
npvmid := nHi * vlo
vnplo := vlo * nLo
// We calculate the carry bit.
carry := (uint64(uint32(vnpmid)) + uint64(uint32(npvmid)) +
(vnplo >> 32)) >> 32
// Last, we add the high bits, the middle bits, and the carry.
v = vnphi + (vnpmid >> 32) + (npvmid >> 32) + carry
return v
}
// Filter describes an immutable filter that can be built from a set of data
// elements, serialized, deserialized, and queried in a thread-safe manner. The
// serialized form is compressed as a Golomb Coded Set (GCS), but does not
// include N or P to allow the user to encode the metadata separately if
// necessary. The hash function used is SipHash, a keyed function; the key used
// in building the filter is required in order to match filter values and is
// not included in the serialized form.
type Filter struct {
n uint32
p uint8
modulusNP uint64
filterData []byte
}
// BuildGCSFilter builds a new GCS filter with the collision probability of
// `1/(2**P)`, key `key`, and including every `[]byte` in `data` as a member of
// the set.
func BuildGCSFilter(P uint8, key [KeySize]byte, data [][]byte) (*Filter, error) {
// Some initial parameter checks: make sure we have data from which to
// build the filter, and make sure our parameters will fit the hash
// function we're using.
if uint64(len(data)) >= (1 << 32) {
return nil, ErrNTooBig
}
if P > 32 {
return nil, ErrPTooBig
}
// Create the filter object and insert metadata.
f := Filter{
n: uint32(len(data)),
p: P,
}
f.modulusNP = uint64(f.n) << P
2018-01-17 23:56:28 +01:00
// Shortcut if the filter is empty.
if f.n == 0 {
return &f, nil
}
// Build the filter.
values := make(uint64Slice, 0, len(data))
b := bstream.NewBStreamWriter(0)
// Insert the hash (fast-ranged over a space of N*P) of each data
// element into a slice and sort the slice. This can be greatly
// optimized with native 128-bit multiplication, but we're going to be
// fully portable for now.
//
// First, we cache the high and low bits of modulusNP for the
// multiplication of 2 64-bit integers into a 128-bit integer.
nphi := f.modulusNP >> 32
nplo := uint64(uint32(f.modulusNP))
for _, d := range data {
// For each datum, we assign the initial hash to a uint64.
v := siphash.Sum64(d, &key)
v = fastReduction(v, nphi, nplo)
values = append(values, v)
}
sort.Sort(values)
// Write the sorted list of values into the filter bitstream,
// compressing it using Golomb coding.
var value, lastValue, remainder uint64
for _, v := range values {
// Calculate the difference between this value and the last,
// modulo P.
remainder = (v - lastValue) & ((uint64(1) << P) - 1)
// Calculate the difference between this value and the last,
// divided by P.
value = (v - lastValue - remainder) >> f.p
lastValue = v
// Write the P multiple into the bitstream in unary; the
// average should be around 1 (2 bits - 0b10).
for value > 0 {
b.WriteBit(true)
value--
}
b.WriteBit(false)
// Write the remainder as a big-endian integer with enough bits
// to represent the appropriate collision probability.
b.WriteBits(remainder, int(f.p))
}
// Copy the bitstream into the filter object and return the object.
f.filterData = b.Bytes()
return &f, nil
}
// FromBytes deserializes a GCS filter from a known N, P, and serialized filter
// as returned by Bytes().
func FromBytes(N uint32, P uint8, d []byte) (*Filter, error) {
// Basic sanity check.
if P > 32 {
return nil, ErrPTooBig
}
// Create the filter object and insert metadata.
f := &Filter{
n: N,
p: P,
}
f.modulusNP = uint64(f.n) << P
// Copy the filter.
f.filterData = make([]byte, len(d))
copy(f.filterData, d)
return f, nil
}
// FromNBytes deserializes a GCS filter from a known P, and serialized N and
// filter as returned by NBytes().
func FromNBytes(P uint8, d []byte) (*Filter, error) {
buffer := bytes.NewBuffer(d)
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
if err != nil {
return nil, err
}
if N >= (1 << 32) {
return nil, ErrNTooBig
}
return FromBytes(uint32(N), P, buffer.Bytes())
}
// FromPBytes deserializes a GCS filter from a known N, and serialized P and
// filter as returned by NBytes().
func FromPBytes(N uint32, d []byte) (*Filter, error) {
return FromBytes(N, d[0], d[1:])
}
// FromNPBytes deserializes a GCS filter from a serialized N, P, and filter as
// returned by NPBytes().
func FromNPBytes(d []byte) (*Filter, error) {
buffer := bytes.NewBuffer(d)
N, err := wire.ReadVarInt(buffer, varIntProtoVer)
if err != nil {
return nil, err
}
if N >= (1 << 32) {
return nil, ErrNTooBig
}
P, err := buffer.ReadByte()
if err != nil {
return nil, err
}
return FromBytes(uint32(N), P, buffer.Bytes())
}
// Bytes returns the serialized format of the GCS filter, which does not
// include N or P (returned by separate methods) or the key used by SipHash.
func (f *Filter) Bytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData))
copy(filterData, f.filterData)
return filterData, nil
}
// NBytes returns the serialized format of the GCS filter with N, which does
// not include P (returned by a separate method) or the key used by SipHash.
func (f *Filter) NBytes() ([]byte, error) {
var buffer bytes.Buffer
buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData))
err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
if err != nil {
return nil, err
}
_, err = buffer.Write(f.filterData)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
// PBytes returns the serialized format of the GCS filter with P, which does
// not include N (returned by a separate method) or the key used by SipHash.
func (f *Filter) PBytes() ([]byte, error) {
filterData := make([]byte, len(f.filterData)+1)
filterData[0] = f.p
copy(filterData[1:], f.filterData)
return filterData, nil
}
// NPBytes returns the serialized format of the GCS filter with N and P, which
// does not include the key used by SipHash.
func (f *Filter) NPBytes() ([]byte, error) {
var buffer bytes.Buffer
buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData))
err := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
if err != nil {
return nil, err
}
err = buffer.WriteByte(f.p)
if err != nil {
return nil, err
}
_, err = buffer.Write(f.filterData)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
// P returns the filter's collision probability as a negative power of 2 (that
// is, a collision probability of `1/2**20` is represented as 20).
func (f *Filter) P() uint8 {
return f.p
}
// N returns the size of the data set used to build the filter.
func (f *Filter) N() uint32 {
return f.n
}
// Match checks whether a []byte value is likely (within collision probability)
// to be a member of the set represented by the filter.
func (f *Filter) Match(key [KeySize]byte, data []byte) (bool, error) {
// Create a filter bitstream.
filterData, err := f.Bytes()
if err != nil {
return false, err
}
b := bstream.NewBStreamReader(filterData)
// We take the high and low bits of modulusNP for the multiplication
// of 2 64-bit integers into a 128-bit integer.
nphi := f.modulusNP >> 32
nplo := uint64(uint32(f.modulusNP))
// Then we hash our search term with the same parameters as the filter.
term := siphash.Sum64(data, &key)
term = fastReduction(term, nphi, nplo)
// Go through the search filter and look for the desired value.
var lastValue uint64
for lastValue < term {
// Read the difference between previous and new value from
// bitstream.
value, err := f.readFullUint64(b)
if err != nil {
if err == io.EOF {
return false, nil
}
return false, err
}
// Add the previous value to it.
value += lastValue
if value == term {
return true, nil
}
lastValue = value
}
return false, nil
}
// MatchAny returns checks whether any []byte value is likely (within collision
// probability) to be a member of the set represented by the filter faster than
// calling Match() for each value individually.
func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
// Basic sanity check.
if len(data) == 0 {
2018-01-17 23:56:28 +01:00
return false, nil
}
// Create a filter bitstream.
filterData, err := f.Bytes()
if err != nil {
return false, err
}
b := bstream.NewBStreamReader(filterData)
// Create an uncompressed filter of the search values.
values := make(uint64Slice, 0, len(data))
// First, we cache the high and low bits of modulusNP for the
// multiplication of 2 64-bit integers into a 128-bit integer.
nphi := f.modulusNP >> 32
nplo := uint64(uint32(f.modulusNP))
for _, d := range data {
// For each datum, we assign the initial hash to a uint64.
v := siphash.Sum64(d, &key)
// We'll then reduce the value down to the range of our
// modulus.
v = fastReduction(v, nphi, nplo)
values = append(values, v)
}
sort.Sort(values)
// Zip down the filters, comparing values until we either run out of
// values to compare in one of the filters or we reach a matching
// value.
var lastValue1, lastValue2 uint64
lastValue2 = values[0]
i := 1
for lastValue1 != lastValue2 {
// Check which filter to advance to make sure we're comparing
// the right values.
switch {
case lastValue1 > lastValue2:
// Advance filter created from search terms or return
// false if we're at the end because nothing matched.
if i < len(values) {
lastValue2 = values[i]
i++
} else {
return false, nil
}
case lastValue2 > lastValue1:
// Advance filter we're searching or return false if
// we're at the end because nothing matched.
value, err := f.readFullUint64(b)
if err != nil {
if err == io.EOF {
return false, nil
}
return false, err
}
lastValue1 += value
}
}
// If we've made it this far, an element matched between filters so we
// return true.
return true, nil
}
// readFullUint64 reads a value represented by the sum of a unary multiple of
// the filter's P modulus (`2**P`) and a big-endian P-bit remainder.
func (f *Filter) readFullUint64(b *bstream.BStream) (uint64, error) {
var quotient uint64
// Count the 1s until we reach a 0.
c, err := b.ReadBit()
if err != nil {
return 0, err
}
for c {
quotient++
c, err = b.ReadBit()
if err != nil {
return 0, err
}
}
// Read P bits.
remainder, err := b.ReadBits(int(f.p))
if err != nil {
return 0, err
}
// Add the multiple and the remainder.
v := (quotient << f.p) + remainder
return v, nil
}