move most dht code into Node

This commit is contained in:
Alex Grintsvayg 2018-04-27 20:16:12 -04:00
parent 34ab2cd1ae
commit 079a6bf610
13 changed files with 758 additions and 727 deletions

7
dht/bootstrap.go Normal file
View file

@ -0,0 +1,7 @@
package dht
// DHT represents a DHT node.
type BootstrapNode struct {
// node
node *Node
}

View file

@ -11,9 +11,9 @@ import (
"github.com/lbryio/errors.go"
"github.com/lbryio/lbry.go/stopOnce"
"github.com/spf13/cast"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
)
func init() {
@ -42,12 +42,6 @@ const compactNodeInfoLength = nodeIDLength + 6
const tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated
// packet represents the information receive from udp.
type packet struct {
data []byte
raddr *net.UDPAddr
}
// Config represents the configure of dht.
type Config struct {
// this node's address. format is `ip:port`
@ -72,33 +66,14 @@ func NewStandardConfig() *Config {
}
}
// UDPConn allows using a mocked connection to test sending/receiving data
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
Close() error
}
// DHT represents a DHT node.
type DHT struct {
// config
conf *Config
// UDP connection for sending and receiving data
conn UDPConn
// the local dht node
// local contact
contact Contact
// node
node *Node
// routing table
rt *routingTable
// channel of incoming packets
packets chan packet
// data store
store *peerStore
// transaction manager
tm *transactionManager
// token manager
tokens *tokenManager
// stopper to shut down DHT
stop *stopOnce.Stopper
// wait group for all the things that need to be stopped when DHT shuts down
@ -113,107 +88,27 @@ func New(config *Config) (*DHT, error) {
config = NewStandardConfig()
}
var id Bitmap
if config.NodeID == "" {
id = RandomBitmapP()
} else {
id = BitmapFromHexP(config.NodeID)
}
ip, port, err := net.SplitHostPort(config.Address)
contact, err := getContact(config.NodeID, config.Address)
if err != nil {
return nil, errors.Err(err)
} else if ip == "" {
return nil, errors.Err("address does not contain an IP")
} else if port == "" {
return nil, errors.Err("address does not contain a port")
return nil, err
}
portInt, err := cast.ToIntE(port)
node, err := NewNode(contact.id)
if err != nil {
return nil, errors.Err(err)
}
node := &Node{id: id, ip: net.ParseIP(ip), port: portInt}
if node.ip == nil {
return nil, errors.Err("invalid ip")
return nil, err
}
d := &DHT{
conf: config,
contact: contact,
node: node,
rt: newRoutingTable(node),
packets: make(chan packet),
store: newPeerStore(),
stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
joined: make(chan struct{}),
tokens: &tokenManager{},
}
d.tm = newTransactionManager(d)
d.tokens.Start(tokenSecretRotationInterval)
return d, nil
}
// init initializes global variables.
func (dht *DHT) init() error {
listener, err := net.ListenPacket(network, dht.conf.Address)
if err != nil {
return errors.Err(err)
}
dht.conn = listener.(*net.UDPConn)
if dht.conf.PrintState > 0 {
go func() {
t := time.NewTicker(dht.conf.PrintState)
for {
dht.PrintState()
select {
case <-t.C:
case <-dht.stop.Chan():
return
}
}
}()
}
return nil
}
// listen receives message from udp.
func (dht *DHT) listen() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
buf := make([]byte, udpMaxMessageLength)
for {
select {
case <-dht.stop.Chan():
return
default:
}
dht.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan
n, raddr, err := dht.conn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() {
log.Errorf("udp read error: %v", err)
}
continue
} else if raddr == nil {
log.Errorf("udp read with no raddr")
continue
}
data := make([]byte, n)
copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet
dht.packets <- packet{data: data, raddr: raddr}
}
}
// join makes current node join the dht network.
func (dht *DHT) join() {
defer close(dht.joined) // if anyone's waiting for join to finish, they'll know its done
@ -243,34 +138,21 @@ func (dht *DHT) join() {
}
}
func (dht *DHT) runHandler() {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
var pkt packet
for {
select {
case pkt = <-dht.packets:
handlePacket(dht, pkt)
case <-dht.stop.Chan():
return
}
}
}
// Start starts the dht
func (dht *DHT) Start() error {
err := dht.init()
listener, err := net.ListenPacket(network, dht.conf.Address)
if err != nil {
return errors.Err(err)
}
conn := listener.(*net.UDPConn)
err = dht.node.Connect(conn)
if err != nil {
return err
}
go dht.listen()
go dht.runHandler()
dht.join()
log.Debugf("[%s] DHT ready on %s (%d nodes found during join)", dht.node.id.HexShort(), dht.node.Addr().String(), dht.rt.Count())
log.Debugf("[%s] DHT ready on %s (%d nodes found during join)", dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count())
return nil
}
@ -286,8 +168,7 @@ func (dht *DHT) Shutdown() {
log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort())
dht.stop.Stop()
dht.stopWG.Wait()
dht.tokens.Stop()
dht.conn.Close()
dht.node.Shutdown()
log.Debugf("[%s] DHT stopped", dht.node.id.HexShort())
}
@ -298,8 +179,8 @@ func (dht *DHT) Ping(addr string) error {
return err
}
tmpNode := Node{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port}
res := dht.tm.Send(tmpNode, Request{Method: pingMethod})
tmpNode := Contact{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port}
res := dht.node.Send(tmpNode, Request{Method: pingMethod})
if res == nil {
return errors.Err("no response from node %s", addr)
}
@ -308,22 +189,22 @@ func (dht *DHT) Ping(addr string) error {
}
// Get returns the list of nodes that have the blob for the given hash
func (dht *DHT) Get(hash Bitmap) ([]Node, error) {
nf := newNodeFinder(dht, hash, true)
func (dht *DHT) Get(hash Bitmap) ([]Contact, error) {
nf := newContactFinder(dht.node, hash, true)
res, err := nf.Find()
if err != nil {
return nil, err
}
if res.Found {
return res.Nodes, nil
return res.Contacts, nil
}
return nil, nil
}
// Announce announces to the DHT that this node has the blob for the given hash
func (dht *DHT) Announce(hash Bitmap) error {
nf := newNodeFinder(dht, hash, false)
nf := newContactFinder(dht.node, hash, false)
res, err := nf.Find()
if err != nil {
return err
@ -331,18 +212,18 @@ func (dht *DHT) Announce(hash Bitmap) error {
// TODO: if this node is closer than farthest peer, store locally and pop farthest peer
for _, node := range res.Nodes {
for _, node := range res.Contacts {
go dht.storeOnNode(hash, node)
}
return nil
}
func (dht *DHT) storeOnNode(hash Bitmap, node Node) {
func (dht *DHT) storeOnNode(hash Bitmap, node Contact) {
dht.stopWG.Add(1)
defer dht.stopWG.Done()
resCh := dht.tm.SendAsync(context.Background(), node, Request{
resCh := dht.node.SendAsync(context.Background(), node, Request{
Method: findValueMethod,
Arg: &hash,
})
@ -358,30 +239,30 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Node) {
return // request timed out
}
dht.tm.SendAsync(context.Background(), node, Request{
dht.node.SendAsync(context.Background(), node, Request{
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: hash,
Value: storeArgsValue{
Token: res.Token,
LbryID: dht.node.id,
Port: dht.node.port,
LbryID: dht.contact.id,
Port: dht.contact.port,
},
},
})
}
func (dht *DHT) PrintState() {
log.Printf("DHT node %s at %s", dht.node.String(), time.Now().Format(time.RFC822Z))
log.Printf("Outstanding transactions: %d", dht.tm.Count())
log.Printf("Stored hashes: %d", dht.store.CountStoredHashes())
log.Printf("DHT node %s at %s", dht.contact.String(), time.Now().Format(time.RFC822Z))
log.Printf("Outstanding transactions: %d", dht.node.CountActiveTransactions())
log.Printf("Stored hashes: %d", dht.node.store.CountStoredHashes())
log.Printf("Buckets:")
for _, line := range strings.Split(dht.rt.BucketInfo(), "\n") {
for _, line := range strings.Split(dht.node.rt.BucketInfo(), "\n") {
log.Println(line)
}
}
func printNodeList(list []Node) {
func printNodeList(list []Contact) {
for i, n := range list {
log.Printf("%d) %s", i, n.String())
}
@ -414,3 +295,33 @@ func MakeTestDHT(numNodes int) []*DHT {
return dhts
}
func getContact(nodeID, addr string) (Contact, error) {
var c Contact
if nodeID == "" {
c.id = RandomBitmapP()
} else {
c.id = BitmapFromHexP(nodeID)
}
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return c, errors.Err(err)
} else if ip == "" {
return c, errors.Err("address does not contain an IP")
} else if port == "" {
return c, errors.Err("address does not contain a port")
}
c.ip = net.ParseIP(ip)
if c.ip == nil {
return c, errors.Err("invalid ip")
}
c.port, err = cast.ToIntE(port)
if err != nil {
return c, errors.Err(err)
}
return c, nil
}

View file

@ -20,12 +20,12 @@ func TestNodeFinder_FindNodes(t *testing.T) {
}
}()
nf := newNodeFinder(dhts[2], RandomBitmapP(), false)
nf := newContactFinder(dhts[2].node, RandomBitmapP(), false)
res, err := nf.Find()
if err != nil {
t.Fatal(err)
}
foundNodes, found := res.Nodes, res.Found
foundNodes, found := res.Contacts, res.Found
if found {
t.Fatal("something was found, but it should not have been")
@ -42,7 +42,7 @@ func TestNodeFinder_FindNodes(t *testing.T) {
if n.id.Equals(dhts[0].node.id) {
foundOne = true
}
//if n.id.Equals(dhts[1].node.id) {
//if n.id.Equals(dhts[1].node.c.id) {
// foundTwo = true
//}
}
@ -51,7 +51,7 @@ func TestNodeFinder_FindNodes(t *testing.T) {
t.Errorf("did not find first node %s", dhts[0].node.id.Hex())
}
//if !foundTwo {
// t.Errorf("did not find second node %s", dhts[1].node.id.Hex())
// t.Errorf("did not find second node %s", dhts[1].node.c.id.Hex())
//}
}
@ -64,15 +64,15 @@ func TestNodeFinder_FindValue(t *testing.T) {
}()
blobHashToFind := RandomBitmapP()
nodeToFind := Node{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dhts[0].store.Upsert(blobHashToFind, nodeToFind)
nodeToFind := Contact{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dhts[0].node.store.Upsert(blobHashToFind, nodeToFind)
nf := newNodeFinder(dhts[2], blobHashToFind, true)
nf := newContactFinder(dhts[2].node, blobHashToFind, true)
res, err := nf.Find()
if err != nil {
t.Fatal(err)
}
foundNodes, found := res.Nodes, res.Found
foundNodes, found := res.Contacts, res.Found
if !found {
t.Fatal("node was not found")

View file

@ -223,7 +223,7 @@ type Response struct {
ID messageID
NodeID Bitmap
Data string
FindNodeData []Node
Contacts []Contact
FindValueKey string
Token string
}
@ -239,7 +239,7 @@ func (r Response) ArgsDebug() string {
}
str += "|"
for _, c := range r.FindNodeData {
for _, c := range r.Contacts {
str += c.Addr().String() + ":" + c.id.HexShort() + ","
}
str = strings.TrimRight(str, ",") + "|"
@ -268,8 +268,8 @@ func (r Response) MarshalBencode() ([]byte, error) {
}
var contacts [][]byte
for _, n := range r.FindNodeData {
compact, err := n.MarshalCompact()
for _, c := range r.Contacts {
compact, err := c.MarshalCompact()
if err != nil {
return nil, err
}
@ -282,12 +282,12 @@ func (r Response) MarshalBencode() ([]byte, error) {
} else if r.Token != "" {
// findValue failure falling back to findNode
data[headerPayloadField] = map[string]interface{}{
contactsField: r.FindNodeData,
contactsField: r.Contacts,
tokenField: r.Token,
}
} else {
// straight up findNode
data[headerPayloadField] = r.FindNodeData
data[headerPayloadField] = r.Contacts
}
return bencode.EncodeBytes(data)
@ -314,7 +314,7 @@ func (r *Response) UnmarshalBencode(b []byte) error {
}
// maybe data is a list of nodes (response to findNode)?
err = bencode.DecodeBytes(raw.Data, &r.FindNodeData)
err = bencode.DecodeBytes(raw.Data, &r.Contacts)
if err == nil {
return nil
}
@ -335,25 +335,25 @@ func (r *Response) UnmarshalBencode(b []byte) error {
}
if contacts, ok := rawData[contactsField]; ok {
err = bencode.DecodeBytes(contacts, &r.FindNodeData)
err = bencode.DecodeBytes(contacts, &r.Contacts)
if err != nil {
return err
}
} else {
for k, v := range rawData {
r.FindValueKey = k
var compactNodes [][]byte
err = bencode.DecodeBytes(v, &compactNodes)
var compactContacts [][]byte
err = bencode.DecodeBytes(v, &compactContacts)
if err != nil {
return err
}
for _, compact := range compactNodes {
var uncompactedNode Node
err = uncompactedNode.UnmarshalCompact(compact)
for _, compact := range compactContacts {
var c Contact
err = c.UnmarshalCompact(compact)
if err != nil {
return err
}
r.FindNodeData = append(r.FindNodeData, uncompactedNode)
r.Contacts = append(r.Contacts, c)
}
break
}

View file

@ -77,7 +77,7 @@ func TestBencodeFindNodesResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: RandomBitmapP(),
FindNodeData: []Node{
Contacts: []Contact{
{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{id: RandomBitmapP(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
},
@ -103,7 +103,7 @@ func TestBencodeFindValueResponse(t *testing.T) {
NodeID: RandomBitmapP(),
FindValueKey: RandomBitmapP().RawString(),
Token: "arst",
FindNodeData: []Node{
Contacts: []Contact{
{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
},
}
@ -182,7 +182,7 @@ func compareResponses(t *testing.T, res, res2 Response) {
if res.Token != res2.Token {
t.Errorf("expected Token %s, got %s", res.Token, res2.Token)
}
if !reflect.DeepEqual(res.FindNodeData, res2.FindNodeData) {
t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.FindNodeData), spew.Sdump(res2.FindNodeData))
if !reflect.DeepEqual(res.Contacts, res2.Contacts) {
t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.Contacts), spew.Sdump(res2.Contacts))
}
}

403
dht/node.go Normal file
View file

@ -0,0 +1,403 @@
package dht
import (
"context"
"encoding/hex"
"net"
"strings"
"sync"
"time"
"github.com/lbryio/errors.go"
"github.com/lbryio/lbry.go/stopOnce"
"github.com/lbryio/lbry.go/util"
"github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
)
// packet represents the information receive from udp.
type packet struct {
data []byte
raddr *net.UDPAddr
}
// UDPConn allows using a mocked connection to test sending/receiving data
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
Close() error
}
type Node struct {
// TODO: replace Contact with id. ip and port aren't used except when connecting
id Bitmap
// UDP connection for sending and receiving data
conn UDPConn
// token manager
tokens *tokenManager
txLock *sync.RWMutex
transactions map[messageID]*transaction
// routing table
rt *routingTable
// data store
store *peerStore
stop *stopOnce.Stopper
stopWG *sync.WaitGroup
}
// New returns a Node pointer.
func NewNode(id Bitmap) (*Node, error) {
n := &Node{
id: id,
rt: newRoutingTable(id),
store: newPeerStore(),
txLock: &sync.RWMutex{},
transactions: make(map[messageID]*transaction),
stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
tokens: &tokenManager{},
}
n.tokens.Start(tokenSecretRotationInterval)
return n, nil
}
func (n *Node) Connect(conn UDPConn) error {
n.conn = conn
//if dht.conf.PrintState > 0 {
// go func() {
// t := time.NewTicker(dht.conf.PrintState)
// for {
// dht.PrintState()
// select {
// case <-t.C:
// case <-dht.stop.Chan():
// return
// }
// }
// }()
//}
packets := make(chan packet)
go func() {
n.stopWG.Add(1)
defer n.stopWG.Done()
buf := make([]byte, udpMaxMessageLength)
for {
select {
case <-n.stop.Chan():
return
default:
}
n.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan
n, raddr, err := n.conn.ReadFromUDP(buf)
if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() {
log.Errorf("udp read error: %v", err)
}
continue
} else if raddr == nil {
log.Errorf("udp read with no raddr")
continue
}
data := make([]byte, n)
copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet
packets <- packet{data: data, raddr: raddr}
}
}()
go func() {
n.stopWG.Add(1)
defer n.stopWG.Done()
var pkt packet
for {
select {
case pkt = <-packets:
n.handlePacket(pkt)
case <-n.stop.Chan():
return
}
}
}()
return nil
}
// Shutdown shuts down the node
func (n *Node) Shutdown() {
log.Debugf("[%s] node shutting down", n.id.HexShort())
n.stop.Stop()
n.stopWG.Wait()
n.tokens.Stop()
n.conn.Close()
log.Debugf("[%s] node stopped", n.id.HexShort())
}
// handlePacket handles packets received from udp.
func (n *Node) handlePacket(pkt packet) {
//log.Debugf("[%s] Received message from %s (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data))
if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) {
log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", n.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
// TODO: test this stuff more thoroughly
// the following is a bit of a hack, but it lets us avoid decoding every message twice
// it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value
switch pkt.data[5] {
case '0' + requestType:
request := Request{}
err := bencode.DecodeBytes(pkt.data, &request)
if err != nil {
log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", n.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug())
n.handleRequest(pkt.raddr, request)
case '0' + responseType:
response := Response{}
err := bencode.DecodeBytes(pkt.data, &response)
if err != nil {
log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received response from %s: %s", n.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
n.handleResponse(pkt.raddr, response)
case '0' + errorType:
e := Error{}
err := bencode.DecodeBytes(pkt.data, &e)
if err != nil {
log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received error from %s: %s", n.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
n.handleError(pkt.raddr, e)
default:
log.Errorf("[%s] invalid message type: %s", n.id.HexShort(), pkt.data[5])
return
}
}
// handleRequest handles the requests received from udp.
func (n *Node) handleRequest(addr *net.UDPAddr, request Request) {
if request.NodeID.Equals(n.id) {
log.Warn("ignoring self-request")
return
}
switch request.Method {
default:
// n.send(addr, makeError(t, protocolError, "invalid q"))
log.Errorln("invalid request method")
return
case pingMethod:
n.send(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse})
case storeMethod:
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
n.store.Upsert(request.StoreArgs.BlobHash, Contact{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
n.send(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse})
} else {
n.send(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"})
}
case findNodeMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
n.send(addr, Response{
ID: request.ID,
NodeID: n.id,
Contacts: n.rt.GetClosest(*request.Arg, bucketSize),
})
case findValueMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
res := Response{
ID: request.ID,
NodeID: n.id,
Token: n.tokens.Get(request.NodeID, addr),
}
if contacts := n.store.Get(*request.Arg); len(contacts) > 0 {
res.FindValueKey = request.Arg.RawString()
res.Contacts = contacts
} else {
res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize)
}
n.send(addr, res)
}
// nodes that send us requests should not be inserted, only refreshed.
// the routing table must only contain "good" nodes, which are nodes that reply to our requests
// if a node is already good (aka in the table), its fine to refresh it
// http://www.bittorrent.org/beps/bep_0005.html#routing-table
n.rt.UpdateIfExists(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port})
}
// handleResponse handles responses received from udp.
func (n *Node) handleResponse(addr *net.UDPAddr, response Response) {
tx := n.txFind(response.ID, addr)
if tx != nil {
tx.res <- response
}
n.rt.Update(Contact{id: response.NodeID, ip: addr.IP, port: addr.Port})
}
// handleError handles errors received from udp.
func (n *Node) handleError(addr *net.UDPAddr, e Error) {
spew.Dump(e)
n.rt.UpdateIfExists(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port})
}
// send sends data to a udp address
func (n *Node) send(addr *net.UDPAddr, data Message) error {
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return errors.Err(err)
}
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
n.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.ArgsDebug())
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
n.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
} else {
log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data))
}
n.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err = n.conn.WriteToUDP(encoded, addr)
return errors.Err(err)
}
// transaction represents a single query to the dht. it stores the queried contact, the request, and the response channel
type transaction struct {
contact Contact
req Request
res chan Response
}
// insert adds a transaction to the manager.
func (n *Node) txInsert(tx *transaction) {
n.txLock.Lock()
defer n.txLock.Unlock()
n.transactions[tx.req.ID] = tx
}
// delete removes a transaction from the manager.
func (n *Node) txDelete(id messageID) {
n.txLock.Lock()
defer n.txLock.Unlock()
delete(n.transactions, id)
}
// Find finds a transaction for the given id. it optionally ensures that addr matches contact from transaction
func (n *Node) txFind(id messageID, addr *net.UDPAddr) *transaction {
n.txLock.RLock()
defer n.txLock.RUnlock()
// TODO: also check that the response's nodeid matches the id you thought you sent to?
t, ok := n.transactions[id]
if !ok || (addr != nil && t.contact.Addr().String() != addr.String()) {
return nil
}
return t
}
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
// The response channel is closed when the transaction is completed or times out.
func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-chan *Response {
if contact.id.Equals(n.id) {
log.Error("sending query to self")
return nil
}
ch := make(chan *Response, 1)
go func() {
defer close(ch)
req.ID = newMessageID()
req.NodeID = n.id
tx := &transaction{
contact: contact,
req: req,
res: make(chan Response),
}
n.txInsert(tx)
defer n.txDelete(tx.req.ID)
for i := 0; i < udpRetry; i++ {
if err := n.send(contact.Addr(), tx.req); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections
log.Error("send error: ", err)
}
continue // try again? return?
}
select {
case res := <-tx.res:
ch <- &res
return
case <-ctx.Done():
return
case <-time.After(udpTimeout):
}
}
// if request timed out each time
n.rt.RemoveByID(tx.contact.id)
}()
return ch
}
// Send sends a transaction and blocks until the response is available. It returns a response, or nil
// if the transaction timed out.
func (n *Node) Send(contact Contact, req Request) *Response {
return <-n.SendAsync(context.Background(), contact, req)
}
// Count returns the number of transactions in the manager
func (n *Node) CountActiveTransactions() int {
n.txLock.Lock()
defer n.txLock.Unlock()
return len(n.transactions)
}

View file

@ -2,6 +2,7 @@ package dht
import (
"context"
"sort"
"sync"
"time"
@ -11,21 +12,21 @@ import (
log "github.com/sirupsen/logrus"
)
type nodeFinder struct {
type contactFinder struct {
findValue bool // true if we're using findValue
target Bitmap
dht *DHT
node *Node
done *stopOnce.Stopper
findValueMutex *sync.Mutex
findValueResult []Node
findValueResult []Contact
activeNodesMutex *sync.Mutex
activeNodes []Node
activeContactsMutex *sync.Mutex
activeContacts []Contact
shortlistMutex *sync.Mutex
shortlist []Node
shortlist []Contact
shortlistAdded map[Bitmap]bool
outstandingRequestsMutex *sync.RWMutex
@ -34,16 +35,16 @@ type nodeFinder struct {
type findNodeResponse struct {
Found bool
Nodes []Node
Contacts []Contact
}
func newNodeFinder(dht *DHT, target Bitmap, findValue bool) *nodeFinder {
return &nodeFinder{
dht: dht,
func newContactFinder(node *Node, target Bitmap, findValue bool) *contactFinder {
return &contactFinder{
node: node,
target: target,
findValue: findValue,
findValueMutex: &sync.Mutex{},
activeNodesMutex: &sync.Mutex{},
activeContactsMutex: &sync.Mutex{},
shortlistMutex: &sync.Mutex{},
shortlistAdded: make(map[Bitmap]bool),
done: stopOnce.New(),
@ -51,15 +52,15 @@ func newNodeFinder(dht *DHT, target Bitmap, findValue bool) *nodeFinder {
}
}
func (nf *nodeFinder) Find() (findNodeResponse, error) {
if nf.findValue {
log.Debugf("[%s] starting an iterative Find for the value %s", nf.dht.node.id.HexShort(), nf.target.HexShort())
func (cf *contactFinder) Find() (findNodeResponse, error) {
if cf.findValue {
log.Debugf("[%s] starting an iterative Find for the value %s", cf.node.id.HexShort(), cf.target.HexShort())
} else {
log.Debugf("[%s] starting an iterative Find for nodes near %s", nf.dht.node.id.HexShort(), nf.target.HexShort())
log.Debugf("[%s] starting an iterative Find for contacts near %s", cf.node.id.HexShort(), cf.target.HexShort())
}
nf.appendNewToShortlist(nf.dht.rt.GetClosest(nf.target, alpha))
if len(nf.shortlist) == 0 {
return findNodeResponse{}, errors.Err("no nodes in routing table")
cf.appendNewToShortlist(cf.node.rt.GetClosest(cf.target, alpha))
if len(cf.shortlist) == 0 {
return findNodeResponse{}, errors.Err("no contacts in routing table")
}
wg := &sync.WaitGroup{}
@ -68,163 +69,163 @@ func (nf *nodeFinder) Find() (findNodeResponse, error) {
wg.Add(1)
go func(i int) {
defer wg.Done()
nf.iterationWorker(i + 1)
cf.iterationWorker(i + 1)
}(i)
}
wg.Wait()
// TODO: what to do if we have less than K active nodes, shortlist is empty, but we
// TODO: have other nodes in our routing table whom we have not contacted. prolly contact them
// TODO: what to do if we have less than K active contacts, shortlist is empty, but we
// TODO: have other contacts in our routing table whom we have not contacted. prolly contact them
result := findNodeResponse{}
if nf.findValue && len(nf.findValueResult) > 0 {
if cf.findValue && len(cf.findValueResult) > 0 {
result.Found = true
result.Nodes = nf.findValueResult
result.Contacts = cf.findValueResult
} else {
result.Nodes = nf.activeNodes
if len(result.Nodes) > bucketSize {
result.Nodes = result.Nodes[:bucketSize]
result.Contacts = cf.activeContacts
if len(result.Contacts) > bucketSize {
result.Contacts = result.Contacts[:bucketSize]
}
}
return result, nil
}
func (nf *nodeFinder) iterationWorker(num int) {
log.Debugf("[%s] starting worker %d", nf.dht.node.id.HexShort(), num)
defer func() { log.Debugf("[%s] stopping worker %d", nf.dht.node.id.HexShort(), num) }()
func (cf *contactFinder) iterationWorker(num int) {
log.Debugf("[%s] starting worker %d", cf.node.id.HexShort(), num)
defer func() { log.Debugf("[%s] stopping worker %d", cf.node.id.HexShort(), num) }()
for {
maybeNode := nf.popFromShortlist()
if maybeNode == nil {
maybeContact := cf.popFromShortlist()
if maybeContact == nil {
// TODO: block if there are pending requests out from other workers. there may be more shortlist values coming
log.Debugf("[%s] worker %d: no nodes in shortlist, waiting...", nf.dht.node.id.HexShort(), num)
log.Debugf("[%s] worker %d: no contacts in shortlist, waiting...", cf.node.id.HexShort(), num)
time.Sleep(100 * time.Millisecond)
} else {
node := *maybeNode
contact := *maybeContact
if node.id.Equals(nf.dht.node.id) {
if contact.id.Equals(cf.node.id) {
continue // cannot contact self
}
req := Request{Arg: &nf.target}
if nf.findValue {
req := Request{Arg: &cf.target}
if cf.findValue {
req.Method = findValueMethod
} else {
req.Method = findNodeMethod
}
log.Debugf("[%s] worker %d: contacting %s", nf.dht.node.id.HexShort(), num, node.id.HexShort())
log.Debugf("[%s] worker %d: contacting %s", cf.node.id.HexShort(), num, contact.id.HexShort())
nf.incrementOutstanding()
cf.incrementOutstanding()
var res *Response
ctx, cancel := context.WithCancel(context.Background())
resCh := nf.dht.tm.SendAsync(ctx, node, req)
resCh := cf.node.SendAsync(ctx, contact, req)
select {
case res = <-resCh:
case <-nf.done.Chan():
log.Debugf("[%s] worker %d: canceled", nf.dht.node.id.HexShort(), num)
case <-cf.done.Chan():
log.Debugf("[%s] worker %d: canceled", cf.node.id.HexShort(), num)
cancel()
return
}
if res == nil {
// nothing to do, response timed out
log.Debugf("[%s] worker %d: timed out waiting for %s", nf.dht.node.id.HexShort(), num, node.id.HexShort())
} else if nf.findValue && res.FindValueKey != "" {
log.Debugf("[%s] worker %d: got value", nf.dht.node.id.HexShort(), num)
nf.findValueMutex.Lock()
nf.findValueResult = res.FindNodeData
nf.findValueMutex.Unlock()
nf.done.Stop()
log.Debugf("[%s] worker %d: timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort())
} else if cf.findValue && res.FindValueKey != "" {
log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num)
cf.findValueMutex.Lock()
cf.findValueResult = res.Contacts
cf.findValueMutex.Unlock()
cf.done.Stop()
return
} else {
log.Debugf("[%s] worker %d: got contacts", nf.dht.node.id.HexShort(), num)
nf.insertIntoActiveList(node)
nf.appendNewToShortlist(res.FindNodeData)
log.Debugf("[%s] worker %d: got contacts", cf.node.id.HexShort(), num)
cf.insertIntoActiveList(contact)
cf.appendNewToShortlist(res.Contacts)
}
nf.decrementOutstanding() // this is all the way down here because we need to add to shortlist first
cf.decrementOutstanding() // this is all the way down here because we need to add to shortlist first
}
if nf.isSearchFinished() {
log.Debugf("[%s] worker %d: search is finished", nf.dht.node.id.HexShort(), num)
nf.done.Stop()
if cf.isSearchFinished() {
log.Debugf("[%s] worker %d: search is finished", cf.node.id.HexShort(), num)
cf.done.Stop()
return
}
}
}
func (nf *nodeFinder) appendNewToShortlist(nodes []Node) {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
func (cf *contactFinder) appendNewToShortlist(contacts []Contact) {
cf.shortlistMutex.Lock()
defer cf.shortlistMutex.Unlock()
for _, n := range nodes {
if _, ok := nf.shortlistAdded[n.id]; !ok {
nf.shortlist = append(nf.shortlist, n)
nf.shortlistAdded[n.id] = true
for _, c := range contacts {
if _, ok := cf.shortlistAdded[c.id]; !ok {
cf.shortlist = append(cf.shortlist, c)
cf.shortlistAdded[c.id] = true
}
}
sortNodesInPlace(nf.shortlist, nf.target)
sortInPlace(cf.shortlist, cf.target)
}
func (nf *nodeFinder) popFromShortlist() *Node {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
func (cf *contactFinder) popFromShortlist() *Contact {
cf.shortlistMutex.Lock()
defer cf.shortlistMutex.Unlock()
if len(nf.shortlist) == 0 {
if len(cf.shortlist) == 0 {
return nil
}
first := nf.shortlist[0]
nf.shortlist = nf.shortlist[1:]
first := cf.shortlist[0]
cf.shortlist = cf.shortlist[1:]
return &first
}
func (nf *nodeFinder) insertIntoActiveList(node Node) {
nf.activeNodesMutex.Lock()
defer nf.activeNodesMutex.Unlock()
func (cf *contactFinder) insertIntoActiveList(contact Contact) {
cf.activeContactsMutex.Lock()
defer cf.activeContactsMutex.Unlock()
inserted := false
for i, n := range nf.activeNodes {
if node.id.Xor(nf.target).Less(n.id.Xor(nf.target)) {
nf.activeNodes = append(nf.activeNodes[:i], append([]Node{node}, nf.activeNodes[i:]...)...)
for i, n := range cf.activeContacts {
if contact.id.Xor(cf.target).Less(n.id.Xor(cf.target)) {
cf.activeContacts = append(cf.activeContacts[:i], append([]Contact{contact}, cf.activeContacts[i:]...)...)
inserted = true
break
}
}
if !inserted {
nf.activeNodes = append(nf.activeNodes, node)
cf.activeContacts = append(cf.activeContacts, contact)
}
}
func (nf *nodeFinder) isSearchFinished() bool {
if nf.findValue && len(nf.findValueResult) > 0 {
func (cf *contactFinder) isSearchFinished() bool {
if cf.findValue && len(cf.findValueResult) > 0 {
return true
}
select {
case <-nf.done.Chan():
case <-cf.done.Chan():
return true
default:
}
if !nf.areRequestsOutstanding() {
nf.shortlistMutex.Lock()
defer nf.shortlistMutex.Unlock()
if !cf.areRequestsOutstanding() {
cf.shortlistMutex.Lock()
defer cf.shortlistMutex.Unlock()
if len(nf.shortlist) == 0 {
if len(cf.shortlist) == 0 {
return true
}
nf.activeNodesMutex.Lock()
defer nf.activeNodesMutex.Unlock()
cf.activeContactsMutex.Lock()
defer cf.activeContactsMutex.Unlock()
if len(nf.activeNodes) >= bucketSize && nf.activeNodes[bucketSize-1].id.Xor(nf.target).Less(nf.shortlist[0].id.Xor(nf.target)) {
// we have at least K active nodes, and we don't have any closer nodes yet to contact
if len(cf.activeContacts) >= bucketSize && cf.activeContacts[bucketSize-1].id.Xor(cf.target).Less(cf.shortlist[0].id.Xor(cf.target)) {
// we have at least K active contacts, and we don't have any closer contacts to ping
return true
}
}
@ -232,20 +233,34 @@ func (nf *nodeFinder) isSearchFinished() bool {
return false
}
func (nf *nodeFinder) incrementOutstanding() {
nf.outstandingRequestsMutex.Lock()
defer nf.outstandingRequestsMutex.Unlock()
nf.outstandingRequests++
func (cf *contactFinder) incrementOutstanding() {
cf.outstandingRequestsMutex.Lock()
defer cf.outstandingRequestsMutex.Unlock()
cf.outstandingRequests++
}
func (nf *nodeFinder) decrementOutstanding() {
nf.outstandingRequestsMutex.Lock()
defer nf.outstandingRequestsMutex.Unlock()
if nf.outstandingRequests > 0 {
nf.outstandingRequests--
func (cf *contactFinder) decrementOutstanding() {
cf.outstandingRequestsMutex.Lock()
defer cf.outstandingRequestsMutex.Unlock()
if cf.outstandingRequests > 0 {
cf.outstandingRequests--
}
}
func (nf *nodeFinder) areRequestsOutstanding() bool {
nf.outstandingRequestsMutex.RLock()
defer nf.outstandingRequestsMutex.RUnlock()
return nf.outstandingRequests > 0
func (cf *contactFinder) areRequestsOutstanding() bool {
cf.outstandingRequestsMutex.RLock()
defer cf.outstandingRequestsMutex.RUnlock()
return cf.outstandingRequests > 0
}
func sortInPlace(contacts []Contact, target Bitmap) {
toSort := make([]sortedContact, len(contacts))
for i, n := range contacts {
toSort[i] = sortedContact{n, n.id.Xor(target)}
}
sort.Sort(byXorDistance(toSort))
for i, c := range toSort {
contacts[i] = c.contact
}
}

View file

@ -97,9 +97,11 @@ func TestPing(t *testing.T) {
if err != nil {
t.Fatal(err)
}
dht.conn = conn
go dht.listen()
go dht.runHandler()
err = dht.node.Connect(conn)
if err != nil {
t.Fatal(err)
}
defer dht.Shutdown()
messageID := newMessageID()
@ -193,9 +195,10 @@ func TestStore(t *testing.T) {
t.Fatal(err)
}
dht.conn = conn
go dht.listen()
go dht.runHandler()
err = dht.node.Connect(conn)
if err != nil {
t.Fatal(err)
}
defer dht.Shutdown()
messageID := newMessageID()
@ -208,7 +211,7 @@ func TestStore(t *testing.T) {
StoreArgs: &storeArgs{
BlobHash: blobHashToStore,
Value: storeArgsValue{
Token: dht.tokens.Get(testNodeID, conn.addr),
Token: dht.node.tokens.Get(testNodeID, conn.addr),
LbryID: testNodeID,
Port: 9999,
},
@ -266,11 +269,11 @@ func TestStore(t *testing.T) {
}
}
if len(dht.store.hashes) != 1 {
if len(dht.node.store.hashes) != 1 {
t.Error("dht store has wrong number of items")
}
items := dht.store.Get(blobHashToStore)
items := dht.node.store.Get(blobHashToStore)
if len(items) != 1 {
t.Error("list created in store, but nothing in list")
}
@ -289,17 +292,19 @@ func TestFindNode(t *testing.T) {
if err != nil {
t.Fatal(err)
}
dht.conn = conn
go dht.listen()
go dht.runHandler()
err = dht.node.Connect(conn)
if err != nil {
t.Fatal(err)
}
defer dht.Shutdown()
nodesToInsert := 3
var nodes []Node
var nodes []Contact
for i := 0; i < nodesToInsert; i++ {
n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.rt.Update(n)
dht.node.rt.Update(n)
}
messageID := newMessageID()
@ -357,17 +362,18 @@ func TestFindValueExisting(t *testing.T) {
t.Fatal(err)
}
dht.conn = conn
go dht.listen()
go dht.runHandler()
err = dht.node.Connect(conn)
if err != nil {
t.Fatal(err)
}
defer dht.Shutdown()
nodesToInsert := 3
var nodes []Node
var nodes []Contact
for i := 0; i < nodesToInsert; i++ {
n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.rt.Update(n)
dht.node.rt.Update(n)
}
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
@ -375,10 +381,10 @@ func TestFindValueExisting(t *testing.T) {
messageID := newMessageID()
valueToFind := RandomBitmapP()
nodeToFind := Node{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286}
dht.store.Upsert(valueToFind, nodeToFind)
dht.store.Upsert(valueToFind, nodeToFind)
dht.store.Upsert(valueToFind, nodeToFind)
nodeToFind := Contact{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286}
dht.node.store.Upsert(valueToFind, nodeToFind)
dht.node.store.Upsert(valueToFind, nodeToFind)
dht.node.store.Upsert(valueToFind, nodeToFind)
request := Request{
ID: messageID,
@ -428,7 +434,7 @@ func TestFindValueExisting(t *testing.T) {
t.Fatal("search results are not a list")
}
verifyCompactContacts(t, contacts, []Node{nodeToFind})
verifyCompactContacts(t, contacts, []Contact{nodeToFind})
}
func TestFindValueFallbackToFindNode(t *testing.T) {
@ -442,17 +448,18 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
t.Fatal(err)
}
dht.conn = conn
go dht.listen()
go dht.runHandler()
err = dht.node.Connect(conn)
if err != nil {
t.Fatal(err)
}
defer dht.Shutdown()
nodesToInsert := 3
var nodes []Node
var nodes []Contact
for i := 0; i < nodesToInsert; i++ {
n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.rt.Update(n)
dht.node.rt.Update(n)
}
messageID := newMessageID()
@ -557,7 +564,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dht
}
}
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) {
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
@ -577,7 +584,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) {
return
}
var currNode Node
var currNode Contact
currNodeFound := false
id, ok := contact[0].(string)
@ -618,7 +625,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) {
}
}
func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Node) {
func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
@ -633,14 +640,14 @@ func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Node) {
return
}
contact := Node{}
contact := Contact{}
err := contact.UnmarshalCompact([]byte(compact))
if err != nil {
t.Error(err)
return
}
var currNode Node
var currNode Contact
currNodeFound := false
if _, ok := foundNodes[contact.id.Hex()]; ok {

View file

@ -14,34 +14,33 @@ import (
"github.com/lyoshenka/bencode"
)
type Node struct {
type Contact struct {
id Bitmap
ip net.IP
port int
token string // this is set when the node is returned from a FindNode call
}
func (n Node) String() string {
return n.id.HexShort() + "@" + n.Addr().String()
func (c Contact) Addr() *net.UDPAddr {
return &net.UDPAddr{IP: c.ip, Port: c.port}
}
func (n Node) Addr() *net.UDPAddr {
return &net.UDPAddr{IP: n.ip, Port: n.port}
func (c Contact) String() string {
return c.id.HexShort() + "@" + c.Addr().String()
}
func (n Node) MarshalCompact() ([]byte, error) {
if n.ip.To4() == nil {
func (c Contact) MarshalCompact() ([]byte, error) {
if c.ip.To4() == nil {
return nil, errors.Err("ip not set")
}
if n.port < 0 || n.port > 65535 {
if c.port < 0 || c.port > 65535 {
return nil, errors.Err("invalid port")
}
var buf bytes.Buffer
buf.Write(n.ip.To4())
buf.WriteByte(byte(n.port >> 8))
buf.WriteByte(byte(n.port))
buf.Write(n.id[:])
buf.Write(c.ip.To4())
buf.WriteByte(byte(c.port >> 8))
buf.WriteByte(byte(c.port))
buf.Write(c.id[:])
if buf.Len() != compactNodeInfoLength {
return nil, errors.Err("i dont know how this happened")
@ -50,21 +49,21 @@ func (n Node) MarshalCompact() ([]byte, error) {
return buf.Bytes(), nil
}
func (n *Node) UnmarshalCompact(b []byte) error {
func (c *Contact) UnmarshalCompact(b []byte) error {
if len(b) != compactNodeInfoLength {
return errors.Err("invalid compact length")
}
n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
n.port = int(uint16(b[5]) | uint16(b[4])<<8)
n.id = BitmapFromBytesP(b[6:])
c.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4()
c.port = int(uint16(b[5]) | uint16(b[4])<<8)
c.id = BitmapFromBytesP(b[6:])
return nil
}
func (n Node) MarshalBencode() ([]byte, error) {
return bencode.EncodeBytes([]interface{}{n.id, n.ip.String(), n.port})
func (c Contact) MarshalBencode() ([]byte, error) {
return bencode.EncodeBytes([]interface{}{c.id, c.ip.String(), c.port})
}
func (n *Node) UnmarshalBencode(b []byte) error {
func (c *Contact) UnmarshalBencode(b []byte) error {
var raw []bencode.RawMessage
err := bencode.DecodeBytes(b, &raw)
if err != nil {
@ -75,7 +74,7 @@ func (n *Node) UnmarshalBencode(b []byte) error {
return errors.Err("contact must have 3 elements; got %d", len(raw))
}
err = bencode.DecodeBytes(raw[0], &n.id)
err = bencode.DecodeBytes(raw[0], &c.id)
if err != nil {
return err
}
@ -85,12 +84,12 @@ func (n *Node) UnmarshalBencode(b []byte) error {
if err != nil {
return err
}
n.ip = net.ParseIP(ipStr).To4()
if n.ip == nil {
c.ip = net.ParseIP(ipStr).To4()
if c.ip == nil {
return errors.Err("invalid IP")
}
err = bencode.DecodeBytes(raw[2], &n.port)
err = bencode.DecodeBytes(raw[2], &c.port)
if err != nil {
return err
}
@ -98,12 +97,12 @@ func (n *Node) UnmarshalBencode(b []byte) error {
return nil
}
type sortedNode struct {
node Node
type sortedContact struct {
contact Contact
xorDistanceToTarget Bitmap
}
type byXorDistance []sortedNode
type byXorDistance []sortedContact
func (a byXorDistance) Len() int { return len(a) }
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
@ -112,17 +111,17 @@ func (a byXorDistance) Less(i, j int) bool {
}
type routingTable struct {
node Node
id Bitmap
buckets [numBuckets]*list.List
lock *sync.RWMutex
}
func newRoutingTable(node *Node) *routingTable {
func newRoutingTable(id Bitmap) *routingTable {
var rt routingTable
for i := range rt.buckets {
rt.buckets[i] = list.New()
}
rt.node = *node
rt.id = id
rt.lock = &sync.RWMutex{}
return &rt
}
@ -131,7 +130,7 @@ func (rt *routingTable) BucketInfo() string {
rt.lock.RLock()
defer rt.lock.RUnlock()
bucketInfo := []string{}
var bucketInfo []string
for i, b := range rt.buckets {
contents := bucketContents(b)
if contents != "" {
@ -152,7 +151,7 @@ func bucketContents(b *list.List) string {
if ids != "" {
ids += ", "
}
ids += curr.Value.(Node).id.HexShort()
ids += curr.Value.(Contact).id.HexShort()
}
if count > 0 {
@ -162,31 +161,31 @@ func bucketContents(b *list.List) string {
}
}
// Update inserts or refreshes a node
func (rt *routingTable) Update(node Node) {
// Update inserts or refreshes a contact
func (rt *routingTable) Update(c Contact) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, node.id)
bucketNum := bucketFor(rt.id, c.id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, node.id)
element := findInList(bucket, c.id)
if element == nil {
if bucket.Len() >= bucketSize {
// TODO: Ping front node first. Only remove if it does not respond
// TODO: Ping front contact first. Only remove if it does not respond
bucket.Remove(bucket.Front())
}
bucket.PushBack(node)
bucket.PushBack(c)
} else {
bucket.MoveToBack(element)
}
}
// UpdateIfExists refreshes a node if its already in the routing table
func (rt *routingTable) UpdateIfExists(node Node) {
// UpdateIfExists refreshes a contact if its already in the routing table
func (rt *routingTable) UpdateIfExists(c Contact) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, node.id)
bucketNum := bucketFor(rt.id, c.id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, node.id)
element := findInList(bucket, c.id)
if element != nil {
bucket.MoveToBack(element)
}
@ -195,55 +194,55 @@ func (rt *routingTable) UpdateIfExists(node Node) {
func (rt *routingTable) RemoveByID(id Bitmap) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, id)
bucketNum := bucketFor(rt.id, id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, rt.node.id)
element := findInList(bucket, rt.id)
if element != nil {
bucket.Remove(element)
}
}
func (rt *routingTable) GetClosest(target Bitmap, limit int) []Node {
func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact {
rt.lock.RLock()
defer rt.lock.RUnlock()
var toSort []sortedNode
var toSort []sortedContact
var bucketNum int
if rt.node.id.Equals(target) {
if rt.id.Equals(target) {
bucketNum = 0
} else {
bucketNum = bucketFor(rt.node.id, target)
bucketNum = bucketFor(rt.id, target)
}
bucket := rt.buckets[bucketNum]
toSort = appendNodes(toSort, bucket.Front(), target)
toSort = appendContacts(toSort, bucket.Front(), target)
for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ {
if bucketNum-i >= 0 {
bucket = rt.buckets[bucketNum-i]
toSort = appendNodes(toSort, bucket.Front(), target)
toSort = appendContacts(toSort, bucket.Front(), target)
}
if bucketNum+i < numBuckets {
bucket = rt.buckets[bucketNum+i]
toSort = appendNodes(toSort, bucket.Front(), target)
toSort = appendContacts(toSort, bucket.Front(), target)
}
}
sort.Sort(byXorDistance(toSort))
var nodes []Node
for _, c := range toSort {
nodes = append(nodes, c.node)
if len(nodes) >= limit {
var contacts []Contact
for _, sorted := range toSort {
contacts = append(contacts, sorted.contact)
if len(contacts) >= limit {
break
}
}
return nodes
return contacts
}
// Count returns the number of nodes in the routing table
// Count returns the number of contacts in the routing table
func (rt *routingTable) Count() int {
rt.lock.RLock()
defer rt.lock.RUnlock()
@ -258,38 +257,24 @@ func (rt *routingTable) Count() int {
func findInList(bucket *list.List, value Bitmap) *list.Element {
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
if curr.Value.(Node).id.Equals(value) {
if curr.Value.(Contact).id.Equals(value) {
return curr
}
}
return nil
}
func appendNodes(nodes []sortedNode, start *list.Element, target Bitmap) []sortedNode {
func appendContacts(contacts []sortedContact, start *list.Element, target Bitmap) []sortedContact {
for curr := start; curr != nil; curr = curr.Next() {
node := curr.Value.(Node)
nodes = append(nodes, sortedNode{node, node.id.Xor(target)})
c := curr.Value.(Contact)
contacts = append(contacts, sortedContact{c, c.id.Xor(target)})
}
return nodes
return contacts
}
func bucketFor(id Bitmap, target Bitmap) int {
if id.Equals(target) {
panic("nodes do not have a bucket for themselves")
panic("routing table does not have a bucket for its own id")
}
return numBuckets - 1 - target.Xor(id).PrefixLen()
}
func sortNodesInPlace(nodes []Node, target Bitmap) {
toSort := make([]sortedNode, len(nodes))
for i, n := range nodes {
toSort[i] = sortedNode{n, n.id.Xor(target)}
}
sort.Sort(byXorDistance(toSort))
for i, c := range toSort {
nodes[i] = c.node
}
}

View file

@ -36,9 +36,9 @@ func TestRoutingTable(t *testing.T) {
n1 := BitmapFromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n2 := BitmapFromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n3 := BitmapFromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000, ""})
rt.Update(Node{n2, net.ParseIP("127.0.0.1"), 8001, ""})
rt.Update(Node{n3, net.ParseIP("127.0.0.1"), 8002, ""})
rt := newRoutingTable(n1)
rt.Update(Contact{n2, net.ParseIP("127.0.0.1"), 8001})
rt.Update(Contact{n3, net.ParseIP("127.0.0.1"), 8002})
contacts := rt.GetClosest(BitmapFromHexP("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
if len(contacts) != 1 {
@ -63,14 +63,14 @@ func TestRoutingTable(t *testing.T) {
}
func TestCompactEncoding(t *testing.T) {
n := Node{
c := Contact{
id: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
ip: net.ParseIP("1.2.3.4"),
port: int(55<<8 + 66),
}
var compact []byte
compact, err := n.MarshalCompact()
compact, err := c.MarshalCompact()
if err != nil {
t.Fatal(err)
}
@ -79,7 +79,7 @@ func TestCompactEncoding(t *testing.T) {
t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength)
}
if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, n.id[:]...)) {
if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.id[:]...)) {
t.Errorf("compact bytes not encoded correctly")
}
}

View file

@ -1,178 +0,0 @@
package dht
import (
"encoding/hex"
"net"
"time"
"github.com/lbryio/errors.go"
"github.com/lbryio/lbry.go/util"
"github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
)
// handlePacket handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) {
//log.Debugf("[%s] Received message from %s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data))
if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) {
log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", dht.node.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
// TODO: test this stuff more thoroughly
// the following is a bit of a hack, but it lets us avoid decoding every message twice
// it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value
switch pkt.data[5] {
case '0' + requestType:
request := Request{}
err := bencode.DecodeBytes(pkt.data, &request)
if err != nil {
log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug())
handleRequest(dht, pkt.raddr, request)
case '0' + responseType:
response := Response{}
err := bencode.DecodeBytes(pkt.data, &response)
if err != nil {
log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
handleResponse(dht, pkt.raddr, response)
case '0' + errorType:
e := Error{}
err := bencode.DecodeBytes(pkt.data, &e)
if err != nil {
log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data))
return
}
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
handleError(dht, pkt.raddr, e)
default:
log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), pkt.data[5])
return
}
}
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
if request.NodeID.Equals(dht.node.id) {
log.Warn("ignoring self-request")
return
}
switch request.Method {
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
log.Errorln("invalid request method")
return
case pingMethod:
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
case storeMethod:
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
if dht.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) {
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
} else {
send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id, ExceptionType: "invalid-token"})
}
case findNodeMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
send(dht, addr, getFindResponse(dht, request))
case findValueMethod:
if request.Arg == nil {
log.Errorln("request is missing arg")
return
}
if nodes := dht.store.Get(*request.Arg); len(nodes) > 0 {
send(dht, addr, Response{
ID: request.ID,
NodeID: dht.node.id,
FindValueKey: request.Arg.RawString(),
FindNodeData: nodes,
Token: dht.tokens.Get(request.NodeID, addr),
})
} else {
res := getFindResponse(dht, request)
res.Token = dht.tokens.Get(request.NodeID, addr)
send(dht, addr, res)
}
}
// nodes that send us requests should not be inserted, only refreshed.
// the routing table must only contain "good" nodes, which are nodes that reply to our requests
// if a node is already good (aka in the table), its fine to refresh it
// http://www.bittorrent.org/beps/bep_0005.html#routing-table
node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.UpdateIfExists(node)
}
func getFindResponse(dht *DHT, request Request) Response {
closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize)
response := Response{
ID: request.ID,
NodeID: dht.node.id,
FindNodeData: make([]Node, len(closestNodes)),
}
for i, n := range closestNodes {
response.FindNodeData[i] = n
}
return response
}
// handleResponse handles responses received from udp.
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
tx := dht.tm.Find(response.ID, addr)
if tx != nil {
tx.res <- response
}
node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
spew.Dump(e)
node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.UpdateIfExists(node)
}
// send sends data to a udp address
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return errors.Err(err)
}
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.ArgsDebug())
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
} else {
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
}
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err = dht.conn.WriteToUDP(encoded, addr)
return errors.Err(err)
}

View file

@ -3,7 +3,7 @@ package dht
import "sync"
type peer struct {
node Node
contact Contact
//<lastPublished>,
//<originallyPublished>
// <originalPublisherID>
@ -12,42 +12,48 @@ type peer struct {
type peerStore struct {
// map of blob hashes to (map of node IDs to bools)
hashes map[Bitmap]map[Bitmap]bool
// map of node IDs to peers
nodeInfo map[Bitmap]peer
// stores the peers themselves, so they can be updated in one place
peers map[Bitmap]peer
lock sync.RWMutex
}
func newPeerStore() *peerStore {
return &peerStore{
hashes: make(map[Bitmap]map[Bitmap]bool),
nodeInfo: make(map[Bitmap]peer),
peers: make(map[Bitmap]peer),
}
}
func (s *peerStore) Upsert(blobHash Bitmap, node Node) {
func (s *peerStore) Upsert(blobHash Bitmap, contact Contact) {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.hashes[blobHash]; !ok {
s.hashes[blobHash] = make(map[Bitmap]bool)
}
s.hashes[blobHash][node.id] = true
s.nodeInfo[node.id] = peer{node: node}
s.hashes[blobHash][contact.id] = true
s.peers[contact.id] = peer{contact: contact}
}
func (s *peerStore) Get(blobHash Bitmap) []Node {
func (s *peerStore) Get(blobHash Bitmap) []Contact {
s.lock.RLock()
defer s.lock.RUnlock()
var nodes []Node
var contacts []Contact
if ids, ok := s.hashes[blobHash]; ok {
for id := range ids {
peer, ok := s.nodeInfo[id]
peer, ok := s.peers[id]
if !ok {
panic("node id in IDs list, but not in nodeInfo")
}
nodes = append(nodes, peer.node)
contacts = append(contacts, peer.contact)
}
}
return nodes
return contacts
}
func (s *peerStore) RemoveTODO(contact Contact) {
// TODO: remove peer from everywhere
}
func (s *peerStore) CountStoredHashes() int {

View file

@ -1,125 +0,0 @@
package dht
import (
"context"
"net"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// transaction represents a single query to the dht. it stores the queried node, the request, and the response channel
type transaction struct {
node Node
req Request
res chan Response
}
// transactionManager keeps track of the outstanding transactions
type transactionManager struct {
dht *DHT
lock *sync.RWMutex
transactions map[messageID]*transaction
}
// newTransactionManager returns a new transactionManager
func newTransactionManager(dht *DHT) *transactionManager {
return &transactionManager{
lock: &sync.RWMutex{},
transactions: make(map[messageID]*transaction),
dht: dht,
}
}
// insert adds a transaction to the manager.
func (tm *transactionManager) insert(tx *transaction) {
tm.lock.Lock()
defer tm.lock.Unlock()
tm.transactions[tx.req.ID] = tx
}
// delete removes a transaction from the manager.
func (tm *transactionManager) delete(id messageID) {
tm.lock.Lock()
defer tm.lock.Unlock()
delete(tm.transactions, id)
}
// Find finds a transaction for the given id. it optionally ensures that addr matches node from transaction
func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
tm.lock.RLock()
defer tm.lock.RUnlock()
// TODO: also check that the response's nodeid matches the id you thought you sent to?
t, ok := tm.transactions[id]
if !ok || (addr != nil && t.node.Addr().String() != addr.String()) {
return nil
}
return t
}
// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response
// The response channel is closed when the transaction is completed or times out.
func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req Request) <-chan *Response {
if node.id.Equals(tm.dht.node.id) {
log.Error("sending query to self")
return nil
}
ch := make(chan *Response, 1)
go func() {
defer close(ch)
req.ID = newMessageID()
req.NodeID = tm.dht.node.id
tx := &transaction{
node: node,
req: req,
res: make(chan Response),
}
tm.insert(tx)
defer tm.delete(tx.req.ID)
for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, node.Addr(), tx.req); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections
log.Error("send error: ", err)
}
continue // try again? return?
}
select {
case res := <-tx.res:
ch <- &res
return
case <-ctx.Done():
return
case <-time.After(udpTimeout):
}
}
// if request timed out each time
tm.dht.rt.RemoveByID(tx.node.id)
}()
return ch
}
// Send sends a transaction and blocks until the response is available. It returns a response, or nil
// if the transaction timed out.
func (tm *transactionManager) Send(node Node, req Request) *Response {
return <-tm.SendAsync(context.Background(), node, req)
}
// Count returns the number of transactions in the manager
func (tm *transactionManager) Count() int {
tm.lock.Lock()
defer tm.lock.Unlock()
return len(tm.transactions)
}