findNode and findValue implemented

This commit is contained in:
Alex Grintsvayg 2018-04-03 12:14:04 -04:00
parent 24c079a7dd
commit a5ef461fc5
8 changed files with 292 additions and 93 deletions

View file

@ -13,6 +13,11 @@ import (
"github.com/spf13/cast"
)
func init() {
//log.SetFormatter(&log.TextFormatter{ForceColors: true})
//log.SetLevel(log.DebugLevel)
}
const network = "udp4"
const alpha = 3 // this is the constant alpha in the spec
@ -67,6 +72,7 @@ type UDPConn interface {
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
Close() error
}
// DHT represents a DHT node.
@ -79,6 +85,7 @@ type DHT struct {
store *peerStore
tm *transactionManager
stop *stopOnce.Stopper
stopWG *sync.WaitGroup
}
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
@ -120,6 +127,7 @@ func New(config *Config) (*DHT, error) {
packets: make(chan packet),
store: newPeerStore(),
stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
}
d.tm = newTransactionManager(d)
return d, nil
@ -127,8 +135,7 @@ func New(config *Config) (*DHT, error) {
// init initializes global variables.
func (dht *DHT) init() error {
log.Info("Initializing DHT on " + dht.conf.Address)
log.Infof("Node ID is %s", dht.node.id.Hex())
log.Debugf("Initializing DHT on %s (node id %s)", dht.conf.Address, dht.node.id.HexShort())
listener, err := net.ListenPacket(network, dht.conf.Address)
if err != nil {
@ -146,7 +153,11 @@ func (dht *DHT) init() error {
// listen receives message from udp.
func (dht *DHT) listen() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
buf := make([]byte, 8192)
for {
select {
case <-dht.stop.Chan():
@ -154,8 +165,7 @@ func (dht *DHT) listen() {
default:
}
dht.conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // need this to periodically check shutdown chan
dht.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) // need this to periodically check shutdown chan
n, raddr, err := dht.conn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() {
@ -167,12 +177,16 @@ func (dht *DHT) listen() {
continue
}
dht.packets <- packet{data: buf[:n], raddr: raddr}
data := make([]byte, n)
copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet
dht.packets <- packet{data: data, raddr: raddr}
}
}
// join makes current node join the dht network.
func (dht *DHT) join() {
log.Debugf("[%s] joining network", dht.node.id.HexShort())
// get real node IDs and add them to the routing table
for _, addr := range dht.conf.SeedNodes {
raddr, err := net.ResolveUDPAddr(network, addr)
@ -191,11 +205,14 @@ func (dht *DHT) join() {
// now call iterativeFind on yourself
_, err := dht.FindNodes(dht.node.id)
if err != nil {
log.Error(err)
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
}
}
func (dht *DHT) runHandler() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
var pkt packet
for {
@ -209,10 +226,11 @@ func (dht *DHT) runHandler() {
}
// Start starts the dht
func (dht *DHT) Start() error {
func (dht *DHT) Start() {
err := dht.init()
if err != nil {
return err
log.Error(err)
return
}
go dht.listen()
@ -220,13 +238,15 @@ func (dht *DHT) Start() error {
dht.join()
log.Infof("[%s] DHT ready", dht.node.id.HexShort())
return nil
}
// Shutdown shuts down the dht
func (dht *DHT) Shutdown() {
log.Infof("[%s] DHT shutting down", dht.node.id.HexShort())
log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
dht.stop.Stop()
dht.stopWG.Wait()
dht.conn.Close()
log.Infof("[%s] DHT stopped", dht.node.id.HexShort())
}
func printState(dht *DHT) {
@ -271,11 +291,9 @@ type nodeFinder struct {
activeNodesMutex *sync.Mutex
activeNodes []Node
shortlistMutex *sync.Mutex
shortlist []Node
contactedMutex *sync.RWMutex
contacted map[bitmap]bool
shortlistContactedMutex *sync.Mutex
shortlist []Node
contacted map[bitmap]bool
}
type findNodeResponse struct {
@ -285,15 +303,14 @@ type findNodeResponse struct {
func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder {
return &nodeFinder{
dht: dht,
target: target,
findValue: findValue,
findValueMutex: &sync.Mutex{},
activeNodesMutex: &sync.Mutex{},
contactedMutex: &sync.RWMutex{},
shortlistMutex: &sync.Mutex{},
contacted: make(map[bitmap]bool),
done: stopOnce.New(),
dht: dht,
target: target,
findValue: findValue,
findValueMutex: &sync.Mutex{},
activeNodesMutex: &sync.Mutex{},
shortlistContactedMutex: &sync.Mutex{},
contacted: make(map[bitmap]bool),
done: stopOnce.New(),
}
}
@ -341,7 +358,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
maybeNode := nf.popFromShortlist()
if maybeNode == nil {
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
log.Debugf("[%s] no more nodes in short list", nf.dht.node.id.HexShort())
log.Debugf("[%s] no more nodes in shortlist", nf.dht.node.id.HexShort())
return
}
node := *maybeNode
@ -382,7 +399,6 @@ func (nf *nodeFinder) iterationWorker(num int) {
} else {
log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num)
nf.insertIntoActiveList(node)
nf.markContacted(node)
nf.appendNewToShortlist(res.FindNodeData)
}
@ -394,39 +410,32 @@ func (nf *nodeFinder) iterationWorker(num int) {
}
}
func (nf *nodeFinder) filterContacted(nodes []Node) []Node {
nf.contactedMutex.RLock()
defer nf.contactedMutex.RUnlock()
filtered := []Node{}
func (nf *nodeFinder) appendNewToShortlist(nodes []Node) {
nf.shortlistContactedMutex.Lock()
defer nf.shortlistContactedMutex.Unlock()
notContacted := []Node{}
for _, n := range nodes {
if ok := nf.contacted[n.id]; !ok {
filtered = append(filtered, n)
if _, ok := nf.contacted[n.id]; !ok {
notContacted = append(notContacted, n)
}
}
return filtered
}
func (nf *nodeFinder) markContacted(node Node) {
nf.contactedMutex.Lock()
defer nf.contactedMutex.Unlock()
nf.contacted[node.id] = true
}
func (nf *nodeFinder) appendNewToShortlist(nodes []Node) {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
nf.shortlist = append(nf.shortlist, nf.filterContacted(nodes)...)
nf.shortlist = append(nf.shortlist, notContacted...)
sortNodesInPlace(nf.shortlist, nf.target)
}
func (nf *nodeFinder) popFromShortlist() *Node {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
nf.shortlistContactedMutex.Lock()
defer nf.shortlistContactedMutex.Unlock()
if len(nf.shortlist) == 0 {
return nil
}
first := nf.shortlist[0]
nf.shortlist = nf.shortlist[1:]
nf.contacted[first.id] = true
return &first
}
@ -448,7 +457,6 @@ func (nf *nodeFinder) insertIntoActiveList(node Node) {
func (nf *nodeFinder) isSearchFinished() bool {
if nf.findValue && len(nf.findValueResult) > 0 {
// if we have a result, always break
return true
}
@ -458,11 +466,10 @@ func (nf *nodeFinder) isSearchFinished() bool {
default:
}
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
nf.shortlistContactedMutex.Lock()
defer nf.shortlistContactedMutex.Unlock()
if len(nf.shortlist) == 0 {
// no more nodes to contact
return true
}

View file

@ -1,6 +1,7 @@
package dht
import (
"net"
"testing"
"time"
@ -8,19 +9,18 @@ import (
)
func TestDHT_FindNodes(t *testing.T) {
//log.SetLevel(log.DebugLevel)
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216"
dht, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil {
t.Fatal(err)
}
go dht.Start()
go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second)
@ -29,6 +29,7 @@ func TestDHT_FindNodes(t *testing.T) {
t.Fatal(err)
}
go dht2.Start()
defer dht2.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
@ -37,8 +38,93 @@ func TestDHT_FindNodes(t *testing.T) {
t.Fatal(err)
}
go dht3.Start()
defer dht3.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
spew.Dump(dht3.FindNodes(id2))
foundNodes, err := dht3.FindNodes(id2)
if err != nil {
t.Fatal(err)
}
spew.Dump(foundNodes)
if len(foundNodes) != 2 {
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
}
foundOne := false
foundTwo := false
for _, n := range foundNodes {
if n.id.Equals(id1) {
foundOne = true
}
if n.id.Equals(id2) {
foundTwo = true
}
}
if !foundOne {
t.Errorf("did not find node %s", id1.Hex())
}
if !foundTwo {
t.Errorf("did not find node %s", id2.Hex())
}
}
func TestDHT_FindValue(t *testing.T) {
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216"
dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil {
t.Fatal(err)
}
go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second)
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht2.Start()
defer dht2.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht3.Start()
defer dht3.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
foundNodes, found, err := dht3.FindValue(nodeToFind.id)
if err != nil {
t.Fatal(err)
}
if !found {
t.Fatal("node was not found")
}
if len(foundNodes) != 1 {
t.Fatalf("expected one node, found %d", len(foundNodes))
}
if !foundNodes[0].id.Equals(nodeToFind.id) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex())
}
}

View file

@ -2,6 +2,7 @@ package dht
import (
"encoding/hex"
"strings"
"github.com/lbryio/errors.go"
@ -174,18 +175,21 @@ type Response struct {
}
func (r Response) ArgsDebug() string {
if len(r.FindNodeData) == 0 {
if r.Data != "" {
return r.Data
}
str := "contacts "
if r.FindValueKey != "" {
str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
str = "value for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
}
str += "|"
for _, c := range r.FindNodeData {
str += c.Addr().String() + ":" + c.id.HexShort() + ", "
str += c.Addr().String() + ":" + c.id.HexShort() + ","
}
return str[:len(str)-2] // chomp off last ", "
str = strings.TrimRight(str, ",") + "|"
return str
}
func (r Response) MarshalBencode() ([]byte, error) {
@ -235,19 +239,30 @@ func (r *Response) UnmarshalBencode(b []byte) error {
return err
}
var rawContacts bencode.RawMessage
var ok bool
if rawContacts, ok = rawData["contacts"]; !ok {
if contacts, ok := rawData["contacts"]; ok {
err = bencode.DecodeBytes(contacts, &r.FindNodeData)
if err != nil {
return err
}
} else {
for k, v := range rawData {
r.FindValueKey = k
rawContacts = v
var compactNodes [][]byte
err = bencode.DecodeBytes(v, &compactNodes)
if err != nil {
return err
}
for _, compact := range compactNodes {
var uncompactedNode Node
err = uncompactedNode.UnmarshalCompact(compact)
if err != nil {
return err
}
r.FindNodeData = append(r.FindNodeData, uncompactedNode)
}
break
}
}
err = bencode.DecodeBytes(rawContacts, &r.FindNodeData)
if err != nil {
return err
}
}
return nil

View file

@ -2,18 +2,17 @@ package dht
import (
"encoding/hex"
"net"
"reflect"
"strconv"
"strings"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
)
func TestBencodeDecodeStoreArgs(t *testing.T) {
log.SetLevel(log.DebugLevel)
blobHash := "3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468"
lbryID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B"
port := hex.EncodeToString([]byte("3333"))
@ -70,6 +69,72 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
t.Error(err)
} else if !reflect.DeepEqual(reencoded, data) {
t.Error("reencoded data does not match original")
//spew.Dump(reencoded, data)
spew.Dump(reencoded, data)
}
}
func TestBencodeFindNodesResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: newRandomBitmap().RawString(),
FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
},
}
encoded, err := bencode.EncodeBytes(res)
if err != nil {
t.Fatal(err)
}
var res2 Response
err = bencode.DecodeBytes(encoded, &res2)
if err != nil {
t.Fatal(err)
}
compareResponses(t, res, res2)
}
func TestBencodeFindValueResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: newRandomBitmap().RawString(),
FindValueKey: newRandomBitmap().RawString(),
FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
},
}
encoded, err := bencode.EncodeBytes(res)
if err != nil {
t.Fatal(err)
}
var res2 Response
err = bencode.DecodeBytes(encoded, &res2)
if err != nil {
t.Fatal(err)
}
compareResponses(t, res, res2)
}
func compareResponses(t *testing.T, res, res2 Response) {
if res.ID != res2.ID {
t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
}
if res.NodeID != res2.NodeID {
t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID)
}
if res.Data != res2.Data {
t.Errorf("expected Data %s, got %s", res.Data, res2.Data)
}
if res.FindValueKey != res2.FindValueKey {
t.Errorf("expected FindValueKey %s, got %s", res.FindValueKey, res2.FindValueKey)
}
if !reflect.DeepEqual(res.FindNodeData, res2.FindNodeData) {
t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.FindNodeData), spew.Sdump(res2.FindNodeData))
}
}

View file

@ -49,7 +49,7 @@ func (n *Node) UnmarshalCompact(b []byte) error {
if len(b) != compactNodeInfoLength {
return errors.Err("invalid compact length")
}
n.ip = net.IPv4(b[0], b[1], b[2], b[3])
n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
n.port = int(uint16(b[5]) | uint16(b[4])<<8)
n.id = newBitmapFromBytes(b[6:])
return nil

View file

@ -23,21 +23,20 @@ func newMessageID() string {
return string(buf)
}
// handlePacke handles packets received from udp.
// handlePacket handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) {
//log.Infof("[%s] Received message from %s:%s : %s\n", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data)
if err != nil {
log.Errorf("error decoding data: %s", err)
log.Errorf(hex.EncodeToString(pkt.data))
log.Errorf("[%s] error decoding data: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
msgType, ok := data[headerTypeField]
if !ok {
log.Errorf("decoded data has no message type: %s", data)
log.Errorf("[%s] decoded data has no message type: %s", dht.node.id.HexShort(), spew.Sdump(data))
return
}
@ -73,7 +72,7 @@ func handlePacket(dht *DHT, pkt packet) {
handleError(dht, pkt.raddr, e)
default:
log.Errorf("Invalid message type: %s", msgType)
log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), msgType)
return
}
}
@ -170,18 +169,20 @@ func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
// send sends data to a udp address
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug())
} else {
log.Debugf("[%s] %s", dht.node.id.HexShort(), spew.Sdump(data))
}
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return err
}
//log.Infof("Encoded: %s", string(encoded))
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug())
} else {
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
}
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))

