findNode and findValue implemented

This commit is contained in:
Alex Grintsvayg 2018-04-03 12:14:04 -04:00
parent 24c079a7dd
commit a5ef461fc5
8 changed files with 292 additions and 93 deletions

View file

@ -13,6 +13,11 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
) )
func init() {
//log.SetFormatter(&log.TextFormatter{ForceColors: true})
//log.SetLevel(log.DebugLevel)
}
const network = "udp4" const network = "udp4"
const alpha = 3 // this is the constant alpha in the spec const alpha = 3 // this is the constant alpha in the spec
@ -67,6 +72,7 @@ type UDPConn interface {
WriteToUDP([]byte, *net.UDPAddr) (int, error) WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetReadDeadline(time.Time) error SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error SetWriteDeadline(time.Time) error
Close() error
} }
// DHT represents a DHT node. // DHT represents a DHT node.
@ -79,6 +85,7 @@ type DHT struct {
store *peerStore store *peerStore
tm *transactionManager tm *transactionManager
stop *stopOnce.Stopper stop *stopOnce.Stopper
stopWG *sync.WaitGroup
} }
// New returns a DHT pointer. If config is nil, then config will be set to the default config. // New returns a DHT pointer. If config is nil, then config will be set to the default config.
@ -120,6 +127,7 @@ func New(config *Config) (*DHT, error) {
packets: make(chan packet), packets: make(chan packet),
store: newPeerStore(), store: newPeerStore(),
stop: stopOnce.New(), stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
} }
d.tm = newTransactionManager(d) d.tm = newTransactionManager(d)
return d, nil return d, nil
@ -127,8 +135,7 @@ func New(config *Config) (*DHT, error) {
// init initializes global variables. // init initializes global variables.
func (dht *DHT) init() error { func (dht *DHT) init() error {
log.Info("Initializing DHT on " + dht.conf.Address) log.Debugf("Initializing DHT on %s (node id %s)", dht.conf.Address, dht.node.id.HexShort())
log.Infof("Node ID is %s", dht.node.id.Hex())
listener, err := net.ListenPacket(network, dht.conf.Address) listener, err := net.ListenPacket(network, dht.conf.Address)
if err != nil { if err != nil {
@ -146,7 +153,11 @@ func (dht *DHT) init() error {
// listen receives message from udp. // listen receives message from udp.
func (dht *DHT) listen() { func (dht *DHT) listen() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
buf := make([]byte, 8192) buf := make([]byte, 8192)
for { for {
select { select {
case <-dht.stop.Chan(): case <-dht.stop.Chan():
@ -154,8 +165,7 @@ func (dht *DHT) listen() {
default: default:
} }
dht.conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // need this to periodically check shutdown chan dht.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) // need this to periodically check shutdown chan
n, raddr, err := dht.conn.ReadFromUDP(buf) n, raddr, err := dht.conn.ReadFromUDP(buf)
if err != nil { if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() { if e, ok := err.(net.Error); !ok || !e.Timeout() {
@ -167,12 +177,16 @@ func (dht *DHT) listen() {
continue continue
} }
dht.packets <- packet{data: buf[:n], raddr: raddr} data := make([]byte, n)
copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet
dht.packets <- packet{data: data, raddr: raddr}
} }
} }
// join makes current node join the dht network. // join makes current node join the dht network.
func (dht *DHT) join() { func (dht *DHT) join() {
log.Debugf("[%s] joining network", dht.node.id.HexShort())
// get real node IDs and add them to the routing table // get real node IDs and add them to the routing table
for _, addr := range dht.conf.SeedNodes { for _, addr := range dht.conf.SeedNodes {
raddr, err := net.ResolveUDPAddr(network, addr) raddr, err := net.ResolveUDPAddr(network, addr)
@ -191,11 +205,14 @@ func (dht *DHT) join() {
// now call iterativeFind on yourself // now call iterativeFind on yourself
_, err := dht.FindNodes(dht.node.id) _, err := dht.FindNodes(dht.node.id)
if err != nil { if err != nil {
log.Error(err) log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
} }
} }
func (dht *DHT) runHandler() { func (dht *DHT) runHandler() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
var pkt packet var pkt packet
for { for {
@ -209,10 +226,11 @@ func (dht *DHT) runHandler() {
} }
// Start starts the dht // Start starts the dht
func (dht *DHT) Start() error { func (dht *DHT) Start() {
err := dht.init() err := dht.init()
if err != nil { if err != nil {
return err log.Error(err)
return
} }
go dht.listen() go dht.listen()
@ -220,13 +238,15 @@ func (dht *DHT) Start() error {
dht.join() dht.join()
log.Infof("[%s] DHT ready", dht.node.id.HexShort()) log.Infof("[%s] DHT ready", dht.node.id.HexShort())
return nil
} }
// Shutdown shuts down the dht // Shutdown shuts down the dht
func (dht *DHT) Shutdown() { func (dht *DHT) Shutdown() {
log.Infof("[%s] DHT shutting down", dht.node.id.HexShort()) log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
dht.stop.Stop() dht.stop.Stop()
dht.stopWG.Wait()
dht.conn.Close()
log.Infof("[%s] DHT stopped", dht.node.id.HexShort())
} }
func printState(dht *DHT) { func printState(dht *DHT) {
@ -271,10 +291,8 @@ type nodeFinder struct {
activeNodesMutex *sync.Mutex activeNodesMutex *sync.Mutex
activeNodes []Node activeNodes []Node
shortlistMutex *sync.Mutex shortlistContactedMutex *sync.Mutex
shortlist []Node shortlist []Node
contactedMutex *sync.RWMutex
contacted map[bitmap]bool contacted map[bitmap]bool
} }
@ -290,8 +308,7 @@ func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder {
findValue: findValue, findValue: findValue,
findValueMutex: &sync.Mutex{}, findValueMutex: &sync.Mutex{},
activeNodesMutex: &sync.Mutex{}, activeNodesMutex: &sync.Mutex{},
contactedMutex: &sync.RWMutex{}, shortlistContactedMutex: &sync.Mutex{},
shortlistMutex: &sync.Mutex{},
contacted: make(map[bitmap]bool), contacted: make(map[bitmap]bool),
done: stopOnce.New(), done: stopOnce.New(),
} }
@ -341,7 +358,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
maybeNode := nf.popFromShortlist() maybeNode := nf.popFromShortlist()
if maybeNode == nil { if maybeNode == nil {
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming // TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
log.Debugf("[%s] no more nodes in short list", nf.dht.node.id.HexShort()) log.Debugf("[%s] no more nodes in shortlist", nf.dht.node.id.HexShort())
return return
} }
node := *maybeNode node := *maybeNode
@ -382,7 +399,6 @@ func (nf *nodeFinder) iterationWorker(num int) {
} else { } else {
log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num) log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num)
nf.insertIntoActiveList(node) nf.insertIntoActiveList(node)
nf.markContacted(node)
nf.appendNewToShortlist(res.FindNodeData) nf.appendNewToShortlist(res.FindNodeData)
} }
@ -394,39 +410,32 @@ func (nf *nodeFinder) iterationWorker(num int) {
} }
} }
func (nf *nodeFinder) filterContacted(nodes []Node) []Node {
nf.contactedMutex.RLock()
defer nf.contactedMutex.RUnlock()
filtered := []Node{}
for _, n := range nodes {
if ok := nf.contacted[n.id]; !ok {
filtered = append(filtered, n)
}
}
return filtered
}
func (nf *nodeFinder) markContacted(node Node) {
nf.contactedMutex.Lock()
defer nf.contactedMutex.Unlock()
nf.contacted[node.id] = true
}
func (nf *nodeFinder) appendNewToShortlist(nodes []Node) { func (nf *nodeFinder) appendNewToShortlist(nodes []Node) {
nf.shortlistMutex.Lock() nf.shortlistContactedMutex.Lock()
defer nf.shortlistMutex.Unlock() defer nf.shortlistContactedMutex.Unlock()
nf.shortlist = append(nf.shortlist, nf.filterContacted(nodes)...)
notContacted := []Node{}
for _, n := range nodes {
if _, ok := nf.contacted[n.id]; !ok {
notContacted = append(notContacted, n)
}
}
nf.shortlist = append(nf.shortlist, notContacted...)
sortNodesInPlace(nf.shortlist, nf.target) sortNodesInPlace(nf.shortlist, nf.target)
} }
func (nf *nodeFinder) popFromShortlist() *Node { func (nf *nodeFinder) popFromShortlist() *Node {
nf.shortlistMutex.Lock() nf.shortlistContactedMutex.Lock()
defer nf.shortlistMutex.Unlock() defer nf.shortlistContactedMutex.Unlock()
if len(nf.shortlist) == 0 { if len(nf.shortlist) == 0 {
return nil return nil
} }
first := nf.shortlist[0] first := nf.shortlist[0]
nf.shortlist = nf.shortlist[1:] nf.shortlist = nf.shortlist[1:]
nf.contacted[first.id] = true
return &first return &first
} }
@ -448,7 +457,6 @@ func (nf *nodeFinder) insertIntoActiveList(node Node) {
func (nf *nodeFinder) isSearchFinished() bool { func (nf *nodeFinder) isSearchFinished() bool {
if nf.findValue && len(nf.findValueResult) > 0 { if nf.findValue && len(nf.findValueResult) > 0 {
// if we have a result, always break
return true return true
} }
@ -458,11 +466,10 @@ func (nf *nodeFinder) isSearchFinished() bool {
default: default:
} }
nf.shortlistMutex.Lock() nf.shortlistContactedMutex.Lock()
defer nf.shortlistMutex.Unlock() defer nf.shortlistContactedMutex.Unlock()
if len(nf.shortlist) == 0 { if len(nf.shortlist) == 0 {
// no more nodes to contact
return true return true
} }

View file

@ -1,6 +1,7 @@
package dht package dht
import ( import (
"net"
"testing" "testing"
"time" "time"
@ -8,19 +9,18 @@ import (
) )
func TestDHT_FindNodes(t *testing.T) { func TestDHT_FindNodes(t *testing.T) {
//log.SetLevel(log.DebugLevel)
id1 := newRandomBitmap() id1 := newRandomBitmap()
id2 := newRandomBitmap() id2 := newRandomBitmap()
id3 := newRandomBitmap() id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216" seedIP := "127.0.0.1:21216"
dht, err := New(&Config{Address: seedIP, NodeID: id1.Hex()}) dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
go dht.Start() go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@ -29,6 +29,7 @@ func TestDHT_FindNodes(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
go dht2.Start() go dht2.Start()
defer dht2.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect time.Sleep(1 * time.Second) // give dhts a chance to connect
@ -37,8 +38,93 @@ func TestDHT_FindNodes(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
go dht3.Start() go dht3.Start()
defer dht3.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect time.Sleep(1 * time.Second) // give dhts a chance to connect
spew.Dump(dht3.FindNodes(id2)) foundNodes, err := dht3.FindNodes(id2)
if err != nil {
t.Fatal(err)
}
spew.Dump(foundNodes)
if len(foundNodes) != 2 {
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
}
foundOne := false
foundTwo := false
for _, n := range foundNodes {
if n.id.Equals(id1) {
foundOne = true
}
if n.id.Equals(id2) {
foundTwo = true
}
}
if !foundOne {
t.Errorf("did not find node %s", id1.Hex())
}
if !foundTwo {
t.Errorf("did not find node %s", id2.Hex())
}
}
func TestDHT_FindValue(t *testing.T) {
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216"
dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil {
t.Fatal(err)
}
go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second)
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht2.Start()
defer dht2.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht3.Start()
defer dht3.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
foundNodes, found, err := dht3.FindValue(nodeToFind.id)
if err != nil {
t.Fatal(err)
}
if !found {
t.Fatal("node was not found")
}
if len(foundNodes) != 1 {
t.Fatalf("expected one node, found %d", len(foundNodes))
}
if !foundNodes[0].id.Equals(nodeToFind.id) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex())
}
} }

View file

@ -2,6 +2,7 @@ package dht
import ( import (
"encoding/hex" "encoding/hex"
"strings"
"github.com/lbryio/errors.go" "github.com/lbryio/errors.go"
@ -174,18 +175,21 @@ type Response struct {
} }
func (r Response) ArgsDebug() string { func (r Response) ArgsDebug() string {
if len(r.FindNodeData) == 0 { if r.Data != "" {
return r.Data return r.Data
} }
str := "contacts " str := "contacts "
if r.FindValueKey != "" { if r.FindValueKey != "" {
str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " " str = "value for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
} }
str += "|"
for _, c := range r.FindNodeData { for _, c := range r.FindNodeData {
str += c.Addr().String() + ":" + c.id.HexShort() + ", " str += c.Addr().String() + ":" + c.id.HexShort() + ","
} }
return str[:len(str)-2] // chomp off last ", " str = strings.TrimRight(str, ",") + "|"
return str
} }
func (r Response) MarshalBencode() ([]byte, error) { func (r Response) MarshalBencode() ([]byte, error) {
@ -235,19 +239,30 @@ func (r *Response) UnmarshalBencode(b []byte) error {
return err return err
} }
var rawContacts bencode.RawMessage if contacts, ok := rawData["contacts"]; ok {
var ok bool err = bencode.DecodeBytes(contacts, &r.FindNodeData)
if rawContacts, ok = rawData["contacts"]; !ok {
for k, v := range rawData {
r.FindValueKey = k
rawContacts = v
break
}
}
err = bencode.DecodeBytes(rawContacts, &r.FindNodeData)
if err != nil { if err != nil {
return err return err
} }
} else {
for k, v := range rawData {
r.FindValueKey = k
var compactNodes [][]byte
err = bencode.DecodeBytes(v, &compactNodes)
if err != nil {
return err
}
for _, compact := range compactNodes {
var uncompactedNode Node
err = uncompactedNode.UnmarshalCompact(compact)
if err != nil {
return err
}
r.FindNodeData = append(r.FindNodeData, uncompactedNode)
}
break
}
}
} }
return nil return nil

View file

@ -2,18 +2,17 @@ package dht
import ( import (
"encoding/hex" "encoding/hex"
"net"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
) )
func TestBencodeDecodeStoreArgs(t *testing.T) { func TestBencodeDecodeStoreArgs(t *testing.T) {
log.SetLevel(log.DebugLevel)
blobHash := "3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468" blobHash := "3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468"
lbryID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" lbryID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B"
port := hex.EncodeToString([]byte("3333")) port := hex.EncodeToString([]byte("3333"))
@ -70,6 +69,72 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
t.Error(err) t.Error(err)
} else if !reflect.DeepEqual(reencoded, data) { } else if !reflect.DeepEqual(reencoded, data) {
t.Error("reencoded data does not match original") t.Error("reencoded data does not match original")
//spew.Dump(reencoded, data) spew.Dump(reencoded, data)
}
}
func TestBencodeFindNodesResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: newRandomBitmap().RawString(),
FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
},
}
encoded, err := bencode.EncodeBytes(res)
if err != nil {
t.Fatal(err)
}
var res2 Response
err = bencode.DecodeBytes(encoded, &res2)
if err != nil {
t.Fatal(err)
}
compareResponses(t, res, res2)
}
func TestBencodeFindValueResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: newRandomBitmap().RawString(),
FindValueKey: newRandomBitmap().RawString(),
FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
},
}
encoded, err := bencode.EncodeBytes(res)
if err != nil {
t.Fatal(err)
}
var res2 Response
err = bencode.DecodeBytes(encoded, &res2)
if err != nil {
t.Fatal(err)
}
compareResponses(t, res, res2)
}
func compareResponses(t *testing.T, res, res2 Response) {
if res.ID != res2.ID {
t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
}
if res.NodeID != res2.NodeID {
t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID)
}
if res.Data != res2.Data {
t.Errorf("expected Data %s, got %s", res.Data, res2.Data)
}
if res.FindValueKey != res2.FindValueKey {
t.Errorf("expected FindValueKey %s, got %s", res.FindValueKey, res2.FindValueKey)
}
if !reflect.DeepEqual(res.FindNodeData, res2.FindNodeData) {
t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.FindNodeData), spew.Sdump(res2.FindNodeData))
} }
} }

