dht iterativefind, some tests

This commit is contained in:
Alex Grintsvayg 2018-03-28 21:05:27 -04:00
parent 05e2d8529a
commit 24c079a7dd
9 changed files with 500 additions and 161 deletions

View file

@ -18,6 +18,10 @@ func (b bitmap) Hex() string {
return hex.EncodeToString(b[0:nodeIDLength])
}
func (b bitmap) HexShort() string {
return hex.EncodeToString(b[0:nodeIDLength])[:8]
}
func (b bitmap) Equals(other bitmap) bool {
for k := range b {
if b[k] != other[k] {

View file

@ -1,10 +1,13 @@
package dht
import (
"context"
"net"
"sync"
"time"
"github.com/lbryio/errors.go"
"github.com/lbryio/lbry.go/stopOnce"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
@ -62,6 +65,7 @@ func NewStandardConfig() *Config {
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
}
@ -74,6 +78,7 @@ type DHT struct {
packets chan packet
store *peerStore
tm *transactionManager
stop *stopOnce.Stopper
}
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
@ -114,6 +119,7 @@ func New(config *Config) (*DHT, error) {
rt: newRoutingTable(node),
packets: make(chan packet),
store: newPeerStore(),
stop: stopOnce.New(),
}
d.tm = newTransactionManager(d)
return d, nil
@ -140,37 +146,52 @@ func (dht *DHT) init() error {
// listen receives message from udp.
func (dht *DHT) listen() {
go func() {
buf := make([]byte, 8192)
for {
n, raddr, err := dht.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("udp read error: %v", err)
continue
} else if raddr == nil {
log.Errorf("udp read with no raddr")
continue
}
dht.packets <- packet{data: buf[:n], raddr: raddr}
buf := make([]byte, 8192)
for {
select {
case <-dht.stop.Chan():
return
default:
}
}()
dht.conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // need this to periodically check shutdown chan
n, raddr, err := dht.conn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() {
log.Errorf("udp read error: %v", err)
}
continue
} else if raddr == nil {
log.Errorf("udp read with no raddr")
continue
}
dht.packets <- packet{data: buf[:n], raddr: raddr}
}
}
// join makes current node join the dht network.
func (dht *DHT) join() {
// get real node IDs and add them to the routing table
for _, addr := range dht.conf.SeedNodes {
raddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
log.Errorln(err)
continue
}
_ = raddr
tmpNode := Node{id: newRandomBitmap(), ip: raddr.IP, port: raddr.Port}
res := dht.tm.Send(tmpNode, &Request{Method: pingMethod})
if res == nil {
log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr)
}
}
// NOTE: Temporary node has NO node id.
//dht.transactionManager.findNode(
// &node{addr: raddr},
// dht.node.id.RawString(),
//)
// now call iterativeFind on yourself
_, err := dht.FindNodes(dht.node.id)
if err != nil {
log.Error(err)
}
}
@ -181,24 +202,33 @@ func (dht *DHT) runHandler() {
select {
case pkt = <-dht.packets:
handlePacket(dht, pkt)
case <-dht.stop.Chan():
return
}
}
}
// Run starts the dht.
func (dht *DHT) Run() error {
// Start starts the dht
func (dht *DHT) Start() error {
err := dht.init()
if err != nil {
return err
}
dht.listen()
go dht.listen()
go dht.runHandler()
dht.join()
log.Info("DHT ready")
dht.runHandler()
log.Infof("[%s] DHT ready", dht.node.id.HexShort())
return nil
}
// Shutdown shuts down the dht
func (dht *DHT) Shutdown() {
log.Infof("[%s] DHT shutting down", dht.node.id.HexShort())
dht.stop.Stop()
}
func printState(dht *DHT) {
t := time.NewTicker(60 * time.Second)
for {
@ -210,68 +240,239 @@ func printState(dht *DHT) {
}
}
//func (dht *DHT) Get(hash bitmap) ([]Node, error) {
// return iterativeFindNode(dht, hash)
//}
//
//func iterativeFindNode(dht *DHT, hash bitmap) ([]Node, error) {
// shortlist := dht.rt.FindClosest(hash, alpha)
// if len(shortlist) == 0 {
// return nil, errors.Err("no nodes in routing table")
// }
//
// pending := make(chan *Node)
// contacted := make(map[bitmap]bool)
// contactedMutex := &sync.Mutex{}
// closestNodeMutex := &sync.Mutex{}
// closestNode := shortlist[0]
// wg := sync.WaitGroup{}
//
// for i := 0; i < alpha; i++ {
// wg.Add(1)
// go func() {
// defer wg.Done()
// for {
// node, ok := <-pending
// if !ok {
// return
// }
//
// contactedMutex.Lock()
// if _, ok := contacted[node.id]; ok {
// contactedMutex.Unlock()
// continue
// }
// contacted[node.id] = true
// contactedMutex.Unlock()
//
// res := dht.tm.Send(node, &Request{
// NodeID: dht.node.id.RawString(),
// Method: findNodeMethod,
// Args: []string{hash.RawString()},
// })
// if res == nil {
// // remove node from shortlist
// continue
// }
//
// for _, n := range res.FindNodeData {
// pending <- &n
// closestNodeMutex.Lock()
// if n.id.Xor(hash).Less(closestNode.id.Xor(hash)) {
// closestNode = &n
// }
// closestNodeMutex.Unlock()
// }
// }
// }()
// }
//
// for _, n := range shortlist {
// pending <- n
// }
//
// wg.Wait()
//
// return nil, nil
//}
func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) {
nf := newNodeFinder(dht, hash, false)
res, err := nf.Find()
if err != nil {
return nil, err
}
return res.Nodes, nil
}
func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
nf := newNodeFinder(dht, hash, true)
res, err := nf.Find()
if err != nil {
return nil, false, err
}
return res.Nodes, res.Found, nil
}
type nodeFinder struct {
findValue bool // true if we're using findValue
target bitmap
dht *DHT
done *stopOnce.Stopper
findValueMutex *sync.Mutex
findValueResult []Node
activeNodesMutex *sync.Mutex
activeNodes []Node
shortlistMutex *sync.Mutex
shortlist []Node
contactedMutex *sync.RWMutex
contacted map[bitmap]bool
}
type findNodeResponse struct {
Found bool
Nodes []Node
}
func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder {
return &nodeFinder{
dht: dht,
target: target,
findValue: findValue,
findValueMutex: &sync.Mutex{},
activeNodesMutex: &sync.Mutex{},
contactedMutex: &sync.RWMutex{},
shortlistMutex: &sync.Mutex{},
contacted: make(map[bitmap]bool),
done: stopOnce.New(),
}
}
func (nf *nodeFinder) Find() (findNodeResponse, error) {
log.Debugf("[%s] starting an iterative Find() for %s (findValue is %t)", nf.dht.node.id.HexShort(), nf.target.HexShort(), nf.findValue)
nf.appendNewToShortlist(nf.dht.rt.GetClosest(nf.target, alpha))
if len(nf.shortlist) == 0 {
return findNodeResponse{}, errors.Err("no nodes in routing table")
}
wg := &sync.WaitGroup{}
for i := 0; i < alpha; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
nf.iterationWorker(i + 1)
}(i)
}
wg.Wait()
// TODO: what to do if we have less than K active nodes, shortlist is empty, but we
// TODO: have other nodes in our routing table whom we have not contacted. prolly contact them?
result := findNodeResponse{}
if nf.findValue && len(nf.findValueResult) > 0 {
result.Found = true
result.Nodes = nf.findValueResult
} else {
result.Nodes = nf.activeNodes
if len(result.Nodes) > bucketSize {
result.Nodes = result.Nodes[:bucketSize]
}
}
return result, nil
}
func (nf *nodeFinder) iterationWorker(num int) {
log.Debugf("[%s] starting worker %d", nf.dht.node.id.HexShort(), num)
defer func() { log.Debugf("[%s] stopping worker %d", nf.dht.node.id.HexShort(), num) }()
for {
maybeNode := nf.popFromShortlist()
if maybeNode == nil {
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
log.Debugf("[%s] no more nodes in short list", nf.dht.node.id.HexShort())
return
}
node := *maybeNode
if node.id.Equals(nf.dht.node.id) {
continue // cannot contact self
}
req := &Request{Args: []string{nf.target.RawString()}}
if nf.findValue {
req.Method = findValueMethod
} else {
req.Method = findNodeMethod
}
log.Debugf("[%s] contacting %s", nf.dht.node.id.HexShort(), node.id.HexShort())
var res *Response
ctx, cancel := context.WithCancel(context.Background())
resCh := nf.dht.tm.SendAsync(ctx, node, req)
select {
case res = <-resCh:
case <-nf.done.Chan():
log.Debugf("[%s] worker %d: canceled", nf.dht.node.id.HexShort(), num)
cancel()
return
}
if res == nil {
// nothing to do, response timed out
} else if nf.findValue && res.FindValueKey != "" {
log.Debugf("[%s] worker %d: got value", nf.dht.node.id.HexShort(), num)
nf.findValueMutex.Lock()
nf.findValueResult = res.FindNodeData
nf.findValueMutex.Unlock()
nf.done.Stop()
return
} else {
log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num)
nf.insertIntoActiveList(node)
nf.markContacted(node)
nf.appendNewToShortlist(res.FindNodeData)
}
if nf.isSearchFinished() {
log.Debugf("[%s] worker %d: search is finished", nf.dht.node.id.HexShort(), num)
nf.done.Stop()
return
}
}
}
func (nf *nodeFinder) filterContacted(nodes []Node) []Node {
nf.contactedMutex.RLock()
defer nf.contactedMutex.RUnlock()
filtered := []Node{}
for _, n := range nodes {
if ok := nf.contacted[n.id]; !ok {
filtered = append(filtered, n)
}
}
return filtered
}
func (nf *nodeFinder) markContacted(node Node) {
nf.contactedMutex.Lock()
defer nf.contactedMutex.Unlock()
nf.contacted[node.id] = true
}
func (nf *nodeFinder) appendNewToShortlist(nodes []Node) {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
nf.shortlist = append(nf.shortlist, nf.filterContacted(nodes)...)
sortNodesInPlace(nf.shortlist, nf.target)
}
func (nf *nodeFinder) popFromShortlist() *Node {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
if len(nf.shortlist) == 0 {
return nil
}
first := nf.shortlist[0]
nf.shortlist = nf.shortlist[1:]
return &first
}
func (nf *nodeFinder) insertIntoActiveList(node Node) {
nf.activeNodesMutex.Lock()
defer nf.activeNodesMutex.Unlock()
inserted := false
for i, n := range nf.activeNodes {
if node.id.Xor(nf.target).Less(n.id.Xor(nf.target)) {
nf.activeNodes = append(nf.activeNodes[:i], append([]Node{node}, nf.activeNodes[i:]...)...)
inserted = true
}
}
if !inserted {
nf.activeNodes = append(nf.activeNodes, node)
}
}
func (nf *nodeFinder) isSearchFinished() bool {
if nf.findValue && len(nf.findValueResult) > 0 {
// if we have a result, always break
return true
}
select {
case <-nf.done.Chan():
return true
default:
}
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
if len(nf.shortlist) == 0 {
// no more nodes to contact
return true
}
nf.activeNodesMutex.Lock()
defer nf.activeNodesMutex.Unlock()
if len(nf.activeNodes) >= bucketSize && nf.activeNodes[bucketSize-1].id.Xor(nf.target).Less(nf.shortlist[0].id.Xor(nf.target)) {
// we have at least K active nodes, and we don't have any closer nodes yet to contact
return true
}
return false
}

