diff --git a/dht/dht_test.go b/dht/dht_test.go index 3b0a61d..437f1c8 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,19 +1,18 @@ package dht import ( - "math/rand" "net" "sync" "testing" "time" - log "github.com/sirupsen/logrus" + "github.com/lbryio/lbry.go/crypto" ) // TODO: make a dht with X nodes, have them all join, then ensure that every node appears at least once in another node's routing table func TestNodeFinder_FindNodes(t *testing.T) { - bs, dhts := TestingCreateDHT(3) + bs, dhts := TestingCreateDHT(3, true, false) defer func() { for i := range dhts { dhts[i].Shutdown() @@ -64,7 +63,7 @@ func TestNodeFinder_FindNodes(t *testing.T) { } func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) { - dhts := TestingCreateDHTNoBootstrap(3, nil) + _, dhts := TestingCreateDHT(3, false, false) defer func() { for i := range dhts { dhts[i].Shutdown() @@ -79,7 +78,7 @@ func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) { } func TestNodeFinder_FindValue(t *testing.T) { - bs, dhts := TestingCreateDHT(3) + bs, dhts := TestingCreateDHT(3, true, false) defer func() { for i := range dhts { dhts[i].Shutdown() @@ -112,10 +111,8 @@ func TestNodeFinder_FindValue(t *testing.T) { } func TestDHT_LargeDHT(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.") nodes := 100 - bs, dhts := TestingCreateDHT(nodes) + bs, dhts := TestingCreateDHT(nodes, true, true) defer func() { for _, d := range dhts { go d.Shutdown() @@ -132,10 +129,9 @@ func TestDHT_LargeDHT(t *testing.T) { } for i := 0; i < numIDs; i++ { go func(i int) { - r := rand.Intn(nodes) wg.Add(1) defer wg.Done() - dhts[r].Announce(ids[i]) + dhts[int(crypto.RandInt64(int64(nodes)))].Announce(ids[i]) }(i) } wg.Wait() diff --git a/dht/node.go b/dht/node.go index 92b3e10..322d69f 100644 --- a/dht/node.go +++ b/dht/node.go @@ -112,7 +112,7 @@ func (n *Node) Connect(conn UDPConn) error { } n.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan - n, raddr, err := n.conn.ReadFromUDP(buf) + bytesRead, raddr, err := n.conn.ReadFromUDP(buf) if err != nil { if e, ok := err.(net.Error); !ok || !e.Timeout() { log.Errorf("udp read error: %v", err) @@ -123,10 +123,13 @@ func (n *Node) Connect(conn UDPConn) error { continue } - data := make([]byte, n) - copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet + data := make([]byte, bytesRead) + copy(data, buf[:bytesRead]) // slices use the same underlying array, so we need a new one for each packet - packets <- packet{data: data, raddr: raddr} + select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks + case packets <- packet{data: data, raddr: raddr}: + case <-n.stop.Chan(): + } } }() diff --git a/dht/testing.go b/dht/testing.go index 111138b..8e7ee18 100644 --- a/dht/testing.go +++ b/dht/testing.go @@ -13,21 +13,23 @@ import ( var testingDHTIP = "127.0.0.1" var testingDHTFirstPort = 21000 -func TestingCreateDHT(numNodes int) (*BootstrapNode, []*DHT) { - bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort) - bootstrapNode := NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration) - listener, err := net.ListenPacket(network, bootstrapAddress) - if err != nil { - panic(err) +func TestingCreateDHT(numNodes int, bootstrap, concurrent bool) (*BootstrapNode, []*DHT) { + var bootstrapNode *BootstrapNode + var seeds []string + + if bootstrap { + bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort) + seeds = []string{bootstrapAddress} + bootstrapNode = NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration) + listener, err := net.ListenPacket(network, bootstrapAddress) + if err != nil { + panic(err) + } + bootstrapNode.Connect(listener.(*net.UDPConn)) } - bootstrapNode.Connect(listener.(*net.UDPConn)) - return bootstrapNode, TestingCreateDHTNoBootstrap(numNodes, []string{bootstrapAddress}) -} - -func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT { if numNodes < 1 { - return nil + return bootstrapNode, nil } firstPort := testingDHTFirstPort + 1 @@ -40,11 +42,19 @@ func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT { } go dht.Start() - dht.WaitUntilJoined() + if !concurrent { + dht.WaitUntilJoined() + } dhts[i] = dht } - return dhts + if concurrent { + for _, d := range dhts { + d.WaitUntilJoined() + } + } + + return bootstrapNode, dhts } type timeoutErr struct {