View file

@ -49,7 +49,7 @@ func (n *Node) UnmarshalCompact(b []byte) error {
if len(b) != compactNodeInfoLength { if len(b) != compactNodeInfoLength {
return errors.Err("invalid compact length") return errors.Err("invalid compact length")
} }
n.ip = net.IPv4(b[0], b[1], b[2], b[3]) n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
n.port = int(uint16(b[5]) | uint16(b[4])<<8) n.port = int(uint16(b[5]) | uint16(b[4])<<8)
n.id = newBitmapFromBytes(b[6:]) n.id = newBitmapFromBytes(b[6:])
return nil return nil

View file

@ -23,21 +23,20 @@ func newMessageID() string {
return string(buf) return string(buf)
} }
// handlePacke handles packets received from udp. // handlePacket handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) { func handlePacket(dht *DHT, pkt packet) {
//log.Infof("[%s] Received message from %s:%s : %s\n", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) //log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
var data map[string]interface{} var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data) err := bencode.DecodeBytes(pkt.data, &data)
if err != nil { if err != nil {
log.Errorf("error decoding data: %s", err) log.Errorf("[%s] error decoding data: %s: (%d bytes) %s", dht.node.id.HexShort(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
log.Errorf(hex.EncodeToString(pkt.data))
return return
} }
msgType, ok := data[headerTypeField] msgType, ok := data[headerTypeField]
if !ok { if !ok {
log.Errorf("decoded data has no message type: %s", data) log.Errorf("[%s] decoded data has no message type: %s", dht.node.id.HexShort(), spew.Sdump(data))
return return
} }
@ -73,7 +72,7 @@ func handlePacket(dht *DHT, pkt packet) {
handleError(dht, pkt.raddr, e) handleError(dht, pkt.raddr, e)
default: default:
log.Errorf("Invalid message type: %s", msgType) log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), msgType)
return return
} }
} }
@ -170,18 +169,20 @@ func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
// send sends data to a udp address // send sends data to a udp address
func send(dht *DHT, addr *net.UDPAddr, data Message) error { func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug())
} else {
log.Debugf("[%s] %s", dht.node.id.HexShort(), spew.Sdump(data))
}
encoded, err := bencode.EncodeBytes(data) encoded, err := bencode.EncodeBytes(data)
if err != nil { if err != nil {
return err return err
} }
//log.Infof("Encoded: %s", string(encoded))
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug())
} else {
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
}
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))

