Merge remote-tracking branch 'reflector/dht'
* reflector/dht: (75 commits) some linting fixed a few reflector issues, added some tests final fix fixed or silenced the last few things to get this building fixed some linting errors fix rpc server update dependencies, only run short tests in travis fix stuck goroutine announce still needs tests, but i tested a lot by hand and its good hash announcer / rate limiter refactor contact sort more handle peer port correctly Revert "add tcp port mapping to data store" iterative find value rpc command add jack.lbry.tech as a known node for debugging add tcp port mapping to data store bucket splitting is solid add dht start command, run a jsonrpc server to interact with the node grin's cleanup and some WIP ...
This commit is contained in:
commit
bf61bd7b92
25 changed files with 8467 additions and 0 deletions
399
dht/bits/bitmap.go
Normal file
399
dht/bits/bitmap.go
Normal file
|
@ -0,0 +1,399 @@
|
|||
package bits
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
// TODO: http://roaringbitmap.org/
|
||||
|
||||
const (
|
||||
NumBytes = 48 // bytes
|
||||
NumBits = NumBytes * 8
|
||||
)
|
||||
|
||||
// Bitmap is a generalized representation of an identifier or data that can be sorted, compared fast. Used by the DHT
|
||||
// package as a way to handle the unique identifiers of a DHT node.
|
||||
type Bitmap [NumBytes]byte
|
||||
|
||||
func (b Bitmap) RawString() string {
|
||||
return string(b[:])
|
||||
}
|
||||
|
||||
func (b Bitmap) String() string {
|
||||
return b.Hex()
|
||||
}
|
||||
|
||||
// BString returns the bitmap as a string of 0s and 1s
|
||||
func (b Bitmap) BString() string {
|
||||
var s string
|
||||
for _, byte := range b {
|
||||
s += strconv.FormatInt(int64(byte), 2)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Hex returns a hexadecimal representation of the bitmap.
|
||||
func (b Bitmap) Hex() string {
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// HexShort returns a hexadecimal representation of the first 4 bytes.
|
||||
func (b Bitmap) HexShort() string {
|
||||
return hex.EncodeToString(b[:4])
|
||||
}
|
||||
|
||||
// HexSimplified returns the hexadecimal representation with all leading 0's removed
|
||||
func (b Bitmap) HexSimplified() string {
|
||||
simple := strings.TrimLeft(b.Hex(), "0")
|
||||
if simple == "" {
|
||||
simple = "0"
|
||||
}
|
||||
return simple
|
||||
}
|
||||
|
||||
func (b Bitmap) Big() *big.Int {
|
||||
i := new(big.Int)
|
||||
i.SetString(b.Hex(), 16)
|
||||
return i
|
||||
}
|
||||
|
||||
// Cmp compares b and other and returns:
|
||||
//
|
||||
// -1 if b < other
|
||||
// 0 if b == other
|
||||
// +1 if b > other
|
||||
//
|
||||
func (b Bitmap) Cmp(other Bitmap) int {
|
||||
for k := range b {
|
||||
if b[k] < other[k] {
|
||||
return -1
|
||||
} else if b[k] > other[k] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Closer returns true if dist(b,x) < dist(b,y)
|
||||
func (b Bitmap) Closer(x, y Bitmap) bool {
|
||||
return x.Xor(b).Cmp(y.Xor(b)) < 0
|
||||
}
|
||||
|
||||
// Equals returns true if every byte in bitmap are equal, false otherwise
|
||||
func (b Bitmap) Equals(other Bitmap) bool {
|
||||
return b.Cmp(other) == 0
|
||||
}
|
||||
|
||||
// Copy returns a duplicate value for the bitmap.
|
||||
func (b Bitmap) Copy() Bitmap {
|
||||
var ret Bitmap
|
||||
copy(ret[:], b[:])
|
||||
return ret
|
||||
}
|
||||
|
||||
// Xor returns a diff bitmap. If they are equal, the returned bitmap will be all 0's. If 100% unique the returned
|
||||
// bitmap will be all 1's.
|
||||
func (b Bitmap) Xor(other Bitmap) Bitmap {
|
||||
var ret Bitmap
|
||||
for k := range b {
|
||||
ret[k] = b[k] ^ other[k]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// And returns a comparison bitmap, that for each byte returns the AND true table result
|
||||
func (b Bitmap) And(other Bitmap) Bitmap {
|
||||
var ret Bitmap
|
||||
for k := range b {
|
||||
ret[k] = b[k] & other[k]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Or returns a comparison bitmap, that for each byte returns the OR true table result
|
||||
func (b Bitmap) Or(other Bitmap) Bitmap {
|
||||
var ret Bitmap
|
||||
for k := range b {
|
||||
ret[k] = b[k] | other[k]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Not returns a complimentary bitmap that is an inverse. So b.NOT.NOT = b
|
||||
func (b Bitmap) Not() Bitmap {
|
||||
var ret Bitmap
|
||||
for k := range b {
|
||||
ret[k] = ^b[k]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (b Bitmap) add(other Bitmap) (Bitmap, bool) {
|
||||
var ret Bitmap
|
||||
carry := false
|
||||
for i := NumBits - 1; i >= 0; i-- {
|
||||
bBit := getBit(b[:], i)
|
||||
oBit := getBit(other[:], i)
|
||||
setBit(ret[:], i, bBit != oBit != carry)
|
||||
carry = (bBit && oBit) || (bBit && carry) || (oBit && carry)
|
||||
}
|
||||
return ret, carry
|
||||
}
|
||||
|
||||
// Add returns a bitmap that treats both bitmaps as numbers and adding them together. Since the size of a bitmap is
|
||||
// limited, an overflow is possible when adding bitmaps.
|
||||
func (b Bitmap) Add(other Bitmap) Bitmap {
|
||||
ret, carry := b.add(other)
|
||||
if carry {
|
||||
panic("overflow in bitmap addition. limited to " + strconv.Itoa(NumBits) + " bits.")
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Sub returns a bitmap that treats both bitmaps as numbers and subtracts then via the inverse of the other and adding
|
||||
// then together a + (-b). Negative bitmaps are not supported so other must be greater than this.
|
||||
func (b Bitmap) Sub(other Bitmap) Bitmap {
|
||||
if b.Cmp(other) < 0 {
|
||||
// ToDo: Why is this not supported? Should it say not implemented? BitMap might have a generic use case outside of dht.
|
||||
panic("negative bitmaps not supported")
|
||||
}
|
||||
complement, _ := other.Not().add(FromShortHexP("1"))
|
||||
ret, _ := b.add(complement)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Get returns the binary bit at the position passed.
|
||||
func (b Bitmap) Get(n int) bool {
|
||||
return getBit(b[:], n)
|
||||
}
|
||||
|
||||
// Set sets the binary bit at the position passed.
|
||||
func (b Bitmap) Set(n int, one bool) Bitmap {
|
||||
ret := b.Copy()
|
||||
setBit(ret[:], n, one)
|
||||
return ret
|
||||
}
|
||||
|
||||
// PrefixLen returns the number of leading 0 bits
|
||||
func (b Bitmap) PrefixLen() int {
|
||||
for i := range b {
|
||||
for j := 0; j < 8; j++ {
|
||||
if (b[i]>>uint8(7-j))&0x1 != 0 {
|
||||
return i*8 + j
|
||||
}
|
||||
}
|
||||
}
|
||||
return NumBits
|
||||
}
|
||||
|
||||
// Prefix returns a copy of b with the first n bits set to 1 (if `one` is true) or 0 (if `one` is false)
|
||||
// https://stackoverflow.com/a/23192263/182709
|
||||
func (b Bitmap) Prefix(n int, one bool) Bitmap {
|
||||
ret := b.Copy()
|
||||
|
||||
Outer:
|
||||
for i := range ret {
|
||||
for j := 0; j < 8; j++ {
|
||||
if i*8+j < n {
|
||||
if one {
|
||||
ret[i] |= 1 << uint(7-j)
|
||||
} else {
|
||||
ret[i] &= ^(1 << uint(7-j))
|
||||
}
|
||||
} else {
|
||||
break Outer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Suffix returns a copy of b with the last n bits set to 1 (if `one` is true) or 0 (if `one` is false)
|
||||
// https://stackoverflow.com/a/23192263/182709
|
||||
func (b Bitmap) Suffix(n int, one bool) Bitmap {
|
||||
ret := b.Copy()
|
||||
|
||||
Outer:
|
||||
for i := len(ret) - 1; i >= 0; i-- {
|
||||
for j := 7; j >= 0; j-- {
|
||||
if i*8+j >= NumBits-n {
|
||||
if one {
|
||||
ret[i] |= 1 << uint(7-j)
|
||||
} else {
|
||||
ret[i] &= ^(1 << uint(7-j))
|
||||
}
|
||||
} else {
|
||||
break Outer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// MarshalBencode implements the Marshaller(bencode)/Message interface.
|
||||
func (b Bitmap) MarshalBencode() ([]byte, error) {
|
||||
str := string(b[:])
|
||||
return bencode.EncodeBytes(str)
|
||||
}
|
||||
|
||||
// UnmarshalBencode implements the Marshaller(bencode)/Message interface.
|
||||
func (b *Bitmap) UnmarshalBencode(encoded []byte) error {
|
||||
var str string
|
||||
err := bencode.DecodeBytes(encoded, &str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(str) != NumBytes {
|
||||
return errors.Err("invalid bitmap length")
|
||||
}
|
||||
copy(b[:], str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FromBytes returns a bitmap as long as the byte array is of a specific length specified in the parameters.
|
||||
func FromBytes(data []byte) (Bitmap, error) {
|
||||
var bmp Bitmap
|
||||
|
||||
if len(data) != len(bmp) {
|
||||
return bmp, errors.Err("invalid bitmap of length %d", len(data))
|
||||
}
|
||||
|
||||
copy(bmp[:], data)
|
||||
return bmp, nil
|
||||
}
|
||||
|
||||
// FromBytesP returns a bitmap as long as the byte array is of a specific length specified in the parameters
|
||||
// otherwise it wil panic.
|
||||
func FromBytesP(data []byte) Bitmap {
|
||||
bmp, err := FromBytes(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bmp
|
||||
}
|
||||
|
||||
//FromString returns a bitmap by converting the string to bytes and creating from bytes as long as the byte array
|
||||
// is of a specific length specified in the parameters
|
||||
func FromString(data string) (Bitmap, error) {
|
||||
return FromBytes([]byte(data))
|
||||
}
|
||||
|
||||
//FromStringP returns a bitmap by converting the string to bytes and creating from bytes as long as the byte array
|
||||
// is of a specific length specified in the parameters otherwise it wil panic.
|
||||
func FromStringP(data string) Bitmap {
|
||||
bmp, err := FromString(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bmp
|
||||
}
|
||||
|
||||
//FromHex returns a bitmap by converting the hex string to bytes and creating from bytes as long as the byte array
|
||||
// is of a specific length specified in the parameters
|
||||
func FromHex(hexStr string) (Bitmap, error) {
|
||||
decoded, err := hex.DecodeString(hexStr)
|
||||
if err != nil {
|
||||
return Bitmap{}, errors.Err(err)
|
||||
}
|
||||
return FromBytes(decoded)
|
||||
}
|
||||
|
||||
//FromHexP returns a bitmap by converting the hex string to bytes and creating from bytes as long as the byte array
|
||||
// is of a specific length specified in the parameters otherwise it wil panic.
|
||||
func FromHexP(hexStr string) Bitmap {
|
||||
bmp, err := FromHex(hexStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bmp
|
||||
}
|
||||
|
||||
//FromShortHex returns a bitmap by converting the hex string to bytes, adding the leading zeros prefix to the
|
||||
// hex string and creating from bytes as long as the byte array is of a specific length specified in the parameters
|
||||
func FromShortHex(hexStr string) (Bitmap, error) {
|
||||
return FromHex(strings.Repeat("0", NumBytes*2-len(hexStr)) + hexStr)
|
||||
}
|
||||
|
||||
//FromShortHexP returns a bitmap by converting the hex string to bytes, adding the leading zeros prefix to the
|
||||
// hex string and creating from bytes as long as the byte array is of a specific length specified in the parameters
|
||||
// otherwise it wil panic.
|
||||
func FromShortHexP(hexStr string) Bitmap {
|
||||
bmp, err := FromShortHex(hexStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bmp
|
||||
}
|
||||
|
||||
func FromBigP(b *big.Int) Bitmap {
|
||||
return FromShortHexP(b.Text(16))
|
||||
}
|
||||
|
||||
// MaxP returns a bitmap with all bits set to 1
|
||||
func MaxP() Bitmap {
|
||||
return FromHexP(strings.Repeat("f", NumBytes*2))
|
||||
}
|
||||
|
||||
// Rand generates a cryptographically random bitmap with the confines of the parameters specified.
|
||||
func Rand() Bitmap {
|
||||
var id Bitmap
|
||||
_, err := rand.Read(id[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// RandInRangeP generates a cryptographically random bitmap and while it is greater than the high threshold
|
||||
// bitmap will subtract the diff between high and low until it is no longer greater that the high.
|
||||
func RandInRangeP(low, high Bitmap) Bitmap {
|
||||
diff := high.Sub(low)
|
||||
r := Rand()
|
||||
for r.Cmp(diff) > 0 {
|
||||
r = r.Sub(diff)
|
||||
}
|
||||
//ToDo - Adding the low at this point doesn't gurantee it will be within the range. Consider bitmaps as numbers and
|
||||
// I have a range of 50-100. If get to say 60, and add 50, I would be at 110. Should protect against this?
|
||||
return r.Add(low)
|
||||
}
|
||||
|
||||
func getBit(b []byte, n int) bool {
|
||||
i := n / 8
|
||||
j := n % 8
|
||||
return b[i]&(1<<uint(7-j)) > 0
|
||||
}
|
||||
|
||||
func setBit(b []byte, n int, one bool) {
|
||||
i := n / 8
|
||||
j := n % 8
|
||||
if one {
|
||||
b[i] |= 1 << uint(7-j)
|
||||
} else {
|
||||
b[i] &= ^(1 << uint(7-j))
|
||||
}
|
||||
}
|
||||
|
||||
// Closest returns the closest bitmap to target. if no bitmaps are provided, target itself is returned
|
||||
func Closest(target Bitmap, bitmaps ...Bitmap) Bitmap {
|
||||
if len(bitmaps) == 0 {
|
||||
return target
|
||||
}
|
||||
|
||||
var closest *Bitmap
|
||||
for _, b := range bitmaps {
|
||||
if closest == nil || target.Closer(b, *closest) {
|
||||
closest = &b
|
||||
}
|
||||
}
|
||||
return *closest
|
||||
}
|
386
dht/bits/bitmap_test.go
Normal file
386
dht/bits/bitmap_test.go
Normal file
|
@ -0,0 +1,386 @@
|
|||
package bits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
func TestBitmap(t *testing.T) {
|
||||
a := Bitmap{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
}
|
||||
b := Bitmap{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 46,
|
||||
}
|
||||
c := Bitmap{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
}
|
||||
|
||||
if !a.Equals(a) {
|
||||
t.Error("bitmap does not equal itself")
|
||||
}
|
||||
if a.Equals(b) {
|
||||
t.Error("bitmap equals another bitmap with different id")
|
||||
}
|
||||
|
||||
if !a.Xor(b).Equals(c) {
|
||||
t.Error(a.Xor(b))
|
||||
}
|
||||
|
||||
if c.PrefixLen() != 375 {
|
||||
t.Error(c.PrefixLen())
|
||||
}
|
||||
|
||||
if b.Cmp(a) < 0 {
|
||||
t.Error("bitmap fails Cmp test")
|
||||
}
|
||||
|
||||
if a.Closer(c, b) || !a.Closer(b, c) || c.Closer(a, b) || c.Closer(b, c) {
|
||||
t.Error("bitmap fails Closer test")
|
||||
}
|
||||
|
||||
id := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
if FromHexP(id).Hex() != id {
|
||||
t.Error(FromHexP(id).Hex())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_GetBit(t *testing.T) {
|
||||
tt := []struct {
|
||||
bit int
|
||||
expected bool
|
||||
panic bool
|
||||
}{
|
||||
{bit: 383, expected: false, panic: false},
|
||||
{bit: 382, expected: true, panic: false},
|
||||
{bit: 381, expected: false, panic: false},
|
||||
{bit: 380, expected: true, panic: false},
|
||||
}
|
||||
|
||||
b := FromShortHexP("a")
|
||||
|
||||
for _, test := range tt {
|
||||
actual := getBit(b[:], test.bit)
|
||||
if test.expected != actual {
|
||||
t.Errorf("getting bit %d of %s: expected %t, got %t", test.bit, b.HexSimplified(), test.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_SetBit(t *testing.T) {
|
||||
tt := []struct {
|
||||
hex string
|
||||
bit int
|
||||
one bool
|
||||
expected string
|
||||
panic bool
|
||||
}{
|
||||
{hex: "0", bit: 383, one: true, expected: "1", panic: false},
|
||||
{hex: "0", bit: 382, one: true, expected: "2", panic: false},
|
||||
{hex: "0", bit: 381, one: true, expected: "4", panic: false},
|
||||
{hex: "0", bit: 385, one: true, expected: "1", panic: true},
|
||||
{hex: "0", bit: 384, one: true, expected: "1", panic: true},
|
||||
}
|
||||
|
||||
for _, test := range tt {
|
||||
expected := FromShortHexP(test.expected)
|
||||
actual := FromShortHexP(test.hex)
|
||||
if test.panic {
|
||||
assertPanic(t, fmt.Sprintf("setting bit %d to %t", test.bit, test.one), func() { setBit(actual[:], test.bit, test.one) })
|
||||
} else {
|
||||
setBit(actual[:], test.bit, test.one)
|
||||
if !expected.Equals(actual) {
|
||||
t.Errorf("setting bit %d to %t: expected %s, got %s", test.bit, test.one, test.expected, actual.HexSimplified())
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_FromHexShort(t *testing.T) {
|
||||
tt := []struct {
|
||||
short string
|
||||
long string
|
||||
}{
|
||||
{short: "", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{short: "0", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{short: "00000", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{short: "9473745bc", long: "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000009473745bc"},
|
||||
{short: "09473745bc", long: "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000009473745bc"},
|
||||
{short: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
long: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"},
|
||||
}
|
||||
|
||||
for _, test := range tt {
|
||||
short := FromShortHexP(test.short)
|
||||
long := FromHexP(test.long)
|
||||
if !short.Equals(long) {
|
||||
t.Errorf("short hex %s: expected %s, got %s", test.short, long.Hex(), short.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmapMarshal(t *testing.T) {
|
||||
b := FromStringP("123456789012345678901234567890123456789012345678")
|
||||
encoded, err := bencode.EncodeBytes(b)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if string(encoded) != "48:123456789012345678901234567890123456789012345678" {
|
||||
t.Error("encoding does not match expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmapMarshalEmbedded(t *testing.T) {
|
||||
e := struct {
|
||||
A string
|
||||
B Bitmap
|
||||
C int
|
||||
}{
|
||||
A: "1",
|
||||
B: FromStringP("222222222222222222222222222222222222222222222222"),
|
||||
C: 3,
|
||||
}
|
||||
|
||||
encoded, err := bencode.EncodeBytes(e)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if string(encoded) != "d1:A1:11:B48:2222222222222222222222222222222222222222222222221:Ci3ee" {
|
||||
t.Error("encoding does not match expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmapMarshalEmbedded2(t *testing.T) {
|
||||
encoded, err := bencode.EncodeBytes([]interface{}{
|
||||
FromStringP("333333333333333333333333333333333333333333333333"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if string(encoded) != "l48:333333333333333333333333333333333333333333333333e" {
|
||||
t.Error("encoding does not match expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_PrefixLen(t *testing.T) {
|
||||
tt := []struct {
|
||||
hex string
|
||||
len int
|
||||
}{
|
||||
{len: 0, hex: "F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{len: 0, hex: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{len: 1, hex: "700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{len: 1, hex: "400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{len: 384, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{len: 383, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
|
||||
{len: 382, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"},
|
||||
{len: 382, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"},
|
||||
}
|
||||
|
||||
for _, test := range tt {
|
||||
len := FromHexP(test.hex).PrefixLen()
|
||||
if len != test.len {
|
||||
t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.hex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_Prefix(t *testing.T) {
|
||||
allOne := FromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
|
||||
|
||||
zerosTT := []struct {
|
||||
zeros int
|
||||
expected string
|
||||
}{
|
||||
{zeros: -123, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{zeros: 0, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{zeros: 1, expected: "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{zeros: 69, expected: "000000000000000007ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{zeros: 383, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
|
||||
{zeros: 384, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
}
|
||||
|
||||
for _, test := range zerosTT {
|
||||
expected := FromHexP(test.expected)
|
||||
actual := allOne.Prefix(test.zeros, false)
|
||||
if !actual.Equals(expected) {
|
||||
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < NumBits; i++ {
|
||||
b := allOne.Prefix(i, false)
|
||||
if b.PrefixLen() != i {
|
||||
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
allZero := FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
|
||||
onesTT := []struct {
|
||||
ones int
|
||||
expected string
|
||||
}{
|
||||
{ones: -123, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{ones: 0, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{ones: 1, expected: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{ones: 69, expected: "fffffffffffffffff8000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{ones: 383, expected: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"},
|
||||
{ones: 384, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{ones: 400, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
}
|
||||
|
||||
for _, test := range onesTT {
|
||||
expected := FromHexP(test.expected)
|
||||
actual := allZero.Prefix(test.ones, true)
|
||||
if !actual.Equals(expected) {
|
||||
t.Errorf("%d ones: got %s; expected %s", test.ones, actual.Hex(), expected.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_Suffix(t *testing.T) {
|
||||
allOne := FromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
|
||||
|
||||
zerosTT := []struct {
|
||||
zeros int
|
||||
expected string
|
||||
}{
|
||||
{zeros: -123, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{zeros: 0, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{zeros: 1, expected: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"},
|
||||
{zeros: 69, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe00000000000000000"},
|
||||
{zeros: 383, expected: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{zeros: 384, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
}
|
||||
|
||||
for _, test := range zerosTT {
|
||||
expected := FromHexP(test.expected)
|
||||
actual := allOne.Suffix(test.zeros, false)
|
||||
if !actual.Equals(expected) {
|
||||
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < NumBits; i++ {
|
||||
b := allOne.Prefix(i, false)
|
||||
if b.PrefixLen() != i {
|
||||
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
allZero := FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
|
||||
onesTT := []struct {
|
||||
ones int
|
||||
expected string
|
||||
}{
|
||||
{ones: -123, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{ones: 0, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
|
||||
{ones: 1, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
|
||||
{ones: 69, expected: "0000000000000000000000000000000000000000000000000000000000000000000000000000001fffffffffffffffff"},
|
||||
{ones: 383, expected: "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{ones: 384, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
{ones: 400, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
|
||||
}
|
||||
|
||||
for _, test := range onesTT {
|
||||
expected := FromHexP(test.expected)
|
||||
actual := allZero.Suffix(test.ones, true)
|
||||
if !actual.Equals(expected) {
|
||||
t.Errorf("%d ones: got %s; expected %s", test.ones, actual.Hex(), expected.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_Add(t *testing.T) {
|
||||
tt := []struct {
|
||||
a, b, sum string
|
||||
panic bool
|
||||
}{
|
||||
{"0", "0", "0", false},
|
||||
{"0", "1", "1", false},
|
||||
{"1", "0", "1", false},
|
||||
{"1", "1", "2", false},
|
||||
{"8", "4", "c", false},
|
||||
{"1000", "0010", "1010", false},
|
||||
{"1111", "1111", "2222", false},
|
||||
{"ffff", "1", "10000", false},
|
||||
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", false},
|
||||
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "1", "", true},
|
||||
}
|
||||
|
||||
for _, test := range tt {
|
||||
a := FromShortHexP(test.a)
|
||||
b := FromShortHexP(test.b)
|
||||
expected := FromShortHexP(test.sum)
|
||||
if test.panic {
|
||||
assertPanic(t, fmt.Sprintf("adding %s and %s", test.a, test.b), func() { a.Add(b) })
|
||||
} else {
|
||||
actual := a.Add(b)
|
||||
if !expected.Equals(actual) {
|
||||
t.Errorf("adding %s and %s; expected %s, got %s", test.a, test.b, test.sum, actual.HexSimplified())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitmap_Sub(t *testing.T) {
|
||||
tt := []struct {
|
||||
a, b, sum string
|
||||
panic bool
|
||||
}{
|
||||
{"0", "0", "0", false},
|
||||
{"1", "0", "1", false},
|
||||
{"1", "1", "0", false},
|
||||
{"8", "4", "4", false},
|
||||
{"f", "9", "6", false},
|
||||
{"f", "e", "1", false},
|
||||
{"10", "f", "1", false},
|
||||
{"2222", "1111", "1111", false},
|
||||
{"ffff", "1", "fffe", false},
|
||||
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", false},
|
||||
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", false},
|
||||
{"0", "1", "", true},
|
||||
}
|
||||
|
||||
for _, test := range tt {
|
||||
a := FromShortHexP(test.a)
|
||||
b := FromShortHexP(test.b)
|
||||
expected := FromShortHexP(test.sum)
|
||||
if test.panic {
|
||||
assertPanic(t, fmt.Sprintf("subtracting %s - %s", test.a, test.b), func() { a.Sub(b) })
|
||||
} else {
|
||||
actual := a.Sub(b)
|
||||
if !expected.Equals(actual) {
|
||||
t.Errorf("subtracting %s - %s; expected %s, got %s", test.a, test.b, test.sum, actual.HexSimplified())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertPanic(t *testing.T, text string, f func()) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("%s: did not panic as expected", text)
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
65
dht/bits/range.go
Normal file
65
dht/bits/range.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package bits
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
)
|
||||
|
||||
// Range has a start and end
|
||||
type Range struct {
|
||||
Start Bitmap
|
||||
End Bitmap
|
||||
}
|
||||
|
||||
func MaxRange() Range {
|
||||
return Range{
|
||||
Start: Bitmap{},
|
||||
End: MaxP(),
|
||||
}
|
||||
}
|
||||
|
||||
// IntervalP divides the range into `num` intervals and returns the `n`th one
|
||||
// intervals are approximately the same size, but may not be exact because of rounding issues
|
||||
// the first interval always starts at the beginning of the range, and the last interval always ends at the end
|
||||
func (r Range) IntervalP(n, num int) Range {
|
||||
if num < 1 || n < 1 || n > num {
|
||||
panic(errors.Err("invalid interval %d of %d", n, num))
|
||||
}
|
||||
|
||||
start := r.intervalStart(n, num)
|
||||
end := r.End.Big()
|
||||
if n < num {
|
||||
end = r.intervalStart(n+1, num)
|
||||
end.Sub(end, big.NewInt(1))
|
||||
}
|
||||
|
||||
return Range{FromBigP(start), FromBigP(end)}
|
||||
}
|
||||
|
||||
func (r Range) intervalStart(n, num int) *big.Int {
|
||||
// formula:
|
||||
// size = (end - start) / num
|
||||
// rem = (end - start) % num
|
||||
// intervalStart = rangeStart + (size * n-1) + ((rem * n-1) % num)
|
||||
|
||||
size := new(big.Int)
|
||||
rem := new(big.Int)
|
||||
size.Sub(r.End.Big(), r.Start.Big()).DivMod(size, big.NewInt(int64(num)), rem)
|
||||
|
||||
size.Mul(size, big.NewInt(int64(n-1)))
|
||||
rem.Mul(rem, big.NewInt(int64(n-1))).Mod(rem, big.NewInt(int64(num)))
|
||||
|
||||
start := r.Start.Big()
|
||||
start.Add(start, size).Add(start, rem)
|
||||
|
||||
return start
|
||||
}
|
||||
|
||||
func (r Range) IntervalSize() *big.Int {
|
||||
return (&big.Int{}).Sub(r.End.Big(), r.Start.Big())
|
||||
}
|
||||
|
||||
func (r Range) Contains(b Bitmap) bool {
|
||||
return r.Start.Cmp(b) <= 0 && r.End.Cmp(b) >= 0
|
||||
}
|
48
dht/bits/range_test.go
Normal file
48
dht/bits/range_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package bits
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMaxRange(t *testing.T) {
|
||||
start := FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
end := FromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
|
||||
r := MaxRange()
|
||||
|
||||
if !r.Start.Equals(start) {
|
||||
t.Error("max range does not start at the beginning")
|
||||
}
|
||||
if !r.End.Equals(end) {
|
||||
t.Error("max range does not end at the end")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRange_IntervalP(t *testing.T) {
|
||||
max := MaxRange()
|
||||
|
||||
numIntervals := 97
|
||||
expectedAvg := (&big.Int{}).Div(max.IntervalSize(), big.NewInt(int64(numIntervals)))
|
||||
maxDiff := big.NewInt(int64(numIntervals))
|
||||
|
||||
var lastEnd Bitmap
|
||||
|
||||
for i := 1; i <= numIntervals; i++ {
|
||||
ival := max.IntervalP(i, numIntervals)
|
||||
if i == 1 && !ival.Start.Equals(max.Start) {
|
||||
t.Error("first interval does not start at 0")
|
||||
}
|
||||
if i == numIntervals && !ival.End.Equals(max.End) {
|
||||
t.Error("last interval does not end at max")
|
||||
}
|
||||
if i > 1 && !ival.Start.Equals(lastEnd.Add(FromShortHexP("1"))) {
|
||||
t.Errorf("interval %d of %d: last end was %s, this start is %s", i, numIntervals, lastEnd.Hex(), ival.Start.Hex())
|
||||
}
|
||||
|
||||
if ival.IntervalSize().Cmp((&big.Int{}).Add(expectedAvg, maxDiff)) > 0 || ival.IntervalSize().Cmp((&big.Int{}).Sub(expectedAvg, maxDiff)) < 0 {
|
||||
t.Errorf("interval %d of %d: interval size is outside the normal range", i, numIntervals)
|
||||
}
|
||||
|
||||
lastEnd = ival.End
|
||||
}
|
||||
}
|
212
dht/bootstrap.go
Normal file
212
dht/bootstrap.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
const (
|
||||
bootstrapDefaultRefreshDuration = 15 * time.Minute
|
||||
)
|
||||
|
||||
// BootstrapNode is a configured node setup for testing.
|
||||
type BootstrapNode struct {
|
||||
Node
|
||||
|
||||
initialPingInterval time.Duration
|
||||
checkInterval time.Duration
|
||||
|
||||
nlock *sync.RWMutex
|
||||
peers map[bits.Bitmap]*peer
|
||||
nodeIDs []bits.Bitmap // necessary for efficient random ID selection
|
||||
}
|
||||
|
||||
// NewBootstrapNode returns a BootstrapNode pointer.
|
||||
func NewBootstrapNode(id bits.Bitmap, initialPingInterval, rePingInterval time.Duration) *BootstrapNode {
|
||||
b := &BootstrapNode{
|
||||
Node: *NewNode(id),
|
||||
|
||||
initialPingInterval: initialPingInterval,
|
||||
checkInterval: rePingInterval,
|
||||
|
||||
nlock: &sync.RWMutex{},
|
||||
peers: make(map[bits.Bitmap]*peer),
|
||||
nodeIDs: make([]bits.Bitmap, 0),
|
||||
}
|
||||
|
||||
b.requestHandler = b.handleRequest
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Add manually adds a contact
|
||||
func (b *BootstrapNode) Add(c Contact) {
|
||||
b.upsert(c)
|
||||
}
|
||||
|
||||
// Connect connects to the given connection and starts any background threads necessary
|
||||
func (b *BootstrapNode) Connect(conn UDPConn) error {
|
||||
err := b.Node.Connect(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("[%s] bootstrap: node connected", b.id.HexShort())
|
||||
|
||||
go func() {
|
||||
t := time.NewTicker(b.checkInterval / 5)
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
b.check()
|
||||
case <-b.grp.Ch():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upsert adds the contact to the list, or updates the lastPinged time
|
||||
func (b *BootstrapNode) upsert(c Contact) {
|
||||
b.nlock.Lock()
|
||||
defer b.nlock.Unlock()
|
||||
|
||||
if peer, exists := b.peers[c.ID]; exists {
|
||||
log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), peer.Contact.ID.HexShort())
|
||||
peer.Touch()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.ID.HexShort())
|
||||
b.peers[c.ID] = &peer{c, b.id.Xor(c.ID), time.Now(), 0}
|
||||
b.nodeIDs = append(b.nodeIDs, c.ID)
|
||||
}
|
||||
|
||||
// remove removes the contact from the list
|
||||
func (b *BootstrapNode) remove(c Contact) {
|
||||
b.nlock.Lock()
|
||||
defer b.nlock.Unlock()
|
||||
|
||||
_, exists := b.peers[c.ID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.ID.HexShort())
|
||||
delete(b.peers, c.ID)
|
||||
for i := range b.nodeIDs {
|
||||
if b.nodeIDs[i].Equals(c.ID) {
|
||||
b.nodeIDs = append(b.nodeIDs[:i], b.nodeIDs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get returns up to `limit` random contacts from the list
|
||||
func (b *BootstrapNode) get(limit int) []Contact {
|
||||
b.nlock.RLock()
|
||||
defer b.nlock.RUnlock()
|
||||
|
||||
if len(b.peers) < limit {
|
||||
limit = len(b.peers)
|
||||
}
|
||||
|
||||
ret := make([]Contact, limit)
|
||||
for i, k := range randKeys(len(b.nodeIDs))[:limit] {
|
||||
ret[i] = b.peers[b.nodeIDs[k]].Contact
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// ping pings a node. if the node responds, it is added to the list. otherwise, it is removed
|
||||
func (b *BootstrapNode) ping(c Contact) {
|
||||
log.Debugf("[%s] bootstrap: pinging %s", b.id.HexShort(), c.ID.HexShort())
|
||||
b.grp.Add(1)
|
||||
defer b.grp.Done()
|
||||
|
||||
resCh := b.SendAsync(c, Request{Method: pingMethod})
|
||||
|
||||
var res *Response
|
||||
|
||||
select {
|
||||
case res = <-resCh:
|
||||
case <-b.grp.Ch():
|
||||
return
|
||||
}
|
||||
|
||||
if res != nil && res.Data == pingSuccessResponse {
|
||||
b.upsert(c)
|
||||
} else {
|
||||
b.remove(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BootstrapNode) check() {
|
||||
b.nlock.RLock()
|
||||
defer b.nlock.RUnlock()
|
||||
|
||||
for i := range b.peers {
|
||||
if !b.peers[i].ActiveInLast(b.checkInterval) {
|
||||
go b.ping(b.peers[i].Contact)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) {
|
||||
switch request.Method {
|
||||
case pingMethod:
|
||||
err := b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse})
|
||||
if err != nil {
|
||||
log.Error("error sending response message - ", err)
|
||||
}
|
||||
case findNodeMethod:
|
||||
if request.Arg == nil {
|
||||
log.Errorln("request is missing arg")
|
||||
return
|
||||
}
|
||||
|
||||
err := b.sendMessage(addr, Response{
|
||||
ID: request.ID,
|
||||
NodeID: b.id,
|
||||
Contacts: b.get(bucketSize),
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("error sending 'findnodemethod' response message - ", err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
b.nlock.RLock()
|
||||
_, exists := b.peers[request.NodeID]
|
||||
b.nlock.RUnlock()
|
||||
if !exists {
|
||||
log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort())
|
||||
<-time.After(b.initialPingInterval)
|
||||
b.nlock.RLock()
|
||||
_, exists = b.peers[request.NodeID]
|
||||
b.nlock.RUnlock()
|
||||
if !exists {
|
||||
b.ping(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func randKeys(max int) []int {
|
||||
keys := make([]int, max)
|
||||
for k := range keys {
|
||||
keys[k] = k
|
||||
}
|
||||
rand.Shuffle(max, func(i, j int) {
|
||||
keys[i], keys[j] = keys[j], keys[i]
|
||||
})
|
||||
return keys
|
||||
}
|
24
dht/bootstrap_test.go
Normal file
24
dht/bootstrap_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
func TestBootstrapPing(t *testing.T) {
|
||||
b := NewBootstrapNode(bits.Rand(), 10, bootstrapDefaultRefreshDuration)
|
||||
|
||||
listener, err := net.ListenPacket(Network, "127.0.0.1:54320")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = b.Connect(listener.(*net.UDPConn))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
b.Shutdown()
|
||||
}
|
76
dht/config.go
Normal file
76
dht/config.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
peerproto "github.com/lbryio/reflector.go/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
Network = "udp4"
|
||||
DefaultPort = 4444
|
||||
|
||||
DefaultAnnounceRate = 10 // send at most this many announces per second
|
||||
DefaultReannounceTime = 50 * time.Minute // should be a bit less than hash expiration time
|
||||
|
||||
// TODO: all these constants should be defaults, and should be used to set values in the standard Config. then the code should use values in the config
|
||||
// TODO: alternatively, have a global Config for constants. at least that way tests can modify the values
|
||||
alpha = 3 // this is the constant alpha in the spec
|
||||
bucketSize = 8 // this is the constant k in the spec
|
||||
nodeIDLength = bits.NumBytes // bytes. this is the constant B in the spec
|
||||
messageIDLength = 20 // bytes.
|
||||
|
||||
udpRetry = 1
|
||||
udpTimeout = 5 * time.Second
|
||||
udpMaxMessageLength = 4096 // bytes. I think our longest message is ~676 bytes, so I rounded up to 1024
|
||||
// scratch that. a findValue could return more than K results if a lot of nodes are storing that value, so we need more buffer
|
||||
|
||||
maxPeerFails = 3 // after this many failures, a peer is considered bad and will be removed from the routing table
|
||||
//tExpire = 60 * time.Minute // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
|
||||
tRefresh = 1 * time.Hour // the time after which an otherwise unaccessed bucket must be refreshed
|
||||
//tReplicate = 1 * time.Hour // the interval between Kademlia replication events, when a node is required to publish its entire database
|
||||
//tNodeRefresh = 15 * time.Minute // the time after which a good node becomes questionable if it has not messaged us
|
||||
|
||||
compactNodeInfoLength = nodeIDLength + 6 // nodeID + 4 for IP + 2 for port
|
||||
|
||||
tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated
|
||||
)
|
||||
|
||||
// Config represents the configure of dht.
|
||||
type Config struct {
|
||||
// this node's address. format is `ip:port`
|
||||
Address string
|
||||
// the seed nodes through which we can join in dht network
|
||||
SeedNodes []string
|
||||
// the hex-encoded node id for this node. if string is empty, a random id will be generated
|
||||
NodeID string
|
||||
// print the state of the dht every X time
|
||||
PrintState time.Duration
|
||||
// the port that clients can use to download blobs using the LBRY peer protocol
|
||||
PeerProtocolPort int
|
||||
// if nonzero, an RPC server will listen to requests on this port and respond to them
|
||||
RPCPort int
|
||||
// the time after which the original publisher must reannounce a key/value pair
|
||||
ReannounceTime time.Duration
|
||||
// send at most this many announces per second
|
||||
AnnounceRate int
|
||||
// channel that will receive notifications about announcements
|
||||
AnnounceNotificationCh chan announceNotification
|
||||
}
|
||||
|
||||
// NewStandardConfig returns a Config pointer with default values.
|
||||
func NewStandardConfig() *Config {
|
||||
return &Config{
|
||||
Address: "0.0.0.0:" + strconv.Itoa(DefaultPort),
|
||||
SeedNodes: []string{
|
||||
"lbrynet1.lbry.io:4444",
|
||||
"lbrynet2.lbry.io:4444",
|
||||
"lbrynet3.lbry.io:4444",
|
||||
},
|
||||
PeerProtocolPort: peerproto.DefaultPort,
|
||||
ReannounceTime: DefaultReannounceTime,
|
||||
AnnounceRate: DefaultAnnounceRate,
|
||||
}
|
||||
}
|
118
dht/contact.go
Normal file
118
dht/contact.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
// TODO: if routing table is ever empty (aka the node is isolated), it should re-bootstrap
|
||||
|
||||
// Contact contains information for contacting another node on the network
|
||||
type Contact struct {
|
||||
ID bits.Bitmap
|
||||
IP net.IP
|
||||
Port int // the udp port used for the dht
|
||||
PeerPort int // the tcp port a peer can be contacted on for blob requests
|
||||
}
|
||||
|
||||
// Equals returns true if two contacts are the same.
|
||||
func (c Contact) Equals(other Contact, checkID bool) bool {
|
||||
return c.IP.Equal(other.IP) && c.Port == other.Port && (!checkID || c.ID == other.ID)
|
||||
}
|
||||
|
||||
// Addr returns the address of the contact.
|
||||
func (c Contact) Addr() *net.UDPAddr {
|
||||
return &net.UDPAddr{IP: c.IP, Port: c.Port}
|
||||
}
|
||||
|
||||
// String returns a short string representation of the contact
|
||||
func (c Contact) String() string {
|
||||
str := c.ID.HexShort() + "@" + c.Addr().String()
|
||||
if c.PeerPort != 0 {
|
||||
str += "(" + strconv.Itoa(c.PeerPort) + ")"
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// MarshalCompact returns a compact byteslice representation of the contact
|
||||
// NOTE: The compact representation always uses the tcp PeerPort, not the udp Port. This is dumb, but that's how the python daemon does it
|
||||
func (c Contact) MarshalCompact() ([]byte, error) {
|
||||
if c.IP.To4() == nil {
|
||||
return nil, errors.Err("ip not set")
|
||||
}
|
||||
if c.PeerPort < 0 || c.PeerPort > 65535 {
|
||||
return nil, errors.Err("invalid port")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write(c.IP.To4())
|
||||
buf.WriteByte(byte(c.PeerPort >> 8))
|
||||
buf.WriteByte(byte(c.PeerPort))
|
||||
buf.Write(c.ID[:])
|
||||
|
||||
if buf.Len() != compactNodeInfoLength {
|
||||
return nil, errors.Err("i dont know how this happened")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalCompact unmarshals the compact byteslice representation of a contact.
|
||||
// NOTE: The compact representation always uses the tcp PeerPort, not the udp Port. This is dumb, but that's how the python daemon does it
|
||||
func (c *Contact) UnmarshalCompact(b []byte) error {
|
||||
if len(b) != compactNodeInfoLength {
|
||||
return errors.Err("invalid compact length")
|
||||
}
|
||||
c.IP = net.IPv4(b[0], b[1], b[2], b[3]).To4()
|
||||
c.PeerPort = int(uint16(b[5]) | uint16(b[4])<<8)
|
||||
c.ID = bits.FromBytesP(b[6:])
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalBencode returns the serialized byte slice representation of a contact.
|
||||
func (c Contact) MarshalBencode() ([]byte, error) {
|
||||
return bencode.EncodeBytes([]interface{}{c.ID, c.IP.String(), c.Port})
|
||||
}
|
||||
|
||||
// UnmarshalBencode unmarshals the serialized byte slice into the appropriate fields of the contact.
|
||||
func (c *Contact) UnmarshalBencode(b []byte) error {
|
||||
var raw []bencode.RawMessage
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(raw) != 3 {
|
||||
return errors.Err("contact must have 3 elements; got %d", len(raw))
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(raw[0], &c.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ipStr string
|
||||
err = bencode.DecodeBytes(raw[1], &ipStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.IP = net.ParseIP(ipStr).To4()
|
||||
if c.IP == nil {
|
||||
return errors.Err("invalid IP")
|
||||
}
|
||||
|
||||
return bencode.DecodeBytes(raw[2], &c.Port)
|
||||
}
|
||||
|
||||
func sortByDistance(contacts []Contact, target bits.Bitmap) {
|
||||
sort.Slice(contacts, func(i, j int) bool {
|
||||
return contacts[i].ID.Xor(target).Cmp(contacts[j].ID.Xor(target)) < 0
|
||||
})
|
||||
}
|
31
dht/contact_test.go
Normal file
31
dht/contact_test.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
func TestCompactEncoding(t *testing.T) {
|
||||
c := Contact{
|
||||
ID: bits.FromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
PeerPort: int(55<<8 + 66),
|
||||
}
|
||||
|
||||
var compact []byte
|
||||
compact, err := c.MarshalCompact()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(compact) != compactNodeInfoLength {
|
||||
t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.ID[:]...)) {
|
||||
t.Errorf("compact bytes not encoded correctly")
|
||||
}
|
||||
}
|
232
dht/dht.go
Normal file
232
dht/dht.go
Normal file
|
@ -0,0 +1,232 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/lbry.go/stop"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var log *logrus.Logger
|
||||
|
||||
func UseLogger(l *logrus.Logger) {
|
||||
log = l
|
||||
}
|
||||
|
||||
func init() {
|
||||
log = logrus.StandardLogger()
|
||||
//log.SetFormatter(&log.TextFormatter{ForceColors: true})
|
||||
//log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
// DHT represents a DHT node.
|
||||
type DHT struct {
|
||||
// config
|
||||
conf *Config
|
||||
// local contact
|
||||
contact Contact
|
||||
// node
|
||||
node *Node
|
||||
// stopGroup to shut down DHT
|
||||
grp *stop.Group
|
||||
// channel is closed when DHT joins network
|
||||
joined chan struct{}
|
||||
// cache for store tokens
|
||||
tokenCache *tokenCache
|
||||
// hashes that need to be put into the announce queue or removed from the queue
|
||||
announceAddRemove chan queueEdit
|
||||
}
|
||||
|
||||
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
|
||||
func New(config *Config) *DHT {
|
||||
if config == nil {
|
||||
config = NewStandardConfig()
|
||||
}
|
||||
|
||||
d := &DHT{
|
||||
conf: config,
|
||||
grp: stop.New(),
|
||||
joined: make(chan struct{}),
|
||||
announceAddRemove: make(chan queueEdit),
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func (dht *DHT) connect(conn UDPConn) error {
|
||||
contact, err := getContact(dht.conf.NodeID, dht.conf.Address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dht.contact = contact
|
||||
dht.node = NewNode(contact.ID)
|
||||
dht.tokenCache = newTokenCache(dht.node, tokenSecretRotationInterval)
|
||||
|
||||
return dht.node.Connect(conn)
|
||||
}
|
||||
|
||||
// Start starts the dht
|
||||
func (dht *DHT) Start() error {
|
||||
listener, err := net.ListenPacket(Network, dht.conf.Address)
|
||||
if err != nil {
|
||||
return errors.Err(err)
|
||||
}
|
||||
conn := listener.(*net.UDPConn)
|
||||
|
||||
err = dht.connect(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dht.join()
|
||||
log.Infof("[%s] DHT ready on %s (%d nodes found during join)",
|
||||
dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count())
|
||||
|
||||
dht.grp.Add(1)
|
||||
go func() {
|
||||
dht.runAnnouncer()
|
||||
dht.grp.Done()
|
||||
}()
|
||||
|
||||
if dht.conf.RPCPort > 0 {
|
||||
dht.grp.Add(1)
|
||||
go func() {
|
||||
dht.runRPCServer(dht.conf.RPCPort)
|
||||
dht.grp.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// join makes current node join the dht network.
|
||||
func (dht *DHT) join() {
|
||||
defer close(dht.joined) // if anyone's waiting for join to finish, they'll know its done
|
||||
|
||||
log.Infof("[%s] joining DHT network", dht.node.id.HexShort())
|
||||
|
||||
// ping nodes, which gets their real node IDs and adds them to the routing table
|
||||
atLeastOneNodeResponded := false
|
||||
for _, addr := range dht.conf.SeedNodes {
|
||||
err := dht.Ping(addr)
|
||||
if err != nil {
|
||||
log.Error(errors.Prefix(fmt.Sprintf("[%s] join", dht.node.id.HexShort()), err))
|
||||
} else {
|
||||
atLeastOneNodeResponded = true
|
||||
}
|
||||
}
|
||||
|
||||
if !atLeastOneNodeResponded {
|
||||
log.Errorf("[%s] join: no nodes responded to initial ping", dht.node.id.HexShort())
|
||||
return
|
||||
}
|
||||
|
||||
// now call iterativeFind on yourself
|
||||
_, _, err := FindContacts(dht.node, dht.node.id, false, dht.grp.Child())
|
||||
if err != nil {
|
||||
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
|
||||
}
|
||||
|
||||
// TODO: after joining, refresh all buckets further away than our closest neighbor
|
||||
// http://xlattice.sourceforge.net/components/protocol/kademlia/specs.html#join
|
||||
}
|
||||
|
||||
// WaitUntilJoined blocks until the node joins the network.
|
||||
func (dht *DHT) WaitUntilJoined() {
|
||||
if dht.joined == nil {
|
||||
panic("dht not initialized")
|
||||
}
|
||||
<-dht.joined
|
||||
}
|
||||
|
||||
// Shutdown shuts down the dht
|
||||
func (dht *DHT) Shutdown() {
|
||||
log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
|
||||
dht.grp.StopAndWait()
|
||||
dht.node.Shutdown()
|
||||
log.Debugf("[%s] DHT stopped", dht.node.id.HexShort())
|
||||
}
|
||||
|
||||
// Ping pings a given address, creates a temporary contact for sending a message, and returns an error if communication
|
||||
// fails.
|
||||
func (dht *DHT) Ping(addr string) error {
|
||||
raddr, err := net.ResolveUDPAddr(Network, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpNode := Contact{ID: bits.Rand(), IP: raddr.IP, Port: raddr.Port}
|
||||
res := dht.node.Send(tmpNode, Request{Method: pingMethod}, SendOptions{skipIDCheck: true})
|
||||
if res == nil {
|
||||
return errors.Err("no response from node %s", addr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the list of nodes that have the blob for the given hash
|
||||
func (dht *DHT) Get(hash bits.Bitmap) ([]Contact, error) {
|
||||
contacts, found, err := FindContacts(dht.node, hash, true, dht.grp.Child())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if found {
|
||||
return contacts, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// PrintState prints the current state of the DHT including address, nr outstanding transactions, stored hashes as well
|
||||
// as current bucket information.
|
||||
func (dht *DHT) PrintState() {
|
||||
log.Printf("DHT node %s at %s", dht.contact.String(), time.Now().Format(time.RFC822Z))
|
||||
log.Printf("Outstanding transactions: %d", dht.node.CountActiveTransactions())
|
||||
log.Printf("Stored hashes: %d", dht.node.store.CountStoredHashes())
|
||||
log.Printf("Buckets:")
|
||||
for _, line := range strings.Split(dht.node.rt.BucketInfo(), "\n") {
|
||||
log.Println(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (dht DHT) ID() bits.Bitmap {
|
||||
return dht.contact.ID
|
||||
}
|
||||
|
||||
func getContact(nodeID, addr string) (Contact, error) {
|
||||
var c Contact
|
||||
if nodeID == "" {
|
||||
c.ID = bits.Rand()
|
||||
} else {
|
||||
c.ID = bits.FromHexP(nodeID)
|
||||
}
|
||||
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return c, errors.Err(err)
|
||||
} else if ip == "" {
|
||||
return c, errors.Err("address does not contain an IP")
|
||||
} else if port == "" {
|
||||
return c, errors.Err("address does not contain a port")
|
||||
}
|
||||
|
||||
c.IP = net.ParseIP(ip)
|
||||
if c.IP == nil {
|
||||
return c, errors.Err("invalid ip")
|
||||
}
|
||||
|
||||
c.Port, err = cast.ToIntE(port)
|
||||
if err != nil {
|
||||
return c, errors.Err(err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
214
dht/dht_announce.go
Normal file
214
dht/dht_announce.go
Normal file
|
@ -0,0 +1,214 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type queueEdit struct {
|
||||
hash bits.Bitmap
|
||||
add bool
|
||||
}
|
||||
|
||||
const (
|
||||
announceStarted = "started"
|
||||
announceFinishd = "finished"
|
||||
)
|
||||
|
||||
type announceNotification struct {
|
||||
hash bits.Bitmap
|
||||
action string
|
||||
err error
|
||||
}
|
||||
|
||||
// Add adds the hash to the list of hashes this node is announcing
|
||||
func (dht *DHT) Add(hash bits.Bitmap) {
|
||||
dht.announceAddRemove <- queueEdit{hash: hash, add: true}
|
||||
}
|
||||
|
||||
// Remove removes the hash from the list of hashes this node is announcing
|
||||
func (dht *DHT) Remove(hash bits.Bitmap) {
|
||||
dht.announceAddRemove <- queueEdit{hash: hash, add: false}
|
||||
}
|
||||
|
||||
func (dht *DHT) runAnnouncer() {
|
||||
type hashAndTime struct {
|
||||
hash bits.Bitmap
|
||||
lastAnnounce time.Time
|
||||
}
|
||||
|
||||
var queue *ring.Ring
|
||||
hashes := make(map[bits.Bitmap]*ring.Ring)
|
||||
|
||||
var announceNextHash <-chan time.Time
|
||||
timer := time.NewTimer(math.MaxInt64)
|
||||
timer.Stop()
|
||||
|
||||
limitCh := make(chan time.Time)
|
||||
dht.grp.Add(1)
|
||||
go func() {
|
||||
defer dht.grp.Done()
|
||||
limiter := rate.NewLimiter(rate.Limit(dht.conf.AnnounceRate), dht.conf.AnnounceRate)
|
||||
for {
|
||||
err := limiter.Wait(context.Background()) // TODO: should use grp.ctx somehow? so when grp is closed, wait returns
|
||||
if err != nil {
|
||||
log.Error(errors.Prefix("rate limiter", err))
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case limitCh <- time.Now():
|
||||
case <-dht.grp.Ch():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
maintenance := time.NewTicker(1 * time.Minute)
|
||||
|
||||
// TODO: work to space hash announces out so they aren't bunched up around the reannounce time. track time since last announce. if its been more than the ideal time (reannounce time / numhashes), start announcing hashes early
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-dht.grp.Ch():
|
||||
return
|
||||
|
||||
case <-maintenance.C:
|
||||
maxAnnounce := dht.conf.AnnounceRate * int(dht.conf.ReannounceTime.Seconds())
|
||||
if len(hashes) > maxAnnounce {
|
||||
// TODO: send this to slack
|
||||
log.Warnf("DHT has %d hashes, but can only announce %d hashes in the %s reannounce window. Raise the announce rate or spawn more nodes.",
|
||||
len(hashes), maxAnnounce, dht.conf.ReannounceTime.String())
|
||||
}
|
||||
|
||||
case change := <-dht.announceAddRemove:
|
||||
if change.add {
|
||||
if _, exists := hashes[change.hash]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
r := ring.New(1)
|
||||
r.Value = hashAndTime{hash: change.hash}
|
||||
if queue != nil {
|
||||
queue.Prev().Link(r)
|
||||
}
|
||||
queue = r
|
||||
hashes[change.hash] = r
|
||||
announceNextHash = limitCh // announce next hash ASAP
|
||||
} else {
|
||||
r, exists := hashes[change.hash]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
delete(hashes, change.hash)
|
||||
|
||||
if len(hashes) == 0 {
|
||||
queue = ring.New(0)
|
||||
announceNextHash = nil // no hashes to announce, wait indefinitely
|
||||
} else {
|
||||
if r == queue {
|
||||
queue = queue.Next() // don't lose our pointer
|
||||
}
|
||||
r.Prev().Link(r.Next())
|
||||
}
|
||||
}
|
||||
|
||||
case <-announceNextHash:
|
||||
dht.grp.Add(1)
|
||||
ht := queue.Value.(hashAndTime)
|
||||
|
||||
if !ht.lastAnnounce.IsZero() {
|
||||
nextAnnounce := ht.lastAnnounce.Add(dht.conf.ReannounceTime)
|
||||
if nextAnnounce.After(time.Now()) {
|
||||
timer.Reset(time.Until(nextAnnounce))
|
||||
announceNextHash = timer.C // wait until next hash should be announced
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if dht.conf.AnnounceNotificationCh != nil {
|
||||
dht.conf.AnnounceNotificationCh <- announceNotification{
|
||||
hash: ht.hash,
|
||||
action: announceStarted,
|
||||
}
|
||||
}
|
||||
|
||||
go func(hash bits.Bitmap) {
|
||||
defer dht.grp.Done()
|
||||
err := dht.announce(hash)
|
||||
if err != nil {
|
||||
log.Error(errors.Prefix("announce", err))
|
||||
}
|
||||
|
||||
if dht.conf.AnnounceNotificationCh != nil {
|
||||
dht.conf.AnnounceNotificationCh <- announceNotification{
|
||||
hash: ht.hash,
|
||||
action: announceFinishd,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
}(ht.hash)
|
||||
|
||||
queue.Value = hashAndTime{hash: ht.hash, lastAnnounce: time.Now()}
|
||||
queue = queue.Next()
|
||||
announceNextHash = limitCh // announce next hash ASAP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Announce announces to the DHT that this node has the blob for the given hash
|
||||
func (dht *DHT) announce(hash bits.Bitmap) error {
|
||||
contacts, _, err := FindContacts(dht.node, hash, false, dht.grp.Child())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// self-store if we found less than K contacts, or we're closer than the farthest contact
|
||||
if len(contacts) < bucketSize {
|
||||
contacts = append(contacts, dht.contact)
|
||||
} else if hash.Closer(dht.node.id, contacts[bucketSize-1].ID) {
|
||||
contacts[bucketSize-1] = dht.contact
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for _, c := range contacts {
|
||||
wg.Add(1)
|
||||
go func(c Contact) {
|
||||
dht.store(hash, c)
|
||||
wg.Done()
|
||||
}(c)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dht *DHT) store(hash bits.Bitmap, c Contact) {
|
||||
if dht.contact.ID == c.ID {
|
||||
// self-store
|
||||
c.PeerPort = dht.conf.PeerProtocolPort
|
||||
dht.node.Store(hash, c)
|
||||
return
|
||||
}
|
||||
|
||||
dht.node.SendAsync(c, Request{
|
||||
Method: storeMethod,
|
||||
StoreArgs: &storeArgs{
|
||||
BlobHash: hash,
|
||||
Value: storeArgsValue{
|
||||
Token: dht.tokenCache.Get(c, hash, dht.grp.Ch()),
|
||||
LbryID: dht.contact.ID,
|
||||
Port: dht.conf.PeerProtocolPort,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
181
dht/dht_test.go
Normal file
181
dht/dht_test.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
func TestNodeFinder_FindNodes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow nodeFinder test")
|
||||
}
|
||||
|
||||
bs, dhts := TestingCreateNetwork(t, 3, true, false)
|
||||
defer func() {
|
||||
for i := range dhts {
|
||||
dhts[i].Shutdown()
|
||||
}
|
||||
bs.Shutdown()
|
||||
}()
|
||||
|
||||
contacts, found, err := FindContacts(dhts[2].node, bits.Rand(), false, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if found {
|
||||
t.Fatal("something was found, but it should not have been")
|
||||
}
|
||||
|
||||
if len(contacts) != 3 {
|
||||
t.Errorf("expected 3 node, found %d", len(contacts))
|
||||
}
|
||||
|
||||
foundBootstrap := false
|
||||
foundOne := false
|
||||
foundTwo := false
|
||||
|
||||
for _, n := range contacts {
|
||||
if n.ID.Equals(bs.id) {
|
||||
foundBootstrap = true
|
||||
}
|
||||
if n.ID.Equals(dhts[0].node.id) {
|
||||
foundOne = true
|
||||
}
|
||||
if n.ID.Equals(dhts[1].node.id) {
|
||||
foundTwo = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBootstrap {
|
||||
t.Errorf("did not find bootstrap node %s", bs.id.Hex())
|
||||
}
|
||||
if !foundOne {
|
||||
t.Errorf("did not find first node %s", dhts[0].node.id.Hex())
|
||||
}
|
||||
if !foundTwo {
|
||||
t.Errorf("did not find second node %s", dhts[1].node.id.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) {
|
||||
_, dhts := TestingCreateNetwork(t, 3, false, false)
|
||||
defer func() {
|
||||
for i := range dhts {
|
||||
dhts[i].Shutdown()
|
||||
}
|
||||
}()
|
||||
|
||||
_, _, err := FindContacts(dhts[2].node, bits.Rand(), false, nil)
|
||||
if err == nil {
|
||||
t.Fatal("contact finder should have errored saying that there are no contacts in the routing table")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeFinder_FindValue(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow nodeFinder test")
|
||||
}
|
||||
|
||||
bs, dhts := TestingCreateNetwork(t, 3, true, false)
|
||||
defer func() {
|
||||
for i := range dhts {
|
||||
dhts[i].Shutdown()
|
||||
}
|
||||
bs.Shutdown()
|
||||
}()
|
||||
|
||||
blobHashToFind := bits.Rand()
|
||||
nodeToFind := Contact{ID: bits.Rand(), IP: net.IPv4(1, 2, 3, 4), Port: 5678}
|
||||
dhts[0].node.store.Upsert(blobHashToFind, nodeToFind)
|
||||
|
||||
contacts, found, err := FindContacts(dhts[2].node, blobHashToFind, true, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Fatal("node was not found")
|
||||
}
|
||||
|
||||
if len(contacts) != 1 {
|
||||
t.Fatalf("expected one node, found %d", len(contacts))
|
||||
}
|
||||
|
||||
if !contacts[0].ID.Equals(nodeToFind.ID) {
|
||||
t.Fatalf("found node id %s, expected %s", contacts[0].ID.Hex(), nodeToFind.ID.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDHT_LargeDHT(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping large DHT test")
|
||||
}
|
||||
|
||||
nodes := 100
|
||||
bs, dhts := TestingCreateNetwork(t, nodes, true, true)
|
||||
defer func() {
|
||||
for _, d := range dhts {
|
||||
go d.Shutdown()
|
||||
}
|
||||
bs.Shutdown()
|
||||
time.Sleep(1 * time.Second)
|
||||
}()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
ids := make([]bits.Bitmap, nodes)
|
||||
for i := range ids {
|
||||
ids[i] = bits.Rand()
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
err := dhts[index].announce(ids[index])
|
||||
if err != nil {
|
||||
t.Error("error announcing random bitmap - ", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// check that each node is in at learst 1 other routing table
|
||||
rtCounts := make(map[bits.Bitmap]int)
|
||||
for _, d := range dhts {
|
||||
for _, d2 := range dhts {
|
||||
if d.node.id.Equals(d2.node.id) {
|
||||
continue
|
||||
}
|
||||
c := d2.node.rt.GetClosest(d.node.id, 1)
|
||||
if len(c) > 1 {
|
||||
t.Error("rt returned more than one node when only one requested")
|
||||
} else if len(c) == 1 && c[0].ID.Equals(d.node.id) {
|
||||
rtCounts[d.node.id]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range rtCounts {
|
||||
if v == 0 {
|
||||
t.Errorf("%s was not in any routing tables", k.HexShort())
|
||||
}
|
||||
}
|
||||
|
||||
// check that each ID is stored by at least 3 nodes
|
||||
storeCounts := make(map[bits.Bitmap]int)
|
||||
for _, d := range dhts {
|
||||
for _, id := range ids {
|
||||
if len(d.node.store.Get(id)) > 0 {
|
||||
storeCounts[id]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range storeCounts {
|
||||
if v == 0 {
|
||||
t.Errorf("%s was not stored by any nodes", k.HexShort())
|
||||
}
|
||||
}
|
||||
}
|
3060
dht/fixtures/TestRoutingTable_Save.golden
Normal file
3060
dht/fixtures/TestRoutingTable_Save.golden
Normal file
File diff suppressed because it is too large
Load diff
463
dht/message.go
Normal file
463
dht/message.go
Normal file
|
@ -0,0 +1,463 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const (
|
||||
pingMethod = "ping"
|
||||
storeMethod = "store"
|
||||
findNodeMethod = "findNode"
|
||||
findValueMethod = "findValue"
|
||||
)
|
||||
|
||||
const (
|
||||
pingSuccessResponse = "pong"
|
||||
storeSuccessResponse = "OK"
|
||||
)
|
||||
|
||||
const (
|
||||
requestType = 0
|
||||
responseType = 1
|
||||
errorType = 2
|
||||
)
|
||||
|
||||
const (
|
||||
// these are strings because bencode requires bytestring keys
|
||||
headerTypeField = "0"
|
||||
headerMessageIDField = "1" // message id is 20 bytes long
|
||||
headerNodeIDField = "2" // node id is 48 bytes long
|
||||
headerPayloadField = "3"
|
||||
headerArgsField = "4"
|
||||
contactsField = "contacts"
|
||||
tokenField = "token"
|
||||
protocolVersionField = "protocolVersion"
|
||||
)
|
||||
|
||||
// Message is a DHT message
|
||||
type Message interface {
|
||||
bencode.Marshaler
|
||||
}
|
||||
|
||||
type messageID [messageIDLength]byte
|
||||
|
||||
// HexShort returns the first 8 hex characters of the hex encoded message id.
|
||||
func (m messageID) HexShort() string {
|
||||
return hex.EncodeToString(m[:])[:8]
|
||||
}
|
||||
|
||||
// UnmarshalBencode takes a byte slice and unmarshals the message id.
|
||||
func (m *messageID) UnmarshalBencode(encoded []byte) error {
|
||||
var str string
|
||||
err := bencode.DecodeBytes(encoded, &str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
copy(m[:], str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshallBencode returns the encoded byte slice of the message id.
|
||||
func (m messageID) MarshalBencode() ([]byte, error) {
|
||||
str := string(m[:])
|
||||
return bencode.EncodeBytes(str)
|
||||
}
|
||||
|
||||
func newMessageID() messageID {
|
||||
var m messageID
|
||||
_, err := rand.Read(m[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Request represents a DHT request message
|
||||
type Request struct {
|
||||
ID messageID
|
||||
NodeID bits.Bitmap
|
||||
Method string
|
||||
Arg *bits.Bitmap
|
||||
StoreArgs *storeArgs
|
||||
ProtocolVersion int
|
||||
}
|
||||
|
||||
// MarshalBencode returns the serialized byte slice representation of the request
|
||||
func (r Request) MarshalBencode() ([]byte, error) {
|
||||
var args interface{}
|
||||
if r.StoreArgs != nil {
|
||||
args = r.StoreArgs
|
||||
} else if r.Arg != nil {
|
||||
args = []bits.Bitmap{*r.Arg}
|
||||
} else {
|
||||
args = []string{} // request must always have keys 0-4, so we use an empty list for PING
|
||||
}
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
headerMessageIDField: r.ID,
|
||||
headerNodeIDField: r.NodeID,
|
||||
headerPayloadField: r.Method,
|
||||
headerArgsField: args,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalBencode unmarshals the serialized byte slice into the appropriate fields of the request.
|
||||
func (r *Request) UnmarshalBencode(b []byte) error {
|
||||
var raw struct {
|
||||
ID messageID `bencode:"1"`
|
||||
NodeID bits.Bitmap `bencode:"2"`
|
||||
Method string `bencode:"3"`
|
||||
Args bencode.RawMessage `bencode:"4"`
|
||||
}
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return errors.Prefix("request unmarshal", err)
|
||||
}
|
||||
|
||||
r.ID = raw.ID
|
||||
r.NodeID = raw.NodeID
|
||||
r.Method = raw.Method
|
||||
|
||||
if r.Method == storeMethod {
|
||||
r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it.
|
||||
err = bencode.DecodeBytes(raw.Args, &r.StoreArgs)
|
||||
if err != nil {
|
||||
return errors.Prefix("request unmarshal", err)
|
||||
}
|
||||
} else if len(raw.Args) > 2 { // 2 because an empty list is `le`
|
||||
r.Arg, r.ProtocolVersion, err = processArgsAndProtoVersion(raw.Args)
|
||||
if err != nil {
|
||||
return errors.Prefix("request unmarshal", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processArgsAndProtoVersion(raw bencode.RawMessage) (arg *bits.Bitmap, version int, err error) {
|
||||
var args []bencode.RawMessage
|
||||
err = bencode.DecodeBytes(raw, &args)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var extras map[string]int
|
||||
err = bencode.DecodeBytes(args[len(args)-1], &extras)
|
||||
if err == nil {
|
||||
if v, exists := extras[protocolVersionField]; exists {
|
||||
version = v
|
||||
args = args[:len(args)-1]
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
var b bits.Bitmap
|
||||
err = bencode.DecodeBytes(args[0], &b)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
arg = &b
|
||||
}
|
||||
|
||||
return arg, version, nil
|
||||
}
|
||||
|
||||
func (r Request) argsDebug() string {
|
||||
if r.StoreArgs != nil {
|
||||
return r.StoreArgs.BlobHash.HexShort() + ", " + r.StoreArgs.Value.LbryID.HexShort() + ":" + strconv.Itoa(r.StoreArgs.Value.Port)
|
||||
} else if r.Arg != nil {
|
||||
return r.Arg.HexShort()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type storeArgsValue struct {
|
||||
Token string `bencode:"token"`
|
||||
LbryID bits.Bitmap `bencode:"lbryid"`
|
||||
Port int `bencode:"port"`
|
||||
}
|
||||
|
||||
type storeArgs struct {
|
||||
BlobHash bits.Bitmap
|
||||
Value storeArgsValue
|
||||
NodeID bits.Bitmap // original publisher id? I think this is getting fixed in the new dht stuff
|
||||
SelfStore bool // this is an int on the wire
|
||||
}
|
||||
|
||||
// MarshalBencode returns the serialized byte slice representation of the storage arguments.
|
||||
func (s storeArgs) MarshalBencode() ([]byte, error) {
|
||||
encodedValue, err := bencode.EncodeString(s.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selfStoreStr := 0
|
||||
if s.SelfStore {
|
||||
selfStoreStr = 1
|
||||
}
|
||||
|
||||
return bencode.EncodeBytes([]interface{}{
|
||||
s.BlobHash,
|
||||
bencode.RawMessage(encodedValue),
|
||||
s.NodeID,
|
||||
selfStoreStr,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalBencode unmarshals the serialized byte slice into the appropriate fields of the store arguments.
|
||||
func (s *storeArgs) UnmarshalBencode(b []byte) error {
|
||||
var argsInt []bencode.RawMessage
|
||||
err := bencode.DecodeBytes(b, &argsInt)
|
||||
if err != nil {
|
||||
return errors.Prefix("storeArgs unmarshal", err)
|
||||
}
|
||||
|
||||
if len(argsInt) != 4 {
|
||||
return errors.Err("unexpected number of fields for store args. got " + cast.ToString(len(argsInt)))
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(argsInt[0], &s.BlobHash)
|
||||
if err != nil {
|
||||
return errors.Prefix("storeArgs unmarshal", err)
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(argsInt[1], &s.Value)
|
||||
if err != nil {
|
||||
return errors.Prefix("storeArgs unmarshal", err)
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(argsInt[2], &s.NodeID)
|
||||
if err != nil {
|
||||
return errors.Prefix("storeArgs unmarshal", err)
|
||||
}
|
||||
|
||||
var selfStore int
|
||||
err = bencode.DecodeBytes(argsInt[3], &selfStore)
|
||||
if err != nil {
|
||||
return errors.Prefix("storeArgs unmarshal", err)
|
||||
}
|
||||
if selfStore == 0 {
|
||||
s.SelfStore = false
|
||||
} else if selfStore == 1 {
|
||||
s.SelfStore = true
|
||||
} else {
|
||||
return errors.Err("selfstore must be 1 or 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Response represents a DHT response message
|
||||
type Response struct {
|
||||
ID messageID
|
||||
NodeID bits.Bitmap
|
||||
Data string
|
||||
Contacts []Contact
|
||||
FindValueKey string
|
||||
Token string
|
||||
ProtocolVersion int
|
||||
}
|
||||
|
||||
func (r Response) argsDebug() string {
|
||||
if r.Data != "" {
|
||||
return r.Data
|
||||
}
|
||||
|
||||
str := "contacts "
|
||||
if r.FindValueKey != "" {
|
||||
str = "value for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
|
||||
}
|
||||
|
||||
str += "|"
|
||||
for _, c := range r.Contacts {
|
||||
str += c.String() + ","
|
||||
}
|
||||
str = strings.TrimRight(str, ",") + "|"
|
||||
|
||||
if r.Token != "" {
|
||||
str += " token: " + hex.EncodeToString([]byte(r.Token))[:8]
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// MarshalBencode returns the serialized byte slice representation of the response.
|
||||
func (r Response) MarshalBencode() ([]byte, error) {
|
||||
data := map[string]interface{}{
|
||||
headerTypeField: responseType,
|
||||
headerMessageIDField: r.ID,
|
||||
headerNodeIDField: r.NodeID,
|
||||
}
|
||||
|
||||
if r.Data != "" {
|
||||
// ping or store
|
||||
data[headerPayloadField] = r.Data
|
||||
} else if r.FindValueKey != "" {
|
||||
// findValue success
|
||||
if r.Token == "" {
|
||||
return nil, errors.Err("response to findValue must have a token")
|
||||
}
|
||||
|
||||
var contacts [][]byte
|
||||
for _, c := range r.Contacts {
|
||||
compact, err := c.MarshalCompact()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contacts = append(contacts, compact)
|
||||
}
|
||||
data[headerPayloadField] = map[string]interface{}{
|
||||
r.FindValueKey: contacts,
|
||||
tokenField: r.Token,
|
||||
}
|
||||
} else if r.Token != "" {
|
||||
// findValue failure falling back to findNode
|
||||
data[headerPayloadField] = map[string]interface{}{
|
||||
contactsField: r.Contacts,
|
||||
tokenField: r.Token,
|
||||
}
|
||||
} else {
|
||||
// straight up findNode
|
||||
data[headerPayloadField] = r.Contacts
|
||||
}
|
||||
|
||||
return bencode.EncodeBytes(data)
|
||||
}
|
||||
|
||||
// UnmarshalBencode unmarshals the serialized byte slice into the appropriate fields of the store arguments.
|
||||
func (r *Response) UnmarshalBencode(b []byte) error {
|
||||
var raw struct {
|
||||
ID messageID `bencode:"1"`
|
||||
NodeID bits.Bitmap `bencode:"2"`
|
||||
Data bencode.RawMessage `bencode:"3"`
|
||||
}
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ID = raw.ID
|
||||
r.NodeID = raw.NodeID
|
||||
|
||||
// maybe data is a string (response to ping or store)?
|
||||
err = bencode.DecodeBytes(raw.Data, &r.Data)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybe data is a list of contacts (response to findNode)?
|
||||
err = bencode.DecodeBytes(raw.Data, &r.Contacts)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// it must be a response to findValue
|
||||
var rawData map[string]bencode.RawMessage
|
||||
err = bencode.DecodeBytes(raw.Data, &rawData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if token, ok := rawData[tokenField]; ok {
|
||||
err = bencode.DecodeBytes(token, &r.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(rawData, tokenField) // so it doesnt mess up findValue key finding below
|
||||
}
|
||||
|
||||
if protocolVersion, ok := rawData[protocolVersionField]; ok {
|
||||
err = bencode.DecodeBytes(protocolVersion, &r.ProtocolVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(rawData, protocolVersionField) // so it doesnt mess up findValue key finding below
|
||||
}
|
||||
|
||||
if contacts, ok := rawData[contactsField]; ok {
|
||||
err = bencode.DecodeBytes(contacts, &r.Contacts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
for k, v := range rawData {
|
||||
r.FindValueKey = k
|
||||
var compactContacts [][]byte
|
||||
err = bencode.DecodeBytes(v, &compactContacts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, compact := range compactContacts {
|
||||
var c Contact
|
||||
err = c.UnmarshalCompact(compact)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Contacts = append(r.Contacts, c)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error represents a DHT error response
|
||||
type Error struct {
|
||||
ID messageID
|
||||
NodeID bits.Bitmap
|
||||
ExceptionType string
|
||||
Response []string
|
||||
}
|
||||
|
||||
// MarshalBencode returns the serialized byte slice representation of an error message.
|
||||
func (e Error) MarshalBencode() ([]byte, error) {
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: errorType,
|
||||
headerMessageIDField: e.ID,
|
||||
headerNodeIDField: e.NodeID,
|
||||
headerPayloadField: e.ExceptionType,
|
||||
headerArgsField: e.Response,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalBencode unmarshals the serialized byte slice into the appropriate fields of the error message.
|
||||
func (e *Error) UnmarshalBencode(b []byte) error {
|
||||
var raw struct {
|
||||
ID messageID `bencode:"1"`
|
||||
NodeID bits.Bitmap `bencode:"2"`
|
||||
ExceptionType string `bencode:"3"`
|
||||
Args interface{} `bencode:"4"`
|
||||
}
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.ID = raw.ID
|
||||
e.NodeID = raw.NodeID
|
||||
e.ExceptionType = raw.ExceptionType
|
||||
|
||||
if reflect.TypeOf(raw.Args).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(raw.Args)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
e.Response = append(e.Response, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
223
dht/message_test.go
Normal file
223
dht/message_test.go
Normal file
File diff suppressed because one or more lines are too long
474
dht/node.go
Normal file
474
dht/node.go
Normal file
|
@ -0,0 +1,474 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
"github.com/lbryio/lbry.go/stop"
|
||||
"github.com/lbryio/lbry.go/util"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
// packet represents the information receive from udp.
|
||||
type packet struct {
|
||||
data []byte
|
||||
raddr *net.UDPAddr
|
||||
}
|
||||
|
||||
// UDPConn allows using a mocked connection to test sending/receiving data
|
||||
// TODO: stop mocking this and use the real thing
|
||||
type UDPConn interface {
|
||||
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
|
||||
WriteToUDP([]byte, *net.UDPAddr) (int, error)
|
||||
SetReadDeadline(time.Time) error
|
||||
SetWriteDeadline(time.Time) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// RequestHandlerFunc is exported handler for requests.
|
||||
type RequestHandlerFunc func(addr *net.UDPAddr, request Request)
|
||||
|
||||
// Node is a type representation of a node on the network.
|
||||
type Node struct {
|
||||
// the node's id
|
||||
id bits.Bitmap
|
||||
// UDP connection for sending and receiving data
|
||||
conn UDPConn
|
||||
// true if we've closed the connection on purpose
|
||||
connClosed bool
|
||||
// token manager
|
||||
tokens *tokenManager
|
||||
|
||||
// map of outstanding transactions + mutex
|
||||
txLock *sync.RWMutex
|
||||
transactions map[messageID]*transaction
|
||||
|
||||
// routing table
|
||||
rt *routingTable
|
||||
// data store
|
||||
store *contactStore
|
||||
|
||||
// overrides for request handlers
|
||||
requestHandler RequestHandlerFunc
|
||||
|
||||
// stop the node neatly and clean up after itself
|
||||
grp *stop.Group
|
||||
}
|
||||
|
||||
// NewNode returns an initialized Node's pointer.
|
||||
func NewNode(id bits.Bitmap) *Node {
|
||||
return &Node{
|
||||
id: id,
|
||||
rt: newRoutingTable(id),
|
||||
store: newStore(),
|
||||
|
||||
txLock: &sync.RWMutex{},
|
||||
transactions: make(map[messageID]*transaction),
|
||||
|
||||
grp: stop.New(),
|
||||
tokens: &tokenManager{},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect connects to the given connection and starts any background threads necessary
|
||||
func (n *Node) Connect(conn UDPConn) error {
|
||||
n.conn = conn
|
||||
|
||||
n.tokens.Start(tokenSecretRotationInterval)
|
||||
|
||||
go func() {
|
||||
// stop tokens and close the connection when we're shutting down
|
||||
<-n.grp.Ch()
|
||||
n.tokens.Stop()
|
||||
n.connClosed = true
|
||||
err := n.conn.Close()
|
||||
if err != nil {
|
||||
log.Error("error closing node connection on shutdown - ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
packets := make(chan packet)
|
||||
|
||||
n.grp.Add(1)
|
||||
go func() {
|
||||
defer n.grp.Done()
|
||||
|
||||
buf := make([]byte, udpMaxMessageLength)
|
||||
|
||||
for {
|
||||
bytesRead, raddr, err := n.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
if n.connClosed {
|
||||
return
|
||||
}
|
||||
log.Errorf("udp read error: %v", err)
|
||||
continue
|
||||
} else if raddr == nil {
|
||||
log.Errorf("udp read with no raddr")
|
||||
continue
|
||||
}
|
||||
|
||||
data := make([]byte, bytesRead)
|
||||
copy(data, buf[:bytesRead]) // slices use the same underlying array, so we need a new one for each packet
|
||||
|
||||
select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks
|
||||
case packets <- packet{data: data, raddr: raddr}:
|
||||
case <-n.grp.Ch():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
n.grp.Add(1)
|
||||
go func() {
|
||||
defer n.grp.Done()
|
||||
|
||||
var pkt packet
|
||||
|
||||
for {
|
||||
select {
|
||||
case pkt = <-packets:
|
||||
n.handlePacket(pkt)
|
||||
case <-n.grp.Ch():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// TODO: turn this back on when you're sure it works right
|
||||
n.grp.Add(1)
|
||||
go func() {
|
||||
defer n.grp.Done()
|
||||
n.startRoutingTableGrooming()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the node
|
||||
func (n *Node) Shutdown() {
|
||||
log.Debugf("[%s] node shutting down", n.id.HexShort())
|
||||
n.grp.StopAndWait()
|
||||
log.Debugf("[%s] node stopped", n.id.HexShort())
|
||||
}
|
||||
|
||||
// handlePacket handles packets received from udp.
|
||||
func (n *Node) handlePacket(pkt packet) {
|
||||
//log.Debugf("[%s] Received message from %s (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||
|
||||
if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) {
|
||||
log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", n.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||
return
|
||||
}
|
||||
|
||||
// the following is a bit of a hack, but it lets us avoid decoding every message twice
|
||||
// it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value
|
||||
// TODO: test this more thoroughly
|
||||
|
||||
switch pkt.data[5] {
|
||||
case '0' + requestType:
|
||||
request := Request{}
|
||||
err := bencode.DecodeBytes(pkt.data, &request)
|
||||
if err != nil {
|
||||
log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", n.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.argsDebug())
|
||||
n.handleRequest(pkt.raddr, request)
|
||||
|
||||
case '0' + responseType:
|
||||
response := Response{}
|
||||
err := bencode.DecodeBytes(pkt.data, &response)
|
||||
if err != nil {
|
||||
log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received response from %s: %s", n.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.argsDebug())
|
||||
n.handleResponse(pkt.raddr, response)
|
||||
|
||||
case '0' + errorType:
|
||||
e := Error{}
|
||||
err := bencode.DecodeBytes(pkt.data, &e)
|
||||
if err != nil {
|
||||
log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received error from %s: %s", n.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
|
||||
n.handleError(pkt.raddr, e)
|
||||
|
||||
default:
|
||||
log.Errorf("[%s] invalid message type: %s", n.id.HexShort(), string(pkt.data[5]))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
|
||||
if request.NodeID.Equals(n.id) {
|
||||
log.Warn("ignoring self-request")
|
||||
return
|
||||
}
|
||||
|
||||
// if a handler is overridden, call it instead
|
||||
if n.requestHandler != nil {
|
||||
n.requestHandler(addr, request)
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
default:
|
||||
//n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-request-method"})
|
||||
log.Errorln("invalid request method")
|
||||
return
|
||||
case pingMethod:
|
||||
err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse})
|
||||
if err != nil {
|
||||
log.Error("error sending 'pingmethod' response message - ", err)
|
||||
}
|
||||
case storeMethod:
|
||||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||
if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
|
||||
n.Store(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: addr.Port, PeerPort: request.StoreArgs.Value.Port})
|
||||
|
||||
err := n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse})
|
||||
if err != nil {
|
||||
log.Error("error sending 'storemethod' response message - ", err)
|
||||
}
|
||||
} else {
|
||||
err := n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"})
|
||||
if err != nil {
|
||||
log.Error("error sending 'storemethod'response message for invalid-token - ", err)
|
||||
}
|
||||
}
|
||||
case findNodeMethod:
|
||||
if request.Arg == nil {
|
||||
log.Errorln("request is missing arg")
|
||||
return
|
||||
}
|
||||
err := n.sendMessage(addr, Response{
|
||||
ID: request.ID,
|
||||
NodeID: n.id,
|
||||
Contacts: n.rt.GetClosest(*request.Arg, bucketSize),
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("error sending 'findnodemethod' response message - ", err)
|
||||
}
|
||||
|
||||
case findValueMethod:
|
||||
if request.Arg == nil {
|
||||
log.Errorln("request is missing arg")
|
||||
return
|
||||
}
|
||||
|
||||
res := Response{
|
||||
ID: request.ID,
|
||||
NodeID: n.id,
|
||||
Token: n.tokens.Get(request.NodeID, addr),
|
||||
}
|
||||
|
||||
if contacts := n.store.Get(*request.Arg); len(contacts) > 0 {
|
||||
res.FindValueKey = request.Arg.RawString()
|
||||
res.Contacts = contacts
|
||||
} else {
|
||||
res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize)
|
||||
}
|
||||
|
||||
err := n.sendMessage(addr, res)
|
||||
if err != nil {
|
||||
log.Error("error sending 'findvaluemethod' response message - ", err)
|
||||
}
|
||||
}
|
||||
|
||||
// nodes that send us requests should not be inserted, only refreshed.
|
||||
// the routing table must only contain "good" nodes, which are nodes that reply to our requests
|
||||
// if a node is already good (aka in the table), its fine to refresh it
|
||||
// http://www.bittorrent.org/beps/bep_0005.html#routing-table
|
||||
n.rt.Fresh(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
|
||||
}
|
||||
|
||||
// handleResponse handles responses received from udp.
|
||||
func (n *Node) handleResponse(addr *net.UDPAddr, response Response) {
|
||||
tx := n.txFind(response.ID, Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port})
|
||||
if tx != nil {
|
||||
select {
|
||||
case tx.res <- response:
|
||||
default:
|
||||
//log.Errorf("[%s] query %s: response received, but tx has no listener or multiple responses to the same tx", n.id.HexShort(), response.ID.HexShort())
|
||||
}
|
||||
}
|
||||
|
||||
n.rt.Update(Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port})
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func (n *Node) handleError(addr *net.UDPAddr, e Error) {
|
||||
spew.Dump(e)
|
||||
n.rt.Fresh(Contact{ID: e.NodeID, IP: addr.IP, Port: addr.Port})
|
||||
}
|
||||
|
||||
// send sends data to a udp address
|
||||
func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error {
|
||||
encoded, err := bencode.EncodeBytes(data)
|
||||
if err != nil {
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
if req, ok := data.(Request); ok {
|
||||
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
|
||||
n.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.argsDebug())
|
||||
} else if res, ok := data.(Response); ok {
|
||||
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
|
||||
n.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.argsDebug())
|
||||
} else {
|
||||
log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data))
|
||||
}
|
||||
|
||||
err = n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err != nil {
|
||||
if n.connClosed {
|
||||
return nil
|
||||
}
|
||||
log.Error("error setting write deadline - ", err)
|
||||
}
|
||||
|
||||
_, err = n.conn.WriteToUDP(encoded, addr)
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
// transaction represents a single query to the dht. it stores the queried contact, the request, and the response channel
|
||||
type transaction struct {
|
||||
contact Contact
|
||||
req Request
|
||||
res chan Response
|
||||
skipIDCheck bool
|
||||
}
|
||||
|
||||
// insert adds a transaction to the manager.
|
||||
func (n *Node) txInsert(tx *transaction) {
|
||||
n.txLock.Lock()
|
||||
defer n.txLock.Unlock()
|
||||
n.transactions[tx.req.ID] = tx
|
||||
}
|
||||
|
||||
// delete removes a transaction from the manager.
|
||||
func (n *Node) txDelete(id messageID) {
|
||||
n.txLock.Lock()
|
||||
defer n.txLock.Unlock()
|
||||
delete(n.transactions, id)
|
||||
}
|
||||
|
||||
// Find finds a transaction for the given id and contact
|
||||
func (n *Node) txFind(id messageID, c Contact) *transaction {
|
||||
n.txLock.RLock()
|
||||
defer n.txLock.RUnlock()
|
||||
|
||||
t, ok := n.transactions[id]
|
||||
if !ok || !t.contact.Equals(c, !t.skipIDCheck) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// SendOptions controls the behavior of send calls
|
||||
type SendOptions struct {
|
||||
skipIDCheck bool
|
||||
}
|
||||
|
||||
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
|
||||
// The response channel is closed when the transaction is completed or times out.
|
||||
func (n *Node) SendAsync(contact Contact, req Request, options ...SendOptions) <-chan *Response {
|
||||
ch := make(chan *Response, 1)
|
||||
|
||||
if contact.ID.Equals(n.id) {
|
||||
log.Error("sending query to self")
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
req.ID = newMessageID()
|
||||
req.NodeID = n.id
|
||||
tx := &transaction{
|
||||
contact: contact,
|
||||
req: req,
|
||||
res: make(chan Response),
|
||||
}
|
||||
|
||||
if len(options) > 0 && options[0].skipIDCheck {
|
||||
tx.skipIDCheck = true
|
||||
}
|
||||
|
||||
n.txInsert(tx)
|
||||
defer n.txDelete(tx.req.ID)
|
||||
|
||||
for i := 0; i < udpRetry; i++ {
|
||||
err := n.sendMessage(contact.Addr(), tx.req)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections
|
||||
log.Error("send error: ", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case res := <-tx.res:
|
||||
ch <- &res
|
||||
return
|
||||
case <-n.grp.Ch():
|
||||
return
|
||||
case <-time.After(udpTimeout):
|
||||
}
|
||||
}
|
||||
|
||||
// notify routing table about a failure to respond
|
||||
n.rt.Fail(tx.contact)
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// Send sends a transaction and blocks until the response is available. It returns a response, or nil
|
||||
// if the transaction timed out.
|
||||
func (n *Node) Send(contact Contact, req Request, options ...SendOptions) *Response {
|
||||
return <-n.SendAsync(contact, req, options...)
|
||||
}
|
||||
|
||||
// CountActiveTransactions returns the number of transactions in the manager
|
||||
func (n *Node) CountActiveTransactions() int {
|
||||
n.txLock.Lock()
|
||||
defer n.txLock.Unlock()
|
||||
return len(n.transactions)
|
||||
}
|
||||
|
||||
func (n *Node) startRoutingTableGrooming() {
|
||||
refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed
|
||||
for {
|
||||
select {
|
||||
case <-refreshTicker.C:
|
||||
RoutingTableRefresh(n, tRefresh, n.grp.Child())
|
||||
case <-n.grp.Ch():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store stores a node contact in the node's contact store.
|
||||
func (n *Node) Store(hash bits.Bitmap, c Contact) {
|
||||
n.store.Upsert(hash, c)
|
||||
}
|
||||
|
||||
//AddKnownNode adds a known-good node to the routing table
|
||||
func (n *Node) AddKnownNode(c Contact) {
|
||||
n.rt.Update(c)
|
||||
}
|
338
dht/node_finder.go
Normal file
338
dht/node_finder.go
Normal file
|
@ -0,0 +1,338 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/lbry.go/crypto"
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/lbry.go/stop"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/uber-go/atomic"
|
||||
)
|
||||
|
||||
// TODO: iterativeFindValue may be stopping early. if it gets a response with one peer, it should keep going because other nodes may know about more peers that have that blob
|
||||
// TODO: or, it should try a tcp handshake with peers as it finds them, to make sure they are still online and have the blob
|
||||
|
||||
var cfLog *logrus.Logger
|
||||
|
||||
func init() {
|
||||
cfLog = logrus.StandardLogger()
|
||||
}
|
||||
|
||||
func NodeFinderUseLogger(l *logrus.Logger) {
|
||||
cfLog = l
|
||||
}
|
||||
|
||||
type contactFinder struct {
|
||||
findValue bool // true if we're using findValue
|
||||
target bits.Bitmap
|
||||
node *Node
|
||||
|
||||
grp *stop.Group
|
||||
|
||||
findValueMutex *sync.Mutex
|
||||
findValueResult []Contact
|
||||
|
||||
activeContactsMutex *sync.Mutex
|
||||
activeContacts []Contact
|
||||
|
||||
shortlistMutex *sync.Mutex
|
||||
shortlist []Contact
|
||||
shortlistAdded map[bits.Bitmap]bool
|
||||
|
||||
closestContactMutex *sync.RWMutex
|
||||
closestContact *Contact
|
||||
notGettingCloser *atomic.Bool
|
||||
}
|
||||
|
||||
func FindContacts(node *Node, target bits.Bitmap, findValue bool, parentGrp *stop.Group) ([]Contact, bool, error) {
|
||||
cf := &contactFinder{
|
||||
node: node,
|
||||
target: target,
|
||||
findValue: findValue,
|
||||
findValueMutex: &sync.Mutex{},
|
||||
activeContactsMutex: &sync.Mutex{},
|
||||
shortlistMutex: &sync.Mutex{},
|
||||
shortlistAdded: make(map[bits.Bitmap]bool),
|
||||
grp: stop.New(parentGrp),
|
||||
closestContactMutex: &sync.RWMutex{},
|
||||
notGettingCloser: atomic.NewBool(false),
|
||||
}
|
||||
|
||||
return cf.Find()
|
||||
}
|
||||
|
||||
func (cf *contactFinder) Stop() {
|
||||
cf.grp.StopAndWait()
|
||||
}
|
||||
|
||||
func (cf *contactFinder) Find() ([]Contact, bool, error) {
|
||||
if cf.findValue {
|
||||
cf.debug("starting iterativeFindValue")
|
||||
} else {
|
||||
cf.debug("starting iterativeFindNode")
|
||||
}
|
||||
|
||||
cf.appendNewToShortlist(cf.node.rt.GetClosest(cf.target, alpha))
|
||||
if len(cf.shortlist) == 0 {
|
||||
return nil, false, errors.Err("[%s] find %s: no contacts in routing table", cf.node.id.HexShort(), cf.target.HexShort())
|
||||
}
|
||||
|
||||
go cf.cycle(false)
|
||||
timeout := 5 * time.Second
|
||||
CycleLoop:
|
||||
for {
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
go cf.cycle(false)
|
||||
case <-cf.grp.Ch():
|
||||
break CycleLoop
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: what to do if we have less than K active contacts, shortlist is empty, but we have other contacts in our routing table whom we have not contacted. prolly contact them
|
||||
|
||||
var contacts []Contact
|
||||
var found bool
|
||||
if cf.findValue && len(cf.findValueResult) > 0 {
|
||||
contacts = cf.findValueResult
|
||||
found = true
|
||||
} else {
|
||||
contacts = cf.activeContacts
|
||||
if len(contacts) > bucketSize {
|
||||
contacts = contacts[:bucketSize]
|
||||
}
|
||||
}
|
||||
|
||||
cf.Stop()
|
||||
return contacts, found, nil
|
||||
}
|
||||
|
||||
// cycle does a single cycle of sending alpha probes and checking results against closestNode
|
||||
func (cf *contactFinder) cycle(bigCycle bool) {
|
||||
cycleID := crypto.RandString(6)
|
||||
if bigCycle {
|
||||
cf.debug("LAUNCHING CYCLE %s, AND ITS A BIG CYCLE", cycleID)
|
||||
} else {
|
||||
cf.debug("LAUNCHING CYCLE %s", cycleID)
|
||||
}
|
||||
defer cf.debug("CYCLE %s DONE", cycleID)
|
||||
|
||||
cf.closestContactMutex.RLock()
|
||||
closestContact := cf.closestContact
|
||||
cf.closestContactMutex.RUnlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan *Contact)
|
||||
|
||||
limit := alpha
|
||||
if bigCycle {
|
||||
limit = bucketSize
|
||||
}
|
||||
|
||||
for i := 0; i < limit; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ch <- cf.probe(cycleID)
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
foundCloser := false
|
||||
for {
|
||||
c, more := <-ch
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
if c != nil && (closestContact == nil || cf.target.Closer(c.ID, closestContact.ID)) {
|
||||
if closestContact != nil {
|
||||
cf.debug("|%s| best contact improved: %s -> %s", cycleID, closestContact.ID.HexShort(), c.ID.HexShort())
|
||||
} else {
|
||||
cf.debug("|%s| best contact starting at %s", cycleID, c.ID.HexShort())
|
||||
}
|
||||
foundCloser = true
|
||||
closestContact = c
|
||||
}
|
||||
}
|
||||
|
||||
if cf.isSearchFinished() {
|
||||
cf.grp.Stop()
|
||||
return
|
||||
}
|
||||
|
||||
if foundCloser {
|
||||
cf.closestContactMutex.Lock()
|
||||
// have to check again after locking in case other probes found a closer one in the meantime
|
||||
if cf.closestContact == nil || cf.target.Closer(closestContact.ID, cf.closestContact.ID) {
|
||||
cf.closestContact = closestContact
|
||||
}
|
||||
cf.closestContactMutex.Unlock()
|
||||
go cf.cycle(false)
|
||||
} else if !bigCycle {
|
||||
cf.debug("|%s| no improvement, running big cycle", cycleID)
|
||||
go cf.cycle(true)
|
||||
} else {
|
||||
// big cycle ran and there was no improvement, so we're done
|
||||
cf.debug("|%s| big cycle ran, still no improvement", cycleID)
|
||||
cf.notGettingCloser.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
// probe sends a single probe, updates the lists, and returns the closest contact it found
|
||||
func (cf *contactFinder) probe(cycleID string) *Contact {
|
||||
maybeContact := cf.popFromShortlist()
|
||||
if maybeContact == nil {
|
||||
cf.debug("|%s| no contacts in shortlist, returning", cycleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
c := *maybeContact
|
||||
|
||||
if c.ID.Equals(cf.node.id) {
|
||||
return nil
|
||||
}
|
||||
|
||||
cf.debug("|%s| probe %s: launching", cycleID, c.ID.HexShort())
|
||||
|
||||
req := Request{Arg: &cf.target}
|
||||
if cf.findValue {
|
||||
req.Method = findValueMethod
|
||||
} else {
|
||||
req.Method = findNodeMethod
|
||||
}
|
||||
|
||||
var res *Response
|
||||
resCh := cf.node.SendAsync(c, req)
|
||||
select {
|
||||
case res = <-resCh:
|
||||
case <-cf.grp.Ch():
|
||||
cf.debug("|%s| probe %s: canceled", cycleID, c.ID.HexShort())
|
||||
return nil
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
cf.debug("|%s| probe %s: req canceled or timed out", cycleID, c.ID.HexShort())
|
||||
return nil
|
||||
}
|
||||
|
||||
if cf.findValue && res.FindValueKey != "" {
|
||||
cf.debug("|%s| probe %s: got value", cycleID, c.ID.HexShort())
|
||||
cf.findValueMutex.Lock()
|
||||
cf.findValueResult = res.Contacts
|
||||
cf.findValueMutex.Unlock()
|
||||
cf.grp.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
cf.debug("|%s| probe %s: got %s", cycleID, c.ID.HexShort(), res.argsDebug())
|
||||
cf.insertIntoActiveList(c)
|
||||
cf.appendNewToShortlist(res.Contacts)
|
||||
|
||||
cf.activeContactsMutex.Lock()
|
||||
contacts := cf.activeContacts
|
||||
if len(contacts) > bucketSize {
|
||||
contacts = contacts[:bucketSize]
|
||||
}
|
||||
contactsStr := ""
|
||||
for _, c := range contacts {
|
||||
contactsStr += c.ID.HexShort() + ", "
|
||||
}
|
||||
cf.activeContactsMutex.Unlock()
|
||||
|
||||
return cf.closest(res.Contacts...)
|
||||
}
|
||||
|
||||
// appendNewToShortlist appends any new contacts to the shortlist and sorts it by distance
|
||||
// contacts that have already been added to the shortlist in the past are ignored
|
||||
func (cf *contactFinder) appendNewToShortlist(contacts []Contact) {
|
||||
cf.shortlistMutex.Lock()
|
||||
defer cf.shortlistMutex.Unlock()
|
||||
|
||||
for _, c := range contacts {
|
||||
if _, ok := cf.shortlistAdded[c.ID]; !ok {
|
||||
cf.shortlist = append(cf.shortlist, c)
|
||||
cf.shortlistAdded[c.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
sortByDistance(cf.shortlist, cf.target)
|
||||
}
|
||||
|
||||
// popFromShortlist pops the first contact off the shortlist and returns it
|
||||
func (cf *contactFinder) popFromShortlist() *Contact {
|
||||
cf.shortlistMutex.Lock()
|
||||
defer cf.shortlistMutex.Unlock()
|
||||
|
||||
if len(cf.shortlist) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
first := cf.shortlist[0]
|
||||
cf.shortlist = cf.shortlist[1:]
|
||||
return &first
|
||||
}
|
||||
|
||||
// insertIntoActiveList inserts the contact into appropriate place in the list of active contacts (sorted by distance)
|
||||
func (cf *contactFinder) insertIntoActiveList(contact Contact) {
|
||||
cf.activeContactsMutex.Lock()
|
||||
defer cf.activeContactsMutex.Unlock()
|
||||
|
||||
inserted := false
|
||||
for i, n := range cf.activeContacts {
|
||||
if cf.target.Closer(contact.ID, n.ID) {
|
||||
cf.activeContacts = append(cf.activeContacts[:i], append([]Contact{contact}, cf.activeContacts[i:]...)...)
|
||||
inserted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !inserted {
|
||||
cf.activeContacts = append(cf.activeContacts, contact)
|
||||
}
|
||||
}
|
||||
|
||||
// isSearchFinished returns true if the search is done and should be stopped
|
||||
func (cf *contactFinder) isSearchFinished() bool {
|
||||
if cf.findValue && len(cf.findValueResult) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-cf.grp.Ch():
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
if cf.notGettingCloser.Load() {
|
||||
return true
|
||||
}
|
||||
|
||||
cf.activeContactsMutex.Lock()
|
||||
defer cf.activeContactsMutex.Unlock()
|
||||
return len(cf.activeContacts) >= bucketSize
|
||||
}
|
||||
|
||||
func (cf *contactFinder) debug(format string, args ...interface{}) {
|
||||
args = append([]interface{}{cf.node.id.HexShort()}, append([]interface{}{cf.target.HexShort()}, args...)...)
|
||||
cfLog.Debugf("[%s] find %s: "+format, args...)
|
||||
}
|
||||
|
||||
func (cf *contactFinder) closest(contacts ...Contact) *Contact {
|
||||
if len(contacts) == 0 {
|
||||
return nil
|
||||
}
|
||||
closest := contacts[0]
|
||||
for _, c := range contacts {
|
||||
if cf.target.Closer(c.ID, closest.ID) {
|
||||
closest = c
|
||||
}
|
||||
}
|
||||
return &closest
|
||||
}
|
422
dht/node_test.go
Normal file
422
dht/node_test.go
Normal file
|
@ -0,0 +1,422 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
dhtNodeID := bits.Rand()
|
||||
testNodeID := bits.Rand()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
|
||||
err := dht.connect(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dht.Shutdown()
|
||||
|
||||
messageID := newMessageID()
|
||||
|
||||
data, err := bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
headerMessageIDField: messageID,
|
||||
headerNodeIDField: testNodeID.RawString(),
|
||||
headerPayloadField: "ping",
|
||||
headerArgsField: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Error("timeout")
|
||||
case resp := <-conn.writes:
|
||||
var response map[string]interface{}
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(response) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(response))
|
||||
}
|
||||
|
||||
_, ok := response[headerTypeField]
|
||||
if !ok {
|
||||
t.Error("missing type field")
|
||||
} else {
|
||||
rType, ok := response[headerTypeField].(int64)
|
||||
if !ok {
|
||||
t.Error("type is not an integer")
|
||||
} else if rType != responseType {
|
||||
t.Error("unexpected response type")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerMessageIDField]
|
||||
if !ok {
|
||||
t.Error("missing message id field")
|
||||
} else {
|
||||
rMessageID, ok := response[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != string(messageID[:]) {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerNodeIDField]
|
||||
if !ok {
|
||||
t.Error("missing node id field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerNodeIDField].(string)
|
||||
if !ok {
|
||||
t.Error("node ID is not a string")
|
||||
} else if rNodeID != dhtNodeID.RawString() {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Error("missing payload field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerPayloadField].(string)
|
||||
if !ok {
|
||||
t.Error("payload is not a string")
|
||||
} else if rNodeID != pingSuccessResponse {
|
||||
t.Error("did not pong")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
dhtNodeID := bits.Rand()
|
||||
testNodeID := bits.Rand()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
|
||||
err := dht.connect(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dht.Shutdown()
|
||||
|
||||
messageID := newMessageID()
|
||||
blobHashToStore := bits.Rand()
|
||||
|
||||
storeRequest := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: storeMethod,
|
||||
StoreArgs: &storeArgs{
|
||||
BlobHash: blobHashToStore,
|
||||
Value: storeArgsValue{
|
||||
Token: dht.node.tokens.Get(testNodeID, conn.addr),
|
||||
LbryID: testNodeID,
|
||||
Port: 9999,
|
||||
},
|
||||
NodeID: testNodeID,
|
||||
},
|
||||
}
|
||||
|
||||
_ = "64 " + // start message
|
||||
"313A30 693065" + // type: 0
|
||||
"313A31 3230 3A 6EB490B5788B63F0F7E6D92352024D0CBDEC2D3A" + // message id
|
||||
"313A32 3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // node id
|
||||
"313A33 35 3A 73746F7265" + // method
|
||||
"313A34 6C" + // start args list
|
||||
"3438 3A 3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468" + // block hash
|
||||
"64" + // start value dict
|
||||
"363A6C6272796964 3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // lbry id
|
||||
"343A706F7274 69 33333333 65" + // port
|
||||
"353A746F6B656E 3438 3A 17C2D8E1E48EF21567FE4AD5C8ED944B798D3B65AB58D0C9122AD6587D1B5FED472EA2CB12284CEFA1C21EFF302322BD" + // token
|
||||
"65" + // end value dict
|
||||
"3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // node id
|
||||
"693065" + // self store (integer)
|
||||
"65" + // end args list
|
||||
"65" // end message
|
||||
|
||||
data, err := bencode.EncodeBytes(storeRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
var response map[string]interface{}
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("timeout")
|
||||
case resp := <-conn.writes:
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
verifyResponse(t, response, messageID, dhtNodeID.RawString())
|
||||
|
||||
_, ok := response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Error("missing payload field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerPayloadField].(string)
|
||||
if !ok {
|
||||
t.Error("payload is not a string")
|
||||
} else if rNodeID != storeSuccessResponse {
|
||||
t.Error("did not return OK")
|
||||
}
|
||||
}
|
||||
|
||||
if dht.node.store.CountStoredHashes() != 1 {
|
||||
t.Error("dht store has wrong number of items")
|
||||
}
|
||||
|
||||
items := dht.node.store.Get(blobHashToStore)
|
||||
if len(items) != 1 {
|
||||
t.Error("list created in store, but nothing in list")
|
||||
}
|
||||
if !items[0].ID.Equals(testNodeID) {
|
||||
t.Error("wrong value stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNode(t *testing.T) {
|
||||
dhtNodeID := bits.Rand()
|
||||
testNodeID := bits.Rand()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
|
||||
err := dht.connect(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dht.Shutdown()
|
||||
|
||||
nodesToInsert := 3
|
||||
var nodes []Contact
|
||||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Contact{ID: bits.Rand(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
|
||||
nodes = append(nodes, n)
|
||||
dht.node.rt.Update(n)
|
||||
}
|
||||
|
||||
messageID := newMessageID()
|
||||
blobHashToFind := bits.Rand()
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: findNodeMethod,
|
||||
Arg: &blobHashToFind,
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
var response map[string]interface{}
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("timeout")
|
||||
case resp := <-conn.writes:
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
verifyResponse(t, response, messageID, dhtNodeID.RawString())
|
||||
|
||||
_, ok := response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Fatal("missing payload field")
|
||||
}
|
||||
|
||||
contacts, ok := response[headerPayloadField].([]interface{})
|
||||
if !ok {
|
||||
t.Fatal("payload is not a list")
|
||||
}
|
||||
|
||||
verifyContacts(t, contacts, nodes)
|
||||
}
|
||||
|
||||
func TestFindValueExisting(t *testing.T) {
|
||||
dhtNodeID := bits.Rand()
|
||||
testNodeID := bits.Rand()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
|
||||
err := dht.connect(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dht.Shutdown()
|
||||
|
||||
nodesToInsert := 3
|
||||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Contact{ID: bits.Rand(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
|
||||
dht.node.rt.Update(n)
|
||||
}
|
||||
|
||||
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
|
||||
|
||||
messageID := newMessageID()
|
||||
valueToFind := bits.Rand()
|
||||
|
||||
nodeToFind := Contact{ID: bits.Rand(), IP: net.ParseIP("1.2.3.4"), PeerPort: 1286}
|
||||
dht.node.store.Upsert(valueToFind, nodeToFind)
|
||||
dht.node.store.Upsert(valueToFind, nodeToFind)
|
||||
dht.node.store.Upsert(valueToFind, nodeToFind)
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: findValueMethod,
|
||||
Arg: &valueToFind,
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
var response map[string]interface{}
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("timeout")
|
||||
case resp := <-conn.writes:
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
verifyResponse(t, response, messageID, dhtNodeID.RawString())
|
||||
|
||||
_, ok := response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Fatal("missing payload field")
|
||||
}
|
||||
|
||||
payload, ok := response[headerPayloadField].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("payload is not a dictionary")
|
||||
}
|
||||
|
||||
compactContacts, ok := payload[valueToFind.RawString()]
|
||||
if !ok {
|
||||
t.Fatal("payload is missing key for search value")
|
||||
}
|
||||
|
||||
contacts, ok := compactContacts.([]interface{})
|
||||
if !ok {
|
||||
t.Fatal("search results are not a list")
|
||||
}
|
||||
|
||||
verifyCompactContacts(t, contacts, []Contact{nodeToFind})
|
||||
}
|
||||
|
||||
func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||
dhtNodeID := bits.Rand()
|
||||
testNodeID := bits.Rand()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
|
||||
err := dht.connect(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dht.Shutdown()
|
||||
|
||||
nodesToInsert := 3
|
||||
var nodes []Contact
|
||||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Contact{ID: bits.Rand(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
|
||||
nodes = append(nodes, n)
|
||||
dht.node.rt.Update(n)
|
||||
}
|
||||
|
||||
messageID := newMessageID()
|
||||
valueToFind := bits.Rand()
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: findValueMethod,
|
||||
Arg: &valueToFind,
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
var response map[string]interface{}
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("timeout")
|
||||
case resp := <-conn.writes:
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
verifyResponse(t, response, messageID, dhtNodeID.RawString())
|
||||
|
||||
_, ok := response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Fatal("missing payload field")
|
||||
}
|
||||
|
||||
payload, ok := response[headerPayloadField].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("payload is not a dictionary")
|
||||
}
|
||||
|
||||
contactsList, ok := payload[contactsField]
|
||||
if !ok {
|
||||
t.Fatal("payload is missing 'contacts' key")
|
||||
}
|
||||
|
||||
contacts, ok := contactsList.([]interface{})
|
||||
if !ok {
|
||||
t.Fatal("'contacts' is not a list")
|
||||
}
|
||||
|
||||
verifyContacts(t, contacts, nodes)
|
||||
}
|
463
dht/routing_table.go
Normal file
463
dht/routing_table.go
Normal file
|
@ -0,0 +1,463 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/lbry.go/stop"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
// TODO: if routing table is ever empty (aka the node is isolated), it should re-bootstrap
|
||||
|
||||
// TODO: use a tree with bucket splitting instead of a fixed bucket list. include jack's optimization (see link in commit mesg)
|
||||
// https://github.com/lbryio/lbry/pull/1211/commits/341b27b6d21ac027671d42458826d02735aaae41
|
||||
|
||||
// peer is a contact with extra information
|
||||
type peer struct {
|
||||
Contact Contact
|
||||
Distance bits.Bitmap
|
||||
LastActivity time.Time
|
||||
// LastReplied time.Time
|
||||
// LastRequested time.Time
|
||||
// LastFailure time.Time
|
||||
// SecondLastFailure time.Time
|
||||
NumFailures int
|
||||
|
||||
//<lastPublished>,
|
||||
//<originallyPublished>
|
||||
// <originalPublisherID>
|
||||
}
|
||||
|
||||
func (p *peer) Touch() {
|
||||
p.LastActivity = time.Now()
|
||||
p.NumFailures = 0
|
||||
}
|
||||
|
||||
// ActiveSince returns whether a peer has responded in the last `d` duration
|
||||
// this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests
|
||||
func (p *peer) ActiveInLast(d time.Duration) bool {
|
||||
return time.Since(p.LastActivity) < d
|
||||
}
|
||||
|
||||
// IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row
|
||||
func (p *peer) IsBad(maxFalures int) bool {
|
||||
return p.NumFailures >= maxFalures
|
||||
}
|
||||
|
||||
// Fail marks a peer as having failed to respond. It returns whether or not the peer should be removed from the routing table
|
||||
func (p *peer) Fail() {
|
||||
p.NumFailures++
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
lock *sync.RWMutex
|
||||
peers []peer
|
||||
lastUpdate time.Time
|
||||
Range bits.Range // capitalized because `range` is a keyword
|
||||
}
|
||||
|
||||
func newBucket(r bits.Range) *bucket {
|
||||
return &bucket{
|
||||
peers: make([]peer, 0, bucketSize),
|
||||
lock: &sync.RWMutex{},
|
||||
Range: r,
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of peers in the bucket
|
||||
func (b bucket) Len() int {
|
||||
b.lock.RLock()
|
||||
defer b.lock.RUnlock()
|
||||
return len(b.peers)
|
||||
}
|
||||
|
||||
func (b bucket) Has(c Contact) bool {
|
||||
b.lock.RLock()
|
||||
defer b.lock.RUnlock()
|
||||
for _, p := range b.peers {
|
||||
if p.Contact.Equals(c, true) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Contacts returns a slice of the bucket's contacts
|
||||
func (b bucket) Contacts() []Contact {
|
||||
b.lock.RLock()
|
||||
defer b.lock.RUnlock()
|
||||
contacts := make([]Contact, len(b.peers))
|
||||
for i := range b.peers {
|
||||
contacts[i] = b.peers[i].Contact
|
||||
}
|
||||
return contacts
|
||||
}
|
||||
|
||||
// UpdatePeer marks a contact as having been successfully contacted. if insertIfNew and the contact is does not exist yet, it is inserted
|
||||
func (b *bucket) UpdatePeer(p peer, insertIfNew bool) error {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if !b.Range.Contains(p.Distance) {
|
||||
return errors.Err("this bucket range does not cover this peer")
|
||||
}
|
||||
|
||||
peerIndex := find(p.Contact.ID, b.peers)
|
||||
if peerIndex >= 0 {
|
||||
b.lastUpdate = time.Now()
|
||||
b.peers[peerIndex].Touch()
|
||||
moveToBack(b.peers, peerIndex)
|
||||
} else if insertIfNew {
|
||||
hasRoom := true
|
||||
|
||||
if len(b.peers) >= bucketSize {
|
||||
hasRoom = false
|
||||
for i := range b.peers {
|
||||
if b.peers[i].IsBad(maxPeerFails) {
|
||||
// TODO: Ping contact first. Only remove if it does not respond
|
||||
b.peers = append(b.peers[:i], b.peers[i+1:]...)
|
||||
hasRoom = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasRoom {
|
||||
b.lastUpdate = time.Now()
|
||||
p.Touch()
|
||||
b.peers = append(b.peers, p)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FailContact marks a contact as having failed, and removes it if it failed too many times
|
||||
func (b *bucket) FailContact(id bits.Bitmap) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
i := find(id, b.peers)
|
||||
if i >= 0 {
|
||||
// BEP5 says not to remove the contact until the bucket is full and you try to insert
|
||||
b.peers[i].Fail()
|
||||
}
|
||||
}
|
||||
|
||||
// find returns the contact in the bucket, or nil if the bucket does not contain the contact
|
||||
func find(id bits.Bitmap, peers []peer) int {
|
||||
for i := range peers {
|
||||
if peers[i].Contact.ID.Equals(id) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// NeedsRefresh returns true if bucket has not been updated in the last `refreshInterval`, false otherwise
|
||||
func (b *bucket) NeedsRefresh(refreshInterval time.Duration) bool {
|
||||
b.lock.RLock()
|
||||
defer b.lock.RUnlock()
|
||||
return time.Since(b.lastUpdate) > refreshInterval
|
||||
}
|
||||
|
||||
func (b *bucket) Split() (*bucket, *bucket) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
left := newBucket(b.Range.IntervalP(1, 2))
|
||||
right := newBucket(b.Range.IntervalP(2, 2))
|
||||
left.lastUpdate = b.lastUpdate
|
||||
right.lastUpdate = b.lastUpdate
|
||||
|
||||
for _, p := range b.peers {
|
||||
if left.Range.Contains(p.Distance) {
|
||||
left.peers = append(left.peers, p)
|
||||
} else {
|
||||
right.peers = append(right.peers, p)
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.peers) > 1 {
|
||||
if len(left.peers) == 0 {
|
||||
left, right = right.Split()
|
||||
left.Range.Start = b.Range.Start
|
||||
} else if len(right.peers) == 0 {
|
||||
left, right = left.Split()
|
||||
right.Range.End = b.Range.End
|
||||
}
|
||||
}
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
type routingTable struct {
|
||||
id bits.Bitmap
|
||||
buckets []*bucket
|
||||
mu *sync.RWMutex // this mutex is write-locked only when CHANGING THE NUMBER OF BUCKETS in the table
|
||||
}
|
||||
|
||||
func newRoutingTable(id bits.Bitmap) *routingTable {
|
||||
rt := routingTable{
|
||||
id: id,
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
rt.reset()
|
||||
return &rt
|
||||
}
|
||||
|
||||
func (rt *routingTable) reset() {
|
||||
rt.mu.Lock()
|
||||
defer rt.mu.Unlock()
|
||||
rt.buckets = []*bucket{newBucket(bits.MaxRange())}
|
||||
}
|
||||
|
||||
func (rt *routingTable) BucketInfo() string {
|
||||
rt.mu.RLock()
|
||||
defer rt.mu.RUnlock()
|
||||
|
||||
var bucketInfo []string
|
||||
for i, b := range rt.buckets {
|
||||
if b.Len() > 0 {
|
||||
contacts := b.Contacts()
|
||||
s := make([]string, len(contacts))
|
||||
for j, c := range contacts {
|
||||
s[j] = c.ID.HexShort()
|
||||
}
|
||||
bucketInfo = append(bucketInfo, fmt.Sprintf("bucket %d: (%d) %s", i, len(contacts), strings.Join(s, ", ")))
|
||||
}
|
||||
}
|
||||
if len(bucketInfo) == 0 {
|
||||
return "buckets are empty"
|
||||
}
|
||||
return strings.Join(bucketInfo, "\n")
|
||||
}
|
||||
|
||||
// Update inserts or refreshes a contact
|
||||
func (rt *routingTable) Update(c Contact) {
|
||||
rt.mu.Lock() // write lock, because updates may cause bucket splits
|
||||
defer rt.mu.Unlock()
|
||||
|
||||
b := rt.bucketFor(c.ID)
|
||||
|
||||
if rt.shouldSplit(b, c) {
|
||||
left, right := b.Split()
|
||||
|
||||
for i := range rt.buckets {
|
||||
if rt.buckets[i].Range.Start.Equals(left.Range.Start) {
|
||||
rt.buckets = append(rt.buckets[:i], append([]*bucket{left, right}, rt.buckets[i+1:]...)...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if left.Range.Contains(c.ID) {
|
||||
b = left
|
||||
} else {
|
||||
b = right
|
||||
}
|
||||
}
|
||||
|
||||
err := b.UpdatePeer(peer{Contact: c, Distance: rt.id.Xor(c.ID)}, true)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fresh refreshes a contact if its already in the routing table
|
||||
func (rt *routingTable) Fresh(c Contact) {
|
||||
rt.mu.RLock()
|
||||
defer rt.mu.RUnlock()
|
||||
err := rt.bucketFor(c.ID).UpdatePeer(peer{Contact: c, Distance: rt.id.Xor(c.ID)}, false)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// FailContact marks a contact as having failed, and removes it if it failed too many times
|
||||
func (rt *routingTable) Fail(c Contact) {
|
||||
rt.mu.RLock()
|
||||
defer rt.mu.RUnlock()
|
||||
rt.bucketFor(c.ID).FailContact(c.ID)
|
||||
}
|
||||
|
||||
// GetClosest returns the closest `limit` contacts from the routing table.
|
||||
// This is a locking wrapper around getClosest()
|
||||
func (rt *routingTable) GetClosest(target bits.Bitmap, limit int) []Contact {
|
||||
rt.mu.RLock()
|
||||
defer rt.mu.RUnlock()
|
||||
return rt.getClosest(target, limit)
|
||||
}
|
||||
|
||||
// getClosest returns the closest `limit` contacts from the routing table
|
||||
func (rt *routingTable) getClosest(target bits.Bitmap, limit int) []Contact {
|
||||
var contacts []Contact
|
||||
for _, b := range rt.buckets {
|
||||
contacts = append(contacts, b.Contacts()...)
|
||||
}
|
||||
|
||||
sortByDistance(contacts, target)
|
||||
if len(contacts) > limit {
|
||||
contacts = contacts[:limit]
|
||||
}
|
||||
|
||||
return contacts
|
||||
}
|
||||
|
||||
// Count returns the number of contacts in the routing table
|
||||
func (rt *routingTable) Count() int {
|
||||
rt.mu.RLock()
|
||||
defer rt.mu.RUnlock()
|
||||
count := 0
|
||||
for _, bucket := range rt.buckets {
|
||||
count += bucket.Len()
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Len returns the number of buckets in the routing table
|
||||
func (rt *routingTable) Len() int {
|
||||
rt.mu.RLock()
|
||||
defer rt.mu.RUnlock()
|
||||
return len(rt.buckets)
|
||||
}
|
||||
|
||||
func (rt *routingTable) bucketFor(target bits.Bitmap) *bucket {
|
||||
if rt.id.Equals(target) {
|
||||
panic("routing table does not have a bucket for its own id")
|
||||
}
|
||||
distance := target.Xor(rt.id)
|
||||
for _, b := range rt.buckets {
|
||||
if b.Range.Contains(distance) {
|
||||
return b
|
||||
}
|
||||
}
|
||||
panic("target is not contained in any buckets")
|
||||
}
|
||||
|
||||
func (rt *routingTable) shouldSplit(b *bucket, c Contact) bool {
|
||||
if b.Has(c) {
|
||||
return false
|
||||
}
|
||||
if b.Len() >= bucketSize {
|
||||
if b.Range.Start.Equals(bits.Bitmap{}) { // this is the bucket covering our node id
|
||||
return true
|
||||
}
|
||||
kClosest := rt.getClosest(rt.id, bucketSize)
|
||||
kthClosest := kClosest[len(kClosest)-1]
|
||||
if rt.id.Closer(c.ID, kthClosest.ID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//func (rt *routingTable) printBucketInfo() {
|
||||
// fmt.Printf("there are %d contacts in %d buckets\n", rt.Count(), rt.Len())
|
||||
// for i, b := range rt.buckets {
|
||||
// fmt.Printf("bucket %d, %d contacts\n", i+1, len(b.peers))
|
||||
// fmt.Printf(" start : %s\n", b.Range.Start.String())
|
||||
// fmt.Printf(" stop : %s\n", b.Range.End.String())
|
||||
// fmt.Println("")
|
||||
// }
|
||||
//}
|
||||
|
||||
func (rt *routingTable) GetIDsForRefresh(refreshInterval time.Duration) []bits.Bitmap {
|
||||
var bitmaps []bits.Bitmap
|
||||
for i, bucket := range rt.buckets {
|
||||
if bucket.NeedsRefresh(refreshInterval) {
|
||||
bitmaps = append(bitmaps, bits.Rand().Prefix(i, false))
|
||||
}
|
||||
}
|
||||
return bitmaps
|
||||
}
|
||||
|
||||
const rtContactSep = "-"
|
||||
|
||||
type rtSave struct {
|
||||
ID string `json:"id"`
|
||||
Contacts []string `json:"contacts"`
|
||||
}
|
||||
|
||||
func (rt *routingTable) MarshalJSON() ([]byte, error) {
|
||||
var data rtSave
|
||||
data.ID = rt.id.Hex()
|
||||
for _, b := range rt.buckets {
|
||||
for _, c := range b.Contacts() {
|
||||
data.Contacts = append(data.Contacts, strings.Join([]string{c.ID.Hex(), c.IP.String(), strconv.Itoa(c.Port)}, rtContactSep))
|
||||
}
|
||||
}
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
func (rt *routingTable) UnmarshalJSON(b []byte) error {
|
||||
var data rtSave
|
||||
err := json.Unmarshal(b, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rt.id, err = bits.FromHex(data.ID)
|
||||
if err != nil {
|
||||
return errors.Prefix("decoding ID", err)
|
||||
}
|
||||
rt.reset()
|
||||
|
||||
for _, s := range data.Contacts {
|
||||
parts := strings.Split(s, rtContactSep)
|
||||
if len(parts) != 3 {
|
||||
return errors.Err("decoding contact %s: wrong number of parts", s)
|
||||
}
|
||||
var c Contact
|
||||
c.ID, err = bits.FromHex(parts[0])
|
||||
if err != nil {
|
||||
return errors.Err("decoding contact %s: invalid ID: %s", s, err)
|
||||
}
|
||||
c.IP = net.ParseIP(parts[1])
|
||||
if c.IP == nil {
|
||||
return errors.Err("decoding contact %s: invalid IP", s)
|
||||
}
|
||||
c.Port, err = strconv.Atoi(parts[2])
|
||||
if err != nil {
|
||||
return errors.Err("decoding contact %s: invalid port: %s", s, err)
|
||||
}
|
||||
rt.Update(c)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoutingTableRefresh refreshes any buckets that need to be refreshed
|
||||
func RoutingTableRefresh(n *Node, refreshInterval time.Duration, parentGrp *stop.Group) {
|
||||
done := stop.New()
|
||||
|
||||
for _, id := range n.rt.GetIDsForRefresh(refreshInterval) {
|
||||
done.Add(1)
|
||||
go func(id bits.Bitmap) {
|
||||
defer done.Done()
|
||||
_, _, err := FindContacts(n, id, false, parentGrp)
|
||||
if err != nil {
|
||||
log.Error("error finding contact during routing table refresh - ", err)
|
||||
}
|
||||
}(id)
|
||||
}
|
||||
|
||||
done.Wait()
|
||||
done.Stop()
|
||||
}
|
||||
|
||||
func moveToBack(peers []peer, index int) {
|
||||
if index < 0 || len(peers) <= index+1 {
|
||||
return
|
||||
}
|
||||
p := peers[index]
|
||||
for i := index; i < len(peers)-1; i++ {
|
||||
peers[i] = peers[i+1]
|
||||
}
|
||||
peers[len(peers)-1] = p
|
||||
}
|
328
dht/routing_table_test.go
Normal file
328
dht/routing_table_test.go
Normal file
|
@ -0,0 +1,328 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/sebdah/goldie"
|
||||
)
|
||||
|
||||
func TestBucket_Split(t *testing.T) {
|
||||
rt := newRoutingTable(bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
|
||||
if len(rt.buckets) != 1 {
|
||||
t.Errorf("there should only be one bucket so far")
|
||||
}
|
||||
if len(rt.buckets[0].peers) != 0 {
|
||||
t.Errorf("there should be no contacts yet")
|
||||
}
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
id bits.Bitmap
|
||||
expectedBucketCount int
|
||||
expectedTotalContacts int
|
||||
}{
|
||||
//fill first bucket
|
||||
{"b1-one", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100"), 1, 1},
|
||||
{"b1-two", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000200"), 1, 2},
|
||||
{"b1-three", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000300"), 1, 3},
|
||||
{"b1-four", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000400"), 1, 4},
|
||||
{"b1-five", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000500"), 1, 5},
|
||||
{"b1-six", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000600"), 1, 6},
|
||||
{"b1-seven", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000700"), 1, 7},
|
||||
{"b1-eight", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), 1, 8},
|
||||
|
||||
// split off second bucket and fill it
|
||||
{"b2-one", bits.FromHexP("001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 9},
|
||||
{"b2-two", bits.FromHexP("002000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 10},
|
||||
{"b2-three", bits.FromHexP("003000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 11},
|
||||
{"b2-four", bits.FromHexP("004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 12},
|
||||
{"b2-five", bits.FromHexP("005000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 13},
|
||||
{"b2-six", bits.FromHexP("006000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 14},
|
||||
{"b2-seven", bits.FromHexP("007000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 15},
|
||||
|
||||
// at this point there are two buckets. the first has 7 contacts, the second has 8
|
||||
|
||||
// inserts into the second bucket should be skipped
|
||||
{"dont-split", bits.FromHexP("009000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2, 15},
|
||||
|
||||
// ... unless the ID is closer than the kth-closest contact
|
||||
{"split-kth-closest", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), 2, 16},
|
||||
|
||||
{"b3-two", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), 3, 17},
|
||||
{"b3-three", bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), 3, 18},
|
||||
}
|
||||
|
||||
for i, testCase := range tests {
|
||||
rt.Update(Contact{testCase.id, net.ParseIP("127.0.0.1"), 8000 + i, 0})
|
||||
|
||||
if len(rt.buckets) != testCase.expectedBucketCount {
|
||||
t.Errorf("failed test case %s. there should be %d buckets, got %d", testCase.name, testCase.expectedBucketCount, len(rt.buckets))
|
||||
}
|
||||
if rt.Count() != testCase.expectedTotalContacts {
|
||||
t.Errorf("failed test case %s. there should be %d contacts, got %d", testCase.name, testCase.expectedTotalContacts, rt.Count())
|
||||
}
|
||||
}
|
||||
|
||||
var testRanges = []struct {
|
||||
id bits.Bitmap
|
||||
expected int
|
||||
}{
|
||||
{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), 0},
|
||||
{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), 0},
|
||||
{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000410"), 1},
|
||||
{bits.FromHexP("0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000007f0"), 1},
|
||||
{bits.FromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), 2},
|
||||
{bits.FromHexP("F00000000000000000000000000000000000000000000000000F00000000000000000000000000000000000000000000"), 2},
|
||||
{bits.FromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), 2},
|
||||
}
|
||||
|
||||
for _, tt := range testRanges {
|
||||
bucket := bucketNumFor(rt, tt.id)
|
||||
if bucket != tt.expected {
|
||||
t.Errorf("bucketFor(%s, %s) => got %d, expected %d", tt.id.Hex(), rt.id.Hex(), bucket, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bucketNumFor(rt *routingTable, target bits.Bitmap) int {
|
||||
if rt.id.Equals(target) {
|
||||
panic("routing table does not have a bucket for its own id")
|
||||
}
|
||||
distance := target.Xor(rt.id)
|
||||
for i := range rt.buckets {
|
||||
if rt.buckets[i].Range.Contains(distance) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
panic("target is not contained in any buckets")
|
||||
}
|
||||
|
||||
func TestBucket_Split_Continuous(t *testing.T) {
|
||||
b := newBucket(bits.MaxRange())
|
||||
|
||||
left, right := b.Split()
|
||||
|
||||
if !left.Range.Start.Equals(b.Range.Start) {
|
||||
t.Errorf("left bucket start does not align with original bucket start. got %s, expected %s", left.Range.Start, b.Range.Start)
|
||||
}
|
||||
|
||||
if !right.Range.End.Equals(b.Range.End) {
|
||||
t.Errorf("right bucket end does not align with original bucket end. got %s, expected %s", right.Range.End, b.Range.End)
|
||||
}
|
||||
|
||||
leftEndNext := (&big.Int{}).Add(left.Range.End.Big(), big.NewInt(1))
|
||||
if !bits.FromBigP(leftEndNext).Equals(right.Range.Start) {
|
||||
t.Errorf("there's a gap between left bucket end and right bucket start. end is %s, start is %s", left.Range.End, right.Range.Start)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBucket_Split_KthClosest_DoSplit(t *testing.T) {
|
||||
rt := newRoutingTable(bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
|
||||
|
||||
// add 4 low IDs
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), net.ParseIP("127.0.0.1"), 8001, 0})
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), net.ParseIP("127.0.0.1"), 8002, 0})
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), net.ParseIP("127.0.0.1"), 8003, 0})
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), net.ParseIP("127.0.0.1"), 8004, 0})
|
||||
|
||||
// add 4 high IDs
|
||||
rt.Update(Contact{bits.FromHexP("800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8001, 0})
|
||||
rt.Update(Contact{bits.FromHexP("900000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8002, 0})
|
||||
rt.Update(Contact{bits.FromHexP("a00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8003, 0})
|
||||
rt.Update(Contact{bits.FromHexP("b00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8004, 0})
|
||||
|
||||
// split the bucket and fill the high bucket
|
||||
rt.Update(Contact{bits.FromHexP("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8005, 0})
|
||||
rt.Update(Contact{bits.FromHexP("d00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8006, 0})
|
||||
rt.Update(Contact{bits.FromHexP("e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8007, 0})
|
||||
rt.Update(Contact{bits.FromHexP("f00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8008, 0})
|
||||
|
||||
// add a high ID. it should split because the high ID is closer than the Kth closest ID
|
||||
rt.Update(Contact{bits.FromHexP("910000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.1"), 8009, 0})
|
||||
|
||||
if len(rt.buckets) != 3 {
|
||||
t.Errorf("expected 3 buckets, got %d", len(rt.buckets))
|
||||
}
|
||||
if rt.Count() != 13 {
|
||||
t.Errorf("expected 13 contacts, got %d", rt.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBucket_Split_KthClosest_DontSplit(t *testing.T) {
|
||||
rt := newRoutingTable(bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
|
||||
|
||||
// add 4 low IDs
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), net.ParseIP("127.0.0.1"), 8001, 0})
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), net.ParseIP("127.0.0.1"), 8002, 0})
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), net.ParseIP("127.0.0.1"), 8003, 0})
|
||||
rt.Update(Contact{bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), net.ParseIP("127.0.0.1"), 8004, 0})
|
||||
|
||||
// add 4 high IDs
|
||||
rt.Update(Contact{bits.FromHexP("800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8001, 0})
|
||||
rt.Update(Contact{bits.FromHexP("900000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8002, 0})
|
||||
rt.Update(Contact{bits.FromHexP("a00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8003, 0})
|
||||
rt.Update(Contact{bits.FromHexP("b00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8004, 0})
|
||||
|
||||
// split the bucket and fill the high bucket
|
||||
rt.Update(Contact{bits.FromHexP("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8005, 0})
|
||||
rt.Update(Contact{bits.FromHexP("d00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8006, 0})
|
||||
rt.Update(Contact{bits.FromHexP("e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8007, 0})
|
||||
rt.Update(Contact{bits.FromHexP("f00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.2"), 8008, 0})
|
||||
|
||||
// add a really high ID. this should not split because its not closer than the Kth closest ID
|
||||
rt.Update(Contact{bits.FromHexP("ffff00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), net.ParseIP("127.0.0.1"), 8009, 0})
|
||||
|
||||
if len(rt.buckets) != 2 {
|
||||
t.Errorf("expected 2 buckets, got %d", len(rt.buckets))
|
||||
}
|
||||
if rt.Count() != 12 {
|
||||
t.Errorf("expected 12 contacts, got %d", rt.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingTable_GetClosest(t *testing.T) {
|
||||
n1 := bits.FromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
n2 := bits.FromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
n3 := bits.FromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
rt := newRoutingTable(n1)
|
||||
rt.Update(Contact{n2, net.ParseIP("127.0.0.1"), 8001, 0})
|
||||
rt.Update(Contact{n3, net.ParseIP("127.0.0.1"), 8002, 0})
|
||||
|
||||
contacts := rt.GetClosest(bits.FromHexP("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
|
||||
if len(contacts) != 1 {
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
if !contacts[0].ID.Equals(n3) {
|
||||
t.Error(contacts[0])
|
||||
}
|
||||
contacts = rt.GetClosest(n2, 10)
|
||||
if len(contacts) != 2 {
|
||||
t.Error(len(contacts))
|
||||
return
|
||||
}
|
||||
if !contacts[0].ID.Equals(n2) {
|
||||
t.Error(contacts[0])
|
||||
}
|
||||
if !contacts[1].ID.Equals(n3) {
|
||||
t.Error(contacts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingTable_GetClosest_Empty(t *testing.T) {
|
||||
n1 := bits.FromShortHexP("1")
|
||||
rt := newRoutingTable(n1)
|
||||
|
||||
contacts := rt.GetClosest(bits.FromShortHexP("a"), 3)
|
||||
if len(contacts) != 0 {
|
||||
t.Error("there shouldn't be any contacts")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingTable_Refresh(t *testing.T) {
|
||||
t.Skip("TODO: test routing table refreshing")
|
||||
}
|
||||
|
||||
func TestRoutingTable_MoveToBack(t *testing.T) {
|
||||
tt := map[string]struct {
|
||||
data []peer
|
||||
index int
|
||||
expected []peer
|
||||
}{
|
||||
"simpleMove": {
|
||||
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
index: 1,
|
||||
expected: []peer{{NumFailures: 0}, {NumFailures: 2}, {NumFailures: 3}, {NumFailures: 1}},
|
||||
},
|
||||
"moveFirst": {
|
||||
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
index: 0,
|
||||
expected: []peer{{NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}, {NumFailures: 0}},
|
||||
},
|
||||
"moveLast": {
|
||||
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
index: 3,
|
||||
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
},
|
||||
"largeIndex": {
|
||||
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
index: 27,
|
||||
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
},
|
||||
"negativeIndex": {
|
||||
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
index: -12,
|
||||
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tt {
|
||||
moveToBack(test.data, test.index)
|
||||
expected := make([]string, len(test.expected))
|
||||
actual := make([]string, len(test.data))
|
||||
for i := range actual {
|
||||
actual[i] = strconv.Itoa(test.data[i].NumFailures)
|
||||
expected[i] = strconv.Itoa(test.expected[i].NumFailures)
|
||||
}
|
||||
|
||||
expJoin := strings.Join(expected, ",")
|
||||
actJoin := strings.Join(actual, ",")
|
||||
|
||||
if actJoin != expJoin {
|
||||
t.Errorf("%s failed: got %s; expected %s", name, actJoin, expJoin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingTable_Save(t *testing.T) {
|
||||
t.Skip("fix me")
|
||||
id := bits.FromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41")
|
||||
rt := newRoutingTable(id)
|
||||
|
||||
for i, b := range rt.buckets {
|
||||
for j := 0; j < bucketSize; j++ {
|
||||
toAdd := b.Range.Start.Add(bits.FromShortHexP(strconv.Itoa(j)))
|
||||
if toAdd.Cmp(b.Range.End) <= 0 {
|
||||
rt.Update(Contact{
|
||||
ID: b.Range.Start.Add(bits.FromShortHexP(strconv.Itoa(j))),
|
||||
IP: net.ParseIP("1.2.3." + strconv.Itoa(j)),
|
||||
Port: 1 + i*bucketSize + j,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(rt, "", " ")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
goldie.Assert(t, t.Name(), data)
|
||||
}
|
||||
|
||||
func TestRoutingTable_Load_ID(t *testing.T) {
|
||||
t.Skip("fix me")
|
||||
id := "1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"
|
||||
data := []byte(`{"id": "` + id + `","contacts": []}`)
|
||||
|
||||
rt := routingTable{}
|
||||
err := json.Unmarshal(data, &rt)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if rt.id.Hex() != id {
|
||||
t.Error("id mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingTable_Load_Contacts(t *testing.T) {
|
||||
t.Skip("TODO")
|
||||
}
|
187
dht/rpc.go
Normal file
187
dht/rpc.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
rpc2 "github.com/gorilla/rpc/v2"
|
||||
"github.com/gorilla/rpc/v2/json"
|
||||
)
|
||||
|
||||
type rpcReceiver struct {
|
||||
dht *DHT
|
||||
}
|
||||
|
||||
type RpcPingArgs struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
func (rpc *rpcReceiver) Ping(r *http.Request, args *RpcPingArgs, result *string) error {
|
||||
if args.Address == "" {
|
||||
return errors.Err("no address given")
|
||||
}
|
||||
|
||||
err := rpc.dht.Ping(args.Address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*result = pingSuccessResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
type RpcFindArgs struct {
|
||||
Key string
|
||||
NodeID string
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
func (rpc *rpcReceiver) FindNode(r *http.Request, args *RpcFindArgs, result *[]Contact) error {
|
||||
key, err := bits.FromHex(args.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toQuery, err := bits.FromHex(args.NodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c := Contact{ID: toQuery, IP: net.ParseIP(args.IP), Port: args.Port}
|
||||
req := Request{Method: findNodeMethod, Arg: &key}
|
||||
|
||||
nodeResponse := rpc.dht.node.Send(c, req)
|
||||
if nodeResponse != nil && nodeResponse.Contacts != nil {
|
||||
*result = nodeResponse.Contacts
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RpcFindValueResult struct {
|
||||
Contacts []Contact
|
||||
Value string
|
||||
}
|
||||
|
||||
func (rpc *rpcReceiver) FindValue(r *http.Request, args *RpcFindArgs, result *RpcFindValueResult) error {
|
||||
key, err := bits.FromHex(args.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toQuery, err := bits.FromHex(args.NodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c := Contact{ID: toQuery, IP: net.ParseIP(args.IP), Port: args.Port}
|
||||
req := Request{Arg: &key, Method: findValueMethod}
|
||||
|
||||
nodeResponse := rpc.dht.node.Send(c, req)
|
||||
if nodeResponse != nil && nodeResponse.FindValueKey != "" {
|
||||
*result = RpcFindValueResult{Value: nodeResponse.FindValueKey}
|
||||
return nil
|
||||
}
|
||||
if nodeResponse != nil && nodeResponse.Contacts != nil {
|
||||
*result = RpcFindValueResult{Contacts: nodeResponse.Contacts}
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Err("not sure what happened")
|
||||
}
|
||||
|
||||
type RpcIterativeFindValueArgs struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
type RpcIterativeFindValueResult struct {
|
||||
Contacts []Contact
|
||||
FoundValue bool
|
||||
}
|
||||
|
||||
func (rpc *rpcReceiver) IterativeFindValue(r *http.Request, args *RpcIterativeFindValueArgs, result *RpcIterativeFindValueResult) error {
|
||||
key, err := bits.FromHex(args.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
foundContacts, found, err := FindContacts(rpc.dht.node, key, false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.Contacts = foundContacts
|
||||
result.FoundValue = found
|
||||
return nil
|
||||
}
|
||||
|
||||
type RpcBucketResponse struct {
|
||||
Start string
|
||||
End string
|
||||
NumContacts int
|
||||
Contacts []Contact
|
||||
}
|
||||
|
||||
type RpcRoutingTableResponse struct {
|
||||
NodeID string
|
||||
NumBuckets int
|
||||
Buckets []RpcBucketResponse
|
||||
}
|
||||
|
||||
func (rpc *rpcReceiver) GetRoutingTable(r *http.Request, args *struct{}, result *RpcRoutingTableResponse) error {
|
||||
result.NodeID = rpc.dht.node.id.String()
|
||||
result.NumBuckets = len(rpc.dht.node.rt.buckets)
|
||||
for _, b := range rpc.dht.node.rt.buckets {
|
||||
result.Buckets = append(result.Buckets, RpcBucketResponse{
|
||||
Start: b.Range.Start.String(),
|
||||
End: b.Range.End.String(),
|
||||
NumContacts: b.Len(),
|
||||
Contacts: b.Contacts(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rpc *rpcReceiver) AddKnownNode(r *http.Request, args *Contact, result *string) error {
|
||||
rpc.dht.node.AddKnownNode(*args)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dht *DHT) runRPCServer(port int) {
|
||||
addr := "0.0.0.0:" + strconv.Itoa(port)
|
||||
|
||||
s := rpc2.NewServer()
|
||||
s.RegisterCodec(json.NewCodec(), "application/json")
|
||||
s.RegisterCodec(json.NewCodec(), "application/json;charset=UTF-8")
|
||||
err := s.RegisterService(&rpcReceiver{dht: dht}, "rpc")
|
||||
if err != nil {
|
||||
log.Error(errors.Prefix("registering rpc service", err))
|
||||
return
|
||||
}
|
||||
|
||||
handler := mux.NewRouter()
|
||||
handler.Handle("/", s)
|
||||
server := &http.Server{Addr: addr, Handler: handler}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Printf("RPC server listening on %s", addr)
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-dht.grp.Ch()
|
||||
err = server.Shutdown(context.Background())
|
||||
if err != nil {
|
||||
log.Error(errors.Prefix("shutting down rpc service", err))
|
||||
return
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
62
dht/store.go
Normal file
62
dht/store.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
// TODO: expire stored data after tExpire time
|
||||
|
||||
type contactStore struct {
|
||||
// map of blob hashes to (map of node IDs to bools)
|
||||
hashes map[bits.Bitmap]map[bits.Bitmap]bool
|
||||
// stores the peers themselves, so they can be updated in one place
|
||||
contacts map[bits.Bitmap]Contact
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func newStore() *contactStore {
|
||||
return &contactStore{
|
||||
hashes: make(map[bits.Bitmap]map[bits.Bitmap]bool),
|
||||
contacts: make(map[bits.Bitmap]Contact),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *contactStore) Upsert(blobHash bits.Bitmap, contact Contact) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if _, ok := s.hashes[blobHash]; !ok {
|
||||
s.hashes[blobHash] = make(map[bits.Bitmap]bool)
|
||||
}
|
||||
s.hashes[blobHash][contact.ID] = true
|
||||
s.contacts[contact.ID] = contact
|
||||
}
|
||||
|
||||
func (s *contactStore) Get(blobHash bits.Bitmap) []Contact {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
var contacts []Contact
|
||||
if ids, ok := s.hashes[blobHash]; ok {
|
||||
for id := range ids {
|
||||
contact, ok := s.contacts[id]
|
||||
if !ok {
|
||||
panic("node id in IDs list, but not in nodeInfo")
|
||||
}
|
||||
contacts = append(contacts, contact)
|
||||
}
|
||||
}
|
||||
return contacts
|
||||
}
|
||||
|
||||
func (s *contactStore) RemoveTODO(contact Contact) {
|
||||
// TODO: remove peer from everywhere
|
||||
}
|
||||
|
||||
func (s *contactStore) CountStoredHashes() int {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return len(s.hashes)
|
||||
}
|
312
dht/testing.go
Normal file
312
dht/testing.go
Normal file
|
@ -0,0 +1,312 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/lbry.go/errors"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
var testingDHTIP = "127.0.0.1"
|
||||
var testingDHTFirstPort = 21000
|
||||
|
||||
// TestingCreateNetwork initializes a testable DHT network with a specific number of nodes, with bootstrap and concurrent options.
|
||||
func TestingCreateNetwork(t *testing.T, numNodes int, bootstrap, concurrent bool) (*BootstrapNode, []*DHT) {
|
||||
var bootstrapNode *BootstrapNode
|
||||
var seeds []string
|
||||
|
||||
if bootstrap {
|
||||
bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort)
|
||||
seeds = []string{bootstrapAddress}
|
||||
bootstrapNode = NewBootstrapNode(bits.Rand(), 0, bootstrapDefaultRefreshDuration)
|
||||
listener, err := net.ListenPacket(Network, bootstrapAddress)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = bootstrapNode.Connect(listener.(*net.UDPConn))
|
||||
if err != nil {
|
||||
t.Error("error connecting bootstrap node - ", err)
|
||||
}
|
||||
}
|
||||
|
||||
if numNodes < 1 {
|
||||
return bootstrapNode, nil
|
||||
}
|
||||
|
||||
firstPort := testingDHTFirstPort + 1
|
||||
dhts := make([]*DHT, numNodes)
|
||||
|
||||
for i := 0; i < numNodes; i++ {
|
||||
c := NewStandardConfig()
|
||||
c.NodeID = bits.Rand().Hex()
|
||||
c.Address = testingDHTIP + ":" + strconv.Itoa(firstPort+i)
|
||||
c.SeedNodes = seeds
|
||||
dht := New(c)
|
||||
|
||||
go func() {
|
||||
err := dht.Start()
|
||||
if err != nil {
|
||||
t.Error("error starting dht - ", err)
|
||||
}
|
||||
}()
|
||||
if !concurrent {
|
||||
dht.WaitUntilJoined()
|
||||
}
|
||||
dhts[i] = dht
|
||||
}
|
||||
|
||||
if concurrent {
|
||||
for _, d := range dhts {
|
||||
d.WaitUntilJoined()
|
||||
}
|
||||
}
|
||||
|
||||
return bootstrapNode, dhts
|
||||
}
|
||||
|
||||
type timeoutErr struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (t timeoutErr) Timeout() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t timeoutErr) Temporary() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: just use a normal net.Conn instead of this mock conn
|
||||
|
||||
type testUDPPacket struct {
|
||||
data []byte
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
type testUDPConn struct {
|
||||
addr *net.UDPAddr
|
||||
toRead chan testUDPPacket
|
||||
writes chan testUDPPacket
|
||||
|
||||
readDeadline time.Time
|
||||
}
|
||||
|
||||
func newTestUDPConn(addr string) *testUDPConn {
|
||||
parts := strings.Split(addr, ":")
|
||||
if len(parts) != 2 {
|
||||
panic("addr needs ip and port")
|
||||
}
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &testUDPConn{
|
||||
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
|
||||
toRead: make(chan testUDPPacket),
|
||||
writes: make(chan testUDPPacket),
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||
var timeoutCh <-chan time.Time
|
||||
if !t.readDeadline.IsZero() {
|
||||
timeoutCh = time.After(time.Until(t.readDeadline))
|
||||
}
|
||||
|
||||
select {
|
||||
case packet, ok := <-t.toRead:
|
||||
if !ok {
|
||||
return 0, nil, errors.Err("conn closed")
|
||||
}
|
||||
n := copy(b, packet.data)
|
||||
return n, packet.addr, nil
|
||||
case <-timeoutCh:
|
||||
return 0, nil, timeoutErr{errors.Err("timeout")}
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||
t.writes <- testUDPPacket{data: b, addr: addr}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *testUDPConn) SetReadDeadline(tm time.Time) error {
|
||||
t.readDeadline = tm
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testUDPConn) SetWriteDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testUDPConn) Close() error {
|
||||
close(t.toRead)
|
||||
t.writes = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
|
||||
if len(resp) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(resp))
|
||||
}
|
||||
|
||||
_, ok := resp[headerTypeField]
|
||||
if !ok {
|
||||
t.Error("missing type field")
|
||||
} else {
|
||||
rType, ok := resp[headerTypeField].(int64)
|
||||
if !ok {
|
||||
t.Error("type is not an integer")
|
||||
} else if rType != responseType {
|
||||
t.Error("unexpected response type")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = resp[headerMessageIDField]
|
||||
if !ok {
|
||||
t.Error("missing message id field")
|
||||
} else {
|
||||
rMessageID, ok := resp[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != string(id[:]) {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
if len(rMessageID) != messageIDLength {
|
||||
t.Errorf("message ID should be %d chars long", messageIDLength)
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = resp[headerNodeIDField]
|
||||
if !ok {
|
||||
t.Error("missing node id field")
|
||||
} else {
|
||||
rNodeID, ok := resp[headerNodeIDField].(string)
|
||||
if !ok {
|
||||
t.Error("node ID is not a string")
|
||||
} else if rNodeID != dhtNodeID {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
if len(rNodeID) != nodeIDLength {
|
||||
t.Errorf("node ID should be %d chars long", nodeIDLength)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
|
||||
if len(contacts) != len(nodes) {
|
||||
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
|
||||
return
|
||||
}
|
||||
|
||||
foundNodes := make(map[string]bool)
|
||||
|
||||
for _, c := range contacts {
|
||||
contact, ok := c.([]interface{})
|
||||
if !ok {
|
||||
t.Error("contact is not a list")
|
||||
return
|
||||
}
|
||||
|
||||
if len(contact) != 3 {
|
||||
t.Error("contact must be 3 items")
|
||||
return
|
||||
}
|
||||
|
||||
var currNode Contact
|
||||
currNodeFound := false
|
||||
|
||||
id, ok := contact[0].(string)
|
||||
if !ok {
|
||||
t.Error("contact id is not a string")
|
||||
} else {
|
||||
if _, ok := foundNodes[id]; ok {
|
||||
t.Errorf("contact %s appears multiple times", id)
|
||||
continue
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if n.ID.RawString() == id {
|
||||
currNode = n
|
||||
currNodeFound = true
|
||||
foundNodes[id] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !currNodeFound {
|
||||
t.Errorf("unexpected contact %s", id)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
ip, ok := contact[1].(string)
|
||||
if !ok {
|
||||
t.Error("contact IP is not a string")
|
||||
} else if !currNode.IP.Equal(net.ParseIP(ip)) {
|
||||
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.IP.String())
|
||||
}
|
||||
|
||||
port, ok := contact[2].(int64)
|
||||
if !ok {
|
||||
t.Error("contact port is not an int")
|
||||
} else if int(port) != currNode.Port {
|
||||
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.Port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
|
||||
if len(contacts) != len(nodes) {
|
||||
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
|
||||
return
|
||||
}
|
||||
|
||||
foundNodes := make(map[string]bool)
|
||||
|
||||
for _, c := range contacts {
|
||||
compact, ok := c.(string)
|
||||
if !ok {
|
||||
t.Error("contact is not a string")
|
||||
return
|
||||
}
|
||||
|
||||
contact := Contact{}
|
||||
err := contact.UnmarshalCompact([]byte(compact))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var currNode Contact
|
||||
currNodeFound := false
|
||||
|
||||
if _, ok := foundNodes[contact.ID.Hex()]; ok {
|
||||
t.Errorf("contact %s appears multiple times", contact.ID.Hex())
|
||||
continue
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if n.ID.Equals(contact.ID) {
|
||||
currNode = n
|
||||
currNodeFound = true
|
||||
foundNodes[contact.ID.Hex()] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !currNodeFound {
|
||||
t.Errorf("unexpected contact %s", contact.ID.Hex())
|
||||
continue
|
||||
}
|
||||
|
||||
if !currNode.IP.Equal(contact.IP) {
|
||||
t.Errorf("contact IP mismatch. got %s; expected %s", contact.IP.String(), currNode.IP.String())
|
||||
}
|
||||
|
||||
if contact.Port != currNode.Port {
|
||||
t.Errorf("contact port mismatch. got %d; expected %d", contact.Port, currNode.Port)
|
||||
}
|
||||
}
|
||||
}
|
71
dht/token_cache.go
Normal file
71
dht/token_cache.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
|
||||
"github.com/lbryio/lbry.go/stop"
|
||||
)
|
||||
|
||||
// TODO: this should be moved out of dht and into node, and it should be completely hidden inside node. dht should not need to know about tokens
|
||||
|
||||
type tokenCacheEntry struct {
|
||||
token string
|
||||
receivedAt time.Time
|
||||
}
|
||||
|
||||
type tokenCache struct {
|
||||
node *Node
|
||||
tokens map[string]tokenCacheEntry
|
||||
expiration time.Duration
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
func newTokenCache(node *Node, expiration time.Duration) *tokenCache {
|
||||
tc := &tokenCache{}
|
||||
tc.node = node
|
||||
tc.tokens = make(map[string]tokenCacheEntry)
|
||||
tc.expiration = expiration
|
||||
tc.lock = &sync.RWMutex{}
|
||||
return tc
|
||||
}
|
||||
|
||||
// TODO: if store fails, get new token. can happen if a node restarts but we have the token cached
|
||||
|
||||
func (tc *tokenCache) Get(c Contact, hash bits.Bitmap, cancelCh stop.Chan) string {
|
||||
tc.lock.RLock()
|
||||
token, exists := tc.tokens[c.String()]
|
||||
tc.lock.RUnlock()
|
||||
|
||||
if exists && time.Since(token.receivedAt) < tc.expiration {
|
||||
return token.token
|
||||
}
|
||||
|
||||
resCh := tc.node.SendAsync(c, Request{
|
||||
Method: findValueMethod,
|
||||
Arg: &hash,
|
||||
})
|
||||
|
||||
var res *Response
|
||||
|
||||
select {
|
||||
case res = <-resCh:
|
||||
case <-cancelCh:
|
||||
return ""
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
tc.lock.Lock()
|
||||
tc.tokens[c.String()] = tokenCacheEntry{
|
||||
token: res.Token,
|
||||
receivedAt: time.Now(),
|
||||
}
|
||||
tc.lock.Unlock()
|
||||
|
||||
return res.Token
|
||||
}
|
78
dht/token_manager.go
Normal file
78
dht/token_manager.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/lbry.go/stop"
|
||||
"github.com/lbryio/reflector.go/dht/bits"
|
||||
)
|
||||
|
||||
type tokenManager struct {
|
||||
secret []byte
|
||||
prevSecret []byte
|
||||
lock *sync.RWMutex
|
||||
stop *stop.Group
|
||||
}
|
||||
|
||||
func (tm *tokenManager) Start(interval time.Duration) {
|
||||
tm.secret = make([]byte, 64)
|
||||
tm.prevSecret = make([]byte, 64)
|
||||
tm.lock = &sync.RWMutex{}
|
||||
tm.stop = stop.New()
|
||||
|
||||
tm.rotateSecret()
|
||||
|
||||
tm.stop.Add(1)
|
||||
go func() {
|
||||
defer tm.stop.Done()
|
||||
tick := time.NewTicker(interval)
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
tm.rotateSecret()
|
||||
case <-tm.stop.Ch():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (tm *tokenManager) Stop() {
|
||||
tm.stop.StopAndWait()
|
||||
}
|
||||
|
||||
func (tm *tokenManager) Get(nodeID bits.Bitmap, addr *net.UDPAddr) string {
|
||||
return genToken(tm.secret, nodeID, addr)
|
||||
}
|
||||
|
||||
func (tm *tokenManager) Verify(token string, nodeID bits.Bitmap, addr *net.UDPAddr) bool {
|
||||
return token == genToken(tm.secret, nodeID, addr) || token == genToken(tm.prevSecret, nodeID, addr)
|
||||
}
|
||||
|
||||
func genToken(secret []byte, nodeID bits.Bitmap, addr *net.UDPAddr) string {
|
||||
buf := bytes.Buffer{}
|
||||
buf.Write(nodeID[:])
|
||||
buf.Write(addr.IP)
|
||||
buf.WriteString(strconv.Itoa(addr.Port))
|
||||
buf.Write(secret)
|
||||
t := sha256.Sum256(buf.Bytes())
|
||||
return string(t[:])
|
||||
}
|
||||
|
||||
func (tm *tokenManager) rotateSecret() {
|
||||
tm.lock.Lock()
|
||||
defer tm.lock.Unlock()
|
||||
|
||||
copy(tm.prevSecret, tm.secret)
|
||||
|
||||
_, err := rand.Read(tm.secret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue