add transaction manager, fix bencoding to support int keys, fix routing table bucketing

This commit is contained in:
Alex Grintsvayg 2018-03-23 19:18:00 -04:00
parent 883d76d8bb
commit 05e2d8529a
12 changed files with 664 additions and 330 deletions

View file

@ -1,11 +1,11 @@
package dht
import (
"crypto/rand"
"encoding/hex"
"math/rand"
"strconv"
"github.com/zeebo/bencode"
"github.com/lyoshenka/bencode"
)
type bitmap [nodeIDLength]byte
@ -45,7 +45,7 @@ func (b bitmap) Xor(other bitmap) bitmap {
}
// PrefixLen returns the number of leading 0 bits
func (b bitmap) PrefixLen() (ret int) {
func (b bitmap) PrefixLen() int {
for i := range b {
for j := 0; j < 8; j++ {
if (b[i]>>uint8(7-j))&0x1 != 0 {
@ -95,8 +95,9 @@ func newBitmapFromHex(hexStr string) bitmap {
func newRandomBitmap() bitmap {
var id bitmap
for k := range id {
id[k] = uint8(rand.Intn(256))
_, err := rand.Read(id[:])
if err != nil {
panic(err)
}
return id
}

View file

@ -3,7 +3,7 @@ package dht
import (
"testing"
"github.com/zeebo/bencode"
"github.com/lyoshenka/bencode"
)
func TestBitmap(t *testing.T) {

View file

@ -1,60 +0,0 @@
package dht
import (
"net"
"strconv"
"strings"
"time"
)
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetWriteDeadline(time.Time) error
}
type testUDPPacket struct {
data []byte
addr *net.UDPAddr
}
type testUDPConn struct {
addr *net.UDPAddr
toRead chan testUDPPacket
writes chan testUDPPacket
}
func newTestUDPConn(addr string) *testUDPConn {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
panic("addr needs ip and port")
}
port, err := strconv.Atoi(parts[1])
if err != nil {
panic(err)
}
return &testUDPConn{
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
toRead: make(chan testUDPPacket),
writes: make(chan testUDPPacket),
}
}
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
select {
case packet := <-t.toRead:
n := copy(b, packet.data)
return n, packet.addr, nil
//default:
// return 0, nil, nil
}
}
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
t.writes <- testUDPPacket{data: b, addr: addr}
return len(b), nil
}
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}

View file

@ -5,7 +5,7 @@ import (
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/zeebo/bencode"
"github.com/lyoshenka/bencode"
)
func TestDecode(t *testing.T) {

View file

@ -1,23 +1,24 @@
package dht
import (
"encoding/hex"
"net"
"reflect"
"strings"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lbryio/errors.go"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
"github.com/zeebo/bencode"
)
const network = "udp4"
const alpha = 3 // this is the constant alpha in the spec
const nodeIDLength = 48 // bytes. this is the constant B in the spec
const bucketSize = 8 // this is the constant k in the spec
const alpha = 3 // this is the constant alpha in the spec
const nodeIDLength = 48 // bytes. this is the constant B in the spec
const messageIDLength = 20 // bytes.
const bucketSize = 8 // this is the constant k in the spec
const udpRetry = 3
const udpTimeout = 10 * time.Second
const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
const tRefresh = 3600 * time.Second // the time after which an otherwise unaccessed bucket must be refreshed
@ -41,6 +42,8 @@ type Config struct {
SeedNodes []string
// the hex-encoded node id for this node. if string is empty, a random id will be generated
NodeID string
// print the state of the dht every minute
PrintState bool
}
// NewStandardConfig returns a Config pointer with default values.
@ -55,18 +58,26 @@ func NewStandardConfig() *Config {
}
}
// UDPConn allows using a mocked connection for testing sending/receiving data
type UDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
SetWriteDeadline(time.Time) error
}
// DHT represents a DHT node.
type DHT struct {
conf *Config
conn UDPConn
node *Node
routingTable *RoutingTable
packets chan packet
store *peerStore
conf *Config
conn UDPConn
node *Node
rt *RoutingTable
packets chan packet
store *peerStore
tm *transactionManager
}
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
func New(config *Config) *DHT {
func New(config *Config) (*DHT, error) {
if config == nil {
config = NewStandardConfig()
}
@ -80,41 +91,51 @@ func New(config *Config) *DHT {
ip, port, err := net.SplitHostPort(config.Address)
if err != nil {
panic(err)
return nil, errors.Err(err)
} else if ip == "" {
panic("address does not contain an IP")
return nil, errors.Err("address does not contain an IP")
} else if port == "" {
panic("address does not contain a port")
return nil, errors.Err("address does not contain a port")
}
portInt, err := cast.ToIntE(port)
if err != nil {
panic(err)
return nil, errors.Err(err)
}
node := &Node{id: id, ip: net.ParseIP(ip), port: portInt}
if node.ip == nil {
panic("invalid ip")
return nil, errors.Err("invalid ip")
}
return &DHT{
conf: config,
node: node,
routingTable: newRoutingTable(node),
packets: make(chan packet),
store: newPeerStore(),
d := &DHT{
conf: config,
node: node,
rt: newRoutingTable(node),
packets: make(chan packet),
store: newPeerStore(),
}
d.tm = newTransactionManager(d)
return d, nil
}
// init initializes global variables.
func (dht *DHT) init() {
func (dht *DHT) init() error {
log.Info("Initializing DHT on " + dht.conf.Address)
log.Infof("Node ID is %s", dht.node.id.Hex())
listener, err := net.ListenPacket(network, dht.conf.Address)
if err != nil {
panic(err)
return errors.Err(err)
}
dht.conn = listener.(*net.UDPConn)
if dht.conf.PrintState {
go printState(dht)
}
return nil
}
// listen receives message from udp.
@ -159,201 +180,98 @@ func (dht *DHT) runHandler() {
for {
select {
case pkt = <-dht.packets:
handle(dht, pkt)
handlePacket(dht, pkt)
}
}
}
// Run starts the dht.
func (dht *DHT) Run() {
dht.init()
func (dht *DHT) Run() error {
err := dht.init()
if err != nil {
return err
}
dht.listen()
dht.join()
log.Info("DHT ready")
dht.runHandler()
return nil
}
// handle handles packets received from udp.
func handle(dht *DHT, pkt packet) {
//log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data)
if err != nil {
log.Errorf("error decoding data: %s\n%s", err, pkt.data)
return
}
msgType, ok := data[headerTypeField]
if !ok {
log.Errorf("decoded data has no message type: %s", data)
return
}
switch msgType.(int64) {
case requestType:
request := Request{}
err = bencode.DecodeBytes(pkt.data, &request)
if err != nil {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
handleRequest(dht, pkt.raddr, request)
case responseType:
response := Response{}
err = bencode.DecodeBytes(pkt.data, &response)
if err != nil {
return
}
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data)
handleResponse(dht, pkt.raddr, response)
case errorType:
e := Error{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
ExceptionType: data[headerPayloadField].(string),
Response: getArgs(data[headerArgsField]),
}
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
handleError(dht, pkt.raddr, e)
default:
log.Errorf("Invalid message type: %s", msgType)
return
func printState(dht *DHT) {
t := time.NewTicker(60 * time.Second)
for {
log.Printf("DHT state at %s", time.Now().Format(time.RFC822Z))
log.Printf("Outstanding transactions: %d", dht.tm.Count())
log.Printf("Known nodes: %d", dht.store.CountKnownNodes())
log.Printf("Buckets: \n%s", dht.rt.BucketInfo())
<-t.C
}
}
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
if request.NodeID == dht.node.id.RawString() {
log.Warn("ignoring self-request")
return
}
switch request.Method {
case pingMethod:
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
case storeMethod:
if request.StoreArgs.BlobHash == "" {
log.Errorln("blobhash is empty")
return // nothing to store
}
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
dht.store.Insert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
case findNodeMethod:
log.Println("findnode")
if len(request.Args) < 1 {
log.Errorln("nothing to find")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return
}
doFindNodes(dht, addr, request)
case findValueMethod:
log.Println("findvalue")
if len(request.Args) < 1 {
log.Errorln("nothing to find")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return
}
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
response.FindValueKey = request.Args[0]
response.FindNodeData = nodes
send(dht, addr, response)
} else {
doFindNodes(dht, addr, request)
}
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
log.Errorln("invalid request method")
return
}
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
dht.routingTable.Update(node)
}
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize)
if len(closestNodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes {
response.FindNodeData[i] = *n
}
send(dht, addr, response)
}
}
// handleResponse handles responses received from udp.
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
spew.Dump(response)
// TODO: find transaction by message id, pass along response
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
dht.routingTable.Update(node)
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
spew.Dump(e)
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
dht.routingTable.Update(node)
}
// send sends data to the udp.
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], spew.Sdump(res.Data))
} else {
log.Debugf("[%s] %s", spew.Sdump(data))
}
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return err
}
//log.Infof("Encoded: %s", string(encoded))
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err = dht.conn.WriteToUDP(encoded, addr)
return err
}
func getArgs(argsInt interface{}) []string {
var args []string
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
args = append(args, cast.ToString(v.Index(i).Interface()))
}
}
return args
}
func argsToString(args []string) string {
argsCopy := make([]string, len(args))
copy(argsCopy, args)
for k, v := range argsCopy {
if len(v) == nodeIDLength {
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
}
}
return strings.Join(argsCopy, ", ")
}
//func (dht *DHT) Get(hash bitmap) ([]Node, error) {
// return iterativeFindNode(dht, hash)
//}
//
//func iterativeFindNode(dht *DHT, hash bitmap) ([]Node, error) {
// shortlist := dht.rt.FindClosest(hash, alpha)
// if len(shortlist) == 0 {
// return nil, errors.Err("no nodes in routing table")
// }
//
// pending := make(chan *Node)
// contacted := make(map[bitmap]bool)
// contactedMutex := &sync.Mutex{}
// closestNodeMutex := &sync.Mutex{}
// closestNode := shortlist[0]
// wg := sync.WaitGroup{}
//
// for i := 0; i < alpha; i++ {
// wg.Add(1)
// go func() {
// defer wg.Done()
// for {
// node, ok := <-pending
// if !ok {
// return
// }
//
// contactedMutex.Lock()
// if _, ok := contacted[node.id]; ok {
// contactedMutex.Unlock()
// continue
// }
// contacted[node.id] = true
// contactedMutex.Unlock()
//
// res := dht.tm.Send(node, &Request{
// NodeID: dht.node.id.RawString(),
// Method: findNodeMethod,
// Args: []string{hash.RawString()},
// })
// if res == nil {
// // remove node from shortlist
// continue
// }
//
// for _, n := range res.FindNodeData {
// pending <- &n
// closestNodeMutex.Lock()
// if n.id.Xor(hash).Less(closestNode.id.Xor(hash)) {
// closestNode = &n
// }
// closestNodeMutex.Unlock()
// }
// }
// }()
// }
//
// for _, n := range shortlist {
// pending <- n
// }
//
// wg.Wait()
//
// return nil, nil
//}

