bootstrap node, bucket grooming

This commit is contained in:
Alex Grintsvayg 2018-05-13 16:02:46 -04:00
parent 01476a2a8d
commit 2bf117b05f
15 changed files with 932 additions and 433 deletions

View file

@ -80,7 +80,7 @@ func dhtCmd(cmd *cobra.Command, args []string) {
//d.WaitUntilJoined()
nodes := 10
dhts := dht.TestingCreateDHT(nodes)
_, dhts := dht.TestingCreateDHT(nodes)
defer func() {
for _, d := range dhts {
go d.Shutdown()

View file

@ -60,6 +60,26 @@ func (b Bitmap) PrefixLen() int {
return numBuckets
}
// ZeroPrefix returns a copy of b with the first n bits set to 0
// https://stackoverflow.com/a/23192263/182709
func (b Bitmap) ZeroPrefix(n int) Bitmap {
var ret Bitmap
copy(ret[:], b[:])
Outer:
for i := range ret {
for j := 0; j < 8; j++ {
if i*8+j < n {
ret[i] &= ^(1 << uint(7-j))
} else {
break Outer
}
}
}
return ret
}
func (b Bitmap) MarshalBencode() ([]byte, error) {
str := string(b[:])
return bencode.EncodeBytes(str)

View file

@ -99,23 +99,55 @@ func TestBitmapMarshalEmbedded2(t *testing.T) {
func TestBitmap_PrefixLen(t *testing.T) {
tt := []struct {
str string
hex string
len int
}{
{len: 0, str: "F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 0, str: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 1, str: "700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 1, str: "400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 384, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 383, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
{len: 382, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"},
{len: 382, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"},
{len: 0, hex: "F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 0, hex: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 1, hex: "700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 1, hex: "400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 384, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{len: 383, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
{len: 382, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"},
{len: 382, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"},
}
for _, test := range tt {
len := BitmapFromHexP(test.str).PrefixLen()
len := BitmapFromHexP(test.hex).PrefixLen()
if len != test.len {
t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.str)
t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.hex)
}
}
}
func TestBitmap_ZeroPrefix(t *testing.T) {
original := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
tt := []struct {
zeros int
expected string
}{
{zeros: -123, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 0, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 1, expected: "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 69, expected: "000000000000000007ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"},
{zeros: 383, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"},
{zeros: 384, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
{zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
}
for _, test := range tt {
expected := BitmapFromHexP(test.expected)
actual := original.ZeroPrefix(test.zeros)
if !actual.Equals(expected) {
t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex())
}
}
for i := 0; i < nodeIDLength*8; i++ {
b := original.ZeroPrefix(i)
if b.PrefixLen() != i {
t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex())
}
}
}

View file

@ -1,7 +1,212 @@
package dht
// DHT represents a DHT node.
import (
"context"
"math/rand"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
bootstrapDefaultRefreshDuration = 15 * time.Minute
)
type nullStore struct{}
func (n nullStore) Upsert(id Bitmap, c Contact) {}
func (n nullStore) Get(id Bitmap) []Contact { return nil }
func (n nullStore) CountStoredHashes() int { return 0 }
type nullRoutingTable struct{}
// TODO: the bootstrap logic *could* be implemented just in the routing table, without a custom request handler
// TODO: the only tricky part is triggering the ping when Fresh is called, as the rt doesnt have access to the node
func (n nullRoutingTable) Update(c Contact) {} // this
func (n nullRoutingTable) Fresh(c Contact) {} // this
func (n nullRoutingTable) Fail(c Contact) {} // this
func (n nullRoutingTable) GetClosest(id Bitmap, limit int) []Contact { return nil } // this
func (n nullRoutingTable) Count() int { return 0 }
func (n nullRoutingTable) GetIDsForRefresh(d time.Duration) []Bitmap { return nil }
func (n nullRoutingTable) BucketInfo() string { return "" }
type BootstrapNode struct {
// node
node *Node
Node
initialPingInterval time.Duration
checkInterval time.Duration
nlock *sync.RWMutex
nodes []peer
nodeKeys map[Bitmap]int
}
// New returns a BootstrapNode pointer.
func NewBootstrapNode(id Bitmap, initialPingInterval, rePingInterval time.Duration) *BootstrapNode {
b := &BootstrapNode{
Node: *NewNode(id),
initialPingInterval: initialPingInterval,
checkInterval: rePingInterval,
nlock: &sync.RWMutex{},
nodes: make([]peer, 0),
nodeKeys: make(map[Bitmap]int),
}
b.rt = &nullRoutingTable{}
b.store = &nullStore{}
b.requestHandler = b.handleRequest
return b
}
// Add manually adds a contact
func (b *BootstrapNode) Add(c Contact) {
b.upsert(c)
}
// Connect connects to the given connection and starts any background threads necessary
func (b *BootstrapNode) Connect(conn UDPConn) error {
err := b.Node.Connect(conn)
if err != nil {
return err
}
log.Debugf("[%s] bootstrap: node connected", b.id.HexShort())
go func() {
t := time.NewTicker(b.checkInterval / 5)
for {
select {
case <-t.C:
b.check()
case <-b.stop.Chan():
return
}
}
}()
return nil
}
// ypsert adds the contact to the list, or updates the lastPinged time
func (b *BootstrapNode) upsert(c Contact) {
b.nlock.Lock()
defer b.nlock.Unlock()
if i, exists := b.nodeKeys[c.id]; exists {
log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].contact.id.HexShort())
b.nodes[i].Touch()
return
}
log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.id.HexShort())
b.nodeKeys[c.id] = len(b.nodes)
b.nodes = append(b.nodes, peer{c, time.Now(), 0})
}
// remove removes the contact from the list
func (b *BootstrapNode) remove(c Contact) {
b.nlock.Lock()
defer b.nlock.Unlock()
i, exists := b.nodeKeys[c.id]
if !exists {
return
}
log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.id.HexShort())
b.nodes = append(b.nodes[:i], b.nodes[i+1:]...)
delete(b.nodeKeys, c.id)
}
// get returns up to `limit` random contacts from the list
func (b *BootstrapNode) get(limit int) []Contact {
b.nlock.RLock()
defer b.nlock.RUnlock()
if len(b.nodes) < limit {
limit = len(b.nodes)
}
ret := make([]Contact, limit)
for i, k := range randKeys(len(b.nodes))[:limit] {
ret[i] = b.nodes[k].contact
}
return ret
}
// ping pings a node. if the node responds, it is added to the list. otherwise, it is removed
func (b *BootstrapNode) ping(c Contact) {
b.stopWG.Add(1)
defer b.stopWG.Done()
ctx, cancel := context.WithCancel(context.Background())
resCh := b.SendAsync(ctx, c, Request{Method: pingMethod})
var res *Response
select {
case res = <-resCh:
case <-b.stop.Chan():
cancel()
return
}
if res != nil && res.Data == pingSuccessResponse {
b.upsert(c)
} else {
b.remove(c)
}
}
func (b *BootstrapNode) check() {
b.nlock.RLock()
defer b.nlock.RUnlock()
for i := range b.nodes {
if !b.nodes[i].ActiveInLast(b.checkInterval) {
go b.ping(b.nodes[i].contact)
}
}
}
// handleRequest handles the requests received from udp.
func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) {
switch request.Method {
case pingMethod:
b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse})
case findNodeMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
b.sendMessage(addr, Response{
ID: request.ID,
NodeID: b.id,
Contacts: b.get(bucketSize),
})
}
go func() {
log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort())
<-time.After(b.initialPingInterval)
b.ping(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port})
}()
}
func randKeys(max int) []int {
keys := make([]int, max)
for k := range keys {
keys[k] = k
}
rand.Shuffle(max, func(i, j int) {
keys[i], keys[j] = keys[j], keys[i]
})
return keys
}

20
dht/bootstrap_test.go Normal file
View file

@ -0,0 +1,20 @@
package dht
import (
"net"
"testing"
)
func TestBootstrapPing(t *testing.T) {
b := NewBootstrapNode(RandomBitmapP(), 10, bootstrapDefaultRefreshDuration)
listener, err := net.ListenPacket(network, "127.0.0.1:54320")
if err != nil {
panic(err)
}
b.Connect(listener.(*net.UDPConn))
defer b.Shutdown()
b.Shutdown()
}

View file

@ -34,6 +34,8 @@ const (
udpTimeout = 5 * time.Second
udpMaxMessageLength = 1024 // bytes. I think our longest message is ~676 bytes, so I rounded up
maxPeerFails = 3 // after this many failures, a peer is considered bad and will be removed from the routing table
tExpire = 24 * time.Hour // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
tRefresh = 1 * time.Hour // the time after which an otherwise unaccessed bucket must be refreshed
tReplicate = 1 * time.Hour // the interval between Kademlia replication events, when a node is required to publish its entire database
@ -97,15 +99,10 @@ func New(config *Config) (*DHT, error) {
return nil, err
}
node, err := NewNode(contact.id)
if err != nil {
return nil, err
}
d := &DHT{
conf: config,
contact: contact,
node: node,
node: NewNode(contact.id),
stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
joined: make(chan struct{}),
@ -136,7 +133,8 @@ func (dht *DHT) join() {
}
// now call iterativeFind on yourself
_, err := dht.Get(dht.node.id)
nf := newContactFinder(dht.node, dht.node.id, false)
_, err := nf.Find()
if err != nil {
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
}
@ -227,15 +225,18 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
resCh := dht.node.SendAsync(context.Background(), node, Request{
ctx, cancel := context.WithCancel(context.Background())
resCh := dht.node.SendAsync(ctx, node, Request{
Method: findValueMethod,
Arg: &hash,
})
var res *Response
select {
case res = <-resCh:
case <-dht.stop.Chan():
cancel()
return
}
@ -243,7 +244,8 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
return // request timed out
}
dht.node.SendAsync(context.Background(), node, Request{
ctx, cancel = context.WithCancel(context.Background())
resCh = dht.node.SendAsync(ctx, node, Request{
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: hash,
@ -254,6 +256,14 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
},
},
})
go func() {
select {
case <-resCh:
case <-dht.stop.Chan():
cancel()
}
}()
}
func (dht *DHT) PrintState() {

View file

@ -13,11 +13,12 @@ import (
// TODO: make a dht with X nodes, have them all join, then ensure that every node appears at least once in another node's routing table
func TestNodeFinder_FindNodes(t *testing.T) {
dhts := TestingCreateDHT(3)
bs, dhts := TestingCreateDHT(3)
defer func() {
for i := range dhts {
dhts[i].Shutdown()
}
bs.Shutdown()
}()
nf := newContactFinder(dhts[2].node, RandomBitmapP(), false)
@ -31,38 +32,61 @@ func TestNodeFinder_FindNodes(t *testing.T) {
t.Fatal("something was found, but it should not have been")
}
if len(foundNodes) != 1 {
t.Errorf("expected 1 node, found %d", len(foundNodes))
if len(foundNodes) != 3 {
t.Errorf("expected 3 node, found %d", len(foundNodes))
}
foundBootstrap := false
foundOne := false
//foundTwo := false
foundTwo := false
for _, n := range foundNodes {
if n.id.Equals(bs.id) {
foundBootstrap = true
}
if n.id.Equals(dhts[0].node.id) {
foundOne = true
}
//if n.id.Equals(dhts[1].node.c.id) {
// foundTwo = true
//}
if n.id.Equals(dhts[1].node.id) {
foundTwo = true
}
}
if !foundBootstrap {
t.Errorf("did not find bootstrap node %s", bs.id.Hex())
}
if !foundOne {
t.Errorf("did not find first node %s", dhts[0].node.id.Hex())
}
//if !foundTwo {
// t.Errorf("did not find second node %s", dhts[1].node.c.id.Hex())
//}
if !foundTwo {
t.Errorf("did not find second node %s", dhts[1].node.id.Hex())
}
}
func TestNodeFinder_FindValue(t *testing.T) {
dhts := TestingCreateDHT(3)
func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) {
dhts := TestingCreateDHTNoBootstrap(3, nil)
defer func() {
for i := range dhts {
dhts[i].Shutdown()
}
}()
nf := newContactFinder(dhts[2].node, RandomBitmapP(), false)
_, err := nf.Find()
if err == nil {
t.Fatal("contact finder should have errored saying that there are no contacts in the routing table")
}
}
func TestNodeFinder_FindValue(t *testing.T) {
bs, dhts := TestingCreateDHT(3)
defer func() {
for i := range dhts {
dhts[i].Shutdown()
}
bs.Shutdown()
}()
blobHashToFind := RandomBitmapP()
nodeToFind := Contact{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dhts[0].node.store.Upsert(blobHashToFind, nodeToFind)
@ -91,11 +115,12 @@ func TestDHT_LargeDHT(t *testing.T) {
rand.Seed(time.Now().UnixNano())
log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.")
nodes := 100
dhts := TestingCreateDHT(nodes)
bs, dhts := TestingCreateDHT(nodes)
defer func() {
for _, d := range dhts {
go d.Shutdown()
}
bs.Shutdown()
time.Sleep(1 * time.Second)
}()
@ -115,5 +140,5 @@ func TestDHT_LargeDHT(t *testing.T) {
}
wg.Wait()
dhts[1].PrintState()
dhts[len(dhts)-1].PrintState()
}

View file

@ -24,6 +24,7 @@ type packet struct {
}
// UDPConn allows using a mocked connection to test sending/receiving data
// TODO: stop mocking this and use the real thing
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
@ -32,6 +33,8 @@ type UDPConn interface {
Close() error
}
type RequestHandlerFunc func(addr *net.UDPAddr, request Request)
type Node struct {
// the node's id
id Bitmap
@ -45,20 +48,24 @@ type Node struct {
transactions map[messageID]*transaction
// routing table
rt *routingTable
rt RoutingTable
// data store
store *peerStore
store Store
// overrides for request handlers
requestHandler RequestHandlerFunc
// stop the node neatly and clean up after itself
stop *stopOnce.Stopper
stopWG *sync.WaitGroup
}
// New returns a Node pointer.
func NewNode(id Bitmap) (*Node, error) {
n := &Node{
func NewNode(id Bitmap) *Node {
return &Node{
id: id,
rt: newRoutingTable(id),
store: newPeerStore(),
store: newStore(),
txLock: &sync.RWMutex{},
transactions: make(map[messageID]*transaction),
@ -67,11 +74,9 @@ func NewNode(id Bitmap) (*Node, error) {
stopWG: &sync.WaitGroup{},
tokens: &tokenManager{},
}
n.tokens.Start(tokenSecretRotationInterval)
return n, nil
}
// Connect connects to the given connection and starts any background threads necessary
func (n *Node) Connect(conn UDPConn) error {
n.conn = conn
@ -89,6 +94,8 @@ func (n *Node) Connect(conn UDPConn) error {
// }()
//}
n.tokens.Start(tokenSecretRotationInterval)
packets := make(chan packet)
go func() {
@ -139,6 +146,8 @@ func (n *Node) Connect(conn UDPConn) error {
}
}()
n.startRoutingTableGrooming()
return nil
}
@ -161,10 +170,9 @@ func (n *Node) handlePacket(pkt packet) {
return
}
// TODO: test this stuff more thoroughly
// the following is a bit of a hack, but it lets us avoid decoding every message twice
// it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value
// TODO: test this more thoroughly
switch pkt.data[5] {
case '0' + requestType:
@ -210,9 +218,15 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
return
}
// if a handler is overridden, call it instead
if n.requestHandler != nil {
n.requestHandler(addr, request)
return
}
switch request.Method {
default:
// n.send(addr, makeError(t, protocolError, "invalid q"))
//n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-request-method"})
log.Errorln("invalid request method")
return
case pingMethod:
@ -263,7 +277,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
// the routing table must only contain "good" nodes, which are nodes that reply to our requests
// if a node is already good (aka in the table), its fine to refresh it
// http://www.bittorrent.org/beps/bep_0005.html#routing-table
n.rt.UpdateIfExists(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port})
n.rt.Fresh(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port})
}
// handleResponse handles responses received from udp.
@ -279,7 +293,7 @@ func (n *Node) handleResponse(addr *net.UDPAddr, response Response) {
// handleError handles errors received from udp.
func (n *Node) handleError(addr *net.UDPAddr, e Error) {
spew.Dump(e)
n.rt.UpdateIfExists(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port})
n.rt.Fresh(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port})
}
// send sends data to a udp address
@ -383,8 +397,8 @@ func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-ch
}
}
// if request timed out each time
n.rt.Remove(tx.contact.id)
// notify routing table about a failure to respond
n.rt.Fail(tx.contact)
}()
return ch
@ -402,3 +416,19 @@ func (n *Node) CountActiveTransactions() int {
defer n.txLock.Unlock()
return len(n.transactions)
}
func (n *Node) startRoutingTableGrooming() {
n.stopWG.Add(1)
go func() {
defer n.stopWG.Done()
refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed
for {
select {
case <-refreshTicker.C:
RoutingTableRefresh(n, tRefresh, n.stop.Chan())
case <-n.stop.Chan():
return
}
}
}()
}

View file

@ -17,7 +17,8 @@ type contactFinder struct {
target Bitmap
node *Node
done *stopOnce.Stopper
done *stopOnce.Stopper
doneWG *sync.WaitGroup
findValueMutex *sync.Mutex
findValueResult []Contact
@ -48,10 +49,16 @@ func newContactFinder(node *Node, target Bitmap, findValue bool) *contactFinder
shortlistMutex: &sync.Mutex{},
shortlistAdded: make(map[Bitmap]bool),
done: stopOnce.New(),
doneWG: &sync.WaitGroup{},
outstandingRequestsMutex: &sync.RWMutex{},
}
}
func (cf *contactFinder) Cancel() {
cf.done.Stop()
cf.doneWG.Wait()
}
func (cf *contactFinder) Find() (findNodeResponse, error) {
if cf.findValue {
log.Debugf("[%s] starting an iterative Find for the value %s", cf.node.id.HexShort(), cf.target.HexShort())
@ -63,17 +70,15 @@ func (cf *contactFinder) Find() (findNodeResponse, error) {
return findNodeResponse{}, errors.Err("no contacts in routing table")
}
wg := &sync.WaitGroup{}
for i := 0; i < alpha; i++ {
wg.Add(1)
cf.doneWG.Add(1)
go func(i int) {
defer wg.Done()
defer cf.doneWG.Done()
cf.iterationWorker(i + 1)
}(i)
}
wg.Wait()
cf.doneWG.Wait()
// TODO: what to do if we have less than K active contacts, shortlist is empty, but we
// TODO: have other contacts in our routing table whom we have not contacted. prolly contact them
@ -133,7 +138,7 @@ func (cf *contactFinder) iterationWorker(num int) {
if res == nil {
// nothing to do, response timed out
log.Debugf("[%s] worker %d: timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort())
log.Debugf("[%s] worker %d: search canceled or timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort())
} else if cf.findValue && res.FindValueKey != "" {
log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num)
cf.findValueMutex.Lock()

View file

@ -2,93 +2,12 @@ package dht
import (
"net"
"strconv"
"strings"
"testing"
"time"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode"
)
type timeoutErr struct {
error
}
func (t timeoutErr) Timeout() bool {
return true
}
func (t timeoutErr) Temporary() bool {
return true
}
// TODO: just use a normal net.Conn instead of this mock conn
type testUDPPacket struct {
data []byte
addr *net.UDPAddr
}
type testUDPConn struct {
addr *net.UDPAddr
toRead chan testUDPPacket
writes chan testUDPPacket
readDeadline time.Time
}
func newTestUDPConn(addr string) *testUDPConn {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
panic("addr needs ip and port")
}
port, err := strconv.Atoi(parts[1])
if err != nil {
panic(err)
}
return &testUDPConn{
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
toRead: make(chan testUDPPacket),
writes: make(chan testUDPPacket),
}
}
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
var timeoutCh <-chan time.Time
if !t.readDeadline.IsZero() {
timeoutCh = time.After(t.readDeadline.Sub(time.Now()))
}
select {
case packet := <-t.toRead:
n := copy(b, packet.data)
return n, packet.addr, nil
case <-timeoutCh:
return 0, nil, timeoutErr{errors.Err("timeout")}
}
}
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
t.writes <- testUDPPacket{data: b, addr: addr}
return len(b), nil
}
func (t *testUDPConn) SetReadDeadline(tm time.Time) error {
t.readDeadline = tm
return nil
}
func (t *testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (t *testUDPConn) Close() error {
t.toRead = nil
t.writes = nil
return nil
}
func TestPing(t *testing.T) {
dhtNodeID := RandomBitmapP()
testNodeID := RandomBitmapP()
@ -271,7 +190,7 @@ func TestStore(t *testing.T) {
}
}
if len(dht.node.store.hashes) != 1 {
if dht.node.store.CountStoredHashes() != 1 {
t.Error("dht store has wrong number of items")
}
@ -517,164 +436,3 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
verifyContacts(t, contacts, nodes)
}
func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
if len(resp) != 4 {
t.Errorf("expected 4 response fields, got %d", len(resp))
}
_, ok := resp[headerTypeField]
if !ok {
t.Error("missing type field")
} else {
rType, ok := resp[headerTypeField].(int64)
if !ok {
t.Error("type is not an integer")
} else if rType != responseType {
t.Error("unexpected response type")
}
}
_, ok = resp[headerMessageIDField]
if !ok {
t.Error("missing message id field")
} else {
rMessageID, ok := resp[headerMessageIDField].(string)
if !ok {
t.Error("message ID is not a string")
} else if rMessageID != string(id[:]) {
t.Error("unexpected message ID")
}
if len(rMessageID) != messageIDLength {
t.Errorf("message ID should be %d chars long", messageIDLength)
}
}
_, ok = resp[headerNodeIDField]
if !ok {
t.Error("missing node id field")
} else {
rNodeID, ok := resp[headerNodeIDField].(string)
if !ok {
t.Error("node ID is not a string")
} else if rNodeID != dhtNodeID {
t.Error("unexpected node ID")
}
if len(rNodeID) != nodeIDLength {
t.Errorf("node ID should be %d chars long", nodeIDLength)
}
}
}
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
}
foundNodes := make(map[string]bool)
for _, c := range contacts {
contact, ok := c.([]interface{})
if !ok {
t.Error("contact is not a list")
return
}
if len(contact) != 3 {
t.Error("contact must be 3 items")
return
}
var currNode Contact
currNodeFound := false
id, ok := contact[0].(string)
if !ok {
t.Error("contact id is not a string")
} else {
if _, ok := foundNodes[id]; ok {
t.Errorf("contact %s appears multiple times", id)
continue
}
for _, n := range nodes {
if n.id.RawString() == id {
currNode = n
currNodeFound = true
foundNodes[id] = true
break
}
}
if !currNodeFound {
t.Errorf("unexpected contact %s", id)
continue
}
}
ip, ok := contact[1].(string)
if !ok {
t.Error("contact IP is not a string")
} else if !currNode.ip.Equal(net.ParseIP(ip)) {
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String())
}
port, ok := contact[2].(int64)
if !ok {
t.Error("contact port is not an int")
} else if int(port) != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port)
}
}
}
func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
}
foundNodes := make(map[string]bool)
for _, c := range contacts {
compact, ok := c.(string)
if !ok {
t.Error("contact is not a string")
return
}
contact := Contact{}
err := contact.UnmarshalCompact([]byte(compact))
if err != nil {
t.Error(err)
return
}
var currNode Contact
currNodeFound := false
if _, ok := foundNodes[contact.id.Hex()]; ok {
t.Errorf("contact %s appears multiple times", contact.id.Hex())
continue
}
for _, n := range nodes {
if n.id.Equals(contact.id) {
currNode = n
currNodeFound = true
foundNodes[contact.id.Hex()] = true
break
}
}
if !currNodeFound {
t.Errorf("unexpected contact %s", contact.id.Hex())
continue
}
if !currNode.ip.Equal(contact.ip) {
t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String())
}
if contact.port != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port)
}
}
}

View file

@ -8,6 +8,7 @@ import (
"sort"
"strings"
"sync"
"time"
"github.com/lbryio/errors.go"
@ -110,31 +111,175 @@ func (a byXorDistance) Less(i, j int) bool {
return a[i].xorDistanceToTarget.Less(a[j].xorDistanceToTarget)
}
type routingTable struct {
id Bitmap
buckets [numBuckets]*list.List
lock *sync.RWMutex
// peer is a contact with extra freshness information
type peer struct {
contact Contact
lastActivity time.Time
numFailures int
//<lastPublished>,
//<originallyPublished>
// <originalPublisherID>
}
func newRoutingTable(id Bitmap) *routingTable {
var rt routingTable
for i := range rt.buckets {
rt.buckets[i] = list.New()
func (p *peer) Touch() {
p.lastActivity = time.Now()
p.numFailures = 0
}
// ActiveSince returns whether a peer has responded in the last `d` duration
// this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests
func (p *peer) ActiveInLast(d time.Duration) bool {
return time.Now().Sub(p.lastActivity) > d
}
// IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row
func (p *peer) IsBad(maxFalures int) bool {
return p.numFailures >= maxFalures
}
// Fail marks a peer as having failed to respond. It returns whether or not the peer should be removed from the routing table
func (p *peer) Fail() {
p.numFailures++
}
// toPeer converts a generic *list.Element into a *peer
// this (along with newPeer) keeps all conversions between *list.Element and peer in one place
func toPeer(el *list.Element) *peer {
return el.Value.(*peer)
}
// newPeer creates a new peer from a contact
// this (along with toPeer) keeps all conversions between *list.Element and peer in one place
func newPeer(c Contact) peer {
return peer{
contact: c,
}
}
type bucket struct {
lock *sync.RWMutex
peers *list.List
lastUpdate time.Time
}
// Len returns the number of peers in the bucket
func (b bucket) Len() int {
b.lock.RLock()
defer b.lock.RUnlock()
return b.peers.Len()
}
// Contacts returns a slice of the bucket's contacts
func (b bucket) Contacts() []Contact {
b.lock.RLock()
defer b.lock.RUnlock()
contacts := make([]Contact, b.peers.Len())
for i, curr := 0, b.peers.Front(); curr != nil; i, curr = i+1, curr.Next() {
contacts[i] = toPeer(curr).contact
}
return contacts
}
// UpdateContact marks a contact as having been successfully contacted. if insertIfNew and the contact is does not exist yet, it is inserted
func (b *bucket) UpdateContact(c Contact, insertIfNew bool) {
b.lock.Lock()
defer b.lock.Unlock()
element := find(c.id, b.peers)
if element != nil {
b.lastUpdate = time.Now()
toPeer(element).Touch()
b.peers.MoveToBack(element)
} else if insertIfNew {
hasRoom := true
if b.peers.Len() >= bucketSize {
hasRoom = false
for curr := b.peers.Front(); curr != nil; curr = curr.Next() {
if toPeer(curr).IsBad(maxPeerFails) {
// TODO: Ping contact first. Only remove if it does not respond
b.peers.Remove(curr)
hasRoom = true
break
}
}
}
if hasRoom {
b.lastUpdate = time.Now()
peer := newPeer(c)
peer.Touch()
b.peers.PushBack(&peer)
}
}
}
// FailContact marks a contact as having failed, and removes it if it failed too many times
func (b *bucket) FailContact(id Bitmap) {
b.lock.Lock()
defer b.lock.Unlock()
element := find(id, b.peers)
if element != nil {
// BEP5 says not to remove the contact until the bucket is full and you try to insert
toPeer(element).Fail()
}
}
// find returns the contact in the bucket, or nil if the bucket does not contain the contact
func find(id Bitmap, peers *list.List) *list.Element {
for curr := peers.Front(); curr != nil; curr = curr.Next() {
if toPeer(curr).contact.id.Equals(id) {
return curr
}
}
return nil
}
// NeedsRefresh returns true if bucket has not been updated in the last `refreshInterval`, false otherwise
func (b *bucket) NeedsRefresh(refreshInterval time.Duration) bool {
b.lock.RLock()
defer b.lock.RUnlock()
return time.Now().Sub(b.lastUpdate) > refreshInterval
}
type RoutingTable interface {
Update(Contact)
Fresh(Contact)
Fail(Contact)
GetClosest(Bitmap, int) []Contact
Count() int
GetIDsForRefresh(time.Duration) []Bitmap
BucketInfo() string // for debugging
}
type routingTableImpl struct {
id Bitmap
buckets [numBuckets]bucket
}
func newRoutingTable(id Bitmap) *routingTableImpl {
var rt routingTableImpl
rt.id = id
rt.lock = &sync.RWMutex{}
for i := range rt.buckets {
rt.buckets[i] = bucket{
peers: list.New(),
lock: &sync.RWMutex{},
}
}
return &rt
}
func (rt *routingTable) BucketInfo() string {
rt.lock.RLock()
defer rt.lock.RUnlock()
func (rt *routingTableImpl) BucketInfo() string {
var bucketInfo []string
for i, b := range rt.buckets {
contents := bucketContents(b)
if contents != "" {
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %s", i, contents))
if b.Len() > 0 {
contacts := b.Contacts()
s := make([]string, len(contacts))
for j, c := range contacts {
s[j] = c.id.HexShort()
}
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, len(contacts), strings.Join(s, ", ")))
}
}
if len(bucketInfo) == 0 {
@ -143,89 +288,41 @@ func (rt *routingTable) BucketInfo() string {
return strings.Join(bucketInfo, "\n")
}
func bucketContents(b *list.List) string {
count := 0
ids := ""
for curr := b.Front(); curr != nil; curr = curr.Next() {
count++
if ids != "" {
ids += ", "
}
ids += curr.Value.(Contact).id.HexShort()
}
if count > 0 {
return fmt.Sprintf("(%d) %s", count, ids)
} else {
return ""
}
}
// Update inserts or refreshes a contact
func (rt *routingTable) Update(c Contact) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := rt.bucketFor(c.id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, c.id)
if element == nil {
if bucket.Len() >= bucketSize {
// TODO: Ping front contact first. Only remove if it does not respond
bucket.Remove(bucket.Front())
}
bucket.PushBack(c)
} else {
bucket.MoveToBack(element)
}
func (rt *routingTableImpl) Update(c Contact) {
rt.bucketFor(c.id).UpdateContact(c, true)
}
// UpdateIfExists refreshes a contact if its already in the routing table
func (rt *routingTable) UpdateIfExists(c Contact) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := rt.bucketFor(c.id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, c.id)
if element != nil {
bucket.MoveToBack(element)
}
// Fresh refreshes a contact if its already in the routing table
func (rt *routingTableImpl) Fresh(c Contact) {
rt.bucketFor(c.id).UpdateContact(c, false)
}
func (rt *routingTable) Remove(id Bitmap) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := rt.bucketFor(id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, rt.id)
if element != nil {
bucket.Remove(element)
}
// FailContact marks a contact as having failed, and removes it if it failed too many times
func (rt *routingTableImpl) Fail(c Contact) {
rt.bucketFor(c.id).FailContact(c.id)
}
func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact {
rt.lock.RLock()
defer rt.lock.RUnlock()
// GetClosest returns the closest `limit` contacts from the routing table
// It marks each bucket it accesses as having been accessed
func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact {
var toSort []sortedContact
var bucketNum int
if rt.id.Equals(target) {
bucketNum = 0
} else {
bucketNum = rt.bucketFor(target)
bucketNum = rt.bucketNumFor(target)
}
bucket := rt.buckets[bucketNum]
toSort = appendContacts(toSort, bucket.Front(), target)
toSort = appendContacts(toSort, rt.buckets[bucketNum], target)
for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ {
if bucketNum-i >= 0 {
bucket = rt.buckets[bucketNum-i]
toSort = appendContacts(toSort, bucket.Front(), target)
toSort = appendContacts(toSort, rt.buckets[bucketNum-i], target)
}
if bucketNum+i < numBuckets {
bucket = rt.buckets[bucketNum+i]
toSort = appendContacts(toSort, bucket.Front(), target)
toSort = appendContacts(toSort, rt.buckets[bucketNum+i], target)
}
}
@ -242,43 +339,75 @@ func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact {
return contacts
}
func appendContacts(contacts []sortedContact, start *list.Element, target Bitmap) []sortedContact {
for curr := start; curr != nil; curr = curr.Next() {
c := toContact(curr)
contacts = append(contacts, sortedContact{c, c.id.Xor(target)})
func appendContacts(contacts []sortedContact, b bucket, target Bitmap) []sortedContact {
for _, contact := range b.Contacts() {
contacts = append(contacts, sortedContact{contact, contact.id.Xor(target)})
}
return contacts
}
// Count returns the number of contacts in the routing table
func (rt *routingTable) Count() int {
rt.lock.RLock()
defer rt.lock.RUnlock()
func (rt *routingTableImpl) Count() int {
count := 0
for _, bucket := range rt.buckets {
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
count++
}
count = bucket.Len()
}
return count
}
func (rt *routingTable) bucketFor(target Bitmap) int {
func (rt *routingTableImpl) bucketNumFor(target Bitmap) int {
if rt.id.Equals(target) {
panic("routing table does not have a bucket for its own id")
}
return numBuckets - 1 - target.Xor(rt.id).PrefixLen()
}
func findInList(bucket *list.List, value Bitmap) *list.Element {
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
if toContact(curr).id.Equals(value) {
return curr
}
}
return nil
func (rt *routingTableImpl) bucketFor(target Bitmap) *bucket {
return &rt.buckets[rt.bucketNumFor(target)]
}
func toContact(el *list.Element) Contact {
return el.Value.(Contact)
func (rt *routingTableImpl) GetIDsForRefresh(refreshInterval time.Duration) []Bitmap {
var bitmaps []Bitmap
for i, bucket := range rt.buckets {
if bucket.NeedsRefresh(refreshInterval) {
bitmaps = append(bitmaps, RandomBitmapP().ZeroPrefix(i))
}
}
return bitmaps
}
// RoutingTableRefresh refreshes any buckets that need to be refreshed
// It returns a channel that will be closed when the refresh is done
func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} {
done := make(chan struct{})
var wg sync.WaitGroup
for _, id := range n.rt.GetIDsForRefresh(refreshInterval) {
wg.Add(1)
go func(id Bitmap) {
defer wg.Done()
nf := newContactFinder(n, id, false)
if cancel != nil {
go func() {
select {
case <-cancel:
nf.Cancel()
case <-done:
}
}()
}
nf.Find()
}(id)
}
go func() {
wg.Wait()
close(done)
}()
return done
}

View file

@ -7,27 +7,26 @@ import (
)
func TestRoutingTable_bucketFor(t *testing.T) {
target := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
rt := newRoutingTable(BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
var tests = []struct {
id Bitmap
target Bitmap
expected int
}{
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), target, 0},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), target, 1},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), target, 1},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), target, 2},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), target, 2},
{BitmapFromHexP("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), target, 3},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), target, 4},
{BitmapFromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), target, 383},
{BitmapFromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), target, 383},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), 0},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), 1},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), 1},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), 2},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), 2},
{BitmapFromHexP("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), 3},
{BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), 4},
{BitmapFromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 383},
{BitmapFromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), 383},
}
for _, tt := range tests {
bucket := bucketFor(tt.id, tt.target)
bucket := rt.bucketNumFor(tt.id)
if bucket != tt.expected {
t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), tt.target.Hex(), bucket, tt.expected)
t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), rt.id.Hex(), bucket, tt.expected)
}
}
}
@ -83,3 +82,7 @@ func TestCompactEncoding(t *testing.T) {
t.Errorf("compact bytes not encoded correctly")
}
}
func TestRoutingTableRefresh(t *testing.T) {
t.Skip("TODO: test routing table refreshing")
}

