lbry.go/dht/dht.go
2018-08-07 11:10:12 -04:00

236 lines
5.5 KiB
Go

package dht
import (
"fmt"
"net"
"strings"
"time"
"github.com/lbryio/reflector.go/dht/bits"
"github.com/lbryio/lbry.go/errors"
"github.com/lbryio/lbry.go/stop"
"github.com/sirupsen/logrus"
"github.com/spf13/cast"
)
var log *logrus.Logger
func UseLogger(l *logrus.Logger) {
log = l
}
func init() {
log = logrus.StandardLogger()
//log.SetFormatter(&log.TextFormatter{ForceColors: true})
//log.SetLevel(log.DebugLevel)
}
// DHT represents a DHT node.
type DHT struct {
// config
conf *Config
// local contact
contact Contact
// node
node *Node
// stopGroup to shut down DHT
grp *stop.Group
// channel is closed when DHT joins network
joined chan struct{}
// cache for store tokens
tokenCache *tokenCache
// hashes that need to be put into the announce queue or removed from the queue
announceAddRemove chan queueEdit
}
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
func New(config *Config) *DHT {
if config == nil {
config = NewStandardConfig()
}
d := &DHT{
conf: config,
grp: stop.New(),
joined: make(chan struct{}),
announceAddRemove: make(chan queueEdit),
}
return d
}
func (dht *DHT) connect(conn UDPConn) error {
contact, err := getContact(dht.conf.NodeID, dht.conf.Address)
if err != nil {
return err
}
dht.contact = contact
dht.node = NewNode(contact.ID)
dht.tokenCache = newTokenCache(dht.node, tokenSecretRotationInterval)
err = dht.node.Connect(conn)
if err != nil {
return err
}
return nil
}
// Start starts the dht
func (dht *DHT) Start() error {
listener, err := net.ListenPacket(Network, dht.conf.Address)
if err != nil {
return errors.Err(err)
}
conn := listener.(*net.UDPConn)
err = dht.connect(conn)
if err != nil {
return err
}
dht.join()
log.Infof("[%s] DHT ready on %s (%d nodes found during join)",
dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count())
dht.grp.Add(1)
go func() {
dht.runAnnouncer()
dht.grp.Done()
}()
if dht.conf.RPCPort > 0 {
dht.grp.Add(1)
go func() {
dht.runRPCServer(dht.conf.RPCPort)
dht.grp.Done()
}()
}
return nil
}
// join makes current node join the dht network.
func (dht *DHT) join() {
defer close(dht.joined) // if anyone's waiting for join to finish, they'll know its done
log.Infof("[%s] joining DHT network", dht.node.id.HexShort())
// ping nodes, which gets their real node IDs and adds them to the routing table
atLeastOneNodeResponded := false
for _, addr := range dht.conf.SeedNodes {
err := dht.Ping(addr)
if err != nil {
log.Error(errors.Prefix(fmt.Sprintf("[%s] join", dht.node.id.HexShort()), err))
} else {
atLeastOneNodeResponded = true
}
}
if !atLeastOneNodeResponded {
log.Errorf("[%s] join: no nodes responded to initial ping", dht.node.id.HexShort())
return
}
// now call iterativeFind on yourself
_, _, err := FindContacts(dht.node, dht.node.id, false, dht.grp.Child())
if err != nil {
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
}
// TODO: after joining, refresh all buckets further away than our closest neighbor
// http://xlattice.sourceforge.net/components/protocol/kademlia/specs.html#join
}
// WaitUntilJoined blocks until the node joins the network.
func (dht *DHT) WaitUntilJoined() {
if dht.joined == nil {
panic("dht not initialized")
}
<-dht.joined
}
// Shutdown shuts down the dht
func (dht *DHT) Shutdown() {
log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
dht.grp.StopAndWait()
dht.node.Shutdown()
log.Debugf("[%s] DHT stopped", dht.node.id.HexShort())
}
// Ping pings a given address, creates a temporary contact for sending a message, and returns an error if communication
// fails.
func (dht *DHT) Ping(addr string) error {
raddr, err := net.ResolveUDPAddr(Network, addr)
if err != nil {
return err
}
tmpNode := Contact{ID: bits.Rand(), IP: raddr.IP, Port: raddr.Port}
res := dht.node.Send(tmpNode, Request{Method: pingMethod}, SendOptions{skipIDCheck: true})
if res == nil {
return errors.Err("no response from node %s", addr)
}
return nil
}
// Get returns the list of nodes that have the blob for the given hash
func (dht *DHT) Get(hash bits.Bitmap) ([]Contact, error) {
contacts, found, err := FindContacts(dht.node, hash, true, dht.grp.Child())
if err != nil {
return nil, err
}
if found {
return contacts, nil
}
return nil, nil
}
// PrintState prints the current state of the DHT including address, nr outstanding transactions, stored hashes as well
// as current bucket information.
func (dht *DHT) PrintState() {
log.Printf("DHT node %s at %s", dht.contact.String(), time.Now().Format(time.RFC822Z))
log.Printf("Outstanding transactions: %d", dht.node.CountActiveTransactions())
log.Printf("Stored hashes: %d", dht.node.store.CountStoredHashes())
log.Printf("Buckets:")
for _, line := range strings.Split(dht.node.rt.BucketInfo(), "\n") {
log.Println(line)
}
}
func (dht DHT) ID() bits.Bitmap {
return dht.contact.ID
}
func getContact(nodeID, addr string) (Contact, error) {
var c Contact
if nodeID == "" {
c.ID = bits.Rand()
} else {
c.ID = bits.FromHexP(nodeID)
}
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return c, errors.Err(err)
} else if ip == "" {
return c, errors.Err("address does not contain an IP")
} else if port == "" {
return c, errors.Err("address does not contain a port")
}
c.IP = net.ParseIP(ip)
if c.IP == nil {
return c, errors.Err("invalid ip")
}
c.Port, err = cast.ToIntE(port)
if err != nil {
return c, errors.Err(err)
}
return c, nil
}