basic dht + tests
This commit is contained in:
parent
d989d42ad3
commit
a74f82e6b2
9 changed files with 971 additions and 0 deletions
85
dht/bitmap.go
Normal file
85
dht/bitmap.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type bitmap [nodeIDLength]byte
|
||||
|
||||
func (b bitmap) RawString() string {
|
||||
return string(b[0:nodeIDLength])
|
||||
}
|
||||
|
||||
func (b bitmap) Hex() string {
|
||||
return hex.EncodeToString(b[0:nodeIDLength])
|
||||
}
|
||||
|
||||
func (b bitmap) Equals(other bitmap) bool {
|
||||
for k := range b {
|
||||
if b[k] != other[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (b bitmap) Less(other interface{}) bool {
|
||||
for k := range b {
|
||||
if b[k] != other.(bitmap)[k] {
|
||||
return b[k] < other.(bitmap)[k]
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b bitmap) Xor(other bitmap) bitmap {
|
||||
var ret bitmap
|
||||
for k := range b {
|
||||
ret[k] = b[k] ^ other[k]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// PrefixLen returns the number of leading 0 bits
|
||||
func (b bitmap) PrefixLen() (ret int) {
|
||||
for i := range b {
|
||||
for j := 0; j < 8; j++ {
|
||||
if (b[i]>>uint8(7-j))&0x1 != 0 {
|
||||
return i*8 + j
|
||||
}
|
||||
}
|
||||
}
|
||||
return nodeIDLength*8 - 1
|
||||
}
|
||||
|
||||
func newBitmapFromBytes(data []byte) bitmap {
|
||||
if len(data) != nodeIDLength {
|
||||
panic("invalid bitmap of length " + strconv.Itoa(len(data)))
|
||||
}
|
||||
|
||||
var bmp bitmap
|
||||
copy(bmp[:], data)
|
||||
return bmp
|
||||
}
|
||||
|
||||
func newBitmapFromString(data string) bitmap {
|
||||
return newBitmapFromBytes([]byte(data))
|
||||
}
|
||||
|
||||
func newBitmapFromHex(hexStr string) bitmap {
|
||||
decoded, err := hex.DecodeString(hexStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return newBitmapFromBytes(decoded)
|
||||
}
|
||||
|
||||
func newRandomBitmap() bitmap {
|
||||
var id bitmap
|
||||
for k := range id {
|
||||
id[k] = uint8(rand.Intn(256))
|
||||
}
|
||||
return id
|
||||
}
|
48
dht/bitmap_test.go
Normal file
48
dht/bitmap_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package dht
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBitmap(t *testing.T) {
|
||||
a := bitmap{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
}
|
||||
b := bitmap{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 46,
|
||||
}
|
||||
c := bitmap{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
}
|
||||
|
||||
if !a.Equals(a) {
|
||||
t.Error("bitmap does not equal itself")
|
||||
}
|
||||
if a.Equals(b) {
|
||||
t.Error("bitmap equals another bitmap with different id")
|
||||
}
|
||||
|
||||
if !a.Xor(b).Equals(c) {
|
||||
t.Error(a.Xor(b))
|
||||
}
|
||||
|
||||
if c.PrefixLen() != 375 {
|
||||
t.Error(c.PrefixLen())
|
||||
}
|
||||
|
||||
if b.Less(a) {
|
||||
t.Error("bitmap fails lessThan test")
|
||||
}
|
||||
|
||||
id := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
if newBitmapFromHex(id).Hex() != id {
|
||||
t.Error(newBitmapFromHex(id).Hex())
|
||||
}
|
||||
}
|
60
dht/conn.go
Normal file
60
dht/conn.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UDPConn interface {
|
||||
ReadFromUDP([]byte) (int, *net.UDPAddr, error)
|
||||
WriteToUDP([]byte, *net.UDPAddr) (int, error)
|
||||
SetWriteDeadline(time.Time) error
|
||||
}
|
||||
|
||||
type testUDPPacket struct {
|
||||
data []byte
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
type testUDPConn struct {
|
||||
addr *net.UDPAddr
|
||||
toRead chan testUDPPacket
|
||||
writes chan testUDPPacket
|
||||
}
|
||||
|
||||
func newTestUDPConn(addr string) *testUDPConn {
|
||||
parts := strings.Split(addr, ":")
|
||||
if len(parts) != 2 {
|
||||
panic("addr needs ip and port")
|
||||
}
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &testUDPConn{
|
||||
addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port},
|
||||
toRead: make(chan testUDPPacket),
|
||||
writes: make(chan testUDPPacket),
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
|
||||
select {
|
||||
case packet := <-t.toRead:
|
||||
n := copy(b, packet.data)
|
||||
return n, packet.addr, nil
|
||||
//default:
|
||||
// return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||
t.writes <- testUDPPacket{data: b, addr: addr}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t testUDPConn) SetWriteDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
350
dht/dht.go
Normal file
350
dht/dht.go
Normal file
|
@ -0,0 +1,350 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
const network = "udp4"
|
||||
const bucketSize = 20
|
||||
const numBuckets = nodeIDLength * 8
|
||||
|
||||
// packet represents the information receive from udp.
|
||||
type packet struct {
|
||||
data []byte
|
||||
raddr *net.UDPAddr
|
||||
}
|
||||
|
||||
// Config represents the configure of dht.
|
||||
type Config struct {
|
||||
// this node's address. format is `ip:port`
|
||||
Address string
|
||||
// the seed nodes through which we can join in dht network
|
||||
SeedNodes []string
|
||||
// the hex-encoded node id for this node. if string is empty, a random id will be generated
|
||||
NodeID string
|
||||
}
|
||||
|
||||
// NewStandardConfig returns a Config pointer with default values.
|
||||
func NewStandardConfig() *Config {
|
||||
return &Config{
|
||||
Address: ":4444",
|
||||
SeedNodes: []string{
|
||||
"lbrynet1.lbry.io:4444",
|
||||
"lbrynet2.lbry.io:4444",
|
||||
"lbrynet3.lbry.io:4444",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DHT represents a DHT node.
|
||||
type DHT struct {
|
||||
conf *Config
|
||||
conn UDPConn
|
||||
node *Node
|
||||
routingTable *RoutingTable
|
||||
packets chan packet
|
||||
}
|
||||
|
||||
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
|
||||
func New(config *Config) *DHT {
|
||||
if config == nil {
|
||||
config = NewStandardConfig()
|
||||
}
|
||||
|
||||
var id bitmap
|
||||
if config.NodeID == "" {
|
||||
id = newRandomBitmap()
|
||||
} else {
|
||||
id = newBitmapFromHex(config.NodeID)
|
||||
}
|
||||
node := &Node{id: id, addr: config.Address}
|
||||
return &DHT{
|
||||
conf: config,
|
||||
node: node,
|
||||
routingTable: NewRoutingTable(node),
|
||||
packets: make(chan packet),
|
||||
}
|
||||
}
|
||||
|
||||
// init initializes global variables.
|
||||
func (dht *DHT) init() {
|
||||
log.Info("Initializing DHT on " + dht.conf.Address)
|
||||
log.Infof("Node ID is %s", dht.node.id.Hex())
|
||||
listener, err := net.ListenPacket(network, dht.conf.Address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dht.conn = listener.(*net.UDPConn)
|
||||
}
|
||||
|
||||
// listen receives message from udp.
|
||||
func (dht *DHT) listen() {
|
||||
go func() {
|
||||
buf := make([]byte, 8192)
|
||||
for {
|
||||
n, raddr, err := dht.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Errorf("udp read error: %v", err)
|
||||
continue
|
||||
} else if raddr == nil {
|
||||
log.Errorf("udp read with no raddr")
|
||||
continue
|
||||
}
|
||||
dht.packets <- packet{data: buf[:n], raddr: raddr}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// join makes current node join the dht network.
|
||||
func (dht *DHT) join() {
|
||||
for _, addr := range dht.conf.SeedNodes {
|
||||
raddr, err := net.ResolveUDPAddr(network, addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_ = raddr
|
||||
|
||||
// NOTE: Temporary node has NO node id.
|
||||
//dht.transactionManager.findNode(
|
||||
// &node{addr: raddr},
|
||||
// dht.node.id.RawString(),
|
||||
//)
|
||||
}
|
||||
}
|
||||
|
||||
func (dht *DHT) runHandler() {
|
||||
var pkt packet
|
||||
|
||||
for {
|
||||
select {
|
||||
case pkt = <-dht.packets:
|
||||
handle(dht, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the dht.
|
||||
func (dht *DHT) Run() {
|
||||
dht.init()
|
||||
dht.listen()
|
||||
dht.join()
|
||||
log.Info("DHT ready")
|
||||
dht.runHandler()
|
||||
}
|
||||
|
||||
// handle handles packets received from udp.
|
||||
func handle(dht *DHT, pkt packet) {
|
||||
//log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data))
|
||||
|
||||
var data map[string]interface{}
|
||||
err := bencode.DecodeBytes(pkt.data, &data)
|
||||
if err != nil {
|
||||
log.Errorf("Error decoding data: %s\n%s", err, pkt.data)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, ok := data[headerTypeField]
|
||||
if !ok {
|
||||
log.Errorf("Decoded data has no message type: %s", data)
|
||||
return
|
||||
}
|
||||
|
||||
switch msgType.(int64) {
|
||||
case requestType:
|
||||
request := Request{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
Method: data[headerPayloadField].(string),
|
||||
Args: getArgs(data[headerArgsField]),
|
||||
}
|
||||
log.Infof("%s: Received from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
|
||||
handleRequest(dht, pkt.raddr, request)
|
||||
|
||||
case responseType:
|
||||
response := Response{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
}
|
||||
|
||||
if reflect.TypeOf(data[headerPayloadField]).Kind() == reflect.String {
|
||||
response.Data = data[headerPayloadField].(string)
|
||||
} else {
|
||||
response.FindNodeData = getFindNodeResponse(data[headerPayloadField])
|
||||
}
|
||||
|
||||
handleResponse(dht, pkt.raddr, response)
|
||||
|
||||
case errorType:
|
||||
e := Error{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
ExceptionType: data[headerPayloadField].(string),
|
||||
Response: getArgs(data[headerArgsField]),
|
||||
}
|
||||
handleError(dht, pkt.raddr, e)
|
||||
|
||||
default:
|
||||
log.Errorf("Invalid message type: %s", msgType)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) {
|
||||
log.Infoln("handling request")
|
||||
if request.NodeID == dht.node.id.RawString() {
|
||||
log.Warn("ignoring self-request")
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
case pingMethod:
|
||||
log.Println("ping")
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: "pong"})
|
||||
case storeMethod:
|
||||
log.Println("store")
|
||||
case findNodeMethod:
|
||||
log.Println("findnode")
|
||||
//if len(request.Args) < 1 {
|
||||
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//target := request.Args[0]
|
||||
//if len(target) != nodeIDLength {
|
||||
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid target"}})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//nodes := []findNodeDatum{}
|
||||
//targetID := newBitmapFromString(target)
|
||||
//
|
||||
//no, _ := dht.routingTable.GetNodeKBucktByID(targetID)
|
||||
//if no != nil {
|
||||
// nodes = []findNodeDatum{{ID: no.id.RawString(), IP: no.addr.IP.String(), Port: no.addr.Port}}
|
||||
//} else {
|
||||
// neighbors := dht.routingTable.GetNeighbors(targetID, dht.K)
|
||||
// for _, n := range neighbors {
|
||||
// nodes = append(nodes, findNodeDatum{ID: n.id.RawString(), IP: n.addr.IP.String(), Port: n.addr.Port})
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: nodes})
|
||||
|
||||
default:
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||
return
|
||||
}
|
||||
|
||||
node := &Node{id: newBitmapFromString(request.NodeID), addr: addr.String()}
|
||||
dht.routingTable.Update(node)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleResponse handles responses received from udp.
|
||||
func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success bool) {
|
||||
spew.Dump(response)
|
||||
|
||||
//switch trans.request.Method {
|
||||
//case pingMethod:
|
||||
//case findNodeMethod:
|
||||
// target := trans.request.Args[0]
|
||||
// if findOn(dht, response.FindNodeData, newBitmapFromString(target), findNodeMethod) != nil {
|
||||
// return
|
||||
// }
|
||||
//default:
|
||||
// return
|
||||
//}
|
||||
|
||||
node := &Node{id: newBitmapFromString(response.NodeID), addr: addr.String()}
|
||||
dht.routingTable.Update(node)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) {
|
||||
spew.Dump(e)
|
||||
return true
|
||||
}
|
||||
|
||||
// send sends data to the udp.
|
||||
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||
if req, ok := data.(Request); ok {
|
||||
log.Infof("%s: Sending %s(%s)", hex.EncodeToString([]byte(req.NodeID))[:8], req.Method, argsToString(req.Args))
|
||||
} else {
|
||||
log.Infof("%s: Sending %s", data.GetID(), spew.Sdump(data))
|
||||
}
|
||||
encoded, err := data.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//log.Infof("Encoded: %s", string(encoded))
|
||||
|
||||
dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
|
||||
|
||||
_, err = dht.conn.WriteToUDP(encoded, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
func getFindNodeResponse(i interface{}) (data []findNodeDatum) {
|
||||
if reflect.TypeOf(i).Kind() != reflect.Slice {
|
||||
return
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(i)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if v.Index(i).Kind() != reflect.Interface {
|
||||
continue
|
||||
}
|
||||
|
||||
contact := v.Index(i).Elem()
|
||||
if contact.Type().Kind() != reflect.Slice || contact.Len() != 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
if contact.Index(0).Elem().Kind() != reflect.String ||
|
||||
contact.Index(1).Elem().Kind() != reflect.String ||
|
||||
!(contact.Index(2).Elem().Kind() == reflect.Int64 ||
|
||||
contact.Index(2).Elem().Kind() == reflect.Int) {
|
||||
continue
|
||||
}
|
||||
|
||||
data = append(data, findNodeDatum{
|
||||
ID: contact.Index(0).Elem().String(),
|
||||
IP: contact.Index(1).Elem().String(),
|
||||
Port: int(contact.Index(2).Elem().Int()),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getArgs(argsInt interface{}) (args []string) {
|
||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(argsInt)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
args = append(args, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func argsToString(args []string) string {
|
||||
for k, v := range args {
|
||||
if len(v) == nodeIDLength {
|
||||
args[k] = hex.EncodeToString([]byte(v))[:8]
|
||||
}
|
||||
}
|
||||
return strings.Join(args, ", ")
|
||||
}
|
193
dht/dht_test.go
Normal file
193
dht/dht_test.go
Normal file
|
@ -0,0 +1,193 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
dhtNodeID := newRandomBitmap()
|
||||
testNodeID := newRandomBitmap()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()})
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
|
||||
data, err := bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
headerMessageIDField: messageID,
|
||||
headerNodeIDField: testNodeID.RawString(),
|
||||
headerPayloadField: "ping",
|
||||
headerArgsField: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Error("timeout")
|
||||
case resp := <-conn.writes:
|
||||
var response map[string]interface{}
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(response) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(response))
|
||||
}
|
||||
|
||||
_, ok := response[headerTypeField]
|
||||
if !ok {
|
||||
t.Error("missing type field")
|
||||
} else {
|
||||
rType, ok := response[headerTypeField].(int64)
|
||||
if !ok {
|
||||
t.Error("type is not an integer")
|
||||
} else if rType != responseType {
|
||||
t.Error("unexpected response type")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerMessageIDField]
|
||||
if !ok {
|
||||
t.Error("missing message id field")
|
||||
} else {
|
||||
rMessageID, ok := response[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != messageID {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerNodeIDField]
|
||||
if !ok {
|
||||
t.Error("missing node id field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerNodeIDField].(string)
|
||||
if !ok {
|
||||
t.Error("node ID is not a string")
|
||||
} else if rNodeID != dhtNodeID.RawString() {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Error("missing payload field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerPayloadField].(string)
|
||||
if !ok {
|
||||
t.Error("payload is not a string")
|
||||
} else if rNodeID != pingSuccessResponse {
|
||||
t.Error("did not pong")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
dhtNodeID := newRandomBitmap()
|
||||
testNodeID := newRandomBitmap()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()})
|
||||
dht.conn = conn
|
||||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
idToStore := newRandomBitmap().RawString()
|
||||
|
||||
data, err := bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
headerMessageIDField: messageID,
|
||||
headerNodeIDField: testNodeID.RawString(),
|
||||
headerPayloadField: "store",
|
||||
headerArgsField: []string{idToStore},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Error("timeout")
|
||||
case resp := <-conn.writes:
|
||||
var response map[string]interface{}
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(response) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(response))
|
||||
}
|
||||
|
||||
_, ok := response[headerTypeField]
|
||||
if !ok {
|
||||
t.Error("missing type field")
|
||||
} else {
|
||||
rType, ok := response[headerTypeField].(int64)
|
||||
if !ok {
|
||||
t.Error("type is not an integer")
|
||||
} else if rType != responseType {
|
||||
t.Error("unexpected response type")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerMessageIDField]
|
||||
if !ok {
|
||||
t.Error("missing message id field")
|
||||
} else {
|
||||
rMessageID, ok := response[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != messageID {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerNodeIDField]
|
||||
if !ok {
|
||||
t.Error("missing node id field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerNodeIDField].(string)
|
||||
if !ok {
|
||||
t.Error("node ID is not a string")
|
||||
} else if rNodeID != dhtNodeID.RawString() {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Error("missing payload field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerPayloadField].(string)
|
||||
if !ok {
|
||||
t.Error("payload is not a string")
|
||||
} else if rNodeID != storeSuccessResponse {
|
||||
t.Error("did not return OK")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
103
dht/messages.go
Normal file
103
dht/messages.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package dht
|
||||
|
||||
import "github.com/zeebo/bencode"
|
||||
|
||||
const (
|
||||
pingMethod = "ping"
|
||||
storeMethod = "store"
|
||||
findNodeMethod = "findNode"
|
||||
findValueMethod = "findValue"
|
||||
)
|
||||
|
||||
const (
|
||||
pingSuccessResponse = "pong"
|
||||
storeSuccessResponse = "OK"
|
||||
)
|
||||
|
||||
const (
|
||||
requestType = 0
|
||||
responseType = 1
|
||||
errorType = 2
|
||||
)
|
||||
|
||||
const (
|
||||
// these are strings because bencode requires bytestring keys
|
||||
headerTypeField = "0"
|
||||
headerMessageIDField = "1"
|
||||
headerNodeIDField = "2"
|
||||
headerPayloadField = "3"
|
||||
headerArgsField = "4"
|
||||
)
|
||||
|
||||
type Message interface {
|
||||
GetID() string
|
||||
Encode() ([]byte, error)
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Method string
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (r Request) GetID() string { return r.ID }
|
||||
func (r Request) Encode() ([]byte, error) {
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
headerMessageIDField: r.ID,
|
||||
headerNodeIDField: r.NodeID,
|
||||
headerPayloadField: r.Method,
|
||||
headerArgsField: r.Args,
|
||||
})
|
||||
}
|
||||
|
||||
type findNodeDatum struct {
|
||||
ID string
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
type Response struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Data string
|
||||
FindNodeData []findNodeDatum
|
||||
}
|
||||
|
||||
func (r Response) GetID() string { return r.ID }
|
||||
func (r Response) Encode() ([]byte, error) {
|
||||
data := map[string]interface{}{
|
||||
headerTypeField: responseType,
|
||||
headerMessageIDField: r.ID,
|
||||
headerNodeIDField: r.NodeID,
|
||||
}
|
||||
if r.Data != "" {
|
||||
data[headerPayloadField] = r.Data
|
||||
} else {
|
||||
var nodes []interface{}
|
||||
for _, n := range r.FindNodeData {
|
||||
nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port})
|
||||
}
|
||||
data[headerPayloadField] = nodes
|
||||
}
|
||||
|
||||
return bencode.EncodeBytes(data)
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Response []string
|
||||
ExceptionType string
|
||||
}
|
||||
|
||||
func (e Error) GetID() string { return e.ID }
|
||||
func (e Error) Encode() ([]byte, error) {
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: errorType,
|
||||
headerMessageIDField: e.ID,
|
||||
headerNodeIDField: e.NodeID,
|
||||
headerPayloadField: e.ExceptionType,
|
||||
headerArgsField: e.Response,
|
||||
})
|
||||
}
|
20
dht/node.go
Normal file
20
dht/node.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package dht
|
||||
|
||||
const nodeIDLength = 48 // bytes
|
||||
const compactNodeInfoLength = nodeIDLength + 6
|
||||
|
||||
type Node struct {
|
||||
id bitmap
|
||||
addr string
|
||||
}
|
||||
|
||||
type SortedNode struct {
|
||||
node *Node
|
||||
sortKey bitmap
|
||||
}
|
||||
|
||||
type byXorDistance []*SortedNode
|
||||
|
||||
func (a byXorDistance) Len() int { return len(a) }
|
||||
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a byXorDistance) Less(i, j int) bool { return a[i].sortKey.Less(a[j].sortKey) }
|
79
dht/routing_table.go
Normal file
79
dht/routing_table.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type RoutingTable struct {
|
||||
node Node
|
||||
buckets [numBuckets]*list.List
|
||||
}
|
||||
|
||||
func NewRoutingTable(node *Node) *RoutingTable {
|
||||
var rt RoutingTable
|
||||
for i := range rt.buckets {
|
||||
rt.buckets[i] = list.New()
|
||||
}
|
||||
rt.node = *node
|
||||
return &rt
|
||||
}
|
||||
|
||||
func (rt *RoutingTable) Update(node *Node) {
|
||||
prefixLength := node.id.Xor(rt.node.id).PrefixLen()
|
||||
bucket := rt.buckets[prefixLength]
|
||||
element := findInList(bucket, rt.node.id)
|
||||
if element == nil {
|
||||
if bucket.Len() <= bucketSize {
|
||||
bucket.PushBack(node)
|
||||
}
|
||||
// TODO: Handle insertion when the list is full by evicting old elements if
|
||||
// they don't respond to a ping.
|
||||
} else {
|
||||
bucket.MoveToBack(element)
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
|
||||
toSort := []*SortedNode{}
|
||||
|
||||
prefixLength := target.Xor(rt.node.id).PrefixLen()
|
||||
bucket := rt.buckets[prefixLength]
|
||||
appendNodes(bucket.Front(), nil, &toSort, target)
|
||||
|
||||
for i := 1; (prefixLength-i >= 0 || prefixLength+i < nodeIDLength*8) && len(toSort) < count; i++ {
|
||||
if prefixLength-i >= 0 {
|
||||
bucket = rt.buckets[prefixLength-i]
|
||||
appendNodes(bucket.Front(), nil, &toSort, target)
|
||||
}
|
||||
if prefixLength+i < nodeIDLength*8 {
|
||||
bucket = rt.buckets[prefixLength+i]
|
||||
appendNodes(bucket.Front(), nil, &toSort, target)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(byXorDistance(toSort))
|
||||
|
||||
nodes := []*Node{}
|
||||
for _, c := range toSort {
|
||||
nodes = append(nodes, c.node)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
func findInList(bucket *list.List, value bitmap) *list.Element {
|
||||
for curr := bucket.Front(); curr != nil; curr = curr.Next() {
|
||||
if curr.Value.(*Node).id.Equals(value) {
|
||||
return curr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendNodes(start, end *list.Element, nodes *[]*SortedNode, target bitmap) {
|
||||
for curr := start; curr != end; curr = curr.Next() {
|
||||
node := curr.Value.(*Node)
|
||||
*nodes = append(*nodes, &SortedNode{node, node.id.Xor(target)})
|
||||
}
|
||||
}
|
33
dht/routing_table_test.go
Normal file
33
dht/routing_table_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package dht
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRoutingTable(t *testing.T) {
|
||||
n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
rt := NewRoutingTable(&Node{n1, "localhost:8000"})
|
||||
rt.Update(&Node{n2, "localhost:8001"})
|
||||
rt.Update(&Node{n3, "localhost:8002"})
|
||||
|
||||
contacts := rt.FindClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
|
||||
if len(contacts) != 1 {
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
if !contacts[0].id.Equals(n3) {
|
||||
t.Error(contacts[0])
|
||||
}
|
||||
|
||||
contacts = rt.FindClosest(n2, 10)
|
||||
if len(contacts) != 2 {
|
||||
t.Error(len(contacts))
|
||||
return
|
||||
}
|
||||
if !contacts[0].id.Equals(n2) {
|
||||
t.Error(contacts[0])
|
||||
}
|
||||
if !contacts[1].id.Equals(n3) {
|
||||
t.Error(contacts[1])
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue