more improvements

This commit is contained in:
Alex Grintsvayg 2018-04-05 11:35:57 -04:00
parent d9acce359f
commit c254243716
11 changed files with 302 additions and 246 deletions

View file

@ -1,6 +1,8 @@
package cmd
import (
"time"
"github.com/lbryio/reflector.go/dht"
log "github.com/sirupsen/logrus"
@ -20,7 +22,7 @@ func dhtCmd(cmd *cobra.Command, args []string) {
dht, err := dht.New(&dht.Config{
Address: "127.0.0.1:21216",
SeedNodes: []string{"127.0.0.1:21215"},
PrintState: true,
PrintState: 30 * time.Second,
})
if err != nil {
log.Fatal(err)

View file

@ -1,7 +1,9 @@
package dht
import (
"context"
"net"
"strings"
"sync"
"time"
@ -49,8 +51,8 @@ type Config struct {
SeedNodes []string
// the hex-encoded node id for this node. if string is empty, a random id will be generated
NodeID string
// print the state of the dht every minute
PrintState bool
// print the state of the dht every X time
PrintState time.Duration
}
// NewStandardConfig returns a Config pointer with default values.
@ -76,15 +78,26 @@ type UDPConn interface {
// DHT represents a DHT node.
type DHT struct {
// config
conf *Config
// UDP connection for sending and receiving data
conn UDPConn
// the local dht node
node *Node
rt *RoutingTable
// routing table
rt *routingTable
// channel of incoming packets
packets chan packet
// data store
store *peerStore
// transaction manager
tm *transactionManager
// stopper to shut down DHT
stop *stopOnce.Stopper
// wait group for all the things that need to be stopped when DHT shuts down
stopWG *sync.WaitGroup
// channel is closed when DHT joins network
joined chan struct{}
}
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
@ -127,6 +140,7 @@ func New(config *Config) (*DHT, error) {
store: newPeerStore(),
stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
joined: make(chan struct{}),
}
d.tm = newTransactionManager(d)
return d, nil
@ -141,8 +155,14 @@ func (dht *DHT) init() error {
dht.conn = listener.(*net.UDPConn)
if dht.conf.PrintState {
go printState(dht)
if dht.conf.PrintState > 0 {
go func() {
t := time.NewTicker(dht.conf.PrintState)
for {
dht.PrintState()
<-t.C
}
}()
}
return nil
@ -153,7 +173,7 @@ func (dht *DHT) listen() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
buf := make([]byte, 8192)
buf := make([]byte, 16384)
for {
select {
@ -162,7 +182,7 @@ func (dht *DHT) listen() {
default:
}
dht.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) // need this to periodically check shutdown chan
dht.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan
n, raddr, err := dht.conn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() {
@ -204,6 +224,8 @@ func (dht *DHT) join() {
if err != nil {
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
}
close(dht.joined) // if anyone's waiting for join to finish, they'll know its done
}
func (dht *DHT) runHandler() {
@ -234,7 +256,14 @@ func (dht *DHT) Start() {
go dht.runHandler()
dht.join()
log.Infof("[%s] DHT ready on %s", dht.node.id.HexShort(), dht.node.Addr().String())
log.Debugf("[%s] DHT ready on %s", dht.node.id.HexShort(), dht.node.Addr().String())
}
func (dht *DHT) WaitUntilJoined() {
if dht.joined == nil {
panic("dht not initialized")
}
<-dht.joined
}
// Shutdown shuts down the dht
@ -243,7 +272,7 @@ func (dht *DHT) Shutdown() {
dht.stop.Stop()
dht.stopWG.Wait()
dht.conn.Close()
log.Infof("[%s] DHT stopped", dht.node.id.HexShort())
log.Debugf("[%s] DHT stopped", dht.node.id.HexShort())
}
// Get returns the list of nodes that have the blob for the given hash
@ -269,7 +298,7 @@ func (dht *DHT) Announce(hash bitmap) error {
}
for _, node := range res.Nodes {
send(dht, node.Addr(), Request{
dht.tm.SendAsync(context.Background(), node, Request{
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: hash,
@ -285,13 +314,18 @@ func (dht *DHT) Announce(hash bitmap) error {
return nil
}
func printState(dht *DHT) {
t := time.NewTicker(60 * time.Second)
for {
func (dht *DHT) PrintState() {
log.Printf("DHT state at %s", time.Now().Format(time.RFC822Z))
log.Printf("Outstanding transactions: %d", dht.tm.Count())
log.Printf("Known nodes: %d", dht.store.CountKnownNodes())
log.Printf("Buckets: \n%s", dht.rt.BucketInfo())
<-t.C
log.Printf("Stored hashes: %d", dht.store.CountStoredHashes())
log.Printf("Buckets:")
for _, line := range strings.Split(dht.rt.BucketInfo(), "\n") {
log.Println(line)
}
}
func printNodeList(list []Node) {
for i, n := range list {
log.Printf("%d) %s %s:%d", i, n.id.HexShort(), n.ip.String(), n.port)
}
}

146
dht/dht_test.go Normal file
View file

@ -0,0 +1,146 @@
package dht
import (
"math/rand"
"net"
"strconv"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestNodeFinder_FindNodes(t *testing.T) {
dhts := makeDHT(t, 3)
defer func() {
for i := range dhts {
dhts[i].Shutdown()
}
}()
nf := newNodeFinder(dhts[2], newRandomBitmap(), false)
res, err := nf.Find()
if err != nil {
t.Fatal(err)
}
foundNodes, found := res.Nodes, res.Found
if found {
t.Fatal("something was found, but it should not have been")
}
if len(foundNodes) != 2 {
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
}
foundOne := false
foundTwo := false
for _, n := range foundNodes {
if n.id.Equals(dhts[0].node.id) {
foundOne = true
}
if n.id.Equals(dhts[1].node.id) {
foundTwo = true
}
}
if !foundOne {
t.Errorf("did not find node %s", dhts[0].node.id.Hex())
}
if !foundTwo {
t.Errorf("did not find node %s", dhts[1].node.id.Hex())
}
}
func TestNodeFinder_FindValue(t *testing.T) {
dhts := makeDHT(t, 3)
defer func() {
for i := range dhts {
dhts[i].Shutdown()
}
}()
blobHashToFind := newRandomBitmap()
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dhts[0].store.Upsert(blobHashToFind, nodeToFind)
nf := newNodeFinder(dhts[2], blobHashToFind, true)
res, err := nf.Find()
if err != nil {
t.Fatal(err)
}
foundNodes, found := res.Nodes, res.Found
if !found {
t.Fatal("node was not found")
}
if len(foundNodes) != 1 {
t.Fatalf("expected one node, found %d", len(foundNodes))
}
if !foundNodes[0].id.Equals(nodeToFind.id) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex())
}
}
func TestDHT_LargeDHT(t *testing.T) {
rand.Seed(time.Now().UnixNano())
log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.")
nodes := 100
dhts := makeDHT(t, nodes)
defer func() {
for _, d := range dhts {
go d.Shutdown()
}
time.Sleep(1 * time.Second)
}()
wg := &sync.WaitGroup{}
numIDs := nodes / 2
ids := make([]bitmap, numIDs)
for i := 0; i < numIDs; i++ {
ids[i] = newRandomBitmap()
}
for i := 0; i < numIDs; i++ {
go func(i int) {
r := rand.Intn(nodes)
wg.Add(1)
defer wg.Done()
dhts[r].Announce(ids[i])
}(i)
}
wg.Wait()
dhts[1].PrintState()
}
func makeDHT(t *testing.T, numNodes int) []*DHT {
if numNodes < 1 {
return nil
}
ip := "127.0.0.1"
firstPort := 21000
dhts := make([]*DHT, numNodes)
for i := 0; i < numNodes; i++ {
seeds := []string{}
if i > 0 {
seeds = []string{ip + ":" + strconv.Itoa(firstPort)}
}
dht, err := New(&Config{Address: ip + ":" + strconv.Itoa(firstPort+i), NodeID: newRandomBitmap().Hex(), SeedNodes: seeds})
if err != nil {
t.Fatal(err)
}
go dht.Start()
dht.WaitUntilJoined()
dhts[i] = dht
}
return dhts
}

View file

@ -4,6 +4,7 @@ import (
"crypto/rand"
"encoding/hex"
"reflect"
"strconv"
"strings"
"github.com/lbryio/errors.go"
@ -132,7 +133,9 @@ func (r *Request) UnmarshalBencode(b []byte) error {
}
func (r Request) ArgsDebug() string {
if r.Arg != nil {
if r.StoreArgs != nil {
return r.StoreArgs.BlobHash.HexShort() + ", " + r.StoreArgs.Value.LbryID.HexShort() + ":" + strconv.Itoa(r.StoreArgs.Value.Port)
} else if r.Arg != nil {
return r.Arg.HexShort()
}
return ""

View file

@ -3,6 +3,7 @@ package dht
import (
"context"
"sync"
"time"
"github.com/lbryio/errors.go"
"github.com/lbryio/lbry.go/stopOnce"
@ -23,9 +24,9 @@ type nodeFinder struct {
activeNodesMutex *sync.Mutex
activeNodes []Node
shortlistContactedMutex *sync.Mutex
shortlistMutex *sync.Mutex
shortlist []Node
contacted map[bitmap]bool
shortlistAdded map[bitmap]bool
}
type findNodeResponse struct {
@ -40,8 +41,8 @@ func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder {
findValue: findValue,
findValueMutex: &sync.Mutex{},
activeNodesMutex: &sync.Mutex{},
shortlistContactedMutex: &sync.Mutex{},
contacted: make(map[bitmap]bool),
shortlistMutex: &sync.Mutex{},
shortlistAdded: make(map[bitmap]bool),
done: stopOnce.New(),
}
}
@ -91,8 +92,8 @@ func (nf *nodeFinder) iterationWorker(num int) {
if maybeNode == nil {
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
log.Debugf("[%s] no more nodes in shortlist", nf.dht.node.id.HexShort())
return
}
time.Sleep(10 * time.Millisecond)
} else {
node := *maybeNode
if node.id.Equals(nf.dht.node.id) {
@ -133,6 +134,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
nf.insertIntoActiveList(node)
nf.appendNewToShortlist(res.FindNodeData)
}
}
if nf.isSearchFinished() {
log.Debugf("[%s] worker %d: search is finished", nf.dht.node.id.HexShort(), num)
@ -143,23 +145,22 @@ func (nf *nodeFinder) iterationWorker(num int) {
}
func (nf *nodeFinder) appendNewToShortlist(nodes []Node) {
nf.shortlistContactedMutex.Lock()
defer nf.shortlistContactedMutex.Unlock()
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
notContacted := []Node{}
for _, n := range nodes {
if _, ok := nf.contacted[n.id]; !ok {
notContacted = append(notContacted, n)
if _, ok := nf.shortlistAdded[n.id]; !ok {
nf.shortlist = append(nf.shortlist, n)
nf.shortlistAdded[n.id] = true
}
}
nf.shortlist = append(nf.shortlist, notContacted...)
sortNodesInPlace(nf.shortlist, nf.target)
}
func (nf *nodeFinder) popFromShortlist() *Node {
nf.shortlistContactedMutex.Lock()
defer nf.shortlistContactedMutex.Unlock()
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
if len(nf.shortlist) == 0 {
return nil
@ -167,7 +168,6 @@ func (nf *nodeFinder) popFromShortlist() *Node {
first := nf.shortlist[0]
nf.shortlist = nf.shortlist[1:]
nf.contacted[first.id] = true
return &first
}
@ -180,6 +180,7 @@ func (nf *nodeFinder) insertIntoActiveList(node Node) {
if node.id.Xor(nf.target).Less(n.id.Xor(nf.target)) {
nf.activeNodes = append(nf.activeNodes[:i], append([]Node{node}, nf.activeNodes[i:]...)...)
inserted = true
break
}
}
if !inserted {
@ -198,8 +199,8 @@ func (nf *nodeFinder) isSearchFinished() bool {
default:
}
nf.shortlistContactedMutex.Lock()
defer nf.shortlistContactedMutex.Unlock()
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
if len(nf.shortlist) == 0 {
return true

View file

@ -1,134 +0,0 @@
package dht
import (
"net"
"testing"
"time"
)
func TestNodeFinder_FindNodes(t *testing.T) {
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216"
dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil {
t.Fatal(err)
}
go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht2.Start()
defer dht2.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht3.Start()
defer dht3.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
nf := newNodeFinder(dht3, newRandomBitmap(), false)
res, err := nf.Find()
if err != nil {
t.Fatal(err)
}
foundNodes, found := res.Nodes, res.Found
if found {
t.Fatal("something was found, but it should not have been")
}
if len(foundNodes) != 2 {
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
}
foundOne := false
foundTwo := false
for _, n := range foundNodes {
if n.id.Equals(id1) {
foundOne = true
}
if n.id.Equals(id2) {
foundTwo = true
}
}
if !foundOne {
t.Errorf("did not find node %s", id1.Hex())
}
if !foundTwo {
t.Errorf("did not find node %s", id2.Hex())
}
}
func TestNodeFinder_FindValue(t *testing.T) {
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
seedIP := "127.0.0.1:21216"
dht1, err := New(&Config{Address: seedIP, NodeID: id1.Hex()})
if err != nil {
t.Fatal(err)
}
go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second)
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht2.Start()
defer dht2.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
t.Fatal(err)
}
go dht3.Start()
defer dht3.Shutdown()
time.Sleep(1 * time.Second) // give dhts a chance to connect
blobHashToFind := newRandomBitmap()
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dht1.store.Upsert(blobHashToFind, nodeToFind)
nf := newNodeFinder(dht3, blobHashToFind, true)
res, err := nf.Find()
if err != nil {
t.Fatal(err)
}
foundNodes, found := res.Nodes, res.Found
if !found {
t.Fatal("node was not found")
}
if len(foundNodes) != 1 {
t.Fatalf("expected one node, found %d", len(foundNodes))
}
if !foundNodes[0].id.Equals(nodeToFind.id) {
t.Fatalf("found node id %s, expected %s", foundNodes[0].id.Hex(), nodeToFind.id.Hex())
}
}

View file

@ -93,12 +93,12 @@ func (n *Node) UnmarshalBencode(b []byte) error {
return nil
}
type SortedNode struct {
type sortedNode struct {
node Node
xorDistanceToTarget bitmap
}
type byXorDistance []SortedNode
type byXorDistance []sortedNode
func (a byXorDistance) Len() int { return len(a) }
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
@ -106,14 +106,14 @@ func (a byXorDistance) Less(i, j int) bool {
return a[i].xorDistanceToTarget.Less(a[j].xorDistanceToTarget)
}
type RoutingTable struct {
type routingTable struct {
node Node
buckets [numBuckets]*list.List
lock *sync.RWMutex
}
func newRoutingTable(node *Node) *RoutingTable {
var rt RoutingTable
func newRoutingTable(node *Node) *routingTable {
var rt routingTable
for i := range rt.buckets {
rt.buckets[i] = list.New()
}
@ -122,7 +122,7 @@ func newRoutingTable(node *Node) *RoutingTable {
return &rt
}
func (rt *RoutingTable) BucketInfo() string {
func (rt *routingTable) BucketInfo() string {
rt.lock.RLock()
defer rt.lock.RUnlock()
@ -148,7 +148,7 @@ func (rt *RoutingTable) BucketInfo() string {
return strings.Join(bucketInfo, "\n")
}
func (rt *RoutingTable) Update(node Node) {
func (rt *routingTable) Update(node Node) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, node.id)
@ -165,7 +165,7 @@ func (rt *RoutingTable) Update(node Node) {
}
}
func (rt *RoutingTable) RemoveByID(id bitmap) {
func (rt *routingTable) RemoveByID(id bitmap) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, id)
@ -176,11 +176,11 @@ func (rt *RoutingTable) RemoveByID(id bitmap) {
}
}
func (rt *RoutingTable) GetClosest(target bitmap, limit int) []Node {
func (rt *routingTable) GetClosest(target bitmap, limit int) []Node {
rt.lock.RLock()
defer rt.lock.RUnlock()
var toSort []SortedNode
var toSort []sortedNode
var bucketNum int
if rt.node.id.Equals(target) {
@ -225,10 +225,10 @@ func findInList(bucket *list.List, value bitmap) *list.Element {
return nil
}
func appendNodes(nodes []SortedNode, start *list.Element, target bitmap) []SortedNode {
func appendNodes(nodes []sortedNode, start *list.Element, target bitmap) []sortedNode {
for curr := start; curr != nil; curr = curr.Next() {
node := curr.Value.(Node)
nodes = append(nodes, SortedNode{node, node.id.Xor(target)})
nodes = append(nodes, sortedNode{node, node.id.Xor(target)})
}
return nodes
}
@ -241,10 +241,10 @@ func bucketFor(id bitmap, target bitmap) int {
}
func sortNodesInPlace(nodes []Node, target bitmap) {
toSort := make([]SortedNode, len(nodes))
toSort := make([]sortedNode, len(nodes))
for i, n := range nodes {
toSort[i] = SortedNode{n, n.id.Xor(target)}
toSort[i] = sortedNode{n, n.id.Xor(target)}
}
sort.Sort(byXorDistance(toSort))

View file

@ -5,6 +5,7 @@ import (
"net"
"time"
"github.com/lbryio/errors.go"
"github.com/lbryio/lbry.go/util"
"github.com/davecgh/go-spew/spew"
@ -146,7 +147,7 @@ func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return err
return errors.Err(err)
}
if req, ok := data.(Request); ok {
@ -162,5 +163,5 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err = dht.conn.WriteToUDP(encoded, addr)
return err
return errors.Err(err)
}

View file

@ -266,7 +266,7 @@ func TestStore(t *testing.T) {
}
}
if len(dht.store.nodeIDs) != 1 {
if len(dht.store.hashes) != 1 {
t.Error("dht store has wrong number of items")
}

View file

@ -11,7 +11,7 @@ type peer struct {
type peerStore struct {
// map of blob hashes to (map of node IDs to bools)
nodeIDs map[bitmap]map[bitmap]bool
hashes map[bitmap]map[bitmap]bool
// map of node IDs to peers
nodeInfo map[bitmap]peer
lock sync.RWMutex
@ -19,7 +19,7 @@ type peerStore struct {
func newPeerStore() *peerStore {
return &peerStore{
nodeIDs: make(map[bitmap]map[bitmap]bool),
hashes: make(map[bitmap]map[bitmap]bool),
nodeInfo: make(map[bitmap]peer),
}
}
@ -27,10 +27,10 @@ func newPeerStore() *peerStore {
func (s *peerStore) Upsert(blobHash bitmap, node Node) {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.nodeIDs[blobHash]; !ok {
s.nodeIDs[blobHash] = make(map[bitmap]bool)
if _, ok := s.hashes[blobHash]; !ok {
s.hashes[blobHash] = make(map[bitmap]bool)
}
s.nodeIDs[blobHash][node.id] = true
s.hashes[blobHash][node.id] = true
s.nodeInfo[node.id] = peer{node: node}
}
@ -38,7 +38,7 @@ func (s *peerStore) Get(blobHash bitmap) []Node {
s.lock.RLock()
defer s.lock.RUnlock()
var nodes []Node
if ids, ok := s.nodeIDs[blobHash]; ok {
if ids, ok := s.hashes[blobHash]; ok {
for id := range ids {
peer, ok := s.nodeInfo[id]
if !ok {
@ -50,8 +50,8 @@ func (s *peerStore) Get(blobHash bitmap) []Node {
return nodes
}
func (s *peerStore) CountKnownNodes() int {
func (s *peerStore) CountStoredHashes() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.nodeInfo)
return len(s.hashes)
}

View file

@ -3,6 +3,7 @@ package dht
import (
"context"
"net"
"strings"
"sync"
"time"
@ -85,7 +86,9 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req Requ
for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, node.Addr(), tx.req); err != nil {
log.Errorf("send error: ", err.Error())
if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections
log.Error("send error: ", err)
}
continue // try again? return?
}