diff --git a/dht/dht.go b/dht/dht.go index a6688bd..bd56323 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -13,6 +13,11 @@ import ( "github.com/spf13/cast" ) +func init() { + //log.SetFormatter(&log.TextFormatter{ForceColors: true}) + //log.SetLevel(log.DebugLevel) +} + const network = "udp4" const alpha = 3 // this is the constant alpha in the spec @@ -67,6 +72,7 @@ type UDPConn interface { WriteToUDP([]byte, *net.UDPAddr) (int, error) SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error + Close() error } // DHT represents a DHT node. @@ -79,6 +85,7 @@ type DHT struct { store *peerStore tm *transactionManager stop *stopOnce.Stopper + stopWG *sync.WaitGroup } // New returns a DHT pointer. If config is nil, then config will be set to the default config. @@ -120,6 +127,7 @@ func New(config *Config) (*DHT, error) { packets: make(chan packet), store: newPeerStore(), stop: stopOnce.New(), + stopWG: &sync.WaitGroup{}, } d.tm = newTransactionManager(d) return d, nil @@ -127,8 +135,7 @@ func New(config *Config) (*DHT, error) { // init initializes global variables. func (dht *DHT) init() error { - log.Info("Initializing DHT on " + dht.conf.Address) - log.Infof("Node ID is %s", dht.node.id.Hex()) + log.Debugf("Initializing DHT on %s (node id %s)", dht.conf.Address, dht.node.id.HexShort()) listener, err := net.ListenPacket(network, dht.conf.Address) if err != nil { @@ -146,7 +153,11 @@ func (dht *DHT) init() error { // listen receives message from udp. func (dht *DHT) listen() { + dht.stopWG.Add(1) + defer dht.stopWG.Done() + buf := make([]byte, 8192) + for { select { case <-dht.stop.Chan(): @@ -154,8 +165,7 @@ func (dht *DHT) listen() { default: } - dht.conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // need this to periodically check shutdown chan - + dht.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) // need this to periodically check shutdown chan n, raddr, err := dht.conn.ReadFromUDP(buf) if err != nil { if e, ok := err.(net.Error); !ok || !e.Timeout() { @@ -167,12 +177,16 @@ func (dht *DHT) listen() { continue } - dht.packets <- packet{data: buf[:n], raddr: raddr} + data := make([]byte, n) + copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet + + dht.packets <- packet{data: data, raddr: raddr} } } // join makes current node join the dht network. func (dht *DHT) join() { + log.Debugf("[%s] joining network", dht.node.id.HexShort()) // get real node IDs and add them to the routing table for _, addr := range dht.conf.SeedNodes { raddr, err := net.ResolveUDPAddr(network, addr) @@ -191,11 +205,14 @@ func (dht *DHT) join() { // now call iterativeFind on yourself _, err := dht.FindNodes(dht.node.id) if err != nil { - log.Error(err) + log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error()) } } func (dht *DHT) runHandler() { + dht.stopWG.Add(1) + defer dht.stopWG.Done() + var pkt packet for { @@ -209,10 +226,11 @@ func (dht *DHT) runHandler() { } // Start starts the dht -func (dht *DHT) Start() error { +func (dht *DHT) Start() { err := dht.init() if err != nil { - return err + log.Error(err) + return } go dht.listen() @@ -220,13 +238,15 @@ func (dht *DHT) Start() error { dht.join() log.Infof("[%s] DHT ready", dht.node.id.HexShort()) - return nil } // Shutdown shuts down the dht func (dht *DHT) Shutdown() { - log.Infof("[%s] DHT shutting down", dht.node.id.HexShort()) + log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort()) dht.stop.Stop() + dht.stopWG.Wait() + dht.conn.Close() + log.Infof("[%s] DHT stopped", dht.node.id.HexShort()) } func printState(dht *DHT) { @@ -271,11 +291,9 @@ type nodeFinder struct { activeNodesMutex *sync.Mutex activeNodes []Node - shortlistMutex *sync.Mutex - shortlist []Node - - contactedMutex *sync.RWMutex - contacted map[bitmap]bool + shortlistContactedMutex *sync.Mutex + shortlist []Node + contacted map[bitmap]bool } type findNodeResponse struct { @@ -285,15 +303,14 @@ type findNodeResponse struct { func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder { return &nodeFinder{ - dht: dht, - target: target, - findValue: findValue, - findValueMutex: &sync.Mutex{}, - activeNodesMutex: &sync.Mutex{}, - contactedMutex: &sync.RWMutex{}, - shortlistMutex: &sync.Mutex{}, - contacted: make(map[bitmap]bool), - done: stopOnce.New(), + dht: dht, + target: target, + findValue: findValue, + findValueMutex: &sync.Mutex{}, + activeNodesMutex: &sync.Mutex{}, + shortlistContactedMutex: &sync.Mutex{}, + contacted: make(map[bitmap]bool), + done: stopOnce.New(), } } @@ -341,7 +358,7 @@ func (nf *nodeFinder) iterationWorker(num int) { maybeNode := nf.popFromShortlist() if maybeNode == nil { // TODO: block if there are pending requests out from other workers. there may be more shortlist values coming - log.Debugf("[%s] no more nodes in short list", nf.dht.node.id.HexShort()) + log.Debugf("[%s] no more nodes in shortlist", nf.dht.node.id.HexShort()) return } node := *maybeNode @@ -382,7 +399,6 @@ func (nf *nodeFinder) iterationWorker(num int) { } else { log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num) nf.insertIntoActiveList(node) - nf.markContacted(node) nf.appendNewToShortlist(res.FindNodeData) } @@ -394,39 +410,32 @@ func (nf *nodeFinder) iterationWorker(num int) { } } -func (nf *nodeFinder) filterContacted(nodes []Node) []Node { - nf.contactedMutex.RLock() - defer nf.contactedMutex.RUnlock() - filtered := []Node{} +func (nf *nodeFinder) appendNewToShortlist(nodes []Node) { + nf.shortlistContactedMutex.Lock() + defer nf.shortlistContactedMutex.Unlock() + + notContacted := []Node{} for _, n := range nodes { - if ok := nf.contacted[n.id]; !ok { - filtered = append(filtered, n) + if _, ok := nf.contacted[n.id]; !ok { + notContacted = append(notContacted, n) } } - return filtered -} -func (nf *nodeFinder) markContacted(node Node) { - nf.contactedMutex.Lock() - defer nf.contactedMutex.Unlock() - nf.contacted[node.id] = true -} - -func (nf *nodeFinder) appendNewToShortlist(nodes []Node) { - nf.shortlistMutex.Lock() - defer nf.shortlistMutex.Unlock() - nf.shortlist = append(nf.shortlist, nf.filterContacted(nodes)...) + nf.shortlist = append(nf.shortlist, notContacted...) sortNodesInPlace(nf.shortlist, nf.target) } func (nf *nodeFinder) popFromShortlist() *Node { - nf.shortlistMutex.Lock() - defer nf.shortlistMutex.Unlock() + nf.shortlistContactedMutex.Lock() + defer nf.shortlistContactedMutex.Unlock() + if len(nf.shortlist) == 0 { return nil } + first := nf.shortlist[0] nf.shortlist = nf.shortlist[1:] + nf.contacted[first.id] = true return &first } @@ -448,7 +457,6 @@ func (nf *nodeFinder) insertIntoActiveList(node Node) { func (nf *nodeFinder) isSearchFinished() bool { if nf.findValue && len(nf.findValueResult) > 0 { - // if we have a result, always break return true } @@ -458,11 +466,10 @@ func (nf *nodeFinder) isSearchFinished() bool { default: } - nf.shortlistMutex.Lock() - defer nf.shortlistMutex.Unlock() + nf.shortlistContactedMutex.Lock() + defer nf.shortlistContactedMutex.Unlock() if len(nf.shortlist) == 0 { - // no more nodes to contact return true } diff --git a/dht/dht_test.go b/dht/dht_test.go index 6866adc..179da59 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,6 +1,7 @@ package dht import ( + "net" "testing" "time" @@ -8,19 +9,18 @@ import ( ) func TestDHT_FindNodes(t *testing.T) { - //log.SetLevel(log.DebugLevel) - id1 := newRandomBitmap() id2 := newRandomBitmap() id3 := newRandomBitmap() seedIP := "127.0.0.1:21216" - dht, err := New(&Config{Address: seedIP, NodeID: id1.Hex()}) + dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()}) if err != nil { t.Fatal(err) } - go dht.Start() + go dht1.Start() + defer dht1.Shutdown() time.Sleep(1 * time.Second) @@ -29,6 +29,7 @@ func TestDHT_FindNodes(t *testing.T) { t.Fatal(err) } go dht2.Start() + defer dht2.Shutdown() time.Sleep(1 * time.Second) // give dhts a chance to connect @@ -37,8 +38,93 @@ func TestDHT_FindNodes(t *testing.T) { t.Fatal(err) } go dht3.Start() + defer dht3.Shutdown() time.Sleep(1 * time.Second) // give dhts a chance to connect - spew.Dump(dht3.FindNodes(id2)) + foundNodes, err := dht3.FindNodes(id2) + + if err != nil { + t.Fatal(err) + } + + spew.Dump(foundNodes) + + if len(foundNodes) != 2 { + t.Errorf("expected 2 nodes, found %d", len(foundNodes)) + } + + foundOne := false + foundTwo := false + + for _, n := range foundNodes { + if n.id.Equals(id1) { + foundOne = true + } + if n.id.Equals(id2) { + foundTwo = true + } + } + + if !foundOne { + t.Errorf("did not find node %s", id1.Hex()) + } + if !foundTwo { + t.Errorf("did not find node %s", id2.Hex()) + } +} + +func TestDHT_FindValue(t *testing.T) { + id1 := newRandomBitmap() + id2 := newRandomBitmap() + id3 := newRandomBitmap() + + seedIP := "127.0.0.1:21216" + + dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()}) + if err != nil { + t.Fatal(err) + } + go dht1.Start() + defer dht1.Shutdown() + + time.Sleep(1 * time.Second) + + dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}}) + if err != nil { + t.Fatal(err) + } + go dht2.Start() + defer dht2.Shutdown() + + time.Sleep(1 * time.Second) // give dhts a chance to connect + + dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}}) + if err != nil { + t.Fatal(err) + } + go dht3.Start() + defer dht3.Shutdown() + + time.Sleep(1 * time.Second) // give dhts a chance to connect + + nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678} + dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind) + + foundNodes, found, err := dht3.FindValue(nodeToFind.id) + if err != nil { + t.Fatal(err) + } + + if !found { + t.Fatal("node was not found") + } + + if len(foundNodes) != 1 { + t.Fatalf("expected one node, found %d", len(foundNodes)) + } + + if !foundNodes[0].id.Equals(nodeToFind.id) { + t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex()) + } } diff --git a/dht/message.go b/dht/message.go index facabda..cdc1dea 100644 --- a/dht/message.go +++ b/dht/message.go @@ -2,6 +2,7 @@ package dht import ( "encoding/hex" + "strings" "github.com/lbryio/errors.go" @@ -174,18 +175,21 @@ type Response struct { } func (r Response) ArgsDebug() string { - if len(r.FindNodeData) == 0 { + if r.Data != "" { return r.Data } str := "contacts " if r.FindValueKey != "" { - str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " " + str = "value for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " " } + + str += "|" for _, c := range r.FindNodeData { - str += c.Addr().String() + ":" + c.id.HexShort() + ", " + str += c.Addr().String() + ":" + c.id.HexShort() + "," } - return str[:len(str)-2] // chomp off last ", " + str = strings.TrimRight(str, ",") + "|" + return str } func (r Response) MarshalBencode() ([]byte, error) { @@ -235,19 +239,30 @@ func (r *Response) UnmarshalBencode(b []byte) error { return err } - var rawContacts bencode.RawMessage - var ok bool - if rawContacts, ok = rawData["contacts"]; !ok { + if contacts, ok := rawData["contacts"]; ok { + err = bencode.DecodeBytes(contacts, &r.FindNodeData) + if err != nil { + return err + } + } else { for k, v := range rawData { r.FindValueKey = k - rawContacts = v + var compactNodes [][]byte + err = bencode.DecodeBytes(v, &compactNodes) + if err != nil { + return err + } + for _, compact := range compactNodes { + var uncompactedNode Node + err = uncompactedNode.UnmarshalCompact(compact) + if err != nil { + return err + } + r.FindNodeData = append(r.FindNodeData, uncompactedNode) + } break } } - err = bencode.DecodeBytes(rawContacts, &r.FindNodeData) - if err != nil { - return err - } } return nil diff --git a/dht/message_test.go b/dht/message_test.go index 4dc4367..de31e6e 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -2,18 +2,17 @@ package dht import ( "encoding/hex" + "net" "reflect" "strconv" "strings" "testing" + "github.com/davecgh/go-spew/spew" "github.com/lyoshenka/bencode" - log "github.com/sirupsen/logrus" ) func TestBencodeDecodeStoreArgs(t *testing.T) { - log.SetLevel(log.DebugLevel) - blobHash := "3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468" lbryID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" port := hex.EncodeToString([]byte("3333")) @@ -70,6 +69,72 @@ func TestBencodeDecodeStoreArgs(t *testing.T) { t.Error(err) } else if !reflect.DeepEqual(reencoded, data) { t.Error("reencoded data does not match original") - //spew.Dump(reencoded, data) + spew.Dump(reencoded, data) + } +} + +func TestBencodeFindNodesResponse(t *testing.T) { + res := Response{ + ID: newMessageID(), + NodeID: newRandomBitmap().RawString(), + FindNodeData: []Node{ + {id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, + {id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765}, + }, + } + + encoded, err := bencode.EncodeBytes(res) + if err != nil { + t.Fatal(err) + } + + var res2 Response + err = bencode.DecodeBytes(encoded, &res2) + if err != nil { + t.Fatal(err) + } + + compareResponses(t, res, res2) +} + +func TestBencodeFindValueResponse(t *testing.T) { + res := Response{ + ID: newMessageID(), + NodeID: newRandomBitmap().RawString(), + FindValueKey: newRandomBitmap().RawString(), + FindNodeData: []Node{ + {id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, + }, + } + + encoded, err := bencode.EncodeBytes(res) + if err != nil { + t.Fatal(err) + } + + var res2 Response + err = bencode.DecodeBytes(encoded, &res2) + if err != nil { + t.Fatal(err) + } + + compareResponses(t, res, res2) +} + +func compareResponses(t *testing.T, res, res2 Response) { + if res.ID != res2.ID { + t.Errorf("expected ID %s, got %s", res.ID, res2.ID) + } + if res.NodeID != res2.NodeID { + t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID) + } + if res.Data != res2.Data { + t.Errorf("expected Data %s, got %s", res.Data, res2.Data) + } + if res.FindValueKey != res2.FindValueKey { + t.Errorf("expected FindValueKey %s, got %s", res.FindValueKey, res2.FindValueKey) + } + if !reflect.DeepEqual(res.FindNodeData, res2.FindNodeData) { + t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.FindNodeData), spew.Sdump(res2.FindNodeData)) } } diff --git a/dht/routing_table.go b/dht/routing_table.go index 3d914c7..0290b90 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -49,7 +49,7 @@ func (n *Node) UnmarshalCompact(b []byte) error { if len(b) != compactNodeInfoLength { return errors.Err("invalid compact length") } - n.ip = net.IPv4(b[0], b[1], b[2], b[3]) + n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4() n.port = int(uint16(b[5]) | uint16(b[4])<<8) n.id = newBitmapFromBytes(b[6:]) return nil diff --git a/dht/rpc.go b/dht/rpc.go index ac53fd9..286f60b 100644 --- a/dht/rpc.go +++ b/dht/rpc.go @@ -23,21 +23,20 @@ func newMessageID() string { return string(buf) } -// handlePacke handles packets received from udp. +// handlePacket handles packets received from udp. func handlePacket(dht *DHT, pkt packet) { - //log.Infof("[%s] Received message from %s:%s : %s\n", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) + //log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data)) var data map[string]interface{} err := bencode.DecodeBytes(pkt.data, &data) if err != nil { - log.Errorf("error decoding data: %s", err) - log.Errorf(hex.EncodeToString(pkt.data)) + log.Errorf("[%s] error decoding data: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) return } msgType, ok := data[headerTypeField] if !ok { - log.Errorf("decoded data has no message type: %s", data) + log.Errorf("[%s] decoded data has no message type: %s", dht.node.id.HexShort(), spew.Sdump(data)) return } @@ -73,7 +72,7 @@ func handlePacket(dht *DHT, pkt packet) { handleError(dht, pkt.raddr, e) default: - log.Errorf("Invalid message type: %s", msgType) + log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), msgType) return } } @@ -170,18 +169,20 @@ func handleError(dht *DHT, addr *net.UDPAddr, e Error) { // send sends data to a udp address func send(dht *DHT, addr *net.UDPAddr, data Message) error { - if req, ok := data.(Request); ok { - log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args)) - } else if res, ok := data.(Response); ok { - log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug()) - } else { - log.Debugf("[%s] %s", dht.node.id.HexShort(), spew.Sdump(data)) - } encoded, err := bencode.EncodeBytes(data) if err != nil { return err } - //log.Infof("Encoded: %s", string(encoded)) + + if req, ok := data.(Request); ok { + log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)", + dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args)) + } else if res, ok := data.(Response); ok { + log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s", + dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug()) + } else { + log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data)) + } dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) diff --git a/dht/rpc_test.go b/dht/rpc_test.go index 6742d87..70df70a 100644 --- a/dht/rpc_test.go +++ b/dht/rpc_test.go @@ -7,10 +7,22 @@ import ( "testing" "time" + "github.com/lbryio/errors.go" "github.com/lyoshenka/bencode" - log "github.com/sirupsen/logrus" ) +type timeoutErr struct { + error +} + +func (t timeoutErr) Timeout() bool { + return true +} + +func (t timeoutErr) Temporary() bool { + return true +} + type testUDPPacket struct { data []byte addr *net.UDPAddr @@ -20,6 +32,8 @@ type testUDPConn struct { addr *net.UDPAddr toRead chan testUDPPacket writes chan testUDPPacket + + readDeadline time.Time } func newTestUDPConn(addr string) *testUDPConn { @@ -39,12 +53,17 @@ func newTestUDPConn(addr string) *testUDPConn { } func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + var timeoutCh <-chan time.Time + if !t.readDeadline.IsZero() { + timeoutCh = time.After(t.readDeadline.Sub(time.Now())) + } + select { case packet := <-t.toRead: n := copy(b, packet.data) return n, packet.addr, nil - //default: - // return 0, nil, nil + case <-timeoutCh: + return 0, nil, timeoutErr{errors.Err("timeout")} } } @@ -53,16 +72,22 @@ func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { return len(b), nil } -func (t testUDPConn) SetReadDeadline(tm time.Time) error { +func (t *testUDPConn) SetReadDeadline(tm time.Time) error { + t.readDeadline = tm return nil } -func (t testUDPConn) SetWriteDeadline(tm time.Time) error { +func (t *testUDPConn) SetWriteDeadline(tm time.Time) error { + return nil +} + +func (t *testUDPConn) Close() error { + t.toRead = nil + t.writes = nil return nil } func TestPing(t *testing.T) { - log.SetLevel(log.DebugLevel) dhtNodeID := newRandomBitmap() testNodeID := newRandomBitmap() diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go index 122a1a2..061b2ae 100644 --- a/dht/transaction_manager.go +++ b/dht/transaction_manager.go @@ -85,7 +85,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req for i := 0; i < udpRetry; i++ { if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil { - log.Error(err) + log.Errorf("send error: ", err.Error()) continue // try again? return? }