44
dht/dht_test.go Normal file
View file

@ -0,0 +1,44 @@
package dht
import (
"testing"
"time"
"github.com/davecgh/go-spew/spew"
)
func TestDHT_FindNodes(t *testing.T) {
//log.SetLevel(log.DebugLevel)
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216"
dht, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil {
t.Fatal(err)
}
go dht.Start()
time.Sleep(1 * time.Second)
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht2.Start()
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht3.Start()
time.Sleep(1 * time.Second) // give dhts a chance to connect
spew.Dump(dht3.FindNodes(id2))
}

View file

@ -183,7 +183,7 @@ func (r Response) ArgsDebug() string {
str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
}
for _, c := range r.FindNodeData {
str += c.Addr().String() + ":" + c.id.Hex()[:8] + ", "
str += c.Addr().String() + ":" + c.id.HexShort() + ", "
}
return str[:len(str)-2] // chomp off last ", "
}
@ -229,7 +229,22 @@ func (r *Response) UnmarshalBencode(b []byte) error {
err = bencode.DecodeBytes(raw.Data, &r.Data)
if err != nil {
err = bencode.DecodeBytes(raw.Data, r.FindNodeData)
var rawData map[string]bencode.RawMessage
err = bencode.DecodeBytes(raw.Data, &rawData)
if err != nil {
return err
}
var rawContacts bencode.RawMessage
var ok bool
if rawContacts, ok = rawData["contacts"]; !ok {
for k, v := range rawData {
r.FindValueKey = k
rawContacts = v
break
}
}
err = bencode.DecodeBytes(rawContacts, &r.FindNodeData)
if err != nil {
return err
}

View file

@ -94,11 +94,11 @@ func (n *Node) UnmarshalBencode(b []byte) error {
}
type SortedNode struct {
node *Node
node Node
xorDistanceToTarget bitmap
}
type byXorDistance []*SortedNode
type byXorDistance []SortedNode
func (a byXorDistance) Len() int { return len(a) }
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
@ -128,9 +128,18 @@ func (rt *RoutingTable) BucketInfo() string {
bucketInfo := []string{}
for i, b := range rt.buckets {
count := countInList(b)
count := 0
ids := ""
for curr := b.Front(); curr != nil; curr = curr.Next() {
count++
if ids != "" {
ids += ", "
}
ids += curr.Value.(Node).id.HexShort()
}
if count > 0 {
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %d", i, count))
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, count, ids))
}
}
if len(bucketInfo) == 0 {
@ -139,12 +148,12 @@ func (rt *RoutingTable) BucketInfo() string {
return strings.Join(bucketInfo, "\n")
}
func (rt *RoutingTable) Update(node *Node) {
func (rt *RoutingTable) Update(node Node) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, node.id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, rt.node.id)
element := findInList(bucket, node.id)
if element == nil {
if bucket.Len() >= bucketSize {
// TODO: Ping front node first. Only remove if it does not respond
@ -167,13 +176,19 @@ func (rt *RoutingTable) RemoveByID(id bitmap) {
}
}
func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node {
func (rt *RoutingTable) GetClosest(target bitmap, limit int) []Node {
rt.lock.RLock()
defer rt.lock.RUnlock()
var toSort []*SortedNode
var toSort []SortedNode
var bucketNum int
if rt.node.id.Equals(target) {
bucketNum = 0
} else {
bucketNum = bucketFor(rt.node.id, target)
}
bucketNum := bucketFor(rt.node.id, target)
bucket := rt.buckets[bucketNum]
toSort = appendNodes(toSort, bucket.Front(), target)
@ -190,7 +205,7 @@ func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node {
sort.Sort(byXorDistance(toSort))
var nodes []*Node
var nodes []Node
for _, c := range toSort {
nodes = append(nodes, c.node)
if len(nodes) >= limit {
@ -203,25 +218,17 @@ func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node {
func findInList(bucket *list.List, value bitmap) *list.Element {
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
if curr.Value.(*Node).id.Equals(value) {
if curr.Value.(Node).id.Equals(value) {
return curr
}
}
return nil
}
func countInList(bucket *list.List) int {
count := 0
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
count++
}
return count
}
func appendNodes(nodes []*SortedNode, start *list.Element, target bitmap) []*SortedNode {
func appendNodes(nodes []SortedNode, start *list.Element, target bitmap) []SortedNode {
for curr := start; curr != nil; curr = curr.Next() {
node := curr.Value.(*Node)
nodes = append(nodes, &SortedNode{node, node.id.Xor(target)})
node := curr.Value.(Node)
nodes = append(nodes, SortedNode{node, node.id.Xor(target)})
}
return nodes
}
@ -232,3 +239,17 @@ func bucketFor(id bitmap, target bitmap) int {
}
return numBuckets - 1 - target.Xor(id).PrefixLen()
}
func sortNodesInPlace(nodes []Node, target bitmap) {
toSort := make([]SortedNode, len(nodes))
for i, n := range nodes {
toSort[i] = SortedNode{n, n.id.Xor(target)}
}
sort.Sort(byXorDistance(toSort))
for i, c := range toSort {
nodes[i] = c.node
}
}

View file

@ -6,15 +6,41 @@ import (
"testing"
)
func TestRoutingTable_bucketFor(t *testing.T) {
target := newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
var tests = []struct {
id bitmap
target bitmap
expected int
}{
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), target, 0},
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), target, 1},
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), target, 1},
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), target, 2},
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), target, 2},
{newBitmapFromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), target, 3},
{newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), target, 4},
{newBitmapFromHex("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), target, 383},
{newBitmapFromHex("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), target, 383},
}
for _, tt := range tests {
bucket := bucketFor(tt.id, tt.target)
if bucket != tt.expected {
t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), tt.target.Hex(), bucket, tt.expected)
}
}
}
func TestRoutingTable(t *testing.T) {
n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000})
rt.Update(&Node{n2, net.ParseIP("127.0.0.1"), 8001})
rt.Update(&Node{n3, net.ParseIP("127.0.0.1"), 8002})
rt.Update(Node{n2, net.ParseIP("127.0.0.1"), 8001})
rt.Update(Node{n3, net.ParseIP("127.0.0.1"), 8002})
contacts := rt.FindClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
contacts := rt.GetClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
if len(contacts) != 1 {
t.Fail()
return
@ -23,7 +49,7 @@ func TestRoutingTable(t *testing.T) {
t.Error(contacts[0])
}
contacts = rt.FindClosest(n2, 10)
contacts = rt.GetClosest(n2, 10)
if len(contacts) != 2 {
t.Error(len(contacts))
return

View file

@ -25,12 +25,13 @@ func newMessageID() string {
// handlePacke handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) {
//log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
//log.Infof("[%s] Received message from %s:%s : %s\n", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data)
if err != nil {
log.Errorf("error decoding data: %s\n%s", err, pkt.data)
log.Errorf("error decoding data: %s", err)
log.Errorf(hex.EncodeToString(pkt.data))
return
}
@ -48,16 +49,17 @@ func handlePacket(dht *DHT, pkt packet) {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
handleRequest(dht, pkt.raddr, request)
case responseType:
response := Response{}
err = bencode.DecodeBytes(pkt.data, &response)
if err != nil {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data)
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug())
handleResponse(dht, pkt.raddr, response)
case errorType:
@ -67,7 +69,7 @@ func handlePacket(dht *DHT, pkt packet) {
ExceptionType: data[headerPayloadField].(string),
Response: getArgs(data[headerArgsField]),
}
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
handleError(dht, pkt.raddr, e)
default:
@ -130,17 +132,17 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
return
}
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.rt.FindClosest(nodeID, bucketSize)
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
if len(closestNodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes {
response.FindNodeData[i] = *n
response.FindNodeData[i] = n
}
send(dht, addr, response)
} else {
@ -155,25 +157,25 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
tx.res <- &response
}
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
spew.Dump(e)
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
// send sends data to the udp.
// send sends data to a udp address
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args))
log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug())
log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug())
} else {
log.Debugf("[%s] %s", dht.node.id.Hex()[:8], spew.Sdump(data))
log.Debugf("[%s] %s", dht.node.id.HexShort(), spew.Sdump(data))
}
encoded, err := bencode.EncodeBytes(data)
if err != nil {

View file

@ -53,6 +53,10 @@ func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
return len(b), nil
}
func (t testUDPConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}
@ -69,8 +73,9 @@ func TestPing(t *testing.T) {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.listen()
go dht.runHandler()
defer dht.Shutdown()
messageID := newMessageID()
@ -164,8 +169,9 @@ func TestStore(t *testing.T) {
}
dht.conn = conn
dht.listen()
go dht.listen()
go dht.runHandler()
defer dht.Shutdown()
messageID := newMessageID()
blobHashToStore := newRandomBitmap().RawString()
@ -257,15 +263,16 @@ func TestFindNode(t *testing.T) {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.listen()
go dht.runHandler()
defer dht.Shutdown()
nodesToInsert := 3
var nodes []Node
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.rt.Update(&n)
dht.rt.Update(n)
}
messageID := newMessageID()
@ -334,15 +341,16 @@ func TestFindValueExisting(t *testing.T) {
}
dht.conn = conn
dht.listen()
go dht.listen()
go dht.runHandler()
defer dht.Shutdown()
nodesToInsert := 3
var nodes []Node
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.rt.Update(&n)
dht.rt.Update(n)
}
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
@ -418,15 +426,16 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
}
dht.conn = conn
dht.listen()
go dht.listen()
go dht.runHandler()
defer dht.Shutdown()
nodesToInsert := 3
var nodes []Node
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.rt.Update(&n)
dht.rt.Update(n)
}
messageID := newMessageID()