View file

@ -7,10 +7,22 @@ import (
"testing" "testing"
"time" "time"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
) )
type timeoutErr struct {
error
}
func (t timeoutErr) Timeout() bool {
return true
}
func (t timeoutErr) Temporary() bool {
return true
}
type testUDPPacket struct { type testUDPPacket struct {
data []byte data []byte
addr *net.UDPAddr addr *net.UDPAddr
@ -20,6 +32,8 @@ type testUDPConn struct {
addr *net.UDPAddr addr *net.UDPAddr
toRead chan testUDPPacket toRead chan testUDPPacket
writes chan testUDPPacket writes chan testUDPPacket
readDeadline time.Time
} }
func newTestUDPConn(addr string) *testUDPConn { func newTestUDPConn(addr string) *testUDPConn {
@ -39,12 +53,17 @@ func newTestUDPConn(addr string) *testUDPConn {
} }
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
var timeoutCh <-chan time.Time
if !t.readDeadline.IsZero() {
timeoutCh = time.After(t.readDeadline.Sub(time.Now()))
}
select { select {
case packet := <-t.toRead: case packet := <-t.toRead:
n := copy(b, packet.data) n := copy(b, packet.data)
return n, packet.addr, nil return n, packet.addr, nil
//default: case <-timeoutCh:
// return 0, nil, nil return 0, nil, timeoutErr{errors.Err("timeout")}
} }
} }
@ -53,16 +72,22 @@ func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
return len(b), nil return len(b), nil
} }
func (t testUDPConn) SetReadDeadline(tm time.Time) error { func (t *testUDPConn) SetReadDeadline(tm time.Time) error {
t.readDeadline = tm
return nil return nil
} }
func (t testUDPConn) SetWriteDeadline(tm time.Time) error { func (t *testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (t *testUDPConn) Close() error {
t.toRead = nil
t.writes = nil
return nil return nil
} }
func TestPing(t *testing.T) { func TestPing(t *testing.T) {
log.SetLevel(log.DebugLevel)
dhtNodeID := newRandomBitmap() dhtNodeID := newRandomBitmap()
testNodeID := newRandomBitmap() testNodeID := newRandomBitmap()

View file

@ -85,7 +85,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
for i := 0; i < udpRetry; i++ { for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil { if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
log.Error(err) log.Errorf("send error: ", err.Error())
continue // try again? return? continue // try again? return?
} }