View file

@ -2,7 +2,13 @@ package dht
import "sync"
type peerStore struct {
type Store interface {
Upsert(Bitmap, Contact)
Get(Bitmap) []Contact
CountStoredHashes() int
}
type storeImpl struct {
// map of blob hashes to (map of node IDs to bools)
hashes map[Bitmap]map[Bitmap]bool
// stores the peers themselves, so they can be updated in one place
@ -10,14 +16,14 @@ type peerStore struct {
lock sync.RWMutex
}
func newPeerStore() *peerStore {
return &peerStore{
func newStore() *storeImpl {
return &storeImpl{
hashes: make(map[Bitmap]map[Bitmap]bool),
contacts: make(map[Bitmap]Contact),
}
}
func (s *peerStore) Upsert(blobHash Bitmap, contact Contact) {
func (s *storeImpl) Upsert(blobHash Bitmap, contact Contact) {
s.lock.Lock()
defer s.lock.Unlock()
@ -28,7 +34,7 @@ func (s *peerStore) Upsert(blobHash Bitmap, contact Contact) {
s.contacts[contact.id] = contact
}
func (s *peerStore) Get(blobHash Bitmap) []Contact {
func (s *storeImpl) Get(blobHash Bitmap) []Contact {
s.lock.RLock()
defer s.lock.RUnlock()
@ -45,11 +51,11 @@ func (s *peerStore) Get(blobHash Bitmap) []Contact {
return contacts
}
func (s *peerStore) RemoveTODO(contact Contact) {
func (s *storeImpl) RemoveTODO(contact Contact) {
// TODO: remove peer from everywhere
}
func (s *peerStore) CountStoredHashes() int {
func (s *storeImpl) CountStoredHashes() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.hashes)

View file

@ -1,23 +1,40 @@
package dht
import "strconv"
import (
"net"
"strconv"
"strings"
"testing"
"time"
func TestingCreateDHT(numNodes int) []*DHT {
"github.com/lbryio/errors.go"
)
var testingDHTIP = "127.0.0.1"
var testingDHTFirstPort = 21000
func TestingCreateDHT(numNodes int) (*BootstrapNode, []*DHT) {
bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort)
bootstrapNode := NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration)
listener, err := net.ListenPacket(network, bootstrapAddress)
if err != nil {
panic(err)
}
bootstrapNode.Connect(listener.(*net.UDPConn))
return bootstrapNode, TestingCreateDHTNoBootstrap(numNodes, []string{bootstrapAddress})
}
func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT {
if numNodes < 1 {
return nil
}
ip := "127.0.0.1"
firstPort := 21000
firstPort := testingDHTFirstPort + 1
dhts := make([]*DHT, numNodes)
for i := 0; i < numNodes; i++ {
seeds := []string{}
if i > 0 {
seeds = []string{ip + ":" + strconv.Itoa(firstPort)}
}
dht, err := New(&Config{Address: ip + ":" + strconv.Itoa(firstPort+i), NodeID: RandomBitmapP().Hex(), SeedNodes: seeds})
dht, err := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: RandomBitmapP().Hex(), SeedNodes: seeds})
if err != nil {
panic(err)
}
@ -29,3 +46,242 @@ func TestingCreateDHT(numNodes int) []*DHT {
return dhts
}
type timeoutErr struct {
error
}
func (t timeoutErr) Timeout() bool {
return true
}
func (t timeoutErr) Temporary() bool {
return true
}
// TODO: just use a normal net.Conn instead of this mock conn
type testUDPPacket struct {
data []byte
addr *net.UDPAddr
}
type testUDPConn struct {
addr *net.UDPAddr
toRead chan testUDPPacket
writes chan testUDPPacket
readDeadline time.Time
}
func newTestUDPConn(addr string) *testUDPConn {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
panic("addr needs ip and port")
}
port, err := strconv.Atoi(parts[1])
if err != nil {
panic(err)
}
return &testUDPConn{
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
toRead: make(chan testUDPPacket),
writes: make(chan testUDPPacket),
}
}
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
var timeoutCh <-chan time.Time
if !t.readDeadline.IsZero() {
timeoutCh = time.After(t.readDeadline.Sub(time.Now()))
}
select {
case packet := <-t.toRead:
n := copy(b, packet.data)
return n, packet.addr, nil
case <-timeoutCh:
return 0, nil, timeoutErr{errors.Err("timeout")}
}
}
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
t.writes <- testUDPPacket{data: b, addr: addr}
return len(b), nil
}
func (t *testUDPConn) SetReadDeadline(tm time.Time) error {
t.readDeadline = tm
return nil
}
func (t *testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (t *testUDPConn) Close() error {
t.toRead = nil
t.writes = nil
return nil
}
func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
if len(resp) != 4 {
t.Errorf("expected 4 response fields, got %d", len(resp))
}
_, ok := resp[headerTypeField]
if !ok {
t.Error("missing type field")
} else {
rType, ok := resp[headerTypeField].(int64)
if !ok {
t.Error("type is not an integer")
} else if rType != responseType {
t.Error("unexpected response type")
}
}
_, ok = resp[headerMessageIDField]
if !ok {
t.Error("missing message id field")
} else {
rMessageID, ok := resp[headerMessageIDField].(string)
if !ok {
t.Error("message ID is not a string")
} else if rMessageID != string(id[:]) {
t.Error("unexpected message ID")
}
if len(rMessageID) != messageIDLength {
t.Errorf("message ID should be %d chars long", messageIDLength)
}
}
_, ok = resp[headerNodeIDField]
if !ok {
t.Error("missing node id field")
} else {
rNodeID, ok := resp[headerNodeIDField].(string)
if !ok {
t.Error("node ID is not a string")
} else if rNodeID != dhtNodeID {
t.Error("unexpected node ID")
}
if len(rNodeID) != nodeIDLength {
t.Errorf("node ID should be %d chars long", nodeIDLength)
}
}
}
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
}
foundNodes := make(map[string]bool)
for _, c := range contacts {
contact, ok := c.([]interface{})
if !ok {
t.Error("contact is not a list")
return
}
if len(contact) != 3 {
t.Error("contact must be 3 items")
return
}
var currNode Contact
currNodeFound := false
id, ok := contact[0].(string)
if !ok {
t.Error("contact id is not a string")
} else {
if _, ok := foundNodes[id]; ok {
t.Errorf("contact %s appears multiple times", id)
continue
}
for _, n := range nodes {
if n.id.RawString() == id {
currNode = n
currNodeFound = true
foundNodes[id] = true
break
}
}
if !currNodeFound {
t.Errorf("unexpected contact %s", id)
continue
}
}
ip, ok := contact[1].(string)
if !ok {
t.Error("contact IP is not a string")
} else if !currNode.ip.Equal(net.ParseIP(ip)) {
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String())
}
port, ok := contact[2].(int64)
if !ok {
t.Error("contact port is not an int")
} else if int(port) != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port)
}
}
}
func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
}
foundNodes := make(map[string]bool)
for _, c := range contacts {
compact, ok := c.(string)
if !ok {
t.Error("contact is not a string")
return
}
contact := Contact{}
err := contact.UnmarshalCompact([]byte(compact))
if err != nil {
t.Error(err)
return
}
var currNode Contact
currNodeFound := false
if _, ok := foundNodes[contact.id.Hex()]; ok {
t.Errorf("contact %s appears multiple times", contact.id.Hex())
continue
}
for _, n := range nodes {
if n.id.Equals(contact.id) {
currNode = n
currNodeFound = true
foundNodes[contact.id.Hex()] = true
break
}
}
if !currNodeFound {
t.Errorf("unexpected contact %s", contact.id.Hex())
continue
}
if !currNode.ip.Equal(contact.ip) {
t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String())
}
if contact.port != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port)
}
}
}

View file

@ -79,7 +79,7 @@ func (s *S3BlobStore) Get(hash string) ([]byte, error) {
log.Debugf("Getting %s from S3", hash[:8])
defer func(t time.Time) {
log.Debugf("Getting %s took %s", hash[:8], time.Since(t).String())
log.Debugf("Getting %s from S3 took %s", hash[:8], time.Since(t).String())
}(time.Now())
buf := &aws.WriteAtBuffer{}