From 03a1e61d573f16737a19bd1d39bb4d22dc90acad Mon Sep 17 00:00:00 2001
From: Alex Grintsvayg <git@grin.io>
Date: Sun, 13 May 2018 21:17:29 -0400
Subject: [PATCH] added concurrent dht test

---
 dht/dht_test.go | 16 ++++++----------
 dht/node.go     | 11 +++++++----
 dht/testing.go  | 38 ++++++++++++++++++++++++--------------
 3 files changed, 37 insertions(+), 28 deletions(-)

diff --git a/dht/dht_test.go b/dht/dht_test.go
index 3b0a61d..437f1c8 100644
--- a/dht/dht_test.go
+++ b/dht/dht_test.go
@@ -1,19 +1,18 @@
 package dht
 
 import (
-	"math/rand"
 	"net"
 	"sync"
 	"testing"
 	"time"
 
-	log "github.com/sirupsen/logrus"
+	"github.com/lbryio/lbry.go/crypto"
 )
 
 // TODO: make a dht with X nodes, have them all join, then ensure that every node appears at least once in another node's routing table
 
 func TestNodeFinder_FindNodes(t *testing.T) {
-	bs, dhts := TestingCreateDHT(3)
+	bs, dhts := TestingCreateDHT(3, true, false)
 	defer func() {
 		for i := range dhts {
 			dhts[i].Shutdown()
@@ -64,7 +63,7 @@ func TestNodeFinder_FindNodes(t *testing.T) {
 }
 
 func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) {
-	dhts := TestingCreateDHTNoBootstrap(3, nil)
+	_, dhts := TestingCreateDHT(3, false, false)
 	defer func() {
 		for i := range dhts {
 			dhts[i].Shutdown()
@@ -79,7 +78,7 @@ func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) {
 }
 
 func TestNodeFinder_FindValue(t *testing.T) {
-	bs, dhts := TestingCreateDHT(3)
+	bs, dhts := TestingCreateDHT(3, true, false)
 	defer func() {
 		for i := range dhts {
 			dhts[i].Shutdown()
@@ -112,10 +111,8 @@ func TestNodeFinder_FindValue(t *testing.T) {
 }
 
 func TestDHT_LargeDHT(t *testing.T) {
-	rand.Seed(time.Now().UnixNano())
-	log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.")
 	nodes := 100
-	bs, dhts := TestingCreateDHT(nodes)
+	bs, dhts := TestingCreateDHT(nodes, true, true)
 	defer func() {
 		for _, d := range dhts {
 			go d.Shutdown()
@@ -132,10 +129,9 @@ func TestDHT_LargeDHT(t *testing.T) {
 	}
 	for i := 0; i < numIDs; i++ {
 		go func(i int) {
-			r := rand.Intn(nodes)
 			wg.Add(1)
 			defer wg.Done()
-			dhts[r].Announce(ids[i])
+			dhts[int(crypto.RandInt64(int64(nodes)))].Announce(ids[i])
 		}(i)
 	}
 	wg.Wait()
diff --git a/dht/node.go b/dht/node.go
index 92b3e10..322d69f 100644
--- a/dht/node.go
+++ b/dht/node.go
@@ -112,7 +112,7 @@ func (n *Node) Connect(conn UDPConn) error {
 			}
 
 			n.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan
-			n, raddr, err := n.conn.ReadFromUDP(buf)
+			bytesRead, raddr, err := n.conn.ReadFromUDP(buf)
 			if err != nil {
 				if e, ok := err.(net.Error); !ok || !e.Timeout() {
 					log.Errorf("udp read error: %v", err)
@@ -123,10 +123,13 @@ func (n *Node) Connect(conn UDPConn) error {
 				continue
 			}
 
-			data := make([]byte, n)
-			copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet
+			data := make([]byte, bytesRead)
+			copy(data, buf[:bytesRead]) // slices use the same underlying array, so we need a new one for each packet
 
-			packets <- packet{data: data, raddr: raddr}
+			select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks
+			case packets <- packet{data: data, raddr: raddr}:
+			case <-n.stop.Chan():
+			}
 		}
 	}()
 
diff --git a/dht/testing.go b/dht/testing.go
index 111138b..8e7ee18 100644
--- a/dht/testing.go
+++ b/dht/testing.go
@@ -13,21 +13,23 @@ import (
 var testingDHTIP = "127.0.0.1"
 var testingDHTFirstPort = 21000
 
-func TestingCreateDHT(numNodes int) (*BootstrapNode, []*DHT) {
-	bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort)
-	bootstrapNode := NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration)
-	listener, err := net.ListenPacket(network, bootstrapAddress)
-	if err != nil {
-		panic(err)
+func TestingCreateDHT(numNodes int, bootstrap, concurrent bool) (*BootstrapNode, []*DHT) {
+	var bootstrapNode *BootstrapNode
+	var seeds []string
+
+	if bootstrap {
+		bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort)
+		seeds = []string{bootstrapAddress}
+		bootstrapNode = NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration)
+		listener, err := net.ListenPacket(network, bootstrapAddress)
+		if err != nil {
+			panic(err)
+		}
+		bootstrapNode.Connect(listener.(*net.UDPConn))
 	}
-	bootstrapNode.Connect(listener.(*net.UDPConn))
 
-	return bootstrapNode, TestingCreateDHTNoBootstrap(numNodes, []string{bootstrapAddress})
-}
-
-func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT {
 	if numNodes < 1 {
-		return nil
+		return bootstrapNode, nil
 	}
 
 	firstPort := testingDHTFirstPort + 1
@@ -40,11 +42,19 @@ func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT {
 		}
 
 		go dht.Start()
-		dht.WaitUntilJoined()
+		if !concurrent {
+			dht.WaitUntilJoined()
+		}
 		dhts[i] = dht
 	}
 
-	return dhts
+	if concurrent {
+		for _, d := range dhts {
+			d.WaitUntilJoined()
+		}
+	}
+
+	return bootstrapNode, dhts
 }
 
 type timeoutErr struct {