add transaction manager, fix bencoding to support int keys, fix routing table bucketing
This commit is contained in:
parent
883d76d8bb
commit
05e2d8529a
12 changed files with 664 additions and 330 deletions
|
@ -1,11 +1,11 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
|
||||
"github.com/zeebo/bencode"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
type bitmap [nodeIDLength]byte
|
||||
|
@ -45,7 +45,7 @@ func (b bitmap) Xor(other bitmap) bitmap {
|
|||
}
|
||||
|
||||
// PrefixLen returns the number of leading 0 bits
|
||||
func (b bitmap) PrefixLen() (ret int) {
|
||||
func (b bitmap) PrefixLen() int {
|
||||
for i := range b {
|
||||
for j := 0; j < 8; j++ {
|
||||
if (b[i]>>uint8(7-j))&0x1 != 0 {
|
||||
|
@ -95,8 +95,9 @@ func newBitmapFromHex(hexStr string) bitmap {
|
|||
|
||||
func newRandomBitmap() bitmap {
|
||||
var id bitmap
|
||||
for k := range id {
|
||||
id[k] = uint8(rand.Intn(256))
|
||||
_, err := rand.Read(id[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package dht
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zeebo/bencode"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
func TestBitmap(t *testing.T) {
|
||||
|
|
60
dht/conn.go
60
dht/conn.go
|
@ -1,60 +0,0 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UDPConn interface {
|
||||
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
|
||||
WriteToUDP([]byte, *net.UDPAddr) (int, error)
|
||||
SetWriteDeadline(time.Time) error
|
||||
}
|
||||
|
||||
type testUDPPacket struct {
|
||||
data []byte
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
type testUDPConn struct {
|
||||
addr *net.UDPAddr
|
||||
toRead chan testUDPPacket
|
||||
writes chan testUDPPacket
|
||||
}
|
||||
|
||||
func newTestUDPConn(addr string) *testUDPConn {
|
||||
parts := strings.Split(addr, ":")
|
||||
if len(parts) != 2 {
|
||||
panic("addr needs ip and port")
|
||||
}
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &testUDPConn{
|
||||
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
|
||||
toRead: make(chan testUDPPacket),
|
||||
writes: make(chan testUDPPacket),
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||
select {
|
||||
case packet := <-t.toRead:
|
||||
n := copy(b, packet.data)
|
||||
return n, packet.addr, nil
|
||||
//default:
|
||||
// return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||
t.writes <- testUDPPacket{data: b, addr: addr}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/zeebo/bencode"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
|
|
342
dht/dht.go
342
dht/dht.go
|
@ -1,23 +1,24 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lbryio/errors.go"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
const network = "udp4"
|
||||
|
||||
const alpha = 3 // this is the constant alpha in the spec
|
||||
const nodeIDLength = 48 // bytes. this is the constant B in the spec
|
||||
const bucketSize = 8 // this is the constant k in the spec
|
||||
const alpha = 3 // this is the constant alpha in the spec
|
||||
const nodeIDLength = 48 // bytes. this is the constant B in the spec
|
||||
const messageIDLength = 20 // bytes.
|
||||
const bucketSize = 8 // this is the constant k in the spec
|
||||
|
||||
const udpRetry = 3
|
||||
const udpTimeout = 10 * time.Second
|
||||
|
||||
const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
|
||||
const tRefresh = 3600 * time.Second // the time after which an otherwise unaccessed bucket must be refreshed
|
||||
|
@ -41,6 +42,8 @@ type Config struct {
|
|||
SeedNodes []string
|
||||
// the hex-encoded node id for this node. if string is empty, a random id will be generated
|
||||
NodeID string
|
||||
// print the state of the dht every minute
|
||||
PrintState bool
|
||||
}
|
||||
|
||||
// NewStandardConfig returns a Config pointer with default values.
|
||||
|
@ -55,18 +58,26 @@ func NewStandardConfig() *Config {
|
|||
}
|
||||
}
|
||||
|
||||
// UDPConn allows using a mocked connection for testing sending/receiving data
|
||||
type UDPConn interface {
|
||||
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
|
||||
WriteToUDP([]byte, *net.UDPAddr) (int, error)
|
||||
SetWriteDeadline(time.Time) error
|
||||
}
|
||||
|
||||
// DHT represents a DHT node.
|
||||
type DHT struct {
|
||||
conf *Config
|
||||
conn UDPConn
|
||||
node *Node
|
||||
routingTable *RoutingTable
|
||||
packets chan packet
|
||||
store *peerStore
|
||||
conf *Config
|
||||
conn UDPConn
|
||||
node *Node
|
||||
rt *RoutingTable
|
||||
packets chan packet
|
||||
store *peerStore
|
||||
tm *transactionManager
|
||||
}
|
||||
|
||||
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
|
||||
func New(config *Config) *DHT {
|
||||
func New(config *Config) (*DHT, error) {
|
||||
if config == nil {
|
||||
config = NewStandardConfig()
|
||||
}
|
||||
|
@ -80,41 +91,51 @@ func New(config *Config) *DHT {
|
|||
|
||||
ip, port, err := net.SplitHostPort(config.Address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, errors.Err(err)
|
||||
} else if ip == "" {
|
||||
panic("address does not contain an IP")
|
||||
return nil, errors.Err("address does not contain an IP")
|
||||
} else if port == "" {
|
||||
panic("address does not contain a port")
|
||||
return nil, errors.Err("address does not contain a port")
|
||||
}
|
||||
|
||||
portInt, err := cast.ToIntE(port)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, errors.Err(err)
|
||||
}
|
||||
|
||||
node := &Node{id: id, ip: net.ParseIP(ip), port: portInt}
|
||||
if node.ip == nil {
|
||||
panic("invalid ip")
|
||||
return nil, errors.Err("invalid ip")
|
||||
}
|
||||
return &DHT{
|
||||
conf: config,
|
||||
node: node,
|
||||
routingTable: newRoutingTable(node),
|
||||
packets: make(chan packet),
|
||||
store: newPeerStore(),
|
||||
|
||||
d := &DHT{
|
||||
conf: config,
|
||||
node: node,
|
||||
rt: newRoutingTable(node),
|
||||
packets: make(chan packet),
|
||||
store: newPeerStore(),
|
||||
}
|
||||
d.tm = newTransactionManager(d)
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// init initializes global variables.
|
||||
func (dht *DHT) init() {
|
||||
func (dht *DHT) init() error {
|
||||
log.Info("Initializing DHT on " + dht.conf.Address)
|
||||
log.Infof("Node ID is %s", dht.node.id.Hex())
|
||||
|
||||
listener, err := net.ListenPacket(network, dht.conf.Address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
dht.conn = listener.(*net.UDPConn)
|
||||
|
||||
if dht.conf.PrintState {
|
||||
go printState(dht)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// listen receives message from udp.
|
||||
|
@ -159,201 +180,98 @@ func (dht *DHT) runHandler() {
|
|||
for {
|
||||
select {
|
||||
case pkt = <-dht.packets:
|
||||
handle(dht, pkt)
|
||||
handlePacket(dht, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the dht.
|
||||
func (dht *DHT) Run() {
|
||||
dht.init()
|
||||
func (dht *DHT) Run() error {
|
||||
err := dht.init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dht.listen()
|
||||
dht.join()
|
||||
log.Info("DHT ready")
|
||||
dht.runHandler()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle handles packets received from udp.
|
||||
func handle(dht *DHT, pkt packet) {
|
||||
//log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
|
||||
|
||||
var data map[string]interface{}
|
||||
err := bencode.DecodeBytes(pkt.data, &data)
|
||||
if err != nil {
|
||||
log.Errorf("error decoding data: %s\n%s", err, pkt.data)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, ok := data[headerTypeField]
|
||||
if !ok {
|
||||
log.Errorf("decoded data has no message type: %s", data)
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType.(int64) {
|
||||
case requestType:
|
||||
request := Request{}
|
||||
err = bencode.DecodeBytes(pkt.data, &request)
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
|
||||
handleRequest(dht, pkt.raddr, request)
|
||||
|
||||
case responseType:
|
||||
response := Response{}
|
||||
err = bencode.DecodeBytes(pkt.data, &response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data)
|
||||
handleResponse(dht, pkt.raddr, response)
|
||||
|
||||
case errorType:
|
||||
e := Error{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
ExceptionType: data[headerPayloadField].(string),
|
||||
Response: getArgs(data[headerArgsField]),
|
||||
}
|
||||
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
|
||||
handleError(dht, pkt.raddr, e)
|
||||
|
||||
default:
|
||||
log.Errorf("Invalid message type: %s", msgType)
|
||||
return
|
||||
func printState(dht *DHT) {
|
||||
t := time.NewTicker(60 * time.Second)
|
||||
for {
|
||||
log.Printf("DHT state at %s", time.Now().Format(time.RFC822Z))
|
||||
log.Printf("Outstanding transactions: %d", dht.tm.Count())
|
||||
log.Printf("Known nodes: %d", dht.store.CountKnownNodes())
|
||||
log.Printf("Buckets: \n%s", dht.rt.BucketInfo())
|
||||
<-t.C
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||
if request.NodeID == dht.node.id.RawString() {
|
||||
log.Warn("ignoring self-request")
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
case pingMethod:
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
|
||||
case storeMethod:
|
||||
if request.StoreArgs.BlobHash == "" {
|
||||
log.Errorln("blobhash is empty")
|
||||
return // nothing to store
|
||||
}
|
||||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||
dht.store.Insert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
|
||||
case findNodeMethod:
|
||||
log.Println("findnode")
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
return
|
||||
}
|
||||
doFindNodes(dht, addr, request)
|
||||
case findValueMethod:
|
||||
log.Println("findvalue")
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
return
|
||||
}
|
||||
|
||||
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
|
||||
response.FindValueKey = request.Args[0]
|
||||
response.FindNodeData = nodes
|
||||
send(dht, addr, response)
|
||||
} else {
|
||||
doFindNodes(dht, addr, request)
|
||||
}
|
||||
|
||||
default:
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||
log.Errorln("invalid request method")
|
||||
return
|
||||
}
|
||||
|
||||
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
|
||||
dht.routingTable.Update(node)
|
||||
}
|
||||
|
||||
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||
nodeID := newBitmapFromString(request.Args[0])
|
||||
closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize)
|
||||
if len(closestNodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
|
||||
for i, n := range closestNodes {
|
||||
response.FindNodeData[i] = *n
|
||||
}
|
||||
send(dht, addr, response)
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponse handles responses received from udp.
|
||||
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
|
||||
spew.Dump(response)
|
||||
|
||||
// TODO: find transaction by message id, pass along response
|
||||
|
||||
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
|
||||
dht.routingTable.Update(node)
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
|
||||
spew.Dump(e)
|
||||
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
|
||||
dht.routingTable.Update(node)
|
||||
}
|
||||
|
||||
// send sends data to the udp.
|
||||
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||
if req, ok := data.(Request); ok {
|
||||
log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args))
|
||||
} else if res, ok := data.(Response); ok {
|
||||
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], spew.Sdump(res.Data))
|
||||
} else {
|
||||
log.Debugf("[%s] %s", spew.Sdump(data))
|
||||
}
|
||||
encoded, err := bencode.EncodeBytes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//log.Infof("Encoded: %s", string(encoded))
|
||||
|
||||
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
_, err = dht.conn.WriteToUDP(encoded, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
func getArgs(argsInt interface{}) []string {
|
||||
var args []string
|
||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(argsInt)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
args = append(args, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func argsToString(args []string) string {
|
||||
argsCopy := make([]string, len(args))
|
||||
copy(argsCopy, args)
|
||||
for k, v := range argsCopy {
|
||||
if len(v) == nodeIDLength {
|
||||
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
||||
}
|
||||
}
|
||||
return strings.Join(argsCopy, ", ")
|
||||
}
|
||||
//func (dht *DHT) Get(hash bitmap) ([]Node, error) {
|
||||
// return iterativeFindNode(dht, hash)
|
||||
//}
|
||||
//
|
||||
//func iterativeFindNode(dht *DHT, hash bitmap) ([]Node, error) {
|
||||
// shortlist := dht.rt.FindClosest(hash, alpha)
|
||||
// if len(shortlist) == 0 {
|
||||
// return nil, errors.Err("no nodes in routing table")
|
||||
// }
|
||||
//
|
||||
// pending := make(chan *Node)
|
||||
// contacted := make(map[bitmap]bool)
|
||||
// contactedMutex := &sync.Mutex{}
|
||||
// closestNodeMutex := &sync.Mutex{}
|
||||
// closestNode := shortlist[0]
|
||||
// wg := sync.WaitGroup{}
|
||||
//
|
||||
// for i := 0; i < alpha; i++ {
|
||||
// wg.Add(1)
|
||||
// go func() {
|
||||
// defer wg.Done()
|
||||
// for {
|
||||
// node, ok := <-pending
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// contactedMutex.Lock()
|
||||
// if _, ok := contacted[node.id]; ok {
|
||||
// contactedMutex.Unlock()
|
||||
// continue
|
||||
// }
|
||||
// contacted[node.id] = true
|
||||
// contactedMutex.Unlock()
|
||||
//
|
||||
// res := dht.tm.Send(node, &Request{
|
||||
// NodeID: dht.node.id.RawString(),
|
||||
// Method: findNodeMethod,
|
||||
// Args: []string{hash.RawString()},
|
||||
// })
|
||||
// if res == nil {
|
||||
// // remove node from shortlist
|
||||
// continue
|
||||
// }
|
||||
//
|
||||
// for _, n := range res.FindNodeData {
|
||||
// pending <- &n
|
||||
// closestNodeMutex.Lock()
|
||||
// if n.id.Xor(hash).Less(closestNode.id.Xor(hash)) {
|
||||
// closestNode = &n
|
||||
// }
|
||||
// closestNodeMutex.Unlock()
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
// }
|
||||
//
|
||||
// for _, n := range shortlist {
|
||||
// pending <- n
|
||||
// }
|
||||
//
|
||||
// wg.Wait()
|
||||
//
|
||||
// return nil, nil
|
||||
//}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -171,6 +173,21 @@ type Response struct {
|
|||
FindValueKey string
|
||||
}
|
||||
|
||||
func (r Response) ArgsDebug() string {
|
||||
if len(r.FindNodeData) == 0 {
|
||||
return r.Data
|
||||
}
|
||||
|
||||
str := "contacts "
|
||||
if r.FindValueKey != "" {
|
||||
str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
|
||||
}
|
||||
for _, c := range r.FindNodeData {
|
||||
str += c.Addr().String() + ":" + c.id.Hex()[:8] + ", "
|
||||
}
|
||||
return str[:len(str)-2] // chomp off last ", "
|
||||
}
|
||||
|
||||
func (r Response) MarshalBencode() ([]byte, error) {
|
||||
data := map[string]interface{}{
|
||||
headerTypeField: responseType,
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
func TestBencodeDecodeStoreArgs(t *testing.T) {
|
||||
|
|
|
@ -3,12 +3,15 @@ package dht
|
|||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
|
||||
"github.com/zeebo/bencode"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
|
@ -17,6 +20,10 @@ type Node struct {
|
|||
port int
|
||||
}
|
||||
|
||||
func (n Node) Addr() *net.UDPAddr {
|
||||
return &net.UDPAddr{IP: n.ip, Port: n.port}
|
||||
}
|
||||
|
||||
func (n Node) MarshalCompact() ([]byte, error) {
|
||||
if n.ip.To4() == nil {
|
||||
return nil, errors.Err("ip not set")
|
||||
|
@ -102,6 +109,7 @@ func (a byXorDistance) Less(i, j int) bool {
|
|||
type RoutingTable struct {
|
||||
node Node
|
||||
buckets [numBuckets]*list.List
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
func newRoutingTable(node *Node) *RoutingTable {
|
||||
|
@ -110,39 +118,73 @@ func newRoutingTable(node *Node) *RoutingTable {
|
|||
rt.buckets[i] = list.New()
|
||||
}
|
||||
rt.node = *node
|
||||
rt.lock = &sync.RWMutex{}
|
||||
return &rt
|
||||
}
|
||||
|
||||
func (rt *RoutingTable) BucketInfo() string {
|
||||
rt.lock.RLock()
|
||||
defer rt.lock.RUnlock()
|
||||
|
||||
bucketInfo := []string{}
|
||||
for i, b := range rt.buckets {
|
||||
count := countInList(b)
|
||||
if count > 0 {
|
||||
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %d", i, count))
|
||||
}
|
||||
}
|
||||
if len(bucketInfo) == 0 {
|
||||
return "buckets are empty"
|
||||
}
|
||||
return strings.Join(bucketInfo, "\n")
|
||||
}
|
||||
|
||||
func (rt *RoutingTable) Update(node *Node) {
|
||||
prefixLength := node.id.Xor(rt.node.id).PrefixLen()
|
||||
bucket := rt.buckets[prefixLength]
|
||||
rt.lock.Lock()
|
||||
defer rt.lock.Unlock()
|
||||
bucketNum := bucketFor(rt.node.id, node.id)
|
||||
bucket := rt.buckets[bucketNum]
|
||||
element := findInList(bucket, rt.node.id)
|
||||
if element == nil {
|
||||
if bucket.Len() <= bucketSize {
|
||||
bucket.PushBack(node)
|
||||
if bucket.Len() >= bucketSize {
|
||||
// TODO: Ping front node first. Only remove if it does not respond
|
||||
bucket.Remove(bucket.Front())
|
||||
}
|
||||
// TODO: Handle insertion when the list is full by evicting old elements if
|
||||
// they don't respond to a ping.
|
||||
bucket.PushBack(node)
|
||||
} else {
|
||||
bucket.MoveToBack(element)
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
|
||||
func (rt *RoutingTable) RemoveByID(id bitmap) {
|
||||
rt.lock.Lock()
|
||||
defer rt.lock.Unlock()
|
||||
bucketNum := bucketFor(rt.node.id, id)
|
||||
bucket := rt.buckets[bucketNum]
|
||||
element := findInList(bucket, rt.node.id)
|
||||
if element != nil {
|
||||
bucket.Remove(element)
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node {
|
||||
rt.lock.RLock()
|
||||
defer rt.lock.RUnlock()
|
||||
|
||||
var toSort []*SortedNode
|
||||
|
||||
prefixLength := target.Xor(rt.node.id).PrefixLen()
|
||||
bucket := rt.buckets[prefixLength]
|
||||
toSort = appendNodes(toSort, bucket.Front(), nil, target)
|
||||
bucketNum := bucketFor(rt.node.id, target)
|
||||
bucket := rt.buckets[bucketNum]
|
||||
toSort = appendNodes(toSort, bucket.Front(), target)
|
||||
|
||||
for i := 1; (prefixLength-i >= 0 || prefixLength+i < numBuckets) && len(toSort) < count; i++ {
|
||||
if prefixLength-i >= 0 {
|
||||
bucket = rt.buckets[prefixLength-i]
|
||||
toSort = appendNodes(toSort, bucket.Front(), nil, target)
|
||||
for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ {
|
||||
if bucketNum-i >= 0 {
|
||||
bucket = rt.buckets[bucketNum-i]
|
||||
toSort = appendNodes(toSort, bucket.Front(), target)
|
||||
}
|
||||
if prefixLength+i < numBuckets {
|
||||
bucket = rt.buckets[prefixLength+i]
|
||||
toSort = appendNodes(toSort, bucket.Front(), nil, target)
|
||||
if bucketNum+i < numBuckets {
|
||||
bucket = rt.buckets[bucketNum+i]
|
||||
toSort = appendNodes(toSort, bucket.Front(), target)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,6 +193,9 @@ func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
|
|||
var nodes []*Node
|
||||
for _, c := range toSort {
|
||||
nodes = append(nodes, c.node)
|
||||
if len(nodes) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nodes
|
||||
|
@ -165,10 +210,25 @@ func findInList(bucket *list.List, value bitmap) *list.Element {
|
|||
return nil
|
||||
}
|
||||
|
||||
func appendNodes(nodes []*SortedNode, start, end *list.Element, target bitmap) []*SortedNode {
|
||||
for curr := start; curr != end; curr = curr.Next() {
|
||||
func countInList(bucket *list.List) int {
|
||||
count := 0
|
||||
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func appendNodes(nodes []*SortedNode, start *list.Element, target bitmap) []*SortedNode {
|
||||
for curr := start; curr != nil; curr = curr.Next() {
|
||||
node := curr.Value.(*Node)
|
||||
nodes = append(nodes, &SortedNode{node, node.id.Xor(target)})
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
func bucketFor(id bitmap, target bitmap) int {
|
||||
if id.Equals(target) {
|
||||
panic("nodes do not have a bucket for themselves")
|
||||
}
|
||||
return numBuckets - 1 - target.Xor(id).PrefixLen()
|
||||
}
|
||||
|
|
210
dht/rpc.go
Normal file
210
dht/rpc.go
Normal file
|
@ -0,0 +1,210 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lyoshenka/bencode"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func newMessageID() string {
|
||||
buf := make([]byte, messageIDLength)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// handlePacke handles packets received from udp.
|
||||
func handlePacket(dht *DHT, pkt packet) {
|
||||
//log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
|
||||
|
||||
var data map[string]interface{}
|
||||
err := bencode.DecodeBytes(pkt.data, &data)
|
||||
if err != nil {
|
||||
log.Errorf("error decoding data: %s\n%s", err, pkt.data)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, ok := data[headerTypeField]
|
||||
if !ok {
|
||||
log.Errorf("decoded data has no message type: %s", data)
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType.(int64) {
|
||||
case requestType:
|
||||
request := Request{}
|
||||
err = bencode.DecodeBytes(pkt.data, &request)
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
|
||||
handleRequest(dht, pkt.raddr, request)
|
||||
|
||||
case responseType:
|
||||
response := Response{}
|
||||
err = bencode.DecodeBytes(pkt.data, &response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data)
|
||||
handleResponse(dht, pkt.raddr, response)
|
||||
|
||||
case errorType:
|
||||
e := Error{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
ExceptionType: data[headerPayloadField].(string),
|
||||
Response: getArgs(data[headerArgsField]),
|
||||
}
|
||||
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
|
||||
handleError(dht, pkt.raddr, e)
|
||||
|
||||
default:
|
||||
log.Errorf("Invalid message type: %s", msgType)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||
if request.NodeID == dht.node.id.RawString() {
|
||||
log.Warn("ignoring self-request")
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
case pingMethod:
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
|
||||
case storeMethod:
|
||||
if request.StoreArgs.BlobHash == "" {
|
||||
log.Errorln("blobhash is empty")
|
||||
return // nothing to store
|
||||
}
|
||||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
|
||||
case findNodeMethod:
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
return
|
||||
}
|
||||
doFindNodes(dht, addr, request)
|
||||
case findValueMethod:
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
return
|
||||
}
|
||||
|
||||
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
|
||||
response.FindValueKey = request.Args[0]
|
||||
response.FindNodeData = nodes
|
||||
send(dht, addr, response)
|
||||
} else {
|
||||
doFindNodes(dht, addr, request)
|
||||
}
|
||||
|
||||
default:
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||
log.Errorln("invalid request method")
|
||||
return
|
||||
}
|
||||
|
||||
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
|
||||
dht.rt.Update(node)
|
||||
}
|
||||
|
||||
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||
nodeID := newBitmapFromString(request.Args[0])
|
||||
closestNodes := dht.rt.FindClosest(nodeID, bucketSize)
|
||||
if len(closestNodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
|
||||
for i, n := range closestNodes {
|
||||
response.FindNodeData[i] = *n
|
||||
}
|
||||
send(dht, addr, response)
|
||||
} else {
|
||||
log.Warn("no nodes in routing table")
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponse handles responses received from udp.
|
||||
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
|
||||
tx := dht.tm.Find(response.ID, addr)
|
||||
if tx != nil {
|
||||
tx.res <- &response
|
||||
}
|
||||
|
||||
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
|
||||
dht.rt.Update(node)
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
|
||||
spew.Dump(e)
|
||||
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
|
||||
dht.rt.Update(node)
|
||||
}
|
||||
|
||||
// send sends data to the udp.
|
||||
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||
if req, ok := data.(Request); ok {
|
||||
log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args))
|
||||
} else if res, ok := data.(Response); ok {
|
||||
log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug())
|
||||
} else {
|
||||
log.Debugf("[%s] %s", dht.node.id.Hex()[:8], spew.Sdump(data))
|
||||
}
|
||||
encoded, err := bencode.EncodeBytes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//log.Infof("Encoded: %s", string(encoded))
|
||||
|
||||
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
_, err = dht.conn.WriteToUDP(encoded, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
func getArgs(argsInt interface{}) []string {
|
||||
var args []string
|
||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(argsInt)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
args = append(args, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func argsToString(args []string) string {
|
||||
argsCopy := make([]string, len(args))
|
||||
copy(argsCopy, args)
|
||||
for k, v := range argsCopy {
|
||||
if len(v) == nodeIDLength {
|
||||
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
||||
}
|
||||
}
|
||||
return strings.Join(argsCopy, ", ")
|
||||
}
|
|
@ -2,13 +2,61 @@ package dht
|
|||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lyoshenka/bencode"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
type testUDPPacket struct {
|
||||
data []byte
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
type testUDPConn struct {
|
||||
addr *net.UDPAddr
|
||||
toRead chan testUDPPacket
|
||||
writes chan testUDPPacket
|
||||
}
|
||||
|
||||
func newTestUDPConn(addr string) *testUDPConn {
|
||||
parts := strings.Split(addr, ":")
|
||||
if len(parts) != 2 {
|
||||
panic("addr needs ip and port")
|
||||
}
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &testUDPConn{
|
||||
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
|
||||
toRead: make(chan testUDPPacket),
|
||||
writes: make(chan testUDPPacket),
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||
select {
|
||||
case packet := <-t.toRead:
|
||||
n := copy(b, packet.data)
|
||||
return n, packet.addr, nil
|
||||
//default:
|
||||
// return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||
t.writes <- testUDPPacket{data: b, addr: addr}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
dhtNodeID := newRandomBitmap()
|
||||
|
@ -16,12 +64,15 @@ func TestPing(t *testing.T) {
|
|||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
messageID := newMessageID()
|
||||
|
||||
data, err := bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
|
@ -107,12 +158,16 @@ func TestStore(t *testing.T) {
|
|||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
messageID := newMessageID()
|
||||
blobHashToStore := newRandomBitmap().RawString()
|
||||
|
||||
storeRequest := Request{
|
||||
|
@ -178,7 +233,7 @@ func TestStore(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
if len(dht.store.data) != 1 {
|
||||
if len(dht.store.nodeIDs) != 1 {
|
||||
t.Error("dht store has wrong number of items")
|
||||
}
|
||||
|
||||
|
@ -197,7 +252,10 @@ func TestFindNode(t *testing.T) {
|
|||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
@ -207,10 +265,10 @@ func TestFindNode(t *testing.T) {
|
|||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
||||
nodes = append(nodes, n)
|
||||
dht.routingTable.Update(&n)
|
||||
dht.rt.Update(&n)
|
||||
}
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
messageID := newMessageID()
|
||||
blobHashToFind := newRandomBitmap().RawString()
|
||||
|
||||
request := Request{
|
||||
|
@ -270,7 +328,11 @@ func TestFindValueExisting(t *testing.T) {
|
|||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
@ -280,16 +342,18 @@ func TestFindValueExisting(t *testing.T) {
|
|||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
||||
nodes = append(nodes, n)
|
||||
dht.routingTable.Update(&n)
|
||||
dht.rt.Update(&n)
|
||||
}
|
||||
|
||||
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
messageID := newMessageID()
|
||||
valueToFind := newRandomBitmap().RawString()
|
||||
|
||||
nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286}
|
||||
dht.store.Insert(valueToFind, nodeToFind)
|
||||
dht.store.Upsert(valueToFind, nodeToFind)
|
||||
dht.store.Upsert(valueToFind, nodeToFind)
|
||||
dht.store.Upsert(valueToFind, nodeToFind)
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
|
@ -348,7 +412,11 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
|||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
@ -358,10 +426,10 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
|||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
|
||||
nodes = append(nodes, n)
|
||||
dht.routingTable.Update(&n)
|
||||
dht.rt.Update(&n)
|
||||
}
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
messageID := newMessageID()
|
||||
valueToFind := newRandomBitmap().RawString()
|
||||
|
||||
request := Request{
|
||||
|
@ -442,6 +510,9 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
|
|||
} else if rMessageID != messageID {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
if len(rMessageID) != messageIDLength {
|
||||
t.Errorf("message ID should be %d chars long", messageIDLength)
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = resp[headerNodeIDField]
|
||||
|
@ -454,6 +525,9 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
|
|||
} else if rNodeID != dhtNodeID {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
if len(rNodeID) != nodeIDLength {
|
||||
t.Errorf("node ID should be %d chars long", nodeIDLength)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
39
dht/store.go
39
dht/store.go
|
@ -4,39 +4,52 @@ import "sync"
|
|||
|
||||
type peer struct {
|
||||
node Node
|
||||
//<lastPublished>,
|
||||
//<originallyPublished>
|
||||
// <originalPublisherID>
|
||||
}
|
||||
|
||||
type peerStore struct {
|
||||
data map[string][]peer
|
||||
lock sync.RWMutex
|
||||
nodeIDs map[string]map[bitmap]bool
|
||||
nodeInfo map[bitmap]peer
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func newPeerStore() *peerStore {
|
||||
return &peerStore{
|
||||
data: make(map[string][]peer),
|
||||
nodeIDs: make(map[string]map[bitmap]bool),
|
||||
nodeInfo: make(map[bitmap]peer),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *peerStore) Insert(key string, node Node) {
|
||||
func (s *peerStore) Upsert(key string, node Node) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
newPeer := peer{node: node}
|
||||
_, ok := s.data[key]
|
||||
if !ok {
|
||||
s.data[key] = []peer{newPeer}
|
||||
} else {
|
||||
s.data[key] = append(s.data[key], newPeer)
|
||||
if _, ok := s.nodeIDs[key]; !ok {
|
||||
s.nodeIDs[key] = make(map[bitmap]bool)
|
||||
}
|
||||
s.nodeIDs[key][node.id] = true
|
||||
s.nodeInfo[node.id] = peer{node: node}
|
||||
}
|
||||
|
||||
func (s *peerStore) Get(key string) []Node {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
var nodes []Node
|
||||
if peers, ok := s.data[key]; ok {
|
||||
for _, p := range peers {
|
||||
nodes = append(nodes, p.node)
|
||||
if ids, ok := s.nodeIDs[key]; ok {
|
||||
for id := range ids {
|
||||
peer, ok := s.nodeInfo[id]
|
||||
if !ok {
|
||||
panic("node id in IDs list, but not in nodeInfo")
|
||||
}
|
||||
nodes = append(nodes, peer.node)
|
||||
}
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
func (s *peerStore) CountKnownNodes() int {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return len(s.nodeInfo)
|
||||
}
|
||||
|
|
101
dht/transaction_manager.go
Normal file
101
dht/transaction_manager.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// query represents the query data included queried node and query-formed data.
|
||||
type transaction struct {
|
||||
node *Node
|
||||
req *Request
|
||||
res chan *Response
|
||||
}
|
||||
|
||||
// transactionManager represents the manager of transactions.
|
||||
type transactionManager struct {
|
||||
lock *sync.RWMutex
|
||||
transactions map[string]*transaction
|
||||
dht *DHT
|
||||
}
|
||||
|
||||
// newTransactionManager returns new transactionManager pointer.
|
||||
func newTransactionManager(dht *DHT) *transactionManager {
|
||||
return &transactionManager{
|
||||
lock: &sync.RWMutex{},
|
||||
transactions: make(map[string]*transaction),
|
||||
dht: dht,
|
||||
}
|
||||
}
|
||||
|
||||
// insert adds a transaction to transactionManager.
|
||||
func (tm *transactionManager) insert(trans *transaction) {
|
||||
tm.lock.Lock()
|
||||
defer tm.lock.Unlock()
|
||||
tm.transactions[trans.req.ID] = trans
|
||||
}
|
||||
|
||||
// delete removes a transaction from transactionManager.
|
||||
func (tm *transactionManager) delete(transID string) {
|
||||
tm.lock.Lock()
|
||||
defer tm.lock.Unlock()
|
||||
delete(tm.transactions, transID)
|
||||
}
|
||||
|
||||
// find transaction for id. optionally ensure that addr matches node from transaction
|
||||
func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction {
|
||||
tm.lock.RLock()
|
||||
defer tm.lock.RUnlock()
|
||||
|
||||
t, ok := tm.transactions[id]
|
||||
if !ok {
|
||||
return nil
|
||||
} else if addr != nil && t.node.Addr().String() != addr.String() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (tm *transactionManager) Send(node *Node, req *Request) *Response {
|
||||
if node.id.Equals(tm.dht.node.id) {
|
||||
log.Error("sending query to self")
|
||||
return nil
|
||||
}
|
||||
|
||||
req.ID = newMessageID()
|
||||
trans := &transaction{
|
||||
node: node,
|
||||
req: req,
|
||||
res: make(chan *Response),
|
||||
}
|
||||
|
||||
tm.insert(trans)
|
||||
defer tm.delete(trans.req.ID)
|
||||
|
||||
for i := 0; i < udpRetry; i++ {
|
||||
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case res := <-trans.res:
|
||||
return res
|
||||
case <-time.After(udpTimeout):
|
||||
}
|
||||
}
|
||||
|
||||
tm.dht.rt.RemoveByID(trans.node.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the number of transactions in the manager
|
||||
func (tm *transactionManager) Count() int {
|
||||
tm.lock.Lock()
|
||||
defer tm.lock.Unlock()
|
||||
return len(tm.transactions)
|
||||
}
|
Loading…
Reference in a new issue