added routing table saving, bitmap operations, lots of tests

This commit is contained in:
Alex Grintsvayg 2018-05-19 13:05:30 -04:00
parent b9ee0b0644
commit 14cceda81e
15 changed files with 3858 additions and 244 deletions

View file

@ -1,8 +1,10 @@
package dht package dht
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"strings"
"github.com/lbryio/errors.go" "github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
@ -14,6 +16,19 @@ func (b Bitmap) RawString() string {
return string(b[:]) return string(b[:])
} }
// BString returns the bitmap as a string of 0s and 1s
func (b Bitmap) BString() string {
var buf bytes.Buffer
for i := 0; i < nodeIDBits; i++ {
if b.Get(i) {
buf.WriteString("1")
} else {
buf.WriteString("0")
}
}
return buf.String()
}
func (b Bitmap) Hex() string { func (b Bitmap) Hex() string {
return hex.EncodeToString(b[:]) return hex.EncodeToString(b[:])
} }
@ -22,6 +37,14 @@ func (b Bitmap) HexShort() string {
return hex.EncodeToString(b[:4]) return hex.EncodeToString(b[:4])
} }
func (b Bitmap) HexSimplified() string {
simple := strings.TrimLeft(b.Hex(), "0")
if simple == "" {
simple = "0"
}
return simple
}
func (b Bitmap) Equals(other Bitmap) bool { func (b Bitmap) Equals(other Bitmap) bool {
for k := range b { for k := range b {
if b[k] != other[k] { if b[k] != other[k] {
@ -40,6 +63,35 @@ func (b Bitmap) Less(other interface{}) bool {
return false return false
} }
func (b Bitmap) LessOrEqual(other interface{}) bool {
if bm, ok := other.(Bitmap); ok && b.Equals(bm) {
return true
}
return b.Less(other)
}
func (b Bitmap) Greater(other interface{}) bool {
for k := range b {
if b[k] != other.(Bitmap)[k] {
return b[k] > other.(Bitmap)[k]
}
}
return false
}
func (b Bitmap) GreaterOrEqual(other interface{}) bool {
if bm, ok := other.(Bitmap); ok && b.Equals(bm) {
return true
}
return b.Greater(other)
}
func (b Bitmap) Copy() Bitmap {
var ret Bitmap
copy(ret[:], b[:])
return ret
}
func (b Bitmap) Xor(other Bitmap) Bitmap { func (b Bitmap) Xor(other Bitmap) Bitmap {
var ret Bitmap var ret Bitmap
for k := range b { for k := range b {
@ -48,6 +100,69 @@ func (b Bitmap) Xor(other Bitmap) Bitmap {
return ret return ret
} }
func (b Bitmap) And(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] & other[k]
}
return ret
}
func (b Bitmap) Or(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] | other[k]
}
return ret
}
func (b Bitmap) Not() Bitmap {
var ret Bitmap
for k := range b {
ret[k] = ^b[k]
}
return ret
}
func (b Bitmap) add(other Bitmap) (Bitmap, bool) {
var ret Bitmap
carry := false
for i := nodeIDBits - 1; i >= 0; i-- {
bBit := getBit(b[:], i)
oBit := getBit(other[:], i)
setBit(ret[:], i, bBit != oBit != carry)
carry = (bBit && oBit) || (bBit && carry) || (oBit && carry)
}
return ret, carry
}
func (b Bitmap) Add(other Bitmap) Bitmap {
ret, carry := b.add(other)
if carry {
panic("overflow in bitmap addition")
}
return ret
}
func (b Bitmap) Sub(other Bitmap) Bitmap {
if b.Less(other) {
panic("negative bitmaps not supported")
}
complement, _ := other.Not().add(BitmapFromShortHexP("1"))
ret, _ := b.add(complement)
return ret
}
func (b Bitmap) Get(n int) bool {
return getBit(b[:], n)
}
func (b Bitmap) Set(n int, one bool) Bitmap {
ret := b.Copy()
setBit(ret[:], n, one)
return ret
}
// PrefixLen returns the number of leading 0 bits // PrefixLen returns the number of leading 0 bits
func (b Bitmap) PrefixLen() int { func (b Bitmap) PrefixLen() int {
for i := range b { for i := range b {
@ -57,20 +172,46 @@ func (b Bitmap) PrefixLen() int {
} }
} }
} }
return numBuckets return nodeIDBits
} }
// ZeroPrefix returns a copy of b with the first n bits set to 0 // Prefix returns a copy of b with the first n bits set to 1 (if `one` is true) or 0 (if `one` is false)
// https://stackoverflow.com/a/23192263/182709 // https://stackoverflow.com/a/23192263/182709
func (b Bitmap) ZeroPrefix(n int) Bitmap { func (b Bitmap) Prefix(n int, one bool) Bitmap {
var ret Bitmap ret := b.Copy()
copy(ret[:], b[:])
Outer: Outer:
for i := range ret { for i := range ret {
for j := 0; j < 8; j++ { for j := 0; j < 8; j++ {
if i*8+j < n { if i*8+j < n {
if one {
ret[i] |= 1 << uint(7-j)
} else {
ret[i] &= ^(1 << uint(7-j)) ret[i] &= ^(1 << uint(7-j))
}
} else {
break Outer
}
}
}
return ret
}
// Syffix returns a copy of b with the last n bits set to 1 (if `one` is true) or 0 (if `one` is false)
// https://stackoverflow.com/a/23192263/182709
func (b Bitmap) Suffix(n int, one bool) Bitmap {
ret := b.Copy()
Outer:
for i := len(ret) - 1; i >= 0; i-- {
for j := 7; j >= 0; j-- {
if i*8+j >= nodeIDBits-n {
if one {
ret[i] |= 1 << uint(7-j)
} else {
ret[i] &= ^(1 << uint(7-j))
}
} else { } else {
break Outer break Outer
} }
@ -145,6 +286,18 @@ func BitmapFromHexP(hexStr string) Bitmap {
return bmp return bmp
} }
func BitmapFromShortHex(hexStr string) (Bitmap, error) {
return BitmapFromHex(strings.Repeat("0", nodeIDLength*2-len(hexStr)) + hexStr)
}
func BitmapFromShortHexP(hexStr string) Bitmap {
bmp, err := BitmapFromShortHex(hexStr)
if err != nil {
panic(err)
}
return bmp
}
func RandomBitmapP() Bitmap { func RandomBitmapP() Bitmap {
var id Bitmap var id Bitmap
_, err := rand.Read(id[:]) _, err := rand.Read(id[:])
@ -153,3 +306,28 @@ func RandomBitmapP() Bitmap {
} }
return id return id
} }
func RandomBitmapInRangeP(low, high Bitmap) Bitmap {
diff := high.Sub(low)
r := RandomBitmapP()
for r.Greater(diff) {
r = r.Sub(diff)
}
return r.Add(low)
}
func getBit(b []byte, n int) bool {
i := n / 8
j := n % 8
return b[i]&(1<<uint(7-j)) > 0
}
func setBit(b []byte, n int, one bool) {
i := n / 8
j := n % 8
if one {
b[i] |= 1 << uint(7-j)
} else {
b[i] &= ^(1 << uint(7-j))
}
}

View file

@ -1,6 +1,7 @@
package dht package dht
import ( import (
"fmt"
"testing" "testing"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
@ -51,6 +52,84 @@ func TestBitmap(t *testing.T) {
} }
} }
func TestBitmap_GetBit(t *testing.T) {
tt := []struct {
hex string
bit int
expected bool
panic bool
}{
//{hex: "0", bit: 385, one: true, expected: "1", panic:true}, // should error
//{hex: "0", bit: 384, one: true, expected: "1", panic:true},
{bit: 383, expected: false, panic: false},
{bit: 382, expected: true, panic: false},
{bit: 381, expected: false, panic: false},
{bit: 380, expected: true, panic: false},
}
b := BitmapFromShortHexP("a")
for _, test := range tt {
actual := getBit(b[:], test.bit)
if test.expected != actual {
t.Errorf("getting bit %d of %s: expected %t, got %t", test.bit, b.HexSimplified(), test.expected, actual)
}
}
}
func TestBitmap_SetBit(t *testing.T) {
tt := []struct {
hex string
bit int
one bool
expected string
panic bool
}{
{hex: "0", bit: 383, one: true, expected: "1", panic: false},
{hex: "0", bit: 382, one: true, expected: "2", panic: false},
{hex: "0", bit: 381, one: true, expected: "4", panic: false},
{hex: "0", bit: 385, one: true, expected: "1", panic: true},
{hex: "0", bit: 384, one: true, expected: "1", panic: true},
}
for _, test := range tt {
expected := BitmapFromShortHexP(test.expected)
actual := BitmapFromShortHexP(test.hex)
if test.panic {
assertPanic(t, fmt.Sprintf("setting bit %d to %t", test.bit, test.one), func() { setBit(actual[:], test.bit, test.one) })
} else {
setBit(actual[:], test.bit, test.one)
if !expected.Equals(actual) {
t.Errorf("setting bit %d to %t: expected %s, got %s", test.bit, test.one, test.expected, actual.HexSimplified())
}
}
}
}
func TestBitmap_FromHexShort(t *testing.T) {
tt := []struct {
short string
long string
}{
{short: "", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{short: "0", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{short: "00000", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{short: "9473745bc", long: "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000009473745bc"},
{short: "09473745bc", long: "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000009473745bc"},
{short: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
long: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"},
}
for _, test := range tt {
short := BitmapFromShortHexP(test.short)
long := BitmapFromHexP(test.long)
if !short.Equals(long) {
t.Errorf("short hex %s: expected %s, got %s", test.short, long.Hex(), short.Hex())
}
}
}
func TestBitmapMarshal(t *testing.T) { func TestBitmapMarshal(t *testing.T) {
b := BitmapFromStringP("123456789012345678901234567890123456789012345678") b := BitmapFromStringP("123456789012345678901234567890123456789012345678")
encoded, err := bencode.EncodeBytes(b) encoded, err := bencode.EncodeBytes(b)
@ -120,10 +199,10 @@ func TestBitmap_PrefixLen(t *testing.T) {
} }
} }
func TestBitmap_ZeroPrefix(t *testing.T) { func TestBitmap_Prefix(t *testing.T) {
original := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") allOne := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
tt := []struct { zerosTT := []struct {
zeros int zeros int
expected string expected string
}{ }{
@ -136,18 +215,162 @@ func TestBitmap_ZeroPrefix(t *testing.T) {
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, {zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
} }
for _, test := range tt { for _, test := range zerosTT {
expected := BitmapFromHexP(test.expected) expected := BitmapFromHexP(test.expected)
actual := original.ZeroPrefix(test.zeros) actual := allOne.Prefix(test.zeros, false)
if !actual.Equals(expected) { if !actual.Equals(expected) {
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex()) t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
} }
} }
for i := 0; i < nodeIDLength*8; i++ { for i := 0; i < nodeIDLength*8; i++ {
b := original.ZeroPrefix(i) b := allOne.Prefix(i, false)
if b.PrefixLen() != i { if b.PrefixLen() != i {
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex()) t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
} }
} }
allZero := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
onesTT := []struct {
ones int
expected string
}{
{ones: -123, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 0, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 1, expected: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 69, expected: "fffffffffffffffff8000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 383, expected: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"},
{ones: 384, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{ones: 400, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
}
for _, test := range onesTT {
expected := BitmapFromHexP(test.expected)
actual := allZero.Prefix(test.ones, true)
if !actual.Equals(expected) {
t.Errorf("%d ones: got %s; expected %s", test.ones, actual.Hex(), expected.Hex())
}
}
}
func TestBitmap_Suffix(t *testing.T) {
allOne := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
zerosTT := []struct {
zeros int
expected string
}{
{zeros: -123, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 0, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 1, expected: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"},
{zeros: 69, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe00000000000000000"},
{zeros: 383, expected: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{zeros: 384, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
}
for _, test := range zerosTT {
expected := BitmapFromHexP(test.expected)
actual := allOne.Suffix(test.zeros, false)
if !actual.Equals(expected) {
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
}
}
for i := 0; i < nodeIDLength*8; i++ {
b := allOne.Prefix(i, false)
if b.PrefixLen() != i {
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
}
}
allZero := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
onesTT := []struct {
ones int
expected string
}{
{ones: -123, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 0, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 1, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
{ones: 69, expected: "0000000000000000000000000000000000000000000000000000000000000000000000000000001fffffffffffffffff"},
{ones: 383, expected: "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{ones: 384, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{ones: 400, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
}
for _, test := range onesTT {
expected := BitmapFromHexP(test.expected)
actual := allZero.Suffix(test.ones, true)
if !actual.Equals(expected) {
t.Errorf("%d ones: got %s; expected %s", test.ones, actual.Hex(), expected.Hex())
}
}
}
func TestBitmap_Add(t *testing.T) {
tt := []struct {
a, b, sum string
panic bool
}{
{"0", "0", "0", false},
{"0", "1", "1", false},
{"1", "0", "1", false},
{"1", "1", "2", false},
{"8", "4", "c", false},
{"1000", "0010", "1010", false},
{"1111", "1111", "2222", false},
{"ffff", "1", "10000", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "1", "", true},
}
for _, test := range tt {
a := BitmapFromShortHexP(test.a)
b := BitmapFromShortHexP(test.b)
expected := BitmapFromShortHexP(test.sum)
if test.panic {
assertPanic(t, fmt.Sprintf("adding %s and %s", test.a, test.b), func() { a.Add(b) })
} else {
actual := a.Add(b)
if !expected.Equals(actual) {
t.Errorf("adding %s and %s; expected %s, got %s", test.a, test.b, test.sum, actual.HexSimplified())
}
}
}
}
func TestBitmap_Sub(t *testing.T) {
tt := []struct {
a, b, sum string
panic bool
}{
{"0", "0", "0", false},
{"1", "0", "1", false},
{"1", "1", "0", false},
{"8", "4", "4", false},
{"f", "9", "6", false},
{"f", "e", "1", false},
{"10", "f", "1", false},
{"2222", "1111", "1111", false},
{"ffff", "1", "fffe", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", false},
{"0", "1", "", true},
}
for _, test := range tt {
a := BitmapFromShortHexP(test.a)
b := BitmapFromShortHexP(test.b)
expected := BitmapFromShortHexP(test.sum)
if test.panic {
assertPanic(t, fmt.Sprintf("subtracting %s - %s", test.a, test.b), func() { a.Sub(b) })
} else {
actual := a.Sub(b)
if !expected.Equals(actual) {
t.Errorf("subtracting %s - %s; expected %s, got %s", test.a, test.b, test.sum, actual.HexSimplified())
}
}
}
} }

View file

@ -1,7 +1,6 @@
package dht package dht
import ( import (
"context"
"math/rand" "math/rand"
"net" "net"
"sync" "sync"
@ -14,25 +13,6 @@ const (
bootstrapDefaultRefreshDuration = 15 * time.Minute bootstrapDefaultRefreshDuration = 15 * time.Minute
) )
type nullStore struct{}
func (n nullStore) Upsert(id Bitmap, c Contact) {}
func (n nullStore) Get(id Bitmap) []Contact { return nil }
func (n nullStore) CountStoredHashes() int { return 0 }
type nullRoutingTable struct{}
// TODO: the bootstrap logic *could* be implemented just in the routing table, without a custom request handler
// TODO: the only tricky part is triggering the ping when Fresh is called, as the rt doesnt have access to the node
func (n nullRoutingTable) Update(c Contact) {} // this
func (n nullRoutingTable) Fresh(c Contact) {} // this
func (n nullRoutingTable) Fail(c Contact) {} // this
func (n nullRoutingTable) GetClosest(id Bitmap, limit int) []Contact { return nil } // this
func (n nullRoutingTable) Count() int { return 0 }
func (n nullRoutingTable) GetIDsForRefresh(d time.Duration) []Bitmap { return nil }
func (n nullRoutingTable) BucketInfo() string { return "" }
type BootstrapNode struct { type BootstrapNode struct {
Node Node
@ -57,8 +37,6 @@ func NewBootstrapNode(id Bitmap, initialPingInterval, rePingInterval time.Durati
nodeKeys: make(map[Bitmap]int), nodeKeys: make(map[Bitmap]int),
} }
b.rt = &nullRoutingTable{}
b.store = &nullStore{}
b.requestHandler = b.handleRequest b.requestHandler = b.handleRequest
return b return b
@ -98,14 +76,14 @@ func (b *BootstrapNode) upsert(c Contact) {
b.nlock.Lock() b.nlock.Lock()
defer b.nlock.Unlock() defer b.nlock.Unlock()
if i, exists := b.nodeKeys[c.id]; exists { if i, exists := b.nodeKeys[c.ID]; exists {
log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].contact.id.HexShort()) log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].Contact.ID.HexShort())
b.nodes[i].Touch() b.nodes[i].Touch()
return return
} }
log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.id.HexShort()) log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.ID.HexShort())
b.nodeKeys[c.id] = len(b.nodes) b.nodeKeys[c.ID] = len(b.nodes)
b.nodes = append(b.nodes, peer{c, time.Now(), 0}) b.nodes = append(b.nodes, peer{c, time.Now(), 0})
} }
@ -114,14 +92,14 @@ func (b *BootstrapNode) remove(c Contact) {
b.nlock.Lock() b.nlock.Lock()
defer b.nlock.Unlock() defer b.nlock.Unlock()
i, exists := b.nodeKeys[c.id] i, exists := b.nodeKeys[c.ID]
if !exists { if !exists {
return return
} }
log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.id.HexShort()) log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.ID.HexShort())
b.nodes = append(b.nodes[:i], b.nodes[i+1:]...) b.nodes = append(b.nodes[:i], b.nodes[i+1:]...)
delete(b.nodeKeys, c.id) delete(b.nodeKeys, c.ID)
} }
// get returns up to `limit` random contacts from the list // get returns up to `limit` random contacts from the list
@ -135,7 +113,7 @@ func (b *BootstrapNode) get(limit int) []Contact {
ret := make([]Contact, limit) ret := make([]Contact, limit)
for i, k := range randKeys(len(b.nodes))[:limit] { for i, k := range randKeys(len(b.nodes))[:limit] {
ret[i] = b.nodes[k].contact ret[i] = b.nodes[k].Contact
} }
return ret return ret
@ -146,8 +124,7 @@ func (b *BootstrapNode) ping(c Contact) {
b.stopWG.Add(1) b.stopWG.Add(1)
defer b.stopWG.Done() defer b.stopWG.Done()
ctx, cancel := context.WithCancel(context.Background()) resCh, cancel := b.SendCancelable(c, Request{Method: pingMethod})
resCh := b.SendAsync(ctx, c, Request{Method: pingMethod})
var res *Response var res *Response
@ -171,7 +148,7 @@ func (b *BootstrapNode) check() {
for i := range b.nodes { for i := range b.nodes {
if !b.nodes[i].ActiveInLast(b.checkInterval) { if !b.nodes[i].ActiveInLast(b.checkInterval) {
go b.ping(b.nodes[i].contact) go b.ping(b.nodes[i].Contact)
} }
} }
} }
@ -196,7 +173,7 @@ func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) {
go func() { go func() {
log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort()) log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort())
<-time.After(b.initialPingInterval) <-time.After(b.initialPingInterval)
b.ping(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port}) b.ping(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
}() }()
} }

View file

@ -1,7 +1,6 @@
package dht package dht
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strings" "strings"
@ -28,6 +27,7 @@ const (
alpha = 3 // this is the constant alpha in the spec alpha = 3 // this is the constant alpha in the spec
bucketSize = 8 // this is the constant k in the spec bucketSize = 8 // this is the constant k in the spec
nodeIDLength = 48 // bytes. this is the constant B in the spec nodeIDLength = 48 // bytes. this is the constant B in the spec
nodeIDBits = nodeIDLength * 8 // number of bits in node ID
messageIDLength = 20 // bytes. messageIDLength = 20 // bytes.
udpRetry = 3 udpRetry = 3
@ -42,7 +42,6 @@ const (
tRepublish = 24 * time.Hour // the time after which the original publisher must republish a key/value pair tRepublish = 24 * time.Hour // the time after which the original publisher must republish a key/value pair
tNodeRefresh = 15 * time.Minute // the time after which a good node becomes questionable if it has not messaged us tNodeRefresh = 15 * time.Minute // the time after which a good node becomes questionable if it has not messaged us
numBuckets = nodeIDLength * 8
compactNodeInfoLength = nodeIDLength + 6 // nodeID + 4 for IP + 2 for port compactNodeInfoLength = nodeIDLength + 6 // nodeID + 4 for IP + 2 for port
tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated
@ -102,7 +101,7 @@ func New(config *Config) (*DHT, error) {
d := &DHT{ d := &DHT{
conf: config, conf: config,
contact: contact, contact: contact,
node: NewNode(contact.id), node: NewNode(contact.ID),
stop: stopOnce.New(), stop: stopOnce.New(),
stopWG: &sync.WaitGroup{}, stopWG: &sync.WaitGroup{},
joined: make(chan struct{}), joined: make(chan struct{}),
@ -181,7 +180,7 @@ func (dht *DHT) Ping(addr string) error {
return err return err
} }
tmpNode := Contact{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port} tmpNode := Contact{ID: RandomBitmapP(), IP: raddr.IP, Port: raddr.Port}
res := dht.node.Send(tmpNode, Request{Method: pingMethod}) res := dht.node.Send(tmpNode, Request{Method: pingMethod})
if res == nil { if res == nil {
return errors.Err("no response from node %s", addr) return errors.Err("no response from node %s", addr)
@ -214,19 +213,25 @@ func (dht *DHT) Announce(hash Bitmap) error {
// TODO: if this node is closer than farthest peer, store locally and pop farthest peer // TODO: if this node is closer than farthest peer, store locally and pop farthest peer
for _, node := range res.Contacts { wg := &sync.WaitGroup{}
go dht.storeOnNode(hash, node) for _, c := range res.Contacts {
wg.Add(1)
go func(c Contact) {
dht.storeOnNode(hash, c)
wg.Done()
}(c)
} }
wg.Wait()
return nil return nil
} }
func (dht *DHT) storeOnNode(hash Bitmap, node Contact) { func (dht *DHT) storeOnNode(hash Bitmap, c Contact) {
dht.stopWG.Add(1) dht.stopWG.Add(1)
defer dht.stopWG.Done() defer dht.stopWG.Done()
ctx, cancel := context.WithCancel(context.Background()) resCh, cancel := dht.node.SendCancelable(c, Request{
resCh := dht.node.SendAsync(ctx, node, Request{
Method: findValueMethod, Method: findValueMethod,
Arg: &hash, Arg: &hash,
}) })
@ -244,15 +249,14 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
return // request timed out return // request timed out
} }
ctx, cancel = context.WithCancel(context.Background()) resCh, cancel = dht.node.SendCancelable(c, Request{
resCh = dht.node.SendAsync(ctx, node, Request{
Method: storeMethod, Method: storeMethod,
StoreArgs: &storeArgs{ StoreArgs: &storeArgs{
BlobHash: hash, BlobHash: hash,
Value: storeArgsValue{ Value: storeArgsValue{
Token: res.Token, Token: res.Token,
LbryID: dht.contact.id, LbryID: dht.contact.ID,
Port: dht.contact.port, Port: dht.contact.Port,
}, },
}, },
}) })
@ -276,18 +280,12 @@ func (dht *DHT) PrintState() {
} }
} }
func printNodeList(list []Contact) {
for i, n := range list {
log.Printf("%d) %s", i, n.String())
}
}
func getContact(nodeID, addr string) (Contact, error) { func getContact(nodeID, addr string) (Contact, error) {
var c Contact var c Contact
if nodeID == "" { if nodeID == "" {
c.id = RandomBitmapP() c.ID = RandomBitmapP()
} else { } else {
c.id = BitmapFromHexP(nodeID) c.ID = BitmapFromHexP(nodeID)
} }
ip, port, err := net.SplitHostPort(addr) ip, port, err := net.SplitHostPort(addr)
@ -299,12 +297,12 @@ func getContact(nodeID, addr string) (Contact, error) {
return c, errors.Err("address does not contain a port") return c, errors.Err("address does not contain a port")
} }
c.ip = net.ParseIP(ip) c.IP = net.ParseIP(ip)
if c.ip == nil { if c.IP == nil {
return c, errors.Err("invalid ip") return c, errors.Err("invalid ip")
} }
c.port, err = cast.ToIntE(port) c.Port, err = cast.ToIntE(port)
if err != nil { if err != nil {
return c, errors.Err(err) return c, errors.Err(err)
} }

View file

@ -36,13 +36,13 @@ func TestNodeFinder_FindNodes(t *testing.T) {
foundTwo := false foundTwo := false
for _, n := range foundNodes { for _, n := range foundNodes {
if n.id.Equals(bs.id) { if n.ID.Equals(bs.id) {
foundBootstrap = true foundBootstrap = true
} }
if n.id.Equals(dhts[0].node.id) { if n.ID.Equals(dhts[0].node.id) {
foundOne = true foundOne = true
} }
if n.id.Equals(dhts[1].node.id) { if n.ID.Equals(dhts[1].node.id) {
foundTwo = true foundTwo = true
} }
} }
@ -83,7 +83,7 @@ func TestNodeFinder_FindValue(t *testing.T) {
}() }()
blobHashToFind := RandomBitmapP() blobHashToFind := RandomBitmapP()
nodeToFind := Contact{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678} nodeToFind := Contact{ID: RandomBitmapP(), IP: net.IPv4(1, 2, 3, 4), Port: 5678}
dhts[0].node.store.Upsert(blobHashToFind, nodeToFind) dhts[0].node.store.Upsert(blobHashToFind, nodeToFind)
nf := newContactFinder(dhts[2].node, blobHashToFind, true) nf := newContactFinder(dhts[2].node, blobHashToFind, true)
@ -101,8 +101,8 @@ func TestNodeFinder_FindValue(t *testing.T) {
t.Fatalf("expected one node, found %d", len(foundNodes)) t.Fatalf("expected one node, found %d", len(foundNodes))
} }
if !foundNodes[0].id.Equals(nodeToFind.id) { if !foundNodes[0].ID.Equals(nodeToFind.ID) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex()) t.Fatalf("found node id %s, expected %s", foundNodes[0].ID.Hex(), nodeToFind.ID.Hex())
} }
} }
@ -139,7 +139,7 @@ func TestDHT_LargeDHT(t *testing.T) {
c := d2.node.rt.GetClosest(d.node.id, 1) c := d2.node.rt.GetClosest(d.node.id, 1)
if len(c) > 1 { if len(c) > 1 {
t.Error("rt returned more than one node when only one requested") t.Error("rt returned more than one node when only one requested")
} else if len(c) == 1 && c[0].id.Equals(d.node.id) { } else if len(c) == 1 && c[0].ID.Equals(d.node.id) {
rtCounts[d.node.id]++ rtCounts[d.node.id]++
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -240,7 +240,7 @@ func (r Response) ArgsDebug() string {
str += "|" str += "|"
for _, c := range r.Contacts { for _, c := range r.Contacts {
str += c.Addr().String() + ":" + c.id.HexShort() + "," str += c.Addr().String() + ":" + c.ID.HexShort() + ","
} }
str = strings.TrimRight(str, ",") + "|" str = strings.TrimRight(str, ",") + "|"

View file

@ -78,8 +78,8 @@ func TestBencodeFindNodesResponse(t *testing.T) {
ID: newMessageID(), ID: newMessageID(),
NodeID: RandomBitmapP(), NodeID: RandomBitmapP(),
Contacts: []Contact{ Contacts: []Contact{
{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, {ID: RandomBitmapP(), IP: net.IPv4(1, 2, 3, 4).To4(), Port: 5678},
{id: RandomBitmapP(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765}, {ID: RandomBitmapP(), IP: net.IPv4(4, 3, 2, 1).To4(), Port: 8765},
}, },
} }
@ -104,7 +104,7 @@ func TestBencodeFindValueResponse(t *testing.T) {
FindValueKey: RandomBitmapP().RawString(), FindValueKey: RandomBitmapP().RawString(),
Token: "arst", Token: "arst",
Contacts: []Contact{ Contacts: []Contact{
{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, {ID: RandomBitmapP(), IP: net.IPv4(1, 2, 3, 4).To4(), Port: 5678},
}, },
} }

View file

@ -48,9 +48,9 @@ type Node struct {
transactions map[messageID]*transaction transactions map[messageID]*transaction
// routing table // routing table
rt RoutingTable rt *routingTable
// data store // data store
store Store store *contactStore
// overrides for request handlers // overrides for request handlers
requestHandler RequestHandlerFunc requestHandler RequestHandlerFunc
@ -238,7 +238,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
// TODO: we should be sending the IP in the request, not just using the sender's IP // TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) { if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
n.store.Upsert(request.StoreArgs.BlobHash, Contact{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) n.store.Upsert(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: request.StoreArgs.Value.Port})
n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse}) n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse})
} else { } else {
n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"}) n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"})
@ -280,7 +280,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
// the routing table must only contain "good" nodes, which are nodes that reply to our requests // the routing table must only contain "good" nodes, which are nodes that reply to our requests
// if a node is already good (aka in the table), its fine to refresh it // if a node is already good (aka in the table), its fine to refresh it
// http://www.bittorrent.org/beps/bep_0005.html#routing-table // http://www.bittorrent.org/beps/bep_0005.html#routing-table
n.rt.Fresh(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port}) n.rt.Fresh(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
} }
// handleResponse handles responses received from udp. // handleResponse handles responses received from udp.
@ -290,13 +290,13 @@ func (n *Node) handleResponse(addr *net.UDPAddr, response Response) {
tx.res <- response tx.res <- response
} }
n.rt.Update(Contact{id: response.NodeID, ip: addr.IP, port: addr.Port}) n.rt.Update(Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port})
} }
// handleError handles errors received from udp. // handleError handles errors received from udp.
func (n *Node) handleError(addr *net.UDPAddr, e Error) { func (n *Node) handleError(addr *net.UDPAddr, e Error) {
spew.Dump(e) spew.Dump(e)
n.rt.Fresh(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port}) n.rt.Fresh(Contact{ID: e.NodeID, IP: addr.IP, Port: addr.Port})
} }
// send sends data to a udp address // send sends data to a udp address
@ -361,7 +361,7 @@ func (n *Node) txFind(id messageID, addr *net.UDPAddr) *transaction {
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response // SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
// The response channel is closed when the transaction is completed or times out. // The response channel is closed when the transaction is completed or times out.
func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-chan *Response { func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-chan *Response {
if contact.id.Equals(n.id) { if contact.ID.Equals(n.id) {
log.Error("sending query to self") log.Error("sending query to self")
return nil return nil
} }
@ -413,6 +413,12 @@ func (n *Node) Send(contact Contact, req Request) *Response {
return <-n.SendAsync(context.Background(), contact, req) return <-n.SendAsync(context.Background(), contact, req)
} }
// SendCancelable sends the transaction asynchronously and allows the transaction to be canceled
func (n *Node) SendCancelable(contact Contact, req Request) (<-chan *Response, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
return n.SendAsync(ctx, contact, req), cancel
}
// Count returns the number of transactions in the manager // Count returns the number of transactions in the manager
func (n *Node) CountActiveTransactions() int { func (n *Node) CountActiveTransactions() int {
n.txLock.Lock() n.txLock.Lock()

View file

@ -1,7 +1,6 @@
package dht package dht
import ( import (
"context"
"sort" "sort"
"sync" "sync"
"time" "time"
@ -113,7 +112,7 @@ func (cf *contactFinder) iterationWorker(num int) {
} else { } else {
contact := *maybeContact contact := *maybeContact
if contact.id.Equals(cf.node.id) { if contact.ID.Equals(cf.node.id) {
continue // cannot contact self continue // cannot contact self
} }
@ -124,13 +123,12 @@ func (cf *contactFinder) iterationWorker(num int) {
req.Method = findNodeMethod req.Method = findNodeMethod
} }
log.Debugf("[%s] worker %d: contacting %s", cf.node.id.HexShort(), num, contact.id.HexShort()) log.Debugf("[%s] worker %d: contacting %s", cf.node.id.HexShort(), num, contact.ID.HexShort())
cf.incrementOutstanding() cf.incrementOutstanding()
var res *Response var res *Response
ctx, cancel := context.WithCancel(context.Background()) resCh, cancel := cf.node.SendCancelable(contact, req)
resCh := cf.node.SendAsync(ctx, contact, req)
select { select {
case res = <-resCh: case res = <-resCh:
case <-cf.done.Chan(): case <-cf.done.Chan():
@ -141,7 +139,7 @@ func (cf *contactFinder) iterationWorker(num int) {
if res == nil { if res == nil {
// nothing to do, response timed out // nothing to do, response timed out
log.Debugf("[%s] worker %d: search canceled or timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort()) log.Debugf("[%s] worker %d: search canceled or timed out waiting for %s", cf.node.id.HexShort(), num, contact.ID.HexShort())
} else if cf.findValue && res.FindValueKey != "" { } else if cf.findValue && res.FindValueKey != "" {
log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num) log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num)
cf.findValueMutex.Lock() cf.findValueMutex.Lock()
@ -171,9 +169,9 @@ func (cf *contactFinder) appendNewToShortlist(contacts []Contact) {
defer cf.shortlistMutex.Unlock() defer cf.shortlistMutex.Unlock()
for _, c := range contacts { for _, c := range contacts {
if _, ok := cf.shortlistAdded[c.id]; !ok { if _, ok := cf.shortlistAdded[c.ID]; !ok {
cf.shortlist = append(cf.shortlist, c) cf.shortlist = append(cf.shortlist, c)
cf.shortlistAdded[c.id] = true cf.shortlistAdded[c.ID] = true
} }
} }
@ -199,7 +197,7 @@ func (cf *contactFinder) insertIntoActiveList(contact Contact) {
inserted := false inserted := false
for i, n := range cf.activeContacts { for i, n := range cf.activeContacts {
if contact.id.Xor(cf.target).Less(n.id.Xor(cf.target)) { if contact.ID.Xor(cf.target).Less(n.ID.Xor(cf.target)) {
cf.activeContacts = append(cf.activeContacts[:i], append([]Contact{contact}, cf.activeContacts[i:]...)...) cf.activeContacts = append(cf.activeContacts[:i], append([]Contact{contact}, cf.activeContacts[i:]...)...)
inserted = true inserted = true
break break
@ -232,7 +230,7 @@ func (cf *contactFinder) isSearchFinished() bool {
cf.activeContactsMutex.Lock() cf.activeContactsMutex.Lock()
defer cf.activeContactsMutex.Unlock() defer cf.activeContactsMutex.Unlock()
if len(cf.activeContacts) >= bucketSize && cf.activeContacts[bucketSize-1].id.Xor(cf.target).Less(cf.shortlist[0].id.Xor(cf.target)) { if len(cf.activeContacts) >= bucketSize && cf.activeContacts[bucketSize-1].ID.Xor(cf.target).Less(cf.shortlist[0].ID.Xor(cf.target)) {
// we have at least K active contacts, and we don't have any closer contacts to ping // we have at least K active contacts, and we don't have any closer contacts to ping
return true return true
} }
@ -263,7 +261,7 @@ func sortInPlace(contacts []Contact, target Bitmap) {
toSort := make([]sortedContact, len(contacts)) toSort := make([]sortedContact, len(contacts))
for i, n := range contacts { for i, n := range contacts {
toSort[i] = sortedContact{n, n.id.Xor(target)} toSort[i] = sortedContact{n, n.ID.Xor(target)}
} }
sort.Sort(byXorDistance(toSort)) sort.Sort(byXorDistance(toSort))

View file

@ -198,7 +198,7 @@ func TestStore(t *testing.T) {
if len(items) != 1 { if len(items) != 1 {
t.Error("list created in store, but nothing in list") t.Error("list created in store, but nothing in list")
} }
if !items[0].id.Equals(testNodeID) { if !items[0].ID.Equals(testNodeID) {
t.Error("wrong value stored") t.Error("wrong value stored")
} }
} }
@ -223,7 +223,7 @@ func TestFindNode(t *testing.T) {
nodesToInsert := 3 nodesToInsert := 3
var nodes []Contact var nodes []Contact
for i := 0; i < nodesToInsert; i++ { for i := 0; i < nodesToInsert; i++ {
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} n := Contact{ID: RandomBitmapP(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
nodes = append(nodes, n) nodes = append(nodes, n)
dht.node.rt.Update(n) dht.node.rt.Update(n)
} }
@ -292,7 +292,7 @@ func TestFindValueExisting(t *testing.T) {
nodesToInsert := 3 nodesToInsert := 3
var nodes []Contact var nodes []Contact
for i := 0; i < nodesToInsert; i++ { for i := 0; i < nodesToInsert; i++ {
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} n := Contact{ID: RandomBitmapP(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
nodes = append(nodes, n) nodes = append(nodes, n)
dht.node.rt.Update(n) dht.node.rt.Update(n)
} }
@ -302,7 +302,7 @@ func TestFindValueExisting(t *testing.T) {
messageID := newMessageID() messageID := newMessageID()
valueToFind := RandomBitmapP() valueToFind := RandomBitmapP()
nodeToFind := Contact{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286} nodeToFind := Contact{ID: RandomBitmapP(), IP: net.ParseIP("1.2.3.4"), Port: 1286}
dht.node.store.Upsert(valueToFind, nodeToFind) dht.node.store.Upsert(valueToFind, nodeToFind)
dht.node.store.Upsert(valueToFind, nodeToFind) dht.node.store.Upsert(valueToFind, nodeToFind)
dht.node.store.Upsert(valueToFind, nodeToFind) dht.node.store.Upsert(valueToFind, nodeToFind)
@ -378,7 +378,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
nodesToInsert := 3 nodesToInsert := 3
var nodes []Contact var nodes []Contact
for i := 0; i < nodesToInsert; i++ { for i := 0; i < nodesToInsert; i++ {
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} n := Contact{ID: RandomBitmapP(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
nodes = append(nodes, n) nodes = append(nodes, n)
dht.node.rt.Update(n) dht.node.rt.Update(n)
} }

View file

@ -2,10 +2,11 @@ package dht
import ( import (
"bytes" "bytes"
"container/list" "encoding/json"
"fmt" "fmt"
"net" "net"
"sort" "sort"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -16,32 +17,32 @@ import (
) )
type Contact struct { type Contact struct {
id Bitmap ID Bitmap
ip net.IP IP net.IP
port int Port int
} }
func (c Contact) Addr() *net.UDPAddr { func (c Contact) Addr() *net.UDPAddr {
return &net.UDPAddr{IP: c.ip, Port: c.port} return &net.UDPAddr{IP: c.IP, Port: c.Port}
} }
func (c Contact) String() string { func (c Contact) String() string {
return c.id.HexShort() + "@" + c.Addr().String() return c.ID.HexShort() + "@" + c.Addr().String()
} }
func (c Contact) MarshalCompact() ([]byte, error) { func (c Contact) MarshalCompact() ([]byte, error) {
if c.ip.To4() == nil { if c.IP.To4() == nil {
return nil, errors.Err("ip not set") return nil, errors.Err("ip not set")
} }
if c.port < 0 || c.port > 65535 { if c.Port < 0 || c.Port > 65535 {
return nil, errors.Err("invalid port") return nil, errors.Err("invalid port")
} }
var buf bytes.Buffer var buf bytes.Buffer
buf.Write(c.ip.To4()) buf.Write(c.IP.To4())
buf.WriteByte(byte(c.port >> 8)) buf.WriteByte(byte(c.Port >> 8))
buf.WriteByte(byte(c.port)) buf.WriteByte(byte(c.Port))
buf.Write(c.id[:]) buf.Write(c.ID[:])
if buf.Len() != compactNodeInfoLength { if buf.Len() != compactNodeInfoLength {
return nil, errors.Err("i dont know how this happened") return nil, errors.Err("i dont know how this happened")
@ -54,14 +55,14 @@ func (c *Contact) UnmarshalCompact(b []byte) error {
if len(b) != compactNodeInfoLength { if len(b) != compactNodeInfoLength {
return errors.Err("invalid compact length") return errors.Err("invalid compact length")
} }
c.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4() c.IP = net.IPv4(b[0], b[1], b[2], b[3]).To4()
c.port = int(uint16(b[5]) | uint16(b[4])<<8) c.Port = int(uint16(b[5]) | uint16(b[4])<<8)
c.id = BitmapFromBytesP(b[6:]) c.ID = BitmapFromBytesP(b[6:])
return nil return nil
} }
func (c Contact) MarshalBencode() ([]byte, error) { func (c Contact) MarshalBencode() ([]byte, error) {
return bencode.EncodeBytes([]interface{}{c.id, c.ip.String(), c.port}) return bencode.EncodeBytes([]interface{}{c.ID, c.IP.String(), c.Port})
} }
func (c *Contact) UnmarshalBencode(b []byte) error { func (c *Contact) UnmarshalBencode(b []byte) error {
@ -75,7 +76,7 @@ func (c *Contact) UnmarshalBencode(b []byte) error {
return errors.Err("contact must have 3 elements; got %d", len(raw)) return errors.Err("contact must have 3 elements; got %d", len(raw))
} }
err = bencode.DecodeBytes(raw[0], &c.id) err = bencode.DecodeBytes(raw[0], &c.ID)
if err != nil { if err != nil {
return err return err
} }
@ -85,12 +86,12 @@ func (c *Contact) UnmarshalBencode(b []byte) error {
if err != nil { if err != nil {
return err return err
} }
c.ip = net.ParseIP(ipStr).To4() c.IP = net.ParseIP(ipStr).To4()
if c.ip == nil { if c.IP == nil {
return errors.Err("invalid IP") return errors.Err("invalid IP")
} }
err = bencode.DecodeBytes(raw[2], &c.port) err = bencode.DecodeBytes(raw[2], &c.Port)
if err != nil { if err != nil {
return err return err
} }
@ -113,52 +114,38 @@ func (a byXorDistance) Less(i, j int) bool {
// peer is a contact with extra freshness information // peer is a contact with extra freshness information
type peer struct { type peer struct {
contact Contact Contact Contact
lastActivity time.Time LastActivity time.Time
numFailures int NumFailures int
//<lastPublished>, //<lastPublished>,
//<originallyPublished> //<originallyPublished>
// <originalPublisherID> // <originalPublisherID>
} }
func (p *peer) Touch() { func (p *peer) Touch() {
p.lastActivity = time.Now() p.LastActivity = time.Now()
p.numFailures = 0 p.NumFailures = 0
} }
// ActiveSince returns whether a peer has responded in the last `d` duration // ActiveSince returns whether a peer has responded in the last `d` duration
// this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests // this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests
func (p *peer) ActiveInLast(d time.Duration) bool { func (p *peer) ActiveInLast(d time.Duration) bool {
return time.Now().Sub(p.lastActivity) > d return time.Now().Sub(p.LastActivity) > d
} }
// IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row // IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row
func (p *peer) IsBad(maxFalures int) bool { func (p *peer) IsBad(maxFalures int) bool {
return p.numFailures >= maxFalures return p.NumFailures >= maxFalures
} }
// Fail marks a peer as having failed to respond. It returns whether or not the peer should be removed from the routing table // Fail marks a peer as having failed to respond. It returns whether or not the peer should be removed from the routing table
func (p *peer) Fail() { func (p *peer) Fail() {
p.numFailures++ p.NumFailures++
}
// toPeer converts a generic *list.Element into a *peer
// this (along with newPeer) keeps all conversions between *list.Element and peer in one place
func toPeer(el *list.Element) *peer {
return el.Value.(*peer)
}
// newPeer creates a new peer from a contact
// this (along with toPeer) keeps all conversions between *list.Element and peer in one place
func newPeer(c Contact) peer {
return peer{
contact: c,
}
} }
type bucket struct { type bucket struct {
lock *sync.RWMutex lock *sync.RWMutex
peers *list.List peers []peer
lastUpdate time.Time lastUpdate time.Time
} }
@ -166,16 +153,16 @@ type bucket struct {
func (b bucket) Len() int { func (b bucket) Len() int {
b.lock.RLock() b.lock.RLock()
defer b.lock.RUnlock() defer b.lock.RUnlock()
return b.peers.Len() return len(b.peers)
} }
// Contacts returns a slice of the bucket's contacts // Contacts returns a slice of the bucket's contacts
func (b bucket) Contacts() []Contact { func (b bucket) Contacts() []Contact {
b.lock.RLock() b.lock.RLock()
defer b.lock.RUnlock() defer b.lock.RUnlock()
contacts := make([]Contact, b.peers.Len()) contacts := make([]Contact, len(b.peers))
for i, curr := 0, b.peers.Front(); curr != nil; i, curr = i+1, curr.Next() { for i := range b.peers {
contacts[i] = toPeer(curr).contact contacts[i] = b.peers[i].Contact
} }
return contacts return contacts
} }
@ -185,21 +172,21 @@ func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
b.lock.Lock() b.lock.Lock()
defer b.lock.Unlock() defer b.lock.Unlock()
element := find(c.id, b.peers) peerIndex := find(c.ID, b.peers)
if element != nil { if peerIndex >= 0 {
b.lastUpdate = time.Now() b.lastUpdate = time.Now()
toPeer(element).Touch() b.peers[peerIndex].Touch()
b.peers.MoveToBack(element) moveToBack(b.peers, peerIndex)
} else if insertIfNew { } else if insertIfNew {
hasRoom := true hasRoom := true
if b.peers.Len() >= bucketSize { if len(b.peers) >= bucketSize {
hasRoom = false hasRoom = false
for curr := b.peers.Front(); curr != nil; curr = curr.Next() { for i := range b.peers {
if toPeer(curr).IsBad(maxPeerFails) { if b.peers[i].IsBad(maxPeerFails) {
// TODO: Ping contact first. Only remove if it does not respond // TODO: Ping contact first. Only remove if it does not respond
b.peers.Remove(curr) b.peers = append(b.peers[:i], b.peers[i+1:]...)
hasRoom = true hasRoom = true
break break
} }
@ -208,9 +195,9 @@ func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
if hasRoom { if hasRoom {
b.lastUpdate = time.Now() b.lastUpdate = time.Now()
peer := newPeer(c) peer := peer{Contact: c}
peer.Touch() peer.Touch()
b.peers.PushBack(&peer) b.peers = append(b.peers, peer)
} }
} }
} }
@ -219,21 +206,21 @@ func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
func (b *bucket) FailContact(id Bitmap) { func (b *bucket) FailContact(id Bitmap) {
b.lock.Lock() b.lock.Lock()
defer b.lock.Unlock() defer b.lock.Unlock()
element := find(id, b.peers) i := find(id, b.peers)
if element != nil { if i >= 0 {
// BEP5 says not to remove the contact until the bucket is full and you try to insert // BEP5 says not to remove the contact until the bucket is full and you try to insert
toPeer(element).Fail() b.peers[i].Fail()
} }
} }
// find returns the contact in the bucket, or nil if the bucket does not contain the contact // find returns the contact in the bucket, or nil if the bucket does not contain the contact
func find(id Bitmap, peers *list.List) *list.Element { func find(id Bitmap, peers []peer) int {
for curr := peers.Front(); curr != nil; curr = curr.Next() { for i := range peers {
if toPeer(curr).contact.id.Equals(id) { if peers[i].Contact.ID.Equals(id) {
return curr return i
} }
} }
return nil return -1
} }
// NeedsRefresh returns true if bucket has not been updated in the last `refreshInterval`, false otherwise // NeedsRefresh returns true if bucket has not been updated in the last `refreshInterval`, false otherwise
@ -243,41 +230,31 @@ func (b *bucket) NeedsRefresh(refreshInterval time.Duration) bool {
return time.Now().Sub(b.lastUpdate) > refreshInterval return time.Now().Sub(b.lastUpdate) > refreshInterval
} }
type RoutingTable interface { type routingTable struct {
Update(Contact)
Fresh(Contact)
Fail(Contact)
GetClosest(Bitmap, int) []Contact
Count() int
GetIDsForRefresh(time.Duration) []Bitmap
BucketInfo() string // for debugging
}
type routingTableImpl struct {
id Bitmap id Bitmap
buckets [numBuckets]bucket buckets [nodeIDBits]bucket
} }
func newRoutingTable(id Bitmap) *routingTableImpl { func newRoutingTable(id Bitmap) *routingTable {
var rt routingTableImpl var rt routingTable
rt.id = id rt.id = id
for i := range rt.buckets { for i := range rt.buckets {
rt.buckets[i] = bucket{ rt.buckets[i] = bucket{
peers: list.New(), peers: make([]peer, 0, bucketSize),
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
} }
} }
return &rt return &rt
} }
func (rt *routingTableImpl) BucketInfo() string { func (rt *routingTable) BucketInfo() string {
var bucketInfo []string var bucketInfo []string
for i, b := range rt.buckets { for i, b := range rt.buckets {
if b.Len() > 0 { if b.Len() > 0 {
contacts := b.Contacts() contacts := b.Contacts()
s := make([]string, len(contacts)) s := make([]string, len(contacts))
for j, c := range contacts { for j, c := range contacts {
s[j] = c.id.HexShort() s[j] = c.ID.HexShort()
} }
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, len(contacts), strings.Join(s, ", "))) bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, len(contacts), strings.Join(s, ", ")))
} }
@ -289,23 +266,23 @@ func (rt *routingTableImpl) BucketInfo() string {
} }
// Update inserts or refreshes a contact // Update inserts or refreshes a contact
func (rt *routingTableImpl) Update(c Contact) { func (rt *routingTable) Update(c Contact) {
rt.bucketFor(c.id).UpdateContact(c, true) rt.bucketFor(c.ID).UpdateContact(c, true)
} }
// Fresh refreshes a contact if its already in the routing table // Fresh refreshes a contact if its already in the routing table
func (rt *routingTableImpl) Fresh(c Contact) { func (rt *routingTable) Fresh(c Contact) {
rt.bucketFor(c.id).UpdateContact(c, false) rt.bucketFor(c.ID).UpdateContact(c, false)
} }
// FailContact marks a contact as having failed, and removes it if it failed too many times // FailContact marks a contact as having failed, and removes it if it failed too many times
func (rt *routingTableImpl) Fail(c Contact) { func (rt *routingTable) Fail(c Contact) {
rt.bucketFor(c.id).FailContact(c.id) rt.bucketFor(c.ID).FailContact(c.ID)
} }
// GetClosest returns the closest `limit` contacts from the routing table // GetClosest returns the closest `limit` contacts from the routing table
// It marks each bucket it accesses as having been accessed // It marks each bucket it accesses as having been accessed
func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact { func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact {
var toSort []sortedContact var toSort []sortedContact
var bucketNum int var bucketNum int
@ -317,11 +294,11 @@ func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact {
toSort = appendContacts(toSort, rt.buckets[bucketNum], target) toSort = appendContacts(toSort, rt.buckets[bucketNum], target)
for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ { for i := 1; (bucketNum-i >= 0 || bucketNum+i < nodeIDBits) && len(toSort) < limit; i++ {
if bucketNum-i >= 0 { if bucketNum-i >= 0 {
toSort = appendContacts(toSort, rt.buckets[bucketNum-i], target) toSort = appendContacts(toSort, rt.buckets[bucketNum-i], target)
} }
if bucketNum+i < numBuckets { if bucketNum+i < nodeIDBits {
toSort = appendContacts(toSort, rt.buckets[bucketNum+i], target) toSort = appendContacts(toSort, rt.buckets[bucketNum+i], target)
} }
} }
@ -341,13 +318,13 @@ func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact {
func appendContacts(contacts []sortedContact, b bucket, target Bitmap) []sortedContact { func appendContacts(contacts []sortedContact, b bucket, target Bitmap) []sortedContact {
for _, contact := range b.Contacts() { for _, contact := range b.Contacts() {
contacts = append(contacts, sortedContact{contact, contact.id.Xor(target)}) contacts = append(contacts, sortedContact{contact, contact.ID.Xor(target)})
} }
return contacts return contacts
} }
// Count returns the number of contacts in the routing table // Count returns the number of contacts in the routing table
func (rt *routingTableImpl) Count() int { func (rt *routingTable) Count() int {
count := 0 count := 0
for _, bucket := range rt.buckets { for _, bucket := range rt.buckets {
count = bucket.Len() count = bucket.Len()
@ -355,27 +332,99 @@ func (rt *routingTableImpl) Count() int {
return count return count
} }
func (rt *routingTableImpl) bucketNumFor(target Bitmap) int { type Range struct {
start Bitmap
end Bitmap
}
// BucketRanges returns a slice of ranges, where the `start` of each range is the smallest id that can
// go in that bucket, and the `end` is the largest id
func (rt *routingTable) BucketRanges() []Range {
ranges := make([]Range, len(rt.buckets))
for i := range rt.buckets {
ranges[i] = Range{
rt.id.Suffix(i, false).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)),
rt.id.Suffix(i, true).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)),
}
}
return ranges
}
func (rt *routingTable) bucketNumFor(target Bitmap) int {
if rt.id.Equals(target) { if rt.id.Equals(target) {
panic("routing table does not have a bucket for its own id") panic("routing table does not have a bucket for its own id")
} }
return numBuckets - 1 - target.Xor(rt.id).PrefixLen() return nodeIDBits - 1 - target.Xor(rt.id).PrefixLen()
} }
func (rt *routingTableImpl) bucketFor(target Bitmap) *bucket { func (rt *routingTable) bucketFor(target Bitmap) *bucket {
return &rt.buckets[rt.bucketNumFor(target)] return &rt.buckets[rt.bucketNumFor(target)]
} }
func (rt *routingTableImpl) GetIDsForRefresh(refreshInterval time.Duration) []Bitmap { func (rt *routingTable) GetIDsForRefresh(refreshInterval time.Duration) []Bitmap {
var bitmaps []Bitmap var bitmaps []Bitmap
for i, bucket := range rt.buckets { for i, bucket := range rt.buckets {
if bucket.NeedsRefresh(refreshInterval) { if bucket.NeedsRefresh(refreshInterval) {
bitmaps = append(bitmaps, RandomBitmapP().ZeroPrefix(i)) bitmaps = append(bitmaps, RandomBitmapP().Prefix(i, false))
} }
} }
return bitmaps return bitmaps
} }
const rtContactSep = "-"
type rtSave struct {
ID string `json:"id"`
Contacts []string `json:"contacts"`
}
func (rt *routingTable) MarshalJSON() ([]byte, error) {
var data rtSave
data.ID = rt.id.Hex()
for _, b := range rt.buckets {
for _, c := range b.Contacts() {
data.Contacts = append(data.Contacts, strings.Join([]string{c.ID.Hex(), c.IP.String(), strconv.Itoa(c.Port)}, rtContactSep))
}
}
return json.Marshal(data)
}
func (rt *routingTable) UnmarshalJSON(b []byte) error {
var data rtSave
err := json.Unmarshal(b, &data)
if err != nil {
return err
}
rt.id, err = BitmapFromHex(data.ID)
if err != nil {
return errors.Prefix("decoding ID", err)
}
for _, s := range data.Contacts {
parts := strings.Split(s, rtContactSep)
if len(parts) != 3 {
return errors.Err("decoding contact %s: wrong number of parts", s)
}
var c Contact
c.ID, err = BitmapFromHex(parts[0])
if err != nil {
return errors.Err("decoding contact %s: invalid ID: %s", s, err)
}
c.IP = net.ParseIP(parts[1])
if c.IP == nil {
return errors.Err("decoding contact %s: invalid IP", s)
}
c.Port, err = strconv.Atoi(parts[2])
if err != nil {
return errors.Err("decoding contact %s: invalid port: %s", s, err)
}
rt.Update(c)
}
return nil
}
// RoutingTableRefresh refreshes any buckets that need to be refreshed // RoutingTableRefresh refreshes any buckets that need to be refreshed
// It returns a channel that will be closed when the refresh is done // It returns a channel that will be closed when the refresh is done
func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} { func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} {
@ -411,3 +460,14 @@ func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan s
return done return done
} }
func moveToBack(peers []peer, index int) {
if index < 0 || len(peers) <= index+1 {
return
}
p := peers[index]
for i := index; i < len(peers)-1; i++ {
peers[i] = peers[i+1]
}
peers[len(peers)-1] = p
}

View file

@ -1,9 +1,14 @@
package dht package dht
import ( import (
"encoding/json"
"net" "net"
"reflect" "reflect"
"strconv"
"strings"
"testing" "testing"
"github.com/sebdah/goldie"
) )
func TestRoutingTable_bucketFor(t *testing.T) { func TestRoutingTable_bucketFor(t *testing.T) {
@ -31,7 +36,7 @@ func TestRoutingTable_bucketFor(t *testing.T) {
} }
} }
func TestRoutingTable(t *testing.T) { func TestRoutingTable_GetClosest(t *testing.T) {
n1 := BitmapFromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n1 := BitmapFromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n2 := BitmapFromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n2 := BitmapFromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n3 := BitmapFromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n3 := BitmapFromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
@ -44,7 +49,7 @@ func TestRoutingTable(t *testing.T) {
t.Fail() t.Fail()
return return
} }
if !contacts[0].id.Equals(n3) { if !contacts[0].ID.Equals(n3) {
t.Error(contacts[0]) t.Error(contacts[0])
} }
@ -53,19 +58,19 @@ func TestRoutingTable(t *testing.T) {
t.Error(len(contacts)) t.Error(len(contacts))
return return
} }
if !contacts[0].id.Equals(n2) { if !contacts[0].ID.Equals(n2) {
t.Error(contacts[0]) t.Error(contacts[0])
} }
if !contacts[1].id.Equals(n3) { if !contacts[1].ID.Equals(n3) {
t.Error(contacts[1]) t.Error(contacts[1])
} }
} }
func TestCompactEncoding(t *testing.T) { func TestCompactEncoding(t *testing.T) {
c := Contact{ c := Contact{
id: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"), ID: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
ip: net.ParseIP("1.2.3.4"), IP: net.ParseIP("1.2.3.4"),
port: int(55<<8 + 66), Port: int(55<<8 + 66),
} }
var compact []byte var compact []byte
@ -78,11 +83,117 @@ func TestCompactEncoding(t *testing.T) {
t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength) t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength)
} }
if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.id[:]...)) { if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.ID[:]...)) {
t.Errorf("compact bytes not encoded correctly") t.Errorf("compact bytes not encoded correctly")
} }
} }
func TestRoutingTableRefresh(t *testing.T) { func TestRoutingTable_Refresh(t *testing.T) {
t.Skip("TODO: test routing table refreshing") t.Skip("TODO: test routing table refreshing")
} }
func TestRoutingTable_MoveToBack(t *testing.T) {
tt := map[string]struct {
data []peer
index int
expected []peer
}{
"simpleMove": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 1,
expected: []peer{{NumFailures: 0}, {NumFailures: 2}, {NumFailures: 3}, {NumFailures: 1}},
},
"moveFirst": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 0,
expected: []peer{{NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}, {NumFailures: 0}},
},
"moveLast": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 3,
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
},
"largeIndex": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 27,
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
},
"negativeIndex": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: -12,
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
},
}
for name, test := range tt {
moveToBack(test.data, test.index)
expected := make([]string, len(test.expected))
actual := make([]string, len(test.data))
for i := range actual {
actual[i] = strconv.Itoa(test.data[i].NumFailures)
expected[i] = strconv.Itoa(test.expected[i].NumFailures)
}
expJoin := strings.Join(expected, ",")
actJoin := strings.Join(actual, ",")
if actJoin != expJoin {
t.Errorf("%s failed: got %s; expected %s", name, actJoin, expJoin)
}
}
}
func TestRoutingTable_BucketRanges(t *testing.T) {
id := BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41")
ranges := newRoutingTable(id).BucketRanges()
if !ranges[0].start.Equals(ranges[0].end) {
t.Error("first bucket should only fit exactly one id")
}
for i := 0; i < 1000; i++ {
randID := RandomBitmapP()
found := -1
for i, r := range ranges {
if r.start.LessOrEqual(randID) && r.end.GreaterOrEqual(randID) {
if found >= 0 {
t.Errorf("%s appears in buckets %d and %d", randID.Hex(), found, i)
} else {
found = i
}
}
}
if found < 0 {
t.Errorf("%s did not appear in any bucket", randID.Hex())
}
}
}
func TestRoutingTable_Save(t *testing.T) {
id := BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41")
rt := newRoutingTable(id)
ranges := rt.BucketRanges()
for i, r := range ranges {
for j := 0; j < bucketSize; j++ {
toAdd := r.start.Add(BitmapFromShortHexP(strconv.Itoa(j)))
if toAdd.LessOrEqual(r.end) {
rt.Update(Contact{
ID: r.start.Add(BitmapFromShortHexP(strconv.Itoa(j))),
IP: net.ParseIP("1.2.3." + strconv.Itoa(j)),
Port: 1 + i*bucketSize + j,
})
}
}
}
data, err := json.MarshalIndent(rt, "", " ")
if err != nil {
t.Error(err)
}
goldie.Assert(t, t.Name(), data)
}
func TestRoutingTable_Load(t *testing.T) {
t.Skip("TODO")
}

View file

@ -2,13 +2,7 @@ package dht
import "sync" import "sync"
type Store interface { type contactStore struct {
Upsert(Bitmap, Contact)
Get(Bitmap) []Contact
CountStoredHashes() int
}
type storeImpl struct {
// map of blob hashes to (map of node IDs to bools) // map of blob hashes to (map of node IDs to bools)
hashes map[Bitmap]map[Bitmap]bool hashes map[Bitmap]map[Bitmap]bool
// stores the peers themselves, so they can be updated in one place // stores the peers themselves, so they can be updated in one place
@ -16,25 +10,25 @@ type storeImpl struct {
lock sync.RWMutex lock sync.RWMutex
} }
func newStore() *storeImpl { func newStore() *contactStore {
return &storeImpl{ return &contactStore{
hashes: make(map[Bitmap]map[Bitmap]bool), hashes: make(map[Bitmap]map[Bitmap]bool),
contacts: make(map[Bitmap]Contact), contacts: make(map[Bitmap]Contact),
} }
} }
func (s *storeImpl) Upsert(blobHash Bitmap, contact Contact) { func (s *contactStore) Upsert(blobHash Bitmap, contact Contact) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
if _, ok := s.hashes[blobHash]; !ok { if _, ok := s.hashes[blobHash]; !ok {
s.hashes[blobHash] = make(map[Bitmap]bool) s.hashes[blobHash] = make(map[Bitmap]bool)
} }
s.hashes[blobHash][contact.id] = true s.hashes[blobHash][contact.ID] = true
s.contacts[contact.id] = contact s.contacts[contact.ID] = contact
} }
func (s *storeImpl) Get(blobHash Bitmap) []Contact { func (s *contactStore) Get(blobHash Bitmap) []Contact {
s.lock.RLock() s.lock.RLock()
defer s.lock.RUnlock() defer s.lock.RUnlock()
@ -51,11 +45,11 @@ func (s *storeImpl) Get(blobHash Bitmap) []Contact {
return contacts return contacts
} }
func (s *storeImpl) RemoveTODO(contact Contact) { func (s *contactStore) RemoveTODO(contact Contact) {
// TODO: remove peer from everywhere // TODO: remove peer from everywhere
} }
func (s *storeImpl) CountStoredHashes() int { func (s *contactStore) CountStoredHashes() int {
s.lock.RLock() s.lock.RLock()
defer s.lock.RUnlock() defer s.lock.RUnlock()
return len(s.hashes) return len(s.hashes)

View file

@ -215,7 +215,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
continue continue
} }
for _, n := range nodes { for _, n := range nodes {
if n.id.RawString() == id { if n.ID.RawString() == id {
currNode = n currNode = n
currNodeFound = true currNodeFound = true
foundNodes[id] = true foundNodes[id] = true
@ -231,15 +231,15 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
ip, ok := contact[1].(string) ip, ok := contact[1].(string)
if !ok { if !ok {
t.Error("contact IP is not a string") t.Error("contact IP is not a string")
} else if !currNode.ip.Equal(net.ParseIP(ip)) { } else if !currNode.IP.Equal(net.ParseIP(ip)) {
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String()) t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.IP.String())
} }
port, ok := contact[2].(int64) port, ok := contact[2].(int64)
if !ok { if !ok {
t.Error("contact port is not an int") t.Error("contact port is not an int")
} else if int(port) != currNode.port { } else if int(port) != currNode.Port {
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port) t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.Port)
} }
} }
} }
@ -269,29 +269,38 @@ func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact
var currNode Contact var currNode Contact
currNodeFound := false currNodeFound := false
if _, ok := foundNodes[contact.id.Hex()]; ok { if _, ok := foundNodes[contact.ID.Hex()]; ok {
t.Errorf("contact %s appears multiple times", contact.id.Hex()) t.Errorf("contact %s appears multiple times", contact.ID.Hex())
continue continue
} }
for _, n := range nodes { for _, n := range nodes {
if n.id.Equals(contact.id) { if n.ID.Equals(contact.ID) {
currNode = n currNode = n
currNodeFound = true currNodeFound = true
foundNodes[contact.id.Hex()] = true foundNodes[contact.ID.Hex()] = true
break break
} }
} }
if !currNodeFound { if !currNodeFound {
t.Errorf("unexpected contact %s", contact.id.Hex()) t.Errorf("unexpected contact %s", contact.ID.Hex())
continue continue
} }
if !currNode.ip.Equal(contact.ip) { if !currNode.IP.Equal(contact.IP) {
t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String()) t.Errorf("contact IP mismatch. got %s; expected %s", contact.IP.String(), currNode.IP.String())
} }
if contact.port != currNode.port { if contact.Port != currNode.Port {
t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port) t.Errorf("contact port mismatch. got %d; expected %d", contact.Port, currNode.Port)
} }
} }
} }
func assertPanic(t *testing.T, text string, f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("%s: did not panic as expected", text)
}
}()
f()
}