lbry.go/dht/bitmap.go
2018-05-24 17:49:43 -04:00

335 lines
6 KiB
Go

package dht
import (
"bytes"
"crypto/rand"
"encoding/hex"
"strings"
"github.com/lbryio/lbry.go/errors"
"github.com/lyoshenka/bencode"
)
// TODO: http://roaringbitmap.org/
type Bitmap [nodeIDLength]byte
func (b Bitmap) RawString() string {
return string(b[:])
}
// BString returns the bitmap as a string of 0s and 1s
func (b Bitmap) BString() string {
var buf bytes.Buffer
for i := 0; i < nodeIDBits; i++ {
if b.Get(i) {
buf.WriteString("1")
} else {
buf.WriteString("0")
}
}
return buf.String()
}
func (b Bitmap) Hex() string {
return hex.EncodeToString(b[:])
}
func (b Bitmap) HexShort() string {
return hex.EncodeToString(b[:4])
}
func (b Bitmap) HexSimplified() string {
simple := strings.TrimLeft(b.Hex(), "0")
if simple == "" {
simple = "0"
}
return simple
}
func (b Bitmap) Equals(other Bitmap) bool {
for k := range b {
if b[k] != other[k] {
return false
}
}
return true
}
func (b Bitmap) Less(other interface{}) bool {
for k := range b {
if b[k] != other.(Bitmap)[k] {
return b[k] < other.(Bitmap)[k]
}
}
return false
}
func (b Bitmap) LessOrEqual(other interface{}) bool {
if bm, ok := other.(Bitmap); ok && b.Equals(bm) {
return true
}
return b.Less(other)
}
func (b Bitmap) Greater(other interface{}) bool {
for k := range b {
if b[k] != other.(Bitmap)[k] {
return b[k] > other.(Bitmap)[k]
}
}
return false
}
func (b Bitmap) GreaterOrEqual(other interface{}) bool {
if bm, ok := other.(Bitmap); ok && b.Equals(bm) {
return true
}
return b.Greater(other)
}
func (b Bitmap) Copy() Bitmap {
var ret Bitmap
copy(ret[:], b[:])
return ret
}
func (b Bitmap) Xor(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] ^ other[k]
}
return ret
}
func (b Bitmap) And(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] & other[k]
}
return ret
}
func (b Bitmap) Or(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] | other[k]
}
return ret
}
func (b Bitmap) Not() Bitmap {
var ret Bitmap
for k := range b {
ret[k] = ^b[k]
}
return ret
}
func (b Bitmap) add(other Bitmap) (Bitmap, bool) {
var ret Bitmap
carry := false
for i := nodeIDBits - 1; i >= 0; i-- {
bBit := getBit(b[:], i)
oBit := getBit(other[:], i)
setBit(ret[:], i, bBit != oBit != carry)
carry = (bBit && oBit) || (bBit && carry) || (oBit && carry)
}
return ret, carry
}
func (b Bitmap) Add(other Bitmap) Bitmap {
ret, carry := b.add(other)
if carry {
panic("overflow in bitmap addition")
}
return ret
}
func (b Bitmap) Sub(other Bitmap) Bitmap {
if b.Less(other) {
panic("negative bitmaps not supported")
}
complement, _ := other.Not().add(BitmapFromShortHexP("1"))
ret, _ := b.add(complement)
return ret
}
func (b Bitmap) Get(n int) bool {
return getBit(b[:], n)
}
func (b Bitmap) Set(n int, one bool) Bitmap {
ret := b.Copy()
setBit(ret[:], n, one)
return ret
}
// PrefixLen returns the number of leading 0 bits
func (b Bitmap) PrefixLen() int {
for i := range b {
for j := 0; j < 8; j++ {
if (b[i]>>uint8(7-j))&0x1 != 0 {
return i*8 + j
}
}
}
return nodeIDBits
}
// Prefix returns a copy of b with the first n bits set to 1 (if `one` is true) or 0 (if `one` is false)
// https://stackoverflow.com/a/23192263/182709
func (b Bitmap) Prefix(n int, one bool) Bitmap {
ret := b.Copy()
Outer:
for i := range ret {
for j := 0; j < 8; j++ {
if i*8+j < n {
if one {
ret[i] |= 1 << uint(7-j)
} else {
ret[i] &= ^(1 << uint(7-j))
}
} else {
break Outer
}
}
}
return ret
}
// Syffix returns a copy of b with the last n bits set to 1 (if `one` is true) or 0 (if `one` is false)
// https://stackoverflow.com/a/23192263/182709
func (b Bitmap) Suffix(n int, one bool) Bitmap {
ret := b.Copy()
Outer:
for i := len(ret) - 1; i >= 0; i-- {
for j := 7; j >= 0; j-- {
if i*8+j >= nodeIDBits-n {
if one {
ret[i] |= 1 << uint(7-j)
} else {
ret[i] &= ^(1 << uint(7-j))
}
} else {
break Outer
}
}
}
return ret
}
func (b Bitmap) MarshalBencode() ([]byte, error) {
str := string(b[:])
return bencode.EncodeBytes(str)
}
func (b *Bitmap) UnmarshalBencode(encoded []byte) error {
var str string
err := bencode.DecodeBytes(encoded, &str)
if err != nil {
return err
}
if len(str) != nodeIDLength {
return errors.Err("invalid bitmap length")
}
copy(b[:], str)
return nil
}
func BitmapFromBytes(data []byte) (Bitmap, error) {
var bmp Bitmap
if len(data) != len(bmp) {
return bmp, errors.Err("invalid bitmap of length %d", len(data))
}
copy(bmp[:], data)
return bmp, nil
}
func BitmapFromBytesP(data []byte) Bitmap {
bmp, err := BitmapFromBytes(data)
if err != nil {
panic(err)
}
return bmp
}
func BitmapFromString(data string) (Bitmap, error) {
return BitmapFromBytes([]byte(data))
}
func BitmapFromStringP(data string) Bitmap {
bmp, err := BitmapFromString(data)
if err != nil {
panic(err)
}
return bmp
}
func BitmapFromHex(hexStr string) (Bitmap, error) {
decoded, err := hex.DecodeString(hexStr)
if err != nil {
return Bitmap{}, errors.Err(err)
}
return BitmapFromBytes(decoded)
}
func BitmapFromHexP(hexStr string) Bitmap {
bmp, err := BitmapFromHex(hexStr)
if err != nil {
panic(err)
}
return bmp
}
func BitmapFromShortHex(hexStr string) (Bitmap, error) {
return BitmapFromHex(strings.Repeat("0", nodeIDLength*2-len(hexStr)) + hexStr)
}
func BitmapFromShortHexP(hexStr string) Bitmap {
bmp, err := BitmapFromShortHex(hexStr)
if err != nil {
panic(err)
}
return bmp
}
func RandomBitmapP() Bitmap {
var id Bitmap
_, err := rand.Read(id[:])
if err != nil {
panic(err)
}
return id
}
func RandomBitmapInRangeP(low, high Bitmap) Bitmap {
diff := high.Sub(low)
r := RandomBitmapP()
for r.Greater(diff) {
r = r.Sub(diff)
}
return r.Add(low)
}
func getBit(b []byte, n int) bool {
i := n / 8
j := n % 8
return b[i]&(1<<uint(7-j)) > 0
}
func setBit(b []byte, n int, one bool) {
i := n / 8
j := n % 8
if one {
b[i] |= 1 << uint(7-j)
} else {
b[i] &= ^(1 << uint(7-j))
}
}