View file

@ -1,6 +1,7 @@
package dht
import (
"context"
"net"
"sync"
"time"
@ -10,7 +11,7 @@ import (
// query represents the query data included queried node and query-formed data.
type transaction struct {
node *Node
node Node
req *Request
res chan *Response
}
@ -60,37 +61,53 @@ func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction {
return t
}
func (tm *transactionManager) Send(node *Node, req *Request) *Response {
func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Request) <-chan *Response {
if node.id.Equals(tm.dht.node.id) {
log.Error("sending query to self")
return nil
}
req.ID = newMessageID()
trans := &transaction{
node: node,
req: req,
res: make(chan *Response),
}
ch := make(chan *Response, 1)
tm.insert(trans)
defer tm.delete(trans.req.ID)
go func() {
defer close(ch)
for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
log.Error(err)
break
req.ID = newMessageID()
req.NodeID = tm.dht.node.id.RawString()
trans := &transaction{
node: node,
req: req,
res: make(chan *Response),
}
select {
case res := <-trans.res:
return res
case <-time.After(udpTimeout):
}
}
tm.insert(trans)
defer tm.delete(trans.req.ID)
tm.dht.rt.RemoveByID(trans.node.id)
return nil
for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
log.Error(err)
continue // try again? return?
}
select {
case res := <-trans.res:
ch <- res
return
case <-ctx.Done():
return
case <-time.After(udpTimeout):
}
}
// if request timed out each time
tm.dht.rt.RemoveByID(trans.node.id)
}()
return ch
}
func (tm *transactionManager) Send(node Node, req *Request) *Response {
return <-tm.SendAsync(context.Background(), node, req)
}
// Count returns the number of transactions in the manager