package dht import ( "net" "strconv" "strings" "testing" "time" "github.com/lbryio/errors.go" ) var testingDHTIP = "127.0.0.1" var testingDHTFirstPort = 21000 func TestingCreateDHT(numNodes int) (*BootstrapNode, []*DHT) { bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort) bootstrapNode := NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration) listener, err := net.ListenPacket(network, bootstrapAddress) if err != nil { panic(err) } bootstrapNode.Connect(listener.(*net.UDPConn)) return bootstrapNode, TestingCreateDHTNoBootstrap(numNodes, []string{bootstrapAddress}) } func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT { if numNodes < 1 { return nil } firstPort := testingDHTFirstPort + 1 dhts := make([]*DHT, numNodes) for i := 0; i < numNodes; i++ { dht, err := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: RandomBitmapP().Hex(), SeedNodes: seeds}) if err != nil { panic(err) } go dht.Start() dht.WaitUntilJoined() dhts[i] = dht } return dhts } type timeoutErr struct { error } func (t timeoutErr) Timeout() bool { return true } func (t timeoutErr) Temporary() bool { return true } // TODO: just use a normal net.Conn instead of this mock conn type testUDPPacket struct { data []byte addr *net.UDPAddr } type testUDPConn struct { addr *net.UDPAddr toRead chan testUDPPacket writes chan testUDPPacket readDeadline time.Time } func newTestUDPConn(addr string) *testUDPConn { parts := strings.Split(addr, ":") if len(parts) != 2 { panic("addr needs ip and port") } port, err := strconv.Atoi(parts[1]) if err != nil { panic(err) } return &testUDPConn{ addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port}, toRead: make(chan testUDPPacket), writes: make(chan testUDPPacket), } } func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { var timeoutCh <-chan time.Time if !t.readDeadline.IsZero() { timeoutCh = time.After(t.readDeadline.Sub(time.Now())) } select { case packet := <-t.toRead: n := copy(b, packet.data) return n, packet.addr, nil case <-timeoutCh: return 0, nil, timeoutErr{errors.Err("timeout")} } } func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { t.writes <- testUDPPacket{data: b, addr: addr} return len(b), nil } func (t *testUDPConn) SetReadDeadline(tm time.Time) error { t.readDeadline = tm return nil } func (t *testUDPConn) SetWriteDeadline(tm time.Time) error { return nil } func (t *testUDPConn) Close() error { t.toRead = nil t.writes = nil return nil } func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) { if len(resp) != 4 { t.Errorf("expected 4 response fields, got %d", len(resp)) } _, ok := resp[headerTypeField] if !ok { t.Error("missing type field") } else { rType, ok := resp[headerTypeField].(int64) if !ok { t.Error("type is not an integer") } else if rType != responseType { t.Error("unexpected response type") } } _, ok = resp[headerMessageIDField] if !ok { t.Error("missing message id field") } else { rMessageID, ok := resp[headerMessageIDField].(string) if !ok { t.Error("message ID is not a string") } else if rMessageID != string(id[:]) { t.Error("unexpected message ID") } if len(rMessageID) != messageIDLength { t.Errorf("message ID should be %d chars long", messageIDLength) } } _, ok = resp[headerNodeIDField] if !ok { t.Error("missing node id field") } else { rNodeID, ok := resp[headerNodeIDField].(string) if !ok { t.Error("node ID is not a string") } else if rNodeID != dhtNodeID { t.Error("unexpected node ID") } if len(rNodeID) != nodeIDLength { t.Errorf("node ID should be %d chars long", nodeIDLength) } } } func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) { if len(contacts) != len(nodes) { t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) return } foundNodes := make(map[string]bool) for _, c := range contacts { contact, ok := c.([]interface{}) if !ok { t.Error("contact is not a list") return } if len(contact) != 3 { t.Error("contact must be 3 items") return } var currNode Contact currNodeFound := false id, ok := contact[0].(string) if !ok { t.Error("contact id is not a string") } else { if _, ok := foundNodes[id]; ok { t.Errorf("contact %s appears multiple times", id) continue } for _, n := range nodes { if n.id.RawString() == id { currNode = n currNodeFound = true foundNodes[id] = true break } } if !currNodeFound { t.Errorf("unexpected contact %s", id) continue } } ip, ok := contact[1].(string) if !ok { t.Error("contact IP is not a string") } else if !currNode.ip.Equal(net.ParseIP(ip)) { t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String()) } port, ok := contact[2].(int64) if !ok { t.Error("contact port is not an int") } else if int(port) != currNode.port { t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port) } } } func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) { if len(contacts) != len(nodes) { t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) return } foundNodes := make(map[string]bool) for _, c := range contacts { compact, ok := c.(string) if !ok { t.Error("contact is not a string") return } contact := Contact{} err := contact.UnmarshalCompact([]byte(compact)) if err != nil { t.Error(err) return } var currNode Contact currNodeFound := false if _, ok := foundNodes[contact.id.Hex()]; ok { t.Errorf("contact %s appears multiple times", contact.id.Hex()) continue } for _, n := range nodes { if n.id.Equals(contact.id) { currNode = n currNodeFound = true foundNodes[contact.id.Hex()] = true break } } if !currNodeFound { t.Errorf("unexpected contact %s", contact.id.Hex()) continue } if !currNode.ip.Equal(contact.ip) { t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String()) } if contact.port != currNode.port { t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port) } } }