added routing table saving, bitmap operations, lots of tests

This commit is contained in:
Alex Grintsvayg 2018-05-19 13:05:30 -04:00
parent b9ee0b0644
commit 14cceda81e
15 changed files with 3858 additions and 244 deletions

View file

@ -1,8 +1,10 @@
package dht
import (
"bytes"
"crypto/rand"
"encoding/hex"
"strings"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode"
@ -14,6 +16,19 @@ func (b Bitmap) RawString() string {
return string(b[:])
}
// BString returns the bitmap as a string of 0s and 1s
func (b Bitmap) BString() string {
var buf bytes.Buffer
for i := 0; i < nodeIDBits; i++ {
if b.Get(i) {
buf.WriteString("1")
} else {
buf.WriteString("0")
}
}
return buf.String()
}
func (b Bitmap) Hex() string {
return hex.EncodeToString(b[:])
}
@ -22,6 +37,14 @@ func (b Bitmap) HexShort() string {
return hex.EncodeToString(b[:4])
}
func (b Bitmap) HexSimplified() string {
simple := strings.TrimLeft(b.Hex(), "0")
if simple == "" {
simple = "0"
}
return simple
}
func (b Bitmap) Equals(other Bitmap) bool {
for k := range b {
if b[k] != other[k] {
@ -40,6 +63,35 @@ func (b Bitmap) Less(other interface{}) bool {
return false
}
func (b Bitmap) LessOrEqual(other interface{}) bool {
if bm, ok := other.(Bitmap); ok && b.Equals(bm) {
return true
}
return b.Less(other)
}
func (b Bitmap) Greater(other interface{}) bool {
for k := range b {
if b[k] != other.(Bitmap)[k] {
return b[k] > other.(Bitmap)[k]
}
}
return false
}
func (b Bitmap) GreaterOrEqual(other interface{}) bool {
if bm, ok := other.(Bitmap); ok && b.Equals(bm) {
return true
}
return b.Greater(other)
}
func (b Bitmap) Copy() Bitmap {
var ret Bitmap
copy(ret[:], b[:])
return ret
}
func (b Bitmap) Xor(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
@ -48,6 +100,69 @@ func (b Bitmap) Xor(other Bitmap) Bitmap {
return ret
}
func (b Bitmap) And(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] & other[k]
}
return ret
}
func (b Bitmap) Or(other Bitmap) Bitmap {
var ret Bitmap
for k := range b {
ret[k] = b[k] | other[k]
}
return ret
}
func (b Bitmap) Not() Bitmap {
var ret Bitmap
for k := range b {
ret[k] = ^b[k]
}
return ret
}
func (b Bitmap) add(other Bitmap) (Bitmap, bool) {
var ret Bitmap
carry := false
for i := nodeIDBits - 1; i >= 0; i-- {
bBit := getBit(b[:], i)
oBit := getBit(other[:], i)
setBit(ret[:], i, bBit != oBit != carry)
carry = (bBit && oBit) || (bBit && carry) || (oBit && carry)
}
return ret, carry
}
func (b Bitmap) Add(other Bitmap) Bitmap {
ret, carry := b.add(other)
if carry {
panic("overflow in bitmap addition")
}
return ret
}
func (b Bitmap) Sub(other Bitmap) Bitmap {
if b.Less(other) {
panic("negative bitmaps not supported")
}
complement, _ := other.Not().add(BitmapFromShortHexP("1"))
ret, _ := b.add(complement)
return ret
}
func (b Bitmap) Get(n int) bool {
return getBit(b[:], n)
}
func (b Bitmap) Set(n int, one bool) Bitmap {
ret := b.Copy()
setBit(ret[:], n, one)
return ret
}
// PrefixLen returns the number of leading 0 bits
func (b Bitmap) PrefixLen() int {
for i := range b {
@ -57,20 +172,46 @@ func (b Bitmap) PrefixLen() int {
}
}
}
return numBuckets
return nodeIDBits
}
// ZeroPrefix returns a copy of b with the first n bits set to 0
// Prefix returns a copy of b with the first n bits set to 1 (if `one` is true) or 0 (if `one` is false)
// https://stackoverflow.com/a/23192263/182709
func (b Bitmap) ZeroPrefix(n int) Bitmap {
var ret Bitmap
copy(ret[:], b[:])
func (b Bitmap) Prefix(n int, one bool) Bitmap {
ret := b.Copy()
Outer:
for i := range ret {
for j := 0; j < 8; j++ {
if i*8+j < n {
if one {
ret[i] |= 1 << uint(7-j)
} else {
ret[i] &= ^(1 << uint(7-j))
}
} else {
break Outer
}
}
}
return ret
}
// Syffix returns a copy of b with the last n bits set to 1 (if `one` is true) or 0 (if `one` is false)
// https://stackoverflow.com/a/23192263/182709
func (b Bitmap) Suffix(n int, one bool) Bitmap {
ret := b.Copy()
Outer:
for i := len(ret) - 1; i >= 0; i-- {
for j := 7; j >= 0; j-- {
if i*8+j >= nodeIDBits-n {
if one {
ret[i] |= 1 << uint(7-j)
} else {
ret[i] &= ^(1 << uint(7-j))
}
} else {
break Outer
}
@ -145,6 +286,18 @@ func BitmapFromHexP(hexStr string) Bitmap {
return bmp
}
func BitmapFromShortHex(hexStr string) (Bitmap, error) {
return BitmapFromHex(strings.Repeat("0", nodeIDLength*2-len(hexStr)) + hexStr)
}
func BitmapFromShortHexP(hexStr string) Bitmap {
bmp, err := BitmapFromShortHex(hexStr)
if err != nil {
panic(err)
}
return bmp
}
func RandomBitmapP() Bitmap {
var id Bitmap
_, err := rand.Read(id[:])
@ -153,3 +306,28 @@ func RandomBitmapP() Bitmap {
}
return id
}
func RandomBitmapInRangeP(low, high Bitmap) Bitmap {
diff := high.Sub(low)
r := RandomBitmapP()
for r.Greater(diff) {
r = r.Sub(diff)
}
return r.Add(low)
}
func getBit(b []byte, n int) bool {
i := n / 8
j := n % 8
return b[i]&(1<<uint(7-j)) > 0
}
func setBit(b []byte, n int, one bool) {
i := n / 8
j := n % 8
if one {
b[i] |= 1 << uint(7-j)
} else {
b[i] &= ^(1 << uint(7-j))
}
}

View file

@ -1,6 +1,7 @@
package dht
import (
"fmt"
"testing"
"github.com/lyoshenka/bencode"
@ -51,6 +52,84 @@ func TestBitmap(t *testing.T) {
}
}
func TestBitmap_GetBit(t *testing.T) {
tt := []struct {
hex string
bit int
expected bool
panic bool
}{
//{hex: "0", bit: 385, one: true, expected: "1", panic:true}, // should error
//{hex: "0", bit: 384, one: true, expected: "1", panic:true},
{bit: 383, expected: false, panic: false},
{bit: 382, expected: true, panic: false},
{bit: 381, expected: false, panic: false},
{bit: 380, expected: true, panic: false},
}
b := BitmapFromShortHexP("a")
for _, test := range tt {
actual := getBit(b[:], test.bit)
if test.expected != actual {
t.Errorf("getting bit %d of %s: expected %t, got %t", test.bit, b.HexSimplified(), test.expected, actual)
}
}
}
func TestBitmap_SetBit(t *testing.T) {
tt := []struct {
hex string
bit int
one bool
expected string
panic bool
}{
{hex: "0", bit: 383, one: true, expected: "1", panic: false},
{hex: "0", bit: 382, one: true, expected: "2", panic: false},
{hex: "0", bit: 381, one: true, expected: "4", panic: false},
{hex: "0", bit: 385, one: true, expected: "1", panic: true},
{hex: "0", bit: 384, one: true, expected: "1", panic: true},
}
for _, test := range tt {
expected := BitmapFromShortHexP(test.expected)
actual := BitmapFromShortHexP(test.hex)
if test.panic {
assertPanic(t, fmt.Sprintf("setting bit %d to %t", test.bit, test.one), func() { setBit(actual[:], test.bit, test.one) })
} else {
setBit(actual[:], test.bit, test.one)
if !expected.Equals(actual) {
t.Errorf("setting bit %d to %t: expected %s, got %s", test.bit, test.one, test.expected, actual.HexSimplified())
}
}
}
}
func TestBitmap_FromHexShort(t *testing.T) {
tt := []struct {
short string
long string
}{
{short: "", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{short: "0", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{short: "00000", long: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{short: "9473745bc", long: "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000009473745bc"},
{short: "09473745bc", long: "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000009473745bc"},
{short: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
long: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"},
}
for _, test := range tt {
short := BitmapFromShortHexP(test.short)
long := BitmapFromHexP(test.long)
if !short.Equals(long) {
t.Errorf("short hex %s: expected %s, got %s", test.short, long.Hex(), short.Hex())
}
}
}
func TestBitmapMarshal(t *testing.T) {
b := BitmapFromStringP("123456789012345678901234567890123456789012345678")
encoded, err := bencode.EncodeBytes(b)
@ -120,10 +199,10 @@ func TestBitmap_PrefixLen(t *testing.T) {
}
}
func TestBitmap_ZeroPrefix(t *testing.T) {
original := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
func TestBitmap_Prefix(t *testing.T) {
allOne := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
tt := []struct {
zerosTT := []struct {
zeros int
expected string
}{
@ -136,18 +215,162 @@ func TestBitmap_ZeroPrefix(t *testing.T) {
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
}
for _, test := range tt {
for _, test := range zerosTT {
expected := BitmapFromHexP(test.expected)
actual := original.ZeroPrefix(test.zeros)
actual := allOne.Prefix(test.zeros, false)
if !actual.Equals(expected) {
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
}
}
for i := 0; i < nodeIDLength*8; i++ {
b := original.ZeroPrefix(i)
b := allOne.Prefix(i, false)
if b.PrefixLen() != i {
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
}
}
allZero := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
onesTT := []struct {
ones int
expected string
}{
{ones: -123, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 0, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 1, expected: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 69, expected: "fffffffffffffffff8000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 383, expected: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"},
{ones: 384, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{ones: 400, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
}
for _, test := range onesTT {
expected := BitmapFromHexP(test.expected)
actual := allZero.Prefix(test.ones, true)
if !actual.Equals(expected) {
t.Errorf("%d ones: got %s; expected %s", test.ones, actual.Hex(), expected.Hex())
}
}
}
func TestBitmap_Suffix(t *testing.T) {
allOne := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
zerosTT := []struct {
zeros int
expected string
}{
{zeros: -123, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 0, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 1, expected: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"},
{zeros: 69, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe00000000000000000"},
{zeros: 383, expected: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{zeros: 384, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
}
for _, test := range zerosTT {
expected := BitmapFromHexP(test.expected)
actual := allOne.Suffix(test.zeros, false)
if !actual.Equals(expected) {
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
}
}
for i := 0; i < nodeIDLength*8; i++ {
b := allOne.Prefix(i, false)
if b.PrefixLen() != i {
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
}
}
allZero := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
onesTT := []struct {
ones int
expected string
}{
{ones: -123, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 0, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{ones: 1, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
{ones: 69, expected: "0000000000000000000000000000000000000000000000000000000000000000000000000000001fffffffffffffffff"},
{ones: 383, expected: "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{ones: 384, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{ones: 400, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
}
for _, test := range onesTT {
expected := BitmapFromHexP(test.expected)
actual := allZero.Suffix(test.ones, true)
if !actual.Equals(expected) {
t.Errorf("%d ones: got %s; expected %s", test.ones, actual.Hex(), expected.Hex())
}
}
}
func TestBitmap_Add(t *testing.T) {
tt := []struct {
a, b, sum string
panic bool
}{
{"0", "0", "0", false},
{"0", "1", "1", false},
{"1", "0", "1", false},
{"1", "1", "2", false},
{"8", "4", "c", false},
{"1000", "0010", "1010", false},
{"1111", "1111", "2222", false},
{"ffff", "1", "10000", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "1", "", true},
}
for _, test := range tt {
a := BitmapFromShortHexP(test.a)
b := BitmapFromShortHexP(test.b)
expected := BitmapFromShortHexP(test.sum)
if test.panic {
assertPanic(t, fmt.Sprintf("adding %s and %s", test.a, test.b), func() { a.Add(b) })
} else {
actual := a.Add(b)
if !expected.Equals(actual) {
t.Errorf("adding %s and %s; expected %s, got %s", test.a, test.b, test.sum, actual.HexSimplified())
}
}
}
}
func TestBitmap_Sub(t *testing.T) {
tt := []struct {
a, b, sum string
panic bool
}{
{"0", "0", "0", false},
{"1", "0", "1", false},
{"1", "1", "0", false},
{"8", "4", "4", false},
{"f", "9", "6", false},
{"f", "e", "1", false},
{"10", "f", "1", false},
{"2222", "1111", "1111", false},
{"ffff", "1", "fffe", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", false},
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0", false},
{"0", "1", "", true},
}
for _, test := range tt {
a := BitmapFromShortHexP(test.a)
b := BitmapFromShortHexP(test.b)
expected := BitmapFromShortHexP(test.sum)
if test.panic {
assertPanic(t, fmt.Sprintf("subtracting %s - %s", test.a, test.b), func() { a.Sub(b) })
} else {
actual := a.Sub(b)
if !expected.Equals(actual) {
t.Errorf("subtracting %s - %s; expected %s, got %s", test.a, test.b, test.sum, actual.HexSimplified())
}
}
}
}

View file

@ -1,7 +1,6 @@
package dht
import (
"context"
"math/rand"
"net"
"sync"
@ -14,25 +13,6 @@ const (
bootstrapDefaultRefreshDuration = 15 * time.Minute
)
type nullStore struct{}
func (n nullStore) Upsert(id Bitmap, c Contact) {}
func (n nullStore) Get(id Bitmap) []Contact { return nil }
func (n nullStore) CountStoredHashes() int { return 0 }
type nullRoutingTable struct{}
// TODO: the bootstrap logic *could* be implemented just in the routing table, without a custom request handler
// TODO: the only tricky part is triggering the ping when Fresh is called, as the rt doesnt have access to the node
func (n nullRoutingTable) Update(c Contact) {} // this
func (n nullRoutingTable) Fresh(c Contact) {} // this
func (n nullRoutingTable) Fail(c Contact) {} // this
func (n nullRoutingTable) GetClosest(id Bitmap, limit int) []Contact { return nil } // this
func (n nullRoutingTable) Count() int { return 0 }
func (n nullRoutingTable) GetIDsForRefresh(d time.Duration) []Bitmap { return nil }
func (n nullRoutingTable) BucketInfo() string { return "" }
type BootstrapNode struct {
Node
@ -57,8 +37,6 @@ func NewBootstrapNode(id Bitmap, initialPingInterval, rePingInterval time.Durati
nodeKeys: make(map[Bitmap]int),
}
b.rt = &nullRoutingTable{}
b.store = &nullStore{}
b.requestHandler = b.handleRequest
return b
@ -98,14 +76,14 @@ func (b *BootstrapNode) upsert(c Contact) {
b.nlock.Lock()
defer b.nlock.Unlock()
if i, exists := b.nodeKeys[c.id]; exists {
log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].contact.id.HexShort())
if i, exists := b.nodeKeys[c.ID]; exists {
log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].Contact.ID.HexShort())
b.nodes[i].Touch()
return
}
log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.id.HexShort())
b.nodeKeys[c.id] = len(b.nodes)
log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.ID.HexShort())
b.nodeKeys[c.ID] = len(b.nodes)
b.nodes = append(b.nodes, peer{c, time.Now(), 0})
}
@ -114,14 +92,14 @@ func (b *BootstrapNode) remove(c Contact) {
b.nlock.Lock()
defer b.nlock.Unlock()
i, exists := b.nodeKeys[c.id]
i, exists := b.nodeKeys[c.ID]
if !exists {
return
}
log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.id.HexShort())
log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.ID.HexShort())
b.nodes = append(b.nodes[:i], b.nodes[i+1:]...)
delete(b.nodeKeys, c.id)
delete(b.nodeKeys, c.ID)
}
// get returns up to `limit` random contacts from the list
@ -135,7 +113,7 @@ func (b *BootstrapNode) get(limit int) []Contact {
ret := make([]Contact, limit)
for i, k := range randKeys(len(b.nodes))[:limit] {
ret[i] = b.nodes[k].contact
ret[i] = b.nodes[k].Contact
}
return ret
@ -146,8 +124,7 @@ func (b *BootstrapNode) ping(c Contact) {
b.stopWG.Add(1)
defer b.stopWG.Done()
ctx, cancel := context.WithCancel(context.Background())
resCh := b.SendAsync(ctx, c, Request{Method: pingMethod})
resCh, cancel := b.SendCancelable(c, Request{Method: pingMethod})
var res *Response
@ -171,7 +148,7 @@ func (b *BootstrapNode) check() {
for i := range b.nodes {
if !b.nodes[i].ActiveInLast(b.checkInterval) {
go b.ping(b.nodes[i].contact)
go b.ping(b.nodes[i].Contact)
}
}
}
@ -196,7 +173,7 @@ func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) {
go func() {
log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort())
<-time.After(b.initialPingInterval)
b.ping(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port})
b.ping(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
}()
}

View file

@ -1,7 +1,6 @@
package dht
import (
"context"
"fmt"
"net"
"strings"
@ -28,6 +27,7 @@ const (
alpha = 3 // this is the constant alpha in the spec
bucketSize = 8 // this is the constant k in the spec
nodeIDLength = 48 // bytes. this is the constant B in the spec
nodeIDBits = nodeIDLength * 8 // number of bits in node ID
messageIDLength = 20 // bytes.
udpRetry = 3
@ -42,7 +42,6 @@ const (
tRepublish = 24 * time.Hour // the time after which the original publisher must republish a key/value pair
tNodeRefresh = 15 * time.Minute // the time after which a good node becomes questionable if it has not messaged us
numBuckets = nodeIDLength * 8
compactNodeInfoLength = nodeIDLength + 6 // nodeID + 4 for IP + 2 for port
tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated
@ -102,7 +101,7 @@ func New(config *Config) (*DHT, error) {
d := &DHT{
conf: config,
contact: contact,
node: NewNode(contact.id),
node: NewNode(contact.ID),
stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
joined: make(chan struct{}),
@ -181,7 +180,7 @@ func (dht *DHT) Ping(addr string) error {
return err
}
tmpNode := Contact{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port}
tmpNode := Contact{ID: RandomBitmapP(), IP: raddr.IP, Port: raddr.Port}
res := dht.node.Send(tmpNode, Request{Method: pingMethod})
if res == nil {
return errors.Err("no response from node %s", addr)
@ -214,19 +213,25 @@ func (dht *DHT) Announce(hash Bitmap) error {
// TODO: if this node is closer than farthest peer, store locally and pop farthest peer
for _, node := range res.Contacts {
go dht.storeOnNode(hash, node)
wg := &sync.WaitGroup{}
for _, c := range res.Contacts {
wg.Add(1)
go func(c Contact) {
dht.storeOnNode(hash, c)
wg.Done()
}(c)
}
wg.Wait()
return nil
}
func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
func (dht *DHT) storeOnNode(hash Bitmap, c Contact) {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
ctx, cancel := context.WithCancel(context.Background())
resCh := dht.node.SendAsync(ctx, node, Request{
resCh, cancel := dht.node.SendCancelable(c, Request{
Method: findValueMethod,
Arg: &hash,
})
@ -244,15 +249,14 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
return // request timed out
}
ctx, cancel = context.WithCancel(context.Background())
resCh = dht.node.SendAsync(ctx, node, Request{
resCh, cancel = dht.node.SendCancelable(c, Request{
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: hash,
Value: storeArgsValue{
Token: res.Token,
LbryID: dht.contact.id,
Port: dht.contact.port,
LbryID: dht.contact.ID,
Port: dht.contact.Port,
},
},
})
@ -276,18 +280,12 @@ func (dht *DHT) PrintState() {
}
}
func printNodeList(list []Contact) {
for i, n := range list {
log.Printf("%d) %s", i, n.String())
}
}
func getContact(nodeID, addr string) (Contact, error) {
var c Contact
if nodeID == "" {
c.id = RandomBitmapP()
c.ID = RandomBitmapP()
} else {
c.id = BitmapFromHexP(nodeID)
c.ID = BitmapFromHexP(nodeID)
}
ip, port, err := net.SplitHostPort(addr)
@ -299,12 +297,12 @@ func getContact(nodeID, addr string) (Contact, error) {
return c, errors.Err("address does not contain a port")
}
c.ip = net.ParseIP(ip)
if c.ip == nil {
c.IP = net.ParseIP(ip)
if c.IP == nil {
return c, errors.Err("invalid ip")
}
c.port, err = cast.ToIntE(port)
c.Port, err = cast.ToIntE(port)
if err != nil {
return c, errors.Err(err)
}

View file

@ -36,13 +36,13 @@ func TestNodeFinder_FindNodes(t *testing.T) {
foundTwo := false
for _, n := range foundNodes {
if n.id.Equals(bs.id) {
if n.ID.Equals(bs.id) {
foundBootstrap = true
}
if n.id.Equals(dhts[0].node.id) {
if n.ID.Equals(dhts[0].node.id) {
foundOne = true
}
if n.id.Equals(dhts[1].node.id) {
if n.ID.Equals(dhts[1].node.id) {
foundTwo = true
}
}
@ -83,7 +83,7 @@ func TestNodeFinder_FindValue(t *testing.T) {
}()
blobHashToFind := RandomBitmapP()
nodeToFind := Contact{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
nodeToFind := Contact{ID: RandomBitmapP(), IP: net.IPv4(1, 2, 3, 4), Port: 5678}
dhts[0].node.store.Upsert(blobHashToFind, nodeToFind)
nf := newContactFinder(dhts[2].node, blobHashToFind, true)
@ -101,8 +101,8 @@ func TestNodeFinder_FindValue(t *testing.T) {
t.Fatalf("expected one node, found %d", len(foundNodes))
}
if !foundNodes[0].id.Equals(nodeToFind.id) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex())
if !foundNodes[0].ID.Equals(nodeToFind.ID) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].ID.Hex(), nodeToFind.ID.Hex())
}
}
@ -139,7 +139,7 @@ func TestDHT_LargeDHT(t *testing.T) {
c := d2.node.rt.GetClosest(d.node.id, 1)
if len(c) > 1 {
t.Error("rt returned more than one node when only one requested")
} else if len(c) == 1 && c[0].id.Equals(d.node.id) {
} else if len(c) == 1 && c[0].ID.Equals(d.node.id) {
rtCounts[d.node.id]++
}
}

File diff suppressed because it is too large Load diff

View file

@ -240,7 +240,7 @@ func (r Response) ArgsDebug() string {
str += "|"
for _, c := range r.Contacts {
str += c.Addr().String() + ":" + c.id.HexShort() + ","
str += c.Addr().String() + ":" + c.ID.HexShort() + ","
}
str = strings.TrimRight(str, ",") + "|"

View file

@ -78,8 +78,8 @@ func TestBencodeFindNodesResponse(t *testing.T) {
ID: newMessageID(),
NodeID: RandomBitmapP(),
Contacts: []Contact{
{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{id: RandomBitmapP(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
{ID: RandomBitmapP(), IP: net.IPv4(1, 2, 3, 4).To4(), Port: 5678},
{ID: RandomBitmapP(), IP: net.IPv4(4, 3, 2, 1).To4(), Port: 8765},
},
}
@ -104,7 +104,7 @@ func TestBencodeFindValueResponse(t *testing.T) {
FindValueKey: RandomBitmapP().RawString(),
Token: "arst",
Contacts: []Contact{
{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{ID: RandomBitmapP(), IP: net.IPv4(1, 2, 3, 4).To4(), Port: 5678},
},
}

View file

@ -48,9 +48,9 @@ type Node struct {
transactions map[messageID]*transaction
// routing table
rt RoutingTable
rt *routingTable
// data store
store Store
store *contactStore
// overrides for request handlers
requestHandler RequestHandlerFunc
@ -238,7 +238,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
n.store.Upsert(request.StoreArgs.BlobHash, Contact{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
n.store.Upsert(request.StoreArgs.BlobHash, Contact{ID: request.StoreArgs.NodeID, IP: addr.IP, Port: request.StoreArgs.Value.Port})
n.sendMessage(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse})
} else {
n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"})
@ -280,7 +280,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
// the routing table must only contain "good" nodes, which are nodes that reply to our requests
// if a node is already good (aka in the table), its fine to refresh it
// http://www.bittorrent.org/beps/bep_0005.html#routing-table
n.rt.Fresh(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port})
n.rt.Fresh(Contact{ID: request.NodeID, IP: addr.IP, Port: addr.Port})
}
// handleResponse handles responses received from udp.
@ -290,13 +290,13 @@ func (n *Node) handleResponse(addr *net.UDPAddr, response Response) {
tx.res <- response
}
n.rt.Update(Contact{id: response.NodeID, ip: addr.IP, port: addr.Port})
n.rt.Update(Contact{ID: response.NodeID, IP: addr.IP, Port: addr.Port})
}
// handleError handles errors received from udp.
func (n *Node) handleError(addr *net.UDPAddr, e Error) {
spew.Dump(e)
n.rt.Fresh(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port})
n.rt.Fresh(Contact{ID: e.NodeID, IP: addr.IP, Port: addr.Port})
}
// send sends data to a udp address
@ -361,7 +361,7 @@ func (n *Node) txFind(id messageID, addr *net.UDPAddr) *transaction {
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
// The response channel is closed when the transaction is completed or times out.
func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-chan *Response {
if contact.id.Equals(n.id) {
if contact.ID.Equals(n.id) {
log.Error("sending query to self")
return nil
}
@ -413,6 +413,12 @@ func (n *Node) Send(contact Contact, req Request) *Response {
return <-n.SendAsync(context.Background(), contact, req)
}
// SendCancelable sends the transaction asynchronously and allows the transaction to be canceled
func (n *Node) SendCancelable(contact Contact, req Request) (<-chan *Response, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
return n.SendAsync(ctx, contact, req), cancel
}
// Count returns the number of transactions in the manager
func (n *Node) CountActiveTransactions() int {
n.txLock.Lock()

View file

@ -1,7 +1,6 @@
package dht
import (
"context"
"sort"
"sync"
"time"
@ -113,7 +112,7 @@ func (cf *contactFinder) iterationWorker(num int) {
} else {
contact := *maybeContact
if contact.id.Equals(cf.node.id) {
if contact.ID.Equals(cf.node.id) {
continue // cannot contact self
}
@ -124,13 +123,12 @@ func (cf *contactFinder) iterationWorker(num int) {
req.Method = findNodeMethod
}
log.Debugf("[%s] worker %d: contacting %s", cf.node.id.HexShort(), num, contact.id.HexShort())
log.Debugf("[%s] worker %d: contacting %s", cf.node.id.HexShort(), num, contact.ID.HexShort())
cf.incrementOutstanding()
var res *Response
ctx, cancel := context.WithCancel(context.Background())
resCh := cf.node.SendAsync(ctx, contact, req)
resCh, cancel := cf.node.SendCancelable(contact, req)
select {
case res = <-resCh:
case <-cf.done.Chan():
@ -141,7 +139,7 @@ func (cf *contactFinder) iterationWorker(num int) {
if res == nil {
// nothing to do, response timed out
log.Debugf("[%s] worker %d: search canceled or timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort())
log.Debugf("[%s] worker %d: search canceled or timed out waiting for %s", cf.node.id.HexShort(), num, contact.ID.HexShort())
} else if cf.findValue && res.FindValueKey != "" {
log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num)
cf.findValueMutex.Lock()
@ -171,9 +169,9 @@ func (cf *contactFinder) appendNewToShortlist(contacts []Contact) {
defer cf.shortlistMutex.Unlock()
for _, c := range contacts {
if _, ok := cf.shortlistAdded[c.id]; !ok {
if _, ok := cf.shortlistAdded[c.ID]; !ok {
cf.shortlist = append(cf.shortlist, c)
cf.shortlistAdded[c.id] = true
cf.shortlistAdded[c.ID] = true
}
}
@ -199,7 +197,7 @@ func (cf *contactFinder) insertIntoActiveList(contact Contact) {
inserted := false
for i, n := range cf.activeContacts {
if contact.id.Xor(cf.target).Less(n.id.Xor(cf.target)) {
if contact.ID.Xor(cf.target).Less(n.ID.Xor(cf.target)) {
cf.activeContacts = append(cf.activeContacts[:i], append([]Contact{contact}, cf.activeContacts[i:]...)...)
inserted = true
break
@ -232,7 +230,7 @@ func (cf *contactFinder) isSearchFinished() bool {
cf.activeContactsMutex.Lock()
defer cf.activeContactsMutex.Unlock()
if len(cf.activeContacts) >= bucketSize && cf.activeContacts[bucketSize-1].id.Xor(cf.target).Less(cf.shortlist[0].id.Xor(cf.target)) {
if len(cf.activeContacts) >= bucketSize && cf.activeContacts[bucketSize-1].ID.Xor(cf.target).Less(cf.shortlist[0].ID.Xor(cf.target)) {
// we have at least K active contacts, and we don't have any closer contacts to ping
return true
}
@ -263,7 +261,7 @@ func sortInPlace(contacts []Contact, target Bitmap) {
toSort := make([]sortedContact, len(contacts))
for i, n := range contacts {
toSort[i] = sortedContact{n, n.id.Xor(target)}
toSort[i] = sortedContact{n, n.ID.Xor(target)}
}
sort.Sort(byXorDistance(toSort))

View file

@ -198,7 +198,7 @@ func TestStore(t *testing.T) {
if len(items) != 1 {
t.Error("list created in store, but nothing in list")
}
if !items[0].id.Equals(testNodeID) {
if !items[0].ID.Equals(testNodeID) {
t.Error("wrong value stored")
}
}
@ -223,7 +223,7 @@ func TestFindNode(t *testing.T) {
nodesToInsert := 3
var nodes []Contact
for i := 0; i < nodesToInsert; i++ {
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
n := Contact{ID: RandomBitmapP(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
nodes = append(nodes, n)
dht.node.rt.Update(n)
}
@ -292,7 +292,7 @@ func TestFindValueExisting(t *testing.T) {
nodesToInsert := 3
var nodes []Contact
for i := 0; i < nodesToInsert; i++ {
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
n := Contact{ID: RandomBitmapP(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
nodes = append(nodes, n)
dht.node.rt.Update(n)
}
@ -302,7 +302,7 @@ func TestFindValueExisting(t *testing.T) {
messageID := newMessageID()
valueToFind := RandomBitmapP()
nodeToFind := Contact{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286}
nodeToFind := Contact{ID: RandomBitmapP(), IP: net.ParseIP("1.2.3.4"), Port: 1286}
dht.node.store.Upsert(valueToFind, nodeToFind)
dht.node.store.Upsert(valueToFind, nodeToFind)
dht.node.store.Upsert(valueToFind, nodeToFind)
@ -378,7 +378,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
nodesToInsert := 3
var nodes []Contact
for i := 0; i < nodesToInsert; i++ {
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
n := Contact{ID: RandomBitmapP(), IP: net.ParseIP("127.0.0.1"), Port: 10000 + i}
nodes = append(nodes, n)
dht.node.rt.Update(n)
}

View file

@ -2,10 +2,11 @@ package dht
import (
"bytes"
"container/list"
"encoding/json"
"fmt"
"net"
"sort"
"strconv"
"strings"
"sync"
"time"
@ -16,32 +17,32 @@ import (
)
type Contact struct {
id Bitmap
ip net.IP
port int
ID Bitmap
IP net.IP
Port int
}
func (c Contact) Addr() *net.UDPAddr {
return &net.UDPAddr{IP: c.ip, Port: c.port}
return &net.UDPAddr{IP: c.IP, Port: c.Port}
}
func (c Contact) String() string {
return c.id.HexShort() + "@" + c.Addr().String()
return c.ID.HexShort() + "@" + c.Addr().String()
}
func (c Contact) MarshalCompact() ([]byte, error) {
if c.ip.To4() == nil {
if c.IP.To4() == nil {
return nil, errors.Err("ip not set")
}
if c.port < 0 || c.port > 65535 {
if c.Port < 0 || c.Port > 65535 {
return nil, errors.Err("invalid port")
}
var buf bytes.Buffer
buf.Write(c.ip.To4())
buf.WriteByte(byte(c.port >> 8))
buf.WriteByte(byte(c.port))
buf.Write(c.id[:])
buf.Write(c.IP.To4())
buf.WriteByte(byte(c.Port >> 8))
buf.WriteByte(byte(c.Port))
buf.Write(c.ID[:])
if buf.Len() != compactNodeInfoLength {
return nil, errors.Err("i dont know how this happened")
@ -54,14 +55,14 @@ func (c *Contact) UnmarshalCompact(b []byte) error {
if len(b) != compactNodeInfoLength {
return errors.Err("invalid compact length")
}
c.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
c.port = int(uint16(b[5]) | uint16(b[4])<<8)
c.id = BitmapFromBytesP(b[6:])
c.IP = net.IPv4(b[0], b[1], b[2], b[3]).To4()
c.Port = int(uint16(b[5]) | uint16(b[4])<<8)
c.ID = BitmapFromBytesP(b[6:])
return nil
}
func (c Contact) MarshalBencode() ([]byte, error) {
return bencode.EncodeBytes([]interface{}{c.id, c.ip.String(), c.port})
return bencode.EncodeBytes([]interface{}{c.ID, c.IP.String(), c.Port})
}
func (c *Contact) UnmarshalBencode(b []byte) error {
@ -75,7 +76,7 @@ func (c *Contact) UnmarshalBencode(b []byte) error {
return errors.Err("contact must have 3 elements; got %d", len(raw))
}
err = bencode.DecodeBytes(raw[0], &c.id)
err = bencode.DecodeBytes(raw[0], &c.ID)
if err != nil {
return err
}
@ -85,12 +86,12 @@ func (c *Contact) UnmarshalBencode(b []byte) error {
if err != nil {
return err
}
c.ip = net.ParseIP(ipStr).To4()
if c.ip == nil {
c.IP = net.ParseIP(ipStr).To4()
if c.IP == nil {
return errors.Err("invalid IP")
}
err = bencode.DecodeBytes(raw[2], &c.port)
err = bencode.DecodeBytes(raw[2], &c.Port)
if err != nil {
return err
}
@ -113,52 +114,38 @@ func (a byXorDistance) Less(i, j int) bool {
// peer is a contact with extra freshness information
type peer struct {
contact Contact
lastActivity time.Time
numFailures int
Contact Contact
LastActivity time.Time
NumFailures int
//<lastPublished>,
//<originallyPublished>
// <originalPublisherID>
}
func (p *peer) Touch() {
p.lastActivity = time.Now()
p.numFailures = 0
p.LastActivity = time.Now()
p.NumFailures = 0
}
// ActiveSince returns whether a peer has responded in the last `d` duration
// this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests
func (p *peer) ActiveInLast(d time.Duration) bool {
return time.Now().Sub(p.lastActivity) > d
return time.Now().Sub(p.LastActivity) > d
}
// IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row
func (p *peer) IsBad(maxFalures int) bool {
return p.numFailures >= maxFalures
return p.NumFailures >= maxFalures
}
// Fail marks a peer as having failed to respond. It returns whether or not the peer should be removed from the routing table
func (p *peer) Fail() {
p.numFailures++
}
// toPeer converts a generic *list.Element into a *peer
// this (along with newPeer) keeps all conversions between *list.Element and peer in one place
func toPeer(el *list.Element) *peer {
return el.Value.(*peer)
}
// newPeer creates a new peer from a contact
// this (along with toPeer) keeps all conversions between *list.Element and peer in one place
func newPeer(c Contact) peer {
return peer{
contact: c,
}
p.NumFailures++
}
type bucket struct {
lock *sync.RWMutex
peers *list.List
peers []peer
lastUpdate time.Time
}
@ -166,16 +153,16 @@ type bucket struct {
func (b bucket) Len() int {
b.lock.RLock()
defer b.lock.RUnlock()
return b.peers.Len()
return len(b.peers)
}
// Contacts returns a slice of the bucket's contacts
func (b bucket) Contacts() []Contact {
b.lock.RLock()
defer b.lock.RUnlock()
contacts := make([]Contact, b.peers.Len())
for i, curr := 0, b.peers.Front(); curr != nil; i, curr = i+1, curr.Next() {
contacts[i] = toPeer(curr).contact
contacts := make([]Contact, len(b.peers))
for i := range b.peers {
contacts[i] = b.peers[i].Contact
}
return contacts
}
@ -185,21 +172,21 @@ func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
b.lock.Lock()
defer b.lock.Unlock()
element := find(c.id, b.peers)
if element != nil {
peerIndex := find(c.ID, b.peers)
if peerIndex >= 0 {
b.lastUpdate = time.Now()
toPeer(element).Touch()
b.peers.MoveToBack(element)
b.peers[peerIndex].Touch()
moveToBack(b.peers, peerIndex)
} else if insertIfNew {
hasRoom := true
if b.peers.Len() >= bucketSize {
if len(b.peers) >= bucketSize {
hasRoom = false
for curr := b.peers.Front(); curr != nil; curr = curr.Next() {
if toPeer(curr).IsBad(maxPeerFails) {
for i := range b.peers {
if b.peers[i].IsBad(maxPeerFails) {
// TODO: Ping contact first. Only remove if it does not respond
b.peers.Remove(curr)
b.peers = append(b.peers[:i], b.peers[i+1:]...)
hasRoom = true
break
}
@ -208,9 +195,9 @@ func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
if hasRoom {
b.lastUpdate = time.Now()
peer := newPeer(c)
peer := peer{Contact: c}
peer.Touch()
b.peers.PushBack(&peer)
b.peers = append(b.peers, peer)
}
}
}
@ -219,21 +206,21 @@ func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
func (b *bucket) FailContact(id Bitmap) {
b.lock.Lock()
defer b.lock.Unlock()
element := find(id, b.peers)
if element != nil {
i := find(id, b.peers)
if i >= 0 {
// BEP5 says not to remove the contact until the bucket is full and you try to insert
toPeer(element).Fail()
b.peers[i].Fail()
}
}
// find returns the contact in the bucket, or nil if the bucket does not contain the contact
func find(id Bitmap, peers *list.List) *list.Element {
for curr := peers.Front(); curr != nil; curr = curr.Next() {
if toPeer(curr).contact.id.Equals(id) {
return curr
func find(id Bitmap, peers []peer) int {
for i := range peers {
if peers[i].Contact.ID.Equals(id) {
return i
}
}
return nil
return -1
}
// NeedsRefresh returns true if bucket has not been updated in the last `refreshInterval`, false otherwise
@ -243,41 +230,31 @@ func (b *bucket) NeedsRefresh(refreshInterval time.Duration) bool {
return time.Now().Sub(b.lastUpdate) > refreshInterval
}
type RoutingTable interface {
Update(Contact)
Fresh(Contact)
Fail(Contact)
GetClosest(Bitmap, int) []Contact
Count() int
GetIDsForRefresh(time.Duration) []Bitmap
BucketInfo() string // for debugging
}
type routingTableImpl struct {
type routingTable struct {
id Bitmap
buckets [numBuckets]bucket
buckets [nodeIDBits]bucket
}
func newRoutingTable(id Bitmap) *routingTableImpl {
var rt routingTableImpl
func newRoutingTable(id Bitmap) *routingTable {
var rt routingTable
rt.id = id
for i := range rt.buckets {
rt.buckets[i] = bucket{
peers: list.New(),
peers: make([]peer, 0, bucketSize),
lock: &sync.RWMutex{},
}
}
return &rt
}
func (rt *routingTableImpl) BucketInfo() string {
func (rt *routingTable) BucketInfo() string {
var bucketInfo []string
for i, b := range rt.buckets {
if b.Len() > 0 {
contacts := b.Contacts()
s := make([]string, len(contacts))
for j, c := range contacts {
s[j] = c.id.HexShort()
s[j] = c.ID.HexShort()
}
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, len(contacts), strings.Join(s, ", ")))
}
@ -289,23 +266,23 @@ func (rt *routingTableImpl) BucketInfo() string {
}
// Update inserts or refreshes a contact
func (rt *routingTableImpl) Update(c Contact) {
rt.bucketFor(c.id).UpdateContact(c, true)
func (rt *routingTable) Update(c Contact) {
rt.bucketFor(c.ID).UpdateContact(c, true)
}
// Fresh refreshes a contact if its already in the routing table
func (rt *routingTableImpl) Fresh(c Contact) {
rt.bucketFor(c.id).UpdateContact(c, false)
func (rt *routingTable) Fresh(c Contact) {
rt.bucketFor(c.ID).UpdateContact(c, false)
}
// FailContact marks a contact as having failed, and removes it if it failed too many times
func (rt *routingTableImpl) Fail(c Contact) {
rt.bucketFor(c.id).FailContact(c.id)
func (rt *routingTable) Fail(c Contact) {
rt.bucketFor(c.ID).FailContact(c.ID)
}
// GetClosest returns the closest `limit` contacts from the routing table
// It marks each bucket it accesses as having been accessed
func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact {
func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact {
var toSort []sortedContact
var bucketNum int
@ -317,11 +294,11 @@ func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact {
toSort = appendContacts(toSort, rt.buckets[bucketNum], target)
for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ {
for i := 1; (bucketNum-i >= 0 || bucketNum+i < nodeIDBits) && len(toSort) < limit; i++ {
if bucketNum-i >= 0 {
toSort = appendContacts(toSort, rt.buckets[bucketNum-i], target)
}
if bucketNum+i < numBuckets {
if bucketNum+i < nodeIDBits {
toSort = appendContacts(toSort, rt.buckets[bucketNum+i], target)
}
}
@ -341,13 +318,13 @@ func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact {
func appendContacts(contacts []sortedContact, b bucket, target Bitmap) []sortedContact {
for _, contact := range b.Contacts() {
contacts = append(contacts, sortedContact{contact, contact.id.Xor(target)})
contacts = append(contacts, sortedContact{contact, contact.ID.Xor(target)})
}
return contacts
}
// Count returns the number of contacts in the routing table
func (rt *routingTableImpl) Count() int {
func (rt *routingTable) Count() int {
count := 0
for _, bucket := range rt.buckets {
count = bucket.Len()
@ -355,27 +332,99 @@ func (rt *routingTableImpl) Count() int {
return count
}
func (rt *routingTableImpl) bucketNumFor(target Bitmap) int {
type Range struct {
start Bitmap
end Bitmap
}
// BucketRanges returns a slice of ranges, where the `start` of each range is the smallest id that can
// go in that bucket, and the `end` is the largest id
func (rt *routingTable) BucketRanges() []Range {
ranges := make([]Range, len(rt.buckets))
for i := range rt.buckets {
ranges[i] = Range{
rt.id.Suffix(i, false).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)),
rt.id.Suffix(i, true).Set(nodeIDBits-1-i, !rt.id.Get(nodeIDBits-1-i)),
}
}
return ranges
}
func (rt *routingTable) bucketNumFor(target Bitmap) int {
if rt.id.Equals(target) {
panic("routing table does not have a bucket for its own id")
}
return numBuckets - 1 - target.Xor(rt.id).PrefixLen()
return nodeIDBits - 1 - target.Xor(rt.id).PrefixLen()
}
func (rt *routingTableImpl) bucketFor(target Bitmap) *bucket {
func (rt *routingTable) bucketFor(target Bitmap) *bucket {
return &rt.buckets[rt.bucketNumFor(target)]
}
func (rt *routingTableImpl) GetIDsForRefresh(refreshInterval time.Duration) []Bitmap {
func (rt *routingTable) GetIDsForRefresh(refreshInterval time.Duration) []Bitmap {
var bitmaps []Bitmap
for i, bucket := range rt.buckets {
if bucket.NeedsRefresh(refreshInterval) {
bitmaps = append(bitmaps, RandomBitmapP().ZeroPrefix(i))
bitmaps = append(bitmaps, RandomBitmapP().Prefix(i, false))
}
}
return bitmaps
}
const rtContactSep = "-"
type rtSave struct {
ID string `json:"id"`
Contacts []string `json:"contacts"`
}
func (rt *routingTable) MarshalJSON() ([]byte, error) {
var data rtSave
data.ID = rt.id.Hex()
for _, b := range rt.buckets {
for _, c := range b.Contacts() {
data.Contacts = append(data.Contacts, strings.Join([]string{c.ID.Hex(), c.IP.String(), strconv.Itoa(c.Port)}, rtContactSep))
}
}
return json.Marshal(data)
}
func (rt *routingTable) UnmarshalJSON(b []byte) error {
var data rtSave
err := json.Unmarshal(b, &data)
if err != nil {
return err
}
rt.id, err = BitmapFromHex(data.ID)
if err != nil {
return errors.Prefix("decoding ID", err)
}
for _, s := range data.Contacts {
parts := strings.Split(s, rtContactSep)
if len(parts) != 3 {
return errors.Err("decoding contact %s: wrong number of parts", s)
}
var c Contact
c.ID, err = BitmapFromHex(parts[0])
if err != nil {
return errors.Err("decoding contact %s: invalid ID: %s", s, err)
}
c.IP = net.ParseIP(parts[1])
if c.IP == nil {
return errors.Err("decoding contact %s: invalid IP", s)
}
c.Port, err = strconv.Atoi(parts[2])
if err != nil {
return errors.Err("decoding contact %s: invalid port: %s", s, err)
}
rt.Update(c)
}
return nil
}
// RoutingTableRefresh refreshes any buckets that need to be refreshed
// It returns a channel that will be closed when the refresh is done
func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} {
@ -411,3 +460,14 @@ func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan s
return done
}
func moveToBack(peers []peer, index int) {
if index < 0 || len(peers) <= index+1 {
return
}
p := peers[index]
for i := index; i < len(peers)-1; i++ {
peers[i] = peers[i+1]
}
peers[len(peers)-1] = p
}

View file

@ -1,9 +1,14 @@
package dht
import (
"encoding/json"
"net"
"reflect"
"strconv"
"strings"
"testing"
"github.com/sebdah/goldie"
)
func TestRoutingTable_bucketFor(t *testing.T) {
@ -31,7 +36,7 @@ func TestRoutingTable_bucketFor(t *testing.T) {
}
}
func TestRoutingTable(t *testing.T) {
func TestRoutingTable_GetClosest(t *testing.T) {
n1 := BitmapFromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n2 := BitmapFromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n3 := BitmapFromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
@ -44,7 +49,7 @@ func TestRoutingTable(t *testing.T) {
t.Fail()
return
}
if !contacts[0].id.Equals(n3) {
if !contacts[0].ID.Equals(n3) {
t.Error(contacts[0])
}
@ -53,19 +58,19 @@ func TestRoutingTable(t *testing.T) {
t.Error(len(contacts))
return
}
if !contacts[0].id.Equals(n2) {
if !contacts[0].ID.Equals(n2) {
t.Error(contacts[0])
}
if !contacts[1].id.Equals(n3) {
if !contacts[1].ID.Equals(n3) {
t.Error(contacts[1])
}
}
func TestCompactEncoding(t *testing.T) {
c := Contact{
id: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
ip: net.ParseIP("1.2.3.4"),
port: int(55<<8 + 66),
ID: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
IP: net.ParseIP("1.2.3.4"),
Port: int(55<<8 + 66),
}
var compact []byte
@ -78,11 +83,117 @@ func TestCompactEncoding(t *testing.T) {
t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength)
}
if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.id[:]...)) {
if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.ID[:]...)) {
t.Errorf("compact bytes not encoded correctly")
}
}
func TestRoutingTableRefresh(t *testing.T) {
func TestRoutingTable_Refresh(t *testing.T) {
t.Skip("TODO: test routing table refreshing")
}
func TestRoutingTable_MoveToBack(t *testing.T) {
tt := map[string]struct {
data []peer
index int
expected []peer
}{
"simpleMove": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 1,
expected: []peer{{NumFailures: 0}, {NumFailures: 2}, {NumFailures: 3}, {NumFailures: 1}},
},
"moveFirst": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 0,
expected: []peer{{NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}, {NumFailures: 0}},
},
"moveLast": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 3,
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
},
"largeIndex": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: 27,
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
},
"negativeIndex": {
data: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
index: -12,
expected: []peer{{NumFailures: 0}, {NumFailures: 1}, {NumFailures: 2}, {NumFailures: 3}},
},
}
for name, test := range tt {
moveToBack(test.data, test.index)
expected := make([]string, len(test.expected))
actual := make([]string, len(test.data))
for i := range actual {
actual[i] = strconv.Itoa(test.data[i].NumFailures)
expected[i] = strconv.Itoa(test.expected[i].NumFailures)
}
expJoin := strings.Join(expected, ",")
actJoin := strings.Join(actual, ",")
if actJoin != expJoin {
t.Errorf("%s failed: got %s; expected %s", name, actJoin, expJoin)
}
}
}
func TestRoutingTable_BucketRanges(t *testing.T) {
id := BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41")
ranges := newRoutingTable(id).BucketRanges()
if !ranges[0].start.Equals(ranges[0].end) {
t.Error("first bucket should only fit exactly one id")
}
for i := 0; i < 1000; i++ {
randID := RandomBitmapP()
found := -1
for i, r := range ranges {
if r.start.LessOrEqual(randID) && r.end.GreaterOrEqual(randID) {
if found >= 0 {
t.Errorf("%s appears in buckets %d and %d", randID.Hex(), found, i)
} else {
found = i
}
}
}
if found < 0 {
t.Errorf("%s did not appear in any bucket", randID.Hex())
}
}
}
func TestRoutingTable_Save(t *testing.T) {
id := BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41")
rt := newRoutingTable(id)
ranges := rt.BucketRanges()
for i, r := range ranges {
for j := 0; j < bucketSize; j++ {
toAdd := r.start.Add(BitmapFromShortHexP(strconv.Itoa(j)))
if toAdd.LessOrEqual(r.end) {
rt.Update(Contact{
ID: r.start.Add(BitmapFromShortHexP(strconv.Itoa(j))),
IP: net.ParseIP("1.2.3." + strconv.Itoa(j)),
Port: 1 + i*bucketSize + j,
})
}
}
}
data, err := json.MarshalIndent(rt, "", " ")
if err != nil {
t.Error(err)
}
goldie.Assert(t, t.Name(), data)
}
func TestRoutingTable_Load(t *testing.T) {
t.Skip("TODO")
}

View file

@ -2,13 +2,7 @@ package dht
import "sync"
type Store interface {
Upsert(Bitmap, Contact)
Get(Bitmap) []Contact
CountStoredHashes() int
}
type storeImpl struct {
type contactStore struct {
// map of blob hashes to (map of node IDs to bools)
hashes map[Bitmap]map[Bitmap]bool
// stores the peers themselves, so they can be updated in one place
@ -16,25 +10,25 @@ type storeImpl struct {
lock sync.RWMutex
}
func newStore() *storeImpl {
return &storeImpl{
func newStore() *contactStore {
return &contactStore{
hashes: make(map[Bitmap]map[Bitmap]bool),
contacts: make(map[Bitmap]Contact),
}
}
func (s *storeImpl) Upsert(blobHash Bitmap, contact Contact) {
func (s *contactStore) Upsert(blobHash Bitmap, contact Contact) {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.hashes[blobHash]; !ok {
s.hashes[blobHash] = make(map[Bitmap]bool)
}
s.hashes[blobHash][contact.id] = true
s.contacts[contact.id] = contact
s.hashes[blobHash][contact.ID] = true
s.contacts[contact.ID] = contact
}
func (s *storeImpl) Get(blobHash Bitmap) []Contact {
func (s *contactStore) Get(blobHash Bitmap) []Contact {
s.lock.RLock()
defer s.lock.RUnlock()
@ -51,11 +45,11 @@ func (s *storeImpl) Get(blobHash Bitmap) []Contact {
return contacts
}
func (s *storeImpl) RemoveTODO(contact Contact) {
func (s *contactStore) RemoveTODO(contact Contact) {
// TODO: remove peer from everywhere
}
func (s *storeImpl) CountStoredHashes() int {
func (s *contactStore) CountStoredHashes() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.hashes)

View file

@ -215,7 +215,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
continue
}
for _, n := range nodes {
if n.id.RawString() == id {
if n.ID.RawString() == id {
currNode = n
currNodeFound = true
foundNodes[id] = true
@ -231,15 +231,15 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
ip, ok := contact[1].(string)
if !ok {
t.Error("contact IP is not a string")
} else if !currNode.ip.Equal(net.ParseIP(ip)) {
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String())
} else if !currNode.IP.Equal(net.ParseIP(ip)) {
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.IP.String())
}
port, ok := contact[2].(int64)
if !ok {
t.Error("contact port is not an int")
} else if int(port) != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port)
} else if int(port) != currNode.Port {
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.Port)
}
}
}
@ -269,29 +269,38 @@ func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact
var currNode Contact
currNodeFound := false
if _, ok := foundNodes[contact.id.Hex()]; ok {
t.Errorf("contact %s appears multiple times", contact.id.Hex())
if _, ok := foundNodes[contact.ID.Hex()]; ok {
t.Errorf("contact %s appears multiple times", contact.ID.Hex())
continue
}
for _, n := range nodes {
if n.id.Equals(contact.id) {
if n.ID.Equals(contact.ID) {
currNode = n
currNodeFound = true
foundNodes[contact.id.Hex()] = true
foundNodes[contact.ID.Hex()] = true
break
}
}
if !currNodeFound {
t.Errorf("unexpected contact %s", contact.id.Hex())
t.Errorf("unexpected contact %s", contact.ID.Hex())
continue
}
if !currNode.ip.Equal(contact.ip) {
t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String())
if !currNode.IP.Equal(contact.IP) {
t.Errorf("contact IP mismatch. got %s; expected %s", contact.IP.String(), currNode.IP.String())
}
if contact.port != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port)
if contact.Port != currNode.Port {
t.Errorf("contact port mismatch. got %d; expected %d", contact.Port, currNode.Port)
}
}
}
func assertPanic(t *testing.T, text string, f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("%s: did not panic as expected", text)
}
}()
f()
}