View file

@ -1,10 +1,12 @@
package dht
import (
"encoding/hex"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode"
"github.com/spf13/cast"
"github.com/zeebo/bencode"
)
const (
@ -171,6 +173,21 @@ type Response struct {
FindValueKey string
}
func (r Response) ArgsDebug() string {
if len(r.FindNodeData) == 0 {
return r.Data
}
str := "contacts "
if r.FindValueKey != "" {
str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " "
}
for _, c := range r.FindNodeData {
str += c.Addr().String() + ":" + c.id.Hex()[:8] + ", "
}
return str[:len(str)-2] // chomp off last ", "
}
func (r Response) MarshalBencode() ([]byte, error) {
data := map[string]interface{}{
headerTypeField: responseType,

View file

@ -7,8 +7,8 @@ import (
"strings"
"testing"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
"github.com/zeebo/bencode"
)
func TestBencodeDecodeStoreArgs(t *testing.T) {

View file

@ -3,12 +3,15 @@ package dht
import (
"bytes"
"container/list"
"fmt"
"net"
"sort"
"strings"
"sync"
"github.com/lbryio/errors.go"
"github.com/zeebo/bencode"
"github.com/lyoshenka/bencode"
)
type Node struct {
@ -17,6 +20,10 @@ type Node struct {
port int
}
func (n Node) Addr() *net.UDPAddr {
return &net.UDPAddr{IP: n.ip, Port: n.port}
}
func (n Node) MarshalCompact() ([]byte, error) {
if n.ip.To4() == nil {
return nil, errors.Err("ip not set")
@ -102,6 +109,7 @@ func (a byXorDistance) Less(i, j int) bool {
type RoutingTable struct {
node Node
buckets [numBuckets]*list.List
lock *sync.RWMutex
}
func newRoutingTable(node *Node) *RoutingTable {
@ -110,39 +118,73 @@ func newRoutingTable(node *Node) *RoutingTable {
rt.buckets[i] = list.New()
}
rt.node = *node
rt.lock = &sync.RWMutex{}
return &rt
}
func (rt *RoutingTable) BucketInfo() string {
rt.lock.RLock()
defer rt.lock.RUnlock()
bucketInfo := []string{}
for i, b := range rt.buckets {
count := countInList(b)
if count > 0 {
bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %d", i, count))
}
}
if len(bucketInfo) == 0 {
return "buckets are empty"
}
return strings.Join(bucketInfo, "\n")
}
func (rt *RoutingTable) Update(node *Node) {
prefixLength := node.id.Xor(rt.node.id).PrefixLen()
bucket := rt.buckets[prefixLength]
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, node.id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, rt.node.id)
if element == nil {
if bucket.Len() <= bucketSize {
bucket.PushBack(node)
if bucket.Len() >= bucketSize {
// TODO: Ping front node first. Only remove if it does not respond
bucket.Remove(bucket.Front())
}
// TODO: Handle insertion when the list is full by evicting old elements if
// they don't respond to a ping.
bucket.PushBack(node)
} else {
bucket.MoveToBack(element)
}
}
func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
func (rt *RoutingTable) RemoveByID(id bitmap) {
rt.lock.Lock()
defer rt.lock.Unlock()
bucketNum := bucketFor(rt.node.id, id)
bucket := rt.buckets[bucketNum]
element := findInList(bucket, rt.node.id)
if element != nil {
bucket.Remove(element)
}
}
func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node {
rt.lock.RLock()
defer rt.lock.RUnlock()
var toSort []*SortedNode
prefixLength := target.Xor(rt.node.id).PrefixLen()
bucket := rt.buckets[prefixLength]
toSort = appendNodes(toSort, bucket.Front(), nil, target)
bucketNum := bucketFor(rt.node.id, target)
bucket := rt.buckets[bucketNum]
toSort = appendNodes(toSort, bucket.Front(), target)
for i := 1; (prefixLength-i >= 0 || prefixLength+i < numBuckets) && len(toSort) < count; i++ {
if prefixLength-i >= 0 {
bucket = rt.buckets[prefixLength-i]
toSort = appendNodes(toSort, bucket.Front(), nil, target)
for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ {
if bucketNum-i >= 0 {
bucket = rt.buckets[bucketNum-i]
toSort = appendNodes(toSort, bucket.Front(), target)
}
if prefixLength+i < numBuckets {
bucket = rt.buckets[prefixLength+i]
toSort = appendNodes(toSort, bucket.Front(), nil, target)
if bucketNum+i < numBuckets {
bucket = rt.buckets[bucketNum+i]
toSort = appendNodes(toSort, bucket.Front(), target)
}
}
@ -151,6 +193,9 @@ func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
var nodes []*Node
for _, c := range toSort {
nodes = append(nodes, c.node)
if len(nodes) >= limit {
break
}
}
return nodes
@ -165,10 +210,25 @@ func findInList(bucket *list.List, value bitmap) *list.Element {
return nil
}
func appendNodes(nodes []*SortedNode, start, end *list.Element, target bitmap) []*SortedNode {
for curr := start; curr != end; curr = curr.Next() {
func countInList(bucket *list.List) int {
count := 0
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
count++
}
return count
}
func appendNodes(nodes []*SortedNode, start *list.Element, target bitmap) []*SortedNode {
for curr := start; curr != nil; curr = curr.Next() {
node := curr.Value.(*Node)
nodes = append(nodes, &SortedNode{node, node.id.Xor(target)})
}
return nodes
}
func bucketFor(id bitmap, target bitmap) int {
if id.Equals(target) {
panic("nodes do not have a bucket for themselves")
}
return numBuckets - 1 - target.Xor(id).PrefixLen()
}

210
dht/rpc.go Normal file
View file

@ -0,0 +1,210 @@
package dht
import (
"crypto/rand"
"encoding/hex"
"net"
"reflect"
"strings"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
)
func newMessageID() string {
buf := make([]byte, messageIDLength)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
return string(buf)
}
// handlePacke handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) {
//log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data)
if err != nil {
log.Errorf("error decoding data: %s\n%s", err, pkt.data)
return
}
msgType, ok := data[headerTypeField]
if !ok {
log.Errorf("decoded data has no message type: %s", data)
return
}
switch msgType.(int64) {
case requestType:
request := Request{}
err = bencode.DecodeBytes(pkt.data, &request)
if err != nil {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
handleRequest(dht, pkt.raddr, request)
case responseType:
response := Response{}
err = bencode.DecodeBytes(pkt.data, &response)
if err != nil {
return
}
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data)
handleResponse(dht, pkt.raddr, response)
case errorType:
e := Error{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
ExceptionType: data[headerPayloadField].(string),
Response: getArgs(data[headerArgsField]),
}
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
handleError(dht, pkt.raddr, e)
default:
log.Errorf("Invalid message type: %s", msgType)
return
}
}
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
if request.NodeID == dht.node.id.RawString() {
log.Warn("ignoring self-request")
return
}
switch request.Method {
case pingMethod:
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
case storeMethod:
if request.StoreArgs.BlobHash == "" {
log.Errorln("blobhash is empty")
return // nothing to store
}
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
case findNodeMethod:
if len(request.Args) < 1 {
log.Errorln("nothing to find")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return
}
doFindNodes(dht, addr, request)
case findValueMethod:
if len(request.Args) < 1 {
log.Errorln("nothing to find")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return
}
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
response.FindValueKey = request.Args[0]
response.FindNodeData = nodes
send(dht, addr, response)
} else {
doFindNodes(dht, addr, request)
}
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
log.Errorln("invalid request method")
return
}
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.rt.FindClosest(nodeID, bucketSize)
if len(closestNodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes {
response.FindNodeData[i] = *n
}
send(dht, addr, response)
} else {
log.Warn("no nodes in routing table")
}
}
// handleResponse handles responses received from udp.
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
tx := dht.tm.Find(response.ID, addr)
if tx != nil {
tx.res <- &response
}
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
spew.Dump(e)
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
// send sends data to the udp.
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug())
} else {
log.Debugf("[%s] %s", dht.node.id.Hex()[:8], spew.Sdump(data))
}
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return err
}
//log.Infof("Encoded: %s", string(encoded))
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
_, err = dht.conn.WriteToUDP(encoded, addr)
return err
}
func getArgs(argsInt interface{}) []string {
var args []string
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
args = append(args, cast.ToString(v.Index(i).Interface()))
}
}
return args
}
func argsToString(args []string) string {
argsCopy := make([]string, len(args))
copy(argsCopy, args)
for k, v := range argsCopy {
if len(v) == nodeIDLength {
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
}
}
return strings.Join(argsCopy, ", ")
}

View file

@ -2,13 +2,61 @@ package dht
import (
"net"
"strconv"
"strings"
"testing"
"time"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
"github.com/zeebo/bencode"
)
type testUDPPacket struct {
data []byte
addr *net.UDPAddr
}
type testUDPConn struct {
addr *net.UDPAddr
toRead chan testUDPPacket
writes chan testUDPPacket
}
func newTestUDPConn(addr string) *testUDPConn {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
panic("addr needs ip and port")
}
port, err := strconv.Atoi(parts[1])
if err != nil {
panic(err)
}
return &testUDPConn{
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
toRead: make(chan testUDPPacket),
writes: make(chan testUDPPacket),
}
}
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
select {
case packet := <-t.toRead:
n := copy(b, packet.data)
return n, packet.addr, nil
//default:
// return 0, nil, nil
}
}
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
t.writes <- testUDPPacket{data: b, addr: addr}
return len(b), nil
}
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func TestPing(t *testing.T) {
log.SetLevel(log.DebugLevel)
dhtNodeID := newRandomBitmap()
@ -16,12 +64,15 @@ func TestPing(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.runHandler()
messageID := newRandomBitmap().RawString()
messageID := newMessageID()
data, err := bencode.EncodeBytes(map[string]interface{}{
headerTypeField: requestType,
@ -107,12 +158,16 @@ func TestStore(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.runHandler()
messageID := newRandomBitmap().RawString()
messageID := newMessageID()
blobHashToStore := newRandomBitmap().RawString()
storeRequest := Request{
@ -178,7 +233,7 @@ func TestStore(t *testing.T) {
}
}
if len(dht.store.data) != 1 {
if len(dht.store.nodeIDs) != 1 {
t.Error("dht store has wrong number of items")
}
@ -197,7 +252,10 @@ func TestFindNode(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.runHandler()
@ -207,10 +265,10 @@ func TestFindNode(t *testing.T) {
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.routingTable.Update(&n)
dht.rt.Update(&n)
}
messageID := newRandomBitmap().RawString()
messageID := newMessageID()
blobHashToFind := newRandomBitmap().RawString()
request := Request{
@ -270,7 +328,11 @@ func TestFindValueExisting(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.runHandler()
@ -280,16 +342,18 @@ func TestFindValueExisting(t *testing.T) {
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.routingTable.Update(&n)
dht.rt.Update(&n)
}
//data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
messageID := newRandomBitmap().RawString()
messageID := newMessageID()
valueToFind := newRandomBitmap().RawString()
nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286}
dht.store.Insert(valueToFind, nodeToFind)
dht.store.Upsert(valueToFind, nodeToFind)
dht.store.Upsert(valueToFind, nodeToFind)
dht.store.Upsert(valueToFind, nodeToFind)
request := Request{
ID: messageID,
@ -348,7 +412,11 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()})
if err != nil {
t.Fatal(err)
}
dht.conn = conn
dht.listen()
go dht.runHandler()
@ -358,10 +426,10 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i}
nodes = append(nodes, n)
dht.routingTable.Update(&n)
dht.rt.Update(&n)
}
messageID := newRandomBitmap().RawString()
messageID := newMessageID()
valueToFind := newRandomBitmap().RawString()
request := Request{
@ -442,6 +510,9 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
} else if rMessageID != messageID {
t.Error("unexpected message ID")
}
if len(rMessageID) != messageIDLength {
t.Errorf("message ID should be %d chars long", messageIDLength)
}
}
_, ok = resp[headerNodeIDField]
@ -454,6 +525,9 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
} else if rNodeID != dhtNodeID {
t.Error("unexpected node ID")
}
if len(rNodeID) != nodeIDLength {
t.Errorf("node ID should be %d chars long", nodeIDLength)
}
}
}

View file

@ -4,39 +4,52 @@ import "sync"
type peer struct {
node Node
//<lastPublished>,
//<originallyPublished>
// <originalPublisherID>
}
type peerStore struct {
data map[string][]peer
lock sync.RWMutex
nodeIDs map[string]map[bitmap]bool
nodeInfo map[bitmap]peer
lock sync.RWMutex
}
func newPeerStore() *peerStore {
return &peerStore{
data: make(map[string][]peer),
nodeIDs: make(map[string]map[bitmap]bool),
nodeInfo: make(map[bitmap]peer),
}
}
func (s *peerStore) Insert(key string, node Node) {
func (s *peerStore) Upsert(key string, node Node) {
s.lock.Lock()
defer s.lock.Unlock()
newPeer := peer{node: node}
_, ok := s.data[key]
if !ok {
s.data[key] = []peer{newPeer}
} else {
s.data[key] = append(s.data[key], newPeer)
if _, ok := s.nodeIDs[key]; !ok {
s.nodeIDs[key] = make(map[bitmap]bool)
}
s.nodeIDs[key][node.id] = true
s.nodeInfo[node.id] = peer{node: node}
}
func (s *peerStore) Get(key string) []Node {
s.lock.RLock()
defer s.lock.RUnlock()
var nodes []Node
if peers, ok := s.data[key]; ok {
for _, p := range peers {
nodes = append(nodes, p.node)
if ids, ok := s.nodeIDs[key]; ok {
for id := range ids {
peer, ok := s.nodeInfo[id]
if !ok {
panic("node id in IDs list, but not in nodeInfo")
}
nodes = append(nodes, peer.node)
}
}
return nodes
}
func (s *peerStore) CountKnownNodes() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.nodeInfo)
}

101
dht/transaction_manager.go Normal file
View file

@ -0,0 +1,101 @@
package dht
import (
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// query represents the query data included queried node and query-formed data.
type transaction struct {
node *Node
req *Request
res chan *Response
}
// transactionManager represents the manager of transactions.
type transactionManager struct {
lock *sync.RWMutex
transactions map[string]*transaction
dht *DHT
}
// newTransactionManager returns new transactionManager pointer.
func newTransactionManager(dht *DHT) *transactionManager {
return &transactionManager{
lock: &sync.RWMutex{},
transactions: make(map[string]*transaction),
dht: dht,
}
}
// insert adds a transaction to transactionManager.
func (tm *transactionManager) insert(trans *transaction) {
tm.lock.Lock()
defer tm.lock.Unlock()
tm.transactions[trans.req.ID] = trans
}
// delete removes a transaction from transactionManager.
func (tm *transactionManager) delete(transID string) {
tm.lock.Lock()
defer tm.lock.Unlock()
delete(tm.transactions, transID)
}
// find transaction for id. optionally ensure that addr matches node from transaction
func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction {
tm.lock.RLock()
defer tm.lock.RUnlock()
t, ok := tm.transactions[id]
if !ok {
return nil
} else if addr != nil && t.node.Addr().String() != addr.String() {
return nil
}
return t
}
func (tm *transactionManager) Send(node *Node, req *Request) *Response {
if node.id.Equals(tm.dht.node.id) {
log.Error("sending query to self")
return nil
}
req.ID = newMessageID()
trans := &transaction{
node: node,
req: req,
res: make(chan *Response),
}
tm.insert(trans)
defer tm.delete(trans.req.ID)
for i := 0; i < udpRetry; i++ {
if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil {
log.Error(err)
break
}
select {
case res := <-trans.res:
return res
case <-time.After(udpTimeout):
}
}
tm.dht.rt.RemoveByID(trans.node.id)
return nil
}
// Count returns the number of transactions in the manager
func (tm *transactionManager) Count() int {
tm.lock.Lock()
defer tm.lock.Unlock()
return len(tm.transactions)
}