View file

@ -7,10 +7,22 @@ import (
"testing"
"time"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
)
type timeoutErr struct {
error
}
func (t timeoutErr) Timeout() bool {
return true
}
func (t timeoutErr) Temporary() bool {
return true
}
type testUDPPacket struct {
data []byte
addr *net.UDPAddr
@ -20,6 +32,8 @@ type testUDPConn struct {
addr *net.UDPAddr
toRead chan testUDPPacket
writes chan testUDPPacket
readDeadline time.Time
}
func newTestUDPConn(addr string) *testUDPConn {
@ -39,12 +53,17 @@ func newTestUDPConn(addr string) *testUDPConn {
}
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
var timeoutCh <-chan time.Time
if !t.readDeadline.IsZero() {
timeoutCh = time.After(t.readDeadline.Sub(time.Now()))
}
select {
case packet := <-t.toRead:
n := copy(b, packet.data)
return n, packet.addr, nil
//default:
// return 0, nil, nil
case <-timeoutCh:
return 0, nil, timeoutErr{errors.Err("timeout")}
}
}
@ -53,16 +72,22 @@ func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
return len(b), nil
}
func (t testUDPConn) SetReadDeadline(tm time.Time) error {
func (t *testUDPConn) SetReadDeadline(tm time.Time) error {
t.readDeadline = tm
return nil
}
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
func (t *testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (t *testUDPConn) Close() error {
t.toRead = nil
t.writes = nil
return nil
}
func TestPing(t *testing.T) {
log.SetLevel(log.DebugLevel)
dhtNodeID := newRandomBitmap()
testNodeID := newRandomBitmap()

View file

@ -85,7 +85,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
log.Error(err)
log.Errorf("send error: ", err.Error())
continue // try